49 lines
1.6 KiB
Rust
49 lines
1.6 KiB
Rust
//! 2-GPU AllReduce smoke test. Skips if fewer than 2 GPUs are present.
|
|
|
|
use half::bf16;
|
|
use std::thread;
|
|
use xserv_cuda::{GpuBuffer, device};
|
|
use xserv_distributed::{TpContext, get_unique_id};
|
|
|
|
#[test]
|
|
fn allreduce_two_gpu_sum() {
|
|
let world = 2usize;
|
|
if device::device_count().unwrap_or(0) < world as i32 {
|
|
eprintln!("skip: need >= {world} GPUs");
|
|
return;
|
|
}
|
|
|
|
let id = get_unique_id();
|
|
let n = 4096usize;
|
|
|
|
let handles: Vec<_> = (0..world)
|
|
.map(|rank| {
|
|
let id = id;
|
|
thread::spawn(move || {
|
|
let ctx = TpContext::init(rank, world, id, rank as u32);
|
|
|
|
// Rank r fills its buffer with (r + 1).
|
|
let val = bf16::from_f32((rank + 1) as f32);
|
|
let host = vec![val; n];
|
|
let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) };
|
|
let mut buf = GpuBuffer::alloc(n * 2).unwrap();
|
|
buf.copy_from_host(src).unwrap();
|
|
|
|
ctx.all_reduce_sum_bf16(&mut buf, n);
|
|
|
|
let mut out = vec![0u8; n * 2];
|
|
buf.copy_to_host(&mut out).unwrap();
|
|
let res = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const bf16, n) };
|
|
(res[0].to_f32(), res[n - 1].to_f32())
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
// sum over ranks of (r+1) = 1 + 2 = 3
|
|
for h in handles {
|
|
let (first, last) = h.join().unwrap();
|
|
assert_eq!(first, 3.0, "AllReduce(sum) first element");
|
|
assert_eq!(last, 3.0, "AllReduce(sum) last element");
|
|
}
|
|
}
|