distributed: train_ddp_mp bin (process-per-GPU launcher/worker)
Dual-mode binary self-detecting via XTRAIN_RANK: launcher spawns one worker per visible GPU forwarding full argv; worker rebuilds config from argv and runs run_worker. CLI flags identical to train_ddp (thread-per-GPU, kept), so it doubles as the before->after throughput driver. thread-per-GPU path untouched. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
203
crates/xtrain-distributed/src/bin/train_ddp_mp.rs
Normal file
203
crates/xtrain-distributed/src/bin/train_ddp_mp.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
//! Process-per-GPU DDP launcher / worker (Phase T17, torchrun-style).
|
||||
//!
|
||||
//! ONE binary, two modes (it self-detects via `XTRAIN_RANK`):
|
||||
//! - **launcher** (env unset): mints the NCCL `ncclUniqueId`, then spawns one
|
||||
//! WORKER process per visible GPU, re-execing this same binary with the same
|
||||
//! argv plus `XTRAIN_{RANK,WORLD,LOCAL_RANK,NCCL_ID}` env, and waits for them.
|
||||
//! - **worker** (`XTRAIN_RANK` set): binds its GPU (→ its own CUDA context),
|
||||
//! inits NCCL with the launcher-supplied id, builds its model, runs
|
||||
//! `train_rank` — the T8 training step reused UNCHANGED.
|
||||
//!
|
||||
//! Versus `train_ddp` (thread-per-GPU, kept as the regression baseline) the ONLY
|
||||
//! difference is the launch model + cross-process UniqueId bootstrap. CLI flags
|
||||
//! are identical, so it doubles as the before→after throughput driver.
|
||||
//!
|
||||
//! Run on dash5 (pick idle GPUs — dash5 is shared):
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! CUDA_VISIBLE_DEVICES=0,1,2,3 cargo run -p xtrain-distributed --release \
|
||||
//! --bin train_ddp_mp -- /opt/wjh/models/gpt2/tokenizer.json \
|
||||
//! data/tinystories-valid-3mb.txt \
|
||||
//! --dim 384 --heads 12 --head-dim 32 --layers 12 --ffn 1536 \
|
||||
//! --steps 200 --batch 128 --seq 256
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("train_ddp_mp: built without CUDA (no_cuda); run on a GPU host (dash5).");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use std::path::PathBuf;
|
||||
|
||||
// A flag like `--dim 384`: scan argv for `name`, parse the following token.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_distributed::DdpConfig;
|
||||
use xtrain_distributed::proc::{ModelOpts, launch_processes, run_worker, worker_env};
|
||||
use xtrain_model::Config;
|
||||
use xtrain_train::data::Corpus;
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
|
||||
// ── Launcher mode: no XTRAIN_RANK in env → spawn one worker per visible GPU.
|
||||
let env = worker_env();
|
||||
if env.is_none() {
|
||||
let count = device::device_count().expect("device_count");
|
||||
assert!(count > 0, "no CUDA device visible");
|
||||
let world = count as usize;
|
||||
// Forward the full argv (minus argv[0]) to each worker verbatim.
|
||||
let extra: Vec<String> = args[1..].to_vec();
|
||||
println!("DDP (process-per-GPU): launching {world} worker processes (one per visible GPU)");
|
||||
match launch_processes(world, &extra) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
eprintln!("launcher: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
let env = env.unwrap();
|
||||
|
||||
// ── Worker mode: build config from the forwarded argv, then train this rank.
|
||||
// First two non-flag positionals: tokenizer.json, corpus.txt.
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let tok_path = positionals
|
||||
.first()
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
let corpus_path = positionals
|
||||
.get(1)
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
|
||||
|
||||
// Architecture (scaling-ladder rung). Defaults = v0-baseline tiny config.
|
||||
let n_heads = flag(&args, "--heads", 2usize);
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let dim_flag = flag(&args, "--dim", 0usize);
|
||||
if dim_flag != 0 && dim_flag != n_heads * head_dim {
|
||||
eprintln!(
|
||||
"warning: --dim {dim_flag} != heads*head_dim {}; using {}",
|
||||
n_heads * head_dim,
|
||||
n_heads * head_dim
|
||||
);
|
||||
}
|
||||
|
||||
// Optimization knobs (mirror train_ddp).
|
||||
let steps: usize = flag(&args, "--steps", 100);
|
||||
let batch: usize = flag(&args, "--batch", 16);
|
||||
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
let weight_decay: f32 = flag(&args, "--wd", 0.1);
|
||||
let max_grad_norm: f32 = flag(&args, "--clip", 1.0);
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
let opts = ModelOpts {
|
||||
bf16: args.iter().any(|a| a == "--bf16"),
|
||||
recompute: args.iter().any(|a| a == "--recompute"),
|
||||
flash: args.iter().any(|a| a == "--flash"),
|
||||
};
|
||||
let ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
|
||||
assert_eq!(
|
||||
batch % env.world,
|
||||
0,
|
||||
"global batch {batch} not divisible by world {}",
|
||||
env.world
|
||||
);
|
||||
|
||||
// Each worker loads the corpus independently (read-only u16 cache hit → cheap).
|
||||
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||
let vocab = corpus.vocab_size;
|
||||
let (train_corpus, valid): (Corpus, Option<Corpus>) = if val_tokens > 0 {
|
||||
let (t, v) = corpus.split_tail(val_tokens);
|
||||
(t, Some(v))
|
||||
} else {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
|
||||
if env.rank == 0 {
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} kv_heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
(+ embed/lm {:.2}M = {:.2}M total) | world={} mode=process-per-GPU",
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.num_kv_heads,
|
||||
cfg.head_dim,
|
||||
cfg.ffn_hidden,
|
||||
cfg.core_params() as f32 / 1e6,
|
||||
(cfg.num_params() - cfg.core_params()) as f32 / 1e6,
|
||||
cfg.num_params() as f32 / 1e6,
|
||||
env.world,
|
||||
);
|
||||
if opts.bf16 {
|
||||
println!("bf16 mixed precision: ON (fp32 master weights)");
|
||||
}
|
||||
if opts.recompute {
|
||||
println!("activation recompute: ON (per-block gradient checkpointing)");
|
||||
}
|
||||
if opts.flash {
|
||||
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
|
||||
}
|
||||
}
|
||||
|
||||
let dcfg = DdpConfig {
|
||||
seq_len,
|
||||
batch_size: batch,
|
||||
accum_steps,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr,
|
||||
min_lr,
|
||||
warmup: (steps / 20).max(5),
|
||||
total: steps,
|
||||
},
|
||||
weight_decay,
|
||||
max_grad_norm,
|
||||
log_every: 50,
|
||||
seed: 42,
|
||||
eval_every,
|
||||
eval_batches,
|
||||
ckpt_path: ckpt.clone(),
|
||||
};
|
||||
|
||||
let res = run_worker(&env, cfg, opts, &train_corpus, valid.as_ref(), &dcfg);
|
||||
|
||||
if env.rank == 0 {
|
||||
let start = res.losses.first().copied().unwrap_or(0.0);
|
||||
let end = res.losses.last().copied().unwrap_or(0.0);
|
||||
println!("train loss: start {start:.4} → end {end:.4}");
|
||||
if let Some(best) = res.best_val {
|
||||
println!("best val loss: {best:.4}");
|
||||
}
|
||||
if let Some((s, v)) = res.evals.last() {
|
||||
println!("final val loss (step {s}): {v:.4}");
|
||||
}
|
||||
if let Some(path) = &ckpt {
|
||||
println!("best-val checkpoint → {}", path.display());
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user