One-sentence summary: The training loop is four steps repeated: forward pass → compute loss → backpropagate → update parameters. Under 100 lines of code, it transforms a randomly initialized model into one that can predict the next token.

Complete code repository: github.com/waylandzhang/Transformer-from-scratch


19.1 The Nature of Training

19.1.1 What Does a Model Know at Initialization?

A freshly created model has all parameters randomly initialized. Ask it to predict the next token and it will output near-uniform noise.

# randomly initialized model
model = Model(h_params)

# input: "The agent opened a pull request"
input_ids = tokenizer.encode("The agent opened a pull request")

# output: essentially random tokens
output = model.generate(input_ids)
# might produce: "The agent opened a pull request zxtq moon orbit..."

19.1.2 The Training Goal

Given large amounts of text, teach the model to predict the next token at every position:

Input:  The  agent  opened  a  pull  request
Target: agent opened  a     pull request  for

The model needs to learn:
- see "The"         -> predict "agent"
- see "The agent"   -> predict "opened"
- see "The agent opened" -> predict "a"
- ...

19.1.3 The Four Training Steps

1. Forward pass:      feed input, get predictions
2. Compute loss:      how wrong are the predictions?
3. Backpropagate:     compute gradient of loss w.r.t. every parameter
4. Update parameters: move parameters in the direction that reduces loss

Repeat these four steps. Loss gradually decreases. The model gradually improves.


19.2 Hyperparameter Configuration

19.2.1 Hyperparameter Dictionary

# hyperparameter configuration
h_params = {
    # model architecture
    "d_model": 80,           # embedding dimension (small value for educational model)
    "num_blocks": 6,         # number of Transformer blocks
    "num_heads": 4,          # number of attention heads

    # training configuration
    "batch_size": 2,         # samples per training step
    "context_length": 128,   # context length (sequence length)
    "max_iters": 500,        # total training steps
    "learning_rate": 1e-3,   # learning rate

    # regularization
    "dropout": 0.1,          # Dropout probability

    # evaluation configuration
    "eval_interval": 50,     # evaluate every N steps
    "eval_iters": 10,        # batches to use per evaluation

    # device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # random seed (for reproducibility)
    "TORCH_SEED": 1337
}

19.2.2 Key Hyperparameters Explained

HyperparameterRoleTypical range
batch_sizesamples per training step2-32 (limited by VRAM)
context_lengthhow many tokens the model sees at once128-2048
learning_rateparameter update step size1e-3 to 1e-5
max_iterstotal training stepshundreds to millions
dropoutrandom drop probability0.1-0.3

19.3 Data Preparation

19.3.1 Load Raw Text

# load training data
with open('data/订单商品名称.csv', 'r', encoding="utf-8") as file:
    text = file.read()

print(f"文本长度:{len(text):,} 字符")
# output: Text length: 324,523 characters (value still in Chinese in real output)

19.3.2 Tokenization

# tokenize with TikToken
import tiktoken

tokenizer = tiktoken.get_encoding("cl100k_base")
tokenized_text = tokenizer.encode(text)

print(f"Token 数量:{len(tokenized_text):,}")
# output: Token count: 77,919 (value still in Chinese in real output)

19.3.3 Convert to Tensor and Split Dataset

# convert to PyTorch Tensor
tokenized_text = torch.tensor(tokenized_text, dtype=torch.long, device=h_params['device'])

# 90% train, 10% validation
train_size = int(len(tokenized_text) * 0.9)
train_data = tokenized_text[:train_size]
val_data = tokenized_text[train_size:]

print(f"训练集:{len(train_data):,} tokens")
print(f"验证集:{len(val_data):,} tokens")

19.3.4 Batch Sampling

# randomly sample a batch
def get_batch(split: str):
    """
    Sample one training batch.

    Args:
        split: 'train' or 'valid'

    Returns:
        x: input  [batch_size, context_length]
        y: target [batch_size, context_length]  (shifted right by one)
    """
    data = train_data if split == 'train' else val_data

    # randomly sample starting positions
    idxs = torch.randint(
        low=0,
        high=len(data) - h_params['context_length'],
        size=(h_params['batch_size'],)
    )

    # build input and target
    x = torch.stack([data[idx:idx + h_params['context_length']] for idx in idxs])
    y = torch.stack([data[idx + 1:idx + h_params['context_length'] + 1] for idx in idxs])

    return x.to(h_params['device']), y.to(h_params['device'])

19.3.5 Understanding the x and y Relationship

Assume context_length = 8

Raw data: [The, agent, opened, a, pull, request, for, review, .]
              |
x (input):  [The, agent, opened, a, pull, request, for, review]
y (target): [agent, opened, a, pull, request, for, review, .]

y is x shifted right by one. The model must learn: x[i] -> y[i]

Every training sequence simultaneously provides 8 training examples — one per position.


19.4 Loss Function

19.4.1 Cross-Entropy Loss

The model outputs a probability distribution over the vocabulary at every position. We use cross-entropy loss to measure the gap between prediction and reality:

# compute loss
loss = F.cross_entropy(
    input=logits_reshaped,    # model predictions [batch*seq, vocab_size]
    target=targets_reshaped   # true targets [batch*seq]
)

19.4.2 What Loss Values Mean

  • Random initialization: loss ≈ 10-11 (close to ln(vocab_size))
  • After training: loss can reach 2-4
  • Overfitting: training loss low, validation loss rising

Random initialization produces uniform-ish predictions, which is exactly what maximum-entropy predicts for an unbiased uniform distribution over ~50,000 tokens.


19.5 Evaluation Function

19.5.1 Why Evaluate Separately?

Training loss going down does not guarantee the model is learning — it might be memorizing the training set. We need to check performance on validation data the model has never seen.

19.5.2 Evaluation Code

# evaluation function
@torch.no_grad()  # skip gradient computation to save memory
def estimate_loss():
    out = {}
    model.eval()  # switch to evaluation mode (disables Dropout)

    for split in ['train', 'valid']:
        losses = torch.zeros(h_params['eval_iters'])

        for k in range(h_params['eval_iters']):
            x_batch, y_batch = get_batch(split)
            logits, loss = model(x_batch, y_batch)
            losses[k] = loss.item()

        out[split] = losses.mean()

    model.train()  # switch back to training mode
    return out

19.5.3 model.train() vs model.eval()

ModeDropoutBatchNorm
model.train()randomly drops activationsuses batch statistics
model.eval()no droppinguses stored statistics

Evaluation must use model.eval(). Otherwise results will have random variation from Dropout, making the loss estimate unreliable.


19.6 Optimizer

19.6.1 AdamW

# create optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=h_params['learning_rate']
)

AdamW combines:

  • Momentum: accumulates history of gradient directions
  • Adaptive learning rate: each parameter has its own effective step size
  • Weight decay: L2 regularization that prevents overfitting

19.6.2 Why AdamW?

OptimizerProsCons
SGDsimple, good generalizationslow convergence
Adamfast convergencecan generalize worse
AdamWfast convergence + good generalizationslightly more complex

Modern large model training almost universally uses AdamW. For this educational model, it converges noticeably faster than SGD.


19.7 Training Loop

19.7.1 Complete Training Loop

# training loop
for step in range(h_params['max_iters']):

    # periodic evaluation
    if step % h_params['eval_interval'] == 0 or step == h_params['max_iters'] - 1:
        losses = estimate_loss()
        print(f'Step: {step}, '
              f'Training Loss: {losses["train"]:.3f}, '
              f'Validation Loss: {losses["valid"]:.3f}')

    # 1. sample a batch
    xb, yb = get_batch('train')

    # 2. forward pass
    logits, loss = model(xb, yb)

    # 3. backpropagation
    optimizer.zero_grad(set_to_none=True)  # clear gradients
    loss.backward()                         # compute gradients

    # 4. update parameters
    optimizer.step()

19.7.2 Each Step Explained

optimizer.zero_grad(): Clear the gradients from the previous step.

PyTorch accumulates gradients by default. If you do not zero them, each step adds new gradients on top of the old ones, producing completely wrong updates. set_to_none=True is slightly more memory-efficient than zeroing to zero.

loss.backward(): Run backpropagation through the computation graph.

This is where PyTorch's automatic differentiation earns its keep. It traces all operations from input to loss and computes the gradient of the loss with respect to every parameter, automatically.

optimizer.step(): Apply the gradient update.

parameter_new = parameter_old - learning_rate × gradient

19.8 Training Output Example

Step: 0, Training Loss: 10.847, Validation Loss: 10.852
Step: 50, Training Loss: 7.234, Validation Loss: 7.198
Step: 100, Training Loss: 5.421, Validation Loss: 5.456
Step: 150, Training Loss: 4.312, Validation Loss: 4.387
Step: 200, Training Loss: 3.876, Validation Loss: 3.921
Step: 250, Training Loss: 3.542, Validation Loss: 3.678
Step: 300, Training Loss: 3.298, Validation Loss: 3.512
Step: 350, Training Loss: 3.112, Validation Loss: 3.398
Step: 400, Training Loss: 2.987, Validation Loss: 3.287
Step: 450, Training Loss: 2.876, Validation Loss: 3.198
Step: 499, Training Loss: 2.798, Validation Loss: 3.145

What to observe:

  • Loss drops from ~10.8 to ~2.8 — the model is genuinely learning
  • Validation loss is consistently slightly higher than training loss — normal, it is unseen data
  • If validation loss starts rising while training loss falls, you have an overfitting problem

19.9 Saving the Model

19.9.1 Saving a Checkpoint

# save model
import os

if not os.path.exists('model/'):
    os.makedirs('model/')

torch.save({
    'model_state_dict': model.state_dict(),
    'h_params': h_params
}, 'model/model.ckpt')

print("Model saved to model/model.ckpt")

19.9.2 What to Save

ContentWhy
model.state_dict()all model parameters
h_paramshyperparameters needed to reconstruct the model architecture

Always save the hyperparameters alongside the weights. Without them, you cannot rebuild the model to load the weights into at inference time.


19.10 Complete train.py

"""
Train a Transformer model
"""
import os
import torch
import tiktoken
from model import Model

# GPU memory configuration
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
torch.cuda.empty_cache()

# hyperparameters
h_params = {
    "d_model": 80,
    "batch_size": 2,
    "context_length": 128,
    "num_blocks": 6,
    "num_heads": 4,
    "dropout": 0.1,
    "max_iters": 500,
    "learning_rate": 1e-3,
    "eval_interval": 50,
    "eval_iters": 10,
    "device": "cuda" if torch.cuda.is_available() else
              ("mps" if torch.backends.mps.is_available() else "cpu"),
    "TORCH_SEED": 1337
}
torch.manual_seed(h_params["TORCH_SEED"])

# load data
with open('data/订单商品名称.csv', 'r', encoding="utf-8") as file:
    text = file.read()

# tokenize
tokenizer = tiktoken.get_encoding("cl100k_base")
tokenized_text = tokenizer.encode(text)
max_token_value = max(tokenized_text) + 1
h_params['max_token_value'] = max_token_value
tokenized_text = torch.tensor(tokenized_text, dtype=torch.long, device=h_params['device'])

print(f"Total: {len(tokenized_text):,} tokens")

# split data
train_size = int(len(tokenized_text) * 0.9)
train_data = tokenized_text[:train_size]
val_data = tokenized_text[train_size:]

# initialize model
model = Model(h_params).to(h_params['device'])


def get_batch(split: str):
    data = train_data if split == 'train' else val_data
    idxs = torch.randint(low=0, high=len(data) - h_params['context_length'],
                         size=(h_params['batch_size'],))
    x = torch.stack([data[idx:idx + h_params['context_length']] for idx in idxs])
    y = torch.stack([data[idx + 1:idx + h_params['context_length'] + 1] for idx in idxs])
    return x.to(h_params['device']), y.to(h_params['device'])


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'valid']:
        losses = torch.zeros(h_params['eval_iters'])
        for k in range(h_params['eval_iters']):
            x_batch, y_batch = get_batch(split)
            logits, loss = model(x_batch, y_batch)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


# training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=h_params['learning_rate'])

for step in range(h_params['max_iters']):
    if step % h_params['eval_interval'] == 0 or step == h_params['max_iters'] - 1:
        losses = estimate_loss()
        print(f'Step: {step}, Training Loss: {losses["train"]:.3f}, '
              f'Validation Loss: {losses["valid"]:.3f}')

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# save model
if not os.path.exists('model/'):
    os.makedirs('model/')

torch.save({
    'model_state_dict': model.state_dict(),
    'h_params': h_params
}, 'model/model.ckpt')

print("Training complete. Model saved to model/model.ckpt")

19.11 Optional: WandB Training Tracking

19.11.1 What Is WandB?

Weights & Biases is a training monitoring tool. It can:

  • Visualize loss curves
  • Record hyperparameters
  • Compare across experiments

19.11.2 Integration Code

# WandB integration (optional)
import wandb

# initialize
run = wandb.init(
    project="LLMZhang_lesson_2",
    config={
        "d_model": h_params["d_model"],
        "batch_size": h_params["batch_size"],
        "context_length": h_params["context_length"],
        "max_iters": h_params["max_iters"],
        "learning_rate": h_params["learning_rate"],
    },
)

# log in training loop
for step in range(h_params['max_iters']):
    ...
    wandb.log({
        "train_loss": losses['train'].item(),
        "valid_loss": losses['valid'].item()
    })

WandB is optional for this educational model. For any run you care about repeating or comparing, it is worth the setup time.


19.12 Chapter Summary

19.12.1 Training Flow

1. Load data -> tokenize -> convert to Tensor -> split train/val

2. Training loop:
   for step in range(max_iters):
       x, y = get_batch('train')      # sample data
       logits, loss = model(x, y)     # forward pass
       optimizer.zero_grad()          # clear gradients
       loss.backward()                # backpropagation
       optimizer.step()               # update parameters

3. Save model -> torch.save()

19.12.2 Key Functions

FunctionRole
get_batch()randomly sample one batch
estimate_loss()evaluate on train and val sets
model.train()switch to training mode
model.eval()switch to evaluation mode
loss.backward()compute gradients via autodiff
optimizer.step()update parameters

19.12.3 Core Insight

train.py is under 100 lines but implements a complete training pipeline. The core is the four-step loop: forward pass → compute loss → backpropagate → update parameters. PyTorch's automatic differentiation means you only need to define the forward pass — the backward pass follows automatically.


Chapter Checklist

After this chapter you should be able to:

  • Explain the four steps of the training loop.
  • Explain the relationship between x and y (shifted by one token).
  • Explain the difference between model.train() and model.eval().
  • Write a simple training script from scratch.

Complete Code

The complete implementation is on GitHub:

github.com/waylandzhang/Transformer-from-scratch

Includes model.py, train.py, inference.py, and a step-by-step Jupyter notebook.


See You in the Next Chapter

The model is trained. Parameters are saved to disk. Now we want to use it.

Chapter 20 writes inference.py: load the checkpoint, encode a prompt, let the model generate autoregressively, and decode the output back to text. That is the moment the model "speaks" for the first time.

Cite this page
Zhang, Wayland (2026). Chapter 19: Writing train.py - The Training Loop. In Transformer Architecture: From Intuition to Implementation. https://waylandz.com/llm-transformer-book-en/chapter-19-train-py
@incollection{zhang2026transformer_chapter_19_train_py,
  author = {Zhang, Wayland},
  title = {Chapter 19: Writing train.py - The Training Loop},
  booktitle = {Transformer Architecture: From Intuition to Implementation},
  year = {2026},
  url = {https://waylandz.com/llm-transformer-book-en/chapter-19-train-py}
}