Maxpool Layer - Summarizing the output of Convolution Layer

Pooling layers are important building block of CNNs. They summarize the activation maps and keep the number of network parameters low. Time to implement Maxpool!

Note: Complete source code can be found here

The pooling layer is usually placed after the Convolutional layer. The utility of pooling layer is to reduce the spatial dimension of the input volume for next layers. Note that it only affects weight and height but not depth.

The pooling layer takes an input volume of size .

The two hyperparameters used are:

  1. Spatial Extent
  2. Stride length

The output volume is of size is where , and .

We know that pooling layer computes a fixed function, and in our case the function, so there are no learnable parameters.

Note: It is not common to use zero padding in pooling layers.

Forward Propagation

The max pool layer is similar to convolution layer, but instead of doing convolution operation, we are selecting the max values in the receptive fields of the input, saving the indices and then producing a summarized output volume. The implementation of the forward pass is pretty simple.

# convert channels to separate images so that im2col can arrange them into separate column
X_reshaped = X.reshape(n_X * c_X, 1, h_X, w_X)
X_col = im2col_indices(X_reshaped,h_filter,w_filter,padding = 0,stride = stride)
# find the max index in each receptive field
max_idx = np.argmax(X_col,axis=0)
out = X_col[max_idx, range(max_idx.size)]
# reshape and re-organize the output
out = out.reshape(h_out,w_out,n_X,c_X).transpose(2,3,0,1)

Backward Propagation

For the backward pass of the gradient, we start with a zero matrix and fill the max index of this matrix with the gradient from above. The reason behind this is that there is no gradient for non max neurons, as changing them slightly doesn’t change the output. So gradient from the next layer is passed only to the neurons that achieved max, and all the other neurons get zero gradient.

dX_col = np.zeros_like(X_col)
# flatten dout so that we can index using max_idx from forward pass
dout_flat = dout.transpose(2,3,0,1).ravel()
# now fill max index with the gradient
dX_col[max_idx,range(max_idx.size)] = dout_flat

Since we stretch out the images to columns in the forward pass, we revert this action before we pass back the input gradient.

# convert stretched image to real image
dX = col2im(dX_col,X_reshaped.shape,h_filter,w_filter,padding = 0, stride=stride)
dX = dX.reshape(X.shape)

Source code

Here is the source code for Maxpool layer with forward and backward API implemented.

class Maxpool():

    def __init__(self,X_dim,size,stride):

        self.d_X, self.h_X, self.w_X = X_dim
        self.params = []

        self.size = size
        self.stride = stride
        self.h_out = (self.h_X - size)/stride + 1
        self.w_out = (self.w_X - size)/stride + 1

        if not self.h_out.is_integer() or not self.w_out.is_integer():
            raise Exception("Invalid dimensions!")
        self.h_out,self.w_out  = int(self.h_out), int(self.w_out)
        self.out_dim = (self.d_X,self.h_out,self.w_out)

    def forward(self,X):
        self.n_X = X.shape[0]
        X_reshaped = X.reshape(X.shape[0]*X.shape[1],1,X.shape[2],X.shape[3])

        self.X_col = im2col_indices(X_reshaped, self.size, self.size, padding = 0, stride = self.stride)
        self.max_indexes = np.argmax(self.X_col,axis=0)
        out = self.X_col[self.max_indexes,range(self.max_indexes.size)]

        out = out.reshape(self.h_out,self.w_out,self.n_X,self.d_X).transpose(2,3,0,1)
        return out

    def backward(self,dout):

        dX_col = np.zeros_like(self.X_col)
        # flatten the gradient
        dout_flat = dout.transpose(2,3,0,1).ravel()
        dX_col[self.max_indexes,range(self.max_indexes.size)] = dout_flat
        # get the original X_reshaped structure from col2im
        shape = (self.n_X*self.d_X,1,self.h_X,self.w_X)
        dX = col2im_indices(dX_col,shape,self.size,self.size,padding=0,stride=self.stride)
        dX = dX.reshape(self.n_X,self.d_X,self.h_X,self.w_X)
        return dX,[]