train: real batched step (drop loop+SUM)

Feed a real batch of B sequences as ONE batched forward/backward, replacing the
"loop B times + let the tape SUM grads + clip ×1/B" hack. CE mean over B*S rows
is already the batch-mean loss, so backward yields the batch-mean gradient
directly → clip pre-scale = 1.0.

DDP stays equivalent: each rank runs one batched forward over its b_local =
B_global/world sequences (local-mean grad Σ_local/b_local); all_reduce_average
(sum across ranks /world) = Σ_global/B_global = global batch-mean → clip
pre-scale 1.0. The ddp_correctness single-GPU baseline batches the same way.
DDP loss matches single-GPU 5.7e-7, cross-rank params bit-identical (0.0).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 00:44:33 +08:00
parent 5353b38402
commit 25b032445d
3 changed files with 65 additions and 51 deletions

View File

@@ -17,7 +17,7 @@ use std::thread;
use std::time::Instant;
use xtrain_autodiff::tape::Var;
use xtrain_model::{Config, TinyTransformer, ids_tensor};
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use xtrain_train::checkpoint;
@@ -91,10 +91,11 @@ pub fn train_rank(
let mut losses = Vec::with_capacity(cfg.steps);
let mut evals = Vec::new();
let mut best_val: Option<f32> = None;
// Each rank reaches the global batch mean as (Σ_global / world) · (1/b_local),
// where b_local = batch_size / world (see DdpContext::all_reduce_average_grads).
// Each rank runs ONE batched forward over its b_local = batch_size/world
// sequences → backward grad = local mean (Σ_local / b_local). all_reduce_average
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
let batch_local = cfg.batch_size / ctx.world;
let inv_batch_local = 1.0 / batch_local as f32;
let start = Instant::now();
let mut tokens_seen: u64 = 0;
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
@@ -105,31 +106,36 @@ pub fn train_rank(
let lr = cfg.schedule.lr(step);
// Draw the whole global batch from the shared RNG (same on every rank);
// run forward+backward only on this rank's shard. The tape SUMs the
// shard's grads; the union of shards == the single-GPU batch.
let mut local_loss_sum = 0.0f32;
// collect only this rank's shard (global index % world == rank) and run it
// as ONE batched forward/backward. The union of shards == the single-GPU
// batch; each rank's backward yields its local mean (Σ_local / b_local).
let mut inputs = Vec::with_capacity(batch_local);
let mut targets_v = Vec::with_capacity(batch_local);
for i in 0..cfg.batch_size {
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
if i % ctx.world != ctx.rank {
continue; // not this rank's sequence
if i % ctx.world == ctx.rank {
inputs.push(input);
targets_v.push(target);
}
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
let loss = model.loss(&ids, &targets);
local_loss_sum += read_scalar(&loss);
}
let ids = batched_ids_tensor(&inputs, device);
let targets = batched_ids_tensor(&targets_v, device);
let loss = model.loss_batched(&ids, &targets, batch_local);
let local_mean = read_scalar(&loss); // Σ_local / b_local
loss.backward();
tokens_seen += cfg.seq_len as u64;
}
tokens_seen += (batch_local * cfg.seq_len) as u64;
// AllReduce(sum) + /world the grads → every rank holds Σ_global/world.
// AllReduce(sum) + /world the grads → every rank holds Σ_global/B_global
// (local means summed over ranks, /world = global mean). See note above.
ctx.all_reduce_average_grads(&params);
// The reported loss is the global mean: average local sums across ranks.
let step_loss = all_reduce_loss(ctx, local_loss_sum) / cfg.batch_size as f32;
// Reported loss = global mean: sum the per-rank local sums (= mean·b_local)
// across ranks, /B_global. With equal b_local this is mean over ranks.
let step_loss =
all_reduce_loss(ctx, local_mean * batch_local as f32) / cfg.batch_size as f32;
losses.push(step_loss);
// clip pre_scale = 1/b_local finishes the average to Σ_global/B_global,
// identical to the single-GPU clip(pre_scale = 1/B_global).
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, inv_batch_local);
// Grads are already the global-batch mean — just clip (pre-scale 1.0).
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();

View File

@@ -13,7 +13,7 @@ use std::time::Instant;
use xtrain_cuda::device;
use xtrain_distributed::{DdpConfig, DdpContext, build_model, get_unique_id, launch, train_rank};
use xtrain_model::{Config, ids_tensor};
use xtrain_model::{Config, batched_ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use xtrain_train::clip::clip_grad_norm_gpu;
@@ -47,22 +47,25 @@ fn run_single_gpu(cfg: Config, corpus: &Corpus, dcfg: &DdpConfig) -> (Vec<f32>,
let params = model.params();
let mut opt = GpuAdamW::new(dcfg.weight_decay);
let mut rng = dcfg.seed;
let inv_batch = 1.0 / dcfg.batch_size as f32;
let mut losses = Vec::new();
for step in 0..dcfg.steps {
let lr = dcfg.schedule.lr(step);
let mut loss_sum = 0.0f32;
// Sample the whole global batch and run it as ONE batched forward/backward
// (matches the T10 DDP path: backward yields the global-batch mean grad).
let mut inputs = Vec::with_capacity(dcfg.batch_size);
let mut targets_v = Vec::with_capacity(dcfg.batch_size);
for _ in 0..dcfg.batch_size {
let (input, target) = corpus.sample(dcfg.seq_len, &mut rng);
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
let loss = model.loss(&ids, &targets);
loss_sum += loss.value().to_device(Device::Cpu).as_slice::<f32>()[0];
loss.backward();
inputs.push(input);
targets_v.push(target);
}
losses.push(loss_sum * inv_batch);
clip_grad_norm_gpu(&params, dcfg.max_grad_norm, inv_batch);
let ids = batched_ids_tensor(&inputs, device);
let targets = batched_ids_tensor(&targets_v, device);
let loss = model.loss_batched(&ids, &targets, dcfg.batch_size);
losses.push(loss.value().to_device(Device::Cpu).as_slice::<f32>()[0]);
loss.backward();
clip_grad_norm_gpu(&params, dcfg.max_grad_norm, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();

View File

@@ -1,18 +1,20 @@
//! The training loop: sample sequences → forward `loss` → backward → grad clip
//! (with batch averaging) → AdamW step → zero grads; with an LR schedule,
//! periodic loss logging, and periodic checkpointing.
//! The training loop: sample a batch of sequences → ONE batched forward `loss` →
//! backward → grad clip → AdamW step → zero grads; with an LR schedule, periodic
//! loss logging, and periodic checkpointing.
//!
//! The T5 model is single-sequence, so a "batch" of `batch_size` sequences is
//! handled by running forward+backward on each and letting the tape SUM their
//! grads (its fan-out rule); the clip pass then multiplies by `1/batch_size` to
//! recover the batch-mean gradient before clipping + the optimizer step.
//! Since T10 the model is batched (`loss_batched`): `batch_size` sequences are
//! flattened to `[batch*seq]` and run as a SINGLE forward/backward, so the linear
//! projections become big `[batch*seq, dim]` GEMMs that fill the GPU. The
//! cross-entropy mean is over all `batch*seq` rows — already the batch-mean loss,
//! so backward yields the batch-mean gradient directly (clip pre-scale = 1.0; no
//! more "loop B times + SUM + ×1/batch" hack).
#![cfg(not(no_cuda))]
use std::path::PathBuf;
use std::time::Instant;
use xtrain_model::{TinyTransformer, ids_tensor};
use xtrain_model::{TinyTransformer, batched_ids_tensor, ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
@@ -67,7 +69,6 @@ pub fn train(
let mut losses = Vec::with_capacity(cfg.steps);
let mut evals = Vec::new();
let mut best_val: Option<f32> = None;
let inv_batch = 1.0 / cfg.batch_size as f32;
let start = Instant::now();
let mut tokens_seen: u64 = 0;
// Best-val checkpointing only kicks in when we actually evaluate.
@@ -76,22 +77,26 @@ pub fn train(
for step in 0..cfg.steps {
let lr = cfg.schedule.lr(step);
// Accumulate grads over `batch_size` sequences (tape SUMs them).
let mut step_loss = 0.0f32;
// Sample `batch_size` sequences and run them as ONE batched forward/
// backward. The CE mean over all batch*seq rows is the batch-mean loss, so
// backward already yields the batch-mean gradient (clip pre-scale = 1.0).
let mut inputs = Vec::with_capacity(cfg.batch_size);
let mut targets_v = Vec::with_capacity(cfg.batch_size);
for _ in 0..cfg.batch_size {
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
let loss = model.loss(&ids, &targets);
step_loss += read_scalar(&loss);
loss.backward();
tokens_seen += cfg.seq_len as u64;
inputs.push(input);
targets_v.push(target);
}
step_loss *= inv_batch;
let ids = batched_ids_tensor(&inputs, device);
let targets = batched_ids_tensor(&targets_v, device);
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
let step_loss = read_scalar(&loss);
loss.backward();
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
losses.push(step_loss);
// Average the summed grads (×1/batch) and clip to the global norm.
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, inv_batch);
// Backward already produced the batch-mean gradient — just clip it.
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();