One-sentence summary:
model.pyis just all the components from earlier chapters — Embedding, Positional Encoding, Multi-Head Attention, FFN, LayerNorm — wired together with PyTorch. Each class maps directly to a formula.
Complete code repository: github.com/waylandzhang/Transformer-from-scratch
18.1 Before Writing Code: Overall Structure
18.1.1 What We Are Implementing
Model (complete model)
├── Token Embedding
├── Positional Encoding
├── N × TransformerBlock
│ ├── LayerNorm
│ ├── Multi-Head Attention
│ ├── LayerNorm
│ └── Feed Forward Network
├── Final LayerNorm
└── Output Linear (projection to vocabulary)
18.1.2 File Structure
Everything lives in a single model.py:
# model.py overall structure
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class FeedForwardNetwork(nn.Module): # FFN
...
class Attention(nn.Module): # single attention head
...
class MultiHeadAttention(nn.Module): # multi-head attention
...
class TransformerBlock(nn.Module): # Transformer block
...
class Model(nn.Module): # complete model
...
I prefer writing this file from scratch once before reaching for higher-level libraries. After that, the libraries feel like productivity tools instead of magic curtains.
18.2 Feed Forward Network
18.2.1 Recap of FFN Structure
From Chapter 15: FFN is a two-layer fully connected network:
Input [batch, seq, d_model]
|
Linear1: d_model -> d_model × 4 (4× expansion)
|
ReLU activation
|
Linear2: d_model × 4 -> d_model (back to model width)
|
Dropout
|
Output [batch, seq, d_model]
The 4× expansion gives the model more representational capacity at each position before compressing back. FFN is where most of the model's "knowledge" is stored.
18.2.2 Code
# Feed Forward Network definition
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, dropout):
super().__init__()
self.d_model = d_model
self.dropout = dropout
self.ffn = nn.Sequential(
nn.Linear(self.d_model, self.d_model * 4), # expand 4×
nn.ReLU(), # activation
nn.Linear(self.d_model * 4, self.d_model), # compress back
nn.Dropout(self.dropout) # regularization
)
def forward(self, x):
return self.ffn(x)
18.2.3 Code Walkthrough
| Code | Role | Shape change |
|---|---|---|
nn.Linear(d_model, d_model * 4) | first linear layer | [B,T,512] → [B,T,2048] |
nn.ReLU() | non-linearity | unchanged |
nn.Linear(d_model * 4, d_model) | second linear layer | [B,T,2048] → [B,T,512] |
nn.Dropout(dropout) | random dropout for regularization | unchanged |
18.3 Attention (Single Head)
18.3.1 Recap of the Attention Formula
In code we need to implement:
- Generate Q, K, V via linear projection
- Compute attention scores: Q @ K^T
- Scale by √d_k
- Apply Causal Mask (prevent attending to future positions)
- Softmax normalization
- Multiply by V to produce the output
18.3.2 Code
# Single-head Scaled Dot Product Attention
class Attention(nn.Module):
def __init__(self, d_model, head_size, context_length, dropout):
super().__init__()
self.d_model = d_model
self.head_size = head_size
self.context_length = context_length
self.dropout = dropout
# linear projections for Q, K, V
self.Wq = nn.Linear(self.d_model, self.head_size, bias=False)
self.Wk = nn.Linear(self.d_model, self.head_size, bias=False)
self.Wv = nn.Linear(self.d_model, self.head_size, bias=False)
# Causal Mask: lower-triangular matrix
self.register_buffer('mask', torch.tril(torch.ones(self.context_length, self.context_length)))
self.dropout = nn.Dropout(self.dropout)
def forward(self, x):
B, T, C = x.shape # Batch, Time (seq_len), Channels (d_model)
# 1. generate Q, K, V
q = self.Wq(x) # [B, T, head_size]
k = self.Wk(x) # [B, T, head_size]
v = self.Wv(x) # [B, T, head_size]
# 2. compute attention scores Q @ K^T and scale
weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
# weights: [B, T, T]
# 3. apply Causal Mask (set future positions to -inf)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
# 4. Softmax normalization
weights = F.softmax(weights, dim=-1)
# 5. Dropout
weights = self.dropout(weights)
# 6. multiply by V
output = weights @ v # [B, T, head_size]
return output
18.3.3 Key Code Notes
How the Causal Mask works:
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
torch.tril produces a lower-triangular matrix:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
Position i can attend to positions 0 through i only. That is what Causal means — causality preserved in time.
Why register_buffer?
The mask is not a trainable parameter — it never updates. But it needs to travel with the model to whatever device (CPU or GPU) the model is on. register_buffer is exactly the right tool for that: not a parameter, but a persistent tensor.
18.4 Multi-Head Attention
18.4.1 The Multi-Head Idea
Multi-Head Attention = multiple single-head attention instances running in parallel, outputs concatenated at the end.
Each head can attend to different patterns. One head might track subject-verb agreement, another might track pronoun-referent relationships. Concatenating them gives richer representations than any single head could provide.
18.4.2 Code
# Multi-Head Attention definition
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_size = head_size
self.context_length = context_length
self.dropout = dropout
# create multiple attention heads
self.heads = nn.ModuleList([
Attention(self.d_model, self.head_size, self.context_length, self.dropout)
for _ in range(self.num_heads)
])
# output projection Wo
self.projection_layer = nn.Linear(self.d_model, self.d_model)
self.dropout = nn.Dropout(self.dropout)
def forward(self, x):
# run all heads in parallel
head_outputs = [head(x) for head in self.heads]
# concatenate all head outputs
head_outputs = torch.cat(head_outputs, dim=-1) # [B, T, num_heads * head_size] = [B, T, d_model]
# apply output projection
out = self.dropout(self.projection_layer(head_outputs))
return out
18.4.3 Shape Tracking
Assuming d_model=512, num_heads=8, head_size=64:
Input x: [B, T, 512]
|
Each head output: [B, T, 64] # 8 heads
|
Concatenate: [B, T, 512] # 64 × 8 = 512
|
Wo projection: [B, T, 512]
|
Output: [B, T, 512]
The key relationship: head_size = d_model // num_heads
18.5 Paper-Version Multi-Head Attention
18.5.1 Two Implementations Compared
The implementation above is physically separate: each head has its own Wq, Wk, Wv matrices.
The original "Attention Is All You Need" paper implementation is logically separate: one large linear layer, then reshape into multiple heads.
18.5.2 Paper-Version Code
# Paper-style Multi-Head Attention (logical split)
class MultiHeadAttention_Paper(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.context_length = context_length
self.d_model = d_model
self.num_heads = num_heads
self.head_size = head_size
# one large linear layer, output dim still d_model
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.tril(torch.ones(self.context_length, self.context_length)))
def split_heads(self, x):
"""Logically split into multiple heads"""
batch_size = x.shape[0]
context_length = x.shape[1]
# [B, T, d_model] -> [B, T, num_heads, head_size] -> [B, num_heads, T, head_size]
x = x.reshape(batch_size, context_length, self.num_heads, self.head_size)
x = x.permute(0, 2, 1, 3)
return x
def forward(self, x):
B, T, C = x.shape
# project then split
q = self.split_heads(self.Wq(x)) # [B, num_heads, T, head_size]
k = self.split_heads(self.Wk(x))
v = self.split_heads(self.Wv(x))
# compute attention
weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights = self.dropout(weights)
output = weights @ v # [B, num_heads, T, head_size]
# merge heads: [B, num_heads, T, head_size] -> [B, T, d_model]
output = output.transpose(1, 2).reshape(-1, T, C)
# output projection
output = self.Wo(output)
return output
18.5.3 Comparison
| Physically separate | Logically separate (paper) | |
|---|---|---|
| Wq/Wk/Wv count | num_heads each | 1 each |
| Parameter count | identical | identical |
| Compute efficiency | slightly lower (loop) | higher (GPU-parallel) |
| Code clarity | cleaner for learning | slightly more complex |
Why parameter counts are the same:
- Physically separate:
num_heads × (d_model × head_size) = d_model × d_model - Logically separate:
d_model × d_model
In practice, the paper version is faster because it enables GPU-level parallelism. For educational purposes, the physically separate version is easier to reason about.
18.6 Transformer Block
18.6.1 Block Structure
Each block contains:
- LayerNorm → Multi-Head Attention → residual
- LayerNorm → FFN → residual
This is the Pre-Norm structure used by GPT-2.
18.6.2 Code
# Transformer Block definition
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.mha = MultiHeadAttention(d_model, num_heads, head_size, context_length, dropout)
self.ffn = FeedForwardNetwork(d_model, dropout)
def forward(self, x):
# Attention + residual
x = x + self.mha(self.ln1(x))
# FFN + residual
x = x + self.ffn(self.ln2(x))
return x
18.6.3 Pre-Norm vs Post-Norm
Pre-Norm (what we use):
x = x + self.mha(self.ln1(x)) # normalize before Attention
Post-Norm (original Transformer):
x = self.ln1(x + self.mha(x)) # normalize after Attention
Pre-Norm trains more stably. This is why GPT-2, LLaMA, and all modern models use it.
18.7 The Complete Model Class
18.7.1 Model Structure
# Complete model definition
class Model(nn.Module):
def __init__(self, h_params):
super().__init__()
# read config from hyperparameter dictionary
self.context_length = h_params['context_length']
self.d_model = h_params['d_model']
self.num_blocks = h_params['num_blocks']
self.num_heads = h_params['num_heads']
self.head_size = self.d_model // self.num_heads
self.dropout = h_params['dropout']
self.device = h_params['device']
self.max_token_value = h_params['max_token_value']
# Token Embedding
self.token_embedding_lookup_table = nn.Embedding(self.max_token_value, self.d_model)
# Transformer Blocks + final LayerNorm
self.transformer_blocks = nn.Sequential(*(
[TransformerBlock(self.d_model, self.num_heads, self.head_size,
self.context_length, self.dropout)
for _ in range(self.num_blocks)] +
[nn.LayerNorm(self.d_model)]
))
# output projection layer
self.model_out_linear_layer = nn.Linear(self.d_model, self.max_token_value)
18.7.2 Forward Pass
def forward(self, idx, targets=None):
B, T = idx.shape
# 1. positional encoding (sinusoidal)
position_encoding_lookup_table = torch.zeros(self.context_length, self.d_model, device=self.device)
position = torch.arange(0, self.context_length, dtype=torch.float, device=self.device).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float, device=self.device) * (-math.log(10000.0) / self.d_model))
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
position_embedding = position_encoding_lookup_table[:T, :].to(self.device)
# 2. token embedding + positional encoding
x = self.token_embedding_lookup_table(idx) + position_embedding
# 3. pass through all Transformer blocks
x = self.transformer_blocks(x)
# 4. project to vocabulary
logits = self.model_out_linear_layer(x)
# 5. if targets are provided (training mode), compute loss
if targets is not None:
B, T, C = logits.shape
logits_reshaped = logits.view(B * T, C)
targets_reshaped = targets.view(B * T)
loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
else:
loss = None
return logits, loss
18.7.3 Key Code Notes
Positional encoding formula:
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
This is the sinusoidal position encoding from Chapter 5:
- even dimensions use sin
- odd dimensions use cos
- frequency decreases as dimension index increases
Loss function:
loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
Cross-entropy loss measures the KL divergence between the model's predicted distribution and the true one-hot target. When a model is randomly initialized and the vocabulary has 50,000 tokens, this starts around ln(50000) ≈ 10.8. A well-trained model gets this below 3.
18.8 The Generation Function
18.8.1 Autoregressive Generation
def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
"""
Autoregressive text generation.
Args:
idx: initial token IDs [B, T]
max_new_tokens: maximum new tokens to generate
temperature: controls output randomness
top_k: sample only from the top-k highest probability tokens
"""
for _ in range(max_new_tokens):
# 1. crop to maximum context length
idx_crop = idx[:, -self.context_length:]
# 2. forward pass
logits, loss = self.forward(idx_crop)
# 3. take last position logits only, apply temperature
logits = logits[:, -1, :] / temperature
# 4. optional: keep only top-k candidates
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# 5. Softmax to get probabilities
probs = F.softmax(input=logits, dim=-1)
# 6. sample next token
idx_next = torch.multinomial(input=probs, num_samples=1)
# 7. append to sequence
idx = torch.cat((idx, idx_next), dim=1)
return idx
18.8.2 Temperature
Temperature from Chapter 6:
logits = logits[:, -1, :] / temperature
- T < 1: probabilities more concentrated — more deterministic
- T = 1: original distribution
- T > 1: probabilities more uniform — more random
For factual completions, use T ≈ 0.3. For creative generation, T ≈ 0.8 to 1.0.
18.8.3 Top-K Sampling
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
Keep only the k highest-probability tokens and set everything else to -inf. This prevents the model from generating low-probability tokens that are statistically unlikely and often incoherent.
18.9 Complete model.py
"""
Transformer Decoder-only base model for text generation
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, dropout):
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model),
nn.Dropout(dropout)
)
def forward(self, x):
return self.ffn(x)
class Attention(nn.Module):
def __init__(self, d_model, head_size, context_length, dropout):
super().__init__()
self.head_size = head_size
self.Wq = nn.Linear(d_model, head_size, bias=False)
self.Wk = nn.Linear(d_model, head_size, bias=False)
self.Wv = nn.Linear(d_model, head_size, bias=False)
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.Wq(x)
k = self.Wk(x)
v = self.Wv(x)
weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights = self.dropout(weights)
return weights @ v
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.heads = nn.ModuleList([
Attention(d_model, head_size, context_length, dropout)
for _ in range(num_heads)
])
self.projection_layer = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
head_outputs = torch.cat([head(x) for head in self.heads], dim=-1)
return self.dropout(self.projection_layer(head_outputs))
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.mha = MultiHeadAttention(d_model, num_heads, head_size, context_length, dropout)
self.ffn = FeedForwardNetwork(d_model, dropout)
def forward(self, x):
x = x + self.mha(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class Model(nn.Module):
def __init__(self, h_params):
super().__init__()
self.context_length = h_params['context_length']
self.d_model = h_params['d_model']
self.num_blocks = h_params['num_blocks']
self.num_heads = h_params['num_heads']
self.head_size = self.d_model // self.num_heads
self.dropout = h_params['dropout']
self.device = h_params['device']
self.max_token_value = h_params['max_token_value']
self.token_embedding_lookup_table = nn.Embedding(self.max_token_value, self.d_model)
self.transformer_blocks = nn.Sequential(*(
[TransformerBlock(self.d_model, self.num_heads, self.head_size,
self.context_length, self.dropout)
for _ in range(self.num_blocks)] +
[nn.LayerNorm(self.d_model)]
))
self.model_out_linear_layer = nn.Linear(self.d_model, self.max_token_value)
def forward(self, idx, targets=None):
B, T = idx.shape
# Positional Encoding
position_encoding = torch.zeros(self.context_length, self.d_model, device=self.device)
position = torch.arange(0, self.context_length, dtype=torch.float, device=self.device).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float, device=self.device) * (-math.log(10000.0) / self.d_model))
position_encoding[:, 0::2] = torch.sin(position * div_term)
position_encoding[:, 1::2] = torch.cos(position * div_term)
x = self.token_embedding_lookup_table(idx) + position_encoding[:T, :].to(self.device)
x = self.transformer_blocks(x)
logits = self.model_out_linear_layer(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
loss = None
return logits, loss
def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_crop = idx[:, -self.context_length:]
logits, _ = self.forward(idx_crop)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
18.10 Chapter Summary
18.10.1 Code-to-Concept Mapping
| Class | Concept | Chapter |
|---|---|---|
FeedForwardNetwork | feed forward network | Ch. 7 |
Attention | single-head attention | Ch. 9-12 |
MultiHeadAttention | multi-head attention | Ch. 11 |
TransformerBlock | Transformer block | Ch. 13 |
Model | complete model | Ch. 15 |
18.10.2 Parameter Count Estimate
Assuming d_model=512, num_heads=8, num_blocks=6, vocab_size=50,000:
| Component | Formula | Parameters |
|---|---|---|
| Token Embedding | vocab × d_model | ~25.6M |
| Attention (×6) | 4 × d_model² × 6 | ~6.3M |
| FFN (×6) | 2 × d_model × 4×d_model × 6 | ~12.6M |
| Output Linear | d_model × vocab | ~25.6M |
Total: approximately 70M parameters
18.10.3 Core Insight
model.pyis just the components from all the earlier chapters connected with PyTorch. Each class maps to one concept: FFN, Attention, MultiHeadAttention, TransformerBlock, Model. Understanding the concepts makes the code obvious, not the reverse.
Chapter Checklist
After this chapter you should be able to:
- Implement
FeedForwardNetworkindependently. - Implement
Attention(including Causal Mask) independently. - Implement
MultiHeadAttentionindependently. - Explain the difference between physically separate and logically separate MHA implementations.
- Explain the complete
Model.forward()data flow.
Complete Code
The complete implementation is on GitHub:
Includes model.py, train.py, inference.py, and a step-by-step Jupyter notebook.
See You in the Next Chapter
The model exists now, but it does not know anything — every parameter is randomly initialized. Give it a prompt and it will output noise.
Chapter 19 writes the training loop: load data, forward pass, compute loss, backprop, update parameters. By the end, the model actually learns to predict the next token.