dist: ddp all-reduce + sharded batch

DDP training step (train_rank) on top of DdpContext: each rank advances the
SAME RNG, draws the whole global batch, and runs forward+backward only on its
shard (i % world == rank) so the union over ranks is the single-GPU batch in the
same order. After backward, all-reduce-average the device grads, then finish the
mean with clip(pre_scale = 1/b_local) -> Sigma_global/B_global, identical to the
single-GPU clip(1/B). Each rank then runs its own GpuAdamW.step; same init +
same averaged grad + same optimizer state keep params bit-identical across ranks.

Adds a deterministic build_model (same LCG init as bin/train) shared by ranks +
baseline, a per-step loss all-reduce for the reported global-mean loss, and the
thread-per-GPU launch() helper (thread::scope; Var graph is !Send so each rank
builds its model thread-locally, only UniqueId/config/&Corpus cross threads).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 17:15:29 +08:00
parent e27df50ca9
commit 163f567c80
4 changed files with 213 additions and 0 deletions

3
Cargo.lock generated
View File

@@ -211,7 +211,10 @@ version = "0.1.0"
dependencies = [
"xtrain-autodiff",
"xtrain-cuda",
"xtrain-model",
"xtrain-optim",
"xtrain-tensor",
"xtrain-train",
]
[[package]]

View File

@@ -8,3 +8,6 @@ license.workspace = true
xtrain-cuda = { path = "../xtrain-cuda" }
xtrain-tensor = { path = "../xtrain-tensor" }
xtrain-autodiff = { path = "../xtrain-autodiff" }
xtrain-model = { path = "../xtrain-model" }
xtrain-optim = { path = "../xtrain-optim" }
xtrain-train = { path = "../xtrain-train" }

View File

@@ -0,0 +1,204 @@
//! The DDP training step + a single-process, thread-per-GPU launcher (Phase T8).
//!
//! Each rank owns one GPU and one thread. Per step it processes a DISJOINT shard
//! of the global batch, all-reduce-averages the gradients, then runs its own
//! `GpuAdamW.step`. Identical init + identical optimizer state across ranks keep
//! the parameters consistent — verified by the cross-rank param-identity check in
//! the tests.
//!
//! Sampling matches single-GPU bit-for-bit: every rank advances the SAME RNG and
//! draws all `B_global` sequences of a step, but only runs forward+backward on
//! the ones assigned to it (`global index % world == rank`). The union over ranks
//! is exactly the single-GPU batch in the same order, so the all-reduced grad sum
//! equals the single-GPU summed grad.
use std::thread;
use std::time::Instant;
use xtrain_autodiff::tape::Var;
use xtrain_model::{Config, TinyTransformer, ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use xtrain_train::clip::clip_grad_norm_gpu;
use xtrain_train::data::Corpus;
use xtrain_train::schedule::LrSchedule;
use crate::{DdpContext, get_unique_id};
/// Per-rank DDP training config. `batch_size` is the GLOBAL batch (split across
/// ranks); the rest mirror `xtrain_train::TrainConfig`.
#[derive(Clone)]
pub struct DdpConfig {
pub seq_len: usize,
/// Global batch size; must be divisible by the world size.
pub batch_size: usize,
pub steps: usize,
pub schedule: LrSchedule,
pub weight_decay: f32,
pub max_grad_norm: f32,
pub log_every: usize,
pub seed: u64,
}
/// Run `cfg.steps` DDP steps on this rank's `model`/`corpus`, using `ctx` for the
/// gradient all-reduce. Returns this rank's per-step mean-loss trace (the mean
/// over the GLOBAL batch — every rank computes the same value because losses are
/// all-reduced alongside the grads). The optimizer step is identical on every
/// rank, so the parameters stay in lockstep.
pub fn train_rank(
ctx: &DdpContext,
model: &TinyTransformer,
device: Device,
corpus: &Corpus,
cfg: &DdpConfig,
) -> Vec<f32> {
assert_eq!(
cfg.batch_size % ctx.world,
0,
"global batch {} not divisible by world {}",
cfg.batch_size,
ctx.world
);
let params = model.params();
let mut opt = GpuAdamW::new(cfg.weight_decay);
let mut rng = cfg.seed;
let mut losses = Vec::with_capacity(cfg.steps);
// Each rank reaches the global batch mean as (Σ_global / world) · (1/b_local),
// where b_local = batch_size / world (see DdpContext::all_reduce_average_grads).
let batch_local = cfg.batch_size / ctx.world;
let inv_batch_local = 1.0 / batch_local as f32;
let start = Instant::now();
let mut tokens_seen: u64 = 0;
for step in 0..cfg.steps {
let lr = cfg.schedule.lr(step);
// Draw the whole global batch from the shared RNG (same on every rank);
// run forward+backward only on this rank's shard. The tape SUMs the
// shard's grads; the union of shards == the single-GPU batch.
let mut local_loss_sum = 0.0f32;
for i in 0..cfg.batch_size {
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
if i % ctx.world != ctx.rank {
continue; // not this rank's sequence
}
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
let loss = model.loss(&ids, &targets);
local_loss_sum += read_scalar(&loss);
loss.backward();
tokens_seen += cfg.seq_len as u64;
}
// AllReduce(sum) + /world the grads → every rank holds Σ_global/world.
ctx.all_reduce_average_grads(&params);
// The reported loss is the global mean: average local sums across ranks.
let step_loss = all_reduce_loss(ctx, local_loss_sum) / cfg.batch_size as f32;
losses.push(step_loss);
// clip pre_scale = 1/b_local finishes the average to Σ_global/B_global,
// identical to the single-GPU clip(pre_scale = 1/B_global).
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, inv_batch_local);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
if ctx.rank == 0 && (step % cfg.log_every == 0 || step == cfg.steps - 1) {
let elapsed = start.elapsed().as_secs_f32();
// Global tok/s = per-rank tok/s × world (each rank does 1/world of it).
let tps = (tokens_seen as f32 / elapsed.max(1e-6)) * ctx.world as f32;
println!(
"[rank0] step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
({tps:.0} tok/s global, {} ranks)",
cfg.steps, ctx.world
);
}
}
losses
}
/// Spawn `world` rank threads (one per GPU in `devices`), init NCCL, build an
/// identical model per rank via `make_model`, and run `train_rank`. Returns each
/// rank's loss trace (all identical). The launcher owns the thread-per-GPU model:
/// rank 0 mints the `UniqueId`, every thread `cudaSetDevice`s its GPU, builds its
/// `Var` graph locally (the graph is `!Send`), and joins at the end.
///
/// `make_model(device)` must be deterministic — same params on every rank — for
/// the parameters to stay consistent.
pub fn launch<F>(devices: &[u32], corpus: &Corpus, cfg: &DdpConfig, make_model: F) -> Vec<Vec<f32>>
where
F: Fn(Device) -> TinyTransformer + Send + Sync,
{
let world = devices.len();
let id = get_unique_id();
thread::scope(|s| {
let handles: Vec<_> = devices
.iter()
.enumerate()
.map(|(rank, &dev)| {
let make_model = &make_model;
let cfg = cfg.clone();
s.spawn(move || {
let ctx = DdpContext::init(rank, world, id, dev);
let device = Device::Cuda(dev);
let model = make_model(device);
train_rank(&ctx, &model, device, corpus, &cfg)
})
})
.collect();
handles.into_iter().map(|h| h.join().unwrap()).collect()
})
}
/// AllReduce(sum) a single host scalar across ranks by round-tripping it through a
/// one-element device buffer. Used only for the logged/returned loss, so the cost
/// (one tiny collective per step) is negligible. Returns the summed value.
fn all_reduce_loss(ctx: &DdpContext, local: f32) -> f32 {
use xtrain_tensor::{DType, Tensor};
if ctx.world == 1 {
return local;
}
let device = Device::Cuda(ctx.device);
let t = Tensor::from_slice(&[local], &[1]).to_device(device);
ctx.all_reduce_sum_f32_ptr(t.data_ptr() as *mut std::ffi::c_void, 1);
xtrain_cuda::device::synchronize().expect("loss all-reduce sync");
t.to_device(Device::Cpu).as_slice::<f32>()[0]
}
fn read_scalar(v: &Var) -> f32 {
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
}
/// Build a `TinyTransformer` on `device` with the SAME deterministic init the
/// single-GPU `bin/train` uses (LCG fill, gammas ~1). Used by both the launcher
/// and the correctness test so every rank — and the single-GPU baseline — start
/// from bit-identical parameters. `cfg` must be identical on every call.
pub fn build_model(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
})
}
// Deterministic LCG fill in [-scale, scale) — same scheme as bin/train's `fill`.
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}

View File

@@ -16,8 +16,11 @@
#![cfg(not(no_cuda))]
pub mod ddp;
pub mod ffi;
pub use ddp::{DdpConfig, build_model, launch, train_rank};
use std::ffi::c_void;
use ffi::{NcclComm, NcclUniqueId};