Skip to content
Attention & Transformers
Lesson 4 ⏱ 12 min

Multi-head attention

Video coming soon

Multi-Head Attention - Parallel Specialization

Animates how h parallel attention heads each learn different projection matrices, how their outputs are concatenated, and why the output projection mixes information across heads.

⏱ ~7 min

🧮

Quick refresher

Matrix concatenation and projection

Concatenating matrices along a dimension stacks them side by side. A projection matrix (linear layer) then maps the larger concatenated representation back to the original size.

Example

3 heads each producing (seq_len x 64) outputs, concatenated: (seq_len x 192).

Output projection W_O (192 x 192) maps back to (seq_len x 192).

Single-head attention is powerful, but it makes a single set of attention decisions at once. A word might need to simultaneously track its syntactic role (it's the subject of the verb), its semantic relationship (it's a synonym of another word), and its positional context (it's adjacent to a modifier). These are fundamentally different types of relationships, and one set of Q, K, V projections can only capture one perspective at a time.

Multi-head attention is what allows transformers to track syntax, semantics, and coreference simultaneously. Research on BERT has shown that individual heads specialize — some track direct objects, others track positional proximity. This specialization is why transformers outperform every previous architecture on nearly every NLP benchmark.

Multi-head attention runs several attention mechanisms in parallel, each with different learned projections. Each can specialize in a different type of relationship.

The Architecture

For each of attention heads, compute independent Q, K, V projections and run attention:

headi=Attention(XWiQ,;XWiK,;XWiV)\text{head}_i = \text{Attention}(X W_i^Q,; X W_i^K,; X W_i^V)
headi\text{head}_i
the output of the i-th attention head
WiQ,WiK,WiVW_i^Q, W_i^K, W_i^V
learned projection matrices for head i - different for each head

Each set of projection matrices is learned independently. Head 1 might learn to project in a direction that captures subject-verb agreement; head 2 coreference; head 3 positional proximity.

After computing all h heads, concatenate them along the feature dimension:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\thinspace W_O
Concat\text{Concat}
concatenation along the feature/column dimension
WOW_O
output projection matrix of shape (h*d_v x d_model)

The WOW_O has shape (hdv×dmodel)(h \cdot d_v \times d_{\text{model}}).

Interactive example

Multi-head attention - toggle heads on/off to see how each specializes on a sample sentence

Coming soon

Dimension Arithmetic

To keep multi-head attention computationally comparable to single-head, you reduce the dimensionality per head:

  • Single-head with dmodel=512d_{\text{model}} = 512: use dk=512d_k = 512. Compute: proportional to \text{seq_len}^2 \times 512.
  • Multi-head with h=8h = 8: set dk=dv=dmodel/h=512/8=64d_k = d_v = d_{\text{model}} / h = 512 / 8 = 64 per head.

Each head is 8x cheaper; with 8 heads, total compute is the same as single-head. After concatenation: 8 heads × 64 dimensions = 512. Apply WOW_O: back to dmodeld_{\text{model}}.

What Different Heads Learn

Interpretability researchers have probed what specific attention heads learn in trained transformers:

Syntactic heads: track grammatical structure - a verb attending strongly to its subject, or a noun to its modifying adjective, even across long-distance dependencies.

Coreference heads: connect pronouns to antecedents. "The woman said would arrive" - a coreference head shows strong attention from "she" to "woman."

Positional heads: primarily attend to adjacent tokens - next token, previous token, start of sentence. These capture local context.

Copy heads: near-identity behavior - a token attends to itself or to an earlier instance of the same token. These help copy information through the network.

This specialization - you don't program these behaviors.

Typical Hyperparameters

ModelHeads (h)d_modeld_k per head
GPT-2 small1276864
GPT-2 large20128064
BERT base1276864
GPT-39612288128

Notice that dk=64d_k = 64 is remarkably consistent. What scales is the number of heads and dmodeld_{\text{model}}.

Parameter Count

For one multi-head attention block with h heads, dmodeld_{\text{model}}, and dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h:

Total=4×dmodel2\text{Total} = 4 \times d_{\text{model}}^2
dmodeld_{\text{model}}
the model's hidden dimension - determines total parameter count
  • h sets of WiQW_i^Q: total =h×dmodel×dk=dmodel2= h \times d_{\text{model}} \times d_k = d_{\text{model}}^2
  • Same for WiKW_i^K and WiVW_i^V
  • Output projection WOW_O: dmodel2d_{\text{model}}^2

For dmodel=768d_{\text{model}} = 768: 4×76822.36M4 \times 768^2 \approx 2.36\text{M} parameters per attention layer.

The Output Projection

The output projection does more than resize the concatenated output. After concatenating all h head outputs (shape T×hdvT \times h \cdot d_v), WOW_O applies a learned linear combination that allows head information to mix across the feature dimension.

Without WOW_O, head 1's syntactic features and head 2's semantic features would stay in separate channels of the output. The output projection is what lets the model build token representations simultaneously informed by multiple relationship types — the entire point of multi-head attention.

Quiz

1 / 3

The main motivation for multi-head attention over single-head attention is...