distributed: NCCL P2P primitives (PpContext + send/recv)
Add ncclSend/ncclRecv FFI and a PpContext that initializes a NCCL communicator across P pipeline stages and hands the hidden state to neighbour stages on the null stream. Mirrors TpContext; the collective differs (point-to-point hand-off vs in-layer AllReduce). tests/sendrecv.rs: 2-GPU stage0->stage1 send/recv smoke test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -45,6 +45,23 @@ unsafe extern "C" {
|
||||
comm: NcclComm,
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
// Point-to-point primitives for pipeline parallelism (Phase 18).
|
||||
pub fn ncclSend(
|
||||
sendbuff: *const c_void,
|
||||
count: usize,
|
||||
datatype: i32,
|
||||
peer: i32,
|
||||
comm: NcclComm,
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
pub fn ncclRecv(
|
||||
recvbuff: *mut c_void,
|
||||
count: usize,
|
||||
datatype: i32,
|
||||
peer: i32,
|
||||
comm: NcclComm,
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
pub fn ncclGroupStart() -> i32;
|
||||
pub fn ncclGroupEnd() -> i32;
|
||||
pub fn ncclGetErrorString(result: i32) -> *const c_char;
|
||||
|
||||
@@ -95,3 +95,67 @@ impl Drop for TpContext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-stage pipeline-parallel context: a NCCL communicator spanning all `P`
|
||||
/// stages plus point-to-point send/recv of the hidden state to the neighbour
|
||||
/// stages. Init is identical to `TpContext` (one comm across `world` ranks);
|
||||
/// only the collective differs — PP hands off `[tokens, hidden]` between
|
||||
/// consecutive stages instead of AllReducing within a layer.
|
||||
pub struct PpContext {
|
||||
pub stage: usize,
|
||||
pub world: usize,
|
||||
pub device: u32,
|
||||
comm: NcclComm,
|
||||
}
|
||||
|
||||
// The NCCL communicator is owned by exactly one stage thread.
|
||||
unsafe impl Send for PpContext {}
|
||||
|
||||
impl PpContext {
|
||||
/// Initialize this stage. Must be called from the thread that owns this
|
||||
/// stage's GPU; binds the thread to `device` first. All stages call this
|
||||
/// concurrently with the same `id` and `world`.
|
||||
pub fn init(stage: usize, world: usize, id: NcclUniqueId, device: u32) -> Self {
|
||||
device::set_device(device).expect("set_device");
|
||||
let mut comm: NcclComm = std::ptr::null_mut();
|
||||
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)");
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, stage as i32) },
|
||||
"ncclCommInitRank",
|
||||
);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
|
||||
Self { stage, world, device, comm }
|
||||
}
|
||||
|
||||
/// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is
|
||||
/// ordered after the producing matmul. Asynchronous — a later `synchronize`
|
||||
/// (the caller must do one before reusing/freeing the buffer) completes it.
|
||||
///
|
||||
/// # Safety
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
||||
"ncclSend",
|
||||
);
|
||||
}
|
||||
|
||||
/// Receive `count` BF16 elements from `peer` into `ptr`, on the null stream.
|
||||
///
|
||||
/// # Safety
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
||||
"ncclRecv",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PpContext {
|
||||
fn drop(&mut self) {
|
||||
if !self.comm.is_null() {
|
||||
unsafe { ffi::ncclCommDestroy(self.comm) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
62
crates/xserv-distributed/tests/sendrecv.rs
Normal file
62
crates/xserv-distributed/tests/sendrecv.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! 2-GPU NCCL P2P send/recv smoke test for pipeline parallelism.
|
||||
//! Stage 0 sends a known vector to stage 1, which verifies it. Skips if fewer
|
||||
//! than 2 GPUs are present. Mirrors `allreduce.rs` (GpuBuffer + half only —
|
||||
//! this crate does not depend on xserv-tensor).
|
||||
|
||||
use half::bf16;
|
||||
use std::ffi::c_void;
|
||||
use std::thread;
|
||||
use xserv_cuda::{device, GpuBuffer};
|
||||
use xserv_distributed::{get_unique_id, PpContext};
|
||||
|
||||
#[test]
|
||||
fn pp_send_recv_two_stages() {
|
||||
let world = 2usize;
|
||||
if device::device_count().unwrap_or(0) < world as i32 {
|
||||
eprintln!("skip: need >= {world} GPUs");
|
||||
return;
|
||||
}
|
||||
|
||||
let id = get_unique_id();
|
||||
let n = 4096usize; // one [1, hidden]-sized hand-off
|
||||
|
||||
let handles: Vec<_> = (0..world)
|
||||
.map(|stage| {
|
||||
let id = id;
|
||||
thread::spawn(move || {
|
||||
let pp = PpContext::init(stage, world, id, stage as u32);
|
||||
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
||||
|
||||
if stage == 0 {
|
||||
// Fill with a known pattern and send to stage 1.
|
||||
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
|
||||
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
||||
buf.copy_from_host(src).unwrap();
|
||||
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
|
||||
device::synchronize().unwrap();
|
||||
None
|
||||
} else {
|
||||
// Receive into a zeroed buffer and read it back.
|
||||
buf.copy_from_host(&vec![0u8; n * 2]).unwrap();
|
||||
pp.recv_bf16_ptr(buf.as_mut_ptr() as *mut c_void, n, 0);
|
||||
device::synchronize().unwrap();
|
||||
let mut out = vec![0u8; n * 2];
|
||||
buf.copy_to_host(&mut out).unwrap();
|
||||
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
|
||||
Some((res[0].to_f32(), res[1].to_f32(), res[n - 1].to_f32()))
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut checked = false;
|
||||
for h in handles {
|
||||
if let Some((first, second, last)) = h.join().unwrap() {
|
||||
assert_eq!(first, 0.0, "recv[0]");
|
||||
assert_eq!(second, 1.0, "recv[1]");
|
||||
assert_eq!(last, ((n - 1) % 97) as f32, "recv[last]");
|
||||
checked = true;
|
||||
}
|
||||
}
|
||||
assert!(checked, "stage 1 never verified the received buffer");
|
||||
}
|
||||
Reference in New Issue
Block a user