64 lines
2.5 KiB
Rust
64 lines
2.5 KiB
Rust
//! 2-GPU NCCL P2P send/recv smoke test for pipeline parallelism.
|
|
//! Stage 0 sends a known vector to stage 1, which verifies it. Skips if fewer
|
|
//! than 2 GPUs are present. Mirrors `allreduce.rs` (GpuBuffer + half only —
|
|
//! this crate does not depend on xserv-tensor).
|
|
|
|
use half::bf16;
|
|
use std::ffi::c_void;
|
|
use std::thread;
|
|
use xserv_cuda::{GpuBuffer, device};
|
|
use xserv_distributed::{PpContext, get_unique_id};
|
|
|
|
#[test]
|
|
fn pp_send_recv_two_stages() {
|
|
let world = 2usize;
|
|
if device::device_count().unwrap_or(0) < world as i32 {
|
|
eprintln!("skip: need >= {world} GPUs");
|
|
return;
|
|
}
|
|
|
|
let id = get_unique_id();
|
|
let n = 4096usize; // one [1, hidden]-sized hand-off
|
|
|
|
let handles: Vec<_> = (0..world)
|
|
.map(|stage| {
|
|
let id = id;
|
|
thread::spawn(move || {
|
|
let pp = PpContext::init(stage, world, id, stage as u32);
|
|
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
|
|
|
if stage == 0 {
|
|
// Fill with a known pattern and send to stage 1.
|
|
let host: Vec<bf16> = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect();
|
|
let src =
|
|
unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
|
buf.copy_from_host(src).unwrap();
|
|
pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1);
|
|
device::synchronize().unwrap();
|
|
None
|
|
} else {
|
|
// Receive into a zeroed buffer and read it back.
|
|
buf.copy_from_host(&vec![0u8; n * 2]).unwrap();
|
|
pp.recv_bf16_ptr(buf.as_mut_ptr() as *mut c_void, n, 0);
|
|
device::synchronize().unwrap();
|
|
let mut out = vec![0u8; n * 2];
|
|
buf.copy_to_host(&mut out).unwrap();
|
|
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
|
|
Some((res[0].to_f32(), res[1].to_f32(), res[n - 1].to_f32()))
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
let mut checked = false;
|
|
for h in handles {
|
|
if let Some((first, second, last)) = h.join().unwrap() {
|
|
assert_eq!(first, 0.0, "recv[0]");
|
|
assert_eq!(second, 1.0, "recv[1]");
|
|
assert_eq!(last, ((n - 1) % 97) as f32, "recv[last]");
|
|
checked = true;
|
|
}
|
|
}
|
|
assert!(checked, "stage 1 never verified the received buffer");
|
|
}
|