Many posts have gone into detail discussing the forward pass of LSTMs (for example, the very informative post here). However, relatively few go through backpropagation, and numerical examples are even more rare. I did manage to find some good sources, and Alex Graves’ thesis was a big help, but after I answered this datascience post on LSTMs, I thought it would be worth delving into some more details.
Other blogs discuss RNNs, so I won’t get into them. Instead, I’ll focus on:
- The forward pass: how information travels through an LSTM
- The backward pass: how gradient information travels backwards through the LSTM
Let’s take a very simple example and work through the forward pass and the backward pass. I will assume that you have seen backpropagation before. If you haven’t, you may want to start here.
LSTM stands for Long Short-Term Memory. It was conceived by Hochreiter and Schmidhuber in 1997 and has been improved on since by many others. The purpose of an LSTM is time series modelling: if you have an input sequence, you may want to map it to an output sequence, a scalar value, or a class. LSTMs can help you do that. The entire LSTM model – all of the gates and the memory cell – are referred to as the LSTM cell. The basic components are an input gate, a forget gate (added after the original LSTM), an output gate, and a memory cell. Broadly:
- The input gate allows new information to flow into the network. It has parameters , where stands for input.
- The memory cell preserves the hidden units information across time steps. It has parameters , where stands for cell.
- The forget gate allows information which is no longer pertinent to be discarded. It has parameters , where stands for forget.
- The output gate allows what information will be output to the screen and what will be propagated forward as part of the new hidden state. It has parameters , where stands for output.
Interestingly, all of the weights have the same dimension.
This is an LSTM with two cells:
When I was first looking at LSTMs, the presentations that I saw made it difficult to think of them as feed-forward. While this diagram is a bit unconventional, I find it helpful because it illustrates that we are dealing with a special kind of recurrent neural network (which at its core is just a feed-forward neural network replicated over time).
Let’s break that image down a bit. Here is the first cell:
Green circles indicate input. In each cell, we have the current input from the time series, , and input from the previous time step, . The last operation in the cell is to calculate the hidden state for the next cell, which is at once part of the output of the current cell and the input of the next cell.
Red circles indicate the memory cell. One of the main differences between vanilla RNNs and LSTMs is the addition of the memory cell. Whereas RNNs have only the hidden state to maintain memory from previous time steps, LSTMs have the hidden state as well as this additional memory cell. This helps with the task of learning long-term dependencies, something that RNNs have struggled with. In the first cell, the memory coming from the previous time step is set to 0 (although is some recent work on initialization strategies).
Orange circles are gates. In real life, a gate can be partially open, fully open, or closed. The same idea applies here. The gates control the flow of information.
The line down the centre of some of the circles indicates that the circle (the neuron) has some net input and an activation function. We can imagine the net input coming in on the left hemisphere, and undergoing a nonlinear transformation (the activation function) on the right hemisphere. For example, the orange neuron takes a linear combination as input and outputs an activation. We’ll get into that more when we discuss the input gate. Naturally, the initial input doesn’t have a net input, so circles for have no dividing lines.
Here is the second cell:
This is not that much different from the first cell. The hidden state from the previous layer is , which was the result of a calculation, so this neuron has a dividing line. All of the connections are the same. The hidden state is not input for the next cell because there is no next cell. We go directly from computing the hidden state to computing the output of the LSTM network. We’ll get to that.
A detailed walk-through: forward
Let’s start with a simple example with one dimensional input and one dimensional output.
Let’s focus on the first cell for now:
Suppose we have a scalar-valued input sequence . In English, this means the input at the beginning of the sequence is 0.1, and the input at the next time step is 0.2. Yes, this isn’t much of a sequence, but to illustrate the computation it’ll do fine. We’ll assume that we initialized our weights and biases to have the following values:
This formulation will come in handy later for backpropagation, but you can see that each row of the matrix has all of the parameters needed for one of the gates. The last row is the linear transformation associated with the output (we’ll get to that).
The input gate
It seems fashionable to start at the forget gate, but anytime I walk through a model I like to start with the input.
Our input is , the initial sequence input, the initial hidden value, and the initial memory value. This translates to . Above, I haven’t even bothered to include a column for a multiplication by the initial memory state (I would need, for example, a value ), since this is rarely anything but 0 – people usually don’t bother considering it as part of the input, and I won’t from this point on.
The image associated with the input gate is:
And the associated equations for the first part are:
I’ve used ‘net’ to mean the net input to the gate. We take a linear transformation of the input values. Another way to present the linear transformation (using for transpose) is:
, as done on that first blog I linked to.
The full computation is:
This value can be interpreted as the probability that we will allow the information from to enter the memory cell.
The usual practice is to keep that value 0.515 – that is, to keep the gate partially open. Alternatively, we could make a decision as to whether the information will go forward. That is, we could generate a value and, if , then allow the information through – open the input gate completely. This is referred to as making a stochastic decision. Depending on the decision, the gate would open (value 1) or close (value 0). For the purposes of this example, let's assume the value is 1 to make the arithmetic easy.
The second part of the input gate is related to the memory cell. It creates a proposal for the inclusion of the new information:
The full computation is:
Note no stochastic decision is made here – this is the quantity associated with the input that we'll pass to the memory cell. We could make a stochastic decision using a tanh function, and that often happens, but not here. Why? Because this is the input signal! We need this part as it is.
We’ll use both of these pieces together later when we update the memory cell.
The forget gate
The point of this gate is to decide what information needs to be removed from the network. For example, if you’re making a grocery list for this week’s groceries based on last week’s, and you bought 2 weeks’ worth of apples last week, then you don’t need to buy them this week. You can remove the apples from your list.
The forget gate looks like this:
and takes similar input:
And the computation is also similar:
Again, a stochastic decision could be made here as to whether the previous information should be forgotten (value 0) or allowed through (value 1). For the purposes of this example, let’s assume the value is 1.
The memory cell
This is the best part! We combine the new information from the input gate and remove the information we’re forgetting according to the forget gate.
The picture is:
and the update looks like this:
That’s a new symbol! We need an aside.
Aside: Hadamard product
The Hadamard product is an element-wise product. If we have a vector and a vector , then the Hadamard product would be .
Now that we’ve updated the memory state (another name for the memory cell), we have to think about what we want to output.
The output gate
In a sequence-to-sequence mapping task, like machine translation or image captioning, we might be interested in outputting a value (to the screen or to a file) for each input we see. Here, though, we have a single scalar output – essentially a regression task.
Even though we’re not interested in sending the value to the screen, we still need to compute the output, because it becomes part of the input to the next LSTM cell.
Here’s the image to think of:
By now you should be thoroughly bored with these equations:
And we’ll make a stochastic decision as to whether we pass this output along. For the purposes of this example, let’s assume the stochastic decision results in a 1.
The hidden layer (hidden state)
I bet you were wondering when we’d get to this. The hidden layer is separate from the memory cell, but very related. I like to think of it as the part of the memory cell that we want to ensure persists. Here’s how we do it:
(yes, I’m rounding. I’ve been doing that a lot.)
See what we did there? The output gate decides whether the signal from the memory cell gets sent forward as part of the input to the next LSTM cell.
The second LSTM cell
We’ll assume the weights are shared across LSTM cells. The equations are exactly the same, but now we use where before we used and where we used and where we used , etc. Let’s say we have input , and target scalar value . Here are all of the familiar computations written out, and final answers given (assuming all stochastic gate decisions result in the signal being propagated forward, and 0 information forgotten):
Okay, now we’ve reached the end of our sequence. It’s time to figure out the final output, that we’re going to use for the error calculation. This depends exclusively on the hidden state (remember the memory cell is input to the hidden state, so we only need to use the hidden state to take into account the entire memory of our LSTM).
That value is our final output.
But wait, weren’t we aiming for 0.08? We need to make some changes to our model. To do that, we’ll calculate the error and backpropagate the signal to update our weights.
The error (mean squared error, or MSE, but with only one value so ‘mean’ is irrelevant):
A detailed walkthrough: backpropagation through time
This is going to get messy. There is a whole bunch of chain rule going on here. There are a few paths the derivative can take through the network, which is a great thing, actually, because it means the gradient is less prone to vanishing, but we’ll get to that. First, let’s figure out how to send our error signal back. We need the derivative with respect to our weight matrix, but to get there we have to go through all of our model components. Many thanks to Alex Graves for a beautiful thesis and to the author of this very helpful blog for filling in the gaps in my knowledge here.
The first step is to calculate the gradient of the error with respect to the output:
We can see the dependency on the hidden state by expanding :
I’ll use to refer to the partial derivative of the error with respect to , similar to that blog post.
Now we need to differentiate through the hidden state to get to the next part. Alternatively, we could differentiate through directly – that’s the second path the gradient can take. Actually, going directly through the memory cell saves a step (is shorter) as shown here (pages 12 – 13).
Now we need to go through the input and forget gates.
The input gate:
The forget gate:
The proposal for the new memory state:
The previous cell state:
The input to the proposal:
The net input to the input gate:
because of the derivative of the sigmoid function
The net input to the forget gate:
The net input to the output gate:
Now we need to recall our definitions from way up top:
And let be the total input at time : .
Then we can define , and collect all of our ‘lowest’ derivatives:
Then our last derivatives are:
And there you have it! Backpropagation through an LSTM.
I used the typical activations here – namely, sigmoids and tanh – but this can also be done with ReLUs. I leave that to a future post.