dist: coalesce grads into buckets for all-reduce (KI-5)

Replace the per-parameter eager all-reduce (~150 tiny serial NCCL calls
for dim512, DDP's dominant cost after T10's batched forward) with a
coalesced bucketed all-reduce: pack grads into a few large contiguous
scratch buffers, all-reduce each bucket once (fused via ncclGroupStart/
End), fold the 1/world average into one per-bucket scale, unpack back.

The packed buffer is the concatenation of the grad tensors, so NCCL's
element-wise sum over a bucket equals the per-tensor sums — bit-identical
to the un-bucketed path; only launch/latency overhead is removed. DDP
cross-rank param identity + loss-match are preserved.

Adds xtrain_cuda::device::copy_d2d (cudaMemcpy D2D) for the pack/unpack.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 09:09:44 +08:00
parent a78502e0f0
commit b8b58212dc
3 changed files with 99 additions and 23 deletions

View File

@@ -14,3 +14,15 @@ pub fn set_device(device: u32) -> Result<()> {
pub fn synchronize() -> Result<()> {
error::check(unsafe { ffi::cudaDeviceSynchronize() })
}
/// Device-to-device copy of `count` bytes (`dst <- src`) on the same GPU. Issued
/// on the null stream (like every other xtrain kernel), so it orders with the
/// surrounding work. Used by the DDP bucketed all-reduce to pack/unpack grads
/// into a flat scratch buffer.
///
/// # Safety
/// `dst`/`src` must point to at least `count` valid bytes of device memory on the
/// current device, with no overlap.
pub unsafe fn copy_d2d(dst: *mut u8, src: *const u8, count: usize) -> Result<()> {
error::check(unsafe { ffi::cudaMemcpy(dst, src, count, ffi::CUDA_MEMCPY_D2D) })
}

View File

@@ -5,6 +5,7 @@ pub type CudaStream = *mut c_void;
pub const CUDA_MEMCPY_H2D: i32 = 1;
pub const CUDA_MEMCPY_D2H: i32 = 2;
pub const CUDA_MEMCPY_D2D: i32 = 3;
pub const CUDA_SUCCESS: i32 = 0;
pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;

View File

@@ -4,10 +4,11 @@
//! rank thread binds its device, builds its own model (xtrain's `Var` graph is
//! `Rc`-based and not `Send`, so it must be constructed thread-locally — only the
//! `UniqueId` and scalar config cross the thread boundary), processes a disjoint
//! shard of the global batch, then AllReduces every parameter's `.grad()` device
//! buffer in place, averages by world size, and runs its own `GpuAdamW.step`.
//! Identical init + identical optimizer state across ranks keeps the parameters
//! consistent without ever re-syncing the weights.
//! shard of the global batch, then **coalesces every parameter's `.grad()` into a
//! few large buckets and all-reduces each bucket once** (Phase T11 — see
//! `all_reduce_average_grads`), averages by world size, and runs its own
//! `GpuAdamW.step`. Identical init + identical optimizer state across ranks keeps
//! the parameters consistent without ever re-syncing the weights.
//!
//! NCCL is issued on the legacy null stream — every xtrain kernel launches on the
//! null stream (`std::ptr::null_mut()`), so the AllReduce stays correctly ordered
@@ -26,6 +27,7 @@ use std::ffi::c_void;
use ffi::{NcclComm, NcclUniqueId};
use xtrain_autodiff::tape::Var;
use xtrain_cuda::device;
use xtrain_tensor::{Device, Tensor};
pub use ffi::NcclUniqueId as UniqueId;
@@ -101,7 +103,7 @@ impl DdpContext {
}
/// AllReduce every parameter's `.grad()` across ranks and divide by `world`,
/// the one collective DDP needs per step.
/// the one collective DDP needs per step — **coalesced (bucketed)**.
///
/// Each rank ran forward+backward on its own shard of `b` sequences, so
/// `.grad()` holds the SUM over that shard (the tape's fan-out rule). After
@@ -112,38 +114,99 @@ impl DdpContext {
/// mean gradient the single-GPU loop computes from a batch of `B_global`.
/// Params without a grad are skipped.
///
/// A single-process group barrier is unnecessary: the all-reduces serialize
/// on the comm, and the in-place scale runs on the same null stream after.
/// **Coalescing (KI-5 fix, Phase T11)**: instead of one tiny `ncclAllReduce`
/// per parameter tensor (~150 serial launches for dim512 → DDP's dominant cost
/// once T10's batched forward made compute fast), pack the grads into a few
/// large contiguous scratch buckets and all-reduce each bucket ONCE. The packed
/// buffer is just the concatenation of the grad tensors, so NCCL's element-wise
/// sum over a bucket equals the per-tensor sums — the result is **bit-identical**
/// to the un-bucketed path; only the launch/latency overhead is removed. The
/// `1/world` average folds into one per-bucket scale. The per-bucket all-reduces
/// are wrapped in `ncclGroupStart/End` so NCCL fuses them into one operation.
pub fn all_reduce_average_grads(&self, params: &[Var]) {
if self.world == 1 {
return;
}
// 1. Sum every grad across ranks (in place, on the null stream).
for p in params {
if let Some(g) = p.grad() {
let n = g.numel();
self.all_reduce_sum_f32_ptr(g.data_ptr() as *mut c_void, n);
}
// Collect this step's grads (in `params()` order) and plan buckets.
let grads: Vec<Tensor> = params.iter().filter_map(|p| p.grad()).collect();
if grads.is_empty() {
return;
}
// 2. Average: scale each summed grad by 1/world (null-stream kernel,
// ordered after the AllReduce that produced it).
let buckets = plan_buckets(&grads, BUCKET_CAP_ELEMS);
let inv_world = 1.0 / self.world as f32;
for p in params {
if let Some(g) = p.grad() {
let device = Device::Cuda(self.device);
for bucket in &buckets {
let total: usize = bucket.iter().map(|g| g.numel()).sum();
// Flat scratch buffer for this bucket (fully overwritten by the pack
// below; `cudaFree` on drop synchronizes, so it outlives its copies).
let flat = Tensor::zeros(&[total], xtrain_tensor::DType::F32, device);
let flat_ptr = flat.data_ptr() as *mut u8;
// Pack: D2D-copy each grad into the bucket at its running offset.
let mut off = 0usize;
for g in bucket {
let bytes = g.numel() * 4;
unsafe {
xtrain_cuda::ffi::launch_scale_inplace_f32(
g.data_ptr() as *mut f32,
inv_world,
g.numel() as i32,
std::ptr::null_mut(),
);
device::copy_d2d(flat_ptr.add(off), g.data_ptr(), bytes)
.expect("pack grad bucket");
}
off += bytes;
}
// One AllReduce(sum) over the whole bucket (fused via the group), then
// one scale by 1/world — same math as per-tensor, far fewer launches.
ffi::check(unsafe { ffi::ncclGroupStart() }, "ncclGroupStart(bucket)");
self.all_reduce_sum_f32_ptr(flat_ptr as *mut c_void, total);
ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(bucket)");
unsafe {
xtrain_cuda::ffi::launch_scale_inplace_f32(
flat_ptr as *mut f32,
inv_world,
total as i32,
std::ptr::null_mut(),
);
}
// Unpack: D2D-copy each averaged slice back into its grad tensor.
let mut off = 0usize;
for g in bucket {
let bytes = g.numel() * 4;
unsafe {
device::copy_d2d(g.data_ptr() as *mut u8, flat_ptr.add(off), bytes)
.expect("unpack grad bucket");
}
off += bytes;
}
}
device::synchronize().expect("grad all-reduce sync failed");
}
}
/// Target bucket size in F32 elements (~25 MB). Big enough to amortize NCCL
/// launch latency across many params, small enough that the scratch allocation
/// stays modest. The exact value is not load-bearing for correctness.
const BUCKET_CAP_ELEMS: usize = 25 * 1024 * 1024 / 4;
/// Greedily group `grads` (in order) into buckets whose total element count stays
/// under `cap` — except a single grad larger than `cap`, which gets its own
/// bucket. Order is preserved so packing offsets are deterministic across ranks.
fn plan_buckets(grads: &[Tensor], cap: usize) -> Vec<Vec<Tensor>> {
let mut buckets: Vec<Vec<Tensor>> = Vec::new();
let mut cur: Vec<Tensor> = Vec::new();
let mut cur_n = 0usize;
for g in grads {
let n = g.numel();
if cur_n > 0 && cur_n + n > cap {
buckets.push(std::mem::take(&mut cur));
cur_n = 0;
}
cur.push(g.clone());
cur_n += n;
}
if !cur.is_empty() {
buckets.push(cur);
}
buckets
}
impl Drop for DdpContext {
fn drop(&mut self) {
if !self.comm.is_null() {