"Convolutional" layer gradients, done right

July 14th, 2025

In this note I want to provide a clear explanation of how the backward pass of a convolutional layer in a Convolutional Neural Network (CNN) works. The goal is for this to be a complete reference, though each section is designed to be self-contained, so you can jump directly to topics of interest.

I'll start with a reminder of the forward pass, without going into too much detail, then cover the backward pass showing the analytical derivation and how that converts into a naive but clear vectorized implementation.

While I get the impression that CNNs are out of favour these days, when I was studying them I found the lack of quality explanations of how backpropagating through a convolutional layer works rather frustrating, so I'm writing my own. A few important things that I felt other explanations lacked, but I will endeavour to elucidate are:

Please write to me with any feedback, errata, or thoughts via email or X.

Forward Pass

The Cross-Correlation "Misnomer"

Despite being called "convolution," the operation that convolutional neural networks actually implement is cross-correlation. That is, under the hood libraries like PyTorch actually implement cross-correlation, not convolution, so let's start by understanding that. The difference is in how we handle the indices:

True Convolution: \[\text{convolution: } (X * W)[i,j] = \sum_{h} \sum_{w} X[i-h, j-w] \cdot W[h,w]\]Cross-Correlation (what CNNs actually do): \[\text{cross-correlation: } (X \star W)[i,j] = \sum_{h} \sum_{w} X[i+h, j+w] \cdot W[h,w]\]where \(R\) and \(C\) are the row and column dimensions of the filter, and \(r\) and \(c\) are row and column indices within the filter.

Notice how convolution subtracts the indices (\(i-h, j-w\)) while cross-correlation adds the indices (\(i+h, j+w\)). This means convolution flips the filter before sliding it, while cross-correlation slides the filter as-is. The summation over \(h\) and \(w\) essentially gives us a way to slide the filter over the input image, at each output position \((i,j)\), we compute the dot product between the filter and the corresponding input window.

This distinction is just a historical misnomer in deep learning, by the time people realized the terminology was technically incorrect, it was too late to change it without causing massive confusion. However, understanding this difference is important because it'll show up when we derive the backward pass equations.

Forward Operation Walkthrough

Let's work through a concrete example to see how the forward pass operates. We'll use the 2D case for mathematical clarity, but remember that in practice we implement the 3D case (multiple channels) by simply adding a summation over the channel dimension.

Setup:

Input Matrix: \[X = \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 6 & 7 & 8 & 9 & 10 \\ 11 & 12 & 13 & 14 & 15 \\ 16 & 17 & 18 & 19 & 20 \\ 21 & 22 & 23 & 24 & 25 \end{bmatrix}\]

Filter: \[W = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix}\]

Computing Output Position (0,0): We take the \(\textcolor{red}{3 \times 3}\) window from the top-left of the input:

\[X_{0:3,0:3} = \begin{bmatrix} \textcolor{red}{1} & \textcolor{red}{2} & \textcolor{red}{3} & 4 & 5 \\ \textcolor{red}{6} & \textcolor{red}{7} & \textcolor{red}{8} & 9 & 10 \\ \textcolor{red}{11} & \textcolor{red}{12} & \textcolor{red}{13} & 14 & 15 \\ 16 & 17 & 18 & 19 & 20 \\ 21 & 22 & 23 & 24 & 25 \end{bmatrix}\] The output value is computed as: \[Y[0,0] = \sum_{i,j} X_{0:3,0:3}[i,j] \cdot W[i,j]\] \[Y[0,0] = (1 \cdot 1) + (2 \cdot 0) + (3 \cdot -1) + (6 \cdot 2) + (7 \cdot 0) + (8 \cdot -2) + (11 \cdot 1) + (12 \cdot 0) + (13 \cdot -1) = -8\]

Computing Output Position (0,1): Now we slide the filter one position to the right: \[X_{0:3,1:4} = \begin{bmatrix} 1 & \textcolor{blue}{2} & \textcolor{blue}{3} & \textcolor{blue}{4} & 5 \\ 6 & \textcolor{blue}{7} & \textcolor{blue}{8} & \textcolor{blue}{9} & 10 \\ 11 & \textcolor{blue}{12} & \textcolor{blue}{13} & \textcolor{blue}{14} & 15 \\ 16 & 17 & 18 & 19 & 20 \\ 21 & 22 & 23 & 24 & 25 \end{bmatrix}\] \[Y[0,1] = (2 \cdot 1) + (3 \cdot 0) + (4 \cdot -1) + (7 \cdot 2) + (8 \cdot 0) + (9 \cdot -2) + (12 \cdot 1) + (13 \cdot 0) + (14 \cdot -1) = -8\]

This process continues for all valid positions, creating our \(3 \times 3\) output feature map.

Implementation

The cross-correlation equation tells us the conceptual operation, but how does this translate to actual code? Let's walk through the implementation step by step.

Creating the Output Matrix: First, we need to define an output matrix to store our convolutional results. To determine its shape, we ask: "How many valid positions can we place our \(HH \times WW\) filter on our \(H \times W\) input?"

With no padding, we can place a \(3 \times 3\) filter on a \(5 \times 5\) input in exactly 3 positions horizontally and 3 positions vertically (positions 0, 1, 2). The general formula counts these valid positions:

h_out = 1 + (H + 2 * padding - HH) // stride
w_out = 1 + (W + 2 * padding - WW) // stride

Starting from position 0, we can go up to position \((H - HH)\), giving us \((H - HH) + 1\) total positions.

Complete nested loop:

for n in range(N):                 # loop over input images
    for f in range(F):             # loop over filters  
        for i in range(h_out):     # loop over output height positions
            for j in range(w_out): # loop over output width positions
                h_start = i * stride
                w_start = j * stride
                h_end = h_start + HH
                w_end = w_start + WW

Our mathematical equation only showed the \((i,j)\) loops because we were working with a single input image and single filter for clarity. In practice, we need the outer loops to process multiple images in a batch (the \(n\) loop) and apply multiple filters to each image (the \(f\) loop).

Each \((i,j)\) represents an output position, and (h_start, w_start) represents where to place the top-left corner of our filter on the input. With stride=1, these are the same, but with stride=2, we'd have h_start = 0, 2, 4, ...

Vectorized window operation:

window = x[n, :, h_start:h_end, w_start:w_end]
out[n, f, i, j] = np.sum(window * w[f]) + b[f]

This is the vectorized version of our double summation. The indexing x[n, :, h_start:h_end, w_start:w_end] selects:

Let's see this in action with our example. When computing output position \((0,0)\):

# h_start=0, w_start=0, so we extract:
window = x[0, :, 0:3, 0:3]  # The top-left 3x3 window

This gives us: \[\text{window} = \begin{bmatrix} 1 & 2 & 3 \\ 6 & 7 & 8 \\ 11 & 12 & 13 \end{bmatrix}, \quad W = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix}\]

The operation window * w[f] performs element-wise multiplication: \[\text{window} \odot W = \begin{bmatrix} 1 \cdot 1 & 2 \cdot 0 & 3 \cdot (-1) \\ 6 \cdot 2 & 7 \cdot 0 & 8 \cdot (-2) \\ 11 \cdot 1 & 12 \cdot 0 & 13 \cdot (-1) \end{bmatrix} = \begin{bmatrix} 1 & 0 & -3 \\ 12 & 0 & -16 \\ 11 & 0 & -13 \end{bmatrix}\]

Then np.sum() collapses this to a scalar: \[\text{np.sum}(\text{window} \odot W) = 1 + 0 + (-3) + 12 + 0 + (-16) + 11 + 0 + (-13) = -8\] Vectorization means that NumPy handles all the indexing automatically, we don't need to write explicit loops over \(h\) and \(w\) positions within the filter. The vectorized operation window * w[f] performs the \(X[i+h, j+w] \cdot W[h,w]\) multiplication, and np.sum() handles the \(\sum_h \sum_w\) summation.

Code

Here's the naive (but clear) vectorized implementation of the forward pass:

def conv_forward_naive(x, w, b, conv_param):
    """
    A naive implementation of the forward pass for a convolutional layer.
    
    Input:
    - x: Input data of shape (N, C, H, W)
    - w: Filter weights of shape (F, C, HH, WW)  
    - b: Biases, of shape (F,)
    - conv_param: Dictionary with 'stride' and 'pad' keys
    
    Returns:
    - out: Output data, of shape (N, F, H', W')
    - cache: (x, w, b, conv_param)
    """
    padding = conv_param['pad']
    stride = conv_param['stride']
    N, C, H, W = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
    F, HH, WW = w.shape[0], w.shape[2], w.shape[3]
    
    # Apply padding
    x = np.pad(x, ((0,0), (0,0), (padding, padding), (padding, padding)))
    
    # Calculate output dimensions
    h_out = 1 + (H + 2 * padding - HH) // stride
    w_out = 1 + (W + 2 * padding - WW) // stride
    out = np.zeros((N, F, h_out, w_out))
    
    # The nested loops implement the sliding window operation
    for n in range(N):          # loop over input images
        for f in range(F):      # loop over filters
            for i in range(h_out):   # loop over height indices
                for j in range(w_out): # loop over width indices
                    # Calculate window position
                    h_start = i * stride
                    w_start = j * stride
                    h_end = h_start + HH
                    w_end = w_start + WW
                    
                    # Extract the input window that matches filter dimensions
                    window = x[n, :, h_start:h_end, w_start:w_end]
                    
                    # Compute cross-correlation: element-wise multiply and sum
                    out[n, f, i, j] = np.sum(window * w[f]) + b[f]
    
    cache = (x, w, b, conv_param)
    return out, cache

The key insight is that each output position is computed by taking the element-wise product of the filter with the corresponding input window, then summing all elements. This is exactly the cross-correlation operation we described mathematically.

Note: While we work with 2D examples for clarity, the code naturally handles the 3D case (multiple input channels) because np.sum(window * w[f]) automatically sums over all dimensions, including the channel dimension. We also omit stride from our mathematical equations for simplicity, but it's implemented in the code. The complete mathematical extension including both multiple channels and stride is:

\[Y[n,f,i,j] = \sum_{c} \sum_{h} \sum_{w} X[n,c,i \cdot \text{stride} + h, j \cdot \text{stride} + w] \cdot W[f,c,h,w] + b[f]\]

where \(c\) represents the channel dimension and stride determines the step size between output positions. For the remainder of this post, we'll assume stride=1 in our equations for clarity.

Backward Pass

Gradient with respect to input (dx)

All right, now let's derive the analytical gradient of the loss with respect to the input. In the forward pass, we computed: \[Y[i,j] = \sum_{h=0}^{HH-1} \sum_{w=0}^{WW-1} X[i+h, j+w] \cdot W[h,w]\]

In backpropagation, we receive the gradient of the loss with respect to the output, \(\frac{\partial L}{\partial Y}\) (which we call dout), and need to compute \(\frac{\partial L}{\partial X}\).

Using the chain rule: \[\frac{\partial L}{\partial X[a,b]} = \sum_{i} \sum_{j} \frac{\partial L}{\partial Y[i,j]} \cdot \frac{\partial Y[i,j]}{\partial X[a,b]}\]

The key question is: which output positions \(Y[i,j]\) depend on input position \(X[a,b]\)?

Looking at our forward pass equation, \(Y[i,j]\) uses \(X[a,b]\) if \(a = i+h\) and \(b = j+w\) for some filter positions \((h,w)\). When this happens, we can compute the partial:

\[\frac{\partial Y[i,j]}{\partial X[a,b]} = \frac{\partial}{\partial X[a,b]} \left[ \sum_{h=0}^{HH-1} \sum_{w=0}^{WW-1} X[i+h, j+w] \cdot W[h,w] \right]\]

Since \(X[a,b]\) appears in this sum only when \(a = i+h\) and \(b = j+w\), the partial derivative picks out the coefficient of \(X[a,b]\): \[\frac{\partial Y[i,j]}{\partial X[a,b]} = W[h,w] \text{ where } h = a-i \text{ and } w = b-j\] Therefore: \[\frac{\partial Y[i,j]}{\partial X[a,b]} = W[a-i, b-j]\] Let's clarify what each variable represents:

So \(W[a-i, b-j]\) tells us: "What filter weight was multiplied by input position \((a,b)\) when computing output position \((i,j)\)?"

Therefore: \[\frac{\partial L}{\partial X[a,b]} = \sum_{i}\sum_{j} \frac{\partial L}{\partial Y[i,j]} \cdot W[a-i, b-j]\]

Now, technically we need to be careful about which \((i,j)\) values to sum over (since not all output positions use every input position), but this leads to complex bounds checking that makes the equation unwieldy. The beauty of the padding approach we'll show next is that it eliminates the need for explicit bounds checking entirely, the padding and convolution structure handle all the boundary conditions automatically.

Recognizing the convolution

Let's substitute \(h = a-i\) and \(w = b-j\) to reveal the structure:

\[\frac{\partial L}{\partial X[a,b]} = \sum_{h} \sum_{w} \frac{\partial L}{\partial Y[a-h, b-w]} \cdot W[h,w]\]

This is exactly the definition of convolution! Which becomes obvious when you compare it with the standard convolution equation:

\[(X * W)[i,j] = \sum_{h} \sum_{w} X[i-h, j-w] \cdot W[h,w]\] This means that our gradient computation can be re-written as a sliding of the flipped filters, since that's what convolution does, whereas cross-correlation slides the non-flipped filter, as we saw at the beginning.

Therefore, our gradient computation is: \[\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} * W_{flipped}\]

where \(W_{flipped}\) is the filter rotated 180 degrees (flipped both horizontally and vertically).

Matching the shapes of x and dx

But we have a problem: the dimensions don't match up correctly. Let's see why with our example:

If we try to convolve \(\frac{\partial L}{\partial Y}\) (3×3) with \(W\) (3×3), we get: \[\text{output size} = 3 - 3 + 1 = 1 \times 1\]

But we need \(\frac{\partial L}{\partial X}\) to be \(5 \times 5\) to match the original input!

Solution: Pad the gradient by (filter_size - 1) We need to pad \(\frac{\partial L}{\partial Y}\) by \((HH-1) = 2\) pixels on each side, transforming it from \(3 \times 3\) to \(7 \times 7\):

\[\frac{\partial L}{\partial Y}_{padded} = \begin{bmatrix} \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} \\ \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} \\ \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{red}{dY[0,0]} & \textcolor{red}{dY[0,1]} & \textcolor{red}{dY[0,2]} & \textcolor{blue}{0} & \textcolor{blue}{0} \\ \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{red}{dY[1,0]} & \textcolor{red}{dY[1,1]} & \textcolor{red}{dY[1,2]} & \textcolor{blue}{0} & \textcolor{blue}{0} \\ \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{red}{dY[2,0]} & \textcolor{red}{dY[2,1]} & \textcolor{red}{dY[2,2]} & \textcolor{blue}{0} & \textcolor{blue}{0} \\ \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} \\ \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} & \textcolor{blue}{0} \end{bmatrix}\]

Now convolving this \(7 \times 7\) padded gradient with our \(3 \times 3\) filter gives us: \[\text{output size} = 7 - 3 + 1 = 5 \times 5\]

which matches our original output size.

Explaining the padding visually

Let's see why this padding is necessary by examining what happens to a single input position. Consider \(X[0,0]\) (top-left corner of input):

Our example matrices: \[X = \begin{bmatrix} \textcolor{red}{1} & 2 & 3 & 4 & 5 \\ 6 & 7 & 8 & 9 & 10 \\ 11 & 12 & 13 & 14 & 15 \\ 16 & 17 & 18 & 19 & 20 \\ 21 & 22 & 23 & 24 & 25 \end{bmatrix}, \quad W = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix}, \quad \frac{\partial L}{\partial Y} = \begin{bmatrix} \textcolor{red}{a} & b & c \\ d & e & f \\ g & h & i \end{bmatrix}\]

Without the padding: To compute \(\frac{\partial L}{\partial X[0,0]}\), we need to convolve with the flipped filter: \[W_{flipped} = \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix}\]

We'd need to place this flipped filter so its bottom-right corner aligns with \(\frac{\partial L}{\partial Y[0,0]}\): \[\begin{bmatrix} \textcolor{gray}{?} & \textcolor{gray}{?} & \textcolor{gray}{?} \\ \textcolor{gray}{?} & \textcolor{gray}{?} & \textcolor{gray}{?} \\ \textcolor{gray}{?} & \textcolor{gray}{?} & \textcolor{red}{1} \end{bmatrix} \text{ trying to align with } \begin{bmatrix} \textcolor{red}{a} & b & c \\ d & e & f \\ g & h & i \end{bmatrix}\]

But the \(\textcolor{gray}{?}\) positions fall outside our \(3 \times 3\) gradient matrix if we try to overlay the red \(1\) on the red \(a\).

With the padding: We pad \(\frac{\partial L}{\partial Y}\) by \((3-1) = 2\) on each side: \[\frac{\partial L}{\partial Y}_{padded} = \begin{bmatrix} 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & \textcolor{red}{a} & b & c & 0 & 0 \\ 0 & 0 & d & e & f & 0 & 0 \\ 0 & 0 & g & h & i & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}\]

Now we can place the flipped filter with its bottom-right corner at position \((2,2)\) where \(\frac{\partial L}{\partial Y[0,0]}\) is located:

\[\begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & \textcolor{red}{1} \end{bmatrix} \text{ overlaid on } \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & \textcolor{red}{a} \end{bmatrix}\]

The convolution gives us: \[\frac{\partial L}{\partial X[0,0]} = (-1) \cdot 0 + 0 \cdot 0 + 1 \cdot 0 + (-2) \cdot 0 + 0 \cdot 0 + 2 \cdot 0 + (-1) \cdot 0 + 0 \cdot 0 + 1 \cdot a = a\]

Perfect! This matches our expectation that \(\frac{\partial L}{\partial X[0,0]}\) should only receive contributions from \(\frac{\partial L}{\partial Y[0,0]}\).

With padding, we can place the flipped filter such that its bottom-right corner aligns with \(\frac{\partial L}{\partial Y[0,0]}\) (which is now at position \((2,2)\) in the padded matrix), allowing us to compute the gradient correctly. Here's the improved "Implementing dx" section with detailed visual explanations:

Computing dx

The mathematical insight translates directly to code, but let's break down each step to understand how the implementation works.

Step 1: Padding the upstream gradient
dout_padded = np.pad(dout, ((0,0), (0,0), (HH-1, HH-1), (WW-1, WW-1)))

Note that dout represents the upstream gradient \(\frac{\partial L}{\partial Y}\). We pad dout by (HH-1, WW-1) = (2, 2) on each side. Let's see this transformation:

Before padding (3×3):
\[\frac{\partial L}{\partial Y} = \begin{bmatrix} a & b & c \\ d & e & f \\ g & h & i \end{bmatrix}\]

After padding (7×7):
\[\frac{\partial L}{\partial Y}_{padded} = \begin{bmatrix} 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & \textcolor{red}{a} & \textcolor{red}{b} & \textcolor{red}{c} & 0 & 0 \\ 0 & 0 & \textcolor{red}{d} & \textcolor{red}{e} & \textcolor{red}{f} & 0 & 0 \\ 0 & 0 & \textcolor{red}{g} & \textcolor{red}{h} & \textcolor{red}{i} & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}\]

The padding parameters ((0,0), (0,0), (HH-1, HH-1), (WW-1, WW-1)) mean:

Step 1.5: Upsampling for stride > 1

Before we can convolve with the flipped filter, we need to handle the stride from the forward pass. Note that this step is only required for when the forward pass used a stride > 1. Let me show you exactly why this is necessary with a concrete example.

The dimensional mismatch problem: Consider a forward pass with stride=2:

In the forward pass with stride=2, we only computed outputs at input positions (0,0), (0,2), (2,0), (2,2) and (0,2), (0,4), (2,2), (2,4) etc., creating a 2×2 output.

In the backward pass, we receive dout that's 2×2: \[\text{dout} = \begin{bmatrix} a & b \\ c & d \end{bmatrix}\]

Following our normal process without upsampling:

  1. Pad dout by (3-1)=2: creates 6×6 padded gradient
  2. Convolve 6×6 with 3×3 filter: produces 4×4 result
  3. But we need 5×5 to match original input!

Without upsampling: After padding the 2×2 dout, we get: \[\text{dout}_{padded} = \begin{bmatrix} 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & a & b & 0 & 0 \\ 0 & 0 & c & d & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}\]

Convolving this 6×6 matrix with our 3×3 filter produces only a 4×4 result: \[\text{dx} = \begin{bmatrix} ? & ? & ? & ? \\ ? & ? & ? & ? \\ ? & ? & ? & ? \\ ? & ? & ? & ? \end{bmatrix}\]

But we need 5×5 to match our original input!

With upsampling: First, upsample dout to 3×3 by inserting zeros at non-stride positions: \[\text{dout}_{upsampled} = \begin{bmatrix} a & 0 & b \\ 0 & 0 & 0 \\ c & 0 & d \end{bmatrix}\]

Now pad by (3-1)=2 to get 7×7: \[\text{dout}_{upsampled,padded} = \begin{bmatrix} 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & a & 0 & b & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & c & 0 & d & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix}\]

Convolving this 7×7 matrix with our 3×3 filter gives us the correct 5×5 result: \[\text{dx} = \begin{bmatrix} ? & ? & ? & ? & ? \\ ? & ? & ? & ? & ? \\ ? & ? & ? & ? & ? \\ ? & ? & ? & ? & ? \\ ? & ? & ? & ? & ? \end{bmatrix}\]

The upsampling step recreates the "sparse" pattern of the forward pass. When stride=2, the forward pass only computed outputs at every 2nd position. By placing our gradient values at these same sparse positions (with zeros in between), we ensure that the backward convolution correctly maps gradients back to the right input positions.

Code

# Upsample dout to account for stride
if stride > 1:
    # Calculate upsampled dimensions
    H_up = (H_out - 1) * stride + 1
    W_up = (W_out - 1) * stride + 1
    
    # Create upsampled array filled with zeros
    dout_upsampled = np.zeros((N, F, H_up, W_up))
    
    # Place original dout values at strided positions
    dout_upsampled[:, :, ::stride, ::stride] = dout
    
    # Replace dout with upsampled version
    dout = dout_upsampled
Step 2: Flipping the filter
w_flipped = w[:, :, ::-1, ::-1]  # Shape: (F, C, HH, WW)

The slice notation ::-1 reverses the array. Applied to both spatial dimensions, this rotates the filter 180°:

Original filter:
\[W = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix}\]

Flipped filter:
\[W_{flipped} = \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix}\]

Step 3: Looping over i,j input positions
for i in range(H):        # loop over input height positions (0 to 4)
    for j in range(W):    # loop over input width positions (0 to 4)
        h_start = i * stride  # = i (since stride=1)
        w_start = j * stride  # = j (since stride=1)
        h_end = h_start + HH  # = i + 3
        w_end = w_start + WW  # = j + 3

Unlike the forward pass where we loop over output positions, here we loop over input positions because we're computing gradients for each input position. Each \((i,j)\) represents a position in the original input \(X\) that we need to compute \(\frac{\partial L}{\partial X[i,j]}\) for, this directly corresponds to the \((a,b)\) in our mathematical equation:

\[\frac{\partial L}{\partial X[a,b]} = \sum_{h} \sum_{w} \frac{\partial L}{\partial Y[a-h, b-w]} \cdot W[h,w]\]

So when the code processes \((i,j) = (1,1)\), we're computing the gradient for input position \(X[1,1]\), which means \(a=1, b=1\) in our equation.

Step 4: Extracting the right window of dout
window = dout_padded[:, :, h_start:h_end, w_start:w_end]  # Shape: (N, F, HH, WW)

For each input position \((i,j)\), we extract a window from the padded gradient that contains all the \(\frac{\partial L}{\partial Y[a-h, b-w]}\) terms needed by our equation. Instead of manually computing \((a-h, b-w)\) for each \((h,w)\) pair, we extract a \(3 \times 3\) window that contains all these values at once. Let's see what this looks like when computing the gradient for input position \((0,0)\):

# When i=0, j=0: h_start=0, w_start=0, h_end=3, w_end=3
window = dout_padded[:, :, 0:3, 0:3]

This extracts the top-left 3×3 window from the padded gradient:
\[\text{window} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & \textcolor{red}{a} \end{bmatrix}\]

When computing the gradient for input position \((1,1)\):

# When i=1, j=1: h_start=1, w_start=1, h_end=4, w_end=4
dout_window = dout_padded[:, :, 1:4, 1:4]

This extracts a different 3×3 window:
\[\text{window} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & \textcolor{red}{a} & \textcolor{red}{b} \\ 0 & \textcolor{red}{d} & \textcolor{red}{e} \end{bmatrix}\]

Step 5: Convolving the flipped filter over the dout window
# Convolve dout window with flipped filter and sum over spatial dimensions
result = np.sum(dout_window[:, :, None, :, :] * w_flipped[None, :, :, :, :], axis=(3, 4))

This is the most complex line. Let's break down the broadcasting:

Let's visualize this with our example when computing the gradient for input position \((1,1)\):

The window extracted: When i=1, j=1, we extract dout_padded[:, :, 1:4, 1:4]: \[\text{window} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & \textcolor{red}{a} & \textcolor{red}{b} \\ 0 & \textcolor{red}{d} & \textcolor{red}{e} \end{bmatrix}\] Before spatial summation: The element-wise product gives us (focusing on one batch, one filter, one channel): \[\text{window} \odot W_{flipped} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & a & b \\ 0 & d & e \end{bmatrix} \odot \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & \textcolor{red}{2b} \\ 0 & 0 & \textcolor{red}{e} \end{bmatrix}\]

After np.sum(..., axis=(3, 4)): We sum over the spatial dimensions (the 3×3 matrix): \[\text{result[n, f, c]} = 0 + 0 + 0 + 0 + 0 + 2b + 0 + 0 + e = 2b + e\]

This tells us: "Filter \(f\) and channel \(c\) contribute \((2b + e)\) to the gradient at input position \((1,1)\) from image \(n\)."

Step 6: Accumulating filter contributions
dx[:, :, i, j] += np.sum(result, axis=1)

The result array has shape (N, F, C), but we need to accumulate into dx[:, :, i, j] which has shape (N, C). We do this by summing over the filter dimension (axis=1).

Put conceptually, since each filter in the forward pass contributed to multiple output positions, and during backpropagation, we need to collect all the gradients from every filter that used input position \((1,1)\) and sum them up to get the total gradient for that input position.

Let's see this with a concrete example. Suppose we have 2 filters and 1 channel, and we're computing the gradient for position \((1,1)\):

Filter 0 (our original filter): \[W_0 = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix}, \quad W_{0,flipped} = \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix}\] Filter 1 (a different filter): \[W_1 = \begin{bmatrix} 0 & 1 & 0 \\ 1 & 1 & 1 \\ 0 & 1 & 0 \end{bmatrix}, \quad W_{1,flipped} = \begin{bmatrix} 0 & 1 & 0 \\ 1 & 1 & 1 \\ 0 & 1 & 0 \end{bmatrix}\] Computing contributions from each filter:

For Filter 0: \[\text{window} \odot W_{0,flipped} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & a & b \\ 0 & d & e \end{bmatrix} \odot \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & 0 & 2b \\ 0 & 0 & e \end{bmatrix}\] Sum: \(\text{result[n, 0, c]} = 2b + e\)

For Filter 1: \[\text{window} \odot W_{1,flipped} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & a & b \\ 0 & d & e \end{bmatrix} \odot \begin{bmatrix} 0 & 1 & 0 \\ 1 & 1 & 1 \\ 0 & 1 & 0 \end{bmatrix} = \begin{bmatrix} 0 & 0 & 0 \\ 0 & a & b \\ 0 & d & 0 \end{bmatrix}\] Sum: \(\text{result[n, 1, c]} = a + b + d\)

So in the end what this line of code is sums over the filter dimension like follows: \[\text{dx[n, c, 1, 1]} += \text{result[n, 0, c]} + \text{result[n, 1, c]} = (2b + e) + (a + b + d) = a + 3b + d + e\] The += operator is important because if we were processing multiple batches or had computed gradients for this position in previous iterations, we need to accumulate all contributions.

Step 7: Removing original padding
if pad > 0:
    dx = dx[:, :, pad:-pad, pad:-pad]

If the original forward pass used padding, it expanded the input (e.g., from 5×5 to 7×7 with pad=1). Our backward pass must compute gradients with respect to this same expanded input because that's what the forward pass filters actually operated on, the cached input x in our code is already the padded version. After computing gradients for the entire padded space, we extract only the portion corresponding to the original input dimensions, since those are the only gradients we need to propagate backward. In our example with no padding, this step doesn't apply.

Final code
# Initialize dx with zeros
dx = np.zeros((N, C, H, W))

# Save original dout for dw and db computations
dout_orig = dout
H_out_orig, W_out_orig = H_out, W_out

# Upsample dout to account for stride
if stride > 1:
    # Calculate upsampled dimensions
    H_up = (H_out - 1) * stride + 1
    W_up = (W_out - 1) * stride + 1
    
    # Create upsampled array filled with zeros
    dout_upsampled = np.zeros((N, F, H_up, W_up))
    
    # Place original dout values at strided positions
    dout_upsampled[:, :, ::stride, ::stride] = dout
    
    # Replace dout with upsampled version
    dout = dout_upsampled

# Pad dout by (HH-1, WW-1) on each side
dout_padded = np.pad(dout, ((0,0), (0,0), (HH-1, HH-1), (WW-1, WW-1)))

# Flip the filter (rotate 180 degrees)
w_flipped = w[:, :, ::-1, ::-1]  # Shape: (F, C, HH, WW)

# Loop over each input position
for i in range(H):
    for j in range(W):
        h_start = i * stride
        w_start = j * stride
        h_end = h_start + HH
        w_end = w_start + WW
        
        # Extract window from padded gradient
        dout_window = dout_padded[:, :, h_start:h_end, w_start:w_end]  # Shape: (N, F, HH, WW)
        
        # Convolve dout window with flipped filter and sum over spatial dimensions
        result = np.sum(dout_window[:, :, None, :, :] * w_flipped[None, :, :, :, :], axis=(3, 4))
        
        # Collect contributions from all filters that used this input position
        dx[:, :, i, j] += np.sum(result, axis=1)

# Remove original padding if it was used
if pad > 0:
    dx = dx[:, :, pad:-pad, pad:-pad]

Gradient with respect to the weights (dw)

Now let's derive the gradient with respect to the filter weights. Starting from our forward pass equation: \[Y[i,j] = \sum_{h=0}^{HH-1} \sum_{w=0}^{WW-1} X[i+h, j+w] \cdot W[h,w]\]

Using the chain rule to find how the loss depends on each weight:

\[\frac{\partial L}{\partial W[h,w]} = \sum_{i} \sum_{j} \frac{\partial L}{\partial Y[i,j]} \cdot \frac{\partial Y[i,j]}{\partial W[h,w]}\]

From our forward pass equation, we can see that each weight \(W[h,w]\) appears in the computation of every output position \(Y[i,j]\), multiplied by the corresponding input value \(X[i+h, j+w]\). Therefore:

\[\frac{\partial Y[i,j]}{\partial W[h,w]} = X[i+h, j+w]\]

Substituting this back in we get:

\[\frac{\partial L}{\partial W[h,w]} = \sum_{i} \sum_{j} \frac{\partial L}{\partial Y[i,j]} \cdot X[i+h, j+w]\]

Let's clarify what each variable represents:

So the equation tells us: "To get the gradient for weight \(W[h,w]\), sum over all output positions \((i,j)\) that used this weight, and for each such position, multiply the output gradient \(\frac{\partial L}{\partial Y[i,j]}\) by the input value \(X[i+h, j+w]\) that was paired with this weight."

Its just cross-correlation

We can rearrange this equation to make the structure more apparent:

\[\frac{\partial L}{\partial W[h,w]} = \sum_{i} \sum_{j} X[i+h, j+w] \cdot \frac{\partial L}{\partial Y[i,j]}\]

Similarly to the input gradient section, you'll notice that this looks an awfully familiar. That's because this is actually just a cross-correlation, which is defined by the standard equation:

\[(X \star Y)[h,w] = \sum_{i} \sum_{j} X[i+h, j+w] \cdot Y[i,j]\]

Therefore, our gradient computation is:

\[\frac{\partial L}{\partial W} = X \star \frac{\partial L}{\partial Y}\]

In cross-correlation, we slide the input \(X\) over the output gradient \(\frac{\partial L}{\partial Y}\) without flipping either one. This is different from the \(dx\) case where we needed to flip the filter to perform convolution because a convolution is just a cross-correlation, except with a flipped filter.

Computing dw

OK, now let's work through the code for computing \(\frac{\partial L}{\partial W}\), using an example, assuming the input \(X\) and the output gradient \(\frac{\partial L}{\partial Y}\) are: \[X = \begin{bmatrix} 1 & 2 & 3 & 4 & 5 \\ 6 & 7 & 8 & 9 & 10 \\ 11 & 12 & 13 & 14 & 15 \\ 16 & 17 & 18 & 19 & 20 \\ 21 & 22 & 23 & 24 & 25 \end{bmatrix} \hspace{3em} \frac{\partial L}{\partial Y} = \begin{bmatrix} a & b & c \\ d & e & f \\ g & h & i \end{bmatrix}\]

Step 1: Looping over output positions
for i in range(H_out_orig):    # loop over output height positions (0 to 2)
    for j in range(W_out_orig):  # loop over output width positions (0 to 2)
        h_start = i * stride  # = i (since stride=1)
        w_start = j * stride  # = j (since stride=1)
        h_end = h_start + HH  # = i + 3
        w_end = w_start + WW  # = j + 3

Unlike the dx computation where we looped over input positions, here we loop over output positions because our equation sums over all output positions \((i,j)\) that used each weight. Each iteration processes one output position and accumulates its contribution to all filter weights.

Step 2: Extracting the input window
x_slice = x[:, :, h_start:h_end, w_start:w_end]  # Shape: (N, C, HH, WW)

This extracts the input window that was used to compute output position \((i,j)\). This corresponds to the \(X[i+h, j+w]\) terms in our equation.

Let's see what this looks like when processing output position (0,0)(0,0) (0,0):

# When i=0, j=0: h_start=0, w_start=0, h_end=3, w_end=3
x_slice = x[:, :, 0:3, 0:3]

This extracts the top-left 3×3 window from the input: \[\text{x slice} = \begin{bmatrix} {1} & {2} & {3} \\ {6} & {7} & {8} \\ {11} & {12} & {13} \end{bmatrix}\]

Step 3: Extracting the output gradient
dout_slice = dout_orig[:, :, i, j]  # Shape: (N, F)

This extracts the gradient with respect to output position \((i,j)\) for all batch samples and all filters. This corresponds to the \(\frac{\partial L}{\partial Y[i,j]}\) term in our equation.

When processing output position \((0,0)\):

dout_slice = dout_orig[:, :, 0, 0]  # Contains value 'a' for all batch samples and filters

When processing output position \((0,1)\):

dout_slice = dout_orig[:, :, 0, 1]  # Contains value 'b' for all batch samples and filters
Step 4: Broadcasting and multiplying
# dout_slice: (N, F) -> (N, F, 1, 1, 1)
# x_slice: (N, C, HH, WW) -> (N, 1, C, HH, WW)
# result: (N, F, C, HH, WW)
result = dout_slice[:, :, None, None, None] * x_slice[:, None, :, :, :]

This is the core of the cross-correlation computation. The broadcasting implements the multiplication \(\frac{\partial L}{\partial Y[i,j]} \cdot X[i+h, j+w]\) for all combinations of filters, channels, and spatial positions.

Let's break down what happens when processing output position \((0,0)\) with our example:

Before broadcasting:

After broadcasting:

Visually (focusing on one batch sample, one filter, one channel): The dout_slice value (scalar a) gets multiplied by every element in the 3×3 input window: \[\text{result[n, f, c]} = a \cdot \begin{bmatrix} 1 & 2 & 3 \\ 6 & 7 & 8 \\ 11 & 12 & 13 \end{bmatrix} = \begin{bmatrix} a \cdot 1 & a \cdot 2 & a \cdot 3 \\ a \cdot 6 & a \cdot 7 & a \cdot 8 \\ a \cdot 11 & a \cdot 12 & a \cdot 13 \end{bmatrix}\] This tells us: "Output position \((0,0)\) contributed gradient a to the loss, so each weight position \((h,w)\) receives a gradient contribution of a times the input value that was paired with that weight."

Step 5: Accumulating over the batch dimension
dw += np.sum(result, axis=0)  # Shape: (F, C, HH, WW)

The result array has shape (N, F, C, HH, WW), but our filter weights have shape (F, C, HH, WW). We sum over the batch dimension (axis=0) to accumulate gradients from all samples in the batch.

This is necessary because each time we process an output position \((i,j)\), we compute how that specific output position contributes to the gradient of each weight. But each weight is used by multiple output positions, so we need to accumulate all these contributions.

Tracing an example

Let's trace what happens to weight \(W[0,0]\) (top-left corner of the filter):

When processing output position \((0,0)\):

When processing output position \((0,1)\):

Continuing this pattern through all 9 output positions, we get:

Output Position Input Value for W[0,0] Output Gradient Contribution to W[0,0]
(0,0) X[0,0] = 1 a a × 1 = a
(0,1) X[0,1] = 2 b b × 2 = 2b
(0,2) X[0,2] = 3 c c × 3 = 3c
(1,0) X[1,0] = 6 d d × 6 = 6d
(1,1) X[1,1] = 7 e e × 7 = 7e
(1,2) X[1,2] = 8 f f × 8 = 8f
(2,0) X[2,0] = 11 g g × 11 = 11g
(2,1) X[2,1] = 12 h h × 12 = 12h
(2,2) X[2,2] = 13 i i × 13 = 13i

\[\frac{\partial L}{\partial W[0,0]} = a + 2b + 3c + 6d + 7e + 8f + 11g + 12h + 13i\] The += operator automatically accumulates these contributions as we process each output position.

Final code
# Initialize dw with zeros
dw = np.zeros_like(w)  # Shape: (F, C, HH, WW)

# Loop over each output position
for i in range(H_out_orig):
    for j in range(W_out_orig):
        h_start = i * stride
        w_start = j * stride
        h_end = h_start + HH
        w_end = w_start + WW

        # Extract the input window that was used for this output position
        x_slice = x[:, :, h_start:h_end, w_start:w_end]  # Shape: (N, C, HH, WW)

        # Extract the gradient for this output position
        dout_slice = dout_orig[:, :, i, j]  # Shape: (N, F)

        # Compute the contribution to all weight gradients
        # Broadcasting: (N, F, 1, 1, 1) * (N, 1, C, HH, WW) = (N, F, C, HH, WW)
        result = dout_slice[:, :, None, None, None] * x_slice[:, None, :, :, :]

        # Accumulate over the batch dimension
        dw += np.sum(result, axis=0)  # Shape: (F, C, HH, WW)

Gradient with respect to the bias (db)

The bias gradient is the most straightforward of the three. Let's start with our forward pass equation that includes bias:

\[Y[i,j] = \sum_{h=0}^{HH-1} \sum_{w=0}^{WW-1} X[i+h, j+w] \cdot W[h,w] + b\]

Using the chain rule to find how the loss depends on the bias, we get:

\[\frac{\partial L}{\partial b} = \sum_{i} \sum_{j} \frac{\partial L}{\partial Y[i,j]} \cdot \frac{\partial Y[i,j]}{\partial b}\]

Since the bias \(b\) is simply added to each output position, we can treat the input x weights term as a constant and we have:

\[\frac{\partial Y[i,j]}{\partial b} = 1\]

Therefore: \[\frac{\partial L}{\partial b} = \sum_{i} \sum_{j} \frac{\partial L}{\partial Y[i,j]} \cdot 1 = \sum_{i} \sum_{j} \frac{\partial L}{\partial Y[i,j]}\]

This means that the bias affects every output position equally by simply being added to each one. During backpropagation, the gradient flows back through this addition unchanged, so the bias gradient is just the sum of all upstream gradients for that filter.

Computing db

Fortunately, the implementation is refreshingly simple:

# db: Sum over all dimensions except F
db = np.sum(dout_orig, axis=(0, 2, 3))  # Shape: (F,)

Visually put, if we have our output gradient: \[\frac{\partial L}{\partial Y} = \begin{bmatrix} a & b & c \\ d & e & f \\ g & h & i \end{bmatrix}\]Then for a single filter, the bias gradient is: \[\frac{\partial L}{\partial b} = a + b + c + d + e + f + g + h + i\] The bias "collects" gradients from every output position because it contributed equally to each one during the forward pass.

The end

OK, that's all! I know that was super long and maybe verbose at times, but in this case I prefer to over-explain than to under-explain. I wanted this to to be as complete of a reference to understanding the backward pass of the convolutional layer as I could reasonably make it.

If I fell short and this didn't help you, here are some other materials that I benefited from that may help:

Finally, thank you for reading, if you actually read this material it's likely that you're a beginner in deep learning but are interested in learning it deeply, which I think is commendable. I would love to hear from you if you have any feedback, thoughts, or errata over email or X.

Good luck :)

Namra