Files
xtrain/crates/xtrain-distributed/src/ddp.rs
Gahow Wang 7a03b0054a train+ddp: micro-batch gradient accumulation (--accum-steps)
Accumulate grads over N micro-batches, then one AdamW step + zero_grad,
for an effective batch of N×micro at one micro-batch's activation cost.
Each micro-loss is scaled by 1/N before backward (the tape SUM-accumulates
the scaled grads) so the boundary grad equals a single step over an N×
batch. accum==1 skips the scale → bit-identical to the pre-T16 path.

DDP: the cross-rank all-reduce fires ONLY at the accumulation boundary
(intermediate micro-steps are local-only, no NCCL); the /world average is
orthogonal to the per-micro 1/N, so the boundary grad is the effective
global-batch mean. New --accum-steps flag in both train binaries; effective
batch is printed.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:45:33 +08:00

300 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! The DDP training step + a single-process, thread-per-GPU launcher (Phase T8).
//!
//! Each rank owns one GPU and one thread. Per step it processes a DISJOINT shard
//! of the global batch, all-reduce-averages the gradients, then runs its own
//! `GpuAdamW.step`. Identical init + identical optimizer state across ranks keep
//! the parameters consistent — verified by the cross-rank param-identity check in
//! the tests.
//!
//! Sampling matches single-GPU bit-for-bit: every rank advances the SAME RNG and
//! draws all `B_global` sequences of a step, but only runs forward+backward on
//! the ones assigned to it (`global index % world == rank`). The union over ranks
//! is exactly the single-GPU batch in the same order, so the all-reduced grad sum
//! equals the single-GPU summed grad.
use std::path::PathBuf;
use std::thread;
use std::time::Instant;
use xtrain_autodiff::tape::Var;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use xtrain_train::checkpoint;
use xtrain_train::clip::clip_grad_norm_gpu;
use xtrain_train::data::Corpus;
use xtrain_train::eval_loss;
use xtrain_train::schedule::LrSchedule;
use crate::{DdpContext, get_unique_id};
/// Per-rank DDP training config. `batch_size` is the GLOBAL batch (split across
/// ranks); the rest mirror `xtrain_train::TrainConfig`.
#[derive(Clone)]
pub struct DdpConfig {
pub seq_len: usize,
/// Global batch size; must be divisible by the world size.
pub batch_size: usize,
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
/// accumulates grads over `accum_steps` micro-batches, giving an EFFECTIVE
/// global batch of `accum_steps × batch_size`. The cross-rank all-reduce
/// fires ONLY at the accumulation boundary (after the last micro-step) —
/// intermediate micro-steps skip the NCCL collective entirely. `1` = no
/// accumulation (bit-identical to the pre-T16 DDP path).
pub accum_steps: usize,
pub steps: usize,
pub schedule: LrSchedule,
pub weight_decay: f32,
pub max_grad_norm: f32,
pub log_every: usize,
pub seed: u64,
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Only rank
/// 0 holds the `valid` corpus and runs the eval (no grad), mirroring
/// `xtrain_train::TrainConfig`. The best-val model is checkpointed by rank 0
/// (every rank's params are identical, so rank 0's are the model's).
pub eval_every: usize,
pub eval_batches: usize,
/// Best-val checkpoint path (written by rank 0 when val improves). When unset,
/// or when `eval_every == 0`, no checkpoint is written.
pub ckpt_path: Option<PathBuf>,
}
/// Outcome of a DDP run on this rank: per-step mean-loss trace plus, when
/// `eval_every > 0`, the (step, val_loss) eval points and the best val loss
/// (eval/best are only populated on rank 0, which owns the `valid` corpus).
pub struct DdpResult {
pub losses: Vec<f32>,
pub evals: Vec<(usize, f32)>,
pub best_val: Option<f32>,
}
/// Run `cfg.steps` DDP steps on this rank's `model`/`corpus`, using `ctx` for the
/// gradient all-reduce. Returns this rank's per-step mean-loss trace (the mean
/// over the GLOBAL batch — every rank computes the same value because losses are
/// all-reduced alongside the grads) plus eval/best-val (rank 0 only). The
/// optimizer step is identical on every rank, so the parameters stay in lockstep.
///
/// `valid` is the held-out corpus for periodic val-loss eval. Only rank 0 needs
/// it (it runs the no-grad eval and writes the best-val checkpoint); pass `None`
/// on the other ranks (or when `cfg.eval_every == 0`).
pub fn train_rank(
ctx: &DdpContext,
model: &TinyTransformer,
device: Device,
corpus: &Corpus,
valid: Option<&Corpus>,
cfg: &DdpConfig,
) -> DdpResult {
assert_eq!(
cfg.batch_size % ctx.world,
0,
"global batch {} not divisible by world {}",
cfg.batch_size,
ctx.world
);
let params = model.params();
let mut opt = GpuAdamW::new(cfg.weight_decay);
let mut rng = cfg.seed;
let mut losses = Vec::with_capacity(cfg.steps);
let mut evals = Vec::new();
let mut best_val: Option<f32> = None;
// Each rank runs ONE batched forward over its b_local = batch_size/world
// sequences → backward grad = local mean (Σ_local / b_local). all_reduce_average
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
let batch_local = cfg.batch_size / ctx.world;
let accum = cfg.accum_steps.max(1);
let start = Instant::now();
let mut tokens_seen: u64 = 0;
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
// across ranks, so rank 0's are the model). Other ranks never touch `valid`.
let do_eval = ctx.rank == 0 && cfg.eval_every > 0 && valid.is_some();
for step in 0..cfg.steps {
let lr = cfg.schedule.lr(step);
// Accumulate grads over `accum` micro-batches, then ONE optimizer step
// (Phase T16). Per micro-batch: draw the whole micro global batch from the
// shared RNG (same on every rank), keep only this rank's shard (global index
// % world == rank), run it as ONE batched forward/backward. Each micro-loss
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
// grads across the `accum` micro-backwards) so the boundary grad equals a
// single step over an `accum × batch_size` global batch. `accum == 1` skips
// the scale → bit-identical to the pre-T16 DDP path. The cross-rank
// all-reduce fires ONLY after the last micro-step (intermediate micro-steps
// are local-only, no NCCL).
let mut local_sum = 0.0f32; // Σ over micro of (local_mean · b_local)
for _ in 0..accum {
let mut inputs = Vec::with_capacity(batch_local);
let mut targets_v = Vec::with_capacity(batch_local);
for i in 0..cfg.batch_size {
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
if i % ctx.world == ctx.rank {
inputs.push(input);
targets_v.push(target);
}
}
let ids = batched_ids_tensor(&inputs, device);
let targets = batched_ids_tensor(&targets_v, device);
let loss = model.loss_batched(&ids, &targets, batch_local);
local_sum += read_scalar(&loss) * batch_local as f32; // local mean·b_local
if accum == 1 {
loss.backward();
} else {
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
}
tokens_seen += (batch_local * cfg.seq_len) as u64;
}
// Accumulation boundary: ONE AllReduce(sum) + /world over the accumulated
// grads → every rank holds the effective-batch (accum·B_global) mean grad
// (the per-micro 1/accum scaling is already baked into each backward; the
// /world here is orthogonal to accum). Intermediate micro-steps issued NO
// NCCL — only this single boundary collective per optimizer step.
ctx.all_reduce_average_grads(&params);
// Reported loss = effective-batch mean: AllReduce(sum) the per-rank local
// sums across ranks, /(accum·B_global).
let step_loss = all_reduce_loss(ctx, local_sum) / (accum * cfg.batch_size) as f32;
losses.push(step_loss);
// Grads are already the effective-batch mean — just clip (pre-scale 1.0).
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
if ctx.rank == 0 && (step % cfg.log_every == 0 || step == cfg.steps - 1) {
let elapsed = start.elapsed().as_secs_f32();
// Global tok/s = per-rank tok/s × world (each rank does 1/world of it).
let tps = (tokens_seen as f32 / elapsed.max(1e-6)) * ctx.world as f32;
println!(
"[rank0] step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
({tps:.0} tok/s global, {} ranks)",
cfg.steps, ctx.world
);
}
// Periodic held-out eval + best-val checkpoint (rank 0 only). Mirrors the
// single-GPU `xtrain_train::train` loop, reusing its `eval_loss` /
// `checkpoint::save` so single-GPU and DDP share one eval/ckpt path. Other
// ranks have nothing to do here (params are identical across ranks).
if do_eval && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
let v = valid.unwrap();
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
evals.push((step, vl));
let improved = best_val.map(|b| vl < b).unwrap_or(true);
println!(
" [rank0] eval @ step {step}: val loss {vl:.4}{}",
if improved { " (best)" } else { "" }
);
if improved {
best_val = Some(vl);
if let Some(path) = &cfg.ckpt_path {
checkpoint::save(path, &params).expect("best checkpoint save");
}
}
}
}
DdpResult {
losses,
evals,
best_val,
}
}
/// Spawn `world` rank threads (one per GPU in `devices`), init NCCL, build an
/// identical model per rank via `make_model`, and run `train_rank`. Returns each
/// rank's `DdpResult` (loss traces are identical; eval/best-val are on rank 0).
/// The launcher owns the thread-per-GPU model: rank 0 mints the `UniqueId`, every
/// thread `cudaSetDevice`s its GPU, builds its `Var` graph locally (the graph is
/// `!Send`), and joins at the end.
///
/// `valid` is the held-out corpus for rank 0's periodic eval (only used when
/// `cfg.eval_every > 0`). `make_model(device)` must be deterministic — same params
/// on every rank — for the parameters to stay consistent.
pub fn launch<F>(
devices: &[u32],
corpus: &Corpus,
valid: Option<&Corpus>,
cfg: &DdpConfig,
make_model: F,
) -> Vec<DdpResult>
where
F: Fn(Device) -> TinyTransformer + Send + Sync,
{
let world = devices.len();
let id = get_unique_id();
thread::scope(|s| {
let handles: Vec<_> = devices
.iter()
.enumerate()
.map(|(rank, &dev)| {
let make_model = &make_model;
let cfg = cfg.clone();
s.spawn(move || {
let ctx = DdpContext::init(rank, world, id, dev);
let device = Device::Cuda(dev);
let model = make_model(device);
// Only rank 0 holds the val corpus for eval.
let v = if rank == 0 { valid } else { None };
train_rank(&ctx, &model, device, corpus, v, &cfg)
})
})
.collect();
handles.into_iter().map(|h| h.join().unwrap()).collect()
})
}
/// AllReduce(sum) a single host scalar across ranks by round-tripping it through a
/// one-element device buffer. Used only for the logged/returned loss, so the cost
/// (one tiny collective per step) is negligible. Returns the summed value.
fn all_reduce_loss(ctx: &DdpContext, local: f32) -> f32 {
use xtrain_tensor::Tensor;
if ctx.world == 1 {
return local;
}
let device = Device::Cuda(ctx.device);
let t = Tensor::from_slice(&[local], &[1]).to_device(device);
ctx.all_reduce_sum_f32_ptr(t.data_ptr() as *mut std::ffi::c_void, 1);
xtrain_cuda::device::synchronize().expect("loss all-reduce sync");
t.to_device(Device::Cpu).as_slice::<f32>()[0]
}
fn read_scalar(v: &Var) -> f32 {
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
}
/// Build a `TinyTransformer` on `device` with the SAME deterministic init the
/// single-GPU `bin/train` uses (LCG fill, gammas ~1). Used by both the launcher
/// and the correctness test so every rank — and the single-GPU baseline — start
/// from bit-identical parameters. `cfg` must be identical on every call.
pub fn build_model(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
})
}
// Deterministic LCG fill in [-scale, scale) — same scheme as bin/train's `fill`.
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}