gradientflow

Back to Field Notes
theory

Reading Step-by-Step: LSTMs and Their Gates

A deep dive into Long Short-Term Memory (LSTM) networks, the architecture that solved the vanishing gradient problem for sequential data using a clever system of gates.

July 3, 2025
8 min read
Module 3
LSTMsRNNsDeep LearningSequential DataGated Architectures

Alright, we know encoders summarise data, and we know basic networks can struggle with learning long patterns due to gradient issues. So, how did engineers first tackle encoding sequences like text, where order and long-range context matter immensely? Enter the Recurrent Neural Network (RNN) family, and specifically their more sophisticated members: Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs).

Unlike feed-forward networks that process fixed-size inputs independently, RNNs are designed for sequences. They process input one step at a time (e.g., one word after another), maintaining an internal hidden state — think of it as memory — that gets updated at each step and influences the processing of the next step. This recurrence allows them to capture sequential dependencies.

However, the simple RNN architecture, while elegant, suffers badly from the vanishing gradient problem we discussed in Module 1. Its memory is too short-term; information from early parts of a sequence tends to get washed out by the time the network reaches the end.


The Sandbox — Get Your Hands Dirty

How do we fix that memory problem? The best way to understand the solution is to operate it yourself.


LSTMs: Introducing Gates for Smarter Memory

Remember the rant about how simple RNNs forget almost everything? The sandbox above is the answer. The sliders are the "cure" for the RNN's terrible memory — a crucial upgrade that turns a basic RNN into an LSTM.

Imagine the simple RNN memory as a single conveyor belt where everything gets overwritten at each step. LSTMs add sophisticated gates — little control valves that regulate the flow of information onto and off this conveyor belt (called the cell state).

The three gates are the heart of the LSTM:

  • Forget Gate — Decides what old information to discard from the cell state. A value near 0 means "forget everything," near 1 means "keep everything." Useful when the topic of a conversation completely changes.
  • Input Gate — Determines what new information should be stored in the cell state. Works alongside a candidate value layer that proposes new content. Together they decide what gets written to memory.
  • Output Gate — Controls what information flows out as the hidden state (the output at this timestep). The cell state might hold a lot; the output gate decides what's actually relevant to expose right now.

These gates use sigmoid functions to make their decisions — outputting values between 0 and 1 that act as "how much to let through." This selective memory management is what allows LSTMs to capture dependencies over much longer sequences than simple RNNs.


Test Your Understanding

Match the gate to its job.


Expanding the View — Looking Both Ways

Standard LSTMs are like reading a book one way: they only know what came before. But often, understanding a word requires context from what comes after it.

Bidirectional LSTMs (BiLSTMs) address this by using two separate LSTMs: one reads the sequence forward, the other reads it backward. Their hidden states at each position are then combined (e.g., concatenated). The difference this makes is immediate: the visualizer below uses that exact "bank" ambiguity to show what a one-directional LSTM misses.


For the Curious: The Math, Advanced Topics & Pitfalls

LSTM Equations: The Inner Workings

Here's a peek at the math that makes the gates work. Each equation uses the current input , the previous hidden state , the previous cell state , and learned weight matrices and biases . The symbol denotes element-wise multiplication.

LSTM Gate Equations
python
import torch
import torch.nn.functional as F
from torch import nn

# input_size = 3,  hidden_size = 4
input_size, hidden_size = 3, 4

# One weight matrix per gate (normally fused inside nn.LSTMCell)
W_f = nn.Linear(input_size + hidden_size, hidden_size)  # forget
W_i = nn.Linear(input_size + hidden_size, hidden_size)  # input
W_C = nn.Linear(input_size + hidden_size, hidden_size)  # candidate
W_o = nn.Linear(input_size + hidden_size, hidden_size)  # output

x_t   = torch.randn(1, input_size)        # current token embedding
h_tm1 = torch.zeros(1, hidden_size)       # h_{t-1}
C_tm1 = torch.zeros(1, hidden_size)       # C_{t-1}

combined = torch.cat([h_tm1, x_t], dim=1) # [h_{t-1}, x_t]

f_t     = torch.sigmoid(W_f(combined))    # forget gate  → [0, 1]
i_t     = torch.sigmoid(W_i(combined))    # input gate   → [0, 1]
C_tilde = torch.tanh(W_C(combined))       # candidate    → [-1, 1]
o_t     = torch.sigmoid(W_o(combined))    # output gate  → [0, 1]
C_t     = f_t * C_tm1 + i_t * C_tilde    # new cell state
h_t     = o_t * torch.tanh(C_t)          # new hidden state

print(f"forget  f_t : {f_t.tolist()[0][:4]}")
print(f"input   i_t : {i_t.tolist()[0][:4]}")
print(f"cell    C_t : {C_t.tolist()[0][:4]}")
print(f"hidden  h_t : {h_t.tolist()[0][:4]}")
# nn.LSTMCell does all of the above in one call:
# h_t, C_t = nn.LSTMCell(input_size, hidden_size)(x_t, (h_tm1, C_tm1))
Pause & Reflect

When might a GRU be preferable to an LSTM?

Hidden State Leakage
Not zeroing the hidden state between independent training sequences can leak information from previous examples. This leads to unrealistic performance during training that doesn't generalise — the model "cheats" by using information it shouldn't have access to. Always reset your hidden states between independent sequences!
python
import torch
import torch.nn as nn

lstm = nn.LSTM(input_size=10, hidden_size=32, batch_first=True)

# ❌  Bug: hidden state carries over between unrelated sequences
hidden = None
for sequence in training_data:
    output, hidden = lstm(sequence, hidden)  # previous hidden leaks in
    loss = criterion(output, targets)
    optimizer.zero_grad(); loss.backward(); optimizer.step()

# ✓  Fix: pass nothing (or explicit zeros) before each new sequence
for sequence in training_data:
    output, _ = lstm(sequence)               # hidden defaults to zeros
    # explicit alternative:
    # h0 = torch.zeros(1, batch_size, 32)
    # c0 = torch.zeros(1, batch_size, 32)
    # output, _ = lstm(sequence, (h0, c0))
    loss = criterion(output, targets)
    optimizer.zero_grad(); loss.backward(); optimizer.step()

Final Knowledge Check

checkpoint
01 / 02

What is the primary role of the output gate oto_t in an LSTM?

TL;DR
  • RNNs process sequences step-by-step but suffer from vanishing gradients (short-term memory).
  • LSTMs solve this with three gates: Forget, Input, and Output.
  • Gates learn to control information flow, enabling long-range dependencies.
  • Cell state acts as a highway for information to flow through time.
  • Bidirectional LSTMs process sequences in both directions for richer context.
Summary

Sequential encoders like RNNs, LSTMs, and GRUs process inputs one element at a time, maintaining a memory (hidden state). While basic RNNs struggle with long sequences due to vanishing gradients, LSTMs solve this with three gates that control what gets written to memory, what gets erased, and what gets read out. The result is a network that can actually hold onto context across long sequences.

My Take
LSTMs are clever. They can also be genuinely fiddly to tune. While they were a huge step forward, these days, if the task allows, I often reach for a Transformer-based approach first, even if it means potentially needing more data or GPU time. The parallelisability and attention mechanisms of Transformers often make them worth the trade-off. But For smaller datasets or streaming input where you need step-by-step processing, an LSTM is often the right call.