Files
xtrain/crates/xtrain-model/tests/overfit.rs
Gahow Wang e3912c2380 model: tiny RoPE+RMSNorm+SwiGLU transformer + overfit test
New crate xtrain-model: a from-scratch decoder built entirely from the
autodiff op set.
- Config (tiny: dim=32, 2 layers, 2 heads, head_dim=16, ffn=64).
- TinyTransformer: embedding -> N x {pre-RMSNorm -> multi-head causal
  attention (RoPE, additive causal mask, per-head SDPA) -> residual;
  pre-RMSNorm -> SwiGLU MLP -> residual} -> final RMSNorm -> LM head.
  x@W weight convention (engine GEMM is plain A@B); dim=n_heads*head_dim.
- params()/zero_grad-able leaves for the optimizer; param_to_host export.
- overfit test: char-level bring-up (embedded text -> vocab -> shifted
  targets), minimal hand-written GD (p -= lr*grad) memorises one fixed
  batch -> loss ~0 + greedy argmax matches targets. End-to-end fwd+bwd
  correctness signal. Gated #![cfg(not(no_cuda))].

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:05:20 +08:00

134 lines
4.6 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// End-to-end acceptance for the Phase T5 tiny transformer: overfit one fixed
// char-level batch with a hand-written gradient-descent step and assert the loss
// collapses toward 0. This is THE signal that the whole fwd+bwd graph (embedding,
// RMSNorm, RoPE, multi-head attention, SwiGLU, LM head, cross-entropy) is wired
// correctly — a single buggy backward would stall the loss.
//
// The optimizer here is deliberately minimal (`p ← p lr·grad`); AdamW / LR
// schedule / real data are T6. Gated behind `not(no_cuda)` (runs on dash5).
#![cfg(not(no_cuda))]
use xtrain_autodiff::tape::Var;
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, ids_tensor};
use xtrain_tensor::Device;
// Deterministic LCG fill in [-scale, scale).
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn require_gpu() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
}
// One GD step over every parameter: p ← p lr·grad, then zero the grad.
fn gd_step(params: &[Var], lr: f32) {
for p in params {
if let Some(g) = p.grad() {
let updated = p.value().add(&g.scale(-lr));
p.set_value(updated);
}
p.zero_grad();
}
}
#[test]
fn overfit_tiny_batch() {
require_gpu();
let device = Device::Cuda(0);
// --- Char-level bring-up: tiny embedded text → vocab → (input, target). ---
let text = "hello tiny transformer world";
let mut vocab_chars: Vec<char> = text.chars().collect();
vocab_chars.sort_unstable();
vocab_chars.dedup();
let vocab = vocab_chars.len();
let stoi = |c: char| vocab_chars.iter().position(|&x| x == c).unwrap() as i32;
let tokens: Vec<i32> = text.chars().map(stoi).collect();
// Next-token prediction: input = tokens[..n-1], target = tokens[1..].
let input: Vec<i32> = tokens[..tokens.len() - 1].to_vec();
let target: Vec<i32> = tokens[1..].to_vec();
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
// --- Tiny model with small-scale deterministic init. ---
let mut cfg = Config::tiny();
cfg.vocab = vocab;
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
// RMSNorm gammas ([dim]) init to ~1; everything else small random.
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
let params = model.params();
println!(
"overfit: vocab={vocab} seq={} params={}",
input.len(),
cfg.num_params()
);
let read_loss = |l: &Var| -> f32 { l.value().to_device(Device::Cpu).as_slice::<f32>()[0] };
let lr = 0.3f32;
let steps = 200;
let start = read_loss(&model.loss(&ids, &targets));
let mut last = start;
for step in 0..steps {
let loss = model.loss(&ids, &targets);
last = read_loss(&loss);
if step % 20 == 0 || step == steps - 1 {
println!("step {step:3}: loss = {last:.6}");
}
loss.backward();
gd_step(&params, lr);
}
println!("overfit: start loss = {start:.6} → final loss = {last:.6} ({steps} steps)");
// A correct fwd+bwd memorises this tiny fixed batch: loss → ~0.
assert!(
last < 0.05,
"overfit failed to drive loss to ~0: start {start:.4} final {last:.4}"
);
assert!(last < start, "loss did not decrease");
// Sanity: greedy argmax should reproduce the target sequence after overfit.
let logits = model.forward(&ids).value().to_device(Device::Cpu);
let lg = logits.as_slice::<f32>();
let mut correct = 0;
for (r, &t) in target.iter().enumerate() {
let row = &lg[r * vocab..(r + 1) * vocab];
let argmax = row
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0 as i32;
if argmax == t {
correct += 1;
}
}
println!("overfit: greedy match {correct}/{}", target.len());
assert_eq!(correct, target.len() as i32, "did not memorise the batch");
}