295 lines
9.9 KiB
Rust
295 lines
9.9 KiB
Rust
// T16 gradient-accumulation correctness gates.
|
||
//
|
||
// Gradient accumulation is mathematically EXACT: accumulating the grads of N
|
||
// micro-batches of B sequences (each micro-loss scaled by 1/N before backward,
|
||
// the tape SUM-accumulating) equals a single step over one N·B-sequence batch.
|
||
// This file makes that a closed loop on-GPU, plus the accum_steps=1 bit-identity
|
||
// regression guard.
|
||
//
|
||
// 1. accum_equiv_big_batch: same init, same N·B sequences in the same order.
|
||
// Path A = ONE batched loss over all N·B (the big-batch baseline). Path B =
|
||
// N micro-backwards of B each, scale(1/N), tape SUM. Assert loss and EVERY
|
||
// parameter grad match within fp tolerance (only the summation order differs,
|
||
// like the T8 DDP-vs-single-GPU and T13 recompute gates).
|
||
// 2. accum1_bit_identical: accum_steps=1 must reproduce the no-accum path
|
||
// bit-for-bit (the implementation skips the ×1/1 scale entirely) — every
|
||
// parameter grad max|Δ| == 0.0.
|
||
// 3. accum_train_converges: drive the real `train()` loop with accum and assert
|
||
// the per-step effective-batch loss trace tracks a big-batch baseline (errors
|
||
// stay bounded over many AdamW steps, not just one).
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use xtrain_autodiff::ops;
|
||
use xtrain_cuda::device;
|
||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||
use xtrain_tensor::Device;
|
||
use xtrain_train::data::Corpus;
|
||
use xtrain_train::schedule::LrSchedule;
|
||
use xtrain_train::{TrainConfig, train};
|
||
|
||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||
let mut state = seed
|
||
.wrapping_mul(2862933555777941757)
|
||
.wrapping_add(3037000493);
|
||
(0..n)
|
||
.map(|_| {
|
||
state = state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||
let mut seed = 1u64;
|
||
TinyTransformer::new(cfg, device, |shape| {
|
||
seed = seed.wrapping_add(1);
|
||
let n: usize = shape.iter().product();
|
||
if shape.len() == 1 {
|
||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||
} else {
|
||
fill(n, seed, 0.08)
|
||
}
|
||
})
|
||
}
|
||
|
||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||
}
|
||
|
||
// `n` deterministic (seq, target) pairs for the equivalence tests.
|
||
fn make_seqs(n: usize, seq: usize, vocab: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
|
||
let seqs = (0..n)
|
||
.map(|b| {
|
||
(0..seq)
|
||
.map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32)
|
||
.collect()
|
||
})
|
||
.collect();
|
||
let tgts = (0..n)
|
||
.map(|b| {
|
||
(0..seq)
|
||
.map(|i| ((b * 5 + i * 2 + 2) % vocab) as i32)
|
||
.collect()
|
||
})
|
||
.collect();
|
||
(seqs, tgts)
|
||
}
|
||
|
||
// Run one big-batch forward/backward over all `seqs` and return the grads.
|
||
fn big_batch_grads(
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
seqs: &[Vec<i32>],
|
||
tgts: &[Vec<i32>],
|
||
) -> (f32, Vec<Vec<f32>>) {
|
||
let n = seqs.len();
|
||
let ids = batched_ids_tensor(seqs, device);
|
||
let tgt = batched_ids_tensor(tgts, device);
|
||
let loss = model.loss_batched(&ids, &tgt, n);
|
||
let loss_val = host(&loss.value())[0];
|
||
loss.backward();
|
||
let grads = model
|
||
.params()
|
||
.iter()
|
||
.map(|p| host(&p.grad().expect("grad")))
|
||
.collect();
|
||
(loss_val, grads)
|
||
}
|
||
|
||
// Accumulate over `accum` micro-batches of `b` sequences (drawn in order from the
|
||
// flat `seqs`/`tgts`), scaling each micro-loss by 1/accum before backward; the
|
||
// tape SUM-accumulates. Returns the mean of the raw micro losses + accumulated grads.
|
||
fn accum_grads(
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
seqs: &[Vec<i32>],
|
||
tgts: &[Vec<i32>],
|
||
accum: usize,
|
||
b: usize,
|
||
scale: bool,
|
||
) -> (f32, Vec<Vec<f32>>) {
|
||
let mut loss_sum = 0.0f32;
|
||
for m in 0..accum {
|
||
let s = &seqs[m * b..(m + 1) * b];
|
||
let t = &tgts[m * b..(m + 1) * b];
|
||
let ids = batched_ids_tensor(s, device);
|
||
let tgt = batched_ids_tensor(t, device);
|
||
let loss = model.loss_batched(&ids, &tgt, b);
|
||
loss_sum += host(&loss.value())[0];
|
||
if scale {
|
||
ops::scale(&loss, 1.0 / accum as f32).backward();
|
||
} else {
|
||
loss.backward(); // accum==1 bit-identity path
|
||
}
|
||
}
|
||
let grads = model
|
||
.params()
|
||
.iter()
|
||
.map(|p| host(&p.grad().expect("grad")))
|
||
.collect();
|
||
(loss_sum / accum as f32, grads)
|
||
}
|
||
|
||
#[test]
|
||
fn accum_equiv_big_batch() {
|
||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||
device::set_device(0).unwrap();
|
||
let device = Device::Cuda(0);
|
||
|
||
let mut cfg = Config::tiny();
|
||
cfg.vocab = 16;
|
||
cfg.n_layers = 3;
|
||
let b = 2usize; // micro-batch
|
||
let accum = 4usize; // → effective batch 8
|
||
let seq = 6usize;
|
||
let (seqs, tgts) = make_seqs(b * accum, seq, cfg.vocab);
|
||
|
||
// Big-batch baseline (accum_steps=1, batch = b·accum).
|
||
let big = build(cfg, device);
|
||
let (big_loss, big_grads) = big_batch_grads(&big, device, &seqs, &tgts);
|
||
|
||
// Accumulated (accum micro-batches of b, scale 1/accum).
|
||
let acc = build(cfg, device);
|
||
let (acc_loss, acc_grads) = accum_grads(&acc, device, &seqs, &tgts, accum, b, true);
|
||
|
||
let loss_rel = (big_loss - acc_loss).abs() / big_loss.abs().max(1e-4);
|
||
let mut max_grad_rel = 0.0f32;
|
||
for (bg, ag) in big_grads.iter().zip(&acc_grads) {
|
||
for (x, y) in bg.iter().zip(ag) {
|
||
max_grad_rel = max_grad_rel.max((x - y).abs() / x.abs().max(1e-3));
|
||
}
|
||
}
|
||
println!(
|
||
"accum=={accum}×b{b} vs big-batch{}: loss {big_loss:.6}/{acc_loss:.6} (rel {loss_rel:.2e}), \
|
||
grad max rel {max_grad_rel:.3e}",
|
||
b * accum
|
||
);
|
||
// fp summation order differs (big batch sums b·accum rows once; accum sums per
|
||
// micro then across micros) → tight fp tol, same convention as T13 recompute.
|
||
assert!(loss_rel < 1e-5, "loss diverged: {loss_rel:.2e}");
|
||
assert!(
|
||
max_grad_rel < 1e-4,
|
||
"accum grads diverged from big batch: {max_grad_rel:.3e}"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn accum1_bit_identical() {
|
||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||
device::set_device(0).unwrap();
|
||
let device = Device::Cuda(0);
|
||
|
||
let mut cfg = Config::tiny();
|
||
cfg.vocab = 16;
|
||
cfg.n_layers = 3;
|
||
let b = 4usize;
|
||
let seq = 6usize;
|
||
let (seqs, tgts) = make_seqs(b, seq, cfg.vocab);
|
||
|
||
// No-accum reference: one batched loss + backward (the pre-T16 path).
|
||
let reference = build(cfg, device);
|
||
let (_, ref_grads) = big_batch_grads(&reference, device, &seqs, &tgts);
|
||
|
||
// accum_steps=1 path: the loop runs ONE micro-batch and (by design) skips the
|
||
// ×1/1 scale → must be byte-for-byte identical to the reference backward.
|
||
let accum1 = build(cfg, device);
|
||
let (_, a1_grads) = accum_grads(&accum1, device, &seqs, &tgts, 1, b, false);
|
||
|
||
let mut max_abs = 0.0f32;
|
||
for (r, a) in ref_grads.iter().zip(&a1_grads) {
|
||
for (x, y) in r.iter().zip(a) {
|
||
max_abs = max_abs.max((x - y).abs());
|
||
}
|
||
}
|
||
println!("accum_steps=1 vs no-accum: grad max |Δ| = {max_abs:.3e}");
|
||
assert_eq!(
|
||
max_abs, 0.0,
|
||
"accum_steps=1 not bit-identical to no-accum: {max_abs:.3e}"
|
||
);
|
||
}
|
||
|
||
// A self-contained synthetic corpus (no tokenizer / data file needed).
|
||
fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||
Corpus {
|
||
tokens: (0..n_tokens)
|
||
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
|
||
.collect(),
|
||
vocab_size: vocab,
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn accum_train_converges() {
|
||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||
device::set_device(0).unwrap();
|
||
let device = Device::Cuda(0);
|
||
|
||
let vocab = 64usize;
|
||
let mut cfg = Config::tiny();
|
||
cfg.vocab = vocab;
|
||
cfg.n_layers = 2;
|
||
let corpus = synth_corpus(vocab, 4096);
|
||
let steps = 20usize;
|
||
let seq = 32usize;
|
||
|
||
// Same per-step RNG stream + effective batch 8 either way: the big-batch run
|
||
// (accum=1, batch=8) and the accumulated run (accum=4, batch=2) draw the SAME
|
||
// 8 sequences per step in the same order, so the per-step loss/grads — and thus
|
||
// the whole AdamW trajectory — track within fp tolerance.
|
||
let sched = LrSchedule {
|
||
max_lr: 3e-3,
|
||
min_lr: 3e-4,
|
||
warmup: 3,
|
||
total: steps,
|
||
};
|
||
let base = |batch, accum| TrainConfig {
|
||
seq_len: seq,
|
||
batch_size: batch,
|
||
accum_steps: accum,
|
||
steps,
|
||
schedule: sched.clone(),
|
||
weight_decay: 0.1,
|
||
max_grad_norm: 1.0,
|
||
log_every: 1_000_000,
|
||
ckpt_path: None,
|
||
ckpt_every: 0,
|
||
eval_every: 0,
|
||
eval_batches: 0,
|
||
seed: 7,
|
||
};
|
||
|
||
let big_model = build(cfg, device);
|
||
let big = train(&big_model, device, &corpus, None, &base(8, 1)).train_losses;
|
||
|
||
let acc_model = build(cfg, device);
|
||
let acc = train(&acc_model, device, &corpus, None, &base(2, 4)).train_losses;
|
||
|
||
let mut max_rel = 0.0f32;
|
||
for (x, y) in big.iter().zip(&acc) {
|
||
max_rel = max_rel.max((x - y).abs() / x.abs().max(1e-6));
|
||
}
|
||
// Final params should also stay close (errors don't blow up over the run).
|
||
let mut max_pdiff = 0.0f32;
|
||
for (p, q) in big_model.params().iter().zip(&acc_model.params()) {
|
||
for (x, y) in host(&p.value()).iter().zip(host(&q.value())) {
|
||
max_pdiff = max_pdiff.max((x - y).abs() / x.abs().max(1e-6));
|
||
}
|
||
}
|
||
println!(
|
||
"accum(4×2) vs big(8) over {steps} steps: loss[last] {:.6}/{:.6} max_rel {max_rel:.2e}, \
|
||
final param max rel {max_pdiff:.2e}",
|
||
big.last().unwrap(),
|
||
acc.last().unwrap()
|
||
);
|
||
assert!(
|
||
max_rel < 1e-3,
|
||
"accum loss trajectory diverged: {max_rel:.3e}"
|
||
);
|
||
assert!(
|
||
max_pdiff < 1e-2,
|
||
"accum final params diverged: {max_pdiff:.3e}"
|
||
);
|
||
}
|