post-train: M2a — KV-cache incremental decode engine (token-identical)

Single-sequence KV-cache decode (xtrain-model/src/decode.rs): per-layer K/V
cache + single-token incremental forward (prefill = first prompt.len() decode
steps, one code path). Mirrors model::block_forward at the raw-Tensor level (no
autograd tape — inference needs no grads), using rope_at + decode_attention.
Cache is host-accumulated token-major f32, rebuilt per step (the honest M2a
baseline; M2b moves it device-side + batched ragged).

Gate (the M2 centerpiece): KV-cache greedy decode is TOKEN-IDENTICAL to the
naive full-recompute greedy — tests/decode_kv.rs (small GQA model, F32, 24
tokens) and corroborated on the v12 1.05B SFT checkpoint (cached eval =
naive eval byte-for-byte: format 100/100, correct 8/100).

eval_arith --cached A/Bs the two paths + reports decode tok/s. Measured on v12
(1.05B, batch 1, F32): the cache win is sequence-length-dependent —
  max_new=32   naive 108 vs cached 111 tok/s  (~1.0x; overhead-bound)
  max_new=128  naive  69 vs cached 133 tok/s  (~1.9x)
  max_new=256  naive OOM     vs cached 129 tok/s
Cached throughput stays ~constant (O(1)/token) while naive decays (O(t)/token,
O(seq^2) graph → OOM at length). Short eval prompts are overhead-bound, so the
cache matters for long rollouts (DPO/GRPO), not the arithmetic eval itself.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 12:00:03 +08:00
parent c88e2ab88c
commit eff26a0898
4 changed files with 348 additions and 4 deletions

View File

@@ -0,0 +1,225 @@
//! KV-cache incremental-decode engine (post-training M2a, single sequence).
//!
//! The naive sampler ([`crate::TinyTransformer`] via `train::sample::generate`)
//! re-runs the full forward over the whole growing prefix every step — O(t²) and
//! a fresh autograd graph per token. This is the inference engine that replaces it:
//! a per-layer **K/V cache** + a **single-token incremental forward** that processes
//! one new token at a time, attending to the cached keys/values.
//!
//! Built on three primitives, all gated by their own correctness tests:
//! - [`Tensor::rope_at`](xtrain_tensor::Tensor::rope_at): RoPE at the token's
//! absolute position (not row-in-tile), so cached post-RoPE K matches the full
//! forward (bit-identical, `integration::rope_at_matches_full_rope_row`).
//! - [`Tensor::decode_attention`](xtrain_tensor::Tensor::decode_attention): the
//! single-query × cached-K/V SDPA, equal to the full causal attention's last row
//! (`integration::decode_attention_matches_full_attention_last_row`).
//! - this module's per-token block forward, mirroring `model::block_forward` at the
//! raw-Tensor level (no autograd tape — inference needs no gradients).
//!
//! Correctness gate (the M2 centerpiece): KV-cache greedy decode is **token-
//! identical** to the naive full-recompute greedy (`tests/decode_kv.rs`).
//!
//! Prefill is just the first `prompt.len()` decode steps (one token at a time) —
//! one code path, at the cost of a non-batched prefill (M2b adds batched prefill +
//! ragged batch decode). The cache is host-accumulated (token-major f32) and the
//! K/V tensor is rebuilt per step; the host round-trip is small (`num_kv·head_dim`
//! floats/token/layer) and is the honest M2a baseline — M2b moves it device-side.
#![cfg(not(no_cuda))]
use crate::TinyTransformer;
use xtrain_tensor::{DType, Device, Tensor};
/// Per-layer K/V cache: token-major host accumulation. For each layer, `k[li]` and
/// `v[li]` hold `[T, num_kv, head_dim]` (f32, flattened), grown by one token's
/// `num_kv·head_dim` values per decode step. Stored f32 (an exact upcast of the
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
/// tensor, so bf16 values round-trip bit-for-bit.
struct KVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
}
impl KVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
}
}
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
}
}
/// Linear `x @ W` in the compute dtype — mirrors `model::linear` (bf16 casts the
/// fp32-master weight to bf16 on the fly; the activation stream is already bf16).
fn linear_t(cdt: DType, x: &Tensor, w: &Tensor) -> Tensor {
match cdt {
DType::F32 => x.matmul(w),
DType::BF16 => x.matmul(&w.to_dtype(DType::BF16)),
_ => unreachable!("compute dtype must be F32/BF16"),
}
}
/// A norm/QK-norm gamma in the compute dtype — mirrors `model::norm_gamma`.
fn gamma_t(cdt: DType, g: &Tensor) -> Tensor {
match cdt {
DType::F32 => g.clone(),
DType::BF16 => g.to_dtype(DType::BF16),
_ => unreachable!("compute dtype must be F32/BF16"),
}
}
/// Greedy KV-cache decode: continue `prompt` by `max_new` tokens, argmax each step.
/// Returns the full token sequence (prompt + generated), matching the naive
/// `sample::generate` interface for `temperature == 0`. Token-identical to the
/// naive full-recompute greedy (gated by `tests/decode_kv.rs`).
pub fn generate_greedy_cached(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
) -> Vec<i32> {
assert!(!prompt.is_empty(), "prompt must be non-empty");
let cfg = model.config();
let cdt = model.compute_dtype();
let n_layers = cfg.n_layers;
// params() is a stable, documented order (see TinyTransformer::params):
// [0] = embed [vocab, dim]
// [1 + li*11 .. +11] = layer li's 11 leaves, in block_params order:
// attn_norm, wq, wk, wv, q_norm, k_norm, wo, ffn_norm, w_gate, w_up, w_down
// [1 + n_layers*11] = final_norm [dim]
// [1 + n_layers*11 + 1] = lm_head [dim, vocab]
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
assert_eq!(
params.len(),
1 + n_layers * 11 + 2,
"unexpected param layout for decode"
);
let embed = &params[0];
let final_norm = &params[1 + n_layers * 11];
let lm_head = &params[1 + n_layers * 11 + 1];
let mut cache = KVCache::new(n_layers);
let mut tokens = prompt.to_vec();
// Prefill: feed each prompt token in order; the last step's logits are the
// distribution for the first generated token.
let mut logits = Vec::new();
for (pos, &tok) in prompt.iter().enumerate() {
logits = decode_step(&params, cfg, cdt, device, &mut cache, tok, pos, embed, final_norm, lm_head);
}
for _ in 0..max_new {
let next = argmax(&logits) as i32;
tokens.push(next);
let pos = tokens.len() - 1; // absolute position of the token just appended
logits = decode_step(&params, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
}
tokens
}
/// One incremental decode step for token `tok` at absolute position `pos`: append
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
#[allow(clippy::too_many_arguments)]
fn decode_step(
params: &[Tensor],
cfg: &crate::Config,
cdt: DType,
device: Device,
cache: &mut KVCache,
tok: i32,
pos: usize,
embed: &Tensor,
final_norm: &Tensor,
lm_head: &Tensor,
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let kv_dim = num_kv * hd;
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
// Embedding (fp32 table) → activation stream in the compute dtype.
let ids = Tensor::from_slice(&[tok], &[1]).to_device(device);
let mut h = embed.embedding(&ids); // [1, dim] f32
if cdt == DType::BF16 {
h = h.to_dtype(DType::BF16);
}
for li in 0..n_layers {
let base = 1 + li * 11;
let (attn_norm, wq, wk, wv) =
(&params[base], &params[base + 1], &params[base + 2], &params[base + 3]);
let (q_norm, k_norm, wo) = (&params[base + 4], &params[base + 5], &params[base + 6]);
let (ffn_norm, w_gate, w_up, w_down) =
(&params[base + 7], &params[base + 8], &params[base + 9], &params[base + 10]);
// --- Attention sub-block (pre-norm + cached-KV attention + residual) ---
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [1, dim]
// Q: project → per-head QK-norm → RoPE at absolute position `pos`.
let q = linear_t(cdt, &normed, wq).reshape(&[1, nh, hd]); // [1, nh, hd]
let q = q.reshape(&[nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
// → repeat_kv to [nh,T,hd].
let t_len = cache.k[li].len() / kv_dim;
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
.to_device(device)
.transpose_3d01(); // [num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
let attn_out = linear_t(cdt, &attn, wo); // [1, dim]
h = h.add(&attn_out);
// --- MLP sub-block (pre-norm + SwiGLU + residual) ---
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
let gate = linear_t(cdt, &normed, w_gate);
let up = linear_t(cdt, &normed, w_up);
let act = gate.silu().mul(&up); // swiglu = silu(gate) ∘ up
let down = linear_t(cdt, &act, w_down);
h = h.add(&down);
}
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
let logits = linear_t(cdt, &h, lm_head); // [1, vocab]
logits
.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
fn argmax(row: &[f32]) -> usize {
row.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0
}

View File

@@ -25,3 +25,8 @@ pub use config::Config;
mod model;
#[cfg(not(no_cuda))]
pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
#[cfg(not(no_cuda))]
pub mod decode;
#[cfg(not(no_cuda))]
pub use decode::generate_greedy_cached;

View File

@@ -103,6 +103,10 @@ fn main() {
let n_show = flag(&args, "--show", 8usize);
let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required");
let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required");
// M2: decode through the KV-cache incremental engine instead of the naive
// full-recompute sampler. Token-identical to the naive path (gated by
// tests/decode_kv.rs); this flag also lets us A/B the two for the speedup.
let use_cached = args.iter().any(|a| a == "--cached");
// Prompts: skip the `#` header / blank lines and decode escaped newlines so the
// count and order line up with the gold file.
@@ -148,18 +152,26 @@ fn main() {
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
println!(
"eval_arith: ckpt {} | {} prompts | max_new {}",
"eval_arith: ckpt {} | {} prompts | max_new {} | decode={}",
ckpt.display(),
prompts.len(),
max_new
max_new,
if use_cached { "kv-cache" } else { "naive" }
);
let (mut n_boxed, mut n_correct) = (0usize, 0usize);
let mut shown = 0usize;
let mut gen_tokens = 0usize;
let t0 = std::time::Instant::now();
for (prompt, &gold) in prompts.iter().zip(&golds) {
let ids: Vec<i32> = tok.encode(prompt).into_iter().map(|t| t as i32).collect();
let mut rng = 7u64;
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
let out = if use_cached {
xtrain_model::generate_greedy_cached(&model, device, &ids, max_new)
} else {
let mut rng = 7u64;
generate(&model, device, &ids, max_new, 0.0, &mut rng)
};
gen_tokens += out.len() - ids.len();
let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
let seg = first_answer_segment(&cont);
if parse_boxed_answer(seg).is_some() {
@@ -176,6 +188,7 @@ fn main() {
}
}
let elapsed = t0.elapsed().as_secs_f64();
let n = prompts.len() as f64;
println!(
"RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)",
@@ -186,4 +199,11 @@ fn main() {
prompts.len(),
100.0 * n_correct as f64 / n,
);
println!(
"TIMING decode={} | {:.2}s | {} gen tokens | {:.1} tok/s",
if use_cached { "kv-cache" } else { "naive" },
elapsed,
gen_tokens,
gen_tokens as f64 / elapsed,
);
}

View File

@@ -0,0 +1,94 @@
// M2a KV-cache decode engine — the token-identical correctness gate.
//
// The centerpiece M2 invariant: greedy decode through the KV-cache incremental
// engine (`xtrain_model::generate_greedy_cached`) must be TOKEN-IDENTICAL to the
// naive full-recompute greedy (`xtrain_train::sample::generate` at temperature 0),
// which re-runs the whole forward over the growing prefix each step. Same tokens ⇒
// the cache + decode-time attention + RoPE-at-position reproduce the full forward.
//
// Numerics note: a randomly-initialised model has near-uniform logits, so argmax
// can be fragile to ~1e-6 differences. This unit gate therefore runs in F32 (the
// tightest path, and the dtype the eval harness actually uses) on a small model.
// The headline gate on the trained v12 checkpoint (peaked logits → robust argmax)
// is run on the GPU box and recorded in docs/18.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, generate_greedy_cached};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device, dtype: DType) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
m.with_compute_dtype(dtype)
}
// A real GQA config (8 query / 2 kv heads → group 4) to exercise repeat_kv in the
// decode path; head_dim 16, dim 128, 4 layers.
fn gqa_cfg() -> Config {
Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2)
}
#[test]
fn kv_cache_decode_is_token_identical_to_naive_f32() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let model = build(gqa_cfg(), device, DType::F32);
let prompt: Vec<i32> = vec![1, 5, 9, 13, 2, 7];
let max_new = 24usize;
let mut rng = 7u64;
let naive = xtrain_train::sample::generate(&model, device, &prompt, max_new, 0.0, &mut rng);
let cached = generate_greedy_cached(&model, device, &prompt, max_new);
assert_eq!(
naive.len(),
cached.len(),
"length mismatch: naive {} vs cached {}",
naive.len(),
cached.len()
);
if naive != cached {
// Report the first divergence for debugging.
let first = naive
.iter()
.zip(&cached)
.position(|(a, b)| a != b)
.unwrap();
panic!(
"token divergence at index {first}: naive={:?} cached={:?}\nnaive ={naive:?}\ncached ={cached:?}",
naive[first], cached[first]
);
}
println!(
"KV-cache decode token-identical to naive over {} generated tokens (F32, GQA 8/2)",
max_new
);
}