Merge t16-grad-accum into main
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> # Conflicts: # README.md # docs/evolution.md
This commit is contained in:
@@ -74,6 +74,10 @@ fn main() {
|
||||
// Optimization knobs (mirror bin/train).
|
||||
let steps: usize = flag(&args, "--steps", 100);
|
||||
let batch: usize = flag(&args, "--batch", 16);
|
||||
// Micro-batch gradient accumulation (Phase T16): effective global batch =
|
||||
// accum_steps × batch, all-reducing only at the accumulation boundary. Default
|
||||
// 1 = no accumulation (bit-identical to the pre-T16 DDP path).
|
||||
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
@@ -150,6 +154,7 @@ fn main() {
|
||||
let dcfg = DdpConfig {
|
||||
seq_len,
|
||||
batch_size: batch,
|
||||
accum_steps,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr,
|
||||
@@ -167,8 +172,9 @@ fn main() {
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {steps} steps, seq {seq_len}, global batch {batch}, lr {max_lr:.1e}→{min_lr:.1e}, \
|
||||
eval every {eval_every}"
|
||||
"training: {steps} steps, seq {seq_len}, global batch {batch} × accum {accum_steps} = \
|
||||
effective global batch {}, lr {max_lr:.1e}→{min_lr:.1e}, eval every {eval_every}",
|
||||
batch * accum_steps
|
||||
);
|
||||
|
||||
if bf16 {
|
||||
|
||||
@@ -35,6 +35,13 @@ pub struct DdpConfig {
|
||||
pub seq_len: usize,
|
||||
/// Global batch size; must be divisible by the world size.
|
||||
pub batch_size: usize,
|
||||
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||
/// accumulates grads over `accum_steps` micro-batches, giving an EFFECTIVE
|
||||
/// global batch of `accum_steps × batch_size`. The cross-rank all-reduce
|
||||
/// fires ONLY at the accumulation boundary (after the last micro-step) —
|
||||
/// intermediate micro-steps skip the NCCL collective entirely. `1` = no
|
||||
/// accumulation (bit-identical to the pre-T16 DDP path).
|
||||
pub accum_steps: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
@@ -96,6 +103,7 @@ pub fn train_rank(
|
||||
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
|
||||
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
|
||||
let batch_local = cfg.batch_size / ctx.world;
|
||||
let accum = cfg.accum_steps.max(1);
|
||||
let start = Instant::now();
|
||||
let mut tokens_seen: u64 = 0;
|
||||
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
|
||||
@@ -105,36 +113,51 @@ pub fn train_rank(
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Draw the whole global batch from the shared RNG (same on every rank);
|
||||
// collect only this rank's shard (global index % world == rank) and run it
|
||||
// as ONE batched forward/backward. The union of shards == the single-GPU
|
||||
// batch; each rank's backward yields its local mean (Σ_local / b_local).
|
||||
let mut inputs = Vec::with_capacity(batch_local);
|
||||
let mut targets_v = Vec::with_capacity(batch_local);
|
||||
for i in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
if i % ctx.world == ctx.rank {
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
// Accumulate grads over `accum` micro-batches, then ONE optimizer step
|
||||
// (Phase T16). Per micro-batch: draw the whole micro global batch from the
|
||||
// shared RNG (same on every rank), keep only this rank's shard (global index
|
||||
// % world == rank), run it as ONE batched forward/backward. Each micro-loss
|
||||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||
// grads across the `accum` micro-backwards) so the boundary grad equals a
|
||||
// single step over an `accum × batch_size` global batch. `accum == 1` skips
|
||||
// the scale → bit-identical to the pre-T16 DDP path. The cross-rank
|
||||
// all-reduce fires ONLY after the last micro-step (intermediate micro-steps
|
||||
// are local-only, no NCCL).
|
||||
let mut local_sum = 0.0f32; // Σ over micro of (local_mean · b_local)
|
||||
for _ in 0..accum {
|
||||
let mut inputs = Vec::with_capacity(batch_local);
|
||||
let mut targets_v = Vec::with_capacity(batch_local);
|
||||
for i in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
if i % ctx.world == ctx.rank {
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
}
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, batch_local);
|
||||
local_sum += read_scalar(&loss) * batch_local as f32; // local mean·b_local
|
||||
if accum == 1 {
|
||||
loss.backward();
|
||||
} else {
|
||||
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
}
|
||||
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, batch_local);
|
||||
let local_mean = read_scalar(&loss); // Σ_local / b_local
|
||||
loss.backward();
|
||||
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
||||
|
||||
// AllReduce(sum) + /world the grads → every rank holds Σ_global/B_global
|
||||
// (local means summed over ranks, /world = global mean). See note above.
|
||||
// Accumulation boundary: ONE AllReduce(sum) + /world over the accumulated
|
||||
// grads → every rank holds the effective-batch (accum·B_global) mean grad
|
||||
// (the per-micro 1/accum scaling is already baked into each backward; the
|
||||
// /world here is orthogonal to accum). Intermediate micro-steps issued NO
|
||||
// NCCL — only this single boundary collective per optimizer step.
|
||||
ctx.all_reduce_average_grads(¶ms);
|
||||
// Reported loss = global mean: sum the per-rank local sums (= mean·b_local)
|
||||
// across ranks, /B_global. With equal b_local this is mean over ranks.
|
||||
let step_loss =
|
||||
all_reduce_loss(ctx, local_mean * batch_local as f32) / cfg.batch_size as f32;
|
||||
// Reported loss = effective-batch mean: AllReduce(sum) the per-rank local
|
||||
// sums across ranks, /(accum·B_global).
|
||||
let step_loss = all_reduce_loss(ctx, local_sum) / (accum * cfg.batch_size) as f32;
|
||||
losses.push(step_loss);
|
||||
|
||||
// Grads are already the global-batch mean — just clip (pre-scale 1.0).
|
||||
// Grads are already the effective-batch mean — just clip (pre-scale 1.0).
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
|
||||
@@ -94,6 +94,7 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
||||
let dcfg = DdpConfig {
|
||||
seq_len: 32,
|
||||
batch_size: 8, // global; 4 per rank with world=2
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
@@ -195,6 +196,127 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
||||
assert!(max_sdiff < 1e-2, "DDP params diverged from single-GPU");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddp_with_accum_matches_single_gpu_big_batch() {
|
||||
// T16: DDP + gradient accumulation must match a single-GPU big-batch baseline
|
||||
// of the SAME effective batch. world=2, accum=2, per-rank micro-batch 2 →
|
||||
// effective global batch = world·accum·b_local = 2·2·2 = 8. Compared against a
|
||||
// single-GPU run with batch 8, accum 1 (the big-batch baseline). The all-reduce
|
||||
// fires only at the accumulation boundary (once per optimizer step, not per
|
||||
// micro-step) — enforced by the train_rank implementation; the load-bearing
|
||||
// gate here is that loss + final params still match the big-batch baseline.
|
||||
let world = 2usize;
|
||||
if device::device_count().unwrap_or(0) < world as i32 {
|
||||
eprintln!("skip: need >= {world} GPUs");
|
||||
return;
|
||||
}
|
||||
|
||||
let vocab = 64usize;
|
||||
let cfg = test_config(vocab);
|
||||
let corpus = synth_corpus(vocab, 4096);
|
||||
let steps = 20usize;
|
||||
let effective_batch = 8usize; // world(2) · accum(2) · b_local(2)
|
||||
let sched = LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
warmup: 3,
|
||||
total: steps,
|
||||
};
|
||||
|
||||
// Single-GPU big-batch baseline: world=1, accum=1, batch = effective_batch.
|
||||
let baseline_cfg = DdpConfig {
|
||||
seq_len: 32,
|
||||
batch_size: effective_batch,
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: sched,
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 1_000_000,
|
||||
seed: 7,
|
||||
eval_every: 0,
|
||||
eval_batches: 0,
|
||||
ckpt_path: None,
|
||||
};
|
||||
let (single_losses, single_params) = run_single_gpu(cfg, &corpus, &baseline_cfg);
|
||||
|
||||
// DDP + accumulation: world=2, accum=2 → per-rank micro-batch = batch/world = 2.
|
||||
let ddp_cfg = DdpConfig {
|
||||
batch_size: effective_batch / 2, // per-step global batch; ×accum = effective
|
||||
accum_steps: 2,
|
||||
..baseline_cfg
|
||||
};
|
||||
let devices = [0u32, 1u32];
|
||||
let id = get_unique_id();
|
||||
let results: Vec<(Vec<f32>, Vec<Vec<f32>>)> = std::thread::scope(|s| {
|
||||
let handles: Vec<_> = devices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(rank, &dev)| {
|
||||
let ddp_cfg = ddp_cfg.clone();
|
||||
let corpus = &corpus;
|
||||
s.spawn(move || {
|
||||
let ctx = DdpContext::init(rank, world, id, dev);
|
||||
let device = Device::Cuda(dev);
|
||||
let model = build_model(cfg, device);
|
||||
let res = train_rank(&ctx, &model, device, corpus, None, &ddp_cfg);
|
||||
let host = model
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
|
||||
.collect::<Vec<_>>();
|
||||
(res.losses, host)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
handles.into_iter().map(|h| h.join().unwrap()).collect()
|
||||
});
|
||||
|
||||
let (ddp_losses, ddp_p0) = &results[0];
|
||||
let (_, ddp_p1) = &results[1];
|
||||
|
||||
// (a) Loss trajectory matches the single-GPU big-batch baseline.
|
||||
let mut max_rel = 0.0f32;
|
||||
for (s, d) in single_losses.iter().zip(ddp_losses) {
|
||||
max_rel = max_rel.max((s - d).abs() / s.abs().max(1e-6));
|
||||
}
|
||||
println!(
|
||||
"DDP+accum(w2·a2·b2) vs single-GPU big-batch(8): single[last]={:.6} ddp[last]={:.6} max_rel={max_rel:.2e}",
|
||||
single_losses.last().unwrap(),
|
||||
ddp_losses.last().unwrap()
|
||||
);
|
||||
assert!(
|
||||
max_rel < 1e-3,
|
||||
"DDP+accum loss diverged from big-batch baseline: {max_rel:.3e}"
|
||||
);
|
||||
|
||||
// (b) Cross-rank parameter agreement (same KI-5 ULP tolerance as the base test).
|
||||
let mut max_pdiff = 0.0f32;
|
||||
for (a, b) in ddp_p0.iter().zip(ddp_p1) {
|
||||
for (x, y) in a.iter().zip(b) {
|
||||
max_pdiff = max_pdiff.max((x - y).abs());
|
||||
}
|
||||
}
|
||||
println!("DDP+accum cross-rank max |param diff| = {max_pdiff:.3e}");
|
||||
assert!(
|
||||
max_pdiff < 1e-6,
|
||||
"ranks' params drifted apart: {max_pdiff:.3e}"
|
||||
);
|
||||
|
||||
// (c) Final params match single-GPU big-batch within fp tolerance.
|
||||
let mut max_sdiff = 0.0f32;
|
||||
for (a, b) in ddp_p0.iter().zip(&single_params) {
|
||||
for (x, y) in a.iter().zip(b) {
|
||||
max_sdiff = max_sdiff.max((x - y).abs() / y.abs().max(1e-6));
|
||||
}
|
||||
}
|
||||
println!("DDP+accum vs single-GPU big-batch max rel |param diff| = {max_sdiff:.3e}");
|
||||
assert!(
|
||||
max_sdiff < 1e-2,
|
||||
"DDP+accum params diverged from big-batch baseline"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddp_throughput_scaling() {
|
||||
let max_gpus = device::device_count().unwrap_or(0) as usize;
|
||||
@@ -230,6 +352,7 @@ fn ddp_throughput_scaling() {
|
||||
let dcfg = DdpConfig {
|
||||
seq_len,
|
||||
batch_size: per_gpu_batch * world,
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 1e-3,
|
||||
|
||||
@@ -101,6 +101,10 @@ fn main() {
|
||||
// Optimization knobs.
|
||||
let steps: usize = flag(&args, "--steps", 2000);
|
||||
let batch_size: usize = flag(&args, "--batch", 8);
|
||||
// Micro-batch gradient accumulation (Phase T16): effective batch =
|
||||
// accum_steps × batch, at one micro-batch's activation-memory cost. Default 1
|
||||
// = no accumulation (bit-identical to the pre-T16 path).
|
||||
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
@@ -208,6 +212,7 @@ fn main() {
|
||||
let tcfg = TrainConfig {
|
||||
seq_len,
|
||||
batch_size,
|
||||
accum_steps,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr,
|
||||
@@ -226,10 +231,13 @@ fn main() {
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}, eval every {}",
|
||||
"training: {} steps, seq {}, batch {} × accum {} = effective batch {}, \
|
||||
lr {:.1e}→{:.1e}, eval every {}",
|
||||
tcfg.steps,
|
||||
tcfg.seq_len,
|
||||
tcfg.batch_size,
|
||||
tcfg.accum_steps,
|
||||
tcfg.batch_size * tcfg.accum_steps,
|
||||
tcfg.schedule.max_lr,
|
||||
tcfg.schedule.min_lr,
|
||||
tcfg.eval_every
|
||||
|
||||
@@ -27,6 +27,12 @@ use crate::schedule::LrSchedule;
|
||||
pub struct TrainConfig {
|
||||
pub seq_len: usize,
|
||||
pub batch_size: usize,
|
||||
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||
/// accumulates grads over `accum_steps` micro-batches of `batch_size`
|
||||
/// sequences, giving an EFFECTIVE batch of `accum_steps × batch_size` at the
|
||||
/// activation-memory cost of a single micro-batch. `1` = no accumulation
|
||||
/// (bit-identical to the pre-T16 path).
|
||||
pub accum_steps: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
@@ -74,28 +80,43 @@ pub fn train(
|
||||
// Best-val checkpointing only kicks in when we actually evaluate.
|
||||
let track_best = valid.is_some() && cfg.eval_every > 0;
|
||||
|
||||
let accum = cfg.accum_steps.max(1);
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Sample `batch_size` sequences and run them as ONE batched forward/
|
||||
// backward. The CE mean over all batch*seq rows is the batch-mean loss, so
|
||||
// backward already yields the batch-mean gradient (clip pre-scale = 1.0).
|
||||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||||
for _ in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
// Accumulate grads over `accum` micro-batches of `batch_size` sequences,
|
||||
// then take ONE optimizer step (Phase T16). Each micro-batch is ONE batched
|
||||
// forward/backward; its loss is the CE mean over batch*seq rows, so backward
|
||||
// yields that micro-batch's mean grad. To make the SUM over `accum` micro-
|
||||
// batches equal a single step over an `accum × batch` batch, each micro-loss
|
||||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||
// grads). `accum == 1` skips the scale entirely → bit-identical to pre-T16.
|
||||
let mut step_loss_sum = 0.0f32;
|
||||
for _ in 0..accum {
|
||||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||||
for _ in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
||||
step_loss_sum += read_scalar(&loss);
|
||||
if accum == 1 {
|
||||
loss.backward();
|
||||
} else {
|
||||
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
}
|
||||
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
||||
let step_loss = read_scalar(&loss);
|
||||
loss.backward();
|
||||
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
||||
// Reported loss = mean over the effective batch = mean of the raw micro
|
||||
// losses (each is itself a micro-batch mean of equal size).
|
||||
let step_loss = step_loss_sum / accum as f32;
|
||||
losses.push(step_loss);
|
||||
|
||||
// Backward already produced the batch-mean gradient — just clip it.
|
||||
// Backward already produced the effective-batch mean gradient — just clip.
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
|
||||
294
crates/xtrain-train/tests/grad_accum.rs
Normal file
294
crates/xtrain-train/tests/grad_accum.rs
Normal file
@@ -0,0 +1,294 @@
|
||||
// T16 gradient-accumulation correctness gates.
|
||||
//
|
||||
// Gradient accumulation is mathematically EXACT: accumulating the grads of N
|
||||
// micro-batches of B sequences (each micro-loss scaled by 1/N before backward,
|
||||
// the tape SUM-accumulating) equals a single step over one N·B-sequence batch.
|
||||
// This file makes that a closed loop on-GPU, plus the accum_steps=1 bit-identity
|
||||
// regression guard.
|
||||
//
|
||||
// 1. accum_equiv_big_batch: same init, same N·B sequences in the same order.
|
||||
// Path A = ONE batched loss over all N·B (the big-batch baseline). Path B =
|
||||
// N micro-backwards of B each, scale(1/N), tape SUM. Assert loss and EVERY
|
||||
// parameter grad match within fp tolerance (only the summation order differs,
|
||||
// like the T8 DDP-vs-single-GPU and T13 recompute gates).
|
||||
// 2. accum1_bit_identical: accum_steps=1 must reproduce the no-accum path
|
||||
// bit-for-bit (the implementation skips the ×1/1 scale entirely) — every
|
||||
// parameter grad max|Δ| == 0.0.
|
||||
// 3. accum_train_converges: drive the real `train()` loop with accum and assert
|
||||
// the per-step effective-batch loss trace tracks a big-batch baseline (errors
|
||||
// stay bounded over many AdamW steps, not just one).
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::Device;
|
||||
use xtrain_train::data::Corpus;
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
use xtrain_train::{TrainConfig, train};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
}
|
||||
|
||||
// `n` deterministic (seq, target) pairs for the equivalence tests.
|
||||
fn make_seqs(n: usize, seq: usize, vocab: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
|
||||
let seqs = (0..n)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts = (0..n)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
(seqs, tgts)
|
||||
}
|
||||
|
||||
// Run one big-batch forward/backward over all `seqs` and return the grads.
|
||||
fn big_batch_grads(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
seqs: &[Vec<i32>],
|
||||
tgts: &[Vec<i32>],
|
||||
) -> (f32, Vec<Vec<f32>>) {
|
||||
let n = seqs.len();
|
||||
let ids = batched_ids_tensor(seqs, device);
|
||||
let tgt = batched_ids_tensor(tgts, device);
|
||||
let loss = model.loss_batched(&ids, &tgt, n);
|
||||
let loss_val = host(&loss.value())[0];
|
||||
loss.backward();
|
||||
let grads = model
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("grad")))
|
||||
.collect();
|
||||
(loss_val, grads)
|
||||
}
|
||||
|
||||
// Accumulate over `accum` micro-batches of `b` sequences (drawn in order from the
|
||||
// flat `seqs`/`tgts`), scaling each micro-loss by 1/accum before backward; the
|
||||
// tape SUM-accumulates. Returns the mean of the raw micro losses + accumulated grads.
|
||||
fn accum_grads(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
seqs: &[Vec<i32>],
|
||||
tgts: &[Vec<i32>],
|
||||
accum: usize,
|
||||
b: usize,
|
||||
scale: bool,
|
||||
) -> (f32, Vec<Vec<f32>>) {
|
||||
let mut loss_sum = 0.0f32;
|
||||
for m in 0..accum {
|
||||
let s = &seqs[m * b..(m + 1) * b];
|
||||
let t = &tgts[m * b..(m + 1) * b];
|
||||
let ids = batched_ids_tensor(s, device);
|
||||
let tgt = batched_ids_tensor(t, device);
|
||||
let loss = model.loss_batched(&ids, &tgt, b);
|
||||
loss_sum += host(&loss.value())[0];
|
||||
if scale {
|
||||
ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
} else {
|
||||
loss.backward(); // accum==1 bit-identity path
|
||||
}
|
||||
}
|
||||
let grads = model
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("grad")))
|
||||
.collect();
|
||||
(loss_sum / accum as f32, grads)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accum_equiv_big_batch() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 3;
|
||||
let b = 2usize; // micro-batch
|
||||
let accum = 4usize; // → effective batch 8
|
||||
let seq = 6usize;
|
||||
let (seqs, tgts) = make_seqs(b * accum, seq, cfg.vocab);
|
||||
|
||||
// Big-batch baseline (accum_steps=1, batch = b·accum).
|
||||
let big = build(cfg, device);
|
||||
let (big_loss, big_grads) = big_batch_grads(&big, device, &seqs, &tgts);
|
||||
|
||||
// Accumulated (accum micro-batches of b, scale 1/accum).
|
||||
let acc = build(cfg, device);
|
||||
let (acc_loss, acc_grads) = accum_grads(&acc, device, &seqs, &tgts, accum, b, true);
|
||||
|
||||
let loss_rel = (big_loss - acc_loss).abs() / big_loss.abs().max(1e-4);
|
||||
let mut max_grad_rel = 0.0f32;
|
||||
for (bg, ag) in big_grads.iter().zip(&acc_grads) {
|
||||
for (x, y) in bg.iter().zip(ag) {
|
||||
max_grad_rel = max_grad_rel.max((x - y).abs() / x.abs().max(1e-3));
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"accum=={accum}×b{b} vs big-batch{}: loss {big_loss:.6}/{acc_loss:.6} (rel {loss_rel:.2e}), \
|
||||
grad max rel {max_grad_rel:.3e}",
|
||||
b * accum
|
||||
);
|
||||
// fp summation order differs (big batch sums b·accum rows once; accum sums per
|
||||
// micro then across micros) → tight fp tol, same convention as T13 recompute.
|
||||
assert!(loss_rel < 1e-5, "loss diverged: {loss_rel:.2e}");
|
||||
assert!(
|
||||
max_grad_rel < 1e-4,
|
||||
"accum grads diverged from big batch: {max_grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accum1_bit_identical() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 3;
|
||||
let b = 4usize;
|
||||
let seq = 6usize;
|
||||
let (seqs, tgts) = make_seqs(b, seq, cfg.vocab);
|
||||
|
||||
// No-accum reference: one batched loss + backward (the pre-T16 path).
|
||||
let reference = build(cfg, device);
|
||||
let (_, ref_grads) = big_batch_grads(&reference, device, &seqs, &tgts);
|
||||
|
||||
// accum_steps=1 path: the loop runs ONE micro-batch and (by design) skips the
|
||||
// ×1/1 scale → must be byte-for-byte identical to the reference backward.
|
||||
let accum1 = build(cfg, device);
|
||||
let (_, a1_grads) = accum_grads(&accum1, device, &seqs, &tgts, 1, b, false);
|
||||
|
||||
let mut max_abs = 0.0f32;
|
||||
for (r, a) in ref_grads.iter().zip(&a1_grads) {
|
||||
for (x, y) in r.iter().zip(a) {
|
||||
max_abs = max_abs.max((x - y).abs());
|
||||
}
|
||||
}
|
||||
println!("accum_steps=1 vs no-accum: grad max |Δ| = {max_abs:.3e}");
|
||||
assert_eq!(
|
||||
max_abs, 0.0,
|
||||
"accum_steps=1 not bit-identical to no-accum: {max_abs:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
// A self-contained synthetic corpus (no tokenizer / data file needed).
|
||||
fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||||
Corpus {
|
||||
tokens: (0..n_tokens)
|
||||
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
|
||||
.collect(),
|
||||
vocab_size: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accum_train_converges() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let vocab = 64usize;
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = vocab;
|
||||
cfg.n_layers = 2;
|
||||
let corpus = synth_corpus(vocab, 4096);
|
||||
let steps = 20usize;
|
||||
let seq = 32usize;
|
||||
|
||||
// Same per-step RNG stream + effective batch 8 either way: the big-batch run
|
||||
// (accum=1, batch=8) and the accumulated run (accum=4, batch=2) draw the SAME
|
||||
// 8 sequences per step in the same order, so the per-step loss/grads — and thus
|
||||
// the whole AdamW trajectory — track within fp tolerance.
|
||||
let sched = LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
warmup: 3,
|
||||
total: steps,
|
||||
};
|
||||
let base = |batch, accum| TrainConfig {
|
||||
seq_len: seq,
|
||||
batch_size: batch,
|
||||
accum_steps: accum,
|
||||
steps,
|
||||
schedule: sched.clone(),
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 1_000_000,
|
||||
ckpt_path: None,
|
||||
ckpt_every: 0,
|
||||
eval_every: 0,
|
||||
eval_batches: 0,
|
||||
seed: 7,
|
||||
};
|
||||
|
||||
let big_model = build(cfg, device);
|
||||
let big = train(&big_model, device, &corpus, None, &base(8, 1)).train_losses;
|
||||
|
||||
let acc_model = build(cfg, device);
|
||||
let acc = train(&acc_model, device, &corpus, None, &base(2, 4)).train_losses;
|
||||
|
||||
let mut max_rel = 0.0f32;
|
||||
for (x, y) in big.iter().zip(&acc) {
|
||||
max_rel = max_rel.max((x - y).abs() / x.abs().max(1e-6));
|
||||
}
|
||||
// Final params should also stay close (errors don't blow up over the run).
|
||||
let mut max_pdiff = 0.0f32;
|
||||
for (p, q) in big_model.params().iter().zip(&acc_model.params()) {
|
||||
for (x, y) in host(&p.value()).iter().zip(host(&q.value())) {
|
||||
max_pdiff = max_pdiff.max((x - y).abs() / x.abs().max(1e-6));
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"accum(4×2) vs big(8) over {steps} steps: loss[last] {:.6}/{:.6} max_rel {max_rel:.2e}, \
|
||||
final param max rel {max_pdiff:.2e}",
|
||||
big.last().unwrap(),
|
||||
acc.last().unwrap()
|
||||
);
|
||||
assert!(
|
||||
max_rel < 1e-3,
|
||||
"accum loss trajectory diverged: {max_rel:.3e}"
|
||||
);
|
||||
assert!(
|
||||
max_pdiff < 1e-2,
|
||||
"accum final params diverged: {max_pdiff:.3e}"
|
||||
);
|
||||
}
|
||||
@@ -84,6 +84,7 @@ fn trains_on_tinystories() {
|
||||
let tcfg = TrainConfig {
|
||||
seq_len: 64,
|
||||
batch_size: 8,
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
|
||||
Reference in New Issue
Block a user