Merge t16-grad-accum into main
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> # Conflicts: # README.md # docs/evolution.md
This commit is contained in:
10
README.md
10
README.md
@@ -51,11 +51,17 @@ Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`]
|
|||||||
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; −29% mem |
|
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; −29% mem |
|
||||||
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
|
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
|
||||||
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem −16%@1k / −23%@2k seq; flash==composed (grads/PyTorch) |
|
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem −16%@1k / −23%@2k seq; flash==composed (grads/PyTorch) |
|
||||||
|
| **T16** | **gradient accumulation** (`--accum-steps`; DDP all-reduces only at the boundary) | equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (−74%) |
|
||||||
|
| **T18** | **dropout** (hand counter-based device RNG + mask, inverted scaling, train/eval switch) | fixed-seed grad-check; **p=0 bit-identical**; recompute-safe |
|
||||||
|
|
||||||
The four performance fixes (T10–T13) each removed a real bottleneck — see
|
The four performance fixes (T10–T13) each removed a real bottleneck — see
|
||||||
[`docs/known-issues.md`](docs/known-issues.md). **Phase 2 (systems-stack depth, T14–)**
|
[`docs/known-issues.md`](docs/known-issues.md). **Phase 2 (systems-stack depth, T14–)**
|
||||||
revisits hand-writing deferred training-stack features; T14 = the fused
|
revisits hand-writing deferred training-stack features: T14 = the fused
|
||||||
flash-attention kernel ([`docs/13-flash-attention.md`](docs/13-flash-attention.md)).
|
flash-attention kernel ([`docs/13-flash-attention.md`](docs/13-flash-attention.md));
|
||||||
|
T16 = micro-batch gradient accumulation ([`docs/15-grad-accum.md`](docs/15-grad-accum.md)),
|
||||||
|
which decouples the effective batch from activation memory (memory tracks the micro-batch,
|
||||||
|
not N×); T18 = dropout ([`docs/17-dropout.md`](docs/17-dropout.md), hand counter-based
|
||||||
|
device RNG + mask, inverted scaling, train/eval switch).
|
||||||
|
|
||||||
## The scaling study — v0 → v8
|
## The scaling study — v0 → v8
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,10 @@ fn main() {
|
|||||||
// Optimization knobs (mirror bin/train).
|
// Optimization knobs (mirror bin/train).
|
||||||
let steps: usize = flag(&args, "--steps", 100);
|
let steps: usize = flag(&args, "--steps", 100);
|
||||||
let batch: usize = flag(&args, "--batch", 16);
|
let batch: usize = flag(&args, "--batch", 16);
|
||||||
|
// Micro-batch gradient accumulation (Phase T16): effective global batch =
|
||||||
|
// accum_steps × batch, all-reducing only at the accumulation boundary. Default
|
||||||
|
// 1 = no accumulation (bit-identical to the pre-T16 DDP path).
|
||||||
|
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||||
let seq_len: usize = flag(&args, "--seq", 64);
|
let seq_len: usize = flag(&args, "--seq", 64);
|
||||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||||
@@ -150,6 +154,7 @@ fn main() {
|
|||||||
let dcfg = DdpConfig {
|
let dcfg = DdpConfig {
|
||||||
seq_len,
|
seq_len,
|
||||||
batch_size: batch,
|
batch_size: batch,
|
||||||
|
accum_steps,
|
||||||
steps,
|
steps,
|
||||||
schedule: LrSchedule {
|
schedule: LrSchedule {
|
||||||
max_lr,
|
max_lr,
|
||||||
@@ -167,8 +172,9 @@ fn main() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"training: {steps} steps, seq {seq_len}, global batch {batch}, lr {max_lr:.1e}→{min_lr:.1e}, \
|
"training: {steps} steps, seq {seq_len}, global batch {batch} × accum {accum_steps} = \
|
||||||
eval every {eval_every}"
|
effective global batch {}, lr {max_lr:.1e}→{min_lr:.1e}, eval every {eval_every}",
|
||||||
|
batch * accum_steps
|
||||||
);
|
);
|
||||||
|
|
||||||
if bf16 {
|
if bf16 {
|
||||||
|
|||||||
@@ -35,6 +35,13 @@ pub struct DdpConfig {
|
|||||||
pub seq_len: usize,
|
pub seq_len: usize,
|
||||||
/// Global batch size; must be divisible by the world size.
|
/// Global batch size; must be divisible by the world size.
|
||||||
pub batch_size: usize,
|
pub batch_size: usize,
|
||||||
|
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||||
|
/// accumulates grads over `accum_steps` micro-batches, giving an EFFECTIVE
|
||||||
|
/// global batch of `accum_steps × batch_size`. The cross-rank all-reduce
|
||||||
|
/// fires ONLY at the accumulation boundary (after the last micro-step) —
|
||||||
|
/// intermediate micro-steps skip the NCCL collective entirely. `1` = no
|
||||||
|
/// accumulation (bit-identical to the pre-T16 DDP path).
|
||||||
|
pub accum_steps: usize,
|
||||||
pub steps: usize,
|
pub steps: usize,
|
||||||
pub schedule: LrSchedule,
|
pub schedule: LrSchedule,
|
||||||
pub weight_decay: f32,
|
pub weight_decay: f32,
|
||||||
@@ -96,6 +103,7 @@ pub fn train_rank(
|
|||||||
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
|
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
|
||||||
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
|
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
|
||||||
let batch_local = cfg.batch_size / ctx.world;
|
let batch_local = cfg.batch_size / ctx.world;
|
||||||
|
let accum = cfg.accum_steps.max(1);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let mut tokens_seen: u64 = 0;
|
let mut tokens_seen: u64 = 0;
|
||||||
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
|
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
|
||||||
@@ -105,36 +113,51 @@ pub fn train_rank(
|
|||||||
for step in 0..cfg.steps {
|
for step in 0..cfg.steps {
|
||||||
let lr = cfg.schedule.lr(step);
|
let lr = cfg.schedule.lr(step);
|
||||||
|
|
||||||
// Draw the whole global batch from the shared RNG (same on every rank);
|
// Accumulate grads over `accum` micro-batches, then ONE optimizer step
|
||||||
// collect only this rank's shard (global index % world == rank) and run it
|
// (Phase T16). Per micro-batch: draw the whole micro global batch from the
|
||||||
// as ONE batched forward/backward. The union of shards == the single-GPU
|
// shared RNG (same on every rank), keep only this rank's shard (global index
|
||||||
// batch; each rank's backward yields its local mean (Σ_local / b_local).
|
// % world == rank), run it as ONE batched forward/backward. Each micro-loss
|
||||||
let mut inputs = Vec::with_capacity(batch_local);
|
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||||
let mut targets_v = Vec::with_capacity(batch_local);
|
// grads across the `accum` micro-backwards) so the boundary grad equals a
|
||||||
for i in 0..cfg.batch_size {
|
// single step over an `accum × batch_size` global batch. `accum == 1` skips
|
||||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
// the scale → bit-identical to the pre-T16 DDP path. The cross-rank
|
||||||
if i % ctx.world == ctx.rank {
|
// all-reduce fires ONLY after the last micro-step (intermediate micro-steps
|
||||||
inputs.push(input);
|
// are local-only, no NCCL).
|
||||||
targets_v.push(target);
|
let mut local_sum = 0.0f32; // Σ over micro of (local_mean · b_local)
|
||||||
|
for _ in 0..accum {
|
||||||
|
let mut inputs = Vec::with_capacity(batch_local);
|
||||||
|
let mut targets_v = Vec::with_capacity(batch_local);
|
||||||
|
for i in 0..cfg.batch_size {
|
||||||
|
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||||
|
if i % ctx.world == ctx.rank {
|
||||||
|
inputs.push(input);
|
||||||
|
targets_v.push(target);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
let ids = batched_ids_tensor(&inputs, device);
|
||||||
|
let targets = batched_ids_tensor(&targets_v, device);
|
||||||
|
let loss = model.loss_batched(&ids, &targets, batch_local);
|
||||||
|
local_sum += read_scalar(&loss) * batch_local as f32; // local mean·b_local
|
||||||
|
if accum == 1 {
|
||||||
|
loss.backward();
|
||||||
|
} else {
|
||||||
|
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||||
|
}
|
||||||
|
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
||||||
}
|
}
|
||||||
let ids = batched_ids_tensor(&inputs, device);
|
|
||||||
let targets = batched_ids_tensor(&targets_v, device);
|
|
||||||
let loss = model.loss_batched(&ids, &targets, batch_local);
|
|
||||||
let local_mean = read_scalar(&loss); // Σ_local / b_local
|
|
||||||
loss.backward();
|
|
||||||
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
|
||||||
|
|
||||||
// AllReduce(sum) + /world the grads → every rank holds Σ_global/B_global
|
// Accumulation boundary: ONE AllReduce(sum) + /world over the accumulated
|
||||||
// (local means summed over ranks, /world = global mean). See note above.
|
// grads → every rank holds the effective-batch (accum·B_global) mean grad
|
||||||
|
// (the per-micro 1/accum scaling is already baked into each backward; the
|
||||||
|
// /world here is orthogonal to accum). Intermediate micro-steps issued NO
|
||||||
|
// NCCL — only this single boundary collective per optimizer step.
|
||||||
ctx.all_reduce_average_grads(¶ms);
|
ctx.all_reduce_average_grads(¶ms);
|
||||||
// Reported loss = global mean: sum the per-rank local sums (= mean·b_local)
|
// Reported loss = effective-batch mean: AllReduce(sum) the per-rank local
|
||||||
// across ranks, /B_global. With equal b_local this is mean over ranks.
|
// sums across ranks, /(accum·B_global).
|
||||||
let step_loss =
|
let step_loss = all_reduce_loss(ctx, local_sum) / (accum * cfg.batch_size) as f32;
|
||||||
all_reduce_loss(ctx, local_mean * batch_local as f32) / cfg.batch_size as f32;
|
|
||||||
losses.push(step_loss);
|
losses.push(step_loss);
|
||||||
|
|
||||||
// Grads are already the global-batch mean — just clip (pre-scale 1.0).
|
// Grads are already the effective-batch mean — just clip (pre-scale 1.0).
|
||||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||||
opt.step(lr, ¶ms);
|
opt.step(lr, ¶ms);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
|||||||
let dcfg = DdpConfig {
|
let dcfg = DdpConfig {
|
||||||
seq_len: 32,
|
seq_len: 32,
|
||||||
batch_size: 8, // global; 4 per rank with world=2
|
batch_size: 8, // global; 4 per rank with world=2
|
||||||
|
accum_steps: 1,
|
||||||
steps,
|
steps,
|
||||||
schedule: LrSchedule {
|
schedule: LrSchedule {
|
||||||
max_lr: 3e-3,
|
max_lr: 3e-3,
|
||||||
@@ -195,6 +196,127 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
|||||||
assert!(max_sdiff < 1e-2, "DDP params diverged from single-GPU");
|
assert!(max_sdiff < 1e-2, "DDP params diverged from single-GPU");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ddp_with_accum_matches_single_gpu_big_batch() {
|
||||||
|
// T16: DDP + gradient accumulation must match a single-GPU big-batch baseline
|
||||||
|
// of the SAME effective batch. world=2, accum=2, per-rank micro-batch 2 →
|
||||||
|
// effective global batch = world·accum·b_local = 2·2·2 = 8. Compared against a
|
||||||
|
// single-GPU run with batch 8, accum 1 (the big-batch baseline). The all-reduce
|
||||||
|
// fires only at the accumulation boundary (once per optimizer step, not per
|
||||||
|
// micro-step) — enforced by the train_rank implementation; the load-bearing
|
||||||
|
// gate here is that loss + final params still match the big-batch baseline.
|
||||||
|
let world = 2usize;
|
||||||
|
if device::device_count().unwrap_or(0) < world as i32 {
|
||||||
|
eprintln!("skip: need >= {world} GPUs");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let vocab = 64usize;
|
||||||
|
let cfg = test_config(vocab);
|
||||||
|
let corpus = synth_corpus(vocab, 4096);
|
||||||
|
let steps = 20usize;
|
||||||
|
let effective_batch = 8usize; // world(2) · accum(2) · b_local(2)
|
||||||
|
let sched = LrSchedule {
|
||||||
|
max_lr: 3e-3,
|
||||||
|
min_lr: 3e-4,
|
||||||
|
warmup: 3,
|
||||||
|
total: steps,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Single-GPU big-batch baseline: world=1, accum=1, batch = effective_batch.
|
||||||
|
let baseline_cfg = DdpConfig {
|
||||||
|
seq_len: 32,
|
||||||
|
batch_size: effective_batch,
|
||||||
|
accum_steps: 1,
|
||||||
|
steps,
|
||||||
|
schedule: sched,
|
||||||
|
weight_decay: 0.1,
|
||||||
|
max_grad_norm: 1.0,
|
||||||
|
log_every: 1_000_000,
|
||||||
|
seed: 7,
|
||||||
|
eval_every: 0,
|
||||||
|
eval_batches: 0,
|
||||||
|
ckpt_path: None,
|
||||||
|
};
|
||||||
|
let (single_losses, single_params) = run_single_gpu(cfg, &corpus, &baseline_cfg);
|
||||||
|
|
||||||
|
// DDP + accumulation: world=2, accum=2 → per-rank micro-batch = batch/world = 2.
|
||||||
|
let ddp_cfg = DdpConfig {
|
||||||
|
batch_size: effective_batch / 2, // per-step global batch; ×accum = effective
|
||||||
|
accum_steps: 2,
|
||||||
|
..baseline_cfg
|
||||||
|
};
|
||||||
|
let devices = [0u32, 1u32];
|
||||||
|
let id = get_unique_id();
|
||||||
|
let results: Vec<(Vec<f32>, Vec<Vec<f32>>)> = std::thread::scope(|s| {
|
||||||
|
let handles: Vec<_> = devices
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(rank, &dev)| {
|
||||||
|
let ddp_cfg = ddp_cfg.clone();
|
||||||
|
let corpus = &corpus;
|
||||||
|
s.spawn(move || {
|
||||||
|
let ctx = DdpContext::init(rank, world, id, dev);
|
||||||
|
let device = Device::Cuda(dev);
|
||||||
|
let model = build_model(cfg, device);
|
||||||
|
let res = train_rank(&ctx, &model, device, corpus, None, &ddp_cfg);
|
||||||
|
let host = model
|
||||||
|
.params()
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
(res.losses, host)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
handles.into_iter().map(|h| h.join().unwrap()).collect()
|
||||||
|
});
|
||||||
|
|
||||||
|
let (ddp_losses, ddp_p0) = &results[0];
|
||||||
|
let (_, ddp_p1) = &results[1];
|
||||||
|
|
||||||
|
// (a) Loss trajectory matches the single-GPU big-batch baseline.
|
||||||
|
let mut max_rel = 0.0f32;
|
||||||
|
for (s, d) in single_losses.iter().zip(ddp_losses) {
|
||||||
|
max_rel = max_rel.max((s - d).abs() / s.abs().max(1e-6));
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"DDP+accum(w2·a2·b2) vs single-GPU big-batch(8): single[last]={:.6} ddp[last]={:.6} max_rel={max_rel:.2e}",
|
||||||
|
single_losses.last().unwrap(),
|
||||||
|
ddp_losses.last().unwrap()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
max_rel < 1e-3,
|
||||||
|
"DDP+accum loss diverged from big-batch baseline: {max_rel:.3e}"
|
||||||
|
);
|
||||||
|
|
||||||
|
// (b) Cross-rank parameter agreement (same KI-5 ULP tolerance as the base test).
|
||||||
|
let mut max_pdiff = 0.0f32;
|
||||||
|
for (a, b) in ddp_p0.iter().zip(ddp_p1) {
|
||||||
|
for (x, y) in a.iter().zip(b) {
|
||||||
|
max_pdiff = max_pdiff.max((x - y).abs());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("DDP+accum cross-rank max |param diff| = {max_pdiff:.3e}");
|
||||||
|
assert!(
|
||||||
|
max_pdiff < 1e-6,
|
||||||
|
"ranks' params drifted apart: {max_pdiff:.3e}"
|
||||||
|
);
|
||||||
|
|
||||||
|
// (c) Final params match single-GPU big-batch within fp tolerance.
|
||||||
|
let mut max_sdiff = 0.0f32;
|
||||||
|
for (a, b) in ddp_p0.iter().zip(&single_params) {
|
||||||
|
for (x, y) in a.iter().zip(b) {
|
||||||
|
max_sdiff = max_sdiff.max((x - y).abs() / y.abs().max(1e-6));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("DDP+accum vs single-GPU big-batch max rel |param diff| = {max_sdiff:.3e}");
|
||||||
|
assert!(
|
||||||
|
max_sdiff < 1e-2,
|
||||||
|
"DDP+accum params diverged from big-batch baseline"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ddp_throughput_scaling() {
|
fn ddp_throughput_scaling() {
|
||||||
let max_gpus = device::device_count().unwrap_or(0) as usize;
|
let max_gpus = device::device_count().unwrap_or(0) as usize;
|
||||||
@@ -230,6 +352,7 @@ fn ddp_throughput_scaling() {
|
|||||||
let dcfg = DdpConfig {
|
let dcfg = DdpConfig {
|
||||||
seq_len,
|
seq_len,
|
||||||
batch_size: per_gpu_batch * world,
|
batch_size: per_gpu_batch * world,
|
||||||
|
accum_steps: 1,
|
||||||
steps,
|
steps,
|
||||||
schedule: LrSchedule {
|
schedule: LrSchedule {
|
||||||
max_lr: 1e-3,
|
max_lr: 1e-3,
|
||||||
|
|||||||
@@ -101,6 +101,10 @@ fn main() {
|
|||||||
// Optimization knobs.
|
// Optimization knobs.
|
||||||
let steps: usize = flag(&args, "--steps", 2000);
|
let steps: usize = flag(&args, "--steps", 2000);
|
||||||
let batch_size: usize = flag(&args, "--batch", 8);
|
let batch_size: usize = flag(&args, "--batch", 8);
|
||||||
|
// Micro-batch gradient accumulation (Phase T16): effective batch =
|
||||||
|
// accum_steps × batch, at one micro-batch's activation-memory cost. Default 1
|
||||||
|
// = no accumulation (bit-identical to the pre-T16 path).
|
||||||
|
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||||
let seq_len: usize = flag(&args, "--seq", 64);
|
let seq_len: usize = flag(&args, "--seq", 64);
|
||||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||||
@@ -208,6 +212,7 @@ fn main() {
|
|||||||
let tcfg = TrainConfig {
|
let tcfg = TrainConfig {
|
||||||
seq_len,
|
seq_len,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
accum_steps,
|
||||||
steps,
|
steps,
|
||||||
schedule: LrSchedule {
|
schedule: LrSchedule {
|
||||||
max_lr,
|
max_lr,
|
||||||
@@ -226,10 +231,13 @@ fn main() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}, eval every {}",
|
"training: {} steps, seq {}, batch {} × accum {} = effective batch {}, \
|
||||||
|
lr {:.1e}→{:.1e}, eval every {}",
|
||||||
tcfg.steps,
|
tcfg.steps,
|
||||||
tcfg.seq_len,
|
tcfg.seq_len,
|
||||||
tcfg.batch_size,
|
tcfg.batch_size,
|
||||||
|
tcfg.accum_steps,
|
||||||
|
tcfg.batch_size * tcfg.accum_steps,
|
||||||
tcfg.schedule.max_lr,
|
tcfg.schedule.max_lr,
|
||||||
tcfg.schedule.min_lr,
|
tcfg.schedule.min_lr,
|
||||||
tcfg.eval_every
|
tcfg.eval_every
|
||||||
|
|||||||
@@ -27,6 +27,12 @@ use crate::schedule::LrSchedule;
|
|||||||
pub struct TrainConfig {
|
pub struct TrainConfig {
|
||||||
pub seq_len: usize,
|
pub seq_len: usize,
|
||||||
pub batch_size: usize,
|
pub batch_size: usize,
|
||||||
|
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||||
|
/// accumulates grads over `accum_steps` micro-batches of `batch_size`
|
||||||
|
/// sequences, giving an EFFECTIVE batch of `accum_steps × batch_size` at the
|
||||||
|
/// activation-memory cost of a single micro-batch. `1` = no accumulation
|
||||||
|
/// (bit-identical to the pre-T16 path).
|
||||||
|
pub accum_steps: usize,
|
||||||
pub steps: usize,
|
pub steps: usize,
|
||||||
pub schedule: LrSchedule,
|
pub schedule: LrSchedule,
|
||||||
pub weight_decay: f32,
|
pub weight_decay: f32,
|
||||||
@@ -74,28 +80,43 @@ pub fn train(
|
|||||||
// Best-val checkpointing only kicks in when we actually evaluate.
|
// Best-val checkpointing only kicks in when we actually evaluate.
|
||||||
let track_best = valid.is_some() && cfg.eval_every > 0;
|
let track_best = valid.is_some() && cfg.eval_every > 0;
|
||||||
|
|
||||||
|
let accum = cfg.accum_steps.max(1);
|
||||||
for step in 0..cfg.steps {
|
for step in 0..cfg.steps {
|
||||||
let lr = cfg.schedule.lr(step);
|
let lr = cfg.schedule.lr(step);
|
||||||
|
|
||||||
// Sample `batch_size` sequences and run them as ONE batched forward/
|
// Accumulate grads over `accum` micro-batches of `batch_size` sequences,
|
||||||
// backward. The CE mean over all batch*seq rows is the batch-mean loss, so
|
// then take ONE optimizer step (Phase T16). Each micro-batch is ONE batched
|
||||||
// backward already yields the batch-mean gradient (clip pre-scale = 1.0).
|
// forward/backward; its loss is the CE mean over batch*seq rows, so backward
|
||||||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
// yields that micro-batch's mean grad. To make the SUM over `accum` micro-
|
||||||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
// batches equal a single step over an `accum × batch` batch, each micro-loss
|
||||||
for _ in 0..cfg.batch_size {
|
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
// grads). `accum == 1` skips the scale entirely → bit-identical to pre-T16.
|
||||||
inputs.push(input);
|
let mut step_loss_sum = 0.0f32;
|
||||||
targets_v.push(target);
|
for _ in 0..accum {
|
||||||
|
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||||||
|
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||||||
|
for _ in 0..cfg.batch_size {
|
||||||
|
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||||
|
inputs.push(input);
|
||||||
|
targets_v.push(target);
|
||||||
|
}
|
||||||
|
let ids = batched_ids_tensor(&inputs, device);
|
||||||
|
let targets = batched_ids_tensor(&targets_v, device);
|
||||||
|
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
||||||
|
step_loss_sum += read_scalar(&loss);
|
||||||
|
if accum == 1 {
|
||||||
|
loss.backward();
|
||||||
|
} else {
|
||||||
|
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||||
|
}
|
||||||
|
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
||||||
}
|
}
|
||||||
let ids = batched_ids_tensor(&inputs, device);
|
// Reported loss = mean over the effective batch = mean of the raw micro
|
||||||
let targets = batched_ids_tensor(&targets_v, device);
|
// losses (each is itself a micro-batch mean of equal size).
|
||||||
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
let step_loss = step_loss_sum / accum as f32;
|
||||||
let step_loss = read_scalar(&loss);
|
|
||||||
loss.backward();
|
|
||||||
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
|
||||||
losses.push(step_loss);
|
losses.push(step_loss);
|
||||||
|
|
||||||
// Backward already produced the batch-mean gradient — just clip it.
|
// Backward already produced the effective-batch mean gradient — just clip.
|
||||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||||
opt.step(lr, ¶ms);
|
opt.step(lr, ¶ms);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
|
|||||||
294
crates/xtrain-train/tests/grad_accum.rs
Normal file
294
crates/xtrain-train/tests/grad_accum.rs
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
// T16 gradient-accumulation correctness gates.
|
||||||
|
//
|
||||||
|
// Gradient accumulation is mathematically EXACT: accumulating the grads of N
|
||||||
|
// micro-batches of B sequences (each micro-loss scaled by 1/N before backward,
|
||||||
|
// the tape SUM-accumulating) equals a single step over one N·B-sequence batch.
|
||||||
|
// This file makes that a closed loop on-GPU, plus the accum_steps=1 bit-identity
|
||||||
|
// regression guard.
|
||||||
|
//
|
||||||
|
// 1. accum_equiv_big_batch: same init, same N·B sequences in the same order.
|
||||||
|
// Path A = ONE batched loss over all N·B (the big-batch baseline). Path B =
|
||||||
|
// N micro-backwards of B each, scale(1/N), tape SUM. Assert loss and EVERY
|
||||||
|
// parameter grad match within fp tolerance (only the summation order differs,
|
||||||
|
// like the T8 DDP-vs-single-GPU and T13 recompute gates).
|
||||||
|
// 2. accum1_bit_identical: accum_steps=1 must reproduce the no-accum path
|
||||||
|
// bit-for-bit (the implementation skips the ×1/1 scale entirely) — every
|
||||||
|
// parameter grad max|Δ| == 0.0.
|
||||||
|
// 3. accum_train_converges: drive the real `train()` loop with accum and assert
|
||||||
|
// the per-step effective-batch loss trace tracks a big-batch baseline (errors
|
||||||
|
// stay bounded over many AdamW steps, not just one).
|
||||||
|
#![cfg(not(no_cuda))]
|
||||||
|
|
||||||
|
use xtrain_autodiff::ops;
|
||||||
|
use xtrain_cuda::device;
|
||||||
|
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||||
|
use xtrain_tensor::Device;
|
||||||
|
use xtrain_train::data::Corpus;
|
||||||
|
use xtrain_train::schedule::LrSchedule;
|
||||||
|
use xtrain_train::{TrainConfig, train};
|
||||||
|
|
||||||
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||||
|
let mut state = seed
|
||||||
|
.wrapping_mul(2862933555777941757)
|
||||||
|
.wrapping_add(3037000493);
|
||||||
|
(0..n)
|
||||||
|
.map(|_| {
|
||||||
|
state = state
|
||||||
|
.wrapping_mul(6364136223846793005)
|
||||||
|
.wrapping_add(1442695040888963407);
|
||||||
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||||
|
let mut seed = 1u64;
|
||||||
|
TinyTransformer::new(cfg, device, |shape| {
|
||||||
|
seed = seed.wrapping_add(1);
|
||||||
|
let n: usize = shape.iter().product();
|
||||||
|
if shape.len() == 1 {
|
||||||
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||||
|
} else {
|
||||||
|
fill(n, seed, 0.08)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||||
|
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||||
|
}
|
||||||
|
|
||||||
|
// `n` deterministic (seq, target) pairs for the equivalence tests.
|
||||||
|
fn make_seqs(n: usize, seq: usize, vocab: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
|
||||||
|
let seqs = (0..n)
|
||||||
|
.map(|b| {
|
||||||
|
(0..seq)
|
||||||
|
.map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32)
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let tgts = (0..n)
|
||||||
|
.map(|b| {
|
||||||
|
(0..seq)
|
||||||
|
.map(|i| ((b * 5 + i * 2 + 2) % vocab) as i32)
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(seqs, tgts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run one big-batch forward/backward over all `seqs` and return the grads.
|
||||||
|
fn big_batch_grads(
|
||||||
|
model: &TinyTransformer,
|
||||||
|
device: Device,
|
||||||
|
seqs: &[Vec<i32>],
|
||||||
|
tgts: &[Vec<i32>],
|
||||||
|
) -> (f32, Vec<Vec<f32>>) {
|
||||||
|
let n = seqs.len();
|
||||||
|
let ids = batched_ids_tensor(seqs, device);
|
||||||
|
let tgt = batched_ids_tensor(tgts, device);
|
||||||
|
let loss = model.loss_batched(&ids, &tgt, n);
|
||||||
|
let loss_val = host(&loss.value())[0];
|
||||||
|
loss.backward();
|
||||||
|
let grads = model
|
||||||
|
.params()
|
||||||
|
.iter()
|
||||||
|
.map(|p| host(&p.grad().expect("grad")))
|
||||||
|
.collect();
|
||||||
|
(loss_val, grads)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate over `accum` micro-batches of `b` sequences (drawn in order from the
|
||||||
|
// flat `seqs`/`tgts`), scaling each micro-loss by 1/accum before backward; the
|
||||||
|
// tape SUM-accumulates. Returns the mean of the raw micro losses + accumulated grads.
|
||||||
|
fn accum_grads(
|
||||||
|
model: &TinyTransformer,
|
||||||
|
device: Device,
|
||||||
|
seqs: &[Vec<i32>],
|
||||||
|
tgts: &[Vec<i32>],
|
||||||
|
accum: usize,
|
||||||
|
b: usize,
|
||||||
|
scale: bool,
|
||||||
|
) -> (f32, Vec<Vec<f32>>) {
|
||||||
|
let mut loss_sum = 0.0f32;
|
||||||
|
for m in 0..accum {
|
||||||
|
let s = &seqs[m * b..(m + 1) * b];
|
||||||
|
let t = &tgts[m * b..(m + 1) * b];
|
||||||
|
let ids = batched_ids_tensor(s, device);
|
||||||
|
let tgt = batched_ids_tensor(t, device);
|
||||||
|
let loss = model.loss_batched(&ids, &tgt, b);
|
||||||
|
loss_sum += host(&loss.value())[0];
|
||||||
|
if scale {
|
||||||
|
ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||||
|
} else {
|
||||||
|
loss.backward(); // accum==1 bit-identity path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let grads = model
|
||||||
|
.params()
|
||||||
|
.iter()
|
||||||
|
.map(|p| host(&p.grad().expect("grad")))
|
||||||
|
.collect();
|
||||||
|
(loss_sum / accum as f32, grads)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accum_equiv_big_batch() {
|
||||||
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
let device = Device::Cuda(0);
|
||||||
|
|
||||||
|
let mut cfg = Config::tiny();
|
||||||
|
cfg.vocab = 16;
|
||||||
|
cfg.n_layers = 3;
|
||||||
|
let b = 2usize; // micro-batch
|
||||||
|
let accum = 4usize; // → effective batch 8
|
||||||
|
let seq = 6usize;
|
||||||
|
let (seqs, tgts) = make_seqs(b * accum, seq, cfg.vocab);
|
||||||
|
|
||||||
|
// Big-batch baseline (accum_steps=1, batch = b·accum).
|
||||||
|
let big = build(cfg, device);
|
||||||
|
let (big_loss, big_grads) = big_batch_grads(&big, device, &seqs, &tgts);
|
||||||
|
|
||||||
|
// Accumulated (accum micro-batches of b, scale 1/accum).
|
||||||
|
let acc = build(cfg, device);
|
||||||
|
let (acc_loss, acc_grads) = accum_grads(&acc, device, &seqs, &tgts, accum, b, true);
|
||||||
|
|
||||||
|
let loss_rel = (big_loss - acc_loss).abs() / big_loss.abs().max(1e-4);
|
||||||
|
let mut max_grad_rel = 0.0f32;
|
||||||
|
for (bg, ag) in big_grads.iter().zip(&acc_grads) {
|
||||||
|
for (x, y) in bg.iter().zip(ag) {
|
||||||
|
max_grad_rel = max_grad_rel.max((x - y).abs() / x.abs().max(1e-3));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"accum=={accum}×b{b} vs big-batch{}: loss {big_loss:.6}/{acc_loss:.6} (rel {loss_rel:.2e}), \
|
||||||
|
grad max rel {max_grad_rel:.3e}",
|
||||||
|
b * accum
|
||||||
|
);
|
||||||
|
// fp summation order differs (big batch sums b·accum rows once; accum sums per
|
||||||
|
// micro then across micros) → tight fp tol, same convention as T13 recompute.
|
||||||
|
assert!(loss_rel < 1e-5, "loss diverged: {loss_rel:.2e}");
|
||||||
|
assert!(
|
||||||
|
max_grad_rel < 1e-4,
|
||||||
|
"accum grads diverged from big batch: {max_grad_rel:.3e}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accum1_bit_identical() {
|
||||||
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
let device = Device::Cuda(0);
|
||||||
|
|
||||||
|
let mut cfg = Config::tiny();
|
||||||
|
cfg.vocab = 16;
|
||||||
|
cfg.n_layers = 3;
|
||||||
|
let b = 4usize;
|
||||||
|
let seq = 6usize;
|
||||||
|
let (seqs, tgts) = make_seqs(b, seq, cfg.vocab);
|
||||||
|
|
||||||
|
// No-accum reference: one batched loss + backward (the pre-T16 path).
|
||||||
|
let reference = build(cfg, device);
|
||||||
|
let (_, ref_grads) = big_batch_grads(&reference, device, &seqs, &tgts);
|
||||||
|
|
||||||
|
// accum_steps=1 path: the loop runs ONE micro-batch and (by design) skips the
|
||||||
|
// ×1/1 scale → must be byte-for-byte identical to the reference backward.
|
||||||
|
let accum1 = build(cfg, device);
|
||||||
|
let (_, a1_grads) = accum_grads(&accum1, device, &seqs, &tgts, 1, b, false);
|
||||||
|
|
||||||
|
let mut max_abs = 0.0f32;
|
||||||
|
for (r, a) in ref_grads.iter().zip(&a1_grads) {
|
||||||
|
for (x, y) in r.iter().zip(a) {
|
||||||
|
max_abs = max_abs.max((x - y).abs());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("accum_steps=1 vs no-accum: grad max |Δ| = {max_abs:.3e}");
|
||||||
|
assert_eq!(
|
||||||
|
max_abs, 0.0,
|
||||||
|
"accum_steps=1 not bit-identical to no-accum: {max_abs:.3e}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A self-contained synthetic corpus (no tokenizer / data file needed).
|
||||||
|
fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||||||
|
Corpus {
|
||||||
|
tokens: (0..n_tokens)
|
||||||
|
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
|
||||||
|
.collect(),
|
||||||
|
vocab_size: vocab,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accum_train_converges() {
|
||||||
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
let device = Device::Cuda(0);
|
||||||
|
|
||||||
|
let vocab = 64usize;
|
||||||
|
let mut cfg = Config::tiny();
|
||||||
|
cfg.vocab = vocab;
|
||||||
|
cfg.n_layers = 2;
|
||||||
|
let corpus = synth_corpus(vocab, 4096);
|
||||||
|
let steps = 20usize;
|
||||||
|
let seq = 32usize;
|
||||||
|
|
||||||
|
// Same per-step RNG stream + effective batch 8 either way: the big-batch run
|
||||||
|
// (accum=1, batch=8) and the accumulated run (accum=4, batch=2) draw the SAME
|
||||||
|
// 8 sequences per step in the same order, so the per-step loss/grads — and thus
|
||||||
|
// the whole AdamW trajectory — track within fp tolerance.
|
||||||
|
let sched = LrSchedule {
|
||||||
|
max_lr: 3e-3,
|
||||||
|
min_lr: 3e-4,
|
||||||
|
warmup: 3,
|
||||||
|
total: steps,
|
||||||
|
};
|
||||||
|
let base = |batch, accum| TrainConfig {
|
||||||
|
seq_len: seq,
|
||||||
|
batch_size: batch,
|
||||||
|
accum_steps: accum,
|
||||||
|
steps,
|
||||||
|
schedule: sched.clone(),
|
||||||
|
weight_decay: 0.1,
|
||||||
|
max_grad_norm: 1.0,
|
||||||
|
log_every: 1_000_000,
|
||||||
|
ckpt_path: None,
|
||||||
|
ckpt_every: 0,
|
||||||
|
eval_every: 0,
|
||||||
|
eval_batches: 0,
|
||||||
|
seed: 7,
|
||||||
|
};
|
||||||
|
|
||||||
|
let big_model = build(cfg, device);
|
||||||
|
let big = train(&big_model, device, &corpus, None, &base(8, 1)).train_losses;
|
||||||
|
|
||||||
|
let acc_model = build(cfg, device);
|
||||||
|
let acc = train(&acc_model, device, &corpus, None, &base(2, 4)).train_losses;
|
||||||
|
|
||||||
|
let mut max_rel = 0.0f32;
|
||||||
|
for (x, y) in big.iter().zip(&acc) {
|
||||||
|
max_rel = max_rel.max((x - y).abs() / x.abs().max(1e-6));
|
||||||
|
}
|
||||||
|
// Final params should also stay close (errors don't blow up over the run).
|
||||||
|
let mut max_pdiff = 0.0f32;
|
||||||
|
for (p, q) in big_model.params().iter().zip(&acc_model.params()) {
|
||||||
|
for (x, y) in host(&p.value()).iter().zip(host(&q.value())) {
|
||||||
|
max_pdiff = max_pdiff.max((x - y).abs() / x.abs().max(1e-6));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"accum(4×2) vs big(8) over {steps} steps: loss[last] {:.6}/{:.6} max_rel {max_rel:.2e}, \
|
||||||
|
final param max rel {max_pdiff:.2e}",
|
||||||
|
big.last().unwrap(),
|
||||||
|
acc.last().unwrap()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
max_rel < 1e-3,
|
||||||
|
"accum loss trajectory diverged: {max_rel:.3e}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
max_pdiff < 1e-2,
|
||||||
|
"accum final params diverged: {max_pdiff:.3e}"
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -84,6 +84,7 @@ fn trains_on_tinystories() {
|
|||||||
let tcfg = TrainConfig {
|
let tcfg = TrainConfig {
|
||||||
seq_len: 64,
|
seq_len: 64,
|
||||||
batch_size: 8,
|
batch_size: 8,
|
||||||
|
accum_steps: 1,
|
||||||
steps,
|
steps,
|
||||||
schedule: LrSchedule {
|
schedule: LrSchedule {
|
||||||
max_lr: 3e-3,
|
max_lr: 3e-3,
|
||||||
|
|||||||
165
docs/15-grad-accum.md
Normal file
165
docs/15-grad-accum.md
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
# Phase T16: Gradient Accumulation — Design Document
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
在已有的训练 loop(T6/T10)与 DDP(T8)之上,加 **micro-batch 梯度累积**:把 `accum_steps=N`
|
||||||
|
个 **micro-step** 的梯度在 tape 里累加起来,再做**一次** `AdamW.step` + `zero_grad`——得到
|
||||||
|
**有效 batch = N × micro_batch** 的更新,而显存只占**一个 micro-batch** 的激活峰值(不随 N 增长)。
|
||||||
|
|
||||||
|
两条硬约束:
|
||||||
|
|
||||||
|
1. **数值等效**:`accum_steps=N`(N 个 micro-step 后一次 step)必须对住「一个 N× 大 batch
|
||||||
|
的单 step」——梯度/loss 在仓内既有容差内**逐位贴合**。这是核心等效性证明。
|
||||||
|
2. **DDP 只在累积边界通信**:`world>1` 下,N 个 micro-step 里**只在最后一个**做 all-reduce
|
||||||
|
(中间 micro-step **跳过跨卡通信**),最终喂给优化器的仍是 global 有效 batch 的均值梯度,
|
||||||
|
loss 对单卡。
|
||||||
|
|
||||||
|
并暴露 train 入口的 `--accum-steps` flag。`accum_steps=1` 必须对当前无累积路径**逐位一致**
|
||||||
|
(回归保护)。
|
||||||
|
|
||||||
|
**不做**:micro-batch 间变 LR / 变 batch(恒定 micro_batch);累积里换 dropout RNG(T18 才有
|
||||||
|
dropout);ZeRO(T17)。本 Phase 只动**优化器 step 的节奏**与 **DDP 通信门控**,复用 tape 既有
|
||||||
|
的 SUM 累加。
|
||||||
|
|
||||||
|
## Module Layout
|
||||||
|
|
||||||
|
```
|
||||||
|
crates/xtrain-train/src/
|
||||||
|
├── train_loop.rs # TrainConfig += accum_steps;inner micro-loop(缩放 loss + tape SUM)
|
||||||
|
└── bin/train.rs # 新 --accum-steps flag;打印有效 batch
|
||||||
|
|
||||||
|
crates/xtrain-distributed/src/
|
||||||
|
└── ddp.rs # DdpConfig += accum_steps;all-reduce 门控到累积边界
|
||||||
|
|
||||||
|
crates/xtrain-train/tests/
|
||||||
|
└── grad_accum.rs # 等效性硬闸门 + accum_steps=1 逐位回归(单卡)
|
||||||
|
|
||||||
|
crates/xtrain-distributed/tests/
|
||||||
|
└── ddp_correctness.rs # += DDP+accum 对单卡(复用既有 ddp_matches… 框架)
|
||||||
|
|
||||||
|
docs/15-grad-accum.md # 本文
|
||||||
|
```
|
||||||
|
|
||||||
|
无新 crate、无新 kernel、无新 autograd op——梯度累积是**纯调度**:tape 早已 SUM 累加,
|
||||||
|
缩放用既有 `ops::scale`,DDP 通信用既有 `all_reduce_average_grads`,只是改**调用节奏与门控**。
|
||||||
|
|
||||||
|
## Key Design Decisions
|
||||||
|
|
||||||
|
### ① 等效性的数学:缩放每个 micro-loss 为 `1/N`
|
||||||
|
|
||||||
|
模型的 `loss_batched` 是 **CE-mean over `batch*seq` 行**(见 `model.rs`)。设一个 micro-batch 有
|
||||||
|
`B` 序列、seq 长 `S`,记某 micro-step 那批 `B*S` 行的 per-row 梯度之和为 `Σ_micro`:
|
||||||
|
|
||||||
|
- **大 batch 基线**(有效 batch `N·B`):一次 `loss_batched(N·B 序列)` = CE-mean over `N·B·S` 行
|
||||||
|
→ backward 给 `G_big = Σ_all / (N·B·S)`,其中 `Σ_all = Σ_n Σ_micro_n`。
|
||||||
|
- **累积**(N 个 micro-step,每个 `B`):micro-step n 的 `loss_batched(B)` = CE-mean over `B·S` 行
|
||||||
|
→ 若直接 backward 得 `Σ_micro_n / (B·S)`;**N 个 backward 之间不 `zero_grad`**,tape SUM 累加 →
|
||||||
|
`Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big`。
|
||||||
|
|
||||||
|
差一个因子 N。修正:**每个 micro-loss 先 `ops::scale(loss, 1/N)` 再 backward**——`scale` 的
|
||||||
|
backward 把上游梯度乘 `1/N`(见 `ops.rs`),于是每个 micro 贡献 `Σ_micro_n / (N·B·S)`,
|
||||||
|
累积后 `Σ_all / (N·B·S) = G_big`,**与大 batch 逐位等效**(仅 fp 求和顺序不同 → 进容差,和
|
||||||
|
T8 DDP-vs-单卡同性质)。
|
||||||
|
|
||||||
|
> 为什么不在 clip 里用 `pre_scale=1/N`?clip 的 `pre_scale` 已被 batch-mean 占用(=1.0)。
|
||||||
|
> 在 loss 上 `scale(1/N)` 更内聚:缩放穿过既有 autograd,不碰 clip/optimizer,且 `N=1` 时
|
||||||
|
> `scale(1.0)` 的 backward 是恒等乘 1 —— 这正是 `accum_steps=1` 逐位回归的保证(见 ④)。
|
||||||
|
|
||||||
|
报告的 step-loss = N 个 micro 的**原始** loss(未缩放值)之和 / N = 有效 batch 的 mean loss,
|
||||||
|
和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。
|
||||||
|
|
||||||
|
### ② 单卡 train loop:inner micro-loop
|
||||||
|
|
||||||
|
每个 optimizer step:
|
||||||
|
|
||||||
|
```text
|
||||||
|
for micro in 0..N:
|
||||||
|
抽 B 序列 → loss = loss_batched(B)
|
||||||
|
step_loss_acc += raw_loss(loss) # 累报告用的原始 loss
|
||||||
|
scale(loss, 1/N).backward() # tape SUM 累加缩放后的梯度
|
||||||
|
# —— 累积边界 ——
|
||||||
|
clip_grad_norm_gpu(params, max_norm, 1.0) # 梯度已是有效 batch 均值
|
||||||
|
opt.step(lr); zero_grad()
|
||||||
|
losses.push(step_loss_acc / N)
|
||||||
|
tokens_seen += N * B * S # 有效 batch tok
|
||||||
|
```
|
||||||
|
|
||||||
|
`accum_steps` 默认 1 → micro-loop 跑一次、`scale(loss,1.0)`、不在 micro 间 zero_grad(本就如此)
|
||||||
|
→ 与现路径完全等价。**每个 micro-step 的计算图在它自己的 backward 后即可释放**(Rust `Rc` 在
|
||||||
|
循环变量出作用域时 drop),所以**显存峰值 = 单个 micro-batch 的激活**,不随 N 增长(③ 实测)。
|
||||||
|
|
||||||
|
抽样次序保持:单卡仍是连续从 RNG 抽 `N·B` 序列;与「大 batch 抽 `N·B`」逐序列对齐,只是分 N 组
|
||||||
|
forward——并集同序,所以 `Σ_all` 的项一致。
|
||||||
|
|
||||||
|
### ③ 显存平 + 有效 batch 实测
|
||||||
|
|
||||||
|
「显存不随 N 增长」是 grad-accum 的卖点,要**实测**而非断言:固定有效 batch `E = N·B`,跑
|
||||||
|
`(N=1,B=E)`(大 batch)vs `(N=E,B=1)`(极端累积),用 `nvidia-smi`/`cudaMemGetInfo` 量峰值显存——
|
||||||
|
后者应**显著低**(少 N× 激活)。train 入口打印 `effective batch = accum_steps × batch`。
|
||||||
|
|
||||||
|
### ④ `accum_steps=1` 逐位回归
|
||||||
|
|
||||||
|
`N=1` 时 inner loop 跑一次、`scale(loss, 1.0)`。`ops::scale(_, 1.0)` 的 fwd 是
|
||||||
|
`value.scale(1.0)`、bwd 是 `grad.scale(1.0)`——数学恒等。为**绝对**逐位(连一次 `×1.0` kernel
|
||||||
|
都不引入),实现里 `N==1` 直接 `loss.backward()`(跳过 scale),与现路径**字节一致**。测试
|
||||||
|
`accum1_bit_identical_to_no_accum` 锁这条。
|
||||||
|
|
||||||
|
### ⑤ DDP:all-reduce 门控到累积边界
|
||||||
|
|
||||||
|
T8 的 `all_reduce_average_grads(params)` 每 step 调一次。grad-accum 下**只在最后一个
|
||||||
|
micro-step 之后调一次**——中间 micro-step 的 backward 只在本卡 tape 里 SUM,**不发 NCCL**。
|
||||||
|
|
||||||
|
均值的账(沿用 T8 的「通信里 /world,clip 里 /b_local」拆分,再叠加 ① 的 /N):
|
||||||
|
|
||||||
|
```text
|
||||||
|
每卡每 micro: scale(loss, 1/N).backward() → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/...
|
||||||
|
N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和
|
||||||
|
all-reduce(sum)+/world (累积边界一次): 跨卡求和后 /world
|
||||||
|
→ 每卡持有 Σ_global,(N·B) / (N · world · ?) # 见下:用 1/N·scale 替代单卡的 1/b
|
||||||
|
clip pre_scale = 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
精确推导:每卡每 micro 的 `loss_batched(B_local)` 是 **本卡 mean over `B_local·S` 行**。
|
||||||
|
`scale(1/N)` 后 backward = `Σ_local_micro / (N · B_local · S)`。N 个 micro tape SUM →
|
||||||
|
`Σ_local_all / (N · B_local · S)`,其中 `Σ_local_all` = 本卡 `N·B_local` 行之和。
|
||||||
|
`all_reduce(sum)` 跨 world 卡 → `Σ_global_all / (N · B_local · S)`(`Σ_global_all` = 全
|
||||||
|
`world·N·B_local = N·B_global` 行之和);`/world` → `Σ_global_all / (N · B_local · S · world)`
|
||||||
|
`= Σ_global_all / (N · B_global · S)`(因 `B_global = world·B_local`)。这正是**有效 batch
|
||||||
|
`N·B_global` 的 mean 梯度**——与单卡「有效 batch `N·B_global` 的大 batch 单 step」逐位等效
|
||||||
|
(求和顺序差进容差)。
|
||||||
|
|
||||||
|
> 关键正确性点:`all_reduce_average_grads` 里的 `/world` 是按 **world** 缩放(与 N 无关);N 的
|
||||||
|
> 那个 `1/N` 已由 ① 的 `scale` 在每个 micro 的 backward 里完成。两者正交,不会互相污染。
|
||||||
|
> 单卡(`world=1`)退化:all-reduce 是 no-op,`/world=1`,只剩 ① 的 `1/N` → 与 ② 一致。
|
||||||
|
|
||||||
|
DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。
|
||||||
|
|
||||||
|
### ⑥ 不变量小结
|
||||||
|
|
||||||
|
| | 单卡基线(大 batch E) | 单卡 accum(N×B=E) | DDP accum(world, N×B_local·world=E) |
|
||||||
|
|---|---|---|---|
|
||||||
|
| loss 缩放 | 无(CE-mean) | 每 micro `×1/N` | 每 micro `×1/N` |
|
||||||
|
| grad 累加 | tape SUM 一批 | tape SUM N 批 | tape SUM N 批/卡 |
|
||||||
|
| 跨卡通信 | — | — | **仅累积边界 1 次** all-reduce + /world |
|
||||||
|
| clip pre_scale | 1.0 | 1.0 | 1.0 |
|
||||||
|
| 显存峰值 | E 的激活 | **B 的激活** | **B_local 的激活** |
|
||||||
|
|
||||||
|
## 验证方法(验收,全部 dash5 实跑 capture)
|
||||||
|
|
||||||
|
GPU 测试 `#[cfg(not(no_cuda))]` 门控。
|
||||||
|
|
||||||
|
1. **等效性(核心硬闸门)** `grad_accum.rs::accum_equiv_big_batch`:同 init、同数据同序,
|
||||||
|
跑「`accum_steps=N`, micro_batch=B」与「`accum_steps=1`, batch=N·B」各一 step,断言
|
||||||
|
①loss、②**每个参数的 grad** rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP
|
||||||
|
闸门约定)。多步版(跑 K 个 optimizer step)再断言**终参**贴合(误差不发散)。
|
||||||
|
2. **`accum_steps=1` 逐位回归** `grad_accum.rs::accum1_bit_identical`:`accum_steps=1` 与现
|
||||||
|
no-accum 路径同 init/同数据 → 每参数 grad `max|Δ| == 0.0`(④ 跳过 scale,字节一致)。
|
||||||
|
3. **DDP+accum 对单卡** `ddp_correctness.rs`(扩既有 `ddp_matches_single_gpu…`):单卡
|
||||||
|
有效 batch `E` 的大 batch baseline vs `world=2 + accum_steps=N`(每卡每 micro `B_local`,
|
||||||
|
`world·N·B_local=E`)→ loss 轨迹 `max_rel<1e-3`、跨 rank 参数一致、且 only-at-boundary 通信
|
||||||
|
(micro 间不发 NCCL,由实现保证 + 不变量推导)。
|
||||||
|
4. **显存平 + 有效 batch** :固定有效 batch,量 `(N=1,大batch)` vs `(N=大,micro=1)` 峰值显存
|
||||||
|
(后者显著低),train 入口打印 effective batch。capture nvidia-smi。
|
||||||
|
5. **全回归套**:autograd grad-check / structural / batched==looped / bf16 / recompute(逐位)/
|
||||||
|
overfit 27/27 / AdamW(GPU bit-exact + host vs torch)/ DDP loss-match + 跨 rank / **xserv
|
||||||
|
闭环 md5**——`accum_steps=1` 默认值保证全部不回归。
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 一、基建 phase(T1–T13)—— 主要动「算法」与「Infra」
|
## 一、基建 phase(T1–T13 + Phase 2 systems-depth)—— 主要动「算法」与「Infra」
|
||||||
|
|
||||||
| Phase | 维度 | 变化 | 结果 / 验证 |
|
| Phase | 维度 | 变化 | 结果 / 验证 |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
@@ -25,6 +25,8 @@
|
|||||||
| T12 | 算法/Infra | **bf16 混合精度**(fp32 master,cuBLAS GemmEx,norm/softmax/CE 保 fp32) | dim768 OOM 解除,−29% 显存/+13% tok/s(修 KI-2) |
|
| T12 | 算法/Infra | **bf16 混合精度**(fp32 master,cuBLAS GemmEx,norm/softmax/CE 保 fp32) | dim768 OOM 解除,−29% 显存/+13% tok/s(修 KI-2) |
|
||||||
| T13 | 算法/Infra | **激活重计算**(per-block gradient checkpointing:前向 no-tape + 反向重算,`backward_seeded`) | 梯度对非重计算版**逐位一致**(0.00);dim768 31.1→14.6GB;**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3,解锁 v8) |
|
| T13 | 算法/Infra | **激活重计算**(per-block gradient checkpointing:前向 no-tape + 反向重算,`backward_seeded`) | 梯度对非重计算版**逐位一致**(0.00);dim768 31.1→14.6GB;**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3,解锁 v8) |
|
||||||
| T14 | 算法/Infra | **融合 flash-attention kernel**(手写单 kernel:online softmax、tiled over KV、**不物化 N×N scores**;flash 式 bwd:重算 scores + `D=ΣdO·O` 化简雅可比 + dQ/dK/dV);opt-in `--flash`,默认保 composed(Phase 2) | fwd 对 composed 6.7e-5、bwd 对 composed dQ 1.7e-5、PyTorch B>1 7.9e-6、flash==composed loss rel 0.0;**峰值显存 −16%@seq1024 / −23%@seq2048**(不物化 N×N,收益随 seq 增长);tok/s ~2.3–2.8× 慢(hd=64 小头维干不过 cuBLAS tensor-core,flash 已知权衡=胜场在显存);md5 闭环逐位一致 |
|
| T14 | 算法/Infra | **融合 flash-attention kernel**(手写单 kernel:online softmax、tiled over KV、**不物化 N×N scores**;flash 式 bwd:重算 scores + `D=ΣdO·O` 化简雅可比 + dQ/dK/dV);opt-in `--flash`,默认保 composed(Phase 2) | fwd 对 composed 6.7e-5、bwd 对 composed dQ 1.7e-5、PyTorch B>1 7.9e-6、flash==composed loss rel 0.0;**峰值显存 −16%@seq1024 / −23%@seq2048**(不物化 N×N,收益随 seq 增长);tok/s ~2.3–2.8× 慢(hd=64 小头维干不过 cuBLAS tensor-core,flash 已知权衡=胜场在显存);md5 闭环逐位一致 |
|
||||||
|
| T16 | 算法/Infra | **梯度累积**(N 个 micro-step:每个 micro-loss `×1/N` 再 backward,tape SUM 累加 → 一次 AdamW step+zero;`--accum-steps`);**DDP 只在累积边界 all-reduce**(中间 micro-step 不发 NCCL,`/world` 与 `1/N` 正交);显存随 micro 不随有效 batch | 等效大 batch**逐位贴合**(loss rel 8.5e-8、grad rel 3.8e-5);`accum=1` 逐位回归(0.00);DDP+accum 对单卡 loss 5.7e-7/跨 rank 一致;**显存平**:同有效 batch 64,big-batch 27.7GB→accum(4×16) **7.2GB(−74%)**(big-batch OOM 而 accum 装下);全回归+xserv 闭环 md5 一致 |
|
||||||
|
| T18 | 算法 | **dropout**(手写 counter-based 设备 RNG → Bernoulli mask,训练 inverted 1/(1-p) scaling、eval 恒等);新 autodiff `dropout` 算子(fwd 生成+施加 mask,bwd 用同 mask),接 residual/ffn 两处;`--dropout` flag 默认 0 | 固定 seed grad-check 过;E[out]≈input + keep≈1-p;**p=0 与无 dropout 逐位一致**;recompute(T13) 组合下梯度仍逐位一致(counter-based seed 重算复现同 mask);全回归 + xserv 闭环绿(导出/推理 dropout 关) |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -50,9 +52,9 @@
|
|||||||
|
|
||||||
## 三、各维度的累积演进(轴向看一条线怎么走的)
|
## 三、各维度的累积演进(轴向看一条线怎么走的)
|
||||||
|
|
||||||
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14,online softmax + flash 式 bwd)。
|
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14,online softmax + flash 式 bwd) → 梯度累积(T16,复用 tape SUM,等效大 batch 而显存随 micro) → dropout(T18,counter-based 设备 RNG + inverted scaling,train/eval 切换)。
|
||||||
- **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024**(v8 首拨容量轴,头数 24→32);核心参数 **41K→226M**(总 3.26M→329M)。
|
- **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024**(v8 首拨容量轴,头数 24→32);核心参数 **41K→226M**(总 3.26M→329M)。
|
||||||
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024) → flash-attention(T14,不物化 N×N,attention 显存收益随 seq 增长)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。
|
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024) → flash-attention(T14,不物化 N×N,attention 显存收益随 seq 增长) → 梯度累积(T16,DDP 只在累积边界通信,显存随 micro 不随有效 batch)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。T13/T14/T16 是三条**显存杠杆**(重计算压激活峰值、flash 不物化 N×N attention scores、梯度累积解耦有效 batch 与激活显存),可叠加放大有效 batch。
|
||||||
- **数据集**:TinyStories 3MB 切片 → 全量 TinyStories(epoch 0.01→5.33,**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**(2.255B 语料,1.02ep)→ **v7 同子集多 epoch(1.45ep,近顶)→ v8 同子集换大模型**(dim1024,1.05ep)。tokenizer 全程 gpt2 BPE(复用 xserv-tokenizer;v6 刻意不换 tokenizer 以隔离「数据来源」变量,KI-4 留后续版本)。
|
- **数据集**:TinyStories 3MB 切片 → 全量 TinyStories(epoch 0.01→5.33,**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**(2.255B 语料,1.02ep)→ **v7 同子集多 epoch(1.45ep,近顶)→ v8 同子集换大模型**(dim1024,1.05ep)。tokenizer 全程 gpt2 BPE(复用 xserv-tokenizer;v6 刻意不换 tokenizer 以隔离「数据来源」变量,KI-4 留后续版本)。
|
||||||
- **v5→v6 数据轴的质变**:v0–v5 都吃合成幼儿故事(TinyStories,低熵、词汇受控),v5 证明同尺寸模型在它上面已饱和;v6 第一版换成**真实教育类网页文本**(FineWeb-edu),语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。
|
- **v5→v6 数据轴的质变**:v0–v5 都吃合成幼儿故事(TinyStories,低熵、词汇受控),v5 证明同尺寸模型在它上面已饱和;v6 第一版换成**真实教育类网页文本**(FineWeb-edu),语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。
|
||||||
- ⚠️ **同子集多 epoch 也有天花板(v6→v7)**:v6 的 FineWeb val 才训 1.02ep、末步仍单调降,曾被读作「还没喂够」;v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B token),FineWeb val 仅 ↓0.05(3.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象**:**「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」,不在「同数据多读几遍」**。后续要继续降 val,必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。
|
- ⚠️ **同子集多 epoch 也有天花板(v6→v7)**:v6 的 FineWeb val 才训 1.02ep、末步仍单调降,曾被读作「还没喂够」;v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B token),FineWeb val 仅 ↓0.05(3.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象**:**「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」,不在「同数据多读几遍」**。后续要继续降 val,必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。
|
||||||
|
|||||||
Reference in New Issue
Block a user