distributed: process-per-GPU launcher + worker (proc.rs)

torchrun-style process-per-GPU: launch_processes spawns one worker process per
GPU (re-exec current_exe with XTRAIN_{RANK,WORLD,LOCAL_RANK,NCCL_ID} env),
mints the ncclUniqueId once in the launcher and hex-injects it via env (no
shared FS/TCP, race-free). worker_env/run_worker read the env, bind the device
(own CUDA context), DdpContext::init + build_model + train_rank reused from T8
UNCHANGED. hex_encode/decode_unique_id are host-testable pure fns.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 17:48:43 +08:00
parent c470c627a7
commit ffd548b80b
2 changed files with 205 additions and 0 deletions

View File

@@ -18,8 +18,13 @@
pub mod ddp;
pub mod ffi;
pub mod proc;
pub use ddp::{DdpConfig, DdpResult, build_model, launch, train_rank};
pub use proc::{
ModelOpts, WorkerEnv, build_worker_model, hex_decode_unique_id, hex_encode_unique_id,
launch_processes, run_worker, worker_env,
};
use std::ffi::c_void;

View File

@@ -0,0 +1,200 @@
//! Process-per-GPU DDP launcher + worker (Phase T17, torchrun-style).
//!
//! T8's DDP is single-process, thread-per-GPU: N rank threads share ONE CUDA
//! primary context, so much of the driver work (kernel launch, cuBLAS handle,
//! stream queueing) serializes at the context level — the residual ~5×@8
//! non-linearity left after T11's allocator fix (see docs/10 / KI-5).
//!
//! Process-per-GPU gives each rank its OWN OS process and OWN CUDA context, so
//! those driver calls no longer queue in a shared context. Only the LAUNCH model
//! and the cross-process NCCL bootstrap change; the training step
//! (`train_rank` → grad all-reduce → local AdamW) and the consistency argument
//! are reused from T8 UNCHANGED.
//!
//! UniqueId rendezvous: the LAUNCHER (the common parent of every worker) mints
//! the `ncclUniqueId` once, hex-encodes it, and injects it into each worker's env
//! at spawn time. No shared file / TCP server / polling — the id is atomically
//! present before the child exists, so there is no "id not ready yet" race. This
//! is the simplest single-node mechanism (see docs/16).
use std::path::PathBuf;
use std::process::{Command, Stdio};
use xtrain_model::{Config, TinyTransformer};
use xtrain_tensor::{DType, Device};
use xtrain_train::data::Corpus;
use crate::ddp::{DdpConfig, DdpResult, build_model, train_rank};
use crate::ffi::NcclUniqueId;
use crate::{DdpContext, get_unique_id};
// Env keys the launcher sets on every spawned worker (torchrun-style: a worker
// detects its role by the presence of `XTRAIN_RANK`).
pub const ENV_RANK: &str = "XTRAIN_RANK";
pub const ENV_WORLD: &str = "XTRAIN_WORLD";
pub const ENV_LOCAL_RANK: &str = "XTRAIN_LOCAL_RANK";
pub const ENV_NCCL_ID: &str = "XTRAIN_NCCL_ID";
/// Hex-encode the 128-byte `ncclUniqueId` for env transport (128 B → 256 chars,
/// well under any env-var length limit). `c_char` is signed on this target, so
/// reinterpret the bytes as `u8` first.
pub fn hex_encode_unique_id(id: &NcclUniqueId) -> String {
let mut s = String::with_capacity(256);
for &b in &id.internal {
s.push_str(&format!("{:02x}", b as u8));
}
s
}
/// Inverse of [`hex_encode_unique_id`]: parse 256 hex chars back into the
/// 128-byte opaque blob. Panics on malformed input (the launcher always writes a
/// well-formed value, so a bad value means a corrupted env).
pub fn hex_decode_unique_id(hex: &str) -> NcclUniqueId {
assert_eq!(
hex.len(),
256,
"NCCL id hex must be 256 chars, got {}",
hex.len()
);
let mut id = NcclUniqueId::default();
for (i, slot) in id.internal.iter_mut().enumerate() {
let byte = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16).expect("NCCL id hex byte parse");
*slot = byte as std::os::raw::c_char;
}
id
}
/// Spawn `world` worker processes (re-exec of the current binary with the same
/// argv), each pinned to one GPU via `XTRAIN_LOCAL_RANK`, and wait for all of
/// them. The launcher mints the `ncclUniqueId` and injects it (hex) into every
/// worker's env, so the cross-process NCCL bootstrap needs no shared file/TCP.
///
/// Returns `Ok(())` iff every worker exits 0; otherwise an error naming the first
/// failing rank (so the caller — `main` / a test — can propagate a non-zero exit).
/// `extra_args` is forwarded to each worker verbatim (so all training hyper-params
/// pass straight through); the workers inherit the launcher's env (incl.
/// `CUDA_VISIBLE_DEVICES`) plus the four `XTRAIN_*` keys.
pub fn launch_processes(world: usize, extra_args: &[String]) -> Result<(), String> {
let exe = std::env::current_exe().map_err(|e| format!("current_exe: {e}"))?;
let id = get_unique_id();
let id_hex = hex_encode_unique_id(&id);
let mut children = Vec::with_capacity(world);
for rank in 0..world {
let child = Command::new(&exe)
.args(extra_args)
.env(ENV_RANK, rank.to_string())
.env(ENV_WORLD, world.to_string())
// Single node: local rank == global rank == device ordinal within the
// visible set. (Multi-node would split these; see docs/16 follow-up.)
.env(ENV_LOCAL_RANK, rank.to_string())
.env(ENV_NCCL_ID, &id_hex)
// Workers inherit stdout/stderr so rank 0's training log surfaces.
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.spawn()
.map_err(|e| format!("spawn worker rank {rank}: {e}"))?;
children.push((rank, child));
}
let mut first_err: Option<String> = None;
for (rank, mut child) in children {
let status = child
.wait()
.map_err(|e| format!("wait worker rank {rank}: {e}"))?;
if !status.success() && first_err.is_none() {
first_err = Some(format!("worker rank {rank} exited with {status}"));
}
}
match first_err {
Some(e) => Err(e),
None => Ok(()),
}
}
/// The four `XTRAIN_*` values a worker reads from its env. Present iff this
/// process was spawned by [`launch_processes`].
pub struct WorkerEnv {
pub rank: usize,
pub world: usize,
pub local_rank: u32,
pub id: NcclUniqueId,
}
/// Read the worker env if this process is a spawned worker (i.e. `XTRAIN_RANK`
/// is set), else `None` (this process is the launcher).
pub fn worker_env() -> Option<WorkerEnv> {
let rank: usize = std::env::var(ENV_RANK).ok()?.parse().ok()?;
let world: usize = std::env::var(ENV_WORLD)
.expect("XTRAIN_WORLD set with XTRAIN_RANK")
.parse()
.expect("XTRAIN_WORLD parse");
let local_rank: u32 = std::env::var(ENV_LOCAL_RANK)
.expect("XTRAIN_LOCAL_RANK set with XTRAIN_RANK")
.parse()
.expect("XTRAIN_LOCAL_RANK parse");
let id_hex = std::env::var(ENV_NCCL_ID).expect("XTRAIN_NCCL_ID set with XTRAIN_RANK");
let id = hex_decode_unique_id(&id_hex);
Some(WorkerEnv {
rank,
world,
local_rank,
id,
})
}
/// Per-worker model construction knobs (the opt-in feature flags the launcher
/// forwards). Mirrors the closure `train_ddp` passes to the thread-per-GPU
/// `launch`, but here it runs once in this worker's own process/context.
#[derive(Clone, Copy, Default)]
pub struct ModelOpts {
pub bf16: bool,
pub recompute: bool,
pub flash: bool,
}
/// Run this worker: bind its GPU (→ its own CUDA context), init NCCL with the
/// launcher-supplied id, build its model with the deterministic init (same as
/// every rank + the single-GPU baseline), and run `train_rank`. Reuses the T8
/// training step verbatim — the only difference from thread-per-GPU is how this
/// rank was started and how it got the `UniqueId`.
///
/// `valid` is the held-out corpus for rank 0's periodic eval (pass `None` on
/// other ranks or when `cfg.eval_every == 0`).
pub fn run_worker(
env: &WorkerEnv,
cfg: Config,
opts: ModelOpts,
corpus: &Corpus,
valid: Option<&Corpus>,
dcfg: &DdpConfig,
) -> DdpResult {
// Binding the device here establishes this process's own CUDA primary context.
let ctx = DdpContext::init(env.rank, env.world, env.id, env.local_rank);
let device = Device::Cuda(env.local_rank);
let model = build_worker_model(cfg, opts, device);
let v = if env.rank == 0 { valid } else { None };
train_rank(&ctx, &model, device, corpus, v, dcfg)
}
/// Build the worker's model with the deterministic `build_model` init + the
/// opt-in feature flags. Shared by `run_worker` and the test worker.
pub fn build_worker_model(cfg: Config, opts: ModelOpts, device: Device) -> TinyTransformer {
let mut m = build_model(cfg, device);
if opts.bf16 {
m = m.with_compute_dtype(DType::BF16);
}
if opts.recompute {
m = m.with_recompute(true);
}
if opts.flash {
m = m.with_flash(true);
}
m
}
/// Convenience: the directory tests/bins can stash per-rank result dumps in
/// (a worker writes its loss/params there; the launching test reads them back).
pub fn rank_dump_path(dir: &std::path::Path, rank: usize) -> PathBuf {
dir.join(format!("rank{rank}.dump"))
}