forward_batched(ids[B*S], batch)/loss_batched: run B equal-length sequences as ONE forward over flattened [B*S] ids, so every linear is one big [B*S,dim] GEMM. Attention reshapes to [B*nh,S,hd], runs the fused batched causal SDPA (per-seq mask + RoPE period=S, no cross-sequence attention), writes back [B*S,dim]. The old per-(batch,head) loop + host-round-tripping split/merge_heads + the additive causal_mask leaf are gone. forward(ids[seq]) is now forward_batched(ids,1), so the sampler / inference path (batch=1) is unchanged. +batched_ids_tensor helper. New batched.rs test: batched forward == looped single-sequence (logits identical 0.0, grads 6.4e-4, loss identical). PyTorch parity now exercises B>1 (B=2,S=4): loss 5e-8, logits 6.9e-6, all 25 param grads within rtol — verifying per-seq RoPE position + per-seq causal masking. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
143 lines
5.1 KiB
Rust
143 lines
5.1 KiB
Rust
// T10 batched-forward equivalence: a batched forward over B sequences must equal
|
||
// the old single-sequence path (run each sequence on its own, concatenate the
|
||
// logits) — both for the forward logits AND every parameter's gradient.
|
||
//
|
||
// This is THE on-GPU correctness gate for batching (no PyTorch needed): if the
|
||
// per-sequence RoPE position, per-sequence causal masking, or any flattened op
|
||
// were wrong, the batched logits/grads would drift from the looped reference.
|
||
//
|
||
// Forward equivalence: batched logits[b*S+i] == single-seq-b logits[i].
|
||
// Gradient equivalence: the batched loss is the mean over all B*S rows, i.e.
|
||
// (1/B)·Σ_b mean_i(loss_b); summing the B single-sequence losses and scaling by
|
||
// 1/B gives the SAME scalar, so their summed grads (tape fan-out) ×1/B match the
|
||
// batched grads. We check that.
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use xtrain_cuda::device;
|
||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor, ids_tensor};
|
||
use xtrain_tensor::Device;
|
||
|
||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||
let mut state = seed
|
||
.wrapping_mul(2862933555777941757)
|
||
.wrapping_add(3037000493);
|
||
(0..n)
|
||
.map(|_| {
|
||
state = state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||
let mut seed = 1u64;
|
||
TinyTransformer::new(cfg, device, |shape| {
|
||
seed = seed.wrapping_add(1);
|
||
let n: usize = shape.iter().product();
|
||
if shape.len() == 1 {
|
||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||
} else {
|
||
fill(n, seed, 0.08)
|
||
}
|
||
})
|
||
}
|
||
|
||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||
}
|
||
|
||
#[test]
|
||
fn batched_matches_looped_single_sequence() {
|
||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||
device::set_device(0).unwrap();
|
||
let device = Device::Cuda(0);
|
||
|
||
let mut cfg = Config::tiny();
|
||
cfg.vocab = 16;
|
||
let batch = 3usize;
|
||
let seq = 5usize;
|
||
// B distinct sequences (sequence-major), within vocab.
|
||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||
.map(|b| {
|
||
(0..seq)
|
||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||
.collect()
|
||
})
|
||
.collect();
|
||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||
.map(|b| {
|
||
(0..seq)
|
||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||
.collect()
|
||
})
|
||
.collect();
|
||
|
||
// --- Batched forward: ONE pass over [B*S]. ---
|
||
let bmodel = build(cfg, device);
|
||
let bids = batched_ids_tensor(&seqs, device);
|
||
let blogits = host(&bmodel.forward_batched(&bids, batch).value());
|
||
|
||
// --- Looped reference: each sequence on its own, concatenate logits. ---
|
||
let smodel = build(cfg, device);
|
||
let mut slogits = Vec::with_capacity(batch * seq * cfg.vocab);
|
||
for s in &seqs {
|
||
let ids = ids_tensor(s, device);
|
||
slogits.extend(host(&smodel.forward(&ids).value()));
|
||
}
|
||
|
||
// Forward equivalence (fp GEMM rounding only differs in summation order).
|
||
let max_rel = blogits
|
||
.iter()
|
||
.zip(&slogits)
|
||
.map(|(b, s)| (b - s).abs() / s.abs().max(1e-4))
|
||
.fold(0.0f32, f32::max);
|
||
println!("batched vs looped: logits max rel err = {max_rel:.3e}");
|
||
assert!(max_rel < 1e-3, "batched logits diverged: {max_rel:.3e}");
|
||
|
||
// --- Gradient equivalence. ---
|
||
// Batched: loss = mean over B*S rows; one backward.
|
||
let bparams = bmodel.params();
|
||
let btgt = batched_ids_tensor(&tgts, device);
|
||
let bloss = bmodel.loss_batched(&bids, &btgt, batch);
|
||
let bloss_val = host(&bloss.value())[0];
|
||
bloss.backward();
|
||
|
||
// Looped: Σ_b loss_b (each a per-sequence mean), then grad ×(1/B) == batched.
|
||
let sparams = smodel.params();
|
||
let mut sloss_sum = 0.0f32;
|
||
for (s, t) in seqs.iter().zip(&tgts) {
|
||
let ids = ids_tensor(s, device);
|
||
let tg = ids_tensor(t, device);
|
||
let l = smodel.loss(&ids, &tg);
|
||
sloss_sum += host(&l.value())[0];
|
||
l.backward();
|
||
}
|
||
println!(
|
||
"batched loss = {bloss_val:.6} looped mean = {:.6}",
|
||
sloss_sum / batch as f32
|
||
);
|
||
assert!(
|
||
(bloss_val - sloss_sum / batch as f32).abs() < 1e-4,
|
||
"batched loss != looped mean"
|
||
);
|
||
|
||
let mut max_grad_rel = 0.0f32;
|
||
for (bp, sp) in bparams.iter().zip(&sparams) {
|
||
let bg = host(&bp.grad().expect("batched grad"));
|
||
let sg = host(&sp.grad().expect("looped grad"));
|
||
for (g_b, g_s) in bg.iter().zip(&sg) {
|
||
// looped grad is the SUM over B sequences; ×(1/B) recovers the mean.
|
||
let g_s = g_s / batch as f32;
|
||
let rel = (g_b - g_s).abs() / g_s.abs().max(1e-4);
|
||
max_grad_rel = max_grad_rel.max(rel);
|
||
}
|
||
}
|
||
println!("batched vs looped: grad max rel err = {max_grad_rel:.3e}");
|
||
assert!(
|
||
max_grad_rel < 5e-3,
|
||
"batched grads diverged: {max_grad_rel:.3e}"
|
||
);
|
||
}
|