Files
xtrain/crates/xtrain-model/tests/batched.rs
Gahow Wang 5353b38402 model: batched forward [B,S]
forward_batched(ids[B*S], batch)/loss_batched: run B equal-length sequences as
ONE forward over flattened [B*S] ids, so every linear is one big [B*S,dim] GEMM.
Attention reshapes to [B*nh,S,hd], runs the fused batched causal SDPA (per-seq
mask + RoPE period=S, no cross-sequence attention), writes back [B*S,dim]. The
old per-(batch,head) loop + host-round-tripping split/merge_heads + the additive
causal_mask leaf are gone. forward(ids[seq]) is now forward_batched(ids,1), so
the sampler / inference path (batch=1) is unchanged.

+batched_ids_tensor helper. New batched.rs test: batched forward == looped
single-sequence (logits identical 0.0, grads 6.4e-4, loss identical). PyTorch
parity now exercises B>1 (B=2,S=4): loss 5e-8, logits 6.9e-6, all 25 param
grads within rtol — verifying per-seq RoPE position + per-seq causal masking.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:25 +08:00

143 lines
5.1 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// T10 batched-forward equivalence: a batched forward over B sequences must equal
// the old single-sequence path (run each sequence on its own, concatenate the
// logits) — both for the forward logits AND every parameter's gradient.
//
// This is THE on-GPU correctness gate for batching (no PyTorch needed): if the
// per-sequence RoPE position, per-sequence causal masking, or any flattened op
// were wrong, the batched logits/grads would drift from the looped reference.
//
// Forward equivalence: batched logits[b*S+i] == single-seq-b logits[i].
// Gradient equivalence: the batched loss is the mean over all B*S rows, i.e.
// (1/B)·Σ_b mean_i(loss_b); summing the B single-sequence losses and scaling by
// 1/B gives the SAME scalar, so their summed grads (tape fan-out) ×1/B match the
// batched grads. We check that.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor, ids_tensor};
use xtrain_tensor::Device;
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
}
#[test]
fn batched_matches_looped_single_sequence() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let mut cfg = Config::tiny();
cfg.vocab = 16;
let batch = 3usize;
let seq = 5usize;
// B distinct sequences (sequence-major), within vocab.
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
// --- Batched forward: ONE pass over [B*S]. ---
let bmodel = build(cfg, device);
let bids = batched_ids_tensor(&seqs, device);
let blogits = host(&bmodel.forward_batched(&bids, batch).value());
// --- Looped reference: each sequence on its own, concatenate logits. ---
let smodel = build(cfg, device);
let mut slogits = Vec::with_capacity(batch * seq * cfg.vocab);
for s in &seqs {
let ids = ids_tensor(s, device);
slogits.extend(host(&smodel.forward(&ids).value()));
}
// Forward equivalence (fp GEMM rounding only differs in summation order).
let max_rel = blogits
.iter()
.zip(&slogits)
.map(|(b, s)| (b - s).abs() / s.abs().max(1e-4))
.fold(0.0f32, f32::max);
println!("batched vs looped: logits max rel err = {max_rel:.3e}");
assert!(max_rel < 1e-3, "batched logits diverged: {max_rel:.3e}");
// --- Gradient equivalence. ---
// Batched: loss = mean over B*S rows; one backward.
let bparams = bmodel.params();
let btgt = batched_ids_tensor(&tgts, device);
let bloss = bmodel.loss_batched(&bids, &btgt, batch);
let bloss_val = host(&bloss.value())[0];
bloss.backward();
// Looped: Σ_b loss_b (each a per-sequence mean), then grad ×(1/B) == batched.
let sparams = smodel.params();
let mut sloss_sum = 0.0f32;
for (s, t) in seqs.iter().zip(&tgts) {
let ids = ids_tensor(s, device);
let tg = ids_tensor(t, device);
let l = smodel.loss(&ids, &tg);
sloss_sum += host(&l.value())[0];
l.backward();
}
println!(
"batched loss = {bloss_val:.6} looped mean = {:.6}",
sloss_sum / batch as f32
);
assert!(
(bloss_val - sloss_sum / batch as f32).abs() < 1e-4,
"batched loss != looped mean"
);
let mut max_grad_rel = 0.0f32;
for (bp, sp) in bparams.iter().zip(&sparams) {
let bg = host(&bp.grad().expect("batched grad"));
let sg = host(&sp.grad().expect("looped grad"));
for (g_b, g_s) in bg.iter().zip(&sg) {
// looped grad is the SUM over B sequences; ×(1/B) recovers the mean.
let g_s = g_s / batch as f32;
let rel = (g_b - g_s).abs() / g_s.abs().max(1e-4);
max_grad_rel = max_grad_rel.max(rel);
}
}
println!("batched vs looped: grad max rel err = {max_grad_rel:.3e}");
assert!(
max_grad_rel < 5e-3,
"batched grads diverged: {max_grad_rel:.3e}"
);
}