Single-head attention is powerful, but it makes a single set of attention decisions at once. A word might need to simultaneously track its syntactic role (it's the subject of the verb), its semantic relationship (it's a synonym of another word), and its positional context (it's adjacent to a modifier). These are fundamentally different types of relationships, and one set of Q, K, V projections can only capture one perspective at a time.
Multi-head attention is what allows transformers to track syntax, semantics, and coreference simultaneously. Research on BERT has shown that individual heads specialize — some track direct objects, others track positional proximity. This specialization is why transformers outperform every previous architecture on nearly every NLP benchmark.
Multi-head attention runs several attention mechanisms in parallel, each with different learned projections. Each can specialize in a different type of relationship.
The Architecture
For each of attention heads, compute independent Q, K, V projections and run attention:
- the output of the i-th attention head
- learned projection matrices for head i - different for each head
Each set of projection matrices is learned independently. Head 1 might learn to project in a direction that captures subject-verb agreement; head 2 coreference; head 3 positional proximity.
After computing all h heads, concatenate them along the feature dimension:
- concatenation along the feature/column dimension
- output projection matrix of shape (h*d_v x d_model)
The has shape .
Interactive example
Multi-head attention - toggle heads on/off to see how each specializes on a sample sentence
Coming soon
Dimension Arithmetic
To keep multi-head attention computationally comparable to single-head, you reduce the dimensionality per head:
- Single-head with : use . Compute: proportional to \text{seq_len}^2 \times 512.
- Multi-head with : set per head.
Each head is 8x cheaper; with 8 heads, total compute is the same as single-head. After concatenation: 8 heads × 64 dimensions = 512. Apply : back to .
What Different Heads Learn
Interpretability researchers have probed what specific attention heads learn in trained transformers:
Syntactic heads: track grammatical structure - a verb attending strongly to its subject, or a noun to its modifying adjective, even across long-distance dependencies.
Coreference heads: connect pronouns to antecedents. "The woman said would arrive" - a coreference head shows strong attention from "she" to "woman."
Positional heads: primarily attend to adjacent tokens - next token, previous token, start of sentence. These capture local context.
Copy heads: near-identity behavior - a token attends to itself or to an earlier instance of the same token. These help copy information through the network.
This specialization - you don't program these behaviors.
Typical Hyperparameters
| Model | Heads (h) | d_model | d_k per head |
|---|---|---|---|
| GPT-2 small | 12 | 768 | 64 |
| GPT-2 large | 20 | 1280 | 64 |
| BERT base | 12 | 768 | 64 |
| GPT-3 | 96 | 12288 | 128 |
Notice that is remarkably consistent. What scales is the number of heads and .
Parameter Count
For one multi-head attention block with h heads, , and :
- the model's hidden dimension - determines total parameter count
- h sets of : total
- Same for and
- Output projection :
For : parameters per attention layer.
The Output Projection
The output projection does more than resize the concatenated output. After concatenating all h head outputs (shape ), applies a learned linear combination that allows head information to mix across the feature dimension.
Without , head 1's syntactic features and head 2's semantic features would stay in separate channels of the output. The output projection is what lets the model build token representations simultaneously informed by multiple relationship types — the entire point of multi-head attention.