Back Propagation in Batch Normalization Layer

December 29, 2017

Many popular deep neural networks use a Batch Normalization (BN) layer. While the equations for the forward path are easy to follow, the equations for the back propagation can appear a bit intimidating. In this post, we will derive the equations for the back propagation of the BN layer. Assuming two inputs $x_1$ and $x_2$, the equations governing the forward propagation are given by 1,

\begin{align} \mu &= \frac{x_1 + x_2}{2} \\ \sigma^2 &= \frac{(x_1 - \mu)^2 + (x_2 - \mu)^2}{2} \\ \hat{x}_1 &= \frac{x_1 - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ \hat{x}_2 &= \frac{x_2 - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_1 &= \gamma\hat{x}_1 + \beta \\ y_2 &= \gamma\hat{x}_2 + \beta \\ \end{align}

where $\mu$ is the mean, $\sigma^2$ is the variance, $\hat{x}_1$, $\hat{x}_2$ are intermediate variables, $\gamma$, $\beta$ are learnable parameters of the BN layer and $y_1$, $y_2$ are the output values. Computational graphs are useful in understanding backward propagation and 2 is a must read if you do not understand the basics. The computational graph for the above equations is shown below.

In forward propagation, given inputs $x_1$ and $x_2$, we calculate the outputs $y_1$ and $y_2$. In backward propagation, given $\frac{\partial L}{\partial y_1}$ and $\frac{\partial L}{\partial y_2}$, we want to calculate $\frac{\partial L}{\partial x_1}$, $\frac{\partial L}{\partial x_2}$, $\frac{\partial L}{\partial \gamma}$ and $\frac{\partial L}{\partial \beta}$, where $L$ is the loss(error) of the network. In the process, we will also need $\frac{\partial L}{\partial \mu}$, $\frac{\partial L}{\partial \sigma^2}$, $\frac{\partial L}{\partial \hat{x}_1}$ and $\frac{\partial L}{\partial \hat{x}_2}$.

For backward propagation, we start at the bottom of the graph and work our way to the top. Referring to the computational graph and using the chain rule of calculas, we first obtain the values of $\frac{\partial L}{\partial \beta}$ and $\frac{\partial L}{\partial \gamma}$ as follows.

\begin{align} \frac{\partial L}{\partial \beta} &= \frac{\partial L}{\partial y_1} \frac{\partial y_1}{\partial \beta} + \frac{\partial L}{\partial y_2} \frac{\partial y_2}{\partial \beta} \\\\ &=\frac{\partial L}{\partial y_1} + \frac{\partial L}{\partial y_2} \\\\ &=\sum_{i=1}^{2} \frac{\partial L}{\partial y_i} \tag{1} \\\\ \end{align} \begin{align} \frac{\partial L}{\partial \gamma} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial \gamma} + \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial \gamma} \\\\ &= \frac{\partial L}{\partial y_1}\hat{x}_1 + \frac{\partial L}{\partial y_2}\hat{x}_2 \\\\ &= \sum_{i=1}^{2} \frac{\partial L}{\partial y_i}\hat{x}_i \tag{2} \\\\ \end{align}

Moving up the graph, we next obtain the values of $\frac{\partial L}{\partial \hat{x}_1}$ and $\frac{\partial L}{\partial \hat{x}_2}$ as follows.

\begin{align} \frac{\partial L}{\partial \hat{x}_1} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial \hat{x}_1} \\\\ &= \frac{\partial L}{\partial y_1}\gamma \\ \end{align} \begin{align} \frac{\partial L}{\partial \hat{x}_2} &= \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial \hat{x}_2} \\\\ &= \frac{\partial L}{\partial y_2}\gamma \\ \end{align}

More generally, $$ \frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i}\gamma \tag{3} \ $$

Next, we calculate $\frac{\partial L}{\partial \sigma^2}$ and $\frac{\partial L}{\partial \mu}$ as follows.

\begin{align} \frac{\partial L}{\partial \sigma^2} &= \frac{\partial L}{\partial \hat{x}_1}\frac{\partial \hat{x}_1}{\partial \sigma^2} + \frac{\partial L}{\partial \hat{x}_2}\frac{\partial \hat{x}_2}{\partial \sigma^2} \\ &= \sum_{i=1}^{2} \frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial \sigma^2} \\ &= \sum_{i=1}^{2} \frac{\partial L}{\partial \hat{x}_i}(x_i - \mu)\frac{-1}{2}(\sigma^2 + \epsilon)^{-3/2} \tag{4} \\ \end{align} \begin{align} \frac{\partial L}{\partial \mu} &= \frac{\partial L}{\partial \hat{x}_1}\frac{\partial \hat{x}_1}{\partial \mu} + \frac{\partial L}{\partial \hat{x}_2}\frac{\partial \hat{x}_2}{\partial \mu} + \frac{\partial L}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial \mu} \\\\ &= \sum_{i=1}^{2}\frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial \mu} + \frac{\partial L}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial \mu} \\\\ &= \sum_{i=1}^{2}\frac{\partial L}{\partial \hat{x}_i}\frac{-1}{\sqrt{\sigma^2 + \epsilon}} + \frac{\partial L}{\partial \sigma^2}\frac{-2(x_1 - \mu)\ - 2(x_2 - \mu)}{2}\\\\ &= \sum_{i=1}^{2}\frac{\partial L}{\partial \hat{x}_i}\frac{-1}{\sqrt{\sigma^2 + \epsilon}} + \frac{\partial L}{\partial \sigma^2}\frac{\sum_{i=1}^{2}-2(x_i - \mu)}{2} \tag{5}\\\\ \end{align}

Note that the second term in equation 5 is zero because $\sum_{i=1}^{2} (x_i - \mu) = 0$. Now, we have everything required to calculate $\frac{\partial L}{\partial x_1}$ and $\frac{\partial L}{\partial x_2}$.

\begin{align} \frac{\partial L}{\partial x_1} &= \frac{\partial L}{\partial \hat{x}_1}\frac{\partial \hat{x}_1}{\partial x_1} + \frac{\partial L}{\partial \sigma^2}\frac{\partial \sigma^2}{\partial x_1} + \frac{\partial L}{\partial \mu}\frac{\partial \mu}{\partial x_1} \\\\ &= \frac{\partial L}{\partial \hat{x}_1}\frac{1}{\sqrt{\sigma^2 + \epsilon}} + \frac{\partial L}{\partial \sigma^2}\frac{2(x_1 - \mu)}{2} + \frac{\partial L}{\partial \mu}\frac{1}{2} \\\\ \end{align}

The equation for $\frac{\partial L}{\partial x_2}$ is similar. More generally, when the number of inputs is $m$

$$ \frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i}\frac{1}{\sqrt{\sigma^2 + \epsilon}} + \frac{\partial L}{\partial \sigma^2}\frac{2(x_i - \mu)}{m} + \frac{\partial L}{\partial \mu}\frac{1}{m} \tag{6} $$

References

  1. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. Sergey Ioffe, Christian Szegedy. arXiv
  2. Calculus on Computational Graphs: Backpropagation. Christopher Olah. link