We know that feature scaling makes the job of gradient descent easy and allows it to converge faster. Feature scaling is performed as a pre-processing task on the dataset. But once the normalized input is fed to the deep network, as each layer is affected by parameters in all the input layer, even a small change in the network parameter is amplified and leads to the input distribution being changed in the internal layers of the network. This is known as internal covariance shift.
Batch Normalization is an idea introduced by Ioffe & Szegedy [1] of normalizing activations of every fully connected and convolution layer with unit standard deviation and zero mean during training, as a part of the network architecture itself. It allows us to use much higher learning rates and be less careful about network initialization.
It is implemented as a layer (with trainable parameters) and normalizes the activations of the previous layer. Backpropagation allows the network to learn if they want the activations to be normalized and upto what extent. It is inserted immediately after fully connected or convolutional layers and before nonlinearities. It effectively reduces the internal covariance shift in deep networks.
Advantages of BatchNorm
Improves gradient flow through very deep networks
Reduces dependency on careful initialization
Allows higher learning rates
Provides regularization and reduces dependency on dropout
Forward Propagation
In the forward pass, we calculate the mean and variance of the batch, normalize the input to have unit Gaussian distribution and scale and shift it with the learnable parameters γ \gamma γ and β \beta β , respectively.
μ B = 1 m ∑ i = 1 m x i σ B 2 = 1 m ∑ i = 1 m ( x i − μ B ) 2 x i ^ = x i − μ B σ B 2 + ϵ y i = γ x i + β \begin{align*}
\mu_B &= \frac{1}{m}\sum_{i=1}^{m} x_i \\
\sigma_B^2 &= \frac{1}{m}\sum_{i=1}^{m} (x_i - \mu_B)^2 \\
\hat{x_i} &= \frac{x_i - \mu_B}{\sqrt{ \sigma_B^2 + \epsilon }} \\
y_i &= \gamma x_i + \beta
\end{align*} μ B σ B 2 x i ^ y i = m 1 i = 1 ∑ m x i = m 1 i = 1 ∑ m ( x i − μ B ) 2 = σ B 2 + ϵ x i − μ B = γ x i + β
The implementation is very simple and straightforward:
n_X,c_X,h_X,w_X = X.shape
X_flat = X.reshape(n_X,c_X*h_X*w_X)
mu = np.mean(X_flat,axis=0 )
var = np.var(X_flat, axis=0 )
X_norm = (X_flat - mu)/np.sqrt(var + 1e-8 )
out = gamma * X_norm + beta
Backward Propagation
For our backward pass, we need to find gradients ∂ C ∂ x i \frac{\partial C}{\partial x_i} ∂ x i ∂ C , ∂ C ∂ γ \frac{\partial C}{\partial \gamma} ∂ γ ∂ C and ∂ C ∂ β \frac{\partial C}{\partial \beta} ∂ β ∂ C . We calculate the intermediate gradients from top to bottom in the computational graph to get these gradients.
∂ C ∂ γ x i ^ = ∂ C ∂ y i × ∂ y i ∂ γ x i = ∂ C ∂ y i × ∂ ( γ x i + β ) ∂ γ x i = ∂ C ∂ y i \begin{align*}
\frac{\partial C}{\partial \gamma \hat{x_i}} &= \frac{\partial C}{\partial y_i} \times \frac{\partial y_i}{\partial \gamma x_i} \\
&= \frac{\partial C}{\partial y_i} \times \frac{\partial (\gamma x_i + \beta)}{\partial \gamma x_i} \\
&= \frac{\partial C}{\partial y_i}
\end{align*} ∂ γ x i ^ ∂ C = ∂ y i ∂ C × ∂ γ x i ∂ y i = ∂ y i ∂ C × ∂ γ x i ∂ ( γ x i + β ) = ∂ y i ∂ C
∂ C ∂ β = ∂ C ∂ y i × ∂ y i ∂ β = ∂ C ∂ y i × ∂ ( γ x i + β ) ∂ β = ∑ i = 1 m ∂ C ∂ y i \begin{align*}
\frac{\partial C}{\partial \beta} &= \frac{\partial C}{\partial y_i} \times \frac{\partial y_i}{\partial \beta} \\
&= \frac{\partial C}{\partial y_i} \times \frac{\partial(\gamma x_i + \beta )}{\partial \beta} \\
&= \sum_{i=1}^m \frac{\partial C}{\partial y_i}
\end{align*} ∂ β ∂ C = ∂ y i ∂ C × ∂ β ∂ y i = ∂ y i ∂ C × ∂ β ∂ ( γ x i + β ) = i = 1 ∑ m ∂ y i ∂ C
∂ C ∂ γ = ∂ C ∂ γ x i ^ × ∂ γ x i ^ ∂ γ = ∑ i = 1 m ∂ C ∂ y i × x i ^ \begin{align*}
\frac{\partial C}{\partial \gamma} &= \frac{\partial C}{\partial \gamma \hat{x_i}} \times \frac{\partial \gamma \hat{x_i}}{\partial \gamma} \\
&= \sum_{i=1}^m \frac{\partial C }{\partial y_i} \times \hat{x_i}
\end{align*} ∂ γ ∂ C = ∂ γ x i ^ ∂ C × ∂ γ ∂ γ x i ^ = i = 1 ∑ m ∂ y i ∂ C × x i ^
Now we have gradients for both the learnable parameters. Now for input gradient,
∂ C ∂ x i ^ = ∂ C ∂ γ x i × ∂ γ x i ∂ x i = ∂ C ∂ y i × γ \begin{align*}
\frac{\partial C}{\partial \hat{x_i}} &= \frac{\partial C}{\partial \gamma x_i} \times \frac{\partial \gamma x_i}{\partial x_i} \\
&= \frac{\partial C}{\partial y_i} \times \gamma \\
\end{align*} ∂ x i ^ ∂ C = ∂ γ x i ∂ C × ∂ x i ∂ γ x i = ∂ y i ∂ C × γ
∂ C ∂ σ B 2 = ∂ C ∂ x i ^ × ∂ x i ^ ∂ σ B 2 = ∂ C ∂ x i ^ × ∂ ( x i − μ B σ B 2 + ϵ ) ∂ σ B 2 = ∑ i = 1 m ∂ C ∂ x i ^ × ( x i − μ B ) × ∂ ( σ B 2 + ϵ ) − 1 / 2 ∂ σ B 2 = ∑ i = 1 m ∂ C ∂ x i ^ × ( x i − μ B ) × − 1 2 × ( σ B 2 + ϵ ) − 3 / 2 \begin{align*}
\frac{\partial C}{\partial \sigma_B^2} &= \frac{\partial C}{\partial \hat{x_i}} \times \frac{\partial \hat{x_i} }{\partial \sigma_B^2} \\
&= \frac{\partial C}{\partial \hat{x_i}} \times \frac{\partial \left ( \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} \right ) }{\partial \sigma_B^2 } \\
&= \sum_{i=1}^m \frac{\partial C}{\partial \hat{x_i}} \times (x_i - \mu_B) \times \frac{\partial (\sigma_B^2 + \epsilon)^{-1/2}}{\partial \sigma_B^2} \\
&= \sum_{i=1}^m \frac{\partial C}{\partial \hat{x_i}} \times (x_i - \mu_B) \times -\frac{1}{2} \times (\sigma_B^2 + \epsilon)^{-3/2}
\end{align*} ∂ σ B 2 ∂ C = ∂ x i ^ ∂ C × ∂ σ B 2 ∂ x i ^ = ∂ x i ^ ∂ C × ∂ σ B 2 ∂ ( σ B 2 + ϵ x i − μ B ) = i = 1 ∑ m ∂ x i ^ ∂ C × ( x i − μ B ) × ∂ σ B 2 ∂ ( σ B 2 + ϵ ) − 1/2 = i = 1 ∑ m ∂ x i ^ ∂ C × ( x i − μ B ) × − 2 1 × ( σ B 2 + ϵ ) − 3/2
We can see from the computation graph, μ B \mu_B μ B is on two nodes, so we need to add up gradients on both nodes.
∂ C ∂ μ B = ∂ C ∂ x i ^ × ∂ x i ^ ∂ μ b + ∂ C ∂ σ B 2 × ∂ σ B 2 ∂ μ B = ∂ C ∂ x i ^ × ∂ ( x i − μ B σ B 2 + ϵ ) ∂ μ B + ∂ C ∂ σ B 2 × ∂ ( 1 m ∑ i = 0 m ( x i − μ B ) ) 2 ∂ μ B = ∑ i = 1 m ∂ C ∂ x i ^ × − 1 σ B 2 + β + ∂ C ∂ σ B 2 × 1 m ∑ i = 1 m 2 ( x i − μ B ) \begin{align*}
\frac{\partial C}{\partial \mu_B} &= \frac{\partial C}{\partial \hat{x_i}} \times \frac{\partial \hat{x_i}}{\partial \mu_b} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{\partial \sigma_B^2}{\partial \mu_B} \\
&= \frac{\partial C}{\partial \hat{x_i} } \times \frac{\partial \left ( \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} \right ) }{\partial \mu_B} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{\partial \left (\frac{1}{m} \sum_{i=0}^m (x_i - \mu_B) \right)^2 }{\partial \mu_B} \\
&= \sum_{i=1}^m \frac{\partial C}{\partial \hat{x_i}}\times \frac{-1}{\sqrt{\sigma_B^2}+\beta} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{1}{m}\sum_{i=1}^m 2(x_i - \mu_B)
\end{align*} ∂ μ B ∂ C = ∂ x i ^ ∂ C × ∂ μ b ∂ x i ^ + ∂ σ B 2 ∂ C × ∂ μ B ∂ σ B 2 = ∂ x i ^ ∂ C × ∂ μ B ∂ ( σ B 2 + ϵ x i − μ B ) + ∂ σ B 2 ∂ C × ∂ μ B ∂ ( m 1 ∑ i = 0 m ( x i − μ B ) ) 2 = i = 1 ∑ m ∂ x i ^ ∂ C × σ B 2 + β − 1 + ∂ σ B 2 ∂ C × m 1 i = 1 ∑ m 2 ( x i − μ B )
Now we have all the intermediate gradients to calculate input gradient. Since x i x_i x i is in three nodes, we add up the gradients on each of those nodes.
∂ C ∂ x i = ∂ C ∂ x i ^ × ∂ x i ^ ∂ x i + ∂ C ∂ μ B × ∂ μ B ∂ x i + ∂ C ∂ σ B 2 × ∂ σ B 2 ∂ x i = ∂ C ∂ x i ^ × 1 σ B 2 + β + ∂ C ∂ μ B × ∂ 1 m ∑ i = 1 m x i ∂ μ B + ∂ C ∂ σ B 2 × 2 m ( x i − μ B ) = ∂ C ∂ x i ^ × 1 σ B 2 + β + ∂ C ∂ μ B × 1 m + ∂ C ∂ σ B 2 × 2 m ( x i − μ B ) \begin{align*}
\frac{\partial C}{\partial x_i} &= \frac{\partial C}{\partial \hat{x_i}} \times \frac{\partial \hat{x_i}}{\partial x_i} + \frac{\partial C}{\partial \mu_B} \times \frac{\partial \mu_B}{\partial x_i} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{\partial \sigma_B^2}{\partial x_i} \\
&= \frac{\partial C}{\partial \hat{x_i}} \times \frac{1}{\sqrt{\sigma_B^2+ \beta}} + \frac{\partial C}{\partial \mu_B} \times \frac{\partial \frac{1}{m}\sum_{i=1}^m x_i} {\partial \mu_B} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{2}{m}(x_i - \mu_B) \\
&= \frac{\partial C}{\partial \hat{x_i}} \times \frac{1}{\sqrt{\sigma_B^2+ \beta}} + \frac{\partial C}{\partial \mu_B} \times \frac{1}{m} + \frac{\partial C}{\partial \sigma_B^2} \times \frac{2}{m}(x_i - \mu_B) \\
\end{align*} ∂ x i ∂ C = ∂ x i ^ ∂ C × ∂ x i ∂ x i ^ + ∂ μ B ∂ C × ∂ x i ∂ μ B + ∂ σ B 2 ∂ C × ∂ x i ∂ σ B 2 = ∂ x i ^ ∂ C × σ B 2 + β 1 + ∂ μ B ∂ C × ∂ μ B ∂ m 1 ∑ i = 1 m x i + ∂ σ B 2 ∂ C × m 2 ( x i − μ B ) = ∂ x i ^ ∂ C × σ B 2 + β 1 + ∂ μ B ∂ C × m 1 + ∂ σ B 2 ∂ C × m 2 ( x i − μ B )
Translating the gradient expressions in python, we have our implementation of backprop through the BatchNorm layer:
n_X,c_X,h_X,w_X = X.shape
X_flat = X.reshape(n_X,c_X*h_X*w_X)
dout = dout.reshape(n_X,c_X*h_X*w_X)
X_mu = X_flat - mu
var_inv = 1. /np.sqrt(var + 1e-8 )
dX_norm = dout * gamma
dvar = np.sum (dX_norm * X_mu,axis=0 ) * -0.5 * (var + 1e-8 )**(-3 /2 )
dmu = np.sum (dX_norm * -var_inv ,axis=0 ) + dvar * 1 /n_X * np.sum (-2. * X_mu, axis=0 )
dX = (dX_norm * var_inv) + (dmu / n_X) + (dvar * 2 /n_X * X_mu)
dbeta = np.sum (dout,axis=0 )
dgamma = dout * X_norm
Source code
Here is the source code for BatchNorm layer with forward and backward API implemented.
class Batchnorm ():
def __init__ (self,X_dim ):
self.d_X, self.h_X, self.w_X = X_dim
self.gamma = np.ones((1 , int (np.prod(X_dim)) ))
self.beta = np.zeros((1 , int (np.prod(X_dim))))
self.params = [self.gamma,self.beta]
def forward (self,X ):
self.n_X = X.shape[0 ]
self.X_shape = X.shape
self.X_flat = X.ravel().reshape(self.n_X,-1 )
self.mu = np.mean(self.X_flat,axis=0 )
self.var = np.var(self.X_flat, axis=0 )
self.X_norm = (self.X_flat - self.mu)/np.sqrt(self.var + 1e-8 )
out = self.gamma * self.X_norm + self.beta
return out.reshape(self.X_shape)
def backward (self,dout ):
dout = dout.ravel().reshape(dout.shape[0 ],-1 )
X_mu = self.X_flat - self.mu
var_inv = 1. /np.sqrt(self.var + 1e-8 )
dbeta = np.sum (dout,axis=0 )
dgamma = dout * self.X_norm
dX_norm = dout * self.gamma
dvar = np.sum (dX_norm * X_mu,axis=0 ) * -0.5 * (self.var + 1e-8 )**(-3 /2 )
dmu = np.sum (dX_norm * -var_inv ,axis=0 ) + dvar * 1 /self.n_X * np.sum (-2. * X_mu, axis=0 )
dX = (dX_norm * var_inv) + (dmu / self.n_X) + (dvar * 2 /self.n_X * X_mu)
dX = dX.reshape(self.X_shape)
return dX, [dgamma, dbeta]