diff --git a/Cargo.toml b/Cargo.toml index c8eb939..e2145f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "crates/xserv-model", "crates/xserv-tokenizer", "crates/xserv-server", + "crates/xserv-distributed", ] [workspace.package] diff --git a/crates/xserv-distributed/Cargo.toml b/crates/xserv-distributed/Cargo.toml new file mode 100644 index 0000000..0932443 --- /dev/null +++ b/crates/xserv-distributed/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "xserv-distributed" +version.workspace = true +edition.workspace = true + +[dependencies] +xserv-cuda = { path = "../xserv-cuda" } +half.workspace = true diff --git a/crates/xserv-distributed/build.rs b/crates/xserv-distributed/build.rs new file mode 100644 index 0000000..d4f3b53 --- /dev/null +++ b/crates/xserv-distributed/build.rs @@ -0,0 +1,13 @@ +use std::env; + +fn main() { + let cuda_path = env::var("CUDA_HOME") + .or_else(|_| env::var("CUDA_PATH")) + .unwrap_or_else(|_| "/usr/local/cuda".to_string()); + + println!("cargo:rustc-link-search=native={cuda_path}/lib64"); + // NCCL is typically installed as a system library. + println!("cargo:rustc-link-search=native=/usr/lib/x86_64-linux-gnu"); + println!("cargo:rustc-link-lib=dylib=nccl"); + println!("cargo:rustc-link-lib=dylib=cudart"); +} diff --git a/crates/xserv-distributed/src/ffi.rs b/crates/xserv-distributed/src/ffi.rs new file mode 100644 index 0000000..2b713b4 --- /dev/null +++ b/crates/xserv-distributed/src/ffi.rs @@ -0,0 +1,65 @@ +//! Minimal NCCL FFI bindings (hand-written, like the CUDA bindings). +//! Only the collectives we need for tensor parallelism. + +use std::ffi::c_void; +use std::os::raw::c_char; +use xserv_cuda::ffi::CudaStream; + +/// Opaque NCCL communicator handle (`ncclComm_t`). +pub type NcclComm = *mut c_void; + +/// `ncclUniqueId` is a 128-byte opaque blob shared from rank 0 to all ranks. +#[repr(C)] +#[derive(Clone, Copy)] +pub struct NcclUniqueId { + pub internal: [c_char; 128], +} + +impl Default for NcclUniqueId { + fn default() -> Self { + Self { internal: [0; 128] } + } +} + +// ncclDataType_t (subset) +pub const NCCL_FLOAT32: i32 = 7; +pub const NCCL_BF16: i32 = 9; + +// ncclRedOp_t +pub const NCCL_SUM: i32 = 0; + +// ncclResult_t +pub const NCCL_SUCCESS: i32 = 0; + +unsafe extern "C" { + pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32; + // ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI. + pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32; + pub fn ncclCommDestroy(comm: NcclComm) -> i32; + pub fn ncclAllReduce( + sendbuff: *const c_void, + recvbuff: *mut c_void, + count: usize, + datatype: i32, + op: i32, + comm: NcclComm, + stream: CudaStream, + ) -> i32; + pub fn ncclGroupStart() -> i32; + pub fn ncclGroupEnd() -> i32; + pub fn ncclGetErrorString(result: i32) -> *const c_char; +} + +pub fn err_string(result: i32) -> String { + unsafe { + let p = ncclGetErrorString(result); + if p.is_null() { + return format!("nccl error {result}"); + } + std::ffi::CStr::from_ptr(p).to_string_lossy().into_owned() + } +} + +pub fn check(result: i32, what: &str) { + assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result)); +} diff --git a/crates/xserv-distributed/src/lib.rs b/crates/xserv-distributed/src/lib.rs new file mode 100644 index 0000000..0feeea5 --- /dev/null +++ b/crates/xserv-distributed/src/lib.rs @@ -0,0 +1,97 @@ +//! Tensor-parallel primitives for xserv. +//! +//! Process model: one OS thread per TP rank, each bound to one GPU. NCCL is +//! used for the collective (AllReduce); a hand-rolled P2P AllReduce may replace +//! it later as a learning exercise (see docs/17-tensor-parallelism.md). + +pub mod ffi; + +use std::ffi::c_void; + +use ffi::{NcclComm, NcclUniqueId}; +use xserv_cuda::device; +use xserv_cuda::GpuBuffer; + +pub use ffi::NcclUniqueId as UniqueId; + +/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run +/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered +/// after the producing matmul and before the consuming kernel — no extra sync. +const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut(); + +/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes +/// to all ranks out-of-band (e.g. via a shared variable across threads). +pub fn get_unique_id() -> NcclUniqueId { + let mut id = NcclUniqueId::default(); + ffi::check(unsafe { ffi::ncclGetUniqueId(&mut id) }, "ncclGetUniqueId"); + id +} + +/// Per-rank tensor-parallel context: NCCL communicator + a dedicated stream. +pub struct TpContext { + pub rank: usize, + pub world: usize, + pub device: u32, + comm: NcclComm, +} + +// The NCCL communicator is owned by exactly one rank thread. +unsafe impl Send for TpContext {} + +impl TpContext { + /// Initialize this rank. Must be called from the thread that will own this + /// rank's GPU work; binds the thread to `device` first. All ranks must call + /// this concurrently with the same `id` and `world`. + pub fn init(rank: usize, world: usize, id: NcclUniqueId, device: u32) -> Self { + device::set_device(device).expect("set_device"); + let mut comm: NcclComm = std::ptr::null_mut(); + // Wrap the concurrent inits in a group so they rendezvous without deadlock. + ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)"); + ffi::check( + unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, rank as i32) }, + "ncclCommInitRank", + ); + ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)"); + Self { rank, world, device, comm } + } + + /// In-place AllReduce(sum) over `count` BF16 elements in `buf`. + pub fn all_reduce_sum_bf16(&self, buf: &mut GpuBuffer, count: usize) { + self.all_reduce_sum_bf16_ptr(buf.as_mut_ptr() as *mut c_void, count); + } + + /// In-place AllReduce(sum) directly on a device pointer (`count` BF16 elems), + /// issued on the null stream so it is ordered with the model's kernels. + /// Asynchronous: a later sync (e.g. the D2H logits copy) completes it. + /// + /// # Safety + /// `ptr` must point to at least `count` BF16 elements of valid device memory + /// on this rank's device. The reduction is in-place (send == recv). + pub fn all_reduce_sum_bf16_ptr(&self, ptr: *mut c_void, count: usize) { + if self.world == 1 { + return; // nothing to reduce + } + ffi::check( + unsafe { + ffi::ncclAllReduce( + ptr as *const c_void, + ptr, + count, + ffi::NCCL_BF16, + ffi::NCCL_SUM, + self.comm, + NULL_STREAM, + ) + }, + "ncclAllReduce", + ); + } +} + +impl Drop for TpContext { + fn drop(&mut self) { + if !self.comm.is_null() { + unsafe { ffi::ncclCommDestroy(self.comm) }; + } + } +} diff --git a/crates/xserv-distributed/tests/allreduce.rs b/crates/xserv-distributed/tests/allreduce.rs new file mode 100644 index 0000000..636a149 --- /dev/null +++ b/crates/xserv-distributed/tests/allreduce.rs @@ -0,0 +1,50 @@ +//! 2-GPU AllReduce smoke test. Skips if fewer than 2 GPUs are present. + +use half::bf16; +use std::thread; +use xserv_cuda::{device, GpuBuffer}; +use xserv_distributed::{get_unique_id, TpContext}; + +#[test] +fn allreduce_two_gpu_sum() { + let world = 2usize; + if device::device_count().unwrap_or(0) < world as i32 { + eprintln!("skip: need >= {world} GPUs"); + return; + } + + let id = get_unique_id(); + let n = 4096usize; + + let handles: Vec<_> = (0..world) + .map(|rank| { + let id = id; + thread::spawn(move || { + let ctx = TpContext::init(rank, world, id, rank as u32); + + // Rank r fills its buffer with (r + 1). + let val = bf16::from_f32((rank + 1) as f32); + let host = vec![val; n]; + let src = unsafe { + std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) + }; + let mut buf = GpuBuffer::alloc(n * 2).unwrap(); + buf.copy_from_host(src).unwrap(); + + ctx.all_reduce_sum_bf16(&mut buf, n); + + let mut out = vec![0u8; n * 2]; + buf.copy_to_host(&mut out).unwrap(); + let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) }; + (res[0].to_f32(), res[n - 1].to_f32()) + }) + }) + .collect(); + + // sum over ranks of (r+1) = 1 + 2 = 3 + for h in handles { + let (first, last) = h.join().unwrap(); + assert_eq!(first, 3.0, "AllReduce(sum) first element"); + assert_eq!(last, 3.0, "AllReduce(sum) last element"); + } +}