distributed: NCCL P2P primitives (PpContext + send/recv)

Add ncclSend/ncclRecv FFI and a PpContext that initializes a NCCL
communicator across P pipeline stages and hands the hidden state to
neighbour stages on the null stream. Mirrors TpContext; the collective
differs (point-to-point hand-off vs in-layer AllReduce).

tests/sendrecv.rs: 2-GPU stage0->stage1 send/recv smoke test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 18:45:42 +08:00
parent c2362df1f1
commit 859c0cc0b6
3 changed files with 143 additions and 0 deletions

View File

@@ -45,6 +45,23 @@ unsafe extern "C" {
comm: NcclComm,
stream: CudaStream,
) -> i32;
// Point-to-point primitives for pipeline parallelism (Phase 18).
pub fn ncclSend(
sendbuff: *const c_void,
count: usize,
datatype: i32,
peer: i32,
comm: NcclComm,
stream: CudaStream,
) -> i32;
pub fn ncclRecv(
recvbuff: *mut c_void,
count: usize,
datatype: i32,
peer: i32,
comm: NcclComm,
stream: CudaStream,
) -> i32;
pub fn ncclGroupStart() -> i32;
pub fn ncclGroupEnd() -> i32;
pub fn ncclGetErrorString(result: i32) -> *const c_char;

View File

@@ -95,3 +95,67 @@ impl Drop for TpContext {
}
}
}
/// Per-stage pipeline-parallel context: a NCCL communicator spanning all `P`
/// stages plus point-to-point send/recv of the hidden state to the neighbour
/// stages. Init is identical to `TpContext` (one comm across `world` ranks);
/// only the collective differs — PP hands off `[tokens, hidden]` between
/// consecutive stages instead of AllReducing within a layer.
pub struct PpContext {
pub stage: usize,
pub world: usize,
pub device: u32,
comm: NcclComm,
}
// The NCCL communicator is owned by exactly one stage thread.
unsafe impl Send for PpContext {}
impl PpContext {
/// Initialize this stage. Must be called from the thread that owns this
/// stage's GPU; binds the thread to `device` first. All stages call this
/// concurrently with the same `id` and `world`.
pub fn init(stage: usize, world: usize, id: NcclUniqueId, device: u32) -> Self {
device::set_device(device).expect("set_device");
let mut comm: NcclComm = std::ptr::null_mut();
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(init)");
ffi::check(
unsafe { ffi::ncclCommInitRank(&mut comm, world as i32, id, stage as i32) },
"ncclCommInitRank",
);
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)");
Self { stage, world, device, comm }
}
/// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is
/// ordered after the producing matmul. Asynchronous — a later `synchronize`
/// (the caller must do one before reusing/freeing the buffer) completes it.
///
/// # Safety
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
ffi::check(
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
"ncclSend",
);
}
/// Receive `count` BF16 elements from `peer` into `ptr`, on the null stream.
///
/// # Safety
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
ffi::check(
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
"ncclRecv",
);
}
}
impl Drop for PpContext {
fn drop(&mut self) {
if !self.comm.is_null() {
unsafe { ffi::ncclCommDestroy(self.comm) };
}
}
}

View File

@@ -0,0 +1,62 @@
//! 2-GPU NCCL P2P send/recv smoke test for pipeline parallelism.
//! Stage 0 sends a known vector to stage 1, which verifies it. Skips if fewer
//! than 2 GPUs are present. Mirrors `allreduce.rs` (GpuBuffer + half only —
//! this crate does not depend on xserv-tensor).
use half::bf16;
use std::ffi::c_void;
use std::thread;
use xserv_cuda::{device, GpuBuffer};
use xserv_distributed::{get_unique_id, PpContext};
#[test]
fn pp_send_recv_two_stages() {
let world = 2usize;
if device::device_count().unwrap_or(0) < world as i32 {
eprintln!("skip: need >= {world} GPUs");
return;
}
let id = get_unique_id();
let n = 4096usize; // one [1, hidden]-sized hand-off
let handles: Vec<_> = (0..world)
.map(|stage| {
let id = id;
thread::spawn(move || {
let pp = PpContext::init(stage, world, id, stage as u32);
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
if stage == 0 {
// Fill with a known pattern and send to stage 1.
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
buf.copy_from_host(src).unwrap();
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
device::synchronize().unwrap();
None
} else {
// Receive into a zeroed buffer and read it back.
buf.copy_from_host(&vec![0u8; n * 2]).unwrap();
pp.recv_bf16_ptr(buf.as_mut_ptr() as *mut c_void, n, 0);
device::synchronize().unwrap();
let mut out = vec![0u8; n * 2];
buf.copy_to_host(&mut out).unwrap();
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
Some((res[0].to_f32(), res[1].to_f32(), res[n - 1].to_f32()))
}
})
})
.collect();
let mut checked = false;
for h in handles {
if let Some((first, second, last)) = h.join().unwrap() {
assert_eq!(first, 0.0, "recv[0]");
assert_eq!(second, 1.0, "recv[1]");
assert_eq!(last, ((n - 1) % 97) as f32, "recv[last]");
checked = true;
}
}
assert!(checked, "stage 1 never verified the received buffer");
}