//! KV-cache incremental-decode engine (post-training M2a, single sequence). //! //! The naive sampler ([`crate::TinyTransformer`] via `train::sample::generate`) //! re-runs the full forward over the whole growing prefix every step — O(t²) and //! a fresh autograd graph per token. This is the inference engine that replaces it: //! a per-layer **K/V cache** + a **single-token incremental forward** that processes //! one new token at a time, attending to the cached keys/values. //! //! Built on three primitives, all gated by their own correctness tests: //! - [`Tensor::rope_at`](xtrain_tensor::Tensor::rope_at): RoPE at the token's //! absolute position (not row-in-tile), so cached post-RoPE K matches the full //! forward (bit-identical, `integration::rope_at_matches_full_rope_row`). //! - [`Tensor::decode_attention`](xtrain_tensor::Tensor::decode_attention): the //! single-query × cached-K/V SDPA, equal to the full causal attention's last row //! (`integration::decode_attention_matches_full_attention_last_row`). //! - this module's per-token block forward, mirroring `model::block_forward` at the //! raw-Tensor level (no autograd tape — inference needs no gradients). //! //! Correctness gate (the M2 centerpiece): KV-cache greedy decode is **token- //! identical** to the naive full-recompute greedy (`tests/decode_kv.rs`). //! //! Prefill is just the first `prompt.len()` decode steps (one token at a time) — //! one code path, at the cost of a non-batched prefill (M2b adds batched prefill + //! ragged batch decode). The cache is host-accumulated (token-major f32) and the //! K/V tensor is rebuilt per step; the host round-trip is small (`num_kv·head_dim` //! floats/token/layer) and is the honest M2a baseline — M2b moves it device-side. #![cfg(not(no_cuda))] use crate::TinyTransformer; use xtrain_tensor::{DType, Device, Tensor}; /// Per-layer K/V cache: token-major host accumulation. For each layer, `k[li]` and /// `v[li]` hold `[T, num_kv, head_dim]` (f32, flattened), grown by one token's /// `num_kv·head_dim` values per decode step. Stored f32 (an exact upcast of the /// bf16 projection output); rebuilt to the compute dtype when forming the K/V /// tensor, so bf16 values round-trip bit-for-bit. struct KVCache { k: Vec>, v: Vec>, } impl KVCache { fn new(n_layers: usize) -> Self { Self { k: (0..n_layers).map(|_| None).collect(), v: (0..n_layers).map(|_| None).collect(), } } /// Append one token's K/V (`[bh,1,hd]`, compute dtype) to layer `li`, growing the /// device-resident `[bh,T,hd]` cache via `cat_seq` (no host round-trip, M2c). fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) { self.k[li] = Some(match self.k[li].take() { Some(c) => c.cat_seq(&k_bh), None => k_bh, }); self.v[li] = Some(match self.v[li].take() { Some(c) => c.cat_seq(&v_bh), None => v_bh, }); } } /// Linear `x @ W` in the compute dtype — mirrors `model::linear` (bf16 casts the /// fp32-master weight to bf16 on the fly; the activation stream is already bf16). fn linear_t(cdt: DType, x: &Tensor, w: &Tensor) -> Tensor { match cdt { DType::F32 => x.matmul(w), DType::BF16 => x.matmul(&w.to_dtype(DType::BF16)), _ => unreachable!("compute dtype must be F32/BF16"), } } /// A norm/QK-norm gamma in the compute dtype — mirrors `model::norm_gamma`. fn gamma_t(cdt: DType, g: &Tensor) -> Tensor { match cdt { DType::F32 => g.clone(), DType::BF16 => g.to_dtype(DType::BF16), _ => unreachable!("compute dtype must be F32/BF16"), } } /// Greedy KV-cache decode: continue `prompt` by `max_new` tokens, argmax each step. /// Returns the full token sequence (prompt + generated), matching the naive /// `sample::generate` interface for `temperature == 0`. Token-identical to the /// naive full-recompute greedy (gated by `tests/decode_kv.rs`). pub fn generate_greedy_cached( model: &TinyTransformer, device: Device, prompt: &[i32], max_new: usize, ) -> Vec { let mut rng = 0u64; generate_cached(model, device, prompt, max_new, 0.0, &mut rng) } /// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax, /// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`). /// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row /// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far /// lighter on memory + the allocator — the naive sampler fragments the caching /// allocator over a long rollout, which is the M4 "rollout is the long pole" wall. pub fn generate_cached( model: &TinyTransformer, device: Device, prompt: &[i32], max_new: usize, temperature: f32, rng_state: &mut u64, ) -> Vec { assert!(!prompt.is_empty(), "prompt must be non-empty"); let cfg = model.config(); let cdt = model.compute_dtype(); let n_layers = cfg.n_layers; // params() is a stable, documented order (see TinyTransformer::params): // [0] = embed [vocab, dim] // [1 + li*11 .. +11] = layer li's 11 leaves, in block_params order: // attn_norm, wq, wk, wv, q_norm, k_norm, wo, ffn_norm, w_gate, w_up, w_down // [1 + n_layers*11] = final_norm [dim] // [1 + n_layers*11 + 1] = lm_head [dim, vocab] let params: Vec = model.params().iter().map(|p| p.value()).collect(); assert_eq!( params.len(), 1 + n_layers * 11 + 2, "unexpected param layout for decode" ); let embed = ¶ms[0]; let final_norm = ¶ms[1 + n_layers * 11]; let lm_head = ¶ms[1 + n_layers * 11 + 1]; let mut cache = KVCache::new(n_layers); let mut tokens = prompt.to_vec(); // Prefill: feed each prompt token in order; the last step's logits are the // distribution for the first generated token. let mut logits = Vec::new(); for (pos, &tok) in prompt.iter().enumerate() { logits = decode_step(¶ms, cfg, cdt, device, &mut cache, tok, pos, embed, final_norm, lm_head); } for _ in 0..max_new { let next = if temperature <= 0.0 { argmax(&logits) as i32 } else { sample_temperature(&logits, temperature, rng_state) as i32 }; tokens.push(next); let pos = tokens.len() - 1; // absolute position of the token just appended logits = decode_step(¶ms, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head); } tokens } /// Sample a token from `softmax(logits / temperature)` (numerically stable). Same /// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`. fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize { let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let exps: Vec = row.iter().map(|&x| ((x - max) / temperature).exp()).collect(); let sum: f32 = exps.iter().sum(); *rng_state = rng_state .wrapping_mul(6364136223846793005) .wrapping_add(1442695040888963407); let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum; let mut acc = 0.0; for (i, &e) in exps.iter().enumerate() { acc += e; if acc >= r { return i; } } exps.len() - 1 } /// One incremental decode step for token `tok` at absolute position `pos`: append /// its K/V to the cache and return the next-token logits as host f32 `[vocab]`. #[allow(clippy::too_many_arguments)] fn decode_step( params: &[Tensor], cfg: &crate::Config, cdt: DType, device: Device, cache: &mut KVCache, tok: i32, pos: usize, embed: &Tensor, final_norm: &Tensor, lm_head: &Tensor, ) -> Vec { let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads); let dim = cfg.dim; let scale = 1.0 / (hd as f32).sqrt(); let (theta, eps) = (cfg.rope_theta, cfg.eps); let n_layers = cfg.n_layers; // Embedding (fp32 table) → activation stream in the compute dtype. let ids = Tensor::from_slice(&[tok], &[1]).to_device(device); let mut h = embed.embedding(&ids); // [1, dim] f32 if cdt == DType::BF16 { h = h.to_dtype(DType::BF16); } for li in 0..n_layers { let base = 1 + li * 11; let (attn_norm, wq, wk, wv) = (¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]); let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]); let (ffn_norm, w_gate, w_up, w_down) = (¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]); // --- Attention sub-block (pre-norm + cached-KV attention + residual) --- let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [1, dim] // Q: project → per-head QK-norm → RoPE at absolute position `pos`. let q = linear_t(cdt, &normed, wq).reshape(&[1, nh, hd]); // [1, nh, hd] let q = q.reshape(&[nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0; let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos); let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data // K: same as Q (QK-norm + RoPE). V: project only. Append each as [num_kv,1,hd] // (bh-major) into the device cache; no host round-trip, no transpose (M2c). let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]); let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0; let k_bh = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos).reshape(&[num_kv, 1, hd]); let v_bh = linear_t(cdt, &normed, wv).reshape(&[num_kv, 1, hd]); cache.append(li, k_bh, v_bh); // repeat_kv the cached [num_kv,T,hd] to [nh,T,hd] for the SDPA. let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, 1) }; let k_full = expand(cache.k[li].as_ref().unwrap()); let v_full = expand(cache.v[li].as_ref().unwrap()); let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd] let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim) let attn_out = linear_t(cdt, &attn, wo); // [1, dim] h = h.add(&attn_out); // --- MLP sub-block (pre-norm + SwiGLU + residual) --- let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0; let gate = linear_t(cdt, &normed, w_gate); let up = linear_t(cdt, &normed, w_up); let act = gate.silu().mul(&up); // swiglu = silu(gate) ∘ up let down = linear_t(cdt, &act, w_down); h = h.add(&down); } let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0; let logits = linear_t(cdt, &h, lm_head); // [1, vocab] logits .to_dtype(DType::F32) .to_device(Device::Cpu) .as_slice::() .to_vec() } fn argmax(row: &[f32]) -> usize { row.iter() .enumerate() .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) .unwrap() .0 } // =================================================================== // M2b — batched KV-cache decode (G samples of one prompt, in lockstep) // =================================================================== /// Batched K/V cache: `G` sequences advancing together. Per layer, a device-resident /// `[G·num_kv, T, head_dim]` grown one token per step via `cat_seq` (M2c — no host /// round-trip). Same as M2a's device cache with a G dimension in `bh`. struct BatchKVCache { k: Vec>, v: Vec>, } impl BatchKVCache { fn new(n_layers: usize) -> Self { Self { k: (0..n_layers).map(|_| None).collect(), v: (0..n_layers).map(|_| None).collect(), } } fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) { self.k[li] = Some(match self.k[li].take() { Some(c) => c.cat_seq(&k_bh), None => k_bh, }); self.v[li] = Some(match self.v[li].take() { Some(c) => c.cat_seq(&v_bh), None => v_bh, }); } } /// Batched KV-cache decode: roll out `n_samples` (G) completions of the SAME /// `prompt` in lockstep — all G share the prompt, so they advance at one common /// decode position each step (uniform RoPE via `rope_pos`). Returns G full token /// sequences (prompt + sampled continuation). The G-way batching amortises the /// per-step kernel launches across G (the rollout long-pole). Token-identical per /// row to G independent single-sequence decodes (gated by `tests/decode_batch.rs`). /// /// `temperature == 0` ⇒ greedy (all G identical); `> 0` ⇒ independent samples /// (per-row draw from one shared `rng_state`). No finished-mask: all G generate /// `max_new` tokens; the caller cuts each at `<|endoftext|>` (a perf-only early /// stop is the M2b+ follow-up). Ragged (different-length prompts) is also deferred. pub fn generate_cached_batch( model: &TinyTransformer, device: Device, prompt: &[i32], n_samples: usize, max_new: usize, temperature: f32, rng_state: &mut u64, ) -> Vec> { assert!(!prompt.is_empty(), "prompt must be non-empty"); assert!(n_samples > 0, "n_samples must be > 0"); let cfg = model.config(); let cdt = model.compute_dtype(); let n_layers = cfg.n_layers; let params: Vec = model.params().iter().map(|p| p.value()).collect(); let embed = ¶ms[0]; let final_norm = ¶ms[1 + n_layers * 11]; let lm_head = ¶ms[1 + n_layers * 11 + 1]; let g = n_samples; let mut cache = BatchKVCache::new(n_layers); let mut seqs: Vec> = vec![prompt.to_vec(); g]; // Prefill: feed each prompt token (identical across G) at its position. let mut logits = Vec::new(); // [G, vocab] flattened for (pos, &tok) in prompt.iter().enumerate() { let toks = vec![tok; g]; logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &toks, pos, embed, final_norm, lm_head); } let vocab = cfg.vocab; for _ in 0..max_new { let mut next = Vec::with_capacity(g); for row in 0..g { let lg = &logits[row * vocab..(row + 1) * vocab]; let t = if temperature <= 0.0 { argmax(lg) as i32 } else { sample_temperature(lg, temperature, rng_state) as i32 }; next.push(t); seqs[row].push(t); } let pos = seqs[0].len() - 1; // all G are at the same position logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &next, pos, embed, final_norm, lm_head); } seqs } /// One batched decode step: `toks` is one current token per sequence (`[G]`), all at /// absolute position `pos`. Appends each sequence's K/V and returns logits `[G·vocab]`. #[allow(clippy::too_many_arguments)] fn decode_step_batch( params: &[Tensor], cfg: &crate::Config, cdt: DType, device: Device, cache: &mut BatchKVCache, toks: &[i32], pos: usize, embed: &Tensor, final_norm: &Tensor, lm_head: &Tensor, ) -> Vec { let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads); let dim = cfg.dim; let g = toks.len(); let scale = 1.0 / (hd as f32).sqrt(); let (theta, eps) = (cfg.rope_theta, cfg.eps); let n_layers = cfg.n_layers; // Uniform per-row position (all G at the same decode step). let positions = Tensor::from_slice(&vec![pos as i32; g], &[g]).to_device(device); let ids = Tensor::from_slice(toks, &[g]).to_device(device); let mut h = embed.embedding(&ids); // [G, dim] f32 if cdt == DType::BF16 { h = h.to_dtype(DType::BF16); } for li in 0..n_layers { let base = 1 + li * 11; let (attn_norm, wq, wk, wv) = (¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]); let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]); let (ffn_norm, w_gate, w_up, w_down) = (¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]); let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [G, dim] // Q: project → per-head QK-norm → RoPE at `pos` for every row. let q = linear_t(cdt, &normed, wq).reshape(&[g, nh, hd]); let q = q.reshape(&[g * nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0; let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta); let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh // K/V appended as [G·num_kv,1,hd] (bh-major) into the device cache (M2c). let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]); let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0; let k_bh = k .reshape(&[g, num_kv, hd]) .rope_pos(&positions, theta) .reshape(&[g * num_kv, 1, hd]); let v_bh = linear_t(cdt, &normed, wv).reshape(&[g * num_kv, 1, hd]); cache.append(li, k_bh, v_bh); // repeat_kv the cached [G·num_kv,T,hd] to [G·nh,T,hd] for the SDPA. let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, g) }; let k_full = expand(cache.k[li].as_ref().unwrap()); let v_full = expand(cache.v[li].as_ref().unwrap()); let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd] let attn = attn.reshape(&[g, dim]); // concat heads per sequence let attn_out = linear_t(cdt, &attn, wo); h = h.add(&attn_out); let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0; let gate = linear_t(cdt, &normed, w_gate); let up = linear_t(cdt, &normed, w_up); let act = gate.silu().mul(&up); let down = linear_t(cdt, &act, w_down); h = h.add(&down); } let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0; linear_t(cdt, &h, lm_head) .to_dtype(DType::F32) .to_device(Device::Cpu) .as_slice::() .to_vec() }