Files
xtrain/crates/xtrain-train/tests/grad_accum.rs
2026-06-17 23:49:04 +08:00

295 lines
9.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// T16 gradient-accumulation correctness gates.
//
// Gradient accumulation is mathematically EXACT: accumulating the grads of N
// micro-batches of B sequences (each micro-loss scaled by 1/N before backward,
// the tape SUM-accumulating) equals a single step over one N·B-sequence batch.
// This file makes that a closed loop on-GPU, plus the accum_steps=1 bit-identity
// regression guard.
//
// 1. accum_equiv_big_batch: same init, same N·B sequences in the same order.
// Path A = ONE batched loss over all N·B (the big-batch baseline). Path B =
// N micro-backwards of B each, scale(1/N), tape SUM. Assert loss and EVERY
// parameter grad match within fp tolerance (only the summation order differs,
// like the T8 DDP-vs-single-GPU and T13 recompute gates).
// 2. accum1_bit_identical: accum_steps=1 must reproduce the no-accum path
// bit-for-bit (the implementation skips the ×1/1 scale entirely) — every
// parameter grad max|Δ| == 0.0.
// 3. accum_train_converges: drive the real `train()` loop with accum and assert
// the per-step effective-batch loss trace tracks a big-batch baseline (errors
// stay bounded over many AdamW steps, not just one).
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::Device;
use xtrain_train::data::Corpus;
use xtrain_train::schedule::LrSchedule;
use xtrain_train::{TrainConfig, train};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
}
// `n` deterministic (seq, target) pairs for the equivalence tests.
fn make_seqs(n: usize, seq: usize, vocab: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
let seqs = (0..n)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32)
.collect()
})
.collect();
let tgts = (0..n)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % vocab) as i32)
.collect()
})
.collect();
(seqs, tgts)
}
// Run one big-batch forward/backward over all `seqs` and return the grads.
fn big_batch_grads(
model: &TinyTransformer,
device: Device,
seqs: &[Vec<i32>],
tgts: &[Vec<i32>],
) -> (f32, Vec<Vec<f32>>) {
let n = seqs.len();
let ids = batched_ids_tensor(seqs, device);
let tgt = batched_ids_tensor(tgts, device);
let loss = model.loss_batched(&ids, &tgt, n);
let loss_val = host(&loss.value())[0];
loss.backward();
let grads = model
.params()
.iter()
.map(|p| host(&p.grad().expect("grad")))
.collect();
(loss_val, grads)
}
// Accumulate over `accum` micro-batches of `b` sequences (drawn in order from the
// flat `seqs`/`tgts`), scaling each micro-loss by 1/accum before backward; the
// tape SUM-accumulates. Returns the mean of the raw micro losses + accumulated grads.
fn accum_grads(
model: &TinyTransformer,
device: Device,
seqs: &[Vec<i32>],
tgts: &[Vec<i32>],
accum: usize,
b: usize,
scale: bool,
) -> (f32, Vec<Vec<f32>>) {
let mut loss_sum = 0.0f32;
for m in 0..accum {
let s = &seqs[m * b..(m + 1) * b];
let t = &tgts[m * b..(m + 1) * b];
let ids = batched_ids_tensor(s, device);
let tgt = batched_ids_tensor(t, device);
let loss = model.loss_batched(&ids, &tgt, b);
loss_sum += host(&loss.value())[0];
if scale {
ops::scale(&loss, 1.0 / accum as f32).backward();
} else {
loss.backward(); // accum==1 bit-identity path
}
}
let grads = model
.params()
.iter()
.map(|p| host(&p.grad().expect("grad")))
.collect();
(loss_sum / accum as f32, grads)
}
#[test]
fn accum_equiv_big_batch() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 3;
let b = 2usize; // micro-batch
let accum = 4usize; // → effective batch 8
let seq = 6usize;
let (seqs, tgts) = make_seqs(b * accum, seq, cfg.vocab);
// Big-batch baseline (accum_steps=1, batch = b·accum).
let big = build(cfg, device);
let (big_loss, big_grads) = big_batch_grads(&big, device, &seqs, &tgts);
// Accumulated (accum micro-batches of b, scale 1/accum).
let acc = build(cfg, device);
let (acc_loss, acc_grads) = accum_grads(&acc, device, &seqs, &tgts, accum, b, true);
let loss_rel = (big_loss - acc_loss).abs() / big_loss.abs().max(1e-4);
let mut max_grad_rel = 0.0f32;
for (bg, ag) in big_grads.iter().zip(&acc_grads) {
for (x, y) in bg.iter().zip(ag) {
max_grad_rel = max_grad_rel.max((x - y).abs() / x.abs().max(1e-3));
}
}
println!(
"accum=={accum}×b{b} vs big-batch{}: loss {big_loss:.6}/{acc_loss:.6} (rel {loss_rel:.2e}), \
grad max rel {max_grad_rel:.3e}",
b * accum
);
// fp summation order differs (big batch sums b·accum rows once; accum sums per
// micro then across micros) → tight fp tol, same convention as T13 recompute.
assert!(loss_rel < 1e-5, "loss diverged: {loss_rel:.2e}");
assert!(
max_grad_rel < 1e-4,
"accum grads diverged from big batch: {max_grad_rel:.3e}"
);
}
#[test]
fn accum1_bit_identical() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 3;
let b = 4usize;
let seq = 6usize;
let (seqs, tgts) = make_seqs(b, seq, cfg.vocab);
// No-accum reference: one batched loss + backward (the pre-T16 path).
let reference = build(cfg, device);
let (_, ref_grads) = big_batch_grads(&reference, device, &seqs, &tgts);
// accum_steps=1 path: the loop runs ONE micro-batch and (by design) skips the
// ×1/1 scale → must be byte-for-byte identical to the reference backward.
let accum1 = build(cfg, device);
let (_, a1_grads) = accum_grads(&accum1, device, &seqs, &tgts, 1, b, false);
let mut max_abs = 0.0f32;
for (r, a) in ref_grads.iter().zip(&a1_grads) {
for (x, y) in r.iter().zip(a) {
max_abs = max_abs.max((x - y).abs());
}
}
println!("accum_steps=1 vs no-accum: grad max |Δ| = {max_abs:.3e}");
assert_eq!(
max_abs, 0.0,
"accum_steps=1 not bit-identical to no-accum: {max_abs:.3e}"
);
}
// A self-contained synthetic corpus (no tokenizer / data file needed).
fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
Corpus {
tokens: (0..n_tokens)
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
.collect(),
vocab_size: vocab,
}
}
#[test]
fn accum_train_converges() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let vocab = 64usize;
let mut cfg = Config::tiny();
cfg.vocab = vocab;
cfg.n_layers = 2;
let corpus = synth_corpus(vocab, 4096);
let steps = 20usize;
let seq = 32usize;
// Same per-step RNG stream + effective batch 8 either way: the big-batch run
// (accum=1, batch=8) and the accumulated run (accum=4, batch=2) draw the SAME
// 8 sequences per step in the same order, so the per-step loss/grads — and thus
// the whole AdamW trajectory — track within fp tolerance.
let sched = LrSchedule {
max_lr: 3e-3,
min_lr: 3e-4,
warmup: 3,
total: steps,
};
let base = |batch, accum| TrainConfig {
seq_len: seq,
batch_size: batch,
accum_steps: accum,
steps,
schedule: sched.clone(),
weight_decay: 0.1,
max_grad_norm: 1.0,
log_every: 1_000_000,
ckpt_path: None,
ckpt_every: 0,
eval_every: 0,
eval_batches: 0,
seed: 7,
};
let big_model = build(cfg, device);
let big = train(&big_model, device, &corpus, None, &base(8, 1)).train_losses;
let acc_model = build(cfg, device);
let acc = train(&acc_model, device, &corpus, None, &base(2, 4)).train_losses;
let mut max_rel = 0.0f32;
for (x, y) in big.iter().zip(&acc) {
max_rel = max_rel.max((x - y).abs() / x.abs().max(1e-6));
}
// Final params should also stay close (errors don't blow up over the run).
let mut max_pdiff = 0.0f32;
for (p, q) in big_model.params().iter().zip(&acc_model.params()) {
for (x, y) in host(&p.value()).iter().zip(host(&q.value())) {
max_pdiff = max_pdiff.max((x - y).abs() / x.abs().max(1e-6));
}
}
println!(
"accum(4×2) vs big(8) over {steps} steps: loss[last] {:.6}/{:.6} max_rel {max_rel:.2e}, \
final param max rel {max_pdiff:.2e}",
big.last().unwrap(),
acc.last().unwrap()
);
assert!(
max_rel < 1e-3,
"accum loss trajectory diverged: {max_rel:.3e}"
);
assert!(
max_pdiff < 1e-2,
"accum final params diverged: {max_pdiff:.3e}"
);
}