diff --git a/crates/xtrain-model/src/decode.rs b/crates/xtrain-model/src/decode.rs new file mode 100644 index 0000000..a3b2d10 --- /dev/null +++ b/crates/xtrain-model/src/decode.rs @@ -0,0 +1,225 @@ +//! KV-cache incremental-decode engine (post-training M2a, single sequence). +//! +//! The naive sampler ([`crate::TinyTransformer`] via `train::sample::generate`) +//! re-runs the full forward over the whole growing prefix every step — O(t²) and +//! a fresh autograd graph per token. This is the inference engine that replaces it: +//! a per-layer **K/V cache** + a **single-token incremental forward** that processes +//! one new token at a time, attending to the cached keys/values. +//! +//! Built on three primitives, all gated by their own correctness tests: +//! - [`Tensor::rope_at`](xtrain_tensor::Tensor::rope_at): RoPE at the token's +//! absolute position (not row-in-tile), so cached post-RoPE K matches the full +//! forward (bit-identical, `integration::rope_at_matches_full_rope_row`). +//! - [`Tensor::decode_attention`](xtrain_tensor::Tensor::decode_attention): the +//! single-query × cached-K/V SDPA, equal to the full causal attention's last row +//! (`integration::decode_attention_matches_full_attention_last_row`). +//! - this module's per-token block forward, mirroring `model::block_forward` at the +//! raw-Tensor level (no autograd tape — inference needs no gradients). +//! +//! Correctness gate (the M2 centerpiece): KV-cache greedy decode is **token- +//! identical** to the naive full-recompute greedy (`tests/decode_kv.rs`). +//! +//! Prefill is just the first `prompt.len()` decode steps (one token at a time) — +//! one code path, at the cost of a non-batched prefill (M2b adds batched prefill + +//! ragged batch decode). The cache is host-accumulated (token-major f32) and the +//! K/V tensor is rebuilt per step; the host round-trip is small (`num_kv·head_dim` +//! floats/token/layer) and is the honest M2a baseline — M2b moves it device-side. + +#![cfg(not(no_cuda))] + +use crate::TinyTransformer; +use xtrain_tensor::{DType, Device, Tensor}; + +/// Per-layer K/V cache: token-major host accumulation. For each layer, `k[li]` and +/// `v[li]` hold `[T, num_kv, head_dim]` (f32, flattened), grown by one token's +/// `num_kv·head_dim` values per decode step. Stored f32 (an exact upcast of the +/// bf16 projection output); rebuilt to the compute dtype when forming the K/V +/// tensor, so bf16 values round-trip bit-for-bit. +struct KVCache { + k: Vec>, + v: Vec>, +} + +impl KVCache { + fn new(n_layers: usize) -> Self { + Self { + k: vec![Vec::new(); n_layers], + v: vec![Vec::new(); n_layers], + } + } + + /// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`. + fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) { + self.k[li].extend_from_slice(k_tok); + self.v[li].extend_from_slice(v_tok); + } +} + +/// Linear `x @ W` in the compute dtype — mirrors `model::linear` (bf16 casts the +/// fp32-master weight to bf16 on the fly; the activation stream is already bf16). +fn linear_t(cdt: DType, x: &Tensor, w: &Tensor) -> Tensor { + match cdt { + DType::F32 => x.matmul(w), + DType::BF16 => x.matmul(&w.to_dtype(DType::BF16)), + _ => unreachable!("compute dtype must be F32/BF16"), + } +} + +/// A norm/QK-norm gamma in the compute dtype — mirrors `model::norm_gamma`. +fn gamma_t(cdt: DType, g: &Tensor) -> Tensor { + match cdt { + DType::F32 => g.clone(), + DType::BF16 => g.to_dtype(DType::BF16), + _ => unreachable!("compute dtype must be F32/BF16"), + } +} + +/// Greedy KV-cache decode: continue `prompt` by `max_new` tokens, argmax each step. +/// Returns the full token sequence (prompt + generated), matching the naive +/// `sample::generate` interface for `temperature == 0`. Token-identical to the +/// naive full-recompute greedy (gated by `tests/decode_kv.rs`). +pub fn generate_greedy_cached( + model: &TinyTransformer, + device: Device, + prompt: &[i32], + max_new: usize, +) -> Vec { + assert!(!prompt.is_empty(), "prompt must be non-empty"); + let cfg = model.config(); + let cdt = model.compute_dtype(); + let n_layers = cfg.n_layers; + + // params() is a stable, documented order (see TinyTransformer::params): + // [0] = embed [vocab, dim] + // [1 + li*11 .. +11] = layer li's 11 leaves, in block_params order: + // attn_norm, wq, wk, wv, q_norm, k_norm, wo, ffn_norm, w_gate, w_up, w_down + // [1 + n_layers*11] = final_norm [dim] + // [1 + n_layers*11 + 1] = lm_head [dim, vocab] + let params: Vec = model.params().iter().map(|p| p.value()).collect(); + assert_eq!( + params.len(), + 1 + n_layers * 11 + 2, + "unexpected param layout for decode" + ); + let embed = ¶ms[0]; + let final_norm = ¶ms[1 + n_layers * 11]; + let lm_head = ¶ms[1 + n_layers * 11 + 1]; + + let mut cache = KVCache::new(n_layers); + let mut tokens = prompt.to_vec(); + + // Prefill: feed each prompt token in order; the last step's logits are the + // distribution for the first generated token. + let mut logits = Vec::new(); + for (pos, &tok) in prompt.iter().enumerate() { + logits = decode_step(¶ms, cfg, cdt, device, &mut cache, tok, pos, embed, final_norm, lm_head); + } + + for _ in 0..max_new { + let next = argmax(&logits) as i32; + tokens.push(next); + let pos = tokens.len() - 1; // absolute position of the token just appended + logits = decode_step(¶ms, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head); + } + tokens +} + +/// One incremental decode step for token `tok` at absolute position `pos`: append +/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`. +#[allow(clippy::too_many_arguments)] +fn decode_step( + params: &[Tensor], + cfg: &crate::Config, + cdt: DType, + device: Device, + cache: &mut KVCache, + tok: i32, + pos: usize, + embed: &Tensor, + final_norm: &Tensor, + lm_head: &Tensor, +) -> Vec { + let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads); + let dim = cfg.dim; + let kv_dim = num_kv * hd; + let scale = 1.0 / (hd as f32).sqrt(); + let (theta, eps) = (cfg.rope_theta, cfg.eps); + let n_layers = cfg.n_layers; + + // Embedding (fp32 table) → activation stream in the compute dtype. + let ids = Tensor::from_slice(&[tok], &[1]).to_device(device); + let mut h = embed.embedding(&ids); // [1, dim] f32 + if cdt == DType::BF16 { + h = h.to_dtype(DType::BF16); + } + + for li in 0..n_layers { + let base = 1 + li * 11; + let (attn_norm, wq, wk, wv) = + (¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]); + let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]); + let (ffn_norm, w_gate, w_up, w_down) = + (¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]); + + // --- Attention sub-block (pre-norm + cached-KV attention + residual) --- + let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [1, dim] + + // Q: project → per-head QK-norm → RoPE at absolute position `pos`. + let q = linear_t(cdt, &normed, wq).reshape(&[1, nh, hd]); // [1, nh, hd] + let q = q.reshape(&[nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0; + let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos); + let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data + + // K: same as Q (QK-norm + RoPE); cache token-major. V: project only. + let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]); + let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0; + let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd] + let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]); + + let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::().to_vec(); + let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::().to_vec(); + cache.append(li, &k_host, &v_host); + + // Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd] + // → repeat_kv to [nh,T,hd]. + let t_len = cache.k[li].len() / kv_dim; + let build = |flat: &[f32]| -> Tensor { + let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd]) + .to_device(device) + .transpose_3d01(); // [num_kv, T, hd], f32 + let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv }; + if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd] + }; + let k_full = build(&cache.k[li]); + let v_full = build(&cache.v[li]); + + let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd] + let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim) + let attn_out = linear_t(cdt, &attn, wo); // [1, dim] + h = h.add(&attn_out); + + // --- MLP sub-block (pre-norm + SwiGLU + residual) --- + let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0; + let gate = linear_t(cdt, &normed, w_gate); + let up = linear_t(cdt, &normed, w_up); + let act = gate.silu().mul(&up); // swiglu = silu(gate) ∘ up + let down = linear_t(cdt, &act, w_down); + h = h.add(&down); + } + + let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0; + let logits = linear_t(cdt, &h, lm_head); // [1, vocab] + logits + .to_dtype(DType::F32) + .to_device(Device::Cpu) + .as_slice::() + .to_vec() +} + +fn argmax(row: &[f32]) -> usize { + row.iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .unwrap() + .0 +} diff --git a/crates/xtrain-model/src/lib.rs b/crates/xtrain-model/src/lib.rs index 6b4c651..a5a32aa 100644 --- a/crates/xtrain-model/src/lib.rs +++ b/crates/xtrain-model/src/lib.rs @@ -25,3 +25,8 @@ pub use config::Config; mod model; #[cfg(not(no_cuda))] pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host}; + +#[cfg(not(no_cuda))] +pub mod decode; +#[cfg(not(no_cuda))] +pub use decode::generate_greedy_cached; diff --git a/crates/xtrain-train/src/bin/eval_arith.rs b/crates/xtrain-train/src/bin/eval_arith.rs index bdba382..c0b6a12 100644 --- a/crates/xtrain-train/src/bin/eval_arith.rs +++ b/crates/xtrain-train/src/bin/eval_arith.rs @@ -103,6 +103,10 @@ fn main() { let n_show = flag(&args, "--show", 8usize); let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required"); let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required"); + // M2: decode through the KV-cache incremental engine instead of the naive + // full-recompute sampler. Token-identical to the naive path (gated by + // tests/decode_kv.rs); this flag also lets us A/B the two for the speedup. + let use_cached = args.iter().any(|a| a == "--cached"); // Prompts: skip the `#` header / blank lines and decode escaped newlines so the // count and order line up with the gold file. @@ -148,18 +152,26 @@ fn main() { xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint"); println!( - "eval_arith: ckpt {} | {} prompts | max_new {}", + "eval_arith: ckpt {} | {} prompts | max_new {} | decode={}", ckpt.display(), prompts.len(), - max_new + max_new, + if use_cached { "kv-cache" } else { "naive" } ); let (mut n_boxed, mut n_correct) = (0usize, 0usize); let mut shown = 0usize; + let mut gen_tokens = 0usize; + let t0 = std::time::Instant::now(); for (prompt, &gold) in prompts.iter().zip(&golds) { let ids: Vec = tok.encode(prompt).into_iter().map(|t| t as i32).collect(); - let mut rng = 7u64; - let out = generate(&model, device, &ids, max_new, 0.0, &mut rng); + let out = if use_cached { + xtrain_model::generate_greedy_cached(&model, device, &ids, max_new) + } else { + let mut rng = 7u64; + generate(&model, device, &ids, max_new, 0.0, &mut rng) + }; + gen_tokens += out.len() - ids.len(); let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::>()); let seg = first_answer_segment(&cont); if parse_boxed_answer(seg).is_some() { @@ -176,6 +188,7 @@ fn main() { } } + let elapsed = t0.elapsed().as_secs_f64(); let n = prompts.len() as f64; println!( "RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)", @@ -186,4 +199,11 @@ fn main() { prompts.len(), 100.0 * n_correct as f64 / n, ); + println!( + "TIMING decode={} | {:.2}s | {} gen tokens | {:.1} tok/s", + if use_cached { "kv-cache" } else { "naive" }, + elapsed, + gen_tokens, + gen_tokens as f64 / elapsed, + ); } diff --git a/crates/xtrain-train/tests/decode_kv.rs b/crates/xtrain-train/tests/decode_kv.rs new file mode 100644 index 0000000..47d1b89 --- /dev/null +++ b/crates/xtrain-train/tests/decode_kv.rs @@ -0,0 +1,94 @@ +// M2a KV-cache decode engine — the token-identical correctness gate. +// +// The centerpiece M2 invariant: greedy decode through the KV-cache incremental +// engine (`xtrain_model::generate_greedy_cached`) must be TOKEN-IDENTICAL to the +// naive full-recompute greedy (`xtrain_train::sample::generate` at temperature 0), +// which re-runs the whole forward over the growing prefix each step. Same tokens ⇒ +// the cache + decode-time attention + RoPE-at-position reproduce the full forward. +// +// Numerics note: a randomly-initialised model has near-uniform logits, so argmax +// can be fragile to ~1e-6 differences. This unit gate therefore runs in F32 (the +// tightest path, and the dtype the eval harness actually uses) on a small model. +// The headline gate on the trained v12 checkpoint (peaked logits → robust argmax) +// is run on the GPU box and recorded in docs/18. +#![cfg(not(no_cuda))] + +use xtrain_cuda::device; +use xtrain_model::{Config, TinyTransformer, generate_greedy_cached}; +use xtrain_tensor::{DType, Device}; + +fn fill(n: usize, seed: u64, scale: f32) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale + }) + .collect() +} + +fn build(cfg: Config, device: Device, dtype: DType) -> TinyTransformer { + let mut seed = 1u64; + let m = TinyTransformer::new(cfg, device, |shape| { + seed = seed.wrapping_add(1); + let n: usize = shape.iter().product(); + if shape.len() == 1 { + fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() + } else { + fill(n, seed, 0.08) + } + }); + m.with_compute_dtype(dtype) +} + +// A real GQA config (8 query / 2 kv heads → group 4) to exercise repeat_kv in the +// decode path; head_dim 16, dim 128, 4 layers. +fn gqa_cfg() -> Config { + Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2) +} + +#[test] +fn kv_cache_decode_is_token_identical_to_naive_f32() { + assert!( + device::device_count().expect("device count") > 0, + "no CUDA device" + ); + device::set_device(0).unwrap(); + let device = Device::Cuda(0); + + let model = build(gqa_cfg(), device, DType::F32); + let prompt: Vec = vec![1, 5, 9, 13, 2, 7]; + let max_new = 24usize; + + let mut rng = 7u64; + let naive = xtrain_train::sample::generate(&model, device, &prompt, max_new, 0.0, &mut rng); + let cached = generate_greedy_cached(&model, device, &prompt, max_new); + + assert_eq!( + naive.len(), + cached.len(), + "length mismatch: naive {} vs cached {}", + naive.len(), + cached.len() + ); + if naive != cached { + // Report the first divergence for debugging. + let first = naive + .iter() + .zip(&cached) + .position(|(a, b)| a != b) + .unwrap(); + panic!( + "token divergence at index {first}: naive={:?} cached={:?}\nnaive ={naive:?}\ncached ={cached:?}", + naive[first], cached[first] + ); + } + println!( + "KV-cache decode token-identical to naive over {} generated tokens (F32, GQA 8/2)", + max_new + ); +}