train: bring DDP trainer to parity with bin/train (val + checkpoint + cache + arch)

The T8 DDP path now matches the single-GPU `bin/train`: CLI-tunable arch
(scaling-ladder rung), the cached token-id stream (`load_cached`), held-out
val-loss eval + best-val checkpointing, and LR warmup→cosine. Rank 0 owns the
val corpus and runs the no-grad eval / writes the best checkpoint (params are
bit-identical across ranks). The eval/checkpoint logic is reused from
`xtrain-train` (`eval_loss`, `checkpoint::save`) rather than duplicated.

- DdpConfig gains eval_every / eval_batches / ckpt_path.
- train_rank takes `valid: Option<&Corpus>` and returns DdpResult
  (losses + evals + best_val); launch threads the val corpus to rank 0 only.
- bin/train_ddp reworked to the bin/train CLI (positional tokenizer/corpus +
  --dim/--heads/--head-dim/--layers/--ffn/--steps/--batch/--seq/--max-lr/
  --val-tokens/--eval-every/--ckpt), reusing the u16 cache.
- DDP correctness test updated to the new signatures (semantics unchanged).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 19:34:40 +08:00
parent 264660527f
commit 7090b475fb
4 changed files with 207 additions and 52 deletions

View File

@@ -1,24 +1,43 @@
//! Multi-rank DDP training launcher (Phase T8): spawn one thread per GPU, NCCL
//! all-reduce the gradients each step, and train the tiny transformer on
//! TinyStories. Doubles as the throughput driver — run it with 1/2/4 GPUs and
//! read the global tok/s line.
//! Multi-rank DDP training launcher (Phase T8 / Scaling v2): spawn one thread per
//! GPU, NCCL all-reduce the gradients each step, and train the tiny transformer on
//! TinyStories. At parity with the single-GPU `bin/train`: CLI-tunable arch
//! (scaling-ladder rung), the cached token-id stream, held-out val-loss eval, LR
//! warmup→cosine, grad clip, and best-val checkpointing. Doubles as the throughput
//! driver — run it with 1/2/4 GPUs and read the global tok/s line.
//!
//! Run on dash5 (pick idle GPUs — dash5 is shared):
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
//! CUDA_VISIBLE_DEVICES=0,1 cargo run -p xtrain-distributed --release \
//! --bin train_ddp -- 100 64 16
//! Args: [steps] [seq_len] [global_batch] [tokenizer.json] [corpus.txt]
//! The launcher uses every GPU visible to it (CUDA_VISIBLE_DEVICES selects them),
//! so the rank devices are always 0..N within the visible set.
//! CUDA_VISIBLE_DEVICES=1,2 cargo run -p xtrain-distributed --release \
//! --bin train_ddp -- /opt/wjh/models/gpt2/tokenizer.json \
//! data/tinystories-train.txt \
//! --dim 384 --heads 12 --head-dim 32 --layers 12 --ffn 1536 \
//! --steps 6000 --batch 32 --seq 256 --max-lr 6e-4 \
//! --val-tokens 1000000 --eval-every 500 --ckpt /tmp/xtrain_v2.ckpt
//!
//! Positional: <tokenizer.json> <corpus.txt>. Everything else is a flag with a
//! sane default. The launcher uses every GPU visible to it (CUDA_VISIBLE_DEVICES
//! selects them), so rank devices are always 0..N within the visible set.
#[cfg(no_cuda)]
fn main() {
eprintln!("train_ddp: built without CUDA (no_cuda); run on a GPU host (dash5).");
}
#[cfg(not(no_cuda))]
use std::path::PathBuf;
// A flag like `--dim 384`: scan argv for `name`, parse the following token.
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn main() {
use std::path::PathBuf;
use xtrain_cuda::device;
use xtrain_distributed::{DdpConfig, build_model, launch};
use xtrain_model::Config;
@@ -26,18 +45,49 @@ fn main() {
use xtrain_train::schedule::LrSchedule;
let args: Vec<String> = std::env::args().collect();
let steps: usize = args.get(1).and_then(|s| s.parse().ok()).unwrap_or(100);
let seq_len: usize = args.get(2).and_then(|s| s.parse().ok()).unwrap_or(64);
let batch: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(16);
let tok_path = args
.get(4)
.map(PathBuf::from)
// First two non-flag positionals: tokenizer.json, corpus.txt.
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let tok_path = positionals
.first()
.map(|s| PathBuf::from(s.as_str()))
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
let corpus_path = args
.get(5)
.map(PathBuf::from)
let corpus_path = positionals
.get(1)
.map(|s| PathBuf::from(s.as_str()))
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
// Architecture (scaling-ladder rung). Defaults = v0-baseline tiny config.
let n_heads = flag(&args, "--heads", 2usize);
let head_dim = flag(&args, "--head-dim", 16usize);
let n_layers = flag(&args, "--layers", 4usize);
let ffn = flag(&args, "--ffn", 64usize);
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
let dim_flag = flag(&args, "--dim", 0usize);
if dim_flag != 0 && dim_flag != n_heads * head_dim {
eprintln!(
"warning: --dim {dim_flag} != heads*head_dim {}; using {}",
n_heads * head_dim,
n_heads * head_dim
);
}
// Optimization knobs (mirror bin/train).
let steps: usize = flag(&args, "--steps", 100);
let batch: usize = flag(&args, "--batch", 16);
let seq_len: usize = flag(&args, "--seq", 64);
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
let weight_decay: f32 = flag(&args, "--wd", 0.1);
let max_grad_norm: f32 = flag(&args, "--clip", 1.0);
let val_tokens: usize = flag(&args, "--val-tokens", 0);
let eval_every: usize = flag(&args, "--eval-every", 0);
let eval_batches: usize = flag(&args, "--eval-batches", 64);
let ckpt: Option<PathBuf> = args
.iter()
.position(|a| a == "--ckpt")
.and_then(|i| args.get(i + 1))
.map(PathBuf::from);
// Use every visible GPU as a rank (CUDA_VISIBLE_DEVICES selects the set;
// device ordinals are 0..count within it).
let count = device::device_count().expect("device_count") as u32;
@@ -56,23 +106,35 @@ fn main() {
devices
);
let corpus = Corpus::load(&tok_path, &corpus_path);
// Reuse the cached token-id stream (v1's u16 cache); never re-tokenize 2GB.
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
println!(
"corpus: {} tokens, vocab {}",
corpus.len(),
corpus.vocab_size
);
let vocab = corpus.vocab_size;
// Hold out a tail slice for validation (rank 0 evaluates on it).
let (train_corpus, valid) = if val_tokens > 0 {
let (t, v) = corpus.split_tail(val_tokens);
println!("split: {} train tokens / {} val tokens", t.len(), v.len());
(t, Some(v))
} else {
(corpus, None)
};
let mut cfg = Config::tiny();
cfg.vocab = corpus.vocab_size;
cfg.n_layers = 4;
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
println!(
"model: dim {} layers {} heads {} ffn {}{} params",
"model: dim {} layers {} heads {} head_dim {} ffn {}core {:.3}M params \
(+ embed/lm {:.2}M = {:.2}M total)",
cfg.dim,
cfg.n_layers,
cfg.n_heads,
cfg.head_dim,
cfg.ffn_hidden,
cfg.num_params()
cfg.core_params() as f32 / 1e6,
(cfg.num_params() - cfg.core_params()) as f32 / 1e6,
cfg.num_params() as f32 / 1e6,
);
let dcfg = DdpConfig {
@@ -80,22 +142,43 @@ fn main() {
batch_size: batch,
steps,
schedule: LrSchedule {
max_lr: 3e-3,
min_lr: 3e-4,
max_lr,
min_lr,
warmup: (steps / 20).max(5),
total: steps,
},
weight_decay: 0.1,
max_grad_norm: 1.0,
log_every: 10,
weight_decay,
max_grad_norm,
log_every: 50,
seed: 42,
eval_every,
eval_batches,
ckpt_path: ckpt.clone(),
};
let traces = launch(&devices, &corpus, &dcfg, move |device| {
build_model(cfg, device)
});
let trace = &traces[0];
let start = trace.first().copied().unwrap_or(0.0);
let end = trace.last().copied().unwrap_or(0.0);
println!("loss: start {start:.4} → end {end:.4}");
println!(
"training: {steps} steps, seq {seq_len}, global batch {batch}, lr {max_lr:.1e}{min_lr:.1e}, \
eval every {eval_every}"
);
let results = launch(
&devices,
&train_corpus,
valid.as_ref(),
&dcfg,
move |device| build_model(cfg, device),
);
let r0 = &results[0];
let start = r0.losses.first().copied().unwrap_or(0.0);
let end = r0.losses.last().copied().unwrap_or(0.0);
println!("train loss: start {start:.4} → end {end:.4}");
if let Some(best) = r0.best_val {
println!("best val loss: {best:.4}");
}
if let Some((s, v)) = r0.evals.last() {
println!("final val loss (step {s}): {v:.4}");
}
if let Some(path) = &ckpt {
println!("best-val checkpoint → {}", path.display());
}
}

View File

@@ -12,6 +12,7 @@
//! is exactly the single-GPU batch in the same order, so the all-reduced grad sum
//! equals the single-GPU summed grad.
use std::path::PathBuf;
use std::thread;
use std::time::Instant;
@@ -19,8 +20,10 @@ use xtrain_autodiff::tape::Var;
use xtrain_model::{Config, TinyTransformer, ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use xtrain_train::checkpoint;
use xtrain_train::clip::clip_grad_norm_gpu;
use xtrain_train::data::Corpus;
use xtrain_train::eval_loss;
use xtrain_train::schedule::LrSchedule;
use crate::{DdpContext, get_unique_id};
@@ -38,20 +41,43 @@ pub struct DdpConfig {
pub max_grad_norm: f32,
pub log_every: usize,
pub seed: u64,
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Only rank
/// 0 holds the `valid` corpus and runs the eval (no grad), mirroring
/// `xtrain_train::TrainConfig`. The best-val model is checkpointed by rank 0
/// (every rank's params are identical, so rank 0's are the model's).
pub eval_every: usize,
pub eval_batches: usize,
/// Best-val checkpoint path (written by rank 0 when val improves). When unset,
/// or when `eval_every == 0`, no checkpoint is written.
pub ckpt_path: Option<PathBuf>,
}
/// Outcome of a DDP run on this rank: per-step mean-loss trace plus, when
/// `eval_every > 0`, the (step, val_loss) eval points and the best val loss
/// (eval/best are only populated on rank 0, which owns the `valid` corpus).
pub struct DdpResult {
pub losses: Vec<f32>,
pub evals: Vec<(usize, f32)>,
pub best_val: Option<f32>,
}
/// Run `cfg.steps` DDP steps on this rank's `model`/`corpus`, using `ctx` for the
/// gradient all-reduce. Returns this rank's per-step mean-loss trace (the mean
/// over the GLOBAL batch — every rank computes the same value because losses are
/// all-reduced alongside the grads). The optimizer step is identical on every
/// rank, so the parameters stay in lockstep.
/// all-reduced alongside the grads) plus eval/best-val (rank 0 only). The
/// optimizer step is identical on every rank, so the parameters stay in lockstep.
///
/// `valid` is the held-out corpus for periodic val-loss eval. Only rank 0 needs
/// it (it runs the no-grad eval and writes the best-val checkpoint); pass `None`
/// on the other ranks (or when `cfg.eval_every == 0`).
pub fn train_rank(
ctx: &DdpContext,
model: &TinyTransformer,
device: Device,
corpus: &Corpus,
valid: Option<&Corpus>,
cfg: &DdpConfig,
) -> Vec<f32> {
) -> DdpResult {
assert_eq!(
cfg.batch_size % ctx.world,
0,
@@ -63,12 +89,17 @@ pub fn train_rank(
let mut opt = GpuAdamW::new(cfg.weight_decay);
let mut rng = cfg.seed;
let mut losses = Vec::with_capacity(cfg.steps);
let mut evals = Vec::new();
let mut best_val: Option<f32> = None;
// Each rank reaches the global batch mean as (Σ_global / world) · (1/b_local),
// where b_local = batch_size / world (see DdpContext::all_reduce_average_grads).
let batch_local = cfg.batch_size / ctx.world;
let inv_batch_local = 1.0 / batch_local as f32;
let start = Instant::now();
let mut tokens_seen: u64 = 0;
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
// across ranks, so rank 0's are the model). Other ranks never touch `valid`.
let do_eval = ctx.rank == 0 && cfg.eval_every > 0 && valid.is_some();
for step in 0..cfg.steps {
let lr = cfg.schedule.lr(step);
@@ -114,19 +145,52 @@ pub fn train_rank(
cfg.steps, ctx.world
);
}
// Periodic held-out eval + best-val checkpoint (rank 0 only). Mirrors the
// single-GPU `xtrain_train::train` loop, reusing its `eval_loss` /
// `checkpoint::save` so single-GPU and DDP share one eval/ckpt path. Other
// ranks have nothing to do here (params are identical across ranks).
if do_eval && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
let v = valid.unwrap();
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
evals.push((step, vl));
let improved = best_val.map(|b| vl < b).unwrap_or(true);
println!(
" [rank0] eval @ step {step}: val loss {vl:.4}{}",
if improved { " (best)" } else { "" }
);
if improved {
best_val = Some(vl);
if let Some(path) = &cfg.ckpt_path {
checkpoint::save(path, &params).expect("best checkpoint save");
}
}
}
}
DdpResult {
losses,
evals,
best_val,
}
losses
}
/// Spawn `world` rank threads (one per GPU in `devices`), init NCCL, build an
/// identical model per rank via `make_model`, and run `train_rank`. Returns each
/// rank's loss trace (all identical). The launcher owns the thread-per-GPU model:
/// rank 0 mints the `UniqueId`, every thread `cudaSetDevice`s its GPU, builds its
/// `Var` graph locally (the graph is `!Send`), and joins at the end.
/// rank's `DdpResult` (loss traces are identical; eval/best-val are on rank 0).
/// The launcher owns the thread-per-GPU model: rank 0 mints the `UniqueId`, every
/// thread `cudaSetDevice`s its GPU, builds its `Var` graph locally (the graph is
/// `!Send`), and joins at the end.
///
/// `make_model(device)` must be deterministic — same params on every rank — for
/// the parameters to stay consistent.
pub fn launch<F>(devices: &[u32], corpus: &Corpus, cfg: &DdpConfig, make_model: F) -> Vec<Vec<f32>>
/// `valid` is the held-out corpus for rank 0's periodic eval (only used when
/// `cfg.eval_every > 0`). `make_model(device)` must be deterministic — same params
/// on every rank — for the parameters to stay consistent.
pub fn launch<F>(
devices: &[u32],
corpus: &Corpus,
valid: Option<&Corpus>,
cfg: &DdpConfig,
make_model: F,
) -> Vec<DdpResult>
where
F: Fn(Device) -> TinyTransformer + Send + Sync,
{
@@ -144,7 +208,9 @@ where
let ctx = DdpContext::init(rank, world, id, dev);
let device = Device::Cuda(dev);
let model = make_model(device);
train_rank(&ctx, &model, device, corpus, &cfg)
// Only rank 0 holds the val corpus for eval.
let v = if rank == 0 { valid } else { None };
train_rank(&ctx, &model, device, corpus, v, &cfg)
})
})
.collect();

View File

@@ -19,7 +19,7 @@
pub mod ddp;
pub mod ffi;
pub use ddp::{DdpConfig, build_model, launch, train_rank};
pub use ddp::{DdpConfig, DdpResult, build_model, launch, train_rank};
use std::ffi::c_void;

View File

@@ -102,6 +102,9 @@ fn ddp_matches_single_gpu_and_params_consistent() {
max_grad_norm: 1.0,
log_every: 1_000_000, // silence per-step logging in the test
seed: 7,
eval_every: 0,
eval_batches: 0,
ckpt_path: None,
};
// Single-GPU baseline (world=1) over the global batch.
@@ -121,13 +124,13 @@ fn ddp_matches_single_gpu_and_params_consistent() {
let ctx = DdpContext::init(rank, world, id, dev);
let device = Device::Cuda(dev);
let model = build_model(cfg, device);
let losses = train_rank(&ctx, &model, device, corpus, &dcfg);
let res = train_rank(&ctx, &model, device, corpus, None, &dcfg);
let host = model
.params()
.iter()
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
.collect::<Vec<_>>();
(losses, host)
(res.losses, host)
})
})
.collect();
@@ -224,10 +227,13 @@ fn ddp_throughput_scaling() {
max_grad_norm: 1.0,
log_every: 1_000_000,
seed: 1,
eval_every: 0,
eval_batches: 0,
ckpt_path: None,
};
let total_tokens = (steps * dcfg.batch_size * seq_len) as f64;
let t = Instant::now();
let _ = launch(&devices, &corpus, &dcfg, move |device| {
let _ = launch(&devices, &corpus, None, &dcfg, move |device| {
build_model(cfg, device)
});
let secs = t.elapsed().as_secs_f64();