Note: Complete source code can be found here https://github.com/parasdahal/deepnet
We know that feature scaling makes the job of gradient descent easy and allows it to converge faster. Feature scaling is performed as a pre-processing task on the dataset. But once the normalized input is fed to the deep network, as each layer is affected by parameters in all the input layer, even a small change in the network parameter is amplified and leads to the input distribution being changed in the internal layers of the network. This is known as internal covariance shift.
Batch Normalization is an idea introduced by Ioffe & Szegedy
It is implemented as a layer (with trainable parameters) and normalizes the activations of the previous layer. Backpropagation allows the network to learn if they want the activations to be normalized and upto what extent. It is inserted immediately after fully connected or convolutional layers and before nonlinearities. It effectively reduces the internal covariance shift in deep networks.
Advantages of BatchNorm
In the forward pass, we calculate the mean and variance of the batch, normalize the input to have unit Gaussian distribution and scale and shift it with the learnable parameters and , respectively.
The implementation is very simple and straightforward:
n_X,c_X,h_X,w_X = X.shape X_flat = X.reshape(n_X,c_X*h_X*w_X) mu = np.mean(X_flat,axis=0) var = np.var(X_flat, axis=0) X_norm = (X_flat - mu)/np.sqrt(var + 1e-8) out = gamma * X_norm + beta
For our backward pass, we need to find gradients , and . We calculate the intermediate gradients from top to bottom in the computational graph to get these gradients.
Now we have gradients for both the learnable parameters. Now for input gradient,
We can see from the computation graph, is on two nodes, so we need to add up gradients on both nodes.
Now we have all the intermediate gradients to calculate input gradient. Since is in three nodes, we add up the gradients on each of those nodes.
Translating the gradient expressions in python, we have our implementation of backprop through the BatchNorm layer:
n_X,c_X,h_X,w_X = X.shape # flatten the inputs and dout X_flat = X.reshape(n_X,c_X*h_X*w_X) dout = dout.reshape(n_X,c_X*h_X*w_X) X_mu = X_flat - mu var_inv = 1./np.sqrt(var + 1e-8) dX_norm = dout * gamma dvar = np.sum(dX_norm * X_mu,axis=0) * -0.5 * (var + 1e-8)**(-3/2) dmu = np.sum(dX_norm * -var_inv ,axis=0) + dvar * 1/n_X * np.sum(-2.* X_mu, axis=0) dX = (dX_norm * var_inv) + (dmu / n_X) + (dvar * 2/n_X * X_mu) dbeta = np.sum(dout,axis=0) dgamma = dout * X_norm
Here is the source code for BatchNorm layer with forward and backward API implemented.
class Batchnorm(): def __init__(self,X_dim): self.d_X, self.h_X, self.w_X = X_dim self.gamma = np.ones((1, int(np.prod(X_dim)) )) self.beta = np.zeros((1, int(np.prod(X_dim)))) self.params = [self.gamma,self.beta] def forward(self,X): self.n_X = X.shape self.X_shape = X.shape self.X_flat = X.ravel().reshape(self.n_X,-1) self.mu = np.mean(self.X_flat,axis=0) self.var = np.var(self.X_flat, axis=0) self.X_norm = (self.X_flat - self.mu)/np.sqrt(self.var + 1e-8) out = self.gamma * self.X_norm + self.beta return out.reshape(self.X_shape) def backward(self,dout): dout = dout.ravel().reshape(dout.shape,-1) X_mu = self.X_flat - self.mu var_inv = 1./np.sqrt(self.var + 1e-8) dbeta = np.sum(dout,axis=0) dgamma = dout * self.X_norm dX_norm = dout * self.gamma dvar = np.sum(dX_norm * X_mu,axis=0) * -0.5 * (self.var + 1e-8)**(-3/2) dmu = np.sum(dX_norm * -var_inv ,axis=0) + dvar * 1/self.n_X * np.sum(-2.* X_mu, axis=0) dX = (dX_norm * var_inv) + (dmu / self.n_X) + (dvar * 2/self.n_X * X_mu) dX = dX.reshape(self.X_shape) return dX, [dgamma, dbeta]