One-sentence summary: model.py is just all the components from earlier chapters — Embedding, Positional Encoding, Multi-Head Attention, FFN, LayerNorm — wired together with PyTorch. Each class maps directly to a formula.

Complete code repository: github.com/waylandzhang/Transformer-from-scratch


18.1 Before Writing Code: Overall Structure

18.1.1 What We Are Implementing

Model (complete model)
├── Token Embedding
├── Positional Encoding
├── N × TransformerBlock
   ├── LayerNorm
   ├── Multi-Head Attention
   ├── LayerNorm
   └── Feed Forward Network
├── Final LayerNorm
└── Output Linear (projection to vocabulary)

18.1.2 File Structure

Everything lives in a single model.py:

# model.py overall structure
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

class FeedForwardNetwork(nn.Module):     # FFN
    ...

class Attention(nn.Module):              # single attention head
    ...

class MultiHeadAttention(nn.Module):     # multi-head attention
    ...

class TransformerBlock(nn.Module):       # Transformer block
    ...

class Model(nn.Module):                  # complete model
    ...

I prefer writing this file from scratch once before reaching for higher-level libraries. After that, the libraries feel like productivity tools instead of magic curtains.


18.2 Feed Forward Network

18.2.1 Recap of FFN Structure

From Chapter 15: FFN is a two-layer fully connected network:

Input [batch, seq, d_model]
     |
Linear1: d_model -> d_model × 4    ( expansion)
     |
ReLU activation
     |
Linear2: d_model × 4 -> d_model    (back to model width)
     |
Dropout
     |
Output [batch, seq, d_model]

The 4× expansion gives the model more representational capacity at each position before compressing back. FFN is where most of the model's "knowledge" is stored.

18.2.2 Code

# Feed Forward Network definition
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.d_model = d_model
        self.dropout = dropout
        self.ffn = nn.Sequential(
            nn.Linear(self.d_model, self.d_model * 4),  # expand 
            nn.ReLU(),                                   # activation
            nn.Linear(self.d_model * 4, self.d_model),  # compress back
            nn.Dropout(self.dropout)                     # regularization
        )

    def forward(self, x):
        return self.ffn(x)

18.2.3 Code Walkthrough

CodeRoleShape change
nn.Linear(d_model, d_model * 4)first linear layer[B,T,512] [B,T,2048]
nn.ReLU()non-linearityunchanged
nn.Linear(d_model * 4, d_model)second linear layer[B,T,2048] [B,T,512]
nn.Dropout(dropout)random dropout for regularizationunchanged

18.3 Attention (Single Head)

18.3.1 Recap of the Attention Formula

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)V

In code we need to implement:

  1. Generate Q, K, V via linear projection
  2. Compute attention scores: Q @ K^T
  3. Scale by √d_k
  4. Apply Causal Mask (prevent attending to future positions)
  5. Softmax normalization
  6. Multiply by V to produce the output

18.3.2 Code

# Single-head Scaled Dot Product Attention
class Attention(nn.Module):
    def __init__(self, d_model, head_size, context_length, dropout):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.context_length = context_length
        self.dropout = dropout

        # linear projections for Q, K, V
        self.Wq = nn.Linear(self.d_model, self.head_size, bias=False)
        self.Wk = nn.Linear(self.d_model, self.head_size, bias=False)
        self.Wv = nn.Linear(self.d_model, self.head_size, bias=False)

        # Causal Mask: lower-triangular matrix
        self.register_buffer('mask', torch.tril(torch.ones(self.context_length, self.context_length)))

        self.dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        B, T, C = x.shape  # Batch, Time (seq_len), Channels (d_model)

        # 1. generate Q, K, V
        q = self.Wq(x)  # [B, T, head_size]
        k = self.Wk(x)  # [B, T, head_size]
        v = self.Wv(x)  # [B, T, head_size]

        # 2. compute attention scores Q @ K^T and scale
        weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
        # weights: [B, T, T]

        # 3. apply Causal Mask (set future positions to -inf)
        weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))

        # 4. Softmax normalization
        weights = F.softmax(weights, dim=-1)

        # 5. Dropout
        weights = self.dropout(weights)

        # 6. multiply by V
        output = weights @ v  # [B, T, head_size]

        return output

18.3.3 Key Code Notes

How the Causal Mask works:

self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))

torch.tril produces a lower-triangular matrix:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

Position i can attend to positions 0 through i only. That is what Causal means — causality preserved in time.

Why register_buffer?

The mask is not a trainable parameter — it never updates. But it needs to travel with the model to whatever device (CPU or GPU) the model is on. register_buffer is exactly the right tool for that: not a parameter, but a persistent tensor.


18.4 Multi-Head Attention

18.4.1 The Multi-Head Idea

Multi-Head Attention = multiple single-head attention instances running in parallel, outputs concatenated at the end.

Each head can attend to different patterns. One head might track subject-verb agreement, another might track pronoun-referent relationships. Concatenating them gives richer representations than any single head could provide.

18.4.2 Code

# Multi-Head Attention definition
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, head_size, context_length, dropout):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_size = head_size
        self.context_length = context_length
        self.dropout = dropout

        # create multiple attention heads
        self.heads = nn.ModuleList([
            Attention(self.d_model, self.head_size, self.context_length, self.dropout)
            for _ in range(self.num_heads)
        ])

        # output projection Wo
        self.projection_layer = nn.Linear(self.d_model, self.d_model)
        self.dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        # run all heads in parallel
        head_outputs = [head(x) for head in self.heads]

        # concatenate all head outputs
        head_outputs = torch.cat(head_outputs, dim=-1)  # [B, T, num_heads * head_size] = [B, T, d_model]

        # apply output projection
        out = self.dropout(self.projection_layer(head_outputs))

        return out

18.4.3 Shape Tracking

Assuming d_model=512, num_heads=8, head_size=64:

Input x: [B, T, 512]
     |
Each head output: [B, T, 64]  # 8 heads
     |
Concatenate: [B, T, 512]  # 64 × 8 = 512
     |
Wo projection: [B, T, 512]
     |
Output: [B, T, 512]

The key relationship: head_size = d_model // num_heads


18.5 Paper-Version Multi-Head Attention

18.5.1 Two Implementations Compared

The implementation above is physically separate: each head has its own Wq, Wk, Wv matrices.

The original "Attention Is All You Need" paper implementation is logically separate: one large linear layer, then reshape into multiple heads.

18.5.2 Paper-Version Code

# Paper-style Multi-Head Attention (logical split)
class MultiHeadAttention_Paper(nn.Module):
    def __init__(self, d_model, num_heads, head_size, context_length, dropout):
        super().__init__()
        self.context_length = context_length
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_size = head_size

        # one large linear layer, output dim still d_model
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.tril(torch.ones(self.context_length, self.context_length)))

    def split_heads(self, x):
        """Logically split into multiple heads"""
        batch_size = x.shape[0]
        context_length = x.shape[1]
        # [B, T, d_model] -> [B, T, num_heads, head_size] -> [B, num_heads, T, head_size]
        x = x.reshape(batch_size, context_length, self.num_heads, self.head_size)
        x = x.permute(0, 2, 1, 3)
        return x

    def forward(self, x):
        B, T, C = x.shape

        # project then split
        q = self.split_heads(self.Wq(x))  # [B, num_heads, T, head_size]
        k = self.split_heads(self.Wk(x))
        v = self.split_heads(self.Wv(x))

        # compute attention
        weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
        weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        output = weights @ v  # [B, num_heads, T, head_size]

        # merge heads: [B, num_heads, T, head_size] -> [B, T, d_model]
        output = output.transpose(1, 2).reshape(-1, T, C)

        # output projection
        output = self.Wo(output)

        return output

18.5.3 Comparison

Physically separateLogically separate (paper)
Wq/Wk/Wv countnum_heads each1 each
Parameter countidenticalidentical
Compute efficiencyslightly lower (loop)higher (GPU-parallel)
Code claritycleaner for learningslightly more complex

Why parameter counts are the same:

  • Physically separate: num_heads × (d_model × head_size) = d_model × d_model
  • Logically separate: d_model × d_model

In practice, the paper version is faster because it enables GPU-level parallelism. For educational purposes, the physically separate version is easier to reason about.


18.6 Transformer Block

18.6.1 Block Structure

Transformer Block: Pre-Norm residual structure

Each block contains:

  1. LayerNorm → Multi-Head Attention → residual
  2. LayerNorm → FFN → residual

This is the Pre-Norm structure used by GPT-2.

18.6.2 Code

# Transformer Block definition
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, head_size, context_length, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, num_heads, head_size, context_length, dropout)
        self.ffn = FeedForwardNetwork(d_model, dropout)

    def forward(self, x):
        # Attention + residual
        x = x + self.mha(self.ln1(x))
        # FFN + residual
        x = x + self.ffn(self.ln2(x))
        return x

18.6.3 Pre-Norm vs Post-Norm

Pre-Norm (what we use):

x = x + self.mha(self.ln1(x))  # normalize before Attention

Post-Norm (original Transformer):

x = self.ln1(x + self.mha(x))  # normalize after Attention

Pre-Norm trains more stably. This is why GPT-2, LLaMA, and all modern models use it.


18.7 The Complete Model Class

18.7.1 Model Structure

# Complete model definition
class Model(nn.Module):
    def __init__(self, h_params):
        super().__init__()
        # read config from hyperparameter dictionary
        self.context_length = h_params['context_length']
        self.d_model = h_params['d_model']
        self.num_blocks = h_params['num_blocks']
        self.num_heads = h_params['num_heads']
        self.head_size = self.d_model // self.num_heads
        self.dropout = h_params['dropout']
        self.device = h_params['device']
        self.max_token_value = h_params['max_token_value']

        # Token Embedding
        self.token_embedding_lookup_table = nn.Embedding(self.max_token_value, self.d_model)

        # Transformer Blocks + final LayerNorm
        self.transformer_blocks = nn.Sequential(*(
            [TransformerBlock(self.d_model, self.num_heads, self.head_size,
                              self.context_length, self.dropout)
             for _ in range(self.num_blocks)] +
            [nn.LayerNorm(self.d_model)]
        ))

        # output projection layer
        self.model_out_linear_layer = nn.Linear(self.d_model, self.max_token_value)

18.7.2 Forward Pass

def forward(self, idx, targets=None):
    B, T = idx.shape

    # 1. positional encoding (sinusoidal)
    position_encoding_lookup_table = torch.zeros(self.context_length, self.d_model, device=self.device)
    position = torch.arange(0, self.context_length, dtype=torch.float, device=self.device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float, device=self.device) * (-math.log(10000.0) / self.d_model))
    position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
    position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
    position_embedding = position_encoding_lookup_table[:T, :].to(self.device)

    # 2. token embedding + positional encoding
    x = self.token_embedding_lookup_table(idx) + position_embedding

    # 3. pass through all Transformer blocks
    x = self.transformer_blocks(x)

    # 4. project to vocabulary
    logits = self.model_out_linear_layer(x)

    # 5. if targets are provided (training mode), compute loss
    if targets is not None:
        B, T, C = logits.shape
        logits_reshaped = logits.view(B * T, C)
        targets_reshaped = targets.view(B * T)
        loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
    else:
        loss = None

    return logits, loss

18.7.3 Key Code Notes

Positional encoding formula:

div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)

This is the sinusoidal position encoding from Chapter 5:

  • even dimensions use sin
  • odd dimensions use cos
  • frequency decreases as dimension index increases

Loss function:

loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)

Cross-entropy loss measures the KL divergence between the model's predicted distribution and the true one-hot target. When a model is randomly initialized and the vocabulary has 50,000 tokens, this starts around ln(50000) 10.8. A well-trained model gets this below 3.


18.8 The Generation Function

18.8.1 Autoregressive Generation

def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
    """
    Autoregressive text generation.

    Args:
        idx: initial token IDs [B, T]
        max_new_tokens: maximum new tokens to generate
        temperature: controls output randomness
        top_k: sample only from the top-k highest probability tokens
    """
    for _ in range(max_new_tokens):
        # 1. crop to maximum context length
        idx_crop = idx[:, -self.context_length:]

        # 2. forward pass
        logits, loss = self.forward(idx_crop)

        # 3. take last position logits only, apply temperature
        logits = logits[:, -1, :] / temperature

        # 4. optional: keep only top-k candidates
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        # 5. Softmax to get probabilities
        probs = F.softmax(input=logits, dim=-1)

        # 6. sample next token
        idx_next = torch.multinomial(input=probs, num_samples=1)

        # 7. append to sequence
        idx = torch.cat((idx, idx_next), dim=1)

    return idx

18.8.2 Temperature

Temperature from Chapter 6:

logits = logits[:, -1, :] / temperature
  • T < 1: probabilities more concentrated — more deterministic
  • T = 1: original distribution
  • T > 1: probabilities more uniform — more random

For factual completions, use T ≈ 0.3. For creative generation, T ≈ 0.8 to 1.0.

18.8.3 Top-K Sampling

if top_k is not None:
    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
    logits[logits < v[:, [-1]]] = -float('Inf')

Keep only the k highest-probability tokens and set everything else to -inf. This prevents the model from generating low-probability tokens that are statistically unlikely and often incoherent.


18.9 Complete model.py

"""
Transformer Decoder-only base model for text generation
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F


class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ffn(x)


class Attention(nn.Module):
    def __init__(self, d_model, head_size, context_length, dropout):
        super().__init__()
        self.head_size = head_size
        self.Wq = nn.Linear(d_model, head_size, bias=False)
        self.Wk = nn.Linear(d_model, head_size, bias=False)
        self.Wv = nn.Linear(d_model, head_size, bias=False)
        self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        q = self.Wq(x)
        k = self.Wk(x)
        v = self.Wv(x)
        weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
        weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        return weights @ v


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, head_size, context_length, dropout):
        super().__init__()
        self.heads = nn.ModuleList([
            Attention(d_model, head_size, context_length, dropout)
            for _ in range(num_heads)
        ])
        self.projection_layer = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        head_outputs = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.dropout(self.projection_layer(head_outputs))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, head_size, context_length, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, num_heads, head_size, context_length, dropout)
        self.ffn = FeedForwardNetwork(d_model, dropout)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


class Model(nn.Module):
    def __init__(self, h_params):
        super().__init__()
        self.context_length = h_params['context_length']
        self.d_model = h_params['d_model']
        self.num_blocks = h_params['num_blocks']
        self.num_heads = h_params['num_heads']
        self.head_size = self.d_model // self.num_heads
        self.dropout = h_params['dropout']
        self.device = h_params['device']
        self.max_token_value = h_params['max_token_value']

        self.token_embedding_lookup_table = nn.Embedding(self.max_token_value, self.d_model)
        self.transformer_blocks = nn.Sequential(*(
            [TransformerBlock(self.d_model, self.num_heads, self.head_size,
                              self.context_length, self.dropout)
             for _ in range(self.num_blocks)] +
            [nn.LayerNorm(self.d_model)]
        ))
        self.model_out_linear_layer = nn.Linear(self.d_model, self.max_token_value)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # Positional Encoding
        position_encoding = torch.zeros(self.context_length, self.d_model, device=self.device)
        position = torch.arange(0, self.context_length, dtype=torch.float, device=self.device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float, device=self.device) * (-math.log(10000.0) / self.d_model))
        position_encoding[:, 0::2] = torch.sin(position * div_term)
        position_encoding[:, 1::2] = torch.cos(position * div_term)

        x = self.token_embedding_lookup_table(idx) + position_encoding[:T, :].to(self.device)
        x = self.transformer_blocks(x)
        logits = self.model_out_linear_layer(x)

        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            loss = None
        return logits, loss

    def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_crop = idx[:, -self.context_length:]
            logits, _ = self.forward(idx_crop)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

18.10 Chapter Summary

18.10.1 Code-to-Concept Mapping

ClassConceptChapter
FeedForwardNetworkfeed forward networkCh. 7
Attentionsingle-head attentionCh. 9-12
MultiHeadAttentionmulti-head attentionCh. 11
TransformerBlockTransformer blockCh. 13
Modelcomplete modelCh. 15

18.10.2 Parameter Count Estimate

Assuming d_model=512, num_heads=8, num_blocks=6, vocab_size=50,000:

ComponentFormulaParameters
Token Embeddingvocab × d_model~25.6M
Attention (×6)4 × d_model² × 6~6.3M
FFN (×6)2 × d_model × 4×d_model × 6~12.6M
Output Lineard_model × vocab~25.6M

Total: approximately 70M parameters

18.10.3 Core Insight

model.py is just the components from all the earlier chapters connected with PyTorch. Each class maps to one concept: FFN, Attention, MultiHeadAttention, TransformerBlock, Model. Understanding the concepts makes the code obvious, not the reverse.


Chapter Checklist

After this chapter you should be able to:

  • Implement FeedForwardNetwork independently.
  • Implement Attention (including Causal Mask) independently.
  • Implement MultiHeadAttention independently.
  • Explain the difference between physically separate and logically separate MHA implementations.
  • Explain the complete Model.forward() data flow.

Complete Code

The complete implementation is on GitHub:

github.com/waylandzhang/Transformer-from-scratch

Includes model.py, train.py, inference.py, and a step-by-step Jupyter notebook.


See You in the Next Chapter

The model exists now, but it does not know anything — every parameter is randomly initialized. Give it a prompt and it will output noise.

Chapter 19 writes the training loop: load data, forward pass, compute loss, backprop, update parameters. By the end, the model actually learns to predict the next token.

Cite this page
Zhang, Wayland (2026). Chapter 18: Writing model.py - Model Definition. In Transformer Architecture: From Intuition to Implementation. https://waylandz.com/llm-transformer-book-en/chapter-18-model-py
@incollection{zhang2026transformer_chapter_18_model_py,
  author = {Zhang, Wayland},
  title = {Chapter 18: Writing model.py - Model Definition},
  booktitle = {Transformer Architecture: From Intuition to Implementation},
  year = {2026},
  url = {https://waylandz.com/llm-transformer-book-en/chapter-18-model-py}
}