Skip to content

Inference & Chat

Training is only satisfying if you can actually talk to the result. The original generate_text.py does raw continuation for the base model, but it's hard-wired to the legacy config and has no chat template — so I added a small inference layer that loads any stage checkpoint (base / SFT / DPO / PPO / GRPO) and talks to it correctly.

Inference / chat flow

Mermaid source (live, editable)
flowchart LR
    CK[(any checkpoint)]:::ckpt --> LD[load_model_from_ckpt<br/>dims from stored cfg]:::proc
    LD --> MODE{chat or raw?}:::proc
    MODE -->|instruction model| CT[wrap in chat template]:::proc
    MODE -->|base model| RAW[raw prefix]:::proc
    CT --> GEN{{generate<br/>temperature / top-p / greedy}}:::model
    RAW --> GEN
    GEN --> DEC([decode → reply]):::eval
    classDef ckpt fill:#eeeeee,stroke:#555,stroke-width:2px,color:#222;
    classDef proc fill:#d6e8ff,stroke:#2c6fbb,stroke-width:2px,color:#0d2c52;
    classDef model fill:#ffe8a3,stroke:#d48806,stroke-width:2px,color:#5a3d00;
    classDef eval fill:#e8d6ff,stroke:#8e44ad,stroke-width:2px,color:#3d1a5a;

Load any checkpoint by its stored config

load_model_from_ckpt reads the model dimensions from the checkpoint's saved cfg, so you never re-specify n_embed/n_blocks, and it tolerates DDP / reward-head key prefixes:

ck = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg = {**(ck.get("cfg") or {}), **(overrides or {})}
model = Transformer(n_head=cfg["n_head"], n_embed=cfg["n_embed"], ...)
state = {k.removeprefix("module.").removeprefix("transformer."): v for k, v in state.items()}

Chat vs raw

generate_reply has two modes, reusing the same tested generation core as training/eval (batched_generate):

  • chat (default) — wraps your text in the chat template (optionally with a system message) and returns the decoded assistant turn. Use this for SFT/DPO/PPO/GRPO checkpoints.
  • raw (--raw) — treats your text as a prefix and returns the base model's continuation (no template). Use this for base_pretrained.pt.
if raw:
    ids = get_tokenizer().encode_ordinary(user_text)
else:
    ids = encode_prompt([{"role": "user", "content": user_text}])   # ...ends at <|assistant|>
out = batched_generate(model, [ids], max_new_tokens, device=device,
                       temperature=temperature, top_k=top_k, top_p=top_p, greedy=greedy)

Decoding is defensive — decode drops the EOT terminator and any padding-vocab ids (the model's vocab is padded to 50304 but r50k_base only decodes 0–50255).

The CLI

scripts/chat.py is one-shot or an interactive REPL:

# instruction-tuned models (chat template applied automatically)
PYTHONPATH=. python scripts/chat.py --ckpt /ephemeral/ckpts/sft.pt  --prompt "What is 13 + 29?"
PYTHONPATH=. python scripts/chat.py --ckpt /ephemeral/ckpts/grpo.pt --prompt "..." --greedy
# base-model continuation
PYTHONPATH=. python scripts/chat.py --ckpt /ephemeral/ckpts/base_pretrained.pt --raw --prompt "Once upon a time"
# interactive REPL (omit --prompt)
PYTHONPATH=. python scripts/chat.py --ckpt /ephemeral/ckpts/sft.pt

Sampling controls: --temperature, --top_p, --top_k, or --greedy for deterministic argmax. Runs on --device cuda or cpu (both verified).

Sampling knobs, briefly

  • greedy — reproducible, best for eval / math (--greedy).
  • temperature — higher = more random; ~0.7–1.0 for open-ended chat.
  • top_p / top_k — nucleus / top-k truncation to cut the long tail of unlikely tokens.

That's the full loop: pretrain → align → reason → measure → chat. Back to the overview.