Reading Step-by-Step: LSTMs and Their Gates
A deep dive into Long Short-Term Memory (LSTM) networks, the architecture that solved the vanishing gradient problem for sequential data using a clever system of gates.
Alright, we know encoders summarise data, and we know basic networks can struggle with learning long patterns due to gradient issues. So, how did engineers first tackle encoding sequences like text, where order and long-range context matter immensely? Enter the Recurrent Neural Network (RNN) family, and specifically their more sophisticated members: Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs).
Unlike feed-forward networks that process fixed-size inputs independently, RNNs are designed for sequences. They process input one step at a time (e.g., one word after another), maintaining an internal hidden state — think of it as memory — that gets updated at each step and influences the processing of the next step. This recurrence allows them to capture sequential dependencies.
However, the simple RNN architecture, while elegant, suffers badly from the vanishing gradient problem we discussed in Module 1. Its memory is too short-term; information from early parts of a sequence tends to get washed out by the time the network reaches the end.
The Sandbox — Get Your Hands Dirty
How do we fix that memory problem? The best way to understand the solution is to operate it yourself.
LSTMs: Introducing Gates for Smarter Memory
Remember the rant about how simple RNNs forget almost everything? The sandbox above is the answer. The sliders are the "cure" for the RNN's terrible memory — a crucial upgrade that turns a basic RNN into an LSTM.
Imagine the simple RNN memory as a single conveyor belt where everything gets overwritten at each step. LSTMs add sophisticated gates — little control valves that regulate the flow of information onto and off this conveyor belt (called the cell state).
The three gates are the heart of the LSTM:
- Forget Gate — Decides what old information to discard from the cell state. A value near 0 means "forget everything," near 1 means "keep everything." Useful when the topic of a conversation completely changes.
- Input Gate — Determines what new information should be stored in the cell state. Works alongside a candidate value layer that proposes new content. Together they decide what gets written to memory.
- Output Gate — Controls what information flows out as the hidden state (the output at this timestep). The cell state might hold a lot; the output gate decides what's actually relevant to expose right now.
These gates use sigmoid functions to make their decisions — outputting values between 0 and 1 that act as "how much to let through." This selective memory management is what allows LSTMs to capture dependencies over much longer sequences than simple RNNs.
Test Your Understanding
Match the gate to its job.
Expanding the View — Looking Both Ways
Standard LSTMs are like reading a book one way: they only know what came before. But often, understanding a word requires context from what comes after it.
Bidirectional LSTMs (BiLSTMs) address this by using two separate LSTMs: one reads the sequence forward, the other reads it backward. Their hidden states at each position are then combined (e.g., concatenated). The difference this makes is immediate: the visualizer below uses that exact "bank" ambiguity to show what a one-directional LSTM misses.
▶For the Curious: The Math, Advanced Topics & Pitfalls
LSTM Equations: The Inner Workings
Here's a peek at the math that makes the gates work. Each equation uses the current input , the previous hidden state , the previous cell state , and learned weight matrices and biases . The symbol denotes element-wise multiplication.
import torch
import torch.nn.functional as F
from torch import nn
# input_size = 3, hidden_size = 4
input_size, hidden_size = 3, 4
# One weight matrix per gate (normally fused inside nn.LSTMCell)
W_f = nn.Linear(input_size + hidden_size, hidden_size) # forget
W_i = nn.Linear(input_size + hidden_size, hidden_size) # input
W_C = nn.Linear(input_size + hidden_size, hidden_size) # candidate
W_o = nn.Linear(input_size + hidden_size, hidden_size) # output
x_t = torch.randn(1, input_size) # current token embedding
h_tm1 = torch.zeros(1, hidden_size) # h_{t-1}
C_tm1 = torch.zeros(1, hidden_size) # C_{t-1}
combined = torch.cat([h_tm1, x_t], dim=1) # [h_{t-1}, x_t]
f_t = torch.sigmoid(W_f(combined)) # forget gate → [0, 1]
i_t = torch.sigmoid(W_i(combined)) # input gate → [0, 1]
C_tilde = torch.tanh(W_C(combined)) # candidate → [-1, 1]
o_t = torch.sigmoid(W_o(combined)) # output gate → [0, 1]
C_t = f_t * C_tm1 + i_t * C_tilde # new cell state
h_t = o_t * torch.tanh(C_t) # new hidden state
print(f"forget f_t : {f_t.tolist()[0][:4]}")
print(f"input i_t : {i_t.tolist()[0][:4]}")
print(f"cell C_t : {C_t.tolist()[0][:4]}")
print(f"hidden h_t : {h_t.tolist()[0][:4]}")
# nn.LSTMCell does all of the above in one call:
# h_t, C_t = nn.LSTMCell(input_size, hidden_size)(x_t, (h_tm1, C_tm1))When might a GRU be preferable to an LSTM?
import torch
import torch.nn as nn
lstm = nn.LSTM(input_size=10, hidden_size=32, batch_first=True)
# ❌ Bug: hidden state carries over between unrelated sequences
hidden = None
for sequence in training_data:
output, hidden = lstm(sequence, hidden) # previous hidden leaks in
loss = criterion(output, targets)
optimizer.zero_grad(); loss.backward(); optimizer.step()
# ✓ Fix: pass nothing (or explicit zeros) before each new sequence
for sequence in training_data:
output, _ = lstm(sequence) # hidden defaults to zeros
# explicit alternative:
# h0 = torch.zeros(1, batch_size, 32)
# c0 = torch.zeros(1, batch_size, 32)
# output, _ = lstm(sequence, (h0, c0))
loss = criterion(output, targets)
optimizer.zero_grad(); loss.backward(); optimizer.step()Final Knowledge Check
What is the primary role of the output gate in an LSTM?
- RNNs process sequences step-by-step but suffer from vanishing gradients (short-term memory).
- LSTMs solve this with three gates: Forget, Input, and Output.
- Gates learn to control information flow, enabling long-range dependencies.
- Cell state acts as a highway for information to flow through time.
- Bidirectional LSTMs process sequences in both directions for richer context.
Sequential encoders like RNNs, LSTMs, and GRUs process inputs one element at a time, maintaining a memory (hidden state). While basic RNNs struggle with long sequences due to vanishing gradients, LSTMs solve this with three gates that control what gets written to memory, what gets erased, and what gets read out. The result is a network that can actually hold onto context across long sequences.