model: batched forward [B,S]
forward_batched(ids[B*S], batch)/loss_batched: run B equal-length sequences as ONE forward over flattened [B*S] ids, so every linear is one big [B*S,dim] GEMM. Attention reshapes to [B*nh,S,hd], runs the fused batched causal SDPA (per-seq mask + RoPE period=S, no cross-sequence attention), writes back [B*S,dim]. The old per-(batch,head) loop + host-round-tripping split/merge_heads + the additive causal_mask leaf are gone. forward(ids[seq]) is now forward_batched(ids,1), so the sampler / inference path (batch=1) is unchanged. +batched_ids_tensor helper. New batched.rs test: batched forward == looped single-sequence (logits identical 0.0, grads 6.4e-4, loss identical). PyTorch parity now exercises B>1 (B=2,S=4): loss 5e-8, logits 6.9e-6, all 25 param grads within rtol — verifying per-seq RoPE position + per-seq causal masking. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -24,4 +24,4 @@ pub use config::Config;
|
||||
#[cfg(not(no_cuda))]
|
||||
mod model;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use model::{TinyTransformer, ids_tensor, param_to_host};
|
||||
pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
|
||||
|
||||
@@ -30,7 +30,6 @@ pub struct TinyTransformer {
|
||||
blocks: Vec<Block>,
|
||||
final_norm: Var, // [dim]
|
||||
lm_head: Var, // [dim, vocab]
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl TinyTransformer {
|
||||
@@ -72,7 +71,6 @@ impl TinyTransformer {
|
||||
blocks,
|
||||
final_norm,
|
||||
lm_head,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,16 +104,34 @@ impl TinyTransformer {
|
||||
}
|
||||
|
||||
/// Forward over a single sequence of token `ids` (`[seq]` I32 on this
|
||||
/// model's device). Returns the logits [`Var`] of shape `[seq, vocab]`.
|
||||
/// model's device). Returns the logits [`Var`] of shape `[seq, vocab]`. This
|
||||
/// is the batch-1 special case of [`forward_batched`](Self::forward_batched)
|
||||
/// (used by the autoregressive sampler / inference path).
|
||||
pub fn forward(&self, ids: &Tensor) -> Var {
|
||||
let seq = ids.shape()[0];
|
||||
let mask = self.causal_mask(seq);
|
||||
self.forward_batched(ids, 1)
|
||||
}
|
||||
|
||||
let mut h = ops::embedding(&self.embed, ids); // [seq, dim]
|
||||
/// Batched forward over `batch` sequences of equal length `seq`, flattened to
|
||||
/// `[batch*seq]` I32 ids in sequence-major order (sequence 0's `seq` tokens,
|
||||
/// then sequence 1's, …). Returns logits `[batch*seq, vocab]` in the SAME flat
|
||||
/// layout. The whole graph runs on the flattened tokens so every linear
|
||||
/// projection is ONE big `[batch*seq, dim] × [dim, out]` GEMM (the
|
||||
/// GPU-filling win); only attention is sequence-aware (per-sequence causal
|
||||
/// mask + RoPE position, NO cross-sequence attention).
|
||||
pub fn forward_batched(&self, ids: &Tensor, batch: usize) -> Var {
|
||||
let total = ids.shape()[0];
|
||||
assert_eq!(
|
||||
total % batch,
|
||||
0,
|
||||
"ids len {total} not divisible by batch {batch}"
|
||||
);
|
||||
let seq = total / batch;
|
||||
|
||||
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim]
|
||||
for b in &self.blocks {
|
||||
// --- Attention sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(&h, &b.attn_norm, self.cfg.eps);
|
||||
let attn = self.attention(b, &normed, &mask, seq);
|
||||
let attn = self.attention(b, &normed, batch, seq);
|
||||
h = ops::add(&h, &attn);
|
||||
|
||||
// --- MLP sub-block (pre-norm + residual) ---
|
||||
@@ -125,7 +141,7 @@ impl TinyTransformer {
|
||||
}
|
||||
|
||||
let h = ops::rms_norm(&h, &self.final_norm, self.cfg.eps);
|
||||
ops::matmul(&h, &self.lm_head) // [seq, vocab]
|
||||
ops::matmul(&h, &self.lm_head) // [batch*seq, vocab]
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32).
|
||||
@@ -134,76 +150,76 @@ impl TinyTransformer {
|
||||
ops::cross_entropy(&logits, targets)
|
||||
}
|
||||
|
||||
/// Multi-head causal self-attention. `x`:[seq,dim] (already normed).
|
||||
fn attention(&self, b: &Block, x: &Var, mask: &Var, seq: usize) -> Var {
|
||||
/// Batched cross-entropy mean loss: `forward_batched(ids, batch)` against
|
||||
/// flat `targets` (`[batch*seq]` I32, same sequence-major layout). The CE mean
|
||||
/// is over all `batch*seq` rows — identical to averaging the per-sequence
|
||||
/// losses, so the loss value matches the looped single-sequence path.
|
||||
pub fn loss_batched(&self, ids: &Tensor, targets: &Tensor, batch: usize) -> Var {
|
||||
let logits = self.forward_batched(ids, batch);
|
||||
ops::cross_entropy(&logits, targets)
|
||||
}
|
||||
|
||||
/// Multi-head causal self-attention over a flattened batch. `x`:[batch*seq,dim]
|
||||
/// (already normed), laid out sequence-major. The Q/K/V/O projections are big
|
||||
/// `[batch*seq, dim]` GEMMs; the scaled-dot-product attention itself runs as a
|
||||
/// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each
|
||||
/// attends within its own `[seq,seq]` causal window (NO cross-sequence
|
||||
/// attention), with RoPE positions reset per sequence (`period = seq`). Causal
|
||||
/// masking is applied inside the fused op's softmax kernel (no additive
|
||||
/// `[seq,seq]` mask tensor).
|
||||
fn attention(&self, b: &Block, x: &Var, batch: usize, seq: usize) -> Var {
|
||||
let (nh, hd) = (self.cfg.n_heads, self.cfg.head_dim);
|
||||
let total = batch * seq;
|
||||
let bh = batch * nh;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
|
||||
// Project, then lay out as per-head [seq, head_dim] tensors.
|
||||
// [seq,dim] @ [dim,dim] = [seq,dim]
|
||||
// reshape [seq, nh, hd]
|
||||
// Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor.
|
||||
// [B*S,dim] @ [dim,dim] = [B*S,dim]
|
||||
// reshape [B*S, nh, hd]
|
||||
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
|
||||
// rope (kernel expects exactly [tokens, heads, head_dim])
|
||||
// transpose [nh, seq, hd] → split into nh × [seq, hd]
|
||||
let to_heads = |proj: Var, norm: Option<&Var>| -> Vec<Var> {
|
||||
let r = ops::reshape(&proj, &[seq, nh, hd]);
|
||||
// rope [B*S, nh, hd] with per-sequence position (period = seq)
|
||||
// reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd]
|
||||
let to_bh = |proj: Var, norm: Option<&Var>| -> Var {
|
||||
let r = ops::reshape(&proj, &[total, nh, hd]);
|
||||
let r = match norm {
|
||||
// Per-head RMSNorm: flatten the (seq,nh) head rows, norm over hd,
|
||||
// Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd,
|
||||
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
|
||||
Some(gamma) => {
|
||||
let flat = ops::reshape(&r, &[seq * nh, hd]);
|
||||
let flat = ops::reshape(&r, &[total * nh, hd]);
|
||||
let normed = ops::rms_norm(&flat, gamma, self.cfg.eps);
|
||||
let r = ops::reshape(&normed, &[seq, nh, hd]);
|
||||
ops::rope(&r, self.cfg.rope_theta)
|
||||
let r = ops::reshape(&normed, &[total, nh, hd]);
|
||||
ops::rope(&r, self.cfg.rope_theta, seq)
|
||||
}
|
||||
None => r,
|
||||
};
|
||||
let t = ops::transpose_3d01(&r); // [nh, seq, hd]
|
||||
ops::split_heads(&t)
|
||||
let r = ops::reshape(&r, &[batch, seq, nh, hd]);
|
||||
let t = ops::transpose_4d12(&r); // [B, nh, S, hd]
|
||||
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
|
||||
};
|
||||
|
||||
let q = to_heads(ops::matmul(x, &b.wq), Some(&b.q_norm));
|
||||
let k = to_heads(ops::matmul(x, &b.wk), Some(&b.k_norm));
|
||||
let v = to_heads(ops::matmul(x, &b.wv), None);
|
||||
let q = to_bh(ops::matmul(x, &b.wq), Some(&b.q_norm));
|
||||
let k = to_bh(ops::matmul(x, &b.wk), Some(&b.k_norm));
|
||||
let v = to_bh(ops::matmul(x, &b.wv), None);
|
||||
|
||||
// Per-head scaled-dot-product attention with causal mask.
|
||||
let heads_out: Vec<Var> = (0..nh)
|
||||
.map(|i| {
|
||||
let kt = ops::transpose_2d(&k[i]); // [hd, seq]
|
||||
let scores = ops::scale(&ops::matmul(&q[i], &kt), scale); // [seq,seq]
|
||||
let scores = ops::add(&scores, mask); // causal
|
||||
let probs = ops::softmax(&scores);
|
||||
ops::matmul(&probs, &v[i]) // [seq, hd]
|
||||
})
|
||||
.collect();
|
||||
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
|
||||
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
|
||||
let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd]
|
||||
|
||||
// Stack heads back: nh × [seq,hd] → [nh,seq,hd] → [seq,nh,hd] → [seq,dim].
|
||||
let merged = ops::merge_heads(&heads_out); // [nh, seq, hd]
|
||||
let t = ops::transpose_3d01(&merged); // [seq, nh, hd]
|
||||
let concat = ops::reshape(&t, &[seq, nh * hd]); // [seq, dim]
|
||||
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
|
||||
// [B,S,nh,hd] → [B*S, dim].
|
||||
let out = ops::reshape(&out, &[batch, nh, seq, hd]);
|
||||
let out = ops::transpose_4d12(&out); // [B, S, nh, hd]
|
||||
let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim]
|
||||
ops::matmul(&concat, &b.wo) // out projection
|
||||
}
|
||||
|
||||
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[seq,dim].
|
||||
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim].
|
||||
fn swiglu_mlp(&self, b: &Block, x: &Var) -> Var {
|
||||
let gate = ops::matmul(x, &b.w_gate); // [seq, ffn_hidden]
|
||||
let up = ops::matmul(x, &b.w_up); // [seq, ffn_hidden]
|
||||
let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up
|
||||
ops::matmul(&act, &b.w_down) // [seq, dim]
|
||||
}
|
||||
|
||||
/// Additive causal mask `[seq,seq]`: 0 on/below the diagonal, −1e9 above it
|
||||
/// (so softmax zeros out future positions). A constant leaf (no grad needed,
|
||||
/// but harmless if it accumulates one — it has no consumers downstream of x).
|
||||
fn causal_mask(&self, seq: usize) -> Var {
|
||||
let mut m = vec![0.0f32; seq * seq];
|
||||
for i in 0..seq {
|
||||
for j in (i + 1)..seq {
|
||||
m[i * seq + j] = -1.0e9;
|
||||
}
|
||||
}
|
||||
Var::leaf(Tensor::from_slice(&m, &[seq, seq]).to_device(self.device))
|
||||
}
|
||||
}
|
||||
|
||||
/// Materialise a parameter's value back to a host `Vec<f32>` (for the GD step
|
||||
@@ -216,3 +232,17 @@ pub fn param_to_host(v: &Var) -> Vec<f32> {
|
||||
pub fn ids_tensor(ids: &[i32], device: Device) -> Tensor {
|
||||
Tensor::from_slice(ids, &[ids.len()]).to_device(device)
|
||||
}
|
||||
|
||||
/// Flatten `batch` equal-length sequences into one `[batch*seq]` I32 tensor in
|
||||
/// sequence-major order (the layout `forward_batched` expects). Each row of
|
||||
/// `seqs` is one sequence; all must have the same length.
|
||||
pub fn batched_ids_tensor(seqs: &[Vec<i32>], device: Device) -> Tensor {
|
||||
assert!(!seqs.is_empty(), "empty batch");
|
||||
let seq = seqs[0].len();
|
||||
let mut flat = Vec::with_capacity(seqs.len() * seq);
|
||||
for s in seqs {
|
||||
assert_eq!(s.len(), seq, "ragged batch: sequences must be equal length");
|
||||
flat.extend_from_slice(s);
|
||||
}
|
||||
Tensor::from_slice(&flat, &[flat.len()]).to_device(device)
|
||||
}
|
||||
|
||||
142
crates/xtrain-model/tests/batched.rs
Normal file
142
crates/xtrain-model/tests/batched.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
// T10 batched-forward equivalence: a batched forward over B sequences must equal
|
||||
// the old single-sequence path (run each sequence on its own, concatenate the
|
||||
// logits) — both for the forward logits AND every parameter's gradient.
|
||||
//
|
||||
// This is THE on-GPU correctness gate for batching (no PyTorch needed): if the
|
||||
// per-sequence RoPE position, per-sequence causal masking, or any flattened op
|
||||
// were wrong, the batched logits/grads would drift from the looped reference.
|
||||
//
|
||||
// Forward equivalence: batched logits[b*S+i] == single-seq-b logits[i].
|
||||
// Gradient equivalence: the batched loss is the mean over all B*S rows, i.e.
|
||||
// (1/B)·Σ_b mean_i(loss_b); summing the B single-sequence losses and scaling by
|
||||
// 1/B gives the SAME scalar, so their summed grads (tape fan-out) ×1/B match the
|
||||
// batched grads. We check that.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor, ids_tensor};
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batched_matches_looped_single_sequence() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
let batch = 3usize;
|
||||
let seq = 5usize;
|
||||
// B distinct sequences (sequence-major), within vocab.
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// --- Batched forward: ONE pass over [B*S]. ---
|
||||
let bmodel = build(cfg, device);
|
||||
let bids = batched_ids_tensor(&seqs, device);
|
||||
let blogits = host(&bmodel.forward_batched(&bids, batch).value());
|
||||
|
||||
// --- Looped reference: each sequence on its own, concatenate logits. ---
|
||||
let smodel = build(cfg, device);
|
||||
let mut slogits = Vec::with_capacity(batch * seq * cfg.vocab);
|
||||
for s in &seqs {
|
||||
let ids = ids_tensor(s, device);
|
||||
slogits.extend(host(&smodel.forward(&ids).value()));
|
||||
}
|
||||
|
||||
// Forward equivalence (fp GEMM rounding only differs in summation order).
|
||||
let max_rel = blogits
|
||||
.iter()
|
||||
.zip(&slogits)
|
||||
.map(|(b, s)| (b - s).abs() / s.abs().max(1e-4))
|
||||
.fold(0.0f32, f32::max);
|
||||
println!("batched vs looped: logits max rel err = {max_rel:.3e}");
|
||||
assert!(max_rel < 1e-3, "batched logits diverged: {max_rel:.3e}");
|
||||
|
||||
// --- Gradient equivalence. ---
|
||||
// Batched: loss = mean over B*S rows; one backward.
|
||||
let bparams = bmodel.params();
|
||||
let btgt = batched_ids_tensor(&tgts, device);
|
||||
let bloss = bmodel.loss_batched(&bids, &btgt, batch);
|
||||
let bloss_val = host(&bloss.value())[0];
|
||||
bloss.backward();
|
||||
|
||||
// Looped: Σ_b loss_b (each a per-sequence mean), then grad ×(1/B) == batched.
|
||||
let sparams = smodel.params();
|
||||
let mut sloss_sum = 0.0f32;
|
||||
for (s, t) in seqs.iter().zip(&tgts) {
|
||||
let ids = ids_tensor(s, device);
|
||||
let tg = ids_tensor(t, device);
|
||||
let l = smodel.loss(&ids, &tg);
|
||||
sloss_sum += host(&l.value())[0];
|
||||
l.backward();
|
||||
}
|
||||
println!(
|
||||
"batched loss = {bloss_val:.6} looped mean = {:.6}",
|
||||
sloss_sum / batch as f32
|
||||
);
|
||||
assert!(
|
||||
(bloss_val - sloss_sum / batch as f32).abs() < 1e-4,
|
||||
"batched loss != looped mean"
|
||||
);
|
||||
|
||||
let mut max_grad_rel = 0.0f32;
|
||||
for (bp, sp) in bparams.iter().zip(&sparams) {
|
||||
let bg = host(&bp.grad().expect("batched grad"));
|
||||
let sg = host(&sp.grad().expect("looped grad"));
|
||||
for (g_b, g_s) in bg.iter().zip(&sg) {
|
||||
// looped grad is the SUM over B sequences; ×(1/B) recovers the mean.
|
||||
let g_s = g_s / batch as f32;
|
||||
let rel = (g_b - g_s).abs() / g_s.abs().max(1e-4);
|
||||
max_grad_rel = max_grad_rel.max(rel);
|
||||
}
|
||||
}
|
||||
println!("batched vs looped: grad max rel err = {max_grad_rel:.3e}");
|
||||
assert!(
|
||||
max_grad_rel < 5e-3,
|
||||
"batched grads diverged: {max_grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
@@ -55,10 +55,13 @@ NH = int(cfg["n_heads"])
|
||||
HD = int(cfg["head_dim"])
|
||||
EPS = float(cfg["eps"])
|
||||
THETA = float(cfg["rope_theta"])
|
||||
# Batched: B sequences of length SEQ, flattened sequence-major to [B*SEQ] ids.
|
||||
B = int(cfg.get("batch", "1"))
|
||||
SEQ = int(cfg["seq"])
|
||||
|
||||
ids = read_ids("ids.txt")
|
||||
targets = read_ids("targets.txt")
|
||||
SEQ = len(ids)
|
||||
assert len(ids) == B * SEQ, f"ids {len(ids)} != B*SEQ {B*SEQ}"
|
||||
|
||||
# Load params as leaf tensors requiring grad (float64 for a clean reference).
|
||||
P = {}
|
||||
@@ -76,15 +79,16 @@ def rms_norm(x, gamma):
|
||||
return x * torch.rsqrt(ms + EPS) * gamma
|
||||
|
||||
|
||||
def rope(x): # x: [seq, nh, hd], position = token index, matching the kernel
|
||||
def rope(x): # x: [B*SEQ, nh, hd], position = (row % SEQ) — resets per sequence
|
||||
half = HD // 2
|
||||
out = torch.empty_like(x)
|
||||
i = torch.arange(half, dtype=torch.float64)
|
||||
freq = THETA ** (-(2.0 * i) / HD) # [half]
|
||||
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1) # [seq,1]
|
||||
ang = pos * freq # [seq, half]
|
||||
c = torch.cos(ang).reshape(SEQ, 1, half)
|
||||
s = torch.sin(ang).reshape(SEQ, 1, half)
|
||||
# Position within each sequence: rows 0..SEQ for seq 0, 0..SEQ for seq 1, ...
|
||||
pos = (torch.arange(B * SEQ, dtype=torch.float64) % SEQ).reshape(B * SEQ, 1)
|
||||
ang = pos * freq # [B*SEQ, half]
|
||||
c = torch.cos(ang).reshape(B * SEQ, 1, half)
|
||||
s = torch.sin(ang).reshape(B * SEQ, 1, half)
|
||||
x0 = x[..., :half]
|
||||
x1 = x[..., half:]
|
||||
out[..., :half] = x0 * c - x1 * s
|
||||
@@ -102,26 +106,30 @@ for l in range(NL):
|
||||
"ffn_norm", "w_gate", "w_up", "w_down"]})
|
||||
|
||||
idx = torch.tensor(ids, dtype=torch.long)
|
||||
# Per-sequence causal mask (broadcast over the batch); NO cross-sequence attention.
|
||||
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
|
||||
|
||||
h = emb[idx] # [seq, dim]
|
||||
h = emb[idx] # [B*SEQ, dim] (everything stays flattened, matching the Rust path)
|
||||
for L in layers:
|
||||
# Attention
|
||||
x = rms_norm(h, L["attn_norm"])
|
||||
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
|
||||
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
|
||||
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
|
||||
q = (x @ L["wq"]).reshape(B * SEQ, NH, HD)
|
||||
k = (x @ L["wk"]).reshape(B * SEQ, NH, HD)
|
||||
v = (x @ L["wv"]).reshape(B * SEQ, NH, HD)
|
||||
# Per-head QK-norm (Qwen3-style), before RoPE.
|
||||
q = rms_norm(q, L["q_norm"])
|
||||
k = rms_norm(k, L["k_norm"])
|
||||
q = rope(q).transpose(0, 1) # [nh, seq, hd]
|
||||
k = rope(k).transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
q = rope(q) # [B*SEQ, nh, hd]
|
||||
k = rope(k)
|
||||
# Reshape to [B, NH, SEQ, HD] so attention runs within each sequence.
|
||||
q = q.reshape(B, SEQ, NH, HD).transpose(1, 2) # [B, nh, seq, hd]
|
||||
k = k.reshape(B, SEQ, NH, HD).transpose(1, 2)
|
||||
v = v.reshape(B, SEQ, NH, HD).transpose(1, 2)
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
scores = (q @ k.transpose(-1, -2)) * scale + mask # [nh, seq, seq]
|
||||
scores = (q @ k.transpose(-1, -2)) * scale + mask # [B, nh, seq, seq]
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
out = probs @ v # [nh, seq, hd]
|
||||
out = out.transpose(0, 1).reshape(SEQ, DIM) # [seq, dim]
|
||||
out = probs @ v # [B, nh, seq, hd]
|
||||
out = out.transpose(1, 2).reshape(B * SEQ, DIM) # [B*SEQ, dim]
|
||||
attn = out @ L["wo"]
|
||||
h = h + attn
|
||||
# MLP
|
||||
@@ -133,7 +141,7 @@ for L in layers:
|
||||
h = h + mlp
|
||||
|
||||
h = rms_norm(h, final_norm)
|
||||
logits = h @ lm_head # [seq, vocab]
|
||||
logits = h @ lm_head # [B*SEQ, vocab]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
|
||||
|
||||
@@ -53,12 +53,17 @@ fn dump_for_parity() {
|
||||
);
|
||||
fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
// Fixed config + ids (independent of any text, for reproducibility).
|
||||
// Fixed config + ids (independent of any text, for reproducibility). B>1 so
|
||||
// the batched forward is exercised: 2 sequences of length 4, flattened
|
||||
// sequence-major to [B*S]=8 ids. Per-sequence RoPE position (resets at the
|
||||
// sequence boundary) + per-sequence causal masking (no cross-sequence
|
||||
// attention) are both checked against PyTorch.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 12;
|
||||
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
|
||||
let batch = 2usize;
|
||||
let seq = 4usize;
|
||||
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6]; // [B*S], sequence-major
|
||||
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
|
||||
let seq = ids.len();
|
||||
|
||||
// Same deterministic init as the overfit test.
|
||||
let mut seed = 1u64;
|
||||
@@ -83,6 +88,7 @@ fn dump_for_parity() {
|
||||
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
|
||||
writeln!(f, "eps {:e}", cfg.eps).unwrap();
|
||||
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
|
||||
writeln!(f, "batch {batch}").unwrap();
|
||||
writeln!(f, "seq {seq}").unwrap();
|
||||
}
|
||||
{
|
||||
@@ -105,10 +111,11 @@ fn dump_for_parity() {
|
||||
write_vec(&dir, &format!("w_{name}.txt"), ¶m_to_host(p), &shape);
|
||||
}
|
||||
|
||||
// Forward logits + loss, then backward → per-param grads.
|
||||
// Batched forward logits + loss (B sequences as one forward), then backward
|
||||
// → per-param grads.
|
||||
let ids_t = ids_tensor(&ids, device);
|
||||
let targets_t = ids_tensor(&targets, device);
|
||||
let logits = model.forward(&ids_t);
|
||||
let logits = model.forward_batched(&ids_t, batch);
|
||||
write_vec(
|
||||
&dir,
|
||||
"logits.txt",
|
||||
@@ -116,7 +123,7 @@ fn dump_for_parity() {
|
||||
logits.value().shape(),
|
||||
);
|
||||
|
||||
let loss = model.loss(&ids_t, &targets_t);
|
||||
let loss = model.loss_batched(&ids_t, &targets_t, batch);
|
||||
let loss_val = param_to_host(&loss)[0];
|
||||
{
|
||||
let mut f = fs::File::create(dir.join("loss.txt")).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user