dist: coalesce grads into buckets for all-reduce (KI-5)
Replace the per-parameter eager all-reduce (~150 tiny serial NCCL calls for dim512, DDP's dominant cost after T10's batched forward) with a coalesced bucketed all-reduce: pack grads into a few large contiguous scratch buffers, all-reduce each bucket once (fused via ncclGroupStart/ End), fold the 1/world average into one per-bucket scale, unpack back. The packed buffer is the concatenation of the grad tensors, so NCCL's element-wise sum over a bucket equals the per-tensor sums — bit-identical to the un-bucketed path; only launch/latency overhead is removed. DDP cross-rank param identity + loss-match are preserved. Adds xtrain_cuda::device::copy_d2d (cudaMemcpy D2D) for the pack/unpack. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -14,3 +14,15 @@ pub fn set_device(device: u32) -> Result<()> {
|
||||
pub fn synchronize() -> Result<()> {
|
||||
error::check(unsafe { ffi::cudaDeviceSynchronize() })
|
||||
}
|
||||
|
||||
/// Device-to-device copy of `count` bytes (`dst <- src`) on the same GPU. Issued
|
||||
/// on the null stream (like every other xtrain kernel), so it orders with the
|
||||
/// surrounding work. Used by the DDP bucketed all-reduce to pack/unpack grads
|
||||
/// into a flat scratch buffer.
|
||||
///
|
||||
/// # Safety
|
||||
/// `dst`/`src` must point to at least `count` valid bytes of device memory on the
|
||||
/// current device, with no overlap.
|
||||
pub unsafe fn copy_d2d(dst: *mut u8, src: *const u8, count: usize) -> Result<()> {
|
||||
error::check(unsafe { ffi::cudaMemcpy(dst, src, count, ffi::CUDA_MEMCPY_D2D) })
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ pub type CudaStream = *mut c_void;
|
||||
|
||||
pub const CUDA_MEMCPY_H2D: i32 = 1;
|
||||
pub const CUDA_MEMCPY_D2H: i32 = 2;
|
||||
pub const CUDA_MEMCPY_D2D: i32 = 3;
|
||||
|
||||
pub const CUDA_SUCCESS: i32 = 0;
|
||||
pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
//! rank thread binds its device, builds its own model (xtrain's `Var` graph is
|
||||
//! `Rc`-based and not `Send`, so it must be constructed thread-locally — only the
|
||||
//! `UniqueId` and scalar config cross the thread boundary), processes a disjoint
|
||||
//! shard of the global batch, then AllReduces every parameter's `.grad()` device
|
||||
//! buffer in place, averages by world size, and runs its own `GpuAdamW.step`.
|
||||
//! Identical init + identical optimizer state across ranks keeps the parameters
|
||||
//! consistent without ever re-syncing the weights.
|
||||
//! shard of the global batch, then **coalesces every parameter's `.grad()` into a
|
||||
//! few large buckets and all-reduces each bucket once** (Phase T11 — see
|
||||
//! `all_reduce_average_grads`), averages by world size, and runs its own
|
||||
//! `GpuAdamW.step`. Identical init + identical optimizer state across ranks keeps
|
||||
//! the parameters consistent without ever re-syncing the weights.
|
||||
//!
|
||||
//! NCCL is issued on the legacy null stream — every xtrain kernel launches on the
|
||||
//! null stream (`std::ptr::null_mut()`), so the AllReduce stays correctly ordered
|
||||
@@ -26,6 +27,7 @@ use std::ffi::c_void;
|
||||
use ffi::{NcclComm, NcclUniqueId};
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
pub use ffi::NcclUniqueId as UniqueId;
|
||||
|
||||
@@ -101,7 +103,7 @@ impl DdpContext {
|
||||
}
|
||||
|
||||
/// AllReduce every parameter's `.grad()` across ranks and divide by `world`,
|
||||
/// the one collective DDP needs per step.
|
||||
/// the one collective DDP needs per step — **coalesced (bucketed)**.
|
||||
///
|
||||
/// Each rank ran forward+backward on its own shard of `b` sequences, so
|
||||
/// `.grad()` holds the SUM over that shard (the tape's fan-out rule). After
|
||||
@@ -112,38 +114,99 @@ impl DdpContext {
|
||||
/// mean gradient the single-GPU loop computes from a batch of `B_global`.
|
||||
/// Params without a grad are skipped.
|
||||
///
|
||||
/// A single-process group barrier is unnecessary: the all-reduces serialize
|
||||
/// on the comm, and the in-place scale runs on the same null stream after.
|
||||
/// **Coalescing (KI-5 fix, Phase T11)**: instead of one tiny `ncclAllReduce`
|
||||
/// per parameter tensor (~150 serial launches for dim512 → DDP's dominant cost
|
||||
/// once T10's batched forward made compute fast), pack the grads into a few
|
||||
/// large contiguous scratch buckets and all-reduce each bucket ONCE. The packed
|
||||
/// buffer is just the concatenation of the grad tensors, so NCCL's element-wise
|
||||
/// sum over a bucket equals the per-tensor sums — the result is **bit-identical**
|
||||
/// to the un-bucketed path; only the launch/latency overhead is removed. The
|
||||
/// `1/world` average folds into one per-bucket scale. The per-bucket all-reduces
|
||||
/// are wrapped in `ncclGroupStart/End` so NCCL fuses them into one operation.
|
||||
pub fn all_reduce_average_grads(&self, params: &[Var]) {
|
||||
if self.world == 1 {
|
||||
return;
|
||||
}
|
||||
// 1. Sum every grad across ranks (in place, on the null stream).
|
||||
for p in params {
|
||||
if let Some(g) = p.grad() {
|
||||
let n = g.numel();
|
||||
self.all_reduce_sum_f32_ptr(g.data_ptr() as *mut c_void, n);
|
||||
}
|
||||
// Collect this step's grads (in `params()` order) and plan buckets.
|
||||
let grads: Vec<Tensor> = params.iter().filter_map(|p| p.grad()).collect();
|
||||
if grads.is_empty() {
|
||||
return;
|
||||
}
|
||||
// 2. Average: scale each summed grad by 1/world (null-stream kernel,
|
||||
// ordered after the AllReduce that produced it).
|
||||
let buckets = plan_buckets(&grads, BUCKET_CAP_ELEMS);
|
||||
|
||||
let inv_world = 1.0 / self.world as f32;
|
||||
for p in params {
|
||||
if let Some(g) = p.grad() {
|
||||
let device = Device::Cuda(self.device);
|
||||
for bucket in &buckets {
|
||||
let total: usize = bucket.iter().map(|g| g.numel()).sum();
|
||||
// Flat scratch buffer for this bucket (fully overwritten by the pack
|
||||
// below; `cudaFree` on drop synchronizes, so it outlives its copies).
|
||||
let flat = Tensor::zeros(&[total], xtrain_tensor::DType::F32, device);
|
||||
let flat_ptr = flat.data_ptr() as *mut u8;
|
||||
// Pack: D2D-copy each grad into the bucket at its running offset.
|
||||
let mut off = 0usize;
|
||||
for g in bucket {
|
||||
let bytes = g.numel() * 4;
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_scale_inplace_f32(
|
||||
g.data_ptr() as *mut f32,
|
||||
inv_world,
|
||||
g.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
device::copy_d2d(flat_ptr.add(off), g.data_ptr(), bytes)
|
||||
.expect("pack grad bucket");
|
||||
}
|
||||
off += bytes;
|
||||
}
|
||||
// One AllReduce(sum) over the whole bucket (fused via the group), then
|
||||
// one scale by 1/world — same math as per-tensor, far fewer launches.
|
||||
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(bucket)");
|
||||
self.all_reduce_sum_f32_ptr(flat_ptr as *mut c_void, total);
|
||||
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(bucket)");
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_scale_inplace_f32(
|
||||
flat_ptr as *mut f32,
|
||||
inv_world,
|
||||
total as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
// Unpack: D2D-copy each averaged slice back into its grad tensor.
|
||||
let mut off = 0usize;
|
||||
for g in bucket {
|
||||
let bytes = g.numel() * 4;
|
||||
unsafe {
|
||||
device::copy_d2d(g.data_ptr() as *mut u8, flat_ptr.add(off), bytes)
|
||||
.expect("unpack grad bucket");
|
||||
}
|
||||
off += bytes;
|
||||
}
|
||||
}
|
||||
device::synchronize().expect("grad all-reduce sync failed");
|
||||
}
|
||||
}
|
||||
|
||||
/// Target bucket size in F32 elements (~25 MB). Big enough to amortize NCCL
|
||||
/// launch latency across many params, small enough that the scratch allocation
|
||||
/// stays modest. The exact value is not load-bearing for correctness.
|
||||
const BUCKET_CAP_ELEMS: usize = 25 * 1024 * 1024 / 4;
|
||||
|
||||
/// Greedily group `grads` (in order) into buckets whose total element count stays
|
||||
/// under `cap` — except a single grad larger than `cap`, which gets its own
|
||||
/// bucket. Order is preserved so packing offsets are deterministic across ranks.
|
||||
fn plan_buckets(grads: &[Tensor], cap: usize) -> Vec<Vec<Tensor>> {
|
||||
let mut buckets: Vec<Vec<Tensor>> = Vec::new();
|
||||
let mut cur: Vec<Tensor> = Vec::new();
|
||||
let mut cur_n = 0usize;
|
||||
for g in grads {
|
||||
let n = g.numel();
|
||||
if cur_n > 0 && cur_n + n > cap {
|
||||
buckets.push(std::mem::take(&mut cur));
|
||||
cur_n = 0;
|
||||
}
|
||||
cur.push(g.clone());
|
||||
cur_n += n;
|
||||
}
|
||||
if !cur.is_empty() {
|
||||
buckets.push(cur);
|
||||
}
|
||||
buckets
|
||||
}
|
||||
|
||||
impl Drop for DdpContext {
|
||||
fn drop(&mut self) {
|
||||
if !self.comm.is_null() {
|
||||
|
||||
Reference in New Issue
Block a user