New crate xtrain-model: a from-scratch decoder built entirely from the
autodiff op set.
- Config (tiny: dim=32, 2 layers, 2 heads, head_dim=16, ffn=64).
- TinyTransformer: embedding -> N x {pre-RMSNorm -> multi-head causal
attention (RoPE, additive causal mask, per-head SDPA) -> residual;
pre-RMSNorm -> SwiGLU MLP -> residual} -> final RMSNorm -> LM head.
x@W weight convention (engine GEMM is plain A@B); dim=n_heads*head_dim.
- params()/zero_grad-able leaves for the optimizer; param_to_host export.
- overfit test: char-level bring-up (embedded text -> vocab -> shifted
targets), minimal hand-written GD (p -= lr*grad) memorises one fixed
batch -> loss ~0 + greedy argmax matches targets. End-to-end fwd+bwd
correctness signal. Gated #![cfg(not(no_cuda))].
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
134 lines
4.6 KiB
Rust
134 lines
4.6 KiB
Rust
// End-to-end acceptance for the Phase T5 tiny transformer: overfit one fixed
|
||
// char-level batch with a hand-written gradient-descent step and assert the loss
|
||
// collapses toward 0. This is THE signal that the whole fwd+bwd graph (embedding,
|
||
// RMSNorm, RoPE, multi-head attention, SwiGLU, LM head, cross-entropy) is wired
|
||
// correctly — a single buggy backward would stall the loss.
|
||
//
|
||
// The optimizer here is deliberately minimal (`p ← p − lr·grad`); AdamW / LR
|
||
// schedule / real data are T6. Gated behind `not(no_cuda)` (runs on dash5).
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use xtrain_autodiff::tape::Var;
|
||
use xtrain_cuda::device;
|
||
use xtrain_model::{Config, TinyTransformer, ids_tensor};
|
||
use xtrain_tensor::Device;
|
||
|
||
// Deterministic LCG fill in [-scale, scale).
|
||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||
let mut state = seed
|
||
.wrapping_mul(2862933555777941757)
|
||
.wrapping_add(3037000493);
|
||
(0..n)
|
||
.map(|_| {
|
||
state = state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn require_gpu() {
|
||
assert!(
|
||
device::device_count().expect("device count") > 0,
|
||
"no CUDA device"
|
||
);
|
||
device::set_device(0).unwrap();
|
||
}
|
||
|
||
// One GD step over every parameter: p ← p − lr·grad, then zero the grad.
|
||
fn gd_step(params: &[Var], lr: f32) {
|
||
for p in params {
|
||
if let Some(g) = p.grad() {
|
||
let updated = p.value().add(&g.scale(-lr));
|
||
p.set_value(updated);
|
||
}
|
||
p.zero_grad();
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn overfit_tiny_batch() {
|
||
require_gpu();
|
||
let device = Device::Cuda(0);
|
||
|
||
// --- Char-level bring-up: tiny embedded text → vocab → (input, target). ---
|
||
let text = "hello tiny transformer world";
|
||
let mut vocab_chars: Vec<char> = text.chars().collect();
|
||
vocab_chars.sort_unstable();
|
||
vocab_chars.dedup();
|
||
let vocab = vocab_chars.len();
|
||
let stoi = |c: char| vocab_chars.iter().position(|&x| x == c).unwrap() as i32;
|
||
|
||
let tokens: Vec<i32> = text.chars().map(stoi).collect();
|
||
// Next-token prediction: input = tokens[..n-1], target = tokens[1..].
|
||
let input: Vec<i32> = tokens[..tokens.len() - 1].to_vec();
|
||
let target: Vec<i32> = tokens[1..].to_vec();
|
||
let ids = ids_tensor(&input, device);
|
||
let targets = ids_tensor(&target, device);
|
||
|
||
// --- Tiny model with small-scale deterministic init. ---
|
||
let mut cfg = Config::tiny();
|
||
cfg.vocab = vocab;
|
||
let mut seed = 1u64;
|
||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||
seed = seed.wrapping_add(1);
|
||
let n: usize = shape.iter().product();
|
||
// RMSNorm gammas ([dim]) init to ~1; everything else small random.
|
||
if shape.len() == 1 {
|
||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||
} else {
|
||
fill(n, seed, 0.08)
|
||
}
|
||
});
|
||
let params = model.params();
|
||
println!(
|
||
"overfit: vocab={vocab} seq={} params={}",
|
||
input.len(),
|
||
cfg.num_params()
|
||
);
|
||
|
||
let read_loss = |l: &Var| -> f32 { l.value().to_device(Device::Cpu).as_slice::<f32>()[0] };
|
||
|
||
let lr = 0.3f32;
|
||
let steps = 200;
|
||
let start = read_loss(&model.loss(&ids, &targets));
|
||
let mut last = start;
|
||
for step in 0..steps {
|
||
let loss = model.loss(&ids, &targets);
|
||
last = read_loss(&loss);
|
||
if step % 20 == 0 || step == steps - 1 {
|
||
println!("step {step:3}: loss = {last:.6}");
|
||
}
|
||
loss.backward();
|
||
gd_step(¶ms, lr);
|
||
}
|
||
|
||
println!("overfit: start loss = {start:.6} → final loss = {last:.6} ({steps} steps)");
|
||
// A correct fwd+bwd memorises this tiny fixed batch: loss → ~0.
|
||
assert!(
|
||
last < 0.05,
|
||
"overfit failed to drive loss to ~0: start {start:.4} final {last:.4}"
|
||
);
|
||
assert!(last < start, "loss did not decrease");
|
||
|
||
// Sanity: greedy argmax should reproduce the target sequence after overfit.
|
||
let logits = model.forward(&ids).value().to_device(Device::Cpu);
|
||
let lg = logits.as_slice::<f32>();
|
||
let mut correct = 0;
|
||
for (r, &t) in target.iter().enumerate() {
|
||
let row = &lg[r * vocab..(r + 1) * vocab];
|
||
let argmax = row
|
||
.iter()
|
||
.enumerate()
|
||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||
.unwrap()
|
||
.0 as i32;
|
||
if argmax == t {
|
||
correct += 1;
|
||
}
|
||
}
|
||
println!("overfit: greedy match {correct}/{}", target.len());
|
||
assert_eq!(correct, target.len() as i32, "did not memorise the batch");
|
||
}
|