Back Propagation in Recurrent Neural Networks

February 28, 2018


In this section, we will recap some equations which were derived in detail in the post on back propagation in fully connected layer [1]. These equations are essential to understand the back propagation in RNNs, including LSTMs. Consider a fully connected network with $tanh$ as the non-linear activation e.g.

\begin{align} Y &= WX + B \\ Z &= tanh(Y) \\ \begin{bmatrix} y_1 \\ y_2 \\ \end{bmatrix} &= \begin{bmatrix} w_1 & w_2 \\ w_3 & w_4 \\ \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \\ \end{bmatrix} + \begin{bmatrix} b_1 \\ b_2 \\ \end{bmatrix} \\ \begin{bmatrix} z_1 \\ z_2 \\ \end{bmatrix} &= \begin{bmatrix} tanh(y_1) \\ tanh(y_2) \\ \end{bmatrix} \end{align}

If $L$ is the loss of the network and given $\frac{\partial L}{\partial Z}$ (from the preceding layer), then

\begin{align} \newcommand{\pd}[2]{\frac{\partial #1}{\partial #2}} \pd{L}{Y} & = \begin{bmatrix} \pd{L}{y_1} \\ \pd{L}{y_2} \\ \end{bmatrix} \\ &= \begin{bmatrix} \pd{L}{z_1} (1-z_1^2) \\ \pd{L}{z_2} (1-z_2^2) \\ \end{bmatrix} \\ &= \begin{bmatrix} (1-z_1^2) \\ (1-z_2^2) \\ \end{bmatrix} \odot \begin{bmatrix} \pd{L}{z_1} \\ \pd{L}{z_2} \\ \end{bmatrix} \\ &= \begin{bmatrix} (1-z_1^2) & 0 \\ 0 & (1-z_2^2) \\ \end{bmatrix} \begin{bmatrix} \pd{L}{z_1} \\ \pd{L}{z_2} \\ \end{bmatrix} \\ &= diag(1-Z^2)\pd{L}{Z} \\ \pd{L}{W} &= \pd{L}{Y}X^T \\ \pd{L}{X} &= W^T\pd{L}{Y} \\ \end{align}


Now consider a many-to-one Recurrent Neural Network (RNN) constructed using a single vanilla RNN cell (not LSTM cell). The computational graph of the network unrolled in time is as shown below.

Let $H_0$, be the initial hidden vector. Let $H_1$, $H_2$ and $H_3$ be the hidden vectors and $X_1$, $X_2$ and $X_3$ be the input vectors at times $t=1$, $t=2$ and $t=3$ respectively. Let $Y$, be the output vector of the network at time $t=3$. Let $h$, $x$ and $y$ be the sizes of the column vectors $H$, $X$ and $Y$ respectively, and $W$, $U$, $V$ be the learnable matrices of sizes $h \times h$, $h \times x$ and $y \times h$ respectively. Let $W_1 = W_2 = W_3 = W$ and $U_1 = U_2 = U_3 = U$ be dummy variables introduced to make the back propagation equations easier to write. The equations governing the forward propagation are given by,

\begin{align} H_1 &= tanh(W_1H_0 + U_1X_1) \\ H_2 &= tanh(W_2H_1 + U_2X_2) \\ H_3 &= tanh(W_3H_2 + U_3X_3) \\ Y &= VH_3 \\ \end{align}

Let $L$, be the loss(error) of the network we want to minimize. $L$ is a function of $Y$ and the ground truth. During backward propagation, given $\frac{\partial L}{\partial Y}$ we want to calculate $\frac{\partial L}{\partial W}$, $\frac{\partial L}{\partial U}$ and $\frac{\partial L}{\partial V}$. Note that these partial derivatives are matrices of the same shape as $Y$, $W$, $U$ and $V$ respectively, and are used to update the matrices $W$, $U$ and $V$ respectively, e.g.

\begin{align} W &= W - \eta\pd{L}{W} \\ U &= U - \eta\pd{L}{U} \\ V &= V - \eta\pd{L}{V} \\ \end{align}


\begin{align} \pd{L}{W} &= \pd{L}{W_1} + \pd{L}{W_2} + \pd{L}{W_3} \\ &= diag(1-(H_1)^2)\pd{L}{H_1}H_0^T + diag(1-(H_2)^2)\pd{L}{H_2}H_1^T + diag(1-(H_3)^2)\pd{L}{H_3}H_2^T \\ \pd{L}{U} &= \pd{L}{U_1} + \pd{L}{U_2} + \pd{L}{U_3} \\ &= diag(1-(H_1)^2)\pd{L}{H_1}X_1^T + diag(1-(H_2)^2)\pd{L}{H_2}X_2^T + diag(1-(H_3)^2)\pd{L}{H_3}X_3^T \\ \pd{L}{V} &= \pd{L}{Y}H_3^T \\ \end{align}

Therefore, we need to calculate $\frac{\partial L}{\partial H_1}$, $\frac{\partial L}{\partial H_2}$ and $\frac{\partial L}{\partial H_3}$. Now,

\begin{align} \pd{L}{H_3} &= V^T\pd{L}{Y} \\ \pd{L}{H_2} &= W^Tdiag(1-(H_3)^2)\pd{L}{H_3} \\ \pd{L}{H_1} &= W^Tdiag(1-(H_2)^2)\pd{L}{H_2} \\ &= W^Tdiag(1-(H_2)^2)W^Tdiag(1-(H_3)^2)\pd{L}{H_3} \\ (\text{if required}) \pd{L}{H_0} &= W^Tdiag(1-(H_1)^2)\pd{L}{H_1} \\ &= W^Tdiag(1-(H_1)^2)W^Tdiag(1-(H_2)^2)W^Tdiag(1-(H_3)^2)\pd{L}{H_3} \\ \end{align}

Substituting the values of $\frac{\partial L}{\partial H_1}$, $\frac{\partial L}{\partial H_2}$ and $\frac{\partial L}{\partial H_3}$ gives us the values of $\pd{L}{W}$ and $\pd{L}{U}$ as a function of $W, V, H_0, H_1, H_2, H_3, X_1, X_2, X_3$ and $\pd{L}{Y}$.


May be in the future.


  1. Back Propagation in Fully Connected Layer.