Files
xtrain/crates/xtrain-model/src/decode.rs
Gahow Wang 2c9b58cb3b post-train: M2b — batched KV-cache decode (G-way, token-identical)
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).

- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
  only kernel) — G rows share one decode position. Gate: == full rope for
  [0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
  decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
  broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
  sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:18:54 +08:00

442 lines
18 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! KV-cache incremental-decode engine (post-training M2a, single sequence).
//!
//! The naive sampler ([`crate::TinyTransformer`] via `train::sample::generate`)
//! re-runs the full forward over the whole growing prefix every step — O(t²) and
//! a fresh autograd graph per token. This is the inference engine that replaces it:
//! a per-layer **K/V cache** + a **single-token incremental forward** that processes
//! one new token at a time, attending to the cached keys/values.
//!
//! Built on three primitives, all gated by their own correctness tests:
//! - [`Tensor::rope_at`](xtrain_tensor::Tensor::rope_at): RoPE at the token's
//! absolute position (not row-in-tile), so cached post-RoPE K matches the full
//! forward (bit-identical, `integration::rope_at_matches_full_rope_row`).
//! - [`Tensor::decode_attention`](xtrain_tensor::Tensor::decode_attention): the
//! single-query × cached-K/V SDPA, equal to the full causal attention's last row
//! (`integration::decode_attention_matches_full_attention_last_row`).
//! - this module's per-token block forward, mirroring `model::block_forward` at the
//! raw-Tensor level (no autograd tape — inference needs no gradients).
//!
//! Correctness gate (the M2 centerpiece): KV-cache greedy decode is **token-
//! identical** to the naive full-recompute greedy (`tests/decode_kv.rs`).
//!
//! Prefill is just the first `prompt.len()` decode steps (one token at a time) —
//! one code path, at the cost of a non-batched prefill (M2b adds batched prefill +
//! ragged batch decode). The cache is host-accumulated (token-major f32) and the
//! K/V tensor is rebuilt per step; the host round-trip is small (`num_kv·head_dim`
//! floats/token/layer) and is the honest M2a baseline — M2b moves it device-side.
#![cfg(not(no_cuda))]
use crate::TinyTransformer;
use xtrain_tensor::{DType, Device, Tensor};
/// Per-layer K/V cache: token-major host accumulation. For each layer, `k[li]` and
/// `v[li]` hold `[T, num_kv, head_dim]` (f32, flattened), grown by one token's
/// `num_kv·head_dim` values per decode step. Stored f32 (an exact upcast of the
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
/// tensor, so bf16 values round-trip bit-for-bit.
struct KVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
}
impl KVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
}
}
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
}
}
/// Linear `x @ W` in the compute dtype — mirrors `model::linear` (bf16 casts the
/// fp32-master weight to bf16 on the fly; the activation stream is already bf16).
fn linear_t(cdt: DType, x: &Tensor, w: &Tensor) -> Tensor {
match cdt {
DType::F32 => x.matmul(w),
DType::BF16 => x.matmul(&w.to_dtype(DType::BF16)),
_ => unreachable!("compute dtype must be F32/BF16"),
}
}
/// A norm/QK-norm gamma in the compute dtype — mirrors `model::norm_gamma`.
fn gamma_t(cdt: DType, g: &Tensor) -> Tensor {
match cdt {
DType::F32 => g.clone(),
DType::BF16 => g.to_dtype(DType::BF16),
_ => unreachable!("compute dtype must be F32/BF16"),
}
}
/// Greedy KV-cache decode: continue `prompt` by `max_new` tokens, argmax each step.
/// Returns the full token sequence (prompt + generated), matching the naive
/// `sample::generate` interface for `temperature == 0`. Token-identical to the
/// naive full-recompute greedy (gated by `tests/decode_kv.rs`).
pub fn generate_greedy_cached(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
) -> Vec<i32> {
let mut rng = 0u64;
generate_cached(model, device, prompt, max_new, 0.0, &mut rng)
}
/// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax,
/// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`).
/// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row
/// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far
/// lighter on memory + the allocator — the naive sampler fragments the caching
/// allocator over a long rollout, which is the M4 "rollout is the long pole" wall.
pub fn generate_cached(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<i32> {
assert!(!prompt.is_empty(), "prompt must be non-empty");
let cfg = model.config();
let cdt = model.compute_dtype();
let n_layers = cfg.n_layers;
// params() is a stable, documented order (see TinyTransformer::params):
// [0] = embed [vocab, dim]
// [1 + li*11 .. +11] = layer li's 11 leaves, in block_params order:
// attn_norm, wq, wk, wv, q_norm, k_norm, wo, ffn_norm, w_gate, w_up, w_down
// [1 + n_layers*11] = final_norm [dim]
// [1 + n_layers*11 + 1] = lm_head [dim, vocab]
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
assert_eq!(
params.len(),
1 + n_layers * 11 + 2,
"unexpected param layout for decode"
);
let embed = &params[0];
let final_norm = &params[1 + n_layers * 11];
let lm_head = &params[1 + n_layers * 11 + 1];
let mut cache = KVCache::new(n_layers);
let mut tokens = prompt.to_vec();
// Prefill: feed each prompt token in order; the last step's logits are the
// distribution for the first generated token.
let mut logits = Vec::new();
for (pos, &tok) in prompt.iter().enumerate() {
logits = decode_step(&params, cfg, cdt, device, &mut cache, tok, pos, embed, final_norm, lm_head);
}
for _ in 0..max_new {
let next = if temperature <= 0.0 {
argmax(&logits) as i32
} else {
sample_temperature(&logits, temperature, rng_state) as i32
};
tokens.push(next);
let pos = tokens.len() - 1; // absolute position of the token just appended
logits = decode_step(&params, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
}
tokens
}
/// Sample a token from `softmax(logits / temperature)` (numerically stable). Same
/// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`.
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row.iter().map(|&x| ((x - max) / temperature).exp()).collect();
let sum: f32 = exps.iter().sum();
*rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum;
let mut acc = 0.0;
for (i, &e) in exps.iter().enumerate() {
acc += e;
if acc >= r {
return i;
}
}
exps.len() - 1
}
/// One incremental decode step for token `tok` at absolute position `pos`: append
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
#[allow(clippy::too_many_arguments)]
fn decode_step(
params: &[Tensor],
cfg: &crate::Config,
cdt: DType,
device: Device,
cache: &mut KVCache,
tok: i32,
pos: usize,
embed: &Tensor,
final_norm: &Tensor,
lm_head: &Tensor,
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let kv_dim = num_kv * hd;
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
// Embedding (fp32 table) → activation stream in the compute dtype.
let ids = Tensor::from_slice(&[tok], &[1]).to_device(device);
let mut h = embed.embedding(&ids); // [1, dim] f32
if cdt == DType::BF16 {
h = h.to_dtype(DType::BF16);
}
for li in 0..n_layers {
let base = 1 + li * 11;
let (attn_norm, wq, wk, wv) =
(&params[base], &params[base + 1], &params[base + 2], &params[base + 3]);
let (q_norm, k_norm, wo) = (&params[base + 4], &params[base + 5], &params[base + 6]);
let (ffn_norm, w_gate, w_up, w_down) =
(&params[base + 7], &params[base + 8], &params[base + 9], &params[base + 10]);
// --- Attention sub-block (pre-norm + cached-KV attention + residual) ---
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [1, dim]
// Q: project → per-head QK-norm → RoPE at absolute position `pos`.
let q = linear_t(cdt, &normed, wq).reshape(&[1, nh, hd]); // [1, nh, hd]
let q = q.reshape(&[nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
// → repeat_kv to [nh,T,hd].
let t_len = cache.k[li].len() / kv_dim;
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
.to_device(device)
.transpose_3d01(); // [num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
let attn_out = linear_t(cdt, &attn, wo); // [1, dim]
h = h.add(&attn_out);
// --- MLP sub-block (pre-norm + SwiGLU + residual) ---
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
let gate = linear_t(cdt, &normed, w_gate);
let up = linear_t(cdt, &normed, w_up);
let act = gate.silu().mul(&up); // swiglu = silu(gate) ∘ up
let down = linear_t(cdt, &act, w_down);
h = h.add(&down);
}
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
let logits = linear_t(cdt, &h, lm_head); // [1, vocab]
logits
.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
fn argmax(row: &[f32]) -> usize {
row.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0
}
// ===================================================================
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
// ===================================================================
/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates
/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt
/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension.
struct BatchKVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
}
impl BatchKVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
}
}
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
}
}
/// Batched KV-cache decode: roll out `n_samples` (G) completions of the SAME
/// `prompt` in lockstep — all G share the prompt, so they advance at one common
/// decode position each step (uniform RoPE via `rope_pos`). Returns G full token
/// sequences (prompt + sampled continuation). The G-way batching amortises the
/// per-step kernel launches across G (the rollout long-pole). Token-identical per
/// row to G independent single-sequence decodes (gated by `tests/decode_batch.rs`).
///
/// `temperature == 0` ⇒ greedy (all G identical); `> 0` ⇒ independent samples
/// (per-row draw from one shared `rng_state`). No finished-mask: all G generate
/// `max_new` tokens; the caller cuts each at `<|endoftext|>` (a perf-only early
/// stop is the M2b+ follow-up). Ragged (different-length prompts) is also deferred.
pub fn generate_cached_batch(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
n_samples: usize,
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<Vec<i32>> {
assert!(!prompt.is_empty(), "prompt must be non-empty");
assert!(n_samples > 0, "n_samples must be > 0");
let cfg = model.config();
let cdt = model.compute_dtype();
let n_layers = cfg.n_layers;
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
let embed = &params[0];
let final_norm = &params[1 + n_layers * 11];
let lm_head = &params[1 + n_layers * 11 + 1];
let g = n_samples;
let mut cache = BatchKVCache::new(n_layers);
let mut seqs: Vec<Vec<i32>> = vec![prompt.to_vec(); g];
// Prefill: feed each prompt token (identical across G) at its position.
let mut logits = Vec::new(); // [G, vocab] flattened
for (pos, &tok) in prompt.iter().enumerate() {
let toks = vec![tok; g];
logits = decode_step_batch(&params, cfg, cdt, device, &mut cache, &toks, pos, embed, final_norm, lm_head);
}
let vocab = cfg.vocab;
for _ in 0..max_new {
let mut next = Vec::with_capacity(g);
for row in 0..g {
let lg = &logits[row * vocab..(row + 1) * vocab];
let t = if temperature <= 0.0 {
argmax(lg) as i32
} else {
sample_temperature(lg, temperature, rng_state) as i32
};
next.push(t);
seqs[row].push(t);
}
let pos = seqs[0].len() - 1; // all G are at the same position
logits = decode_step_batch(&params, cfg, cdt, device, &mut cache, &next, pos, embed, final_norm, lm_head);
}
seqs
}
/// One batched decode step: `toks` is one current token per sequence (`[G]`), all at
/// absolute position `pos`. Appends each sequence's K/V and returns logits `[G·vocab]`.
#[allow(clippy::too_many_arguments)]
fn decode_step_batch(
params: &[Tensor],
cfg: &crate::Config,
cdt: DType,
device: Device,
cache: &mut BatchKVCache,
toks: &[i32],
pos: usize,
embed: &Tensor,
final_norm: &Tensor,
lm_head: &Tensor,
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let g = toks.len();
let g_kv = g * num_kv; // batch·num_kv heads in the cache
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
// Uniform per-row position (all G at the same decode step).
let positions = Tensor::from_slice(&vec![pos as i32; g], &[g]).to_device(device);
let ids = Tensor::from_slice(toks, &[g]).to_device(device);
let mut h = embed.embedding(&ids); // [G, dim] f32
if cdt == DType::BF16 {
h = h.to_dtype(DType::BF16);
}
for li in 0..n_layers {
let base = 1 + li * 11;
let (attn_norm, wq, wk, wv) =
(&params[base], &params[base + 1], &params[base + 2], &params[base + 3]);
let (q_norm, k_norm, wo) = (&params[base + 4], &params[base + 5], &params[base + 6]);
let (ffn_norm, w_gate, w_up, w_down) =
(&params[base + 7], &params[base + 8], &params[base + 9], &params[base + 10]);
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [G, dim]
// Q: project → per-head QK-norm → RoPE at `pos` for every row.
let q = linear_t(cdt, &normed, wq).reshape(&[g, nh, hd]);
let q = q.reshape(&[g * nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta);
let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd].
let t_len = cache.k[li].len() / (g_kv * hd);
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd])
.to_device(device)
.transpose_3d01(); // [G·num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
let attn = attn.reshape(&[g, dim]); // concat heads per sequence
let attn_out = linear_t(cdt, &attn, wo);
h = h.add(&attn_out);
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
let gate = linear_t(cdt, &normed, w_gate);
let up = linear_t(cdt, &normed, w_up);
let act = gate.silu().mul(&up);
let down = linear_t(cdt, &act, w_down);
h = h.add(&down);
}
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
linear_t(cdt, &h, lm_head)
.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}