Batch Normalization

CNN 학습시에 Batch Normalization을 사용하면 좋다는 것을 알고 있었으나, 어떤 문제를 어떻게 해결해주는지 CNN에서 동작은 어떻게 하는지 몰랐기 때문에 이번 기회에 정리를 해보려고 한다.

 

우선 배치 정규화를 통해서 아래와 같은 효과를 볼 수 있다고 한다.

 

장점

  • 학습 속도(training speed)를 빠르게 할 수 있다.
  • 가중치 초기화(weight initialization)에 대한 민감도를 감소시킨다.
  • 모델 일반화(regularization)효과가 있다.

단점

  • 이름 그대로 batch size로 normalization을 하기 때문에 batch size의 크기가 성능에 영향을 준다.
  • RNN과 같이 sequential 데이터를 처리하는 경우에 BN을 적용시키기 어렵다. 

그럼 어떤 문제 때문에 Batch Normalization이 나오게 된 것일까?

Covariate Shift

Covariate shift는 공변량 변화라고 부르며 입력 데이터의 분포가 학습할 때와 테스트할 때 다르게 나타나는 현상을 말한다. 이는 학습 데이터를 이용해 만든 모델이 테스트 데이터셋을 가지고 추론할 때 성능에 영항을 미칠 수 있다. 

예를 들어, 고양이와 강아지를 분류하는 모델을 학습시킬 때 학습 데이터로 고양이 이미지를 '러시안 블루'종만 사용하고 테스트 데이터로 '페르시안'종의 고양이를 분류하려고 한다면 모델은 잘 학습할 수 있을까? 학습 데이터의 분포와 테스트 데이터의 분포가 다르기 때문에 학습시킨 모델의 분류 성능은 떨어질 것이며 이처럼 학습 데이터와 테스트 데이터의 분포가 다른 것을 covariate shift라 부른다.

참고: https://cvml.tistory.com/5

 

 

Internal Covariate Shift

위에서 언급한 Covariate Shift(공변량 변화)가 뉴럴 네트워크 내부에서 일어나는 현상을 Internal Covariate Shift라고 한다. 즉, 매 스텝마다 hidden layer에 입력으로 들어오는 데이터의 분포가 달라지는 것을 의미하며 Internal Covariate Shift는 layer가 깊을수록 심화될 수 있다.

우선 기존에 딥러닝 학습에 있어서 겪고 있던 문제들 중 대표적으로 각 레이어의 가중치들이 역전파 학습을 진행할 때 입력부로 올라오면서 학습이 잘 되지 않는 Gradient Vanishing / Exploding 문제가 있었다. 이를 해결하기 위해 가중치를 초기화하는 방법(Xavier, He), ReLU와 같이 더 나은 활성화 함수(Activation Function) 방법, 규제화(Regularization) 방법을 제안하거나 학습률(Learning Rate)를 낮춰서 모든 레이어들의 가중치가 학습되도록 연구해왔다. 이러한 방법들을 이용함에도 레이어의 수가 많아질수록 여전히 학습이 잘 되지 않아 근본적인 문제를 해결하지 못하고 있었다.

 

BN 논문의 저자는 학습이 잘 되지 않는 근본적인 문제를 'Internal Covariate Shift'현상 때문이라고 주장을 하였고 Batch Normalization 기법이 이 현상을 해결해준다고 주장하였다.

 

Algorithm

직관적 이해

batch Normalization은 어디에(평균), 얼만큼 세게 뿌릴지(분산)을 학습시키는 것이다. 해당 Node에 대해 nonlinearity을 얼마나 살리면서 vanishing gradient를 얼마나 해결할지, AI가 알아내는 것이다.

 

아래 강아지 이미지를 오른쪽 코드를 통해 Layer을 통과시켰을 때 어떻게 진행될지 이해해보자.

input size: 3*32*32

CNN filter size: 32*3*3*3(filter 개수 * input channel 수 * weight * height), padding =1

Batch size: 32

padding이 1로 주어져있기때문에 output의 weight, height은 input의 weight과 height값과 동일하다.

Filter(3*3*3)가 input(3*32*32)를 가로로 32번, 세로로 32번 돌면서 output(32*32)를 출력한다.

이후, 이 output(32*32)를 batch normalization을 진행한다.  

그럼 이렇게 나온다!