dist: ddp all-reduce + sharded batch
DDP training step (train_rank) on top of DdpContext: each rank advances the SAME RNG, draws the whole global batch, and runs forward+backward only on its shard (i % world == rank) so the union over ranks is the single-GPU batch in the same order. After backward, all-reduce-average the device grads, then finish the mean with clip(pre_scale = 1/b_local) -> Sigma_global/B_global, identical to the single-GPU clip(1/B). Each rank then runs its own GpuAdamW.step; same init + same averaged grad + same optimizer state keep params bit-identical across ranks. Adds a deterministic build_model (same LCG init as bin/train) shared by ranks + baseline, a per-step loss all-reduce for the reported global-mean loss, and the thread-per-GPU launch() helper (thread::scope; Var graph is !Send so each rank builds its model thread-locally, only UniqueId/config/&Corpus cross threads). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -211,7 +211,10 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"xtrain-autodiff",
|
||||
"xtrain-cuda",
|
||||
"xtrain-model",
|
||||
"xtrain-optim",
|
||||
"xtrain-tensor",
|
||||
"xtrain-train",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -8,3 +8,6 @@ license.workspace = true
|
||||
xtrain-cuda = { path = "../xtrain-cuda" }
|
||||
xtrain-tensor = { path = "../xtrain-tensor" }
|
||||
xtrain-autodiff = { path = "../xtrain-autodiff" }
|
||||
xtrain-model = { path = "../xtrain-model" }
|
||||
xtrain-optim = { path = "../xtrain-optim" }
|
||||
xtrain-train = { path = "../xtrain-train" }
|
||||
|
||||
204
crates/xtrain-distributed/src/ddp.rs
Normal file
204
crates/xtrain-distributed/src/ddp.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
//! The DDP training step + a single-process, thread-per-GPU launcher (Phase T8).
|
||||
//!
|
||||
//! Each rank owns one GPU and one thread. Per step it processes a DISJOINT shard
|
||||
//! of the global batch, all-reduce-averages the gradients, then runs its own
|
||||
//! `GpuAdamW.step`. Identical init + identical optimizer state across ranks keep
|
||||
//! the parameters consistent — verified by the cross-rank param-identity check in
|
||||
//! the tests.
|
||||
//!
|
||||
//! Sampling matches single-GPU bit-for-bit: every rank advances the SAME RNG and
|
||||
//! draws all `B_global` sequences of a step, but only runs forward+backward on
|
||||
//! the ones assigned to it (`global index % world == rank`). The union over ranks
|
||||
//! is exactly the single-GPU batch in the same order, so the all-reduced grad sum
|
||||
//! equals the single-GPU summed grad.
|
||||
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor};
|
||||
use xtrain_optim::GpuAdamW;
|
||||
use xtrain_tensor::Device;
|
||||
use xtrain_train::clip::clip_grad_norm_gpu;
|
||||
use xtrain_train::data::Corpus;
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
|
||||
use crate::{DdpContext, get_unique_id};
|
||||
|
||||
/// Per-rank DDP training config. `batch_size` is the GLOBAL batch (split across
|
||||
/// ranks); the rest mirror `xtrain_train::TrainConfig`.
|
||||
#[derive(Clone)]
|
||||
pub struct DdpConfig {
|
||||
pub seq_len: usize,
|
||||
/// Global batch size; must be divisible by the world size.
|
||||
pub batch_size: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
pub max_grad_norm: f32,
|
||||
pub log_every: usize,
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Run `cfg.steps` DDP steps on this rank's `model`/`corpus`, using `ctx` for the
|
||||
/// gradient all-reduce. Returns this rank's per-step mean-loss trace (the mean
|
||||
/// over the GLOBAL batch — every rank computes the same value because losses are
|
||||
/// all-reduced alongside the grads). The optimizer step is identical on every
|
||||
/// rank, so the parameters stay in lockstep.
|
||||
pub fn train_rank(
|
||||
ctx: &DdpContext,
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
corpus: &Corpus,
|
||||
cfg: &DdpConfig,
|
||||
) -> Vec<f32> {
|
||||
assert_eq!(
|
||||
cfg.batch_size % ctx.world,
|
||||
0,
|
||||
"global batch {} not divisible by world {}",
|
||||
cfg.batch_size,
|
||||
ctx.world
|
||||
);
|
||||
let params = model.params();
|
||||
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
||||
let mut rng = cfg.seed;
|
||||
let mut losses = Vec::with_capacity(cfg.steps);
|
||||
// Each rank reaches the global batch mean as (Σ_global / world) · (1/b_local),
|
||||
// where b_local = batch_size / world (see DdpContext::all_reduce_average_grads).
|
||||
let batch_local = cfg.batch_size / ctx.world;
|
||||
let inv_batch_local = 1.0 / batch_local as f32;
|
||||
let start = Instant::now();
|
||||
let mut tokens_seen: u64 = 0;
|
||||
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Draw the whole global batch from the shared RNG (same on every rank);
|
||||
// run forward+backward only on this rank's shard. The tape SUMs the
|
||||
// shard's grads; the union of shards == the single-GPU batch.
|
||||
let mut local_loss_sum = 0.0f32;
|
||||
for i in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
if i % ctx.world != ctx.rank {
|
||||
continue; // not this rank's sequence
|
||||
}
|
||||
let ids = ids_tensor(&input, device);
|
||||
let targets = ids_tensor(&target, device);
|
||||
let loss = model.loss(&ids, &targets);
|
||||
local_loss_sum += read_scalar(&loss);
|
||||
loss.backward();
|
||||
tokens_seen += cfg.seq_len as u64;
|
||||
}
|
||||
|
||||
// AllReduce(sum) + /world the grads → every rank holds Σ_global/world.
|
||||
ctx.all_reduce_average_grads(¶ms);
|
||||
// The reported loss is the global mean: average local sums across ranks.
|
||||
let step_loss = all_reduce_loss(ctx, local_loss_sum) / cfg.batch_size as f32;
|
||||
losses.push(step_loss);
|
||||
|
||||
// clip pre_scale = 1/b_local finishes the average to Σ_global/B_global,
|
||||
// identical to the single-GPU clip(pre_scale = 1/B_global).
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, inv_batch_local);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
|
||||
if ctx.rank == 0 && (step % cfg.log_every == 0 || step == cfg.steps - 1) {
|
||||
let elapsed = start.elapsed().as_secs_f32();
|
||||
// Global tok/s = per-rank tok/s × world (each rank does 1/world of it).
|
||||
let tps = (tokens_seen as f32 / elapsed.max(1e-6)) * ctx.world as f32;
|
||||
println!(
|
||||
"[rank0] step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
|
||||
({tps:.0} tok/s global, {} ranks)",
|
||||
cfg.steps, ctx.world
|
||||
);
|
||||
}
|
||||
}
|
||||
losses
|
||||
}
|
||||
|
||||
/// Spawn `world` rank threads (one per GPU in `devices`), init NCCL, build an
|
||||
/// identical model per rank via `make_model`, and run `train_rank`. Returns each
|
||||
/// rank's loss trace (all identical). The launcher owns the thread-per-GPU model:
|
||||
/// rank 0 mints the `UniqueId`, every thread `cudaSetDevice`s its GPU, builds its
|
||||
/// `Var` graph locally (the graph is `!Send`), and joins at the end.
|
||||
///
|
||||
/// `make_model(device)` must be deterministic — same params on every rank — for
|
||||
/// the parameters to stay consistent.
|
||||
pub fn launch<F>(devices: &[u32], corpus: &Corpus, cfg: &DdpConfig, make_model: F) -> Vec<Vec<f32>>
|
||||
where
|
||||
F: Fn(Device) -> TinyTransformer + Send + Sync,
|
||||
{
|
||||
let world = devices.len();
|
||||
let id = get_unique_id();
|
||||
|
||||
thread::scope(|s| {
|
||||
let handles: Vec<_> = devices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(rank, &dev)| {
|
||||
let make_model = &make_model;
|
||||
let cfg = cfg.clone();
|
||||
s.spawn(move || {
|
||||
let ctx = DdpContext::init(rank, world, id, dev);
|
||||
let device = Device::Cuda(dev);
|
||||
let model = make_model(device);
|
||||
train_rank(&ctx, &model, device, corpus, &cfg)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
handles.into_iter().map(|h| h.join().unwrap()).collect()
|
||||
})
|
||||
}
|
||||
|
||||
/// AllReduce(sum) a single host scalar across ranks by round-tripping it through a
|
||||
/// one-element device buffer. Used only for the logged/returned loss, so the cost
|
||||
/// (one tiny collective per step) is negligible. Returns the summed value.
|
||||
fn all_reduce_loss(ctx: &DdpContext, local: f32) -> f32 {
|
||||
use xtrain_tensor::{DType, Tensor};
|
||||
if ctx.world == 1 {
|
||||
return local;
|
||||
}
|
||||
let device = Device::Cuda(ctx.device);
|
||||
let t = Tensor::from_slice(&[local], &[1]).to_device(device);
|
||||
ctx.all_reduce_sum_f32_ptr(t.data_ptr() as *mut std::ffi::c_void, 1);
|
||||
xtrain_cuda::device::synchronize().expect("loss all-reduce sync");
|
||||
t.to_device(Device::Cpu).as_slice::<f32>()[0]
|
||||
}
|
||||
|
||||
fn read_scalar(v: &Var) -> f32 {
|
||||
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
|
||||
}
|
||||
|
||||
/// Build a `TinyTransformer` on `device` with the SAME deterministic init the
|
||||
/// single-GPU `bin/train` uses (LCG fill, gammas ~1). Used by both the launcher
|
||||
/// and the correctness test so every rank — and the single-GPU baseline — start
|
||||
/// from bit-identical parameters. `cfg` must be identical on every call.
|
||||
pub fn build_model(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Deterministic LCG fill in [-scale, scale) — same scheme as bin/train's `fill`.
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -16,8 +16,11 @@
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
pub mod ddp;
|
||||
pub mod ffi;
|
||||
|
||||
pub use ddp::{DdpConfig, build_model, launch, train_rank};
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
use ffi::{NcclComm, NcclUniqueId};
|
||||
|
||||
Reference in New Issue
Block a user