Accumulate grads over N micro-batches, then one AdamW step + zero_grad, for an effective batch of N×micro at one micro-batch's activation cost. Each micro-loss is scaled by 1/N before backward (the tape SUM-accumulates the scaled grads) so the boundary grad equals a single step over an N× batch. accum==1 skips the scale → bit-identical to the pre-T16 path. DDP: the cross-rank all-reduce fires ONLY at the accumulation boundary (intermediate micro-steps are local-only, no NCCL); the /world average is orthogonal to the per-micro 1/N, so the boundary grad is the effective global-batch mean. New --accum-steps flag in both train binaries; effective batch is printed. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
300 lines
13 KiB
Rust
300 lines
13 KiB
Rust
//! The DDP training step + a single-process, thread-per-GPU launcher (Phase T8).
|
||
//!
|
||
//! Each rank owns one GPU and one thread. Per step it processes a DISJOINT shard
|
||
//! of the global batch, all-reduce-averages the gradients, then runs its own
|
||
//! `GpuAdamW.step`. Identical init + identical optimizer state across ranks keep
|
||
//! the parameters consistent — verified by the cross-rank param-identity check in
|
||
//! the tests.
|
||
//!
|
||
//! Sampling matches single-GPU bit-for-bit: every rank advances the SAME RNG and
|
||
//! draws all `B_global` sequences of a step, but only runs forward+backward on
|
||
//! the ones assigned to it (`global index % world == rank`). The union over ranks
|
||
//! is exactly the single-GPU batch in the same order, so the all-reduced grad sum
|
||
//! equals the single-GPU summed grad.
|
||
|
||
use std::path::PathBuf;
|
||
use std::thread;
|
||
use std::time::Instant;
|
||
|
||
use xtrain_autodiff::tape::Var;
|
||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||
use xtrain_optim::GpuAdamW;
|
||
use xtrain_tensor::Device;
|
||
use xtrain_train::checkpoint;
|
||
use xtrain_train::clip::clip_grad_norm_gpu;
|
||
use xtrain_train::data::Corpus;
|
||
use xtrain_train::eval_loss;
|
||
use xtrain_train::schedule::LrSchedule;
|
||
|
||
use crate::{DdpContext, get_unique_id};
|
||
|
||
/// Per-rank DDP training config. `batch_size` is the GLOBAL batch (split across
|
||
/// ranks); the rest mirror `xtrain_train::TrainConfig`.
|
||
#[derive(Clone)]
|
||
pub struct DdpConfig {
|
||
pub seq_len: usize,
|
||
/// Global batch size; must be divisible by the world size.
|
||
pub batch_size: usize,
|
||
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||
/// accumulates grads over `accum_steps` micro-batches, giving an EFFECTIVE
|
||
/// global batch of `accum_steps × batch_size`. The cross-rank all-reduce
|
||
/// fires ONLY at the accumulation boundary (after the last micro-step) —
|
||
/// intermediate micro-steps skip the NCCL collective entirely. `1` = no
|
||
/// accumulation (bit-identical to the pre-T16 DDP path).
|
||
pub accum_steps: usize,
|
||
pub steps: usize,
|
||
pub schedule: LrSchedule,
|
||
pub weight_decay: f32,
|
||
pub max_grad_norm: f32,
|
||
pub log_every: usize,
|
||
pub seed: u64,
|
||
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Only rank
|
||
/// 0 holds the `valid` corpus and runs the eval (no grad), mirroring
|
||
/// `xtrain_train::TrainConfig`. The best-val model is checkpointed by rank 0
|
||
/// (every rank's params are identical, so rank 0's are the model's).
|
||
pub eval_every: usize,
|
||
pub eval_batches: usize,
|
||
/// Best-val checkpoint path (written by rank 0 when val improves). When unset,
|
||
/// or when `eval_every == 0`, no checkpoint is written.
|
||
pub ckpt_path: Option<PathBuf>,
|
||
}
|
||
|
||
/// Outcome of a DDP run on this rank: per-step mean-loss trace plus, when
|
||
/// `eval_every > 0`, the (step, val_loss) eval points and the best val loss
|
||
/// (eval/best are only populated on rank 0, which owns the `valid` corpus).
|
||
pub struct DdpResult {
|
||
pub losses: Vec<f32>,
|
||
pub evals: Vec<(usize, f32)>,
|
||
pub best_val: Option<f32>,
|
||
}
|
||
|
||
/// Run `cfg.steps` DDP steps on this rank's `model`/`corpus`, using `ctx` for the
|
||
/// gradient all-reduce. Returns this rank's per-step mean-loss trace (the mean
|
||
/// over the GLOBAL batch — every rank computes the same value because losses are
|
||
/// all-reduced alongside the grads) plus eval/best-val (rank 0 only). The
|
||
/// optimizer step is identical on every rank, so the parameters stay in lockstep.
|
||
///
|
||
/// `valid` is the held-out corpus for periodic val-loss eval. Only rank 0 needs
|
||
/// it (it runs the no-grad eval and writes the best-val checkpoint); pass `None`
|
||
/// on the other ranks (or when `cfg.eval_every == 0`).
|
||
pub fn train_rank(
|
||
ctx: &DdpContext,
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
corpus: &Corpus,
|
||
valid: Option<&Corpus>,
|
||
cfg: &DdpConfig,
|
||
) -> DdpResult {
|
||
assert_eq!(
|
||
cfg.batch_size % ctx.world,
|
||
0,
|
||
"global batch {} not divisible by world {}",
|
||
cfg.batch_size,
|
||
ctx.world
|
||
);
|
||
let params = model.params();
|
||
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
||
let mut rng = cfg.seed;
|
||
let mut losses = Vec::with_capacity(cfg.steps);
|
||
let mut evals = Vec::new();
|
||
let mut best_val: Option<f32> = None;
|
||
// Each rank runs ONE batched forward over its b_local = batch_size/world
|
||
// sequences → backward grad = local mean (Σ_local / b_local). all_reduce_average
|
||
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
|
||
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
|
||
let batch_local = cfg.batch_size / ctx.world;
|
||
let accum = cfg.accum_steps.max(1);
|
||
let start = Instant::now();
|
||
let mut tokens_seen: u64 = 0;
|
||
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
|
||
// across ranks, so rank 0's are the model). Other ranks never touch `valid`.
|
||
let do_eval = ctx.rank == 0 && cfg.eval_every > 0 && valid.is_some();
|
||
|
||
for step in 0..cfg.steps {
|
||
let lr = cfg.schedule.lr(step);
|
||
|
||
// Accumulate grads over `accum` micro-batches, then ONE optimizer step
|
||
// (Phase T16). Per micro-batch: draw the whole micro global batch from the
|
||
// shared RNG (same on every rank), keep only this rank's shard (global index
|
||
// % world == rank), run it as ONE batched forward/backward. Each micro-loss
|
||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||
// grads across the `accum` micro-backwards) so the boundary grad equals a
|
||
// single step over an `accum × batch_size` global batch. `accum == 1` skips
|
||
// the scale → bit-identical to the pre-T16 DDP path. The cross-rank
|
||
// all-reduce fires ONLY after the last micro-step (intermediate micro-steps
|
||
// are local-only, no NCCL).
|
||
let mut local_sum = 0.0f32; // Σ over micro of (local_mean · b_local)
|
||
for _ in 0..accum {
|
||
let mut inputs = Vec::with_capacity(batch_local);
|
||
let mut targets_v = Vec::with_capacity(batch_local);
|
||
for i in 0..cfg.batch_size {
|
||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||
if i % ctx.world == ctx.rank {
|
||
inputs.push(input);
|
||
targets_v.push(target);
|
||
}
|
||
}
|
||
let ids = batched_ids_tensor(&inputs, device);
|
||
let targets = batched_ids_tensor(&targets_v, device);
|
||
let loss = model.loss_batched(&ids, &targets, batch_local);
|
||
local_sum += read_scalar(&loss) * batch_local as f32; // local mean·b_local
|
||
if accum == 1 {
|
||
loss.backward();
|
||
} else {
|
||
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||
}
|
||
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
||
}
|
||
|
||
// Accumulation boundary: ONE AllReduce(sum) + /world over the accumulated
|
||
// grads → every rank holds the effective-batch (accum·B_global) mean grad
|
||
// (the per-micro 1/accum scaling is already baked into each backward; the
|
||
// /world here is orthogonal to accum). Intermediate micro-steps issued NO
|
||
// NCCL — only this single boundary collective per optimizer step.
|
||
ctx.all_reduce_average_grads(¶ms);
|
||
// Reported loss = effective-batch mean: AllReduce(sum) the per-rank local
|
||
// sums across ranks, /(accum·B_global).
|
||
let step_loss = all_reduce_loss(ctx, local_sum) / (accum * cfg.batch_size) as f32;
|
||
losses.push(step_loss);
|
||
|
||
// Grads are already the effective-batch mean — just clip (pre-scale 1.0).
|
||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||
opt.step(lr, ¶ms);
|
||
for p in ¶ms {
|
||
p.zero_grad();
|
||
}
|
||
|
||
if ctx.rank == 0 && (step % cfg.log_every == 0 || step == cfg.steps - 1) {
|
||
let elapsed = start.elapsed().as_secs_f32();
|
||
// Global tok/s = per-rank tok/s × world (each rank does 1/world of it).
|
||
let tps = (tokens_seen as f32 / elapsed.max(1e-6)) * ctx.world as f32;
|
||
println!(
|
||
"[rank0] step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
|
||
({tps:.0} tok/s global, {} ranks)",
|
||
cfg.steps, ctx.world
|
||
);
|
||
}
|
||
|
||
// Periodic held-out eval + best-val checkpoint (rank 0 only). Mirrors the
|
||
// single-GPU `xtrain_train::train` loop, reusing its `eval_loss` /
|
||
// `checkpoint::save` so single-GPU and DDP share one eval/ckpt path. Other
|
||
// ranks have nothing to do here (params are identical across ranks).
|
||
if do_eval && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
|
||
let v = valid.unwrap();
|
||
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
|
||
evals.push((step, vl));
|
||
let improved = best_val.map(|b| vl < b).unwrap_or(true);
|
||
println!(
|
||
" [rank0] eval @ step {step}: val loss {vl:.4}{}",
|
||
if improved { " (best)" } else { "" }
|
||
);
|
||
if improved {
|
||
best_val = Some(vl);
|
||
if let Some(path) = &cfg.ckpt_path {
|
||
checkpoint::save(path, ¶ms).expect("best checkpoint save");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
DdpResult {
|
||
losses,
|
||
evals,
|
||
best_val,
|
||
}
|
||
}
|
||
|
||
/// Spawn `world` rank threads (one per GPU in `devices`), init NCCL, build an
|
||
/// identical model per rank via `make_model`, and run `train_rank`. Returns each
|
||
/// rank's `DdpResult` (loss traces are identical; eval/best-val are on rank 0).
|
||
/// The launcher owns the thread-per-GPU model: rank 0 mints the `UniqueId`, every
|
||
/// thread `cudaSetDevice`s its GPU, builds its `Var` graph locally (the graph is
|
||
/// `!Send`), and joins at the end.
|
||
///
|
||
/// `valid` is the held-out corpus for rank 0's periodic eval (only used when
|
||
/// `cfg.eval_every > 0`). `make_model(device)` must be deterministic — same params
|
||
/// on every rank — for the parameters to stay consistent.
|
||
pub fn launch<F>(
|
||
devices: &[u32],
|
||
corpus: &Corpus,
|
||
valid: Option<&Corpus>,
|
||
cfg: &DdpConfig,
|
||
make_model: F,
|
||
) -> Vec<DdpResult>
|
||
where
|
||
F: Fn(Device) -> TinyTransformer + Send + Sync,
|
||
{
|
||
let world = devices.len();
|
||
let id = get_unique_id();
|
||
|
||
thread::scope(|s| {
|
||
let handles: Vec<_> = devices
|
||
.iter()
|
||
.enumerate()
|
||
.map(|(rank, &dev)| {
|
||
let make_model = &make_model;
|
||
let cfg = cfg.clone();
|
||
s.spawn(move || {
|
||
let ctx = DdpContext::init(rank, world, id, dev);
|
||
let device = Device::Cuda(dev);
|
||
let model = make_model(device);
|
||
// Only rank 0 holds the val corpus for eval.
|
||
let v = if rank == 0 { valid } else { None };
|
||
train_rank(&ctx, &model, device, corpus, v, &cfg)
|
||
})
|
||
})
|
||
.collect();
|
||
handles.into_iter().map(|h| h.join().unwrap()).collect()
|
||
})
|
||
}
|
||
|
||
/// AllReduce(sum) a single host scalar across ranks by round-tripping it through a
|
||
/// one-element device buffer. Used only for the logged/returned loss, so the cost
|
||
/// (one tiny collective per step) is negligible. Returns the summed value.
|
||
fn all_reduce_loss(ctx: &DdpContext, local: f32) -> f32 {
|
||
use xtrain_tensor::Tensor;
|
||
if ctx.world == 1 {
|
||
return local;
|
||
}
|
||
let device = Device::Cuda(ctx.device);
|
||
let t = Tensor::from_slice(&[local], &[1]).to_device(device);
|
||
ctx.all_reduce_sum_f32_ptr(t.data_ptr() as *mut std::ffi::c_void, 1);
|
||
xtrain_cuda::device::synchronize().expect("loss all-reduce sync");
|
||
t.to_device(Device::Cpu).as_slice::<f32>()[0]
|
||
}
|
||
|
||
fn read_scalar(v: &Var) -> f32 {
|
||
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
|
||
}
|
||
|
||
/// Build a `TinyTransformer` on `device` with the SAME deterministic init the
|
||
/// single-GPU `bin/train` uses (LCG fill, gammas ~1). Used by both the launcher
|
||
/// and the correctness test so every rank — and the single-GPU baseline — start
|
||
/// from bit-identical parameters. `cfg` must be identical on every call.
|
||
pub fn build_model(cfg: Config, device: Device) -> TinyTransformer {
|
||
let mut seed = 1u64;
|
||
TinyTransformer::new(cfg, device, |shape| {
|
||
seed = seed.wrapping_add(1);
|
||
let n: usize = shape.iter().product();
|
||
if shape.len() == 1 {
|
||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||
} else {
|
||
fill(n, seed, 0.04)
|
||
}
|
||
})
|
||
}
|
||
|
||
// Deterministic LCG fill in [-scale, scale) — same scheme as bin/train's `fill`.
|
||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||
let mut state = seed
|
||
.wrapping_mul(2862933555777941757)
|
||
.wrapping_add(3037000493);
|
||
(0..n)
|
||
.map(|_| {
|
||
state = state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||
})
|
||
.collect()
|
||
}
|