Files
xtrain/crates/xtrain-model/src/model.rs
Gahow Wang c36cdf74d1 Merge t18-dropout into main
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

# Conflicts:
#	README.md
#	crates/xtrain-autodiff/tests/autograd.rs
#	crates/xtrain-model/src/model.rs
#	crates/xtrain-train/src/bin/train.rs
#	crates/xtrain-train/src/train_loop.rs
#	docs/evolution.md
2026-06-18 00:41:41 +08:00

519 lines
22 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! The tiny transformer forward graph + parameter container (Phase T5).
#![cfg(not(no_cuda))]
use std::cell::Cell;
use crate::config::Config;
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
use xtrain_tensor::{DType, Device, Tensor};
/// One decoder block's learnable tensors.
struct Block {
attn_norm: Var, // [dim]
wq: Var, // [dim, dim]
wk: Var, // [dim, dim]
wv: Var, // [dim, dim]
q_norm: Var, // [head_dim] — per-head QK-norm (Qwen3-style)
k_norm: Var, // [head_dim]
wo: Var, // [dim, dim]
ffn_norm: Var, // [dim]
w_gate: Var, // [dim, ffn_hidden]
w_up: Var, // [dim, ffn_hidden]
w_down: Var, // [ffn_hidden, dim]
}
/// A tiny RoPE+RMSNorm+SwiGLU decoder. Holds every parameter as a leaf [`Var`];
/// `forward` builds an autograd graph over them.
pub struct TinyTransformer {
cfg: Config,
embed: Var, // [vocab, dim]
blocks: Vec<Block>,
final_norm: Var, // [dim]
lm_head: Var, // [dim, vocab]
/// Compute dtype for the forward graph (Phase T12). `F32` (default) = the
/// original path, bit-identical to T10/T11. `BF16` = mixed precision: the
/// parameter leaves stay fp32 (master), but each linear's weight is cast to
/// bf16 on the fly and the activation stream flows bf16 (see
/// `docs/11-bf16-mixed-precision.md`). The cast op's backward upcasts the bf16
/// weight grad back to fp32, so AdamW/clip/DDP stay fp32 and unchanged.
compute_dtype: DType,
/// Activation recomputation / gradient checkpointing (Phase T13, KI-3). When
/// `true`, each transformer block's forward runs through
/// [`xtrain_autodiff::checkpoint`]: the block's internal activations are NOT
/// kept on the tape during forward (only the block input is), and the block
/// forward is re-run during backward to recover them. Trades ~one extra forward
/// per block for a large drop in peak activation memory → lets dim1024 batch32
/// fit. Default `false` = the unchanged path (every activation stored), so
/// existing numerics are bit-identical; recompute is mathematically exact, so
/// grads match the non-checkpointed path within fp tolerance.
recompute: bool,
/// Fused flash-attention (Phase T14). When `true`, the SDPA core runs through
/// the hand-written single fused kernel ([`ops::flash_attention`]): online
/// softmax over KV tiles, the `[bh,seq,seq]` score matrix NEVER materialized,
/// backward caches only the O(N) logsumexp. Default `false` = the composed T10
/// path (`cublasSgemmStridedBatched`×2 + causal-softmax kernel, O(N²) probs),
/// so the default graph is unchanged. Mathematically the same SDPA → grads/loss
/// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`.
use_flash: bool,
/// Training mode for dropout (Phase T18). `true` → the attn/MLP sub-block
/// outputs pass through `ops::dropout` (with `cfg.dropout` and a per-step,
/// per-site seed); `false` (default) → dropout is identity (eval/sampling/
/// export). `Cell` so `train()`/`eval()` flip it through `&self` (the forward
/// takes `&self`). When `cfg.dropout == 0` this flag is irrelevant — the graph
/// is bit-identical to the no-dropout path either way.
training: Cell<bool>,
/// Per-step dropout RNG seed (Phase T18). Bumped once at the start of each
/// TRAINING forward so every step draws fresh masks; combined with the layer
/// index + a per-site constant to give each dropout site its own seed. The RNG
/// is counter-based, so re-running a checkpointed block's forward in backward
/// (T13) reproduces the same seed → the same mask (recompute stays exact).
step_seed: Cell<u64>,
}
impl TinyTransformer {
/// Build a model with parameters initialised from `init(shape) -> host data`.
/// The caller controls initialisation (deterministic for tests / PyTorch
/// parity). `init` receives the logical shape and returns row-major data.
pub fn new(cfg: Config, device: Device, mut init: impl FnMut(&[usize]) -> Vec<f32>) -> Self {
let leaf = |data: Vec<f32>, shape: &[usize]| -> Var {
Var::leaf(Tensor::from_slice(&data, shape).to_device(device))
};
let mut mk = |shape: &[usize]| -> Var {
let data = init(shape);
assert_eq!(data.len(), shape.iter().product::<usize>(), "init size");
leaf(data, shape)
};
let embed = mk(&[cfg.vocab, cfg.dim]);
let blocks = (0..cfg.n_layers)
.map(|_| Block {
attn_norm: mk(&[cfg.dim]),
wq: mk(&[cfg.dim, cfg.dim]),
wk: mk(&[cfg.dim, cfg.dim]),
wv: mk(&[cfg.dim, cfg.dim]),
q_norm: mk(&[cfg.head_dim]),
k_norm: mk(&[cfg.head_dim]),
wo: mk(&[cfg.dim, cfg.dim]),
ffn_norm: mk(&[cfg.dim]),
w_gate: mk(&[cfg.dim, cfg.ffn_hidden]),
w_up: mk(&[cfg.dim, cfg.ffn_hidden]),
w_down: mk(&[cfg.ffn_hidden, cfg.dim]),
})
.collect();
let final_norm = mk(&[cfg.dim]);
let lm_head = mk(&[cfg.dim, cfg.vocab]);
Self {
cfg,
embed,
blocks,
final_norm,
lm_head,
compute_dtype: DType::F32,
recompute: false,
use_flash: false,
training: Cell::new(false),
step_seed: Cell::new(0),
}
}
pub fn config(&self) -> &Config {
&self.cfg
}
/// Set the forward compute dtype (Phase T12). `BF16` enables mixed precision
/// (fp32 master weights, bf16 linears + activations); `F32` (the default) is
/// the unchanged full-precision path. Builder-style so existing call sites
/// that don't opt in keep the fp32 numerics bit-for-bit.
pub fn with_compute_dtype(mut self, dtype: DType) -> Self {
assert!(
matches!(dtype, DType::F32 | DType::BF16),
"compute_dtype must be F32 or BF16"
);
self.compute_dtype = dtype;
self
}
pub fn compute_dtype(&self) -> DType {
self.compute_dtype
}
/// Enable per-block activation recomputation / gradient checkpointing (Phase
/// T13). Builder-style and opt-in; default off keeps the unchanged tape (every
/// activation stored). On, each block's forward is wrapped in
/// [`xtrain_autodiff::checkpoint`] — exact grads, lower peak activation memory.
pub fn with_recompute(mut self, recompute: bool) -> Self {
self.recompute = recompute;
self
}
pub fn recompute(&self) -> bool {
self.recompute
}
/// Enable the fused flash-attention SDPA core (Phase T14). Builder-style and
/// opt-in; default off keeps the composed T10 path (so the default graph is
/// unchanged). On, the SDPA runs through [`ops::flash_attention`] — same SDPA
/// math, online softmax, no materialized `[bh,seq,seq]` scores.
pub fn with_flash(mut self, use_flash: bool) -> Self {
self.use_flash = use_flash;
self
}
pub fn use_flash(&self) -> bool {
self.use_flash
}
/// Switch to training mode (Phase T18): dropout (if `cfg.dropout > 0`) is
/// active in subsequent forwards. The training loop calls this before stepping.
pub fn train(&self) {
self.training.set(true);
}
/// Switch to eval mode (Phase T18): dropout is identity. Held-out eval,
/// autoregressive sampling, and weight export all run in this mode (default).
pub fn eval(&self) {
self.training.set(false);
}
pub fn is_training(&self) -> bool {
self.training.get()
}
/// Builder-style train/eval toggle (Phase T18) — handy for tests that want a
/// model fixed in one mode. Equivalent to [`train`](Self::train) /
/// [`eval`](Self::eval) but chains off `new(..)`.
pub fn with_training(self, training: bool) -> Self {
self.training.set(training);
self
}
/// All learnable parameters, in a stable order. The optimizer (a hand-written
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
/// `backward()`.
pub fn params(&self) -> Vec<Var> {
let mut ps = vec![self.embed.clone()];
for b in &self.blocks {
ps.extend([
b.attn_norm.clone(),
b.wq.clone(),
b.wk.clone(),
b.wv.clone(),
b.q_norm.clone(),
b.k_norm.clone(),
b.wo.clone(),
b.ffn_norm.clone(),
b.w_gate.clone(),
b.w_up.clone(),
b.w_down.clone(),
]);
}
ps.push(self.final_norm.clone());
ps.push(self.lm_head.clone());
ps
}
/// Forward over a single sequence of token `ids` (`[seq]` I32 on this
/// model's device). Returns the logits [`Var`] of shape `[seq, vocab]`. This
/// is the batch-1 special case of [`forward_batched`](Self::forward_batched)
/// (used by the autoregressive sampler / inference path).
pub fn forward(&self, ids: &Tensor) -> Var {
self.forward_batched(ids, 1)
}
/// Batched forward over `batch` sequences of equal length `seq`, flattened to
/// `[batch*seq]` I32 ids in sequence-major order (sequence 0's `seq` tokens,
/// then sequence 1's, …). Returns logits `[batch*seq, vocab]` in the SAME flat
/// layout. The whole graph runs on the flattened tokens so every linear
/// projection is ONE big `[batch*seq, dim] × [dim, out]` GEMM (the
/// GPU-filling win); only attention is sequence-aware (per-sequence causal
/// mask + RoPE position, NO cross-sequence attention).
pub fn forward_batched(&self, ids: &Tensor, batch: usize) -> Var {
let total = ids.shape()[0];
assert_eq!(
total % batch,
0,
"ids len {total} not divisible by batch {batch}"
);
let seq = total / batch;
// Dropout (T18) is active only in training mode with p>0; otherwise it is
// identity (`ops::dropout` no-ops at p==0). Bump the per-step seed ONCE per
// training forward so each step draws fresh masks (counter-based RNG, so a
// checkpointed block's recompute reproduces the same seed → same mask).
let dropout_p = if self.training.get() {
self.cfg.dropout
} else {
0.0
};
if dropout_p > 0.0 {
self.step_seed.set(self.step_seed.get().wrapping_add(1));
}
let base_seed = self.step_seed.get();
// Embedding gathers from the fp32 master table; in bf16 mode cast the
// activation stream to bf16 here (norms are cast to bf16 gammas too).
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32
if self.compute_dtype == DType::BF16 {
h = ops::cast(&h, DType::BF16);
}
for (li, b) in self.blocks.iter().enumerate() {
// Per-layer dropout seed: a deterministic function of (base_seed,
// layer index) — NOT a mutable counter — so the checkpoint recompute
// (which re-derives it from the captured base_seed/li) gets the same
// masks. The block derives its two per-site seeds from this.
let block_seed = base_seed
.wrapping_mul(0x100000001B3)
.wrapping_add(li as u64);
h = if self.recompute {
// Activation recomputation (T13): run the whole block forward inside
// `checkpoint` so its internal activations aren't kept on the tape;
// the block forward is re-run in backward to recover the grads. The
// segment fn captures only `Copy` config (no borrow of `self`) and
// receives the block's params via the slice, in `block_params` order.
// `flash` is captured too → the recompute segment also runs flash;
// `dropout_p`/`block_seed` are captured so the recompute re-derives
// the same per-site dropout masks (counter-based RNG, exact).
let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash);
let seg = move |x: &Var, p: &[Var]| {
block_forward(cfg, cdt, flash, batch, seq, dropout_p, block_seed, x, p)
};
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
} else {
block_forward(
self.cfg,
self.compute_dtype,
self.use_flash,
batch,
seq,
dropout_p,
block_seed,
&h,
&b.block_params(),
)
};
}
let h = ops::rms_norm(
&h,
&norm_gamma(self.compute_dtype, &self.final_norm),
self.cfg.eps,
);
// lm_head matmul in compute dtype. Logits stay bf16 in bf16 mode — the
// cross_entropy op upcasts to fp32 internally (no persistent fp32 logits
// buffer, a real saving at vocab 50257), and its backward casts dx back.
linear(self.compute_dtype, &h, &self.lm_head) // [batch*seq, vocab]
}
/// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32).
pub fn loss(&self, ids: &Tensor, targets: &Tensor) -> Var {
let logits = self.forward(ids);
ops::cross_entropy(&logits, targets)
}
/// Batched cross-entropy mean loss: `forward_batched(ids, batch)` against
/// flat `targets` (`[batch*seq]` I32, same sequence-major layout). The CE mean
/// is over all `batch*seq` rows — identical to averaging the per-sequence
/// losses, so the loss value matches the looped single-sequence path.
pub fn loss_batched(&self, ids: &Tensor, targets: &Tensor, batch: usize) -> Var {
let logits = self.forward_batched(ids, batch);
ops::cross_entropy(&logits, targets)
}
}
impl Block {
/// The block's learnable leaves, in the fixed order the segment forward
/// (`block_forward`) indexes them — matches the per-block slice in
/// [`TinyTransformer::params`]. This is the param order `checkpoint` passes to
/// the recompute closure.
fn block_params(&self) -> Vec<Var> {
vec![
self.attn_norm.clone(),
self.wq.clone(),
self.wk.clone(),
self.wv.clone(),
self.q_norm.clone(),
self.k_norm.clone(),
self.wo.clone(),
self.ffn_norm.clone(),
self.w_gate.clone(),
self.w_up.clone(),
self.w_down.clone(),
]
}
}
/// Project `x` (activation, in the compute dtype) by weight `w` (an fp32 master
/// leaf). In bf16 mode the weight is cast to bf16 via the autograd `cast` op (whose
/// backward upcasts the grad to fp32); in fp32 mode this is just `matmul(x, w)`.
fn linear(cdt: DType, x: &Var, w: &Var) -> Var {
match cdt {
DType::F32 => ops::matmul(x, w),
DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)),
_ => unreachable!(),
}
}
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast op,
/// grad upcast) in bf16 mode; identity in fp32 mode.
fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
match cdt {
DType::F32 => gamma.clone(),
DType::BF16 => ops::cast(gamma, DType::BF16),
_ => unreachable!(),
}
}
/// One transformer block's forward: pre-norm + multi-head causal attention +
/// (T18) dropout + residual, then pre-norm + SwiGLU MLP + dropout + residual.
/// Attention runs the composed or fused-flash (T14) SDPA per `flash`. Pure in
/// `(cfg, cdt, flash, batch, seq, dropout_p, block_seed, input, params)` (no
/// `&self`, all `Copy`) so it can be the segment fn of
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13) — the
/// recompute re-derives the same per-site seeds, so the dropout masks are
/// reproduced bit-for-bit. `dropout_p == 0` makes `ops::dropout` a no-op (the
/// graph is then identical to the pre-T18 path). `params` is the block's leaves in
/// [`Block::block_params`] order.
#[allow(clippy::too_many_arguments)]
fn block_forward(
cfg: Config,
cdt: DType,
flash: bool,
batch: usize,
seq: usize,
dropout_p: f32,
block_seed: u64,
h: &Var,
p: &[Var],
) -> Var {
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
// Per-site dropout seeds (XOR a site constant into the block seed) so the two
// residual-path dropouts draw independent masks within the same step/layer.
let attn_seed = block_seed ^ 0x0A7700;
let ffn_seed = block_seed ^ 0x0FF700;
// --- Attention sub-block (pre-norm + dropout + residual) ---
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
let attn = attention(
cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
);
let attn = ops::dropout(&attn, dropout_p, attn_seed);
let h = ops::add(h, &attn);
// --- MLP sub-block (pre-norm + dropout + residual) ---
let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps);
let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down);
let mlp = ops::dropout(&mlp, dropout_p, ffn_seed);
ops::add(&h, &mlp)
}
/// Multi-head causal self-attention over a flattened batch. `x`:[batch*seq,dim]
/// (already normed), laid out sequence-major. The Q/K/V/O projections are big
/// `[batch*seq, dim]` GEMMs; the scaled-dot-product attention itself runs as a
/// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each attends
/// within its own `[seq,seq]` causal window (NO cross-sequence attention), with
/// RoPE positions reset per sequence (`period = seq`). Causal masking is applied
/// inside the fused op's softmax kernel (no additive `[seq,seq]` mask tensor).
#[allow(clippy::too_many_arguments)]
fn attention(
cfg: Config,
cdt: DType,
flash: bool,
batch: usize,
seq: usize,
x: &Var,
wq: &Var,
wk: &Var,
wv: &Var,
q_norm: &Var,
k_norm: &Var,
wo: &Var,
) -> Var {
let (nh, hd) = (cfg.n_heads, cfg.head_dim);
let total = batch * seq;
let bh = batch * nh;
let scale = 1.0 / (hd as f32).sqrt();
// Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor.
// [B*S,dim] @ [dim,dim] = [B*S,dim]
// reshape [B*S, nh, hd]
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
// rope [B*S, nh, hd] with per-sequence position (period = seq)
// reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd]
let to_bh = |proj: Var, norm: Option<&Var>| -> Var {
let r = ops::reshape(&proj, &[total, nh, hd]);
let r = match norm {
// Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd,
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
Some(gamma) => {
let flat = ops::reshape(&r, &[total * nh, hd]);
let normed = ops::rms_norm(&flat, &norm_gamma(cdt, gamma), cfg.eps);
let r = ops::reshape(&normed, &[total, nh, hd]);
ops::rope(&r, cfg.rope_theta, seq)
}
None => r,
};
let r = ops::reshape(&r, &[batch, seq, nh, hd]);
let t = ops::transpose_4d12(&r); // [B, nh, S, hd]
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
};
let q = to_bh(linear(cdt, x, wq), Some(q_norm));
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
let v = to_bh(linear(cdt, x, wv), None);
// Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
// single fused flash kernel (online softmax, no materialized [bh,S,S] scores);
// otherwise the composed T10 path (2 batched GEMMs + 1 causal-softmax kernel).
let out = if flash {
ops::flash_attention(&q, &k, &v, scale) // [B*nh, S, hd]
} else {
ops::attention(&q, &k, &v, scale) // [B*nh, S, hd]
};
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
// [B,S,nh,hd] → [B*S, dim].
let out = ops::reshape(&out, &[batch, nh, seq, hd]);
let out = ops::transpose_4d12(&out); // [B, S, nh, hd]
let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim]
linear(cdt, &concat, wo) // out projection
}
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim].
fn swiglu_mlp(cdt: DType, x: &Var, w_gate: &Var, w_up: &Var, w_down: &Var) -> Var {
let gate = linear(cdt, x, w_gate); // [seq, ffn_hidden]
let up = linear(cdt, x, w_up); // [seq, ffn_hidden]
let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up
linear(cdt, &act, w_down) // [seq, dim]
}
/// Materialise a parameter's value back to a host `Vec<f32>` (for the GD step
/// and PyTorch parity export).
pub fn param_to_host(v: &Var) -> Vec<f32> {
v.value().to_device(Device::Cpu).as_slice::<f32>().to_vec()
}
/// Build an I32 id tensor on `device` from token ids.
pub fn ids_tensor(ids: &[i32], device: Device) -> Tensor {
Tensor::from_slice(ids, &[ids.len()]).to_device(device)
}
/// Flatten `batch` equal-length sequences into one `[batch*seq]` I32 tensor in
/// sequence-major order (the layout `forward_batched` expects). Each row of
/// `seqs` is one sequence; all must have the same length.
pub fn batched_ids_tensor(seqs: &[Vec<i32>], device: Device) -> Tensor {
assert!(!seqs.is_empty(), "empty batch");
let seq = seqs[0].len();
let mut flat = Vec::with_capacity(seqs.len() * seq);
for s in seqs {
assert_eq!(s.len(), seq, "ragged batch: sequences must be equal length");
flat.extend_from_slice(s);
}
Tensor::from_slice(&flat, &[flat.len()]).to_device(device)
}