Backpropagation Through Time: recurrent neural network training technique

July 5, 2017, 9:18 p.m. By: Pranjal Kumar


Backpropagation Through Time (BPTT) is the algorithm that is used to update the weights in the recurrent neural network. One of the common examples of a recurrent neural network is LSTM. Backpropagation is an essential skill that you should know if you want to effectively frame sequence prediction problems for the recurrent neural network. You should also be aware of the effects of the Backpropagation Through time on the stability, the speed of the system while training the system.

The ultimate goal of the Backpropagation algorithm is to minimize the error of the network outputs.

The general algorithm is

  1. First, present the input pattern and propagate it through the network to get the output.

  2. Then compare the predicted output to the expected output and calculate the error.

  3. Then calculate the derivates of the error with respect to the network weights

  4. Try to adjust the weights so that the error is minimum.

The Backpropagation algorithm is suitable for the feed forward neural network on fixed sized input-output pairs.

The Backpropagation Through Time is the application of Backpropagation training algorithm which is applied to the sequence data like the time series. It is applied to the recurrent neural network. The recurrent neural network is shown one input each timestep and predicts the corresponding output. So, we can say that BTPP works by unrolling all input timesteps. Each timestep has one input time step, one output time step and one copy of the network. Then the errors are calculated and accumulated for each timestep. The network is then rolled back to update the weights.

But one of the disadvantages of BPTT is when the number of time steps increases the computation also increases. This will make the overall model noisy. The high cost of single parameter updates makes the BPTT impossible to use for a large number of iterations.

This is where Truncated Backpropagation comes save the day for us. Truncated Backpropagation (TBPTT) is nothing but a slightly modified version of BPTT algorithm for the recurrent neural network. In this, the sequence is processed one timestep at a time and periodically the BPTT update is performed for a fixed number of time steps.

The basic Truncated Backpropagation algorithm is

  1. First, give the sequence of, say K1 time steps of input and output pairs to the network.

  2. Then calculate and accumulate the errors across say, k2 time steps by unrolling the network

  3. Finally, update the weights by rolling up the network

As you can clearly see that you need two parameters namely k1 and k2 for implementing TBPTT. K1 is the number of forwarding pass timesteps between updates. This influences how fast or slow will be the training and the frequency of the weight updates. On the other hand, k2 is the number of timesteps which apply to BPTT. It should be large enough to capture the temporal structure in the problem for the network to learn.

Backpropagation in 5 Minutes

Video Source: Siraj Raval

FREE COURSE: Deep Learning by Google