From 163f567c802ad159dbf496b914b430957c84882f Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Mon, 15 Jun 2026 17:15:29 +0800 Subject: [PATCH] dist: ddp all-reduce + sharded batch DDP training step (train_rank) on top of DdpContext: each rank advances the SAME RNG, draws the whole global batch, and runs forward+backward only on its shard (i % world == rank) so the union over ranks is the single-GPU batch in the same order. After backward, all-reduce-average the device grads, then finish the mean with clip(pre_scale = 1/b_local) -> Sigma_global/B_global, identical to the single-GPU clip(1/B). Each rank then runs its own GpuAdamW.step; same init + same averaged grad + same optimizer state keep params bit-identical across ranks. Adds a deterministic build_model (same LCG init as bin/train) shared by ranks + baseline, a per-step loss all-reduce for the reported global-mean loss, and the thread-per-GPU launch() helper (thread::scope; Var graph is !Send so each rank builds its model thread-locally, only UniqueId/config/&Corpus cross threads). Co-Authored-By: Claude Opus 4.8 --- Cargo.lock | 3 + crates/xtrain-distributed/Cargo.toml | 3 + crates/xtrain-distributed/src/ddp.rs | 204 +++++++++++++++++++++++++++ crates/xtrain-distributed/src/lib.rs | 3 + 4 files changed, 213 insertions(+) create mode 100644 crates/xtrain-distributed/src/ddp.rs diff --git a/Cargo.lock b/Cargo.lock index 5887d45..e3512af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,7 +211,10 @@ version = "0.1.0" dependencies = [ "xtrain-autodiff", "xtrain-cuda", + "xtrain-model", + "xtrain-optim", "xtrain-tensor", + "xtrain-train", ] [[package]] diff --git a/crates/xtrain-distributed/Cargo.toml b/crates/xtrain-distributed/Cargo.toml index 4d55dd2..521dd58 100644 --- a/crates/xtrain-distributed/Cargo.toml +++ b/crates/xtrain-distributed/Cargo.toml @@ -8,3 +8,6 @@ license.workspace = true xtrain-cuda = { path = "../xtrain-cuda" } xtrain-tensor = { path = "../xtrain-tensor" } xtrain-autodiff = { path = "../xtrain-autodiff" } +xtrain-model = { path = "../xtrain-model" } +xtrain-optim = { path = "../xtrain-optim" } +xtrain-train = { path = "../xtrain-train" } diff --git a/crates/xtrain-distributed/src/ddp.rs b/crates/xtrain-distributed/src/ddp.rs new file mode 100644 index 0000000..c79583a --- /dev/null +++ b/crates/xtrain-distributed/src/ddp.rs @@ -0,0 +1,204 @@ +//! The DDP training step + a single-process, thread-per-GPU launcher (Phase T8). +//! +//! Each rank owns one GPU and one thread. Per step it processes a DISJOINT shard +//! of the global batch, all-reduce-averages the gradients, then runs its own +//! `GpuAdamW.step`. Identical init + identical optimizer state across ranks keep +//! the parameters consistent — verified by the cross-rank param-identity check in +//! the tests. +//! +//! Sampling matches single-GPU bit-for-bit: every rank advances the SAME RNG and +//! draws all `B_global` sequences of a step, but only runs forward+backward on +//! the ones assigned to it (`global index % world == rank`). The union over ranks +//! is exactly the single-GPU batch in the same order, so the all-reduced grad sum +//! equals the single-GPU summed grad. + +use std::thread; +use std::time::Instant; + +use xtrain_autodiff::tape::Var; +use xtrain_model::{Config, TinyTransformer, ids_tensor}; +use xtrain_optim::GpuAdamW; +use xtrain_tensor::Device; +use xtrain_train::clip::clip_grad_norm_gpu; +use xtrain_train::data::Corpus; +use xtrain_train::schedule::LrSchedule; + +use crate::{DdpContext, get_unique_id}; + +/// Per-rank DDP training config. `batch_size` is the GLOBAL batch (split across +/// ranks); the rest mirror `xtrain_train::TrainConfig`. +#[derive(Clone)] +pub struct DdpConfig { + pub seq_len: usize, + /// Global batch size; must be divisible by the world size. + pub batch_size: usize, + pub steps: usize, + pub schedule: LrSchedule, + pub weight_decay: f32, + pub max_grad_norm: f32, + pub log_every: usize, + pub seed: u64, +} + +/// Run `cfg.steps` DDP steps on this rank's `model`/`corpus`, using `ctx` for the +/// gradient all-reduce. Returns this rank's per-step mean-loss trace (the mean +/// over the GLOBAL batch — every rank computes the same value because losses are +/// all-reduced alongside the grads). The optimizer step is identical on every +/// rank, so the parameters stay in lockstep. +pub fn train_rank( + ctx: &DdpContext, + model: &TinyTransformer, + device: Device, + corpus: &Corpus, + cfg: &DdpConfig, +) -> Vec { + assert_eq!( + cfg.batch_size % ctx.world, + 0, + "global batch {} not divisible by world {}", + cfg.batch_size, + ctx.world + ); + let params = model.params(); + let mut opt = GpuAdamW::new(cfg.weight_decay); + let mut rng = cfg.seed; + let mut losses = Vec::with_capacity(cfg.steps); + // Each rank reaches the global batch mean as (Σ_global / world) · (1/b_local), + // where b_local = batch_size / world (see DdpContext::all_reduce_average_grads). + let batch_local = cfg.batch_size / ctx.world; + let inv_batch_local = 1.0 / batch_local as f32; + let start = Instant::now(); + let mut tokens_seen: u64 = 0; + + for step in 0..cfg.steps { + let lr = cfg.schedule.lr(step); + + // Draw the whole global batch from the shared RNG (same on every rank); + // run forward+backward only on this rank's shard. The tape SUMs the + // shard's grads; the union of shards == the single-GPU batch. + let mut local_loss_sum = 0.0f32; + for i in 0..cfg.batch_size { + let (input, target) = corpus.sample(cfg.seq_len, &mut rng); + if i % ctx.world != ctx.rank { + continue; // not this rank's sequence + } + let ids = ids_tensor(&input, device); + let targets = ids_tensor(&target, device); + let loss = model.loss(&ids, &targets); + local_loss_sum += read_scalar(&loss); + loss.backward(); + tokens_seen += cfg.seq_len as u64; + } + + // AllReduce(sum) + /world the grads → every rank holds Σ_global/world. + ctx.all_reduce_average_grads(¶ms); + // The reported loss is the global mean: average local sums across ranks. + let step_loss = all_reduce_loss(ctx, local_loss_sum) / cfg.batch_size as f32; + losses.push(step_loss); + + // clip pre_scale = 1/b_local finishes the average to Σ_global/B_global, + // identical to the single-GPU clip(pre_scale = 1/B_global). + let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, inv_batch_local); + opt.step(lr, ¶ms); + for p in ¶ms { + p.zero_grad(); + } + + if ctx.rank == 0 && (step % cfg.log_every == 0 || step == cfg.steps - 1) { + let elapsed = start.elapsed().as_secs_f32(); + // Global tok/s = per-rank tok/s × world (each rank does 1/world of it). + let tps = (tokens_seen as f32 / elapsed.max(1e-6)) * ctx.world as f32; + println!( + "[rank0] step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \ + ({tps:.0} tok/s global, {} ranks)", + cfg.steps, ctx.world + ); + } + } + losses +} + +/// Spawn `world` rank threads (one per GPU in `devices`), init NCCL, build an +/// identical model per rank via `make_model`, and run `train_rank`. Returns each +/// rank's loss trace (all identical). The launcher owns the thread-per-GPU model: +/// rank 0 mints the `UniqueId`, every thread `cudaSetDevice`s its GPU, builds its +/// `Var` graph locally (the graph is `!Send`), and joins at the end. +/// +/// `make_model(device)` must be deterministic — same params on every rank — for +/// the parameters to stay consistent. +pub fn launch(devices: &[u32], corpus: &Corpus, cfg: &DdpConfig, make_model: F) -> Vec> +where + F: Fn(Device) -> TinyTransformer + Send + Sync, +{ + let world = devices.len(); + let id = get_unique_id(); + + thread::scope(|s| { + let handles: Vec<_> = devices + .iter() + .enumerate() + .map(|(rank, &dev)| { + let make_model = &make_model; + let cfg = cfg.clone(); + s.spawn(move || { + let ctx = DdpContext::init(rank, world, id, dev); + let device = Device::Cuda(dev); + let model = make_model(device); + train_rank(&ctx, &model, device, corpus, &cfg) + }) + }) + .collect(); + handles.into_iter().map(|h| h.join().unwrap()).collect() + }) +} + +/// AllReduce(sum) a single host scalar across ranks by round-tripping it through a +/// one-element device buffer. Used only for the logged/returned loss, so the cost +/// (one tiny collective per step) is negligible. Returns the summed value. +fn all_reduce_loss(ctx: &DdpContext, local: f32) -> f32 { + use xtrain_tensor::{DType, Tensor}; + if ctx.world == 1 { + return local; + } + let device = Device::Cuda(ctx.device); + let t = Tensor::from_slice(&[local], &[1]).to_device(device); + ctx.all_reduce_sum_f32_ptr(t.data_ptr() as *mut std::ffi::c_void, 1); + xtrain_cuda::device::synchronize().expect("loss all-reduce sync"); + t.to_device(Device::Cpu).as_slice::()[0] +} + +fn read_scalar(v: &Var) -> f32 { + v.value().to_device(Device::Cpu).as_slice::()[0] +} + +/// Build a `TinyTransformer` on `device` with the SAME deterministic init the +/// single-GPU `bin/train` uses (LCG fill, gammas ~1). Used by both the launcher +/// and the correctness test so every rank — and the single-GPU baseline — start +/// from bit-identical parameters. `cfg` must be identical on every call. +pub fn build_model(cfg: Config, device: Device) -> TinyTransformer { + let mut seed = 1u64; + TinyTransformer::new(cfg, device, |shape| { + seed = seed.wrapping_add(1); + let n: usize = shape.iter().product(); + if shape.len() == 1 { + fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() + } else { + fill(n, seed, 0.04) + } + }) +} + +// Deterministic LCG fill in [-scale, scale) — same scheme as bin/train's `fill`. +fn fill(n: usize, seed: u64, scale: f32) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale + }) + .collect() +} diff --git a/crates/xtrain-distributed/src/lib.rs b/crates/xtrain-distributed/src/lib.rs index e16e59b..0c2f1b9 100644 --- a/crates/xtrain-distributed/src/lib.rs +++ b/crates/xtrain-distributed/src/lib.rs @@ -16,8 +16,11 @@ #![cfg(not(no_cuda))] +pub mod ddp; pub mod ffi; +pub use ddp::{DdpConfig, build_model, launch, train_rank}; + use std::ffi::c_void; use ffi::{NcclComm, NcclUniqueId};