train: bring DDP trainer to parity with bin/train (val + checkpoint + cache + arch)
The T8 DDP path now matches the single-GPU `bin/train`: CLI-tunable arch (scaling-ladder rung), the cached token-id stream (`load_cached`), held-out val-loss eval + best-val checkpointing, and LR warmup→cosine. Rank 0 owns the val corpus and runs the no-grad eval / writes the best checkpoint (params are bit-identical across ranks). The eval/checkpoint logic is reused from `xtrain-train` (`eval_loss`, `checkpoint::save`) rather than duplicated. - DdpConfig gains eval_every / eval_batches / ckpt_path. - train_rank takes `valid: Option<&Corpus>` and returns DdpResult (losses + evals + best_val); launch threads the val corpus to rank 0 only. - bin/train_ddp reworked to the bin/train CLI (positional tokenizer/corpus + --dim/--heads/--head-dim/--layers/--ffn/--steps/--batch/--seq/--max-lr/ --val-tokens/--eval-every/--ckpt), reusing the u16 cache. - DDP correctness test updated to the new signatures (semantics unchanged). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -1,15 +1,22 @@
|
||||
//! Multi-rank DDP training launcher (Phase T8): spawn one thread per GPU, NCCL
|
||||
//! all-reduce the gradients each step, and train the tiny transformer on
|
||||
//! TinyStories. Doubles as the throughput driver — run it with 1/2/4 GPUs and
|
||||
//! read the global tok/s line.
|
||||
//! Multi-rank DDP training launcher (Phase T8 / Scaling v2): spawn one thread per
|
||||
//! GPU, NCCL all-reduce the gradients each step, and train the tiny transformer on
|
||||
//! TinyStories. At parity with the single-GPU `bin/train`: CLI-tunable arch
|
||||
//! (scaling-ladder rung), the cached token-id stream, held-out val-loss eval, LR
|
||||
//! warmup→cosine, grad clip, and best-val checkpointing. Doubles as the throughput
|
||||
//! driver — run it with 1/2/4 GPUs and read the global tok/s line.
|
||||
//!
|
||||
//! Run on dash5 (pick idle GPUs — dash5 is shared):
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! CUDA_VISIBLE_DEVICES=0,1 cargo run -p xtrain-distributed --release \
|
||||
//! --bin train_ddp -- 100 64 16
|
||||
//! Args: [steps] [seq_len] [global_batch] [tokenizer.json] [corpus.txt]
|
||||
//! The launcher uses every GPU visible to it (CUDA_VISIBLE_DEVICES selects them),
|
||||
//! so the rank devices are always 0..N within the visible set.
|
||||
//! CUDA_VISIBLE_DEVICES=1,2 cargo run -p xtrain-distributed --release \
|
||||
//! --bin train_ddp -- /opt/wjh/models/gpt2/tokenizer.json \
|
||||
//! data/tinystories-train.txt \
|
||||
//! --dim 384 --heads 12 --head-dim 32 --layers 12 --ffn 1536 \
|
||||
//! --steps 6000 --batch 32 --seq 256 --max-lr 6e-4 \
|
||||
//! --val-tokens 1000000 --eval-every 500 --ckpt /tmp/xtrain_v2.ckpt
|
||||
//!
|
||||
//! Positional: <tokenizer.json> <corpus.txt>. Everything else is a flag with a
|
||||
//! sane default. The launcher uses every GPU visible to it (CUDA_VISIBLE_DEVICES
|
||||
//! selects them), so rank devices are always 0..N within the visible set.
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
@@ -17,8 +24,20 @@ fn main() {
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
// A flag like `--dim 384`: scan argv for `name`, parse the following token.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_distributed::{DdpConfig, build_model, launch};
|
||||
use xtrain_model::Config;
|
||||
@@ -26,18 +45,49 @@ fn main() {
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let steps: usize = args.get(1).and_then(|s| s.parse().ok()).unwrap_or(100);
|
||||
let seq_len: usize = args.get(2).and_then(|s| s.parse().ok()).unwrap_or(64);
|
||||
let batch: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(16);
|
||||
let tok_path = args
|
||||
.get(4)
|
||||
.map(PathBuf::from)
|
||||
// First two non-flag positionals: tokenizer.json, corpus.txt.
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let tok_path = positionals
|
||||
.first()
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
let corpus_path = args
|
||||
.get(5)
|
||||
.map(PathBuf::from)
|
||||
let corpus_path = positionals
|
||||
.get(1)
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
|
||||
|
||||
// Architecture (scaling-ladder rung). Defaults = v0-baseline tiny config.
|
||||
let n_heads = flag(&args, "--heads", 2usize);
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
|
||||
let dim_flag = flag(&args, "--dim", 0usize);
|
||||
if dim_flag != 0 && dim_flag != n_heads * head_dim {
|
||||
eprintln!(
|
||||
"warning: --dim {dim_flag} != heads*head_dim {}; using {}",
|
||||
n_heads * head_dim,
|
||||
n_heads * head_dim
|
||||
);
|
||||
}
|
||||
|
||||
// Optimization knobs (mirror bin/train).
|
||||
let steps: usize = flag(&args, "--steps", 100);
|
||||
let batch: usize = flag(&args, "--batch", 16);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
let weight_decay: f32 = flag(&args, "--wd", 0.1);
|
||||
let max_grad_norm: f32 = flag(&args, "--clip", 1.0);
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
let ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
|
||||
// Use every visible GPU as a rank (CUDA_VISIBLE_DEVICES selects the set;
|
||||
// device ordinals are 0..count within it).
|
||||
let count = device::device_count().expect("device_count") as u32;
|
||||
@@ -56,23 +106,35 @@ fn main() {
|
||||
devices
|
||||
);
|
||||
|
||||
let corpus = Corpus::load(&tok_path, &corpus_path);
|
||||
// Reuse the cached token-id stream (v1's u16 cache); never re-tokenize 2GB.
|
||||
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
let vocab = corpus.vocab_size;
|
||||
// Hold out a tail slice for validation (rank 0 evaluates on it).
|
||||
let (train_corpus, valid) = if val_tokens > 0 {
|
||||
let (t, v) = corpus.split_tail(val_tokens);
|
||||
println!("split: {} train tokens / {} val tokens", t.len(), v.len());
|
||||
(t, Some(v))
|
||||
} else {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = corpus.vocab_size;
|
||||
cfg.n_layers = 4;
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} ffn {} → {} params",
|
||||
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
(+ embed/lm {:.2}M = {:.2}M total)",
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.head_dim,
|
||||
cfg.ffn_hidden,
|
||||
cfg.num_params()
|
||||
cfg.core_params() as f32 / 1e6,
|
||||
(cfg.num_params() - cfg.core_params()) as f32 / 1e6,
|
||||
cfg.num_params() as f32 / 1e6,
|
||||
);
|
||||
|
||||
let dcfg = DdpConfig {
|
||||
@@ -80,22 +142,43 @@ fn main() {
|
||||
batch_size: batch,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
max_lr,
|
||||
min_lr,
|
||||
warmup: (steps / 20).max(5),
|
||||
total: steps,
|
||||
},
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 10,
|
||||
weight_decay,
|
||||
max_grad_norm,
|
||||
log_every: 50,
|
||||
seed: 42,
|
||||
eval_every,
|
||||
eval_batches,
|
||||
ckpt_path: ckpt.clone(),
|
||||
};
|
||||
|
||||
let traces = launch(&devices, &corpus, &dcfg, move |device| {
|
||||
build_model(cfg, device)
|
||||
});
|
||||
let trace = &traces[0];
|
||||
let start = trace.first().copied().unwrap_or(0.0);
|
||||
let end = trace.last().copied().unwrap_or(0.0);
|
||||
println!("loss: start {start:.4} → end {end:.4}");
|
||||
println!(
|
||||
"training: {steps} steps, seq {seq_len}, global batch {batch}, lr {max_lr:.1e}→{min_lr:.1e}, \
|
||||
eval every {eval_every}"
|
||||
);
|
||||
|
||||
let results = launch(
|
||||
&devices,
|
||||
&train_corpus,
|
||||
valid.as_ref(),
|
||||
&dcfg,
|
||||
move |device| build_model(cfg, device),
|
||||
);
|
||||
let r0 = &results[0];
|
||||
let start = r0.losses.first().copied().unwrap_or(0.0);
|
||||
let end = r0.losses.last().copied().unwrap_or(0.0);
|
||||
println!("train loss: start {start:.4} → end {end:.4}");
|
||||
if let Some(best) = r0.best_val {
|
||||
println!("best val loss: {best:.4}");
|
||||
}
|
||||
if let Some((s, v)) = r0.evals.last() {
|
||||
println!("final val loss (step {s}): {v:.4}");
|
||||
}
|
||||
if let Some(path) = &ckpt {
|
||||
println!("best-val checkpoint → {}", path.display());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
//! is exactly the single-GPU batch in the same order, so the all-reduced grad sum
|
||||
//! equals the single-GPU summed grad.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -19,8 +20,10 @@ use xtrain_autodiff::tape::Var;
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor};
|
||||
use xtrain_optim::GpuAdamW;
|
||||
use xtrain_tensor::Device;
|
||||
use xtrain_train::checkpoint;
|
||||
use xtrain_train::clip::clip_grad_norm_gpu;
|
||||
use xtrain_train::data::Corpus;
|
||||
use xtrain_train::eval_loss;
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
|
||||
use crate::{DdpContext, get_unique_id};
|
||||
@@ -38,20 +41,43 @@ pub struct DdpConfig {
|
||||
pub max_grad_norm: f32,
|
||||
pub log_every: usize,
|
||||
pub seed: u64,
|
||||
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Only rank
|
||||
/// 0 holds the `valid` corpus and runs the eval (no grad), mirroring
|
||||
/// `xtrain_train::TrainConfig`. The best-val model is checkpointed by rank 0
|
||||
/// (every rank's params are identical, so rank 0's are the model's).
|
||||
pub eval_every: usize,
|
||||
pub eval_batches: usize,
|
||||
/// Best-val checkpoint path (written by rank 0 when val improves). When unset,
|
||||
/// or when `eval_every == 0`, no checkpoint is written.
|
||||
pub ckpt_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// Outcome of a DDP run on this rank: per-step mean-loss trace plus, when
|
||||
/// `eval_every > 0`, the (step, val_loss) eval points and the best val loss
|
||||
/// (eval/best are only populated on rank 0, which owns the `valid` corpus).
|
||||
pub struct DdpResult {
|
||||
pub losses: Vec<f32>,
|
||||
pub evals: Vec<(usize, f32)>,
|
||||
pub best_val: Option<f32>,
|
||||
}
|
||||
|
||||
/// Run `cfg.steps` DDP steps on this rank's `model`/`corpus`, using `ctx` for the
|
||||
/// gradient all-reduce. Returns this rank's per-step mean-loss trace (the mean
|
||||
/// over the GLOBAL batch — every rank computes the same value because losses are
|
||||
/// all-reduced alongside the grads). The optimizer step is identical on every
|
||||
/// rank, so the parameters stay in lockstep.
|
||||
/// all-reduced alongside the grads) plus eval/best-val (rank 0 only). The
|
||||
/// optimizer step is identical on every rank, so the parameters stay in lockstep.
|
||||
///
|
||||
/// `valid` is the held-out corpus for periodic val-loss eval. Only rank 0 needs
|
||||
/// it (it runs the no-grad eval and writes the best-val checkpoint); pass `None`
|
||||
/// on the other ranks (or when `cfg.eval_every == 0`).
|
||||
pub fn train_rank(
|
||||
ctx: &DdpContext,
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
corpus: &Corpus,
|
||||
valid: Option<&Corpus>,
|
||||
cfg: &DdpConfig,
|
||||
) -> Vec<f32> {
|
||||
) -> DdpResult {
|
||||
assert_eq!(
|
||||
cfg.batch_size % ctx.world,
|
||||
0,
|
||||
@@ -63,12 +89,17 @@ pub fn train_rank(
|
||||
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
||||
let mut rng = cfg.seed;
|
||||
let mut losses = Vec::with_capacity(cfg.steps);
|
||||
let mut evals = Vec::new();
|
||||
let mut best_val: Option<f32> = None;
|
||||
// Each rank reaches the global batch mean as (Σ_global / world) · (1/b_local),
|
||||
// where b_local = batch_size / world (see DdpContext::all_reduce_average_grads).
|
||||
let batch_local = cfg.batch_size / ctx.world;
|
||||
let inv_batch_local = 1.0 / batch_local as f32;
|
||||
let start = Instant::now();
|
||||
let mut tokens_seen: u64 = 0;
|
||||
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
|
||||
// across ranks, so rank 0's are the model). Other ranks never touch `valid`.
|
||||
let do_eval = ctx.rank == 0 && cfg.eval_every > 0 && valid.is_some();
|
||||
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
@@ -114,19 +145,52 @@ pub fn train_rank(
|
||||
cfg.steps, ctx.world
|
||||
);
|
||||
}
|
||||
|
||||
// Periodic held-out eval + best-val checkpoint (rank 0 only). Mirrors the
|
||||
// single-GPU `xtrain_train::train` loop, reusing its `eval_loss` /
|
||||
// `checkpoint::save` so single-GPU and DDP share one eval/ckpt path. Other
|
||||
// ranks have nothing to do here (params are identical across ranks).
|
||||
if do_eval && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
|
||||
let v = valid.unwrap();
|
||||
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
|
||||
evals.push((step, vl));
|
||||
let improved = best_val.map(|b| vl < b).unwrap_or(true);
|
||||
println!(
|
||||
" [rank0] eval @ step {step}: val loss {vl:.4}{}",
|
||||
if improved { " (best)" } else { "" }
|
||||
);
|
||||
if improved {
|
||||
best_val = Some(vl);
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
checkpoint::save(path, ¶ms).expect("best checkpoint save");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
DdpResult {
|
||||
losses,
|
||||
evals,
|
||||
best_val,
|
||||
}
|
||||
losses
|
||||
}
|
||||
|
||||
/// Spawn `world` rank threads (one per GPU in `devices`), init NCCL, build an
|
||||
/// identical model per rank via `make_model`, and run `train_rank`. Returns each
|
||||
/// rank's loss trace (all identical). The launcher owns the thread-per-GPU model:
|
||||
/// rank 0 mints the `UniqueId`, every thread `cudaSetDevice`s its GPU, builds its
|
||||
/// `Var` graph locally (the graph is `!Send`), and joins at the end.
|
||||
/// rank's `DdpResult` (loss traces are identical; eval/best-val are on rank 0).
|
||||
/// The launcher owns the thread-per-GPU model: rank 0 mints the `UniqueId`, every
|
||||
/// thread `cudaSetDevice`s its GPU, builds its `Var` graph locally (the graph is
|
||||
/// `!Send`), and joins at the end.
|
||||
///
|
||||
/// `make_model(device)` must be deterministic — same params on every rank — for
|
||||
/// the parameters to stay consistent.
|
||||
pub fn launch<F>(devices: &[u32], corpus: &Corpus, cfg: &DdpConfig, make_model: F) -> Vec<Vec<f32>>
|
||||
/// `valid` is the held-out corpus for rank 0's periodic eval (only used when
|
||||
/// `cfg.eval_every > 0`). `make_model(device)` must be deterministic — same params
|
||||
/// on every rank — for the parameters to stay consistent.
|
||||
pub fn launch<F>(
|
||||
devices: &[u32],
|
||||
corpus: &Corpus,
|
||||
valid: Option<&Corpus>,
|
||||
cfg: &DdpConfig,
|
||||
make_model: F,
|
||||
) -> Vec<DdpResult>
|
||||
where
|
||||
F: Fn(Device) -> TinyTransformer + Send + Sync,
|
||||
{
|
||||
@@ -144,7 +208,9 @@ where
|
||||
let ctx = DdpContext::init(rank, world, id, dev);
|
||||
let device = Device::Cuda(dev);
|
||||
let model = make_model(device);
|
||||
train_rank(&ctx, &model, device, corpus, &cfg)
|
||||
// Only rank 0 holds the val corpus for eval.
|
||||
let v = if rank == 0 { valid } else { None };
|
||||
train_rank(&ctx, &model, device, corpus, v, &cfg)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
pub mod ddp;
|
||||
pub mod ffi;
|
||||
|
||||
pub use ddp::{DdpConfig, build_model, launch, train_rank};
|
||||
pub use ddp::{DdpConfig, DdpResult, build_model, launch, train_rank};
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
|
||||
@@ -102,6 +102,9 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 1_000_000, // silence per-step logging in the test
|
||||
seed: 7,
|
||||
eval_every: 0,
|
||||
eval_batches: 0,
|
||||
ckpt_path: None,
|
||||
};
|
||||
|
||||
// Single-GPU baseline (world=1) over the global batch.
|
||||
@@ -121,13 +124,13 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
||||
let ctx = DdpContext::init(rank, world, id, dev);
|
||||
let device = Device::Cuda(dev);
|
||||
let model = build_model(cfg, device);
|
||||
let losses = train_rank(&ctx, &model, device, corpus, &dcfg);
|
||||
let res = train_rank(&ctx, &model, device, corpus, None, &dcfg);
|
||||
let host = model
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
|
||||
.collect::<Vec<_>>();
|
||||
(losses, host)
|
||||
(res.losses, host)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -224,10 +227,13 @@ fn ddp_throughput_scaling() {
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 1_000_000,
|
||||
seed: 1,
|
||||
eval_every: 0,
|
||||
eval_batches: 0,
|
||||
ckpt_path: None,
|
||||
};
|
||||
let total_tokens = (steps * dcfg.batch_size * seq_len) as f64;
|
||||
let t = Instant::now();
|
||||
let _ = launch(&devices, &corpus, &dcfg, move |device| {
|
||||
let _ = launch(&devices, &corpus, None, &dcfg, move |device| {
|
||||
build_model(cfg, device)
|
||||
});
|
||||
let secs = t.elapsed().as_secs_f64();
|
||||
|
||||
Reference in New Issue
Block a user