distributed: NCCL tensor-parallel primitives (TpContext + AllReduce)

New xserv-distributed crate: hand-written NCCL FFI, TpContext (one rank per
thread, bound to one GPU), and in-place BF16 AllReduce on the null stream so
it orders naturally with the model's kernels. 2-GPU AllReduce test included.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 11:10:14 +08:00
parent 76fffb3b68
commit 453520d622
6 changed files with 234 additions and 0 deletions

View File

@@ -7,6 +7,7 @@ members = [
"crates/xserv-model",
"crates/xserv-tokenizer",
"crates/xserv-server",
"crates/xserv-distributed",
]
[workspace.package]

View File

@@ -0,0 +1,8 @@
[package]
name = "xserv-distributed"
version.workspace = true
edition.workspace = true
[dependencies]
xserv-cuda = { path = "../xserv-cuda" }
half.workspace = true

View File

@@ -0,0 +1,13 @@
use std::env;
fn main() {
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
// NCCL is typically installed as a system library.
println!("cargo:rustc-link-search=native=/usr/lib/x86_64-linux-gnu");
println!("cargo:rustc-link-lib=dylib=nccl");
println!("cargo:rustc-link-lib=dylib=cudart");
}

View File

@@ -0,0 +1,65 @@
//! Minimal NCCL FFI bindings (hand-written, like the CUDA bindings).
//! Only the collectives we need for tensor parallelism.
use std::ffi::c_void;
use std::os::raw::c_char;
use xserv_cuda::ffi::CudaStream;
/// Opaque NCCL communicator handle (`ncclComm_t`).
pub type NcclComm = *mut c_void;
/// `ncclUniqueId` is a 128-byte opaque blob shared from rank 0 to all ranks.
#[repr(C)]
#[derive(Clone, Copy)]
pub struct NcclUniqueId {
pub internal: [c_char; 128],
}
impl Default for NcclUniqueId {
fn default() -> Self {
Self { internal: [0; 128] }
}
}
// ncclDataType_t (subset)
pub const NCCL_FLOAT32: i32 = 7;
pub const NCCL_BF16: i32 = 9;
// ncclRedOp_t
pub const NCCL_SUM: i32 = 0;
// ncclResult_t
pub const NCCL_SUCCESS: i32 = 0;
unsafe extern "C" {
pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32;
// ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI.
pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32;
pub fn ncclCommDestroy(comm: NcclComm) -> i32;
pub fn ncclAllReduce(
sendbuff: *const c_void,
recvbuff: *mut c_void,
count: usize,
datatype: i32,
op: i32,
comm: NcclComm,
stream: CudaStream,
) -> i32;
pub fn ncclGroupStart() -> i32;
pub fn ncclGroupEnd() -> i32;
pub fn ncclGetErrorString(result: i32) -> *const c_char;
}
pub fn err_string(result: i32) -> String {
unsafe {
let p = ncclGetErrorString(result);
if p.is_null() {
return format!("nccl error {result}");
}
std::ffi::CStr::from_ptr(p).to_string_lossy().into_owned()
}
}
pub fn check(result: i32, what: &str) {
assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result));
}

View File

@@ -0,0 +1,97 @@
//! Tensor-parallel primitives for xserv.
//!
//! Process model: one OS thread per TP rank, each bound to one GPU. NCCL is
//! used for the collective (AllReduce); a hand-rolled P2P AllReduce may replace
//! it later as a learning exercise (see docs/17-tensor-parallelism.md).
pub mod ffi;
use std::ffi::c_void;
use ffi::{NcclComm, NcclUniqueId};
use xserv_cuda::device;
use xserv_cuda::GpuBuffer;
pub use ffi::NcclUniqueId as UniqueId;
/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run
/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered
/// after the producing matmul and before the consuming kernel — no extra sync.
const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut();
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
/// to all ranks out-of-band (e.g. via a shared variable across threads).
pub fn get_unique_id() -> NcclUniqueId {
let mut id = NcclUniqueId::default();
ffi::check(unsafe { ffi::ncclGetUniqueId(&mut id) }, "ncclGetUniqueId");
id
}
/// Per-rank tensor-parallel context: NCCL communicator + a dedicated stream.
pub struct TpContext {
pub rank: usize,
pub world: usize,
pub device: u32,
comm: NcclComm,
}
// The NCCL communicator is owned by exactly one rank thread.
unsafe impl Send for TpContext {}
impl TpContext {
/// Initialize this rank. Must be called from the thread that will own this
/// rank's GPU work; binds the thread to `device` first. All ranks must call
/// this concurrently with the same `id` and `world`.
pub fn init(rank: usize, world: usize, id: NcclUniqueId, device: u32) -> Self {
device::set_device(device).expect("set_device");
let mut comm: NcclComm = std::ptr::null_mut();
// Wrap the concurrent inits in a group so they rendezvous without deadlock.
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)");
ffi::check(
unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, rank as i32) },
"ncclCommInitRank",
);
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
Self { rank, world, device, comm }
}
/// In-place AllReduce(sum) over `count` BF16 elements in `buf`.
pub fn all_reduce_sum_bf16(&self, buf: &mut GpuBuffer, count: usize) {
self.all_reduce_sum_bf16_ptr(buf.as_mut_ptr() as *mut c_void, count);
}
/// In-place AllReduce(sum) directly on a device pointer (`count` BF16 elems),
/// issued on the null stream so it is ordered with the model's kernels.
/// Asynchronous: a later sync (e.g. the D2H logits copy) completes it.
///
/// # Safety
/// `ptr` must point to at least `count` BF16 elements of valid device memory
/// on this rank's device. The reduction is in-place (send == recv).
pub fn all_reduce_sum_bf16_ptr(&self, ptr: *mut c_void, count: usize) {
if self.world == 1 {
return; // nothing to reduce
}
ffi::check(
unsafe {
ffi::ncclAllReduce(
ptr as *const c_void,
ptr,
count,
ffi::NCCL_BF16,
ffi::NCCL_SUM,
self.comm,
NULL_STREAM,
)
},
"ncclAllReduce",
);
}
}
impl Drop for TpContext {
fn drop(&mut self) {
if !self.comm.is_null() {
unsafe { ffi::ncclCommDestroy(self.comm) };
}
}
}

View File

@@ -0,0 +1,50 @@
//! 2-GPU AllReduce smoke test. Skips if fewer than 2 GPUs are present.
use half::bf16;
use std::thread;
use xserv_cuda::{device, GpuBuffer};
use xserv_distributed::{get_unique_id, TpContext};
#[test]
fn allreduce_two_gpu_sum() {
let world = 2usize;
if device::device_count().unwrap_or(0) < world as i32 {
eprintln!("skip: need >= {world} GPUs");
return;
}
let id = get_unique_id();
let n = 4096usize;
let handles: Vec<_> = (0..world)
.map(|rank| {
let id = id;
thread::spawn(move || {
let ctx = TpContext::init(rank, world, id, rank as u32);
// Rank r fills its buffer with (r + 1).
let val = bf16::from_f32((rank + 1) as f32);
let host = vec![val; n];
let src = unsafe {
std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2)
};
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
buf.copy_from_host(src).unwrap();
ctx.all_reduce_sum_bf16(&mut buf, n);
let mut out = vec![0u8; n * 2];
buf.copy_to_host(&mut out).unwrap();
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
(res[0].to_f32(), res[n - 1].to_f32())
})
})
.collect();
// sum over ranks of (r+1) = 1 + 2 = 3
for h in handles {
let (first, last) = h.join().unwrap();
assert_eq!(first, 3.0, "AllReduce(sum) first element");
assert_eq!(last, 3.0, "AllReduce(sum) last element");
}
}