Files
xserv/crates/xserv-distributed/tests/allreduce.rs

49 lines
1.6 KiB
Rust

//! 2-GPU AllReduce smoke test. Skips if fewer than 2 GPUs are present.
use half::bf16;
use std::thread;
use xserv_cuda::{GpuBuffer, device};
use xserv_distributed::{TpContext, get_unique_id};
#[test]
fn allreduce_two_gpu_sum() {
let world = 2usize;
if device::device_count().unwrap_or(0) < world as i32 {
eprintln!("skip: need >= {world} GPUs");
return;
}
let id = get_unique_id();
let n = 4096usize;
let handles: Vec<_> = (0..world)
.map(|rank| {
let id = id;
thread::spawn(move || {
let ctx = TpContext::init(rank, world, id, rank as u32);
// Rank r fills its buffer with (r + 1).
let val = bf16::from_f32((rank + 1) as f32);
let host = vec![val; n];
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
buf.copy_from_host(src).unwrap();
ctx.all_reduce_sum_bf16(&mut buf, n);
let mut out = vec![0u8; n * 2];
buf.copy_to_host(&mut out).unwrap();
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
(res[0].to_f32(), res[n - 1].to_f32())
})
})
.collect();
// sum over ranks of (r+1) = 1 + 2 = 3
for h in handles {
let (first, last) = h.join().unwrap();
assert_eq!(first, 3.0, "AllReduce(sum) first element");
assert_eq!(last, 3.0, "AllReduce(sum) last element");
}
}