model: batched forward [B,S]

forward_batched(ids[B*S], batch)/loss_batched: run B equal-length sequences as
ONE forward over flattened [B*S] ids, so every linear is one big [B*S,dim] GEMM.
Attention reshapes to [B*nh,S,hd], runs the fused batched causal SDPA (per-seq
mask + RoPE period=S, no cross-sequence attention), writes back [B*S,dim]. The
old per-(batch,head) loop + host-round-tripping split/merge_heads + the additive
causal_mask leaf are gone. forward(ids[seq]) is now forward_batched(ids,1), so
the sampler / inference path (batch=1) is unchanged.

+batched_ids_tensor helper. New batched.rs test: batched forward == looped
single-sequence (logits identical 0.0, grads 6.4e-4, loss identical). PyTorch
parity now exercises B>1 (B=2,S=4): loss 5e-8, logits 6.9e-6, all 25 param
grads within rtol — verifying per-seq RoPE position + per-seq causal masking.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 00:44:25 +08:00
parent 7821bd9c34
commit 5353b38402
5 changed files with 265 additions and 78 deletions

View File

@@ -24,4 +24,4 @@ pub use config::Config;
#[cfg(not(no_cuda))]
mod model;
#[cfg(not(no_cuda))]
pub use model::{TinyTransformer, ids_tensor, param_to_host};
pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};

View File

@@ -30,7 +30,6 @@ pub struct TinyTransformer {
blocks: Vec<Block>,
final_norm: Var, // [dim]
lm_head: Var, // [dim, vocab]
device: Device,
}
impl TinyTransformer {
@@ -72,7 +71,6 @@ impl TinyTransformer {
blocks,
final_norm,
lm_head,
device,
}
}
@@ -106,16 +104,34 @@ impl TinyTransformer {
}
/// Forward over a single sequence of token `ids` (`[seq]` I32 on this
/// model's device). Returns the logits [`Var`] of shape `[seq, vocab]`.
/// model's device). Returns the logits [`Var`] of shape `[seq, vocab]`. This
/// is the batch-1 special case of [`forward_batched`](Self::forward_batched)
/// (used by the autoregressive sampler / inference path).
pub fn forward(&self, ids: &Tensor) -> Var {
let seq = ids.shape()[0];
let mask = self.causal_mask(seq);
self.forward_batched(ids, 1)
}
let mut h = ops::embedding(&self.embed, ids); // [seq, dim]
/// Batched forward over `batch` sequences of equal length `seq`, flattened to
/// `[batch*seq]` I32 ids in sequence-major order (sequence 0's `seq` tokens,
/// then sequence 1's, …). Returns logits `[batch*seq, vocab]` in the SAME flat
/// layout. The whole graph runs on the flattened tokens so every linear
/// projection is ONE big `[batch*seq, dim] × [dim, out]` GEMM (the
/// GPU-filling win); only attention is sequence-aware (per-sequence causal
/// mask + RoPE position, NO cross-sequence attention).
pub fn forward_batched(&self, ids: &Tensor, batch: usize) -> Var {
let total = ids.shape()[0];
assert_eq!(
total % batch,
0,
"ids len {total} not divisible by batch {batch}"
);
let seq = total / batch;
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim]
for b in &self.blocks {
// --- Attention sub-block (pre-norm + residual) ---
let normed = ops::rms_norm(&h, &b.attn_norm, self.cfg.eps);
let attn = self.attention(b, &normed, &mask, seq);
let attn = self.attention(b, &normed, batch, seq);
h = ops::add(&h, &attn);
// --- MLP sub-block (pre-norm + residual) ---
@@ -125,7 +141,7 @@ impl TinyTransformer {
}
let h = ops::rms_norm(&h, &self.final_norm, self.cfg.eps);
ops::matmul(&h, &self.lm_head) // [seq, vocab]
ops::matmul(&h, &self.lm_head) // [batch*seq, vocab]
}
/// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32).
@@ -134,76 +150,76 @@ impl TinyTransformer {
ops::cross_entropy(&logits, targets)
}
/// Multi-head causal self-attention. `x`:[seq,dim] (already normed).
fn attention(&self, b: &Block, x: &Var, mask: &Var, seq: usize) -> Var {
/// Batched cross-entropy mean loss: `forward_batched(ids, batch)` against
/// flat `targets` (`[batch*seq]` I32, same sequence-major layout). The CE mean
/// is over all `batch*seq` rows — identical to averaging the per-sequence
/// losses, so the loss value matches the looped single-sequence path.
pub fn loss_batched(&self, ids: &Tensor, targets: &Tensor, batch: usize) -> Var {
let logits = self.forward_batched(ids, batch);
ops::cross_entropy(&logits, targets)
}
/// Multi-head causal self-attention over a flattened batch. `x`:[batch*seq,dim]
/// (already normed), laid out sequence-major. The Q/K/V/O projections are big
/// `[batch*seq, dim]` GEMMs; the scaled-dot-product attention itself runs as a
/// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each
/// attends within its own `[seq,seq]` causal window (NO cross-sequence
/// attention), with RoPE positions reset per sequence (`period = seq`). Causal
/// masking is applied inside the fused op's softmax kernel (no additive
/// `[seq,seq]` mask tensor).
fn attention(&self, b: &Block, x: &Var, batch: usize, seq: usize) -> Var {
let (nh, hd) = (self.cfg.n_heads, self.cfg.head_dim);
let total = batch * seq;
let bh = batch * nh;
let scale = 1.0 / (hd as f32).sqrt();
// Project, then lay out as per-head [seq, head_dim] tensors.
// [seq,dim] @ [dim,dim] = [seq,dim]
// reshape [seq, nh, hd]
// Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor.
// [B*S,dim] @ [dim,dim] = [B*S,dim]
// reshape [B*S, nh, hd]
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
// rope (kernel expects exactly [tokens, heads, head_dim])
// transpose [nh, seq, hd] split into nh × [seq, hd]
let to_heads = |proj: Var, norm: Option<&Var>| -> Vec<Var> {
let r = ops::reshape(&proj, &[seq, nh, hd]);
// rope [B*S, nh, hd] with per-sequence position (period = seq)
// reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd]
let to_bh = |proj: Var, norm: Option<&Var>| -> Var {
let r = ops::reshape(&proj, &[total, nh, hd]);
let r = match norm {
// Per-head RMSNorm: flatten the (seq,nh) head rows, norm over hd,
// Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd,
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
Some(gamma) => {
let flat = ops::reshape(&r, &[seq * nh, hd]);
let flat = ops::reshape(&r, &[total * nh, hd]);
let normed = ops::rms_norm(&flat, gamma, self.cfg.eps);
let r = ops::reshape(&normed, &[seq, nh, hd]);
ops::rope(&r, self.cfg.rope_theta)
let r = ops::reshape(&normed, &[total, nh, hd]);
ops::rope(&r, self.cfg.rope_theta, seq)
}
None => r,
};
let t = ops::transpose_3d01(&r); // [nh, seq, hd]
ops::split_heads(&t)
let r = ops::reshape(&r, &[batch, seq, nh, hd]);
let t = ops::transpose_4d12(&r); // [B, nh, S, hd]
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
};
let q = to_heads(ops::matmul(x, &b.wq), Some(&b.q_norm));
let k = to_heads(ops::matmul(x, &b.wk), Some(&b.k_norm));
let v = to_heads(ops::matmul(x, &b.wv), None);
let q = to_bh(ops::matmul(x, &b.wq), Some(&b.q_norm));
let k = to_bh(ops::matmul(x, &b.wk), Some(&b.k_norm));
let v = to_bh(ops::matmul(x, &b.wv), None);
// Per-head scaled-dot-product attention with causal mask.
let heads_out: Vec<Var> = (0..nh)
.map(|i| {
let kt = ops::transpose_2d(&k[i]); // [hd, seq]
let scores = ops::scale(&ops::matmul(&q[i], &kt), scale); // [seq,seq]
let scores = ops::add(&scores, mask); // causal
let probs = ops::softmax(&scores);
ops::matmul(&probs, &v[i]) // [seq, hd]
})
.collect();
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd]
// Stack heads back: nh × [seq,hd] → [nh,seq,hd] → [seq,nh,hd] → [seq,dim].
let merged = ops::merge_heads(&heads_out); // [nh, seq, hd]
let t = ops::transpose_3d01(&merged); // [seq, nh, hd]
let concat = ops::reshape(&t, &[seq, nh * hd]); // [seq, dim]
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
// [B,S,nh,hd] → [B*S, dim].
let out = ops::reshape(&out, &[batch, nh, seq, hd]);
let out = ops::transpose_4d12(&out); // [B, S, nh, hd]
let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim]
ops::matmul(&concat, &b.wo) // out projection
}
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[seq,dim].
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim].
fn swiglu_mlp(&self, b: &Block, x: &Var) -> Var {
let gate = ops::matmul(x, &b.w_gate); // [seq, ffn_hidden]
let up = ops::matmul(x, &b.w_up); // [seq, ffn_hidden]
let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up
ops::matmul(&act, &b.w_down) // [seq, dim]
}
/// Additive causal mask `[seq,seq]`: 0 on/below the diagonal, 1e9 above it
/// (so softmax zeros out future positions). A constant leaf (no grad needed,
/// but harmless if it accumulates one — it has no consumers downstream of x).
fn causal_mask(&self, seq: usize) -> Var {
let mut m = vec![0.0f32; seq * seq];
for i in 0..seq {
for j in (i + 1)..seq {
m[i * seq + j] = -1.0e9;
}
}
Var::leaf(Tensor::from_slice(&m, &[seq, seq]).to_device(self.device))
}
}
/// Materialise a parameter's value back to a host `Vec<f32>` (for the GD step
@@ -216,3 +232,17 @@ pub fn param_to_host(v: &Var) -> Vec<f32> {
pub fn ids_tensor(ids: &[i32], device: Device) -> Tensor {
Tensor::from_slice(ids, &[ids.len()]).to_device(device)
}
/// Flatten `batch` equal-length sequences into one `[batch*seq]` I32 tensor in
/// sequence-major order (the layout `forward_batched` expects). Each row of
/// `seqs` is one sequence; all must have the same length.
pub fn batched_ids_tensor(seqs: &[Vec<i32>], device: Device) -> Tensor {
assert!(!seqs.is_empty(), "empty batch");
let seq = seqs[0].len();
let mut flat = Vec::with_capacity(seqs.len() * seq);
for s in seqs {
assert_eq!(s.len(), seq, "ragged batch: sequences must be equal length");
flat.extend_from_slice(s);
}
Tensor::from_slice(&flat, &[flat.len()]).to_device(device)
}

View File

@@ -0,0 +1,142 @@
// T10 batched-forward equivalence: a batched forward over B sequences must equal
// the old single-sequence path (run each sequence on its own, concatenate the
// logits) — both for the forward logits AND every parameter's gradient.
//
// This is THE on-GPU correctness gate for batching (no PyTorch needed): if the
// per-sequence RoPE position, per-sequence causal masking, or any flattened op
// were wrong, the batched logits/grads would drift from the looped reference.
//
// Forward equivalence: batched logits[b*S+i] == single-seq-b logits[i].
// Gradient equivalence: the batched loss is the mean over all B*S rows, i.e.
// (1/B)·Σ_b mean_i(loss_b); summing the B single-sequence losses and scaling by
// 1/B gives the SAME scalar, so their summed grads (tape fan-out) ×1/B match the
// batched grads. We check that.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor, ids_tensor};
use xtrain_tensor::Device;
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
}
#[test]
fn batched_matches_looped_single_sequence() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let mut cfg = Config::tiny();
cfg.vocab = 16;
let batch = 3usize;
let seq = 5usize;
// B distinct sequences (sequence-major), within vocab.
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
// --- Batched forward: ONE pass over [B*S]. ---
let bmodel = build(cfg, device);
let bids = batched_ids_tensor(&seqs, device);
let blogits = host(&bmodel.forward_batched(&bids, batch).value());
// --- Looped reference: each sequence on its own, concatenate logits. ---
let smodel = build(cfg, device);
let mut slogits = Vec::with_capacity(batch * seq * cfg.vocab);
for s in &seqs {
let ids = ids_tensor(s, device);
slogits.extend(host(&smodel.forward(&ids).value()));
}
// Forward equivalence (fp GEMM rounding only differs in summation order).
let max_rel = blogits
.iter()
.zip(&slogits)
.map(|(b, s)| (b - s).abs() / s.abs().max(1e-4))
.fold(0.0f32, f32::max);
println!("batched vs looped: logits max rel err = {max_rel:.3e}");
assert!(max_rel < 1e-3, "batched logits diverged: {max_rel:.3e}");
// --- Gradient equivalence. ---
// Batched: loss = mean over B*S rows; one backward.
let bparams = bmodel.params();
let btgt = batched_ids_tensor(&tgts, device);
let bloss = bmodel.loss_batched(&bids, &btgt, batch);
let bloss_val = host(&bloss.value())[0];
bloss.backward();
// Looped: Σ_b loss_b (each a per-sequence mean), then grad ×(1/B) == batched.
let sparams = smodel.params();
let mut sloss_sum = 0.0f32;
for (s, t) in seqs.iter().zip(&tgts) {
let ids = ids_tensor(s, device);
let tg = ids_tensor(t, device);
let l = smodel.loss(&ids, &tg);
sloss_sum += host(&l.value())[0];
l.backward();
}
println!(
"batched loss = {bloss_val:.6} looped mean = {:.6}",
sloss_sum / batch as f32
);
assert!(
(bloss_val - sloss_sum / batch as f32).abs() < 1e-4,
"batched loss != looped mean"
);
let mut max_grad_rel = 0.0f32;
for (bp, sp) in bparams.iter().zip(&sparams) {
let bg = host(&bp.grad().expect("batched grad"));
let sg = host(&sp.grad().expect("looped grad"));
for (g_b, g_s) in bg.iter().zip(&sg) {
// looped grad is the SUM over B sequences; ×(1/B) recovers the mean.
let g_s = g_s / batch as f32;
let rel = (g_b - g_s).abs() / g_s.abs().max(1e-4);
max_grad_rel = max_grad_rel.max(rel);
}
}
println!("batched vs looped: grad max rel err = {max_grad_rel:.3e}");
assert!(
max_grad_rel < 5e-3,
"batched grads diverged: {max_grad_rel:.3e}"
);
}

View File

@@ -55,10 +55,13 @@ NH = int(cfg["n_heads"])
HD = int(cfg["head_dim"])
EPS = float(cfg["eps"])
THETA = float(cfg["rope_theta"])
# Batched: B sequences of length SEQ, flattened sequence-major to [B*SEQ] ids.
B = int(cfg.get("batch", "1"))
SEQ = int(cfg["seq"])
ids = read_ids("ids.txt")
targets = read_ids("targets.txt")
SEQ = len(ids)
assert len(ids) == B * SEQ, f"ids {len(ids)} != B*SEQ {B*SEQ}"
# Load params as leaf tensors requiring grad (float64 for a clean reference).
P = {}
@@ -76,15 +79,16 @@ def rms_norm(x, gamma):
return x * torch.rsqrt(ms + EPS) * gamma
def rope(x): # x: [seq, nh, hd], position = token index, matching the kernel
def rope(x): # x: [B*SEQ, nh, hd], position = (row % SEQ) — resets per sequence
half = HD // 2
out = torch.empty_like(x)
i = torch.arange(half, dtype=torch.float64)
freq = THETA ** (-(2.0 * i) / HD) # [half]
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1) # [seq,1]
ang = pos * freq # [seq, half]
c = torch.cos(ang).reshape(SEQ, 1, half)
s = torch.sin(ang).reshape(SEQ, 1, half)
# Position within each sequence: rows 0..SEQ for seq 0, 0..SEQ for seq 1, ...
pos = (torch.arange(B * SEQ, dtype=torch.float64) % SEQ).reshape(B * SEQ, 1)
ang = pos * freq # [B*SEQ, half]
c = torch.cos(ang).reshape(B * SEQ, 1, half)
s = torch.sin(ang).reshape(B * SEQ, 1, half)
x0 = x[..., :half]
x1 = x[..., half:]
out[..., :half] = x0 * c - x1 * s
@@ -102,26 +106,30 @@ for l in range(NL):
"ffn_norm", "w_gate", "w_up", "w_down"]})
idx = torch.tensor(ids, dtype=torch.long)
# Per-sequence causal mask (broadcast over the batch); NO cross-sequence attention.
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
h = emb[idx] # [seq, dim]
h = emb[idx] # [B*SEQ, dim] (everything stays flattened, matching the Rust path)
for L in layers:
# Attention
x = rms_norm(h, L["attn_norm"])
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
q = (x @ L["wq"]).reshape(B * SEQ, NH, HD)
k = (x @ L["wk"]).reshape(B * SEQ, NH, HD)
v = (x @ L["wv"]).reshape(B * SEQ, NH, HD)
# Per-head QK-norm (Qwen3-style), before RoPE.
q = rms_norm(q, L["q_norm"])
k = rms_norm(k, L["k_norm"])
q = rope(q).transpose(0, 1) # [nh, seq, hd]
k = rope(k).transpose(0, 1)
v = v.transpose(0, 1)
q = rope(q) # [B*SEQ, nh, hd]
k = rope(k)
# Reshape to [B, NH, SEQ, HD] so attention runs within each sequence.
q = q.reshape(B, SEQ, NH, HD).transpose(1, 2) # [B, nh, seq, hd]
k = k.reshape(B, SEQ, NH, HD).transpose(1, 2)
v = v.reshape(B, SEQ, NH, HD).transpose(1, 2)
scale = 1.0 / math.sqrt(HD)
scores = (q @ k.transpose(-1, -2)) * scale + mask # [nh, seq, seq]
scores = (q @ k.transpose(-1, -2)) * scale + mask # [B, nh, seq, seq]
probs = torch.softmax(scores, dim=-1)
out = probs @ v # [nh, seq, hd]
out = out.transpose(0, 1).reshape(SEQ, DIM) # [seq, dim]
out = probs @ v # [B, nh, seq, hd]
out = out.transpose(1, 2).reshape(B * SEQ, DIM) # [B*SEQ, dim]
attn = out @ L["wo"]
h = h + attn
# MLP
@@ -133,7 +141,7 @@ for L in layers:
h = h + mlp
h = rms_norm(h, final_norm)
logits = h @ lm_head # [seq, vocab]
logits = h @ lm_head # [B*SEQ, vocab]
loss = torch.nn.functional.cross_entropy(
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")

View File

@@ -53,12 +53,17 @@ fn dump_for_parity() {
);
fs::create_dir_all(&dir).unwrap();
// Fixed config + ids (independent of any text, for reproducibility).
// Fixed config + ids (independent of any text, for reproducibility). B>1 so
// the batched forward is exercised: 2 sequences of length 4, flattened
// sequence-major to [B*S]=8 ids. Per-sequence RoPE position (resets at the
// sequence boundary) + per-sequence causal masking (no cross-sequence
// attention) are both checked against PyTorch.
let mut cfg = Config::tiny();
cfg.vocab = 12;
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
let batch = 2usize;
let seq = 4usize;
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6]; // [B*S], sequence-major
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
let seq = ids.len();
// Same deterministic init as the overfit test.
let mut seed = 1u64;
@@ -83,6 +88,7 @@ fn dump_for_parity() {
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
writeln!(f, "eps {:e}", cfg.eps).unwrap();
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
writeln!(f, "batch {batch}").unwrap();
writeln!(f, "seq {seq}").unwrap();
}
{
@@ -105,10 +111,11 @@ fn dump_for_parity() {
write_vec(&dir, &format!("w_{name}.txt"), &param_to_host(p), &shape);
}
// Forward logits + loss, then backward → per-param grads.
// Batched forward logits + loss (B sequences as one forward), then backward
// → per-param grads.
let ids_t = ids_tensor(&ids, device);
let targets_t = ids_tensor(&targets, device);
let logits = model.forward(&ids_t);
let logits = model.forward_batched(&ids_t, batch);
write_vec(
&dir,
"logits.txt",
@@ -116,7 +123,7 @@ fn dump_for_parity() {
logits.value().shape(),
);
let loss = model.loss(&ids_t, &targets_t);
let loss = model.loss_batched(&ids_t, &targets_t, batch);
let loss_val = param_to_host(&loss)[0];
{
let mut f = fs::File::create(dir.join("loss.txt")).unwrap();