Skip to content
Recurrent Networks
Lesson 5 ⏱ 18 min

LSTMs: cell state and gates

Video coming soon

LSTMs: Cell State, Gates, and the Gradient Highway

Introduces the cell state as a gradient highway, derives all four LSTM gates from first principles, walks through a complete numerical forward pass, and explains why additive updates prevent vanishing gradients.

⏱ ~9 min

🧮

Quick refresher

Sigmoid function

The sigmoid function σ(x) = 1/(1+e^{-x}) squashes any real number to the range (0,1). It's used in gates because its output can be interpreted as a 'proportion': 0 means 'ignore completely', 1 means 'pass through fully', values in between mean 'partially.'

Example

σ(0) = 0.5 (halfway).

σ(2) ≈ 0.88 (mostly open).

σ(-2) ≈ 0.12 (mostly closed).

σ(5) ≈ 0.993 (almost fully open).

The vanishing gradient problem comes from multiplying many small Jacobians together over long sequences. The LSTM's solution is architectural: introduce an additional state variable that updates additively rather than multiplicatively, creating a "gradient highway" that lets information (and gradients) flow over hundreds of steps without shrinking.

LSTMs were the dominant architecture for language modeling, speech recognition, and machine translation from roughly 2014 to 2018. They powered the first wave of production neural NLP systems at Google, Baidu, and Apple. Understanding the LSTM cell is essential for reading that entire body of work.

The Key Idea: Additive Updates

In a vanilla RNN, the hidden state updates multiplicatively:

ht=tanh(Whht1+Wxxt+b)h_t = \tanh(W_h h_{t-1} + W_x x_t + b)

The gradient ht/ht1\partial h_t / \partial h_{t-1} involves multiplying by WhTW_h^T and by tanh'. After T such multiplications, gradients vanish.

The LSTM introduces a that updates as:

ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t
ctc_t
cell state at time t
ftf_t
forget gate — values in (0,1), controls how much of c_{t-1} to retain
ct1c_{t-1}
previous cell state
iti_t
input gate — values in (0,1), controls how much new info to add
c~tc̃_t
candidate cell state — new information to potentially add

The gradient of ctc_t with respect to ct1c_{t-1}: just diag(ft)\text{diag}(f_t) — a diagonal matrix of forget gate values. When ft1f_t \approx 1 (remember everything), this is near-identity: gradients flow backward through cc with minimal change. No matrix multiplication, no tanh derivative. This is the gradient highway.

The Four Gate Equations

The LSTM has four components, all computed from the concatenation [ht1;xt][h_{t-1}; x_t] (the previous hidden state stacked with the current input):

Forget gate — what fraction of the cell state to erase:

ft=σ(Wf[ht1;,xt]+bf)f_t = \sigma(W_f \cdot [h_{t-1};, x_t] + b_f)
ftf_t
forget gate vector, values in (0,1)
WfW_f
forget gate weight matrix
bfb_f
forget gate bias

Input gate — what fraction of the candidate to write:

it=σ(Wi[ht1;,xt]+bi)i_t = \sigma(W_i \cdot [h_{t-1};, x_t] + b_i)
iti_t
input gate vector

Candidate cell state — what new information to potentially store:

c~t=tanh(Wc[ht1;,xt]+bc)\tilde{c}t = \tanh(W_c \cdot [h{t-1};, x_t] + b_c)
c~tc̃_t
candidate values for cell state update

Output gate — what to expose from the cell state:

ot=σ(Wo[ht1;,xt]+bo)o_t = \sigma(W_o \cdot [h_{t-1};, x_t] + b_o)
oto_t
output gate vector

Then the updates:

ct=ftct1+itc~tc_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t
ht=ottanh(ct)h_t = o_t \odot \tanh(c_t)
hth_t
LSTM hidden state — filtered view of cell state

Parameter Count

Each gate has its own weight matrix applied to [ht1;xt][h_{t-1}; x_t]. If the hidden size is and input size is , each weight matrix is n×(n+d)n \times (n+d). With 4 matrices:

Total params=4×n(n+d)+4n (biases)=4n(n+d+1)\text{Total params} = 4 \times n(n+d) + 4n \text{ (biases)} = 4n(n+d+1)

For n=128,d=50n=128, d=50: 4×128×179=91,6484 \times 128 \times 179 = 91{,}648 — roughly 4× more than a vanilla RNN.

Worked Forward Pass

Let's trace one LSTM step. Use n=2n=2, d=1d=1, starting from h0=[0,0]Th_0 = [0,0]^T, c0=[0,0]Tc_0 = [0,0]^T.

Current input: x1=1.0x_1 = 1.0. Concatenation: [h0;x1]=[0,0,1]T[h_0; x_1] = [0, 0, 1]^T.

Forget gate (suppose Wf[h0;x1]+bf=[0.8,0.2]W_f \cdot [h_0; x_1] + b_f = [0.8, 0.2]):

f1=σ([0.8,0.2])=[0.690,0.550]f_1 = \sigma([0.8, 0.2]) = [0.690, 0.550]

Input gate (suppose pre-activation = [-0.5, 1.2]):

i1=σ([0.5,1.2])=[0.378,0.769]i_1 = \sigma([-0.5, 1.2]) = [0.378, 0.769]

Candidate (suppose pre-activation = [0.6, -0.3]):

c~1=tanh([0.6,0.3])=[0.537,0.291]\tilde{c}_1 = \tanh([0.6, -0.3]) = [0.537, -0.291]

Cell state update:

c1=f1c0+i1c~1=[0.690,0.550][0,0]+[0.378,0.769][0.537,0.291]c_1 = f_1 \odot c_0 + i_1 \odot \tilde{c}_1 = [0.690, 0.550] \odot [0,0] + [0.378, 0.769] \odot [0.537, -0.291]
c1=[0,0]+[0.203,0.224]=[0.203,0.224]c_1 = [0, 0] + [0.203, -0.224] = [0.203, -0.224]

Output gate (suppose pre-activation = [0.3, -0.7]):

o1=σ([0.3,0.7])=[0.574,0.332]o_1 = \sigma([0.3, -0.7]) = [0.574, 0.332]

Hidden state:

h1=o1tanh(c1)=[0.574,0.332]tanh([0.203,0.224])h_1 = o_1 \odot \tanh(c_1) = [0.574, 0.332] \odot \tanh([0.203, -0.224])
=[0.574,0.332][0.201,0.221]=[0.115,0.073]= [0.574, 0.332] \odot [0.201, -0.221] = [0.115, -0.073]

The cell state c1c_1 and hidden state h1h_1 carry forward to the next step.

The Gradient Highway in Action

During backpropagation, the gradient of the loss with respect to ct1c_{t-1} through the cell state path is:

ctct1=diag(ft)\frac{\partial c_t}{\partial c_{t-1}} = \text{diag}(f_t)

As long as ftf_t is close to 1, this is close to the identity matrix — gradients pass through with minimal modification. Over 100 steps:

c100c1=t=2100diag(ft)\frac{\partial c_{100}}{\partial c_1} = \prod_{t=2}^{100} \text{diag}(f_t)

If forget gate values are 0.95: 0.95990.0060.95^{99} \approx 0.006. Still small — but the key difference from vanilla RNNs is that the forget gate learns to be close to 1 when long-range memory is needed. A vanilla RNN's Jacobian depends on WhW_h and tanh' — it can't specifically decide to preserve certain memory components over long distances.

PyTorch Implementation

lstm = nn.LSTM(
    input_size=50,   # d
    hidden_size=128, # n
    num_layers=1,
    batch_first=True
)

x = torch.randn(32, 20, 50)         # [batch, seq_len, input_size]
h0 = torch.zeros(1, 32, 128)        # initial hidden state
c0 = torch.zeros(1, 32, 128)        # initial cell state

output, (hn, cn) = lstm(x, (h0, c0))
# output: [32, 20, 128] — hidden states at every step
# hn, cn: [1, 32, 128]  — final hidden and cell states

The LSTM is powerful but adds complexity: 4 matrix multiplications per step instead of 1, plus more hyperparameters to tune. The next lesson covers the GRU — a streamlined version that achieves similar results with fewer gates.

Quiz

1 / 3

The LSTM forget gate f_t = σ(W_f·[h_{t-1}, x_t] + b_f) determines...