From e27df50ca9337823b683c7db0a16dfe46bab7a99 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Mon, 15 Jun 2026 17:14:56 +0800 Subject: [PATCH] dist: nccl ffi + comm bootstrap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New crate xtrain-distributed (mirrors xserv-distributed): hand-written NCCL FFI (GetUniqueId / CommInitRank / AllReduce / CommDestroy / Group{Start,End}, ncclUniqueId passed by value per the NCCL ABI) and a safe DdpContext wrapper — rank 0 mints the UniqueId, every rank inits its communicator under a group, and all_reduce_average_grads in-place AllReduce(sum)s each param's .grad() device buffer then scales by 1/world (reuses T7's scale_inplace kernel). AllReduce runs on the null stream so it orders with the model's kernels (no extra barrier). build.rs follows the per-crate convention: no nvcc -> no_cuda cfg (crate compiles to empty, cargo check passes host-side); with nvcc, links -lnccl -lcudart like xserv-distributed's build.rs. Co-Authored-By: Claude Opus 4.8 --- Cargo.lock | 9 ++ Cargo.toml | 1 + crates/xtrain-distributed/Cargo.toml | 10 ++ crates/xtrain-distributed/build.rs | 33 ++++++ crates/xtrain-distributed/src/ffi.rs | 76 ++++++++++++++ crates/xtrain-distributed/src/lib.rs | 150 +++++++++++++++++++++++++++ 6 files changed, 279 insertions(+) create mode 100644 crates/xtrain-distributed/Cargo.toml create mode 100644 crates/xtrain-distributed/build.rs create mode 100644 crates/xtrain-distributed/src/ffi.rs create mode 100644 crates/xtrain-distributed/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index a12efd4..5887d45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,6 +205,15 @@ dependencies = [ "cc", ] +[[package]] +name = "xtrain-distributed" +version = "0.1.0" +dependencies = [ + "xtrain-autodiff", + "xtrain-cuda", + "xtrain-tensor", +] + [[package]] name = "xtrain-model" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index fc2329c..8c95f13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "crates/xtrain-model", "crates/xtrain-optim", "crates/xtrain-train", + "crates/xtrain-distributed", ] [workspace.package] diff --git a/crates/xtrain-distributed/Cargo.toml b/crates/xtrain-distributed/Cargo.toml new file mode 100644 index 0000000..4d55dd2 --- /dev/null +++ b/crates/xtrain-distributed/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "xtrain-distributed" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +xtrain-cuda = { path = "../xtrain-cuda" } +xtrain-tensor = { path = "../xtrain-tensor" } +xtrain-autodiff = { path = "../xtrain-autodiff" } diff --git a/crates/xtrain-distributed/build.rs b/crates/xtrain-distributed/build.rs new file mode 100644 index 0000000..ea4c4f2 --- /dev/null +++ b/crates/xtrain-distributed/build.rs @@ -0,0 +1,33 @@ +use std::env; +use std::path::Path; +use std::process::Command; + +// Mirror the per-crate convention (see xtrain-cuda/build.rs): with no nvcc/GPU +// locally, emit `no_cuda` so the NCCL FFI + DDP code compiles (but is not linked +// or run). On dash5, link NCCL exactly like xserv-distributed's build.rs. +fn main() { + println!("cargo:rustc-check-cfg=cfg(no_cuda)"); + + let cuda_path = env::var("CUDA_HOME") + .or_else(|_| env::var("CUDA_PATH")) + .unwrap_or_else(|_| "/usr/local/cuda".to_string()); + + if !nvcc_available(&cuda_path) { + println!("cargo:warning=nvcc not found — skipping NCCL link (host-only build)."); + println!("cargo:rustc-cfg=no_cuda"); + return; + } + + println!("cargo:rustc-link-search=native={cuda_path}/lib64"); + // NCCL is installed as a system library on dash5. + println!("cargo:rustc-link-search=native=/usr/lib/x86_64-linux-gnu"); + println!("cargo:rustc-link-lib=dylib=nccl"); + println!("cargo:rustc-link-lib=dylib=cudart"); +} + +fn nvcc_available(cuda_path: &str) -> bool { + if Command::new("nvcc").arg("--version").output().is_ok() { + return true; + } + Path::new(&format!("{cuda_path}/bin/nvcc")).exists() +} diff --git a/crates/xtrain-distributed/src/ffi.rs b/crates/xtrain-distributed/src/ffi.rs new file mode 100644 index 0000000..4a8ae04 --- /dev/null +++ b/crates/xtrain-distributed/src/ffi.rs @@ -0,0 +1,76 @@ +//! Minimal NCCL FFI bindings (hand-written, like the CUDA bindings in +//! xtrain-cuda). Only the collectives data-parallel training needs: +//! unique-id creation, communicator init/destroy, and AllReduce. Mirrors +//! xserv-distributed's FFI. + +use std::ffi::c_void; +use std::os::raw::c_char; +use xtrain_cuda::ffi::CudaStream; + +/// Opaque NCCL communicator handle (`ncclComm_t`). +pub type NcclComm = *mut c_void; + +/// `ncclUniqueId` is a 128-byte opaque blob shared from rank 0 to every rank. +#[repr(C)] +#[derive(Clone, Copy)] +pub struct NcclUniqueId { + pub internal: [c_char; 128], +} + +impl Default for NcclUniqueId { + fn default() -> Self { + Self { internal: [0; 128] } + } +} + +// ncclDataType_t (subset) — DDP all-reduces fp32 gradients. +pub const NCCL_FLOAT32: i32 = 7; + +// ncclRedOp_t +pub const NCCL_SUM: i32 = 0; + +// ncclResult_t +pub const NCCL_SUCCESS: i32 = 0; + +unsafe extern "C" { + pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32; + // ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI. + pub fn ncclCommInitRank( + comm: *mut NcclComm, + nranks: i32, + commid: NcclUniqueId, + rank: i32, + ) -> i32; + pub fn ncclCommDestroy(comm: NcclComm) -> i32; + pub fn ncclAllReduce( + sendbuff: *const c_void, + recvbuff: *mut c_void, + count: usize, + datatype: i32, + op: i32, + comm: NcclComm, + stream: CudaStream, + ) -> i32; + pub fn ncclGroupStart() -> i32; + pub fn ncclGroupEnd() -> i32; + pub fn ncclGetErrorString(result: i32) -> *const c_char; +} + +pub fn err_string(result: i32) -> String { + unsafe { + let p = ncclGetErrorString(result); + if p.is_null() { + return format!("nccl error {result}"); + } + std::ffi::CStr::from_ptr(p).to_string_lossy().into_owned() + } +} + +pub fn check(result: i32, what: &str) { + assert_eq!( + result, + NCCL_SUCCESS, + "{what} failed: {}", + err_string(result) + ); +} diff --git a/crates/xtrain-distributed/src/lib.rs b/crates/xtrain-distributed/src/lib.rs new file mode 100644 index 0000000..e16e59b --- /dev/null +++ b/crates/xtrain-distributed/src/lib.rs @@ -0,0 +1,150 @@ +//! Distributed data-parallel (DDP) primitives for xtrain (Phase T8). +//! +//! Launch model: **one OS thread per GPU** (same as xserv-distributed). Each +//! rank thread binds its device, builds its own model (xtrain's `Var` graph is +//! `Rc`-based and not `Send`, so it must be constructed thread-locally — only the +//! `UniqueId` and scalar config cross the thread boundary), processes a disjoint +//! shard of the global batch, then AllReduces every parameter's `.grad()` device +//! buffer in place, averages by world size, and runs its own `GpuAdamW.step`. +//! Identical init + identical optimizer state across ranks keeps the parameters +//! consistent without ever re-syncing the weights. +//! +//! NCCL is issued on the legacy null stream — every xtrain kernel launches on the +//! null stream (`std::ptr::null_mut()`), so the AllReduce stays correctly ordered +//! after the producing backward kernels and before the consuming optimizer step, +//! with no extra synchronization. + +#![cfg(not(no_cuda))] + +pub mod ffi; + +use std::ffi::c_void; + +use ffi::{NcclComm, NcclUniqueId}; +use xtrain_autodiff::tape::Var; +use xtrain_cuda::device; + +pub use ffi::NcclUniqueId as UniqueId; + +/// Generate a unique id on one rank (rank 0) and share the raw bytes to every +/// other rank out-of-band — across threads it is just a `Copy` struct moved into +/// each rank closure; across processes it would be written to a file/env. +pub fn get_unique_id() -> NcclUniqueId { + let mut id = NcclUniqueId::default(); + ffi::check(unsafe { ffi::ncclGetUniqueId(&mut id) }, "ncclGetUniqueId"); + id +} + +/// Per-rank data-parallel context: the NCCL communicator plus this rank's +/// identity. AllReduce is in-place on the null stream. +pub struct DdpContext { + pub rank: usize, + pub world: usize, + pub device: u32, + comm: NcclComm, +} + +// The communicator is owned by exactly one rank thread. +unsafe impl Send for DdpContext {} + +impl DdpContext { + /// Initialize this rank. Must run on the thread that will own this rank's GPU + /// work; binds the thread to `device` first. All ranks call this concurrently + /// with the same `id` and `world` — the group wrapper lets the concurrent + /// inits rendezvous without deadlock. + pub fn init(rank: usize, world: usize, id: NcclUniqueId, device: u32) -> Self { + device::set_device(device).expect("set_device"); + let mut comm: NcclComm = std::ptr::null_mut(); + ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)"); + ffi::check( + unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, rank as i32) }, + "ncclCommInitRank", + ); + ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)"); + Self { + rank, + world, + device, + comm, + } + } + + /// In-place AllReduce(sum) over `count` F32 elements at a raw device pointer, + /// issued on the null stream (so it orders with this rank's kernels). The + /// reduction is asynchronous; a later sync (the caller's, or the next null- + /// stream kernel) completes it. + /// + /// # Safety + /// `ptr` must point to at least `count` valid F32 device elements on this + /// rank's device. The reduction is in-place (send == recv). + pub fn all_reduce_sum_f32_ptr(&self, ptr: *mut c_void, count: usize) { + if self.world == 1 { + return; // nothing to reduce + } + ffi::check( + unsafe { + ffi::ncclAllReduce( + ptr as *const c_void, + ptr, + count, + ffi::NCCL_FLOAT32, + ffi::NCCL_SUM, + self.comm, + std::ptr::null_mut(), + ) + }, + "ncclAllReduce", + ); + } + + /// AllReduce every parameter's `.grad()` across ranks and divide by `world`, + /// the one collective DDP needs per step. + /// + /// Each rank ran forward+backward on its own shard of `b` sequences, so + /// `.grad()` holds the SUM over that shard (the tape's fan-out rule). After + /// `AllReduce(sum)` every rank holds `Σ_global` (the sum over all `world·b` + /// sequences); dividing by `world` leaves `Σ_global / world`. The DDP train + /// loop's clip pass then applies the remaining `1/b` (`pre_scale = 1/b_local`), + /// giving `Σ_global / (world·b) = Σ_global / B_global` — bit-for-bit the same + /// mean gradient the single-GPU loop computes from a batch of `B_global`. + /// Params without a grad are skipped. + /// + /// A single-process group barrier is unnecessary: the all-reduces serialize + /// on the comm, and the in-place scale runs on the same null stream after. + pub fn all_reduce_average_grads(&self, params: &[Var]) { + if self.world == 1 { + return; + } + // 1. Sum every grad across ranks (in place, on the null stream). + for p in params { + if let Some(g) = p.grad() { + let n = g.numel(); + self.all_reduce_sum_f32_ptr(g.data_ptr() as *mut c_void, n); + } + } + // 2. Average: scale each summed grad by 1/world (null-stream kernel, + // ordered after the AllReduce that produced it). + let inv_world = 1.0 / self.world as f32; + for p in params { + if let Some(g) = p.grad() { + unsafe { + xtrain_cuda::ffi::launch_scale_inplace_f32( + g.data_ptr() as *mut f32, + inv_world, + g.numel() as i32, + std::ptr::null_mut(), + ); + } + } + } + device::synchronize().expect("grad all-reduce sync failed"); + } +} + +impl Drop for DdpContext { + fn drop(&mut self) { + if !self.comm.is_null() { + unsafe { ffi::ncclCommDestroy(self.comm) }; + } + } +}