티스토리 뷰

Introduction

이 글을 읽는 분들께서는 Batch Normalization이 무엇인지, 어디 쓰는지, 어떻게 쓰는지 등에 대해 기초적인 것은 알고 있다고 생각하고 skip 하도록 하겠습니다.

Forward Pass와 Backpropagation 식을 Trace trick을 활용하여 유도하고, 실제 구현을 해보겠습니다.

대부분 CS231n assignment2의 expression을 따라갑니다. (변수, 표기 등을 의미합니다.)

Forward Pass

뭐.. 유명한 그림 있죠? 논문의 그림 그대로 긁어오겠습니다.

국룰 공식입니다. 근데, 이거 보면 잘 이해가 안 될 수 있으니(저만 그럴수도 ㅎ) 조금 더 보충설명을 하겠습니다.

이쪽 보충설명을 해야, 추후 gradient 유도할 때 안 헷갈립니다.

먼저, 지금 보여주는 이 minibatch $\mathcal{B}=\left\{x_1,\,x_2,\,\cdots,\,x_m\right\}$은 각각이 scalar입니다.

물론, 벡터화해서 생각할 수 있지만, 그 때의 벡터는 (일반적으로 생각하는) column vector가 아니라 Row vector가 됩니다. 다른 말로, 항상 쓰던 $n\times 1$의 벡터가 아닌 $1\times n$의 벡터가 된 다는 것이죠. 행렬 입장에서 보면 가로줄이라고 생각하면 됩니다.

따라서 $\gamma$와 $\beta$ 모두 하나의 값 scalar가 되구요, 계산을 할 때는 broadcasting을 진행하게 됩니다.

다른 말로, Batch normalization은 하나의 Feature에 대해 Normalizae를 실시하는 과정이 됩니다.

만약 따로 따로 gradient를 계산하는 것이라면 크게 문제는 없지만, 우리는 최대한 식을 간단히 하는 것이 목표입니다.

따라서, $\mu_{\mathcal{B}}$와 $\sigma^2_{\mathcal{B}}$도 최대한 간단하게, 더하기 없이 표현하고 싶습니다!

Expression of $\mu$ and $\sigma^2$ with matrix multiplciation

먼저, 짱돌을(?) 잘 굴리기 위해서 하나의 Feature들을 다 모은 벡터를 생각합시다.

즉, $x=\begin{bmatrix} x_1&x_2&\cdots&x_N\end{bmatrix}^T$입니다.

그러면, 평균은 각각의 값들을 모두 더한 값이 됩니다. 위의 psuedo code에도 나와있듯이 $\mu=\dfrac{1}{N} \displaystyle\sum_{i=1}^N x_i$가 됩니다. 다른 시각으로, 이걸 행렬 표현으로 바꿀 수 있을까요?

모든 것을 다 더해서 N으로 나눈 다는 것을 다르게 적으면, $$\mu=\displaystyle\sum_{i=1}^N \dfrac{1}{N} x_i$$이고, 이를 행렬 표현으로 나타나면 $\dfrac{1}{N}\underbrace{\begin{bmatrix} 1&1&\cdots&1\end{bmatrix}}_{N}x$가 되고, size는 $1\times N$과 $N \times 1$의 행렬곱이므로 $1\times 1$, 즉 scalar ($\gamma$)의 값을 뿜어내게 됩니다.

Broadcasting이라는 연산은 파이썬(or numpy)에서 편의를 위해 제공하는 기능이므로, 자체적으로 broadcasting을 진행하면 앞의 행렬곱을 $1\times N$의 벡터가 아니라 $1$을 복사한 $N\times N$의 행렬에 대해 곱을 실시하면 됩니다.

이는 $x$가 $N\times 1$가 아닌, $N\times D$이여도 문제 없이 작동됩니다. 다만, 반환 행렬의 크기는 $N \times D$가 되겠죠.

분산에 대해서도 동일하게 계산합니다. 다만, 분산은 단순히 $x_i$에 대해 계산하는 것이 아닌, 편차의 제곱에 대한 평균임을 감안하여, 계산을 해야합니다.

편의를 위해 모든 행렬의 원소가 1이고 크기가 $N\times M$의 행렬을 $\mathbf{1}_{(N,\,M)}$라 합시다. 그러면, $$\sigma^2=\dfrac{1}{N}\sum_{i=1}^N (x_i-\mu)^2$$이 되고, 합을 구하는 것은 행렬 표현으로 바꿀 수 있다고 하였으니 안에 있는 (모든 항에 대한) 제곱만 어찌 처리하면 가능합니다.

하지만 행렬 항의 제곱을 행렬 곱으로 표현하기에는 너무 어렵습니다. (뭐 어찌저찌하면 가능은 하겠죠? 차원을 올리거나 확장하거나 하는 식으로 진행하면 가능할 것입니다. 가장 대표적인 방법이 vectorized function인 $\mathrm{vec}$을 쓰는 거긴합니다.)

따라서 파이썬에서 지원하는 행렬 연산들 중 위의 문제를 해결할 수 있는 가장 간단한 연산으로는 Hadamard product가 있습니다. 다들 한번쯤은 들어보셨을 연산인데, 다른말로 element-wise product라고도 합니다. 이걸 조금 응용하면, element-wise function operation이 되고, 이거의 훌륭한 예시는 softmax가 됩니다. (How?)

For more details, please refer this link https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#Analogous_operations

여기에서는 Hadamard보다는 element-wise function을 선언해서 해결합시다. $f(X)=X^2$이라고 합시다. 이를 이용하면,

$$\sigma^2=\dfrac{1}{N}\mathbf{1}_{(N,\,N)} f(X-\mu)$$이 됩니다. 이때, $\mu=\frac{1}{N}\mathbf{1}_{(N,\,N)}X$임을 기억합시다.

$X$를 다르게 표현하면 $IX$이므로, $$\sigma^2=\dfrac{1}{N}\mathbf{1}_{(N,\,N)} f((I-\mathbf{1}_{(N,\,N)}/N)X)$$입니다.

Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def batchnorm_forward(x, gamma, beta, bn_param):
    """
    Forward pass for batch normalization.
 
    During training the sample mean and (uncorrected) sample variance are
    computed from minibatch statistics and used to normalize the incoming data.
    During training we also keep an exponentially decaying running mean of the
    mean and variance of each feature, and these averages are used to normalize
    data at test-time.
 
    At each timestep we update the running averages for mean and variance using
    an exponential decay based on the momentum parameter:
 
    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var
 
    Note that the batch normalization paper suggests a different test-time
    behavior: they compute sample mean and variance for each feature using a
    large number of training images rather than using a running average. For
    this implementation we have chosen to use running averages instead since
    they do not require an additional estimation step; the torch7
    implementation of batch normalization also uses running averages.
 
    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - bn_param: Dictionary with the following keys:
      - mode: 'train' or 'test'; required
      - eps: Constant for numeric stability
      - momentum: Constant for running mean / variance.
      - running_mean: Array of shape (D,) giving running mean of features
      - running_var Array of shape (D,) giving running variance of features
 
    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    mode = bn_param["mode"]
    eps = bn_param.get("eps", 1e-5)
    momentum = bn_param.get("momentum"0.9)
 
    N, D = x.shape
    running_mean = bn_param.get("running_mean", np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get("running_var", np.zeros(D, dtype=x.dtype))
 
    out, cache = NoneNone
    if mode == "train":
        #######################################################################
        # TODO: Implement the training-time forward pass for batch norm.      #
        # Use minibatch statistics to compute the mean and variance, use      #
        # these statistics to normalize the incoming data, and scale and      #
        # shift the normalized data using gamma and beta.                     #
        #                                                                     #
        # You should store the output in the variable out. Any intermediates  #
        # that you need for the backward pass should be stored in the cache   #
        # variable.                                                           #
        #                                                                     #
        # You should also use your computed sample mean and variance together #
        # with the momentum variable to update the running mean and running   #
        # variance, storing your result in the running_mean and running_var   #
        # variables.                                                          #
        #                                                                     #
        # Note that though you should be keeping track of the running         #
        # variance, you should normalize the data based on the standard       #
        # deviation (square root of variance) instead!                        #
        # Referencing the original paper (https://arxiv.org/abs/1502.03167)   #
        # might prove to be helpful.                                          #
        #######################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        sample_mean = np.mean(x, axis = 0# (D, )
        sampel_var = np.var(x, axis = 0# (D, )
        running_mean = momentum * running_mean + (1 - momentum) * sample_mean
        running_var = momentum * running_var + (1 - momentum) * sampel_var
        out = (x-sample_mean) / np.sqrt(sampel_var + eps) * gamma + beta
        cache = (x - sample_mean, 1 / (sampel_var + eps), gamma)
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        #######################################################################
        #                           END OF YOUR CODE                          #
        #######################################################################
    elif mode == "test":
        #######################################################################
        # TODO: Implement the test-time forward pass for batch normalization. #
        # Use the running mean and variance to normalize the incoming data,   #
        # then scale and shift the normalized data using gamma and beta.      #
        # Store the result in the out variable.                               #
        #######################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        out = (x - running_mean) / np.sqrt(running_var + eps) * gamma + beta
 
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        #######################################################################
        #                          END OF YOUR CODE                           #
        #######################################################################
    else:
        raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
 
    # Store the updated running means back into bn_param
    bn_param["running_mean"= running_mean
    bn_param["running_var"= running_var
 
    return out, cache
 
cs

 

(Briefly) Trace trick and useful formula

해당 Chapter에서는 별다른 유도 없이, 단순히 식만 나열하도록 하겠습니다.

$\langle \cdot,\,\cdot\rangle$은 Frobenius inner product, $\odot$는 Hadamard product를, $\sigma(X)$는 행렬 $X$에 대한 element-wise operation function입니다.

Inner product notation

  • $\langle X,\,Y\rangle=\langle Y,\,X\rangle$
  • $\langle aX,\,Y\rangle=\langle X,\, aY \rangle=a\langle X,\,Y\rangle$
  • $\langle X,\,Y\odot Z\rangle=\langle X,\,Z\odot Y\rangle=\langle X\odot Y, Z\rangle$

Total derivate & gradient 

$$df=\langle \nabla_y f,\,dy\rangle,$$ where $f$ maps matrix to scalar. That means, $f(X)$ has scalar value.

Note that we have scalar loss function generally.

Matrix differentation rules

  • $d(X\pm Y)=dX\pm Y$, $d(XY)=(dX)Y + XdY$ (Most important)
  • $d tr(X)=tr(dX)$
  • $d(X\odot Y)=(dX)\odot Y+X\odot dY$
  • $d\sigma(X)=\sigma'(X)\odot dX$

Trace trick

  • $a=tr(a)$ for all $a\in\mathbb R$ (Since our result often scalar, this is very useful)
  • $tr(A^T)=tr(A)$
  • $tr(AB)=tr(BA)$, where $A$ and $B^T$ has same size
  • $d\sigma(X)=\sigma'(X)\odot dX$

P.S. Frobeinus inner product의 정의를 보면, 바로 직전 Chapter에서 다루었던 분산의 식을 표현할 수 있음을 알 수 있지만, 이렇게하나 저렇게하나 계산량은 비슷해서 직관적으로 볼 수 있는 함수롤 새로 정의했습니다.

Backpropagation

앞의 Forward pass를 보시면 알겠지만, 모든 원소가 1인 element는 직접적으로 활용되기보다는 N을 나누어서 활용됩니다. 혼란을 막기 위해, $O$를 모든 원소가 $1/N$인 $N\times N$행렬이라고 칭하겠습니다. 다른 말로, $O=\dfrac{1}{N} \mathbf{1}_{(N,\,N)}$입니다.

먼저, $\gamma$에 대한 Gradient부터 구해보죠.

Gamma는 $1\times D$의 row vector이고, $\hat{x}$와 Hadamard product를 시행하게 됩니다.

Python에서 구현할 경우 broadcasting이 들어가게 될 것이구요, broadcasting 없이 진행한다고 하면 row vector을 N번만큼 복사해야하겠죠.

Gradient는 각 $\hat{x}_i$에 대한 Gradient를 다 더한 것이 될 것입니다. 또한, 단순한 Hadamard product를 하기 때문에 Transpose 등을 거칠 필요 없습니다. (By Inner product (3)) 즉, upstream인 dout에 $\hat{x}$을 Hadamard product한 것을, 세로축을 기준으로 다 더해주면 됩니다. (이건 코드를 보시면 조금 더 직관적인 이해가 될 것입니다.)

$\beta$에 대한 Gradient도 동일합니다. 이도 Broadcasting이 진행이 되기 때문에, Gradient는 각 $\hat{x}_I$에 대한 Gradient의 합이 됩니다. 이때, $\beta$는 bias의 역할을 하기 때문에 단순히 upstream을 세로축 기준으로해서 다 더해주기만 하면 됩니다.

하지만.. $x$에 대한 Gradient는 굉장히 복잡합니다. 일단, 위의 표기를 따라갑시다. $y=\gamma \odot \hat{x}+\beta$입니다.

이때, $y$에 대한 gradient는 dout이고, 결과(Loss)는 Scalar이므로 Trace trick에 의해 $$dy=\langle dout,\,\gamma \odot d \hat{x}\rangle=\langle dout\odot \gamma,\,d\hat{x}\rangle$$가 됩니다. 이제, $d\hat{x}$만 어찌저찌 잘 하면 되겠군요!

식을 보면 $\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}$이 있는데, 문제는 $\mu$와 $\sigma^2$이 전부 $x$에 관한 식이기 때문에, 이것에 대한 미분도 하나하나 고려를 해주어야 한다는 것입니다! 또한, $\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}$과는 dot product가 아니라 Hadamard product임을 생각합시다.

이제 식을 하나하나씩 분석해봅시다!

$\hat{x}=\dfrac{x-\mu}{\sqrt{\sigma^2+\varepsilon}}$이 되고, 말했다시피 $\mu$와 $\sigma^2$ 모두 $x$에 관한 식으므로 미분을 할 때 곱미분을 시행해주어야 합니다. 수식으로 보면, 

$$\begin{aligned} dy &= \langle dout\odot \gamma,\, d\hat{x}\rangle \\ &=\left\langle dout\odot \gamma, \dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\odot d((I-O)X) + (I-O)X \odot d\left(\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right)\right\rangle\\&=\left\langle dout\odot\gamma,\,\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\odot d((I-O)X)\right\rangle + \left\langle dout\odot\gamma,\,(I-O)X\odot d\left(\dfrac{1}{\sigma^2+\varepsilon}\right)\right\rangle\\&= \left\langle dout\odot \gamma \odot \dfrac{1}{\sqrt{\sigma^2+\varepsilon}},\,d((I-O)X) \right\rangle+\left\langle dout\odot\gamma\odot (I-O)X,\,d\left(\dfrac{1}{\sigma^2+\varepsilon}\right)\right\rangle\\&=\left\langle (I-O)^T \left(dout\odot\gamma\odot\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right),\,dX\right\rangle + \left\langle dout\odot\gamma\odot (I-O)X,\,-\dfrac{1}{2}(\sigma^2+\varepsilon)^{-3/2}\odot d\sigma^2\right\rangle\\&=\left\langle (I-O)^T \left(dout\odot\gamma\odot\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right),\,dX\right\rangle -\dfrac{1}{2}\left\langle dout\odot\gamma\odot (I-O)X\odot (\sigma^2+\varepsilon)^{-3/2},\, d\sigma^2\right\rangle\\&=\left\langle (I-O)^T \left(dout\odot\gamma\odot\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right),\,dX\right\rangle -\dfrac{1}{2}\left\langle dout\odot\gamma\odot (I-O)X\odot (\sigma^2+\varepsilon)^{-3/2},\, Odf((I-O)X)\right\rangle\\&=\left\langle (I-O)^T \left(dout\odot\gamma\odot\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right),\,dX\right\rangle -\dfrac{1}{2}\left\langle O^T\left(dout\odot\gamma\odot (I-O)X\odot (\sigma^2+\varepsilon)^{-3/2}\right),\, df((I-O)X)\right\rangle\\&=\left\langle (I-O)^T \left(dout\odot\gamma\odot\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right),\,dX\right\rangle\\&\qquad -\dfrac{1}{2}\left\langle O^T\left(dout\odot\gamma\odot (I-O)X\odot (\sigma^2+\varepsilon)^{-3/2}\right),\, 2(I-O)X\odot d(I-O)X\right\rangle\\&=\left\langle (I-O)^T \left(dout\odot\gamma\odot\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right),\,dX\right\rangle\\&\qquad -\dfrac{1}{2}\left\langle 2(I-O)X\odot\left\{O^T\left(dout\odot\gamma\odot (I-O)X\odot (\sigma^2+\varepsilon)^{-3/2}\right)\right\},\, (I-O)dX\right\rangle\\&=\left\langle (I-O)^T \left(dout\odot\gamma\odot\dfrac{1}{\sqrt{\sigma^2+\varepsilon}}\right),\,dX\right\rangle\\&\qquad -\left\langle (I-O)^T(I-O)X\odot\left\{O^T\left(dout\odot\gamma\odot (I-O)X\odot (\sigma^2+\varepsilon)^{-3/2}\right)\right\},\, dX\right\rangle\end{aligned}$$

Implementation(Backpropagation)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def batchnorm_backward(dout, cache):
    """
    Backward pass for batch normalization.
 
    For this implementation, you should write out a computation graph for
    batch normalization on paper and propagate gradients backward through
    intermediate nodes.
 
    Inputs:
    - dout: Upstream derivatives, of shape (N, D)
    - cache: Variable of intermediates from batchnorm_forward.
 
    Returns a tuple of:
    - dx: Gradient with respect to inputs x, of shape (N, D)
    - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
    - dbeta: Gradient with respect to shift parameter beta, of shape (D,)
    """
    dx, dgamma, dbeta = NoneNoneNone
    ###########################################################################
    # TODO: Implement the backward pass for batch normalization. Store the    #
    # results in the dx, dgamma, and dbeta variables.                         #
    # Referencing the original paper (https://arxiv.org/abs/1502.03167)       #
    # might prove to be helpful.                                              #
    ###########################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    norm_x, var, gamma = cache
    N, _ = norm_x.shape
    I = np.identity(N)
    O = np.ones_like(I) / N
 
    dxb = dout * gamma * var**0.5
    dx = (I-O) @ (dxb - (O @ (dxb * norm_x * var)) * norm_x)
    dbeta = np.sum(dout, axis = 0)
    dgamma = np.sum(dxb / gamma * norm_x, axis = 0)
 
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    ###########################################################################
    #                             END OF YOUR CODE                            #
    ###########################################################################
 
    return dx, dgamma, dbeta
cs

 

'개인 공부' 카테고리의 다른 글

[CS231n] Assignment 1, Implement of vectorized lienar svm  (0) 2023.01.04
Graph transformer networks based text representation  (0) 2022.09.03
Softmax & Loss  (0) 2022.08.13
B. (Variational) Auto Encoder  (0) 2022.07.16
A. Attention  (0) 2022.07.16
댓글
최근에 올라온 글
공지사항
Total
Today
Yesterday
최근에 달린 댓글
링크
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30
글 보관함