# Back Propagation in Recurrent Neural Networks

## Background

In this section, we will recap some equations which were derived in detail in the post on back propagation in fully connected layer . These equations are essential to understand the back propagation in RNNs, including LSTMs. Consider a fully connected network with $tanh$ as the non-linear activation e.g.

\begin{align} Y &= WX + B \\ Z &= tanh(Y) \\ \begin{bmatrix} y_1 \\ y_2 \\ \end{bmatrix} &= \begin{bmatrix} w_1 & w_2 \\ w_3 & w_4 \\ \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \\ \end{bmatrix} + \begin{bmatrix} b_1 \\ b_2 \\ \end{bmatrix} \\ \begin{bmatrix} z_1 \\ z_2 \\ \end{bmatrix} &= \begin{bmatrix} tanh(y_1) \\ tanh(y_2) \\ \end{bmatrix} \end{align}

If $L$ is the loss of the network and given $\frac{\partial L}{\partial Z}$ (from the preceding layer), then

\begin{align} \newcommand{\pd}{\frac{\partial #1}{\partial #2}} \pd{L}{Y} & = \begin{bmatrix} \pd{L}{y_1} \\ \pd{L}{y_2} \\ \end{bmatrix} \\ &= \begin{bmatrix} \pd{L}{z_1} (1-z_1^2) \\ \pd{L}{z_2} (1-z_2^2) \\ \end{bmatrix} \\ &= \begin{bmatrix} (1-z_1^2) \\ (1-z_2^2) \\ \end{bmatrix} \odot \begin{bmatrix} \pd{L}{z_1} \\ \pd{L}{z_2} \\ \end{bmatrix} \\ &= \begin{bmatrix} (1-z_1^2) & 0 \\ 0 & (1-z_2^2) \\ \end{bmatrix} \begin{bmatrix} \pd{L}{z_1} \\ \pd{L}{z_2} \\ \end{bmatrix} \\ &= diag(1-Z^2)\pd{L}{Z} \\ \pd{L}{W} &= \pd{L}{Y}X^T \\ \pd{L}{X} &= W^T\pd{L}{Y} \\ \end{align}

## Many-to-one

Now consider a many-to-one Recurrent Neural Network (RNN) constructed using a single vanilla RNN cell (not LSTM cell). The computational graph of the network unrolled in time is as shown below. Let $H_0$, be the initial hidden vector. Let $H_1$, $H_2$ and $H_3$ be the hidden vectors and $X_1$, $X_2$ and $X_3$ be the input vectors at times $t=1$, $t=2$ and $t=3$ respectively. Let $Y$, be the output vector of the network at time $t=3$. Let $h$, $x$ and $y$ be the sizes of the column vectors $H$, $X$ and $Y$ respectively, and $W$, $U$, $V$ be the learnable matrices of sizes $h \times h$, $h \times x$ and $y \times h$ respectively. Let $W_1 = W_2 = W_3 = W$ and $U_1 = U_2 = U_3 = U$ be dummy variables introduced to make the back propagation equations easier to write. The equations governing the forward propagation are given by,

\begin{align} H_1 &= tanh(W_1H_0 + U_1X_1) \\ H_2 &= tanh(W_2H_1 + U_2X_2) \\ H_3 &= tanh(W_3H_2 + U_3X_3) \\ Y &= VH_3 \\ \end{align}

Let $L$, be the loss(error) of the network we want to minimize. $L$ is a function of $Y$ and the ground truth. During backward propagation, given $\frac{\partial L}{\partial Y}$ we want to calculate $\frac{\partial L}{\partial W}$, $\frac{\partial L}{\partial U}$ and $\frac{\partial L}{\partial V}$. Note that these partial derivatives are matrices of the same shape as $Y$, $W$, $U$ and $V$ respectively, and are used to update the matrices $W$, $U$ and $V$ respectively, e.g.

\begin{align} W &= W - \eta\pd{L}{W} \\ U &= U - \eta\pd{L}{U} \\ V &= V - \eta\pd{L}{V} \\ \end{align}

Now,

\begin{align} \pd{L}{W} &= \pd{L}{W_1} + \pd{L}{W_2} + \pd{L}{W_3} \\ &= diag(1-(H_1)^2)\pd{L}{H_1}H_0^T + diag(1-(H_2)^2)\pd{L}{H_2}H_1^T + diag(1-(H_3)^2)\pd{L}{H_3}H_2^T \\ \pd{L}{U} &= \pd{L}{U_1} + \pd{L}{U_2} + \pd{L}{U_3} \\ &= diag(1-(H_1)^2)\pd{L}{H_1}X_1^T + diag(1-(H_2)^2)\pd{L}{H_2}X_2^T + diag(1-(H_3)^2)\pd{L}{H_3}X_3^T \\ \pd{L}{V} &= \pd{L}{Y}H_3^T \\ \end{align}

Therefore, we need to calculate $\frac{\partial L}{\partial H_1}$, $\frac{\partial L}{\partial H_2}$ and $\frac{\partial L}{\partial H_3}$. Now,

\begin{align} \pd{L}{H_3} &= V^T\pd{L}{Y} \\ \pd{L}{H_2} &= W^Tdiag(1-(H_3)^2)\pd{L}{H_3} \\ \pd{L}{H_1} &= W^Tdiag(1-(H_2)^2)\pd{L}{H_2} \\ &= W^Tdiag(1-(H_2)^2)W^Tdiag(1-(H_3)^2)\pd{L}{H_3} \\ (\text{if required}) \pd{L}{H_0} &= W^Tdiag(1-(H_1)^2)\pd{L}{H_1} \\ &= W^Tdiag(1-(H_1)^2)W^Tdiag(1-(H_2)^2)W^Tdiag(1-(H_3)^2)\pd{L}{H_3} \\ \end{align}

Substituting the values of $\frac{\partial L}{\partial H_1}$, $\frac{\partial L}{\partial H_2}$ and $\frac{\partial L}{\partial H_3}$ gives us the values of $\pd{L}{W}$ and $\pd{L}{U}$ as a function of $W, V, H_0, H_1, H_2, H_3, X_1, X_2, X_3$ and $\pd{L}{Y}$.

## Many-to-many

May be in the future.