Skip to content

Stage 2 — Supervised Fine-Tuning (SFT)

The base model can continue text but it doesn't know it's supposed to answer you. SFT fixes that by showing it thousands of (instruction, response) pairs and training it to produce the response. The only real difference from pretraining is a per-token loss mask: we compute the loss on the assistant tokens and ignore the prompt.

SFT masked-loss flow

Mermaid source (live, editable)
flowchart LR
    H[(sft_packed.h5<br/>tokens + loss_mask)]:::store --> B[batch:<br/>tokens, mask]:::proc
    B --> M{{Transformer<br/>logits}}:::model
    M --> SH[shift: predict t+1]:::proc
    SH --> CE[token cross-entropy]:::loss
    MASK([loss_mask = 1 on<br/>assistant tokens]):::data --> CE
    CE --> AVG[average over<br/>masked tokens only]:::loss --> UPD[AdamW step]:::model
    classDef store fill:#cdece8,stroke:#16a085,stroke-width:2px,color:#0a3d33;
    classDef proc fill:#d6e8ff,stroke:#2c6fbb,stroke-width:2px,color:#0d2c52;
    classDef data fill:#d6ffd9,stroke:#27ae60,stroke-width:2px,color:#143d1a;
    classDef model fill:#ffe8a3,stroke:#d48806,stroke-width:2px,color:#5a3d00;
    classDef loss fill:#ffd6d6,stroke:#c0392b,stroke-width:2px,color:#5c1212;

The masked loss

The whole stage hinges on sft_loss. It's ordinary next-token cross-entropy, except every target position is weighted by the mask so only completion tokens count:

def sft_loss(logits, tokens, loss_mask):
    logits = logits[:, :-1, :]        # predict token t+1 from position t (same shift as pretraining)
    targets = tokens[:, 1:]
    mask = loss_mask[:, 1:].to(logits.dtype)
    V = logits.size(-1)
    ce = F.cross_entropy(logits.reshape(-1, V).float(), targets.reshape(-1).long(), reduction="none")
    ce = ce.view(targets.shape) * mask
    return ce.sum() / mask.sum().clamp(min=1.0)     # mean over ASSISTANT tokens only

The mask itself was produced at data-prep time by encode_chat (see 01_data_pipeline.md) and packed alongside the tokens. The .float() on the logits keeps the cross-entropy numerically clean under bf16.

The trainer

train_sft.py loads the pretrained base with load_backbone_from_ckpt, then runs a compact loop — autocast forward, masked loss, clip, step, cosine LR — with periodic dev evaluation:

tokens, mask, epoch = next(train_it)
with amp_autocast(cfg.amp_dtype, ctx.device):
    logits, _ = model(tokens)
    loss = sft_loss(logits, tokens, mask)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
optimizer.step()

Batches come from get_sft_batch_iterator, which shards the packed rows across DDP ranks and yields (tokens, loss_mask, epoch).

Run it

PYTHONPATH=. python scripts/train_sft.py                                   # single GPU
PYTHONPATH=. torchrun --standalone --nproc_per_node=2 scripts/train_sft.py # both GPUs
# tune: --lr 1e-5 --epochs 3 --batch_size 16

What the numbers mean

  • train_loss / ppl — masked cross-entropy (and its perplexity) over assistant tokens; should drop well below the base model's loss. To sanity-check the mechanics I ran an overfit test on 8 rows and watched the loss collapse 11.0 → 4.7, confirming the gradient path learns.
  • dev_loss — the same masked loss on a held-out split (sft_dev_packed.h5); the honest signal.
  • GSM8K dev accuracy — after SFT the model both follows instructions and emits the <answer>…</answer> format, so this should rise above the base model (see 08_evaluation.md).

The result is saved to /ephemeral/ckpts/sft.pt and becomes the starting point for the reward model, DPO, PPO and GRPO.

➡️ Next: Stage 3 — Reward Model or jump to DPO.