train+ddp: micro-batch gradient accumulation (--accum-steps)
Accumulate grads over N micro-batches, then one AdamW step + zero_grad, for an effective batch of N×micro at one micro-batch's activation cost. Each micro-loss is scaled by 1/N before backward (the tape SUM-accumulates the scaled grads) so the boundary grad equals a single step over an N× batch. accum==1 skips the scale → bit-identical to the pre-T16 path. DDP: the cross-rank all-reduce fires ONLY at the accumulation boundary (intermediate micro-steps are local-only, no NCCL); the /world average is orthogonal to the per-micro 1/N, so the boundary grad is the effective global-batch mean. New --accum-steps flag in both train binaries; effective batch is printed. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -74,6 +74,10 @@ fn main() {
|
||||
// Optimization knobs (mirror bin/train).
|
||||
let steps: usize = flag(&args, "--steps", 100);
|
||||
let batch: usize = flag(&args, "--batch", 16);
|
||||
// Micro-batch gradient accumulation (Phase T16): effective global batch =
|
||||
// accum_steps × batch, all-reducing only at the accumulation boundary. Default
|
||||
// 1 = no accumulation (bit-identical to the pre-T16 DDP path).
|
||||
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
@@ -147,6 +151,7 @@ fn main() {
|
||||
let dcfg = DdpConfig {
|
||||
seq_len,
|
||||
batch_size: batch,
|
||||
accum_steps,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr,
|
||||
@@ -164,8 +169,9 @@ fn main() {
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {steps} steps, seq {seq_len}, global batch {batch}, lr {max_lr:.1e}→{min_lr:.1e}, \
|
||||
eval every {eval_every}"
|
||||
"training: {steps} steps, seq {seq_len}, global batch {batch} × accum {accum_steps} = \
|
||||
effective global batch {}, lr {max_lr:.1e}→{min_lr:.1e}, eval every {eval_every}",
|
||||
batch * accum_steps
|
||||
);
|
||||
|
||||
if bf16 {
|
||||
|
||||
@@ -35,6 +35,13 @@ pub struct DdpConfig {
|
||||
pub seq_len: usize,
|
||||
/// Global batch size; must be divisible by the world size.
|
||||
pub batch_size: usize,
|
||||
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||
/// accumulates grads over `accum_steps` micro-batches, giving an EFFECTIVE
|
||||
/// global batch of `accum_steps × batch_size`. The cross-rank all-reduce
|
||||
/// fires ONLY at the accumulation boundary (after the last micro-step) —
|
||||
/// intermediate micro-steps skip the NCCL collective entirely. `1` = no
|
||||
/// accumulation (bit-identical to the pre-T16 DDP path).
|
||||
pub accum_steps: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
@@ -96,6 +103,7 @@ pub fn train_rank(
|
||||
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
|
||||
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
|
||||
let batch_local = cfg.batch_size / ctx.world;
|
||||
let accum = cfg.accum_steps.max(1);
|
||||
let start = Instant::now();
|
||||
let mut tokens_seen: u64 = 0;
|
||||
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
|
||||
@@ -105,10 +113,18 @@ pub fn train_rank(
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Draw the whole global batch from the shared RNG (same on every rank);
|
||||
// collect only this rank's shard (global index % world == rank) and run it
|
||||
// as ONE batched forward/backward. The union of shards == the single-GPU
|
||||
// batch; each rank's backward yields its local mean (Σ_local / b_local).
|
||||
// Accumulate grads over `accum` micro-batches, then ONE optimizer step
|
||||
// (Phase T16). Per micro-batch: draw the whole micro global batch from the
|
||||
// shared RNG (same on every rank), keep only this rank's shard (global index
|
||||
// % world == rank), run it as ONE batched forward/backward. Each micro-loss
|
||||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||
// grads across the `accum` micro-backwards) so the boundary grad equals a
|
||||
// single step over an `accum × batch_size` global batch. `accum == 1` skips
|
||||
// the scale → bit-identical to the pre-T16 DDP path. The cross-rank
|
||||
// all-reduce fires ONLY after the last micro-step (intermediate micro-steps
|
||||
// are local-only, no NCCL).
|
||||
let mut local_sum = 0.0f32; // Σ over micro of (local_mean · b_local)
|
||||
for _ in 0..accum {
|
||||
let mut inputs = Vec::with_capacity(batch_local);
|
||||
let mut targets_v = Vec::with_capacity(batch_local);
|
||||
for i in 0..cfg.batch_size {
|
||||
@@ -121,20 +137,27 @@ pub fn train_rank(
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, batch_local);
|
||||
let local_mean = read_scalar(&loss); // Σ_local / b_local
|
||||
local_sum += read_scalar(&loss) * batch_local as f32; // local mean·b_local
|
||||
if accum == 1 {
|
||||
loss.backward();
|
||||
} else {
|
||||
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
}
|
||||
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
||||
}
|
||||
|
||||
// AllReduce(sum) + /world the grads → every rank holds Σ_global/B_global
|
||||
// (local means summed over ranks, /world = global mean). See note above.
|
||||
// Accumulation boundary: ONE AllReduce(sum) + /world over the accumulated
|
||||
// grads → every rank holds the effective-batch (accum·B_global) mean grad
|
||||
// (the per-micro 1/accum scaling is already baked into each backward; the
|
||||
// /world here is orthogonal to accum). Intermediate micro-steps issued NO
|
||||
// NCCL — only this single boundary collective per optimizer step.
|
||||
ctx.all_reduce_average_grads(¶ms);
|
||||
// Reported loss = global mean: sum the per-rank local sums (= mean·b_local)
|
||||
// across ranks, /B_global. With equal b_local this is mean over ranks.
|
||||
let step_loss =
|
||||
all_reduce_loss(ctx, local_mean * batch_local as f32) / cfg.batch_size as f32;
|
||||
// Reported loss = effective-batch mean: AllReduce(sum) the per-rank local
|
||||
// sums across ranks, /(accum·B_global).
|
||||
let step_loss = all_reduce_loss(ctx, local_sum) / (accum * cfg.batch_size) as f32;
|
||||
losses.push(step_loss);
|
||||
|
||||
// Grads are already the global-batch mean — just clip (pre-scale 1.0).
|
||||
// Grads are already the effective-batch mean — just clip (pre-scale 1.0).
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
|
||||
@@ -101,6 +101,10 @@ fn main() {
|
||||
// Optimization knobs.
|
||||
let steps: usize = flag(&args, "--steps", 2000);
|
||||
let batch_size: usize = flag(&args, "--batch", 8);
|
||||
// Micro-batch gradient accumulation (Phase T16): effective batch =
|
||||
// accum_steps × batch, at one micro-batch's activation-memory cost. Default 1
|
||||
// = no accumulation (bit-identical to the pre-T16 path).
|
||||
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
@@ -201,6 +205,7 @@ fn main() {
|
||||
let tcfg = TrainConfig {
|
||||
seq_len,
|
||||
batch_size,
|
||||
accum_steps,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr,
|
||||
@@ -219,10 +224,13 @@ fn main() {
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}, eval every {}",
|
||||
"training: {} steps, seq {}, batch {} × accum {} = effective batch {}, \
|
||||
lr {:.1e}→{:.1e}, eval every {}",
|
||||
tcfg.steps,
|
||||
tcfg.seq_len,
|
||||
tcfg.batch_size,
|
||||
tcfg.accum_steps,
|
||||
tcfg.batch_size * tcfg.accum_steps,
|
||||
tcfg.schedule.max_lr,
|
||||
tcfg.schedule.min_lr,
|
||||
tcfg.eval_every
|
||||
|
||||
@@ -27,6 +27,12 @@ use crate::schedule::LrSchedule;
|
||||
pub struct TrainConfig {
|
||||
pub seq_len: usize,
|
||||
pub batch_size: usize,
|
||||
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||
/// accumulates grads over `accum_steps` micro-batches of `batch_size`
|
||||
/// sequences, giving an EFFECTIVE batch of `accum_steps × batch_size` at the
|
||||
/// activation-memory cost of a single micro-batch. `1` = no accumulation
|
||||
/// (bit-identical to the pre-T16 path).
|
||||
pub accum_steps: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
@@ -74,12 +80,19 @@ pub fn train(
|
||||
// Best-val checkpointing only kicks in when we actually evaluate.
|
||||
let track_best = valid.is_some() && cfg.eval_every > 0;
|
||||
|
||||
let accum = cfg.accum_steps.max(1);
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Sample `batch_size` sequences and run them as ONE batched forward/
|
||||
// backward. The CE mean over all batch*seq rows is the batch-mean loss, so
|
||||
// backward already yields the batch-mean gradient (clip pre-scale = 1.0).
|
||||
// Accumulate grads over `accum` micro-batches of `batch_size` sequences,
|
||||
// then take ONE optimizer step (Phase T16). Each micro-batch is ONE batched
|
||||
// forward/backward; its loss is the CE mean over batch*seq rows, so backward
|
||||
// yields that micro-batch's mean grad. To make the SUM over `accum` micro-
|
||||
// batches equal a single step over an `accum × batch` batch, each micro-loss
|
||||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||
// grads). `accum == 1` skips the scale entirely → bit-identical to pre-T16.
|
||||
let mut step_loss_sum = 0.0f32;
|
||||
for _ in 0..accum {
|
||||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||||
for _ in 0..cfg.batch_size {
|
||||
@@ -90,12 +103,20 @@ pub fn train(
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
||||
let step_loss = read_scalar(&loss);
|
||||
step_loss_sum += read_scalar(&loss);
|
||||
if accum == 1 {
|
||||
loss.backward();
|
||||
} else {
|
||||
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
}
|
||||
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
||||
}
|
||||
// Reported loss = mean over the effective batch = mean of the raw micro
|
||||
// losses (each is itself a micro-batch mean of equal size).
|
||||
let step_loss = step_loss_sum / accum as f32;
|
||||
losses.push(step_loss);
|
||||
|
||||
// Backward already produced the batch-mean gradient — just clip it.
|
||||
// Backward already produced the effective-batch mean gradient — just clip.
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
|
||||
Reference in New Issue
Block a user