Files
xtrain/crates/xtrain-distributed/tests/ddp_correctness.rs
Gahow Wang abe5ceb913 test: grad-accum equivalence + accum=1 bit-identity + DDP+accum
- grad_accum.rs: accum=N×B grads bit-close to a single N·B big batch;
  accum_steps=1 bit-identical (max|Δ|==0) to no-accum; real train() loop
  with accum tracks a big-batch baseline over 20 AdamW steps.
- ddp_correctness.rs: world=2 + accum=2 matches a single-GPU big batch of
  the same effective size (loss + cross-rank + vs-baseline).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:45:40 +08:00

389 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! DDP acceptance (Phase T8). Gated to a GPU host; skips when fewer than 2 GPUs.
//!
//! 1. **Correctness**: K steps single-GPU (world=1, global batch B) vs 2-rank DDP
//! (B/2 of the SAME data in the same order each) → loss trajectories match
//! within tight fp tolerance (it's just gradient averaging), and the two
//! ranks' parameters are identical after the run.
//! 2. **Throughput**: 1 / 2 / 4 GPU global tok/s on the SAME per-GPU workload →
//! near-linear scaling. Prints the table (run with `--nocapture`).
#![cfg(not(no_cuda))]
use std::time::Instant;
use xtrain_cuda::device;
use xtrain_distributed::{DdpConfig, DdpContext, build_model, get_unique_id, launch, train_rank};
use xtrain_model::{Config, batched_ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use xtrain_train::clip::clip_grad_norm_gpu;
use xtrain_train::data::Corpus;
use xtrain_train::schedule::LrSchedule;
// A self-contained synthetic corpus so the test needs no tokenizer/data files.
fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
let tokens: Vec<i32> = (0..n_tokens)
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
.collect();
Corpus {
tokens,
vocab_size: vocab,
}
}
fn test_config(vocab: usize) -> Config {
let mut cfg = Config::tiny();
cfg.vocab = vocab;
cfg.n_layers = 2;
cfg
}
// Single-GPU baseline: the SAME loop as the DDP rank but world=1, so the global
// batch is processed on one device. Returns (loss trace, final params on host).
fn run_single_gpu(cfg: Config, corpus: &Corpus, dcfg: &DdpConfig) -> (Vec<f32>, Vec<Vec<f32>>) {
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let model = build_model(cfg, device);
let params = model.params();
let mut opt = GpuAdamW::new(dcfg.weight_decay);
let mut rng = dcfg.seed;
let mut losses = Vec::new();
for step in 0..dcfg.steps {
let lr = dcfg.schedule.lr(step);
// Sample the whole global batch and run it as ONE batched forward/backward
// (matches the T10 DDP path: backward yields the global-batch mean grad).
let mut inputs = Vec::with_capacity(dcfg.batch_size);
let mut targets_v = Vec::with_capacity(dcfg.batch_size);
for _ in 0..dcfg.batch_size {
let (input, target) = corpus.sample(dcfg.seq_len, &mut rng);
inputs.push(input);
targets_v.push(target);
}
let ids = batched_ids_tensor(&inputs, device);
let targets = batched_ids_tensor(&targets_v, device);
let loss = model.loss_batched(&ids, &targets, dcfg.batch_size);
losses.push(loss.value().to_device(Device::Cpu).as_slice::<f32>()[0]);
loss.backward();
clip_grad_norm_gpu(&params, dcfg.max_grad_norm, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
}
let host = params
.iter()
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
.collect();
(losses, host)
}
#[test]
fn ddp_matches_single_gpu_and_params_consistent() {
let world = 2usize;
if device::device_count().unwrap_or(0) < world as i32 {
eprintln!("skip: need >= {world} GPUs");
return;
}
let vocab = 64usize;
let cfg = test_config(vocab);
let corpus = synth_corpus(vocab, 4096);
let steps = 20usize;
let dcfg = DdpConfig {
seq_len: 32,
batch_size: 8, // global; 4 per rank with world=2
accum_steps: 1,
steps,
schedule: LrSchedule {
max_lr: 3e-3,
min_lr: 3e-4,
warmup: 3,
total: steps,
},
weight_decay: 0.1,
max_grad_norm: 1.0,
log_every: 1_000_000, // silence per-step logging in the test
seed: 7,
eval_every: 0,
eval_batches: 0,
ckpt_path: None,
};
// Single-GPU baseline (world=1) over the global batch.
let (single_losses, single_params) = run_single_gpu(cfg, &corpus, &dcfg);
// 2-rank DDP over the SAME corpus/config; returns per-rank (losses, params).
let devices = [0u32, 1u32];
let id = get_unique_id();
let results: Vec<(Vec<f32>, Vec<Vec<f32>>)> = std::thread::scope(|s| {
let handles: Vec<_> = devices
.iter()
.enumerate()
.map(|(rank, &dev)| {
let dcfg = dcfg.clone();
let corpus = &corpus;
s.spawn(move || {
let ctx = DdpContext::init(rank, world, id, dev);
let device = Device::Cuda(dev);
let model = build_model(cfg, device);
let res = train_rank(&ctx, &model, device, corpus, None, &dcfg);
let host = model
.params()
.iter()
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
.collect::<Vec<_>>();
(res.losses, host)
})
})
.collect();
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
let (ddp_losses, ddp_p0) = &results[0];
let (_, ddp_p1) = &results[1];
// (a) DDP loss trajectory matches single-GPU within tight tolerance.
let mut max_rel = 0.0f32;
for (s, d) in single_losses.iter().zip(ddp_losses) {
let rel = (s - d).abs() / s.abs().max(1e-6);
max_rel = max_rel.max(rel);
}
println!(
"DDP vs single-GPU loss: single[last]={:.6} ddp[last]={:.6} max_rel={max_rel:.2e}",
single_losses.last().unwrap(),
ddp_losses.last().unwrap()
);
assert!(
max_rel < 1e-3,
"DDP loss trajectory diverged from single-GPU: max_rel {max_rel:.3e}"
);
// (b) Cross-rank parameter identity (same init + same averaged grad + same
// optimizer state ⇒ identical params).
let mut max_pdiff = 0.0f32;
for (a, b) in ddp_p0.iter().zip(ddp_p1) {
for (x, y) in a.iter().zip(b) {
max_pdiff = max_pdiff.max((x - y).abs());
}
}
println!("cross-rank max |param diff| = {max_pdiff:.3e}");
// On this PCIe-only box, NCCL's all-reduce is not bit-reproducible run-to-run
// across ranks (algorithm/chunk choice is unstable), so cross-rank params can
// differ by a few ULP (observed ≤1.2e-7) even with identical init + averaged
// grads. The load-bearing gate is the loss-trajectory match (a, ~5.7e-7); a
// tight tolerance here, not bit-identity, is the honest invariant (KI-5).
assert!(
max_pdiff < 1e-6,
"ranks' params drifted apart: {max_pdiff:.3e}"
);
// (c) DDP final params match single-GPU final params within fp tolerance.
// Looser than (a)/(b): DDP and single-GPU differ only in the gradient SUMMATION
// ORDER (single-GPU sums B sequences in tape order; DDP sums per-rank shards
// then NCCL-sums across ranks). fp addition isn't associative, so that tiny
// per-step rounding compounds over the AdamW steps — a few e-3 relative on
// individual params is expected and benign. The loss-trajectory match (a, ~1e-7)
// and tight cross-rank agreement (b, <1e-6) are the load-bearing checks.
let mut max_sdiff = 0.0f32;
for (a, b) in ddp_p0.iter().zip(&single_params) {
for (x, y) in a.iter().zip(b) {
max_sdiff = max_sdiff.max((x - y).abs() / y.abs().max(1e-6));
}
}
println!("DDP vs single-GPU max rel |param diff| = {max_sdiff:.3e}");
assert!(max_sdiff < 1e-2, "DDP params diverged from single-GPU");
}
#[test]
fn ddp_with_accum_matches_single_gpu_big_batch() {
// T16: DDP + gradient accumulation must match a single-GPU big-batch baseline
// of the SAME effective batch. world=2, accum=2, per-rank micro-batch 2 →
// effective global batch = world·accum·b_local = 2·2·2 = 8. Compared against a
// single-GPU run with batch 8, accum 1 (the big-batch baseline). The all-reduce
// fires only at the accumulation boundary (once per optimizer step, not per
// micro-step) — enforced by the train_rank implementation; the load-bearing
// gate here is that loss + final params still match the big-batch baseline.
let world = 2usize;
if device::device_count().unwrap_or(0) < world as i32 {
eprintln!("skip: need >= {world} GPUs");
return;
}
let vocab = 64usize;
let cfg = test_config(vocab);
let corpus = synth_corpus(vocab, 4096);
let steps = 20usize;
let effective_batch = 8usize; // world(2) · accum(2) · b_local(2)
let sched = LrSchedule {
max_lr: 3e-3,
min_lr: 3e-4,
warmup: 3,
total: steps,
};
// Single-GPU big-batch baseline: world=1, accum=1, batch = effective_batch.
let baseline_cfg = DdpConfig {
seq_len: 32,
batch_size: effective_batch,
accum_steps: 1,
steps,
schedule: sched,
weight_decay: 0.1,
max_grad_norm: 1.0,
log_every: 1_000_000,
seed: 7,
eval_every: 0,
eval_batches: 0,
ckpt_path: None,
};
let (single_losses, single_params) = run_single_gpu(cfg, &corpus, &baseline_cfg);
// DDP + accumulation: world=2, accum=2 → per-rank micro-batch = batch/world = 2.
let ddp_cfg = DdpConfig {
batch_size: effective_batch / 2, // per-step global batch; ×accum = effective
accum_steps: 2,
..baseline_cfg
};
let devices = [0u32, 1u32];
let id = get_unique_id();
let results: Vec<(Vec<f32>, Vec<Vec<f32>>)> = std::thread::scope(|s| {
let handles: Vec<_> = devices
.iter()
.enumerate()
.map(|(rank, &dev)| {
let ddp_cfg = ddp_cfg.clone();
let corpus = &corpus;
s.spawn(move || {
let ctx = DdpContext::init(rank, world, id, dev);
let device = Device::Cuda(dev);
let model = build_model(cfg, device);
let res = train_rank(&ctx, &model, device, corpus, None, &ddp_cfg);
let host = model
.params()
.iter()
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
.collect::<Vec<_>>();
(res.losses, host)
})
})
.collect();
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
let (ddp_losses, ddp_p0) = &results[0];
let (_, ddp_p1) = &results[1];
// (a) Loss trajectory matches the single-GPU big-batch baseline.
let mut max_rel = 0.0f32;
for (s, d) in single_losses.iter().zip(ddp_losses) {
max_rel = max_rel.max((s - d).abs() / s.abs().max(1e-6));
}
println!(
"DDP+accum(w2·a2·b2) vs single-GPU big-batch(8): single[last]={:.6} ddp[last]={:.6} max_rel={max_rel:.2e}",
single_losses.last().unwrap(),
ddp_losses.last().unwrap()
);
assert!(
max_rel < 1e-3,
"DDP+accum loss diverged from big-batch baseline: {max_rel:.3e}"
);
// (b) Cross-rank parameter agreement (same KI-5 ULP tolerance as the base test).
let mut max_pdiff = 0.0f32;
for (a, b) in ddp_p0.iter().zip(ddp_p1) {
for (x, y) in a.iter().zip(b) {
max_pdiff = max_pdiff.max((x - y).abs());
}
}
println!("DDP+accum cross-rank max |param diff| = {max_pdiff:.3e}");
assert!(
max_pdiff < 1e-6,
"ranks' params drifted apart: {max_pdiff:.3e}"
);
// (c) Final params match single-GPU big-batch within fp tolerance.
let mut max_sdiff = 0.0f32;
for (a, b) in ddp_p0.iter().zip(&single_params) {
for (x, y) in a.iter().zip(b) {
max_sdiff = max_sdiff.max((x - y).abs() / y.abs().max(1e-6));
}
}
println!("DDP+accum vs single-GPU big-batch max rel |param diff| = {max_sdiff:.3e}");
assert!(
max_sdiff < 1e-2,
"DDP+accum params diverged from big-batch baseline"
);
}
#[test]
fn ddp_throughput_scaling() {
let max_gpus = device::device_count().unwrap_or(0) as usize;
if max_gpus < 1 {
eprintln!("skip: no GPU");
return;
}
// Same PER-GPU workload at each world size (batch scales with world), so the
// per-rank cost is fixed and global tok/s should scale ~linearly. Use enough
// steps that the one-time NCCL init + model-build overhead (which is larger at
// world=4 and absent at world=1) amortizes — otherwise the wall-clock ratio
// understates steady-state scaling.
let per_gpu_batch = 8usize;
let vocab = 256usize;
let cfg = test_config(vocab);
let corpus = synth_corpus(vocab, 8192);
let steps = 150usize;
let seq_len = 64usize;
let worlds: Vec<usize> = [1, 2, 4, 8]
.into_iter()
.filter(|&w| w <= max_gpus)
.collect();
println!("\n=== DDP throughput scaling (per-GPU batch {per_gpu_batch}, seq {seq_len}) ===");
println!(
"{:>6} | {:>14} | {:>8}",
"GPUs", "tok/s (global)", "speedup"
);
let mut base = 0.0f64;
for &world in &worlds {
let devices: Vec<u32> = (0..world as u32).collect();
let dcfg = DdpConfig {
seq_len,
batch_size: per_gpu_batch * world,
accum_steps: 1,
steps,
schedule: LrSchedule {
max_lr: 1e-3,
min_lr: 1e-3,
warmup: 1,
total: steps,
},
weight_decay: 0.0,
max_grad_norm: 1.0,
log_every: 1_000_000,
seed: 1,
eval_every: 0,
eval_batches: 0,
ckpt_path: None,
};
let total_tokens = (steps * dcfg.batch_size * seq_len) as f64;
let t = Instant::now();
let _ = launch(&devices, &corpus, None, &dcfg, move |device| {
build_model(cfg, device)
});
let secs = t.elapsed().as_secs_f64();
let tps = total_tokens / secs;
if world == 1 {
base = tps;
}
println!(
"{:>6} | {:>14.0} | {:>7.2}x",
world,
tps,
tps / base.max(1e-9)
);
}
}