distributed: NCCL tensor-parallel primitives (TpContext + AllReduce)
New xserv-distributed crate: hand-written NCCL FFI, TpContext (one rank per thread, bound to one GPU), and in-place BF16 AllReduce on the null stream so it orders naturally with the model's kernels. 2-GPU AllReduce test included. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ members = [
|
|||||||
"crates/xserv-model",
|
"crates/xserv-model",
|
||||||
"crates/xserv-tokenizer",
|
"crates/xserv-tokenizer",
|
||||||
"crates/xserv-server",
|
"crates/xserv-server",
|
||||||
|
"crates/xserv-distributed",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
|
|||||||
8
crates/xserv-distributed/Cargo.toml
Normal file
8
crates/xserv-distributed/Cargo.toml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
[package]
|
||||||
|
name = "xserv-distributed"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
xserv-cuda = { path = "../xserv-cuda" }
|
||||||
|
half.workspace = true
|
||||||
13
crates/xserv-distributed/build.rs
Normal file
13
crates/xserv-distributed/build.rs
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
use std::env;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let cuda_path = env::var("CUDA_HOME")
|
||||||
|
.or_else(|_| env::var("CUDA_PATH"))
|
||||||
|
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
||||||
|
|
||||||
|
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
||||||
|
// NCCL is typically installed as a system library.
|
||||||
|
println!("cargo:rustc-link-search=native=/usr/lib/x86_64-linux-gnu");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=nccl");
|
||||||
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
|
}
|
||||||
65
crates/xserv-distributed/src/ffi.rs
Normal file
65
crates/xserv-distributed/src/ffi.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
//! Minimal NCCL FFI bindings (hand-written, like the CUDA bindings).
|
||||||
|
//! Only the collectives we need for tensor parallelism.
|
||||||
|
|
||||||
|
use std::ffi::c_void;
|
||||||
|
use std::os::raw::c_char;
|
||||||
|
use xserv_cuda::ffi::CudaStream;
|
||||||
|
|
||||||
|
/// Opaque NCCL communicator handle (`ncclComm_t`).
|
||||||
|
pub type NcclComm = *mut c_void;
|
||||||
|
|
||||||
|
/// `ncclUniqueId` is a 128-byte opaque blob shared from rank 0 to all ranks.
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct NcclUniqueId {
|
||||||
|
pub internal: [c_char; 128],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NcclUniqueId {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self { internal: [0; 128] }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ncclDataType_t (subset)
|
||||||
|
pub const NCCL_FLOAT32: i32 = 7;
|
||||||
|
pub const NCCL_BF16: i32 = 9;
|
||||||
|
|
||||||
|
// ncclRedOp_t
|
||||||
|
pub const NCCL_SUM: i32 = 0;
|
||||||
|
|
||||||
|
// ncclResult_t
|
||||||
|
pub const NCCL_SUCCESS: i32 = 0;
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32;
|
||||||
|
// ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI.
|
||||||
|
pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32;
|
||||||
|
pub fn ncclCommDestroy(comm: NcclComm) -> i32;
|
||||||
|
pub fn ncclAllReduce(
|
||||||
|
sendbuff: *const c_void,
|
||||||
|
recvbuff: *mut c_void,
|
||||||
|
count: usize,
|
||||||
|
datatype: i32,
|
||||||
|
op: i32,
|
||||||
|
comm: NcclComm,
|
||||||
|
stream: CudaStream,
|
||||||
|
) -> i32;
|
||||||
|
pub fn ncclGroupStart() -> i32;
|
||||||
|
pub fn ncclGroupEnd() -> i32;
|
||||||
|
pub fn ncclGetErrorString(result: i32) -> *const c_char;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn err_string(result: i32) -> String {
|
||||||
|
unsafe {
|
||||||
|
let p = ncclGetErrorString(result);
|
||||||
|
if p.is_null() {
|
||||||
|
return format!("nccl error {result}");
|
||||||
|
}
|
||||||
|
std::ffi::CStr::from_ptr(p).to_string_lossy().into_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(result: i32, what: &str) {
|
||||||
|
assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result));
|
||||||
|
}
|
||||||
97
crates/xserv-distributed/src/lib.rs
Normal file
97
crates/xserv-distributed/src/lib.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
//! Tensor-parallel primitives for xserv.
|
||||||
|
//!
|
||||||
|
//! Process model: one OS thread per TP rank, each bound to one GPU. NCCL is
|
||||||
|
//! used for the collective (AllReduce); a hand-rolled P2P AllReduce may replace
|
||||||
|
//! it later as a learning exercise (see docs/17-tensor-parallelism.md).
|
||||||
|
|
||||||
|
pub mod ffi;
|
||||||
|
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
use ffi::{NcclComm, NcclUniqueId};
|
||||||
|
use xserv_cuda::device;
|
||||||
|
use xserv_cuda::GpuBuffer;
|
||||||
|
|
||||||
|
pub use ffi::NcclUniqueId as UniqueId;
|
||||||
|
|
||||||
|
/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run
|
||||||
|
/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered
|
||||||
|
/// after the producing matmul and before the consuming kernel — no extra sync.
|
||||||
|
const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut();
|
||||||
|
|
||||||
|
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
|
||||||
|
/// to all ranks out-of-band (e.g. via a shared variable across threads).
|
||||||
|
pub fn get_unique_id() -> NcclUniqueId {
|
||||||
|
let mut id = NcclUniqueId::default();
|
||||||
|
ffi::check(unsafe { ffi::ncclGetUniqueId(&mut id) }, "ncclGetUniqueId");
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-rank tensor-parallel context: NCCL communicator + a dedicated stream.
|
||||||
|
pub struct TpContext {
|
||||||
|
pub rank: usize,
|
||||||
|
pub world: usize,
|
||||||
|
pub device: u32,
|
||||||
|
comm: NcclComm,
|
||||||
|
}
|
||||||
|
|
||||||
|
// The NCCL communicator is owned by exactly one rank thread.
|
||||||
|
unsafe impl Send for TpContext {}
|
||||||
|
|
||||||
|
impl TpContext {
|
||||||
|
/// Initialize this rank. Must be called from the thread that will own this
|
||||||
|
/// rank's GPU work; binds the thread to `device` first. All ranks must call
|
||||||
|
/// this concurrently with the same `id` and `world`.
|
||||||
|
pub fn init(rank: usize, world: usize, id: NcclUniqueId, device: u32) -> Self {
|
||||||
|
device::set_device(device).expect("set_device");
|
||||||
|
let mut comm: NcclComm = std::ptr::null_mut();
|
||||||
|
// Wrap the concurrent inits in a group so they rendezvous without deadlock.
|
||||||
|
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)");
|
||||||
|
ffi::check(
|
||||||
|
unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, rank as i32) },
|
||||||
|
"ncclCommInitRank",
|
||||||
|
);
|
||||||
|
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||||
|
Self { rank, world, device, comm }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// In-place AllReduce(sum) over `count` BF16 elements in `buf`.
|
||||||
|
pub fn all_reduce_sum_bf16(&self, buf: &mut GpuBuffer, count: usize) {
|
||||||
|
self.all_reduce_sum_bf16_ptr(buf.as_mut_ptr() as *mut c_void, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// In-place AllReduce(sum) directly on a device pointer (`count` BF16 elems),
|
||||||
|
/// issued on the null stream so it is ordered with the model's kernels.
|
||||||
|
/// Asynchronous: a later sync (e.g. the D2H logits copy) completes it.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
/// `ptr` must point to at least `count` BF16 elements of valid device memory
|
||||||
|
/// on this rank's device. The reduction is in-place (send == recv).
|
||||||
|
pub fn all_reduce_sum_bf16_ptr(&self, ptr: *mut c_void, count: usize) {
|
||||||
|
if self.world == 1 {
|
||||||
|
return; // nothing to reduce
|
||||||
|
}
|
||||||
|
ffi::check(
|
||||||
|
unsafe {
|
||||||
|
ffi::ncclAllReduce(
|
||||||
|
ptr as *const c_void,
|
||||||
|
ptr,
|
||||||
|
count,
|
||||||
|
ffi::NCCL_BF16,
|
||||||
|
ffi::NCCL_SUM,
|
||||||
|
self.comm,
|
||||||
|
NULL_STREAM,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"ncclAllReduce",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TpContext {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.comm.is_null() {
|
||||||
|
unsafe { ffi::ncclCommDestroy(self.comm) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
50
crates/xserv-distributed/tests/allreduce.rs
Normal file
50
crates/xserv-distributed/tests/allreduce.rs
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//! 2-GPU AllReduce smoke test. Skips if fewer than 2 GPUs are present.
|
||||||
|
|
||||||
|
use half::bf16;
|
||||||
|
use std::thread;
|
||||||
|
use xserv_cuda::{device, GpuBuffer};
|
||||||
|
use xserv_distributed::{get_unique_id, TpContext};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allreduce_two_gpu_sum() {
|
||||||
|
let world = 2usize;
|
||||||
|
if device::device_count().unwrap_or(0) < world as i32 {
|
||||||
|
eprintln!("skip: need >= {world} GPUs");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = get_unique_id();
|
||||||
|
let n = 4096usize;
|
||||||
|
|
||||||
|
let handles: Vec<_> = (0..world)
|
||||||
|
.map(|rank| {
|
||||||
|
let id = id;
|
||||||
|
thread::spawn(move || {
|
||||||
|
let ctx = TpContext::init(rank, world, id, rank as u32);
|
||||||
|
|
||||||
|
// Rank r fills its buffer with (r + 1).
|
||||||
|
let val = bf16::from_f32((rank + 1) as f32);
|
||||||
|
let host = vec![val; n];
|
||||||
|
let src = unsafe {
|
||||||
|
std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2)
|
||||||
|
};
|
||||||
|
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
||||||
|
buf.copy_from_host(src).unwrap();
|
||||||
|
|
||||||
|
ctx.all_reduce_sum_bf16(&mut buf, n);
|
||||||
|
|
||||||
|
let mut out = vec![0u8; n * 2];
|
||||||
|
buf.copy_to_host(&mut out).unwrap();
|
||||||
|
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
|
||||||
|
(res[0].to_f32(), res[n - 1].to_f32())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// sum over ranks of (r+1) = 1 + 2 = 3
|
||||||
|
for h in handles {
|
||||||
|
let (first, last) = h.join().unwrap();
|
||||||
|
assert_eq!(first, 3.0, "AllReduce(sum) first element");
|
||||||
|
assert_eq!(last, 3.0, "AllReduce(sum) last element");
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user