Files
xtrain/crates/xtrain-model/src/decode.rs
Gahow Wang 3a3425960c post-train: M2c — device-side KV cache (cat_seq), profile-first bottleneck shift
Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).

Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.

Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:38:16 +08:00

437 lines
18 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! KV-cache incremental-decode engine (post-training M2a, single sequence).
//!
//! The naive sampler ([`crate::TinyTransformer`] via `train::sample::generate`)
//! re-runs the full forward over the whole growing prefix every step — O(t²) and
//! a fresh autograd graph per token. This is the inference engine that replaces it:
//! a per-layer **K/V cache** + a **single-token incremental forward** that processes
//! one new token at a time, attending to the cached keys/values.
//!
//! Built on three primitives, all gated by their own correctness tests:
//! - [`Tensor::rope_at`](xtrain_tensor::Tensor::rope_at): RoPE at the token's
//! absolute position (not row-in-tile), so cached post-RoPE K matches the full
//! forward (bit-identical, `integration::rope_at_matches_full_rope_row`).
//! - [`Tensor::decode_attention`](xtrain_tensor::Tensor::decode_attention): the
//! single-query × cached-K/V SDPA, equal to the full causal attention's last row
//! (`integration::decode_attention_matches_full_attention_last_row`).
//! - this module's per-token block forward, mirroring `model::block_forward` at the
//! raw-Tensor level (no autograd tape — inference needs no gradients).
//!
//! Correctness gate (the M2 centerpiece): KV-cache greedy decode is **token-
//! identical** to the naive full-recompute greedy (`tests/decode_kv.rs`).
//!
//! Prefill is just the first `prompt.len()` decode steps (one token at a time) —
//! one code path, at the cost of a non-batched prefill (M2b adds batched prefill +
//! ragged batch decode). The cache is host-accumulated (token-major f32) and the
//! K/V tensor is rebuilt per step; the host round-trip is small (`num_kv·head_dim`
//! floats/token/layer) and is the honest M2a baseline — M2b moves it device-side.
#![cfg(not(no_cuda))]
use crate::TinyTransformer;
use xtrain_tensor::{DType, Device, Tensor};
/// Per-layer K/V cache: token-major host accumulation. For each layer, `k[li]` and
/// `v[li]` hold `[T, num_kv, head_dim]` (f32, flattened), grown by one token's
/// `num_kv·head_dim` values per decode step. Stored f32 (an exact upcast of the
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
/// tensor, so bf16 values round-trip bit-for-bit.
struct KVCache {
k: Vec<Option<Tensor>>,
v: Vec<Option<Tensor>>,
}
impl KVCache {
fn new(n_layers: usize) -> Self {
Self {
k: (0..n_layers).map(|_| None).collect(),
v: (0..n_layers).map(|_| None).collect(),
}
}
/// Append one token's K/V (`[bh,1,hd]`, compute dtype) to layer `li`, growing the
/// device-resident `[bh,T,hd]` cache via `cat_seq` (no host round-trip, M2c).
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
self.k[li] = Some(match self.k[li].take() {
Some(c) => c.cat_seq(&k_bh),
None => k_bh,
});
self.v[li] = Some(match self.v[li].take() {
Some(c) => c.cat_seq(&v_bh),
None => v_bh,
});
}
}
/// Linear `x @ W` in the compute dtype — mirrors `model::linear` (bf16 casts the
/// fp32-master weight to bf16 on the fly; the activation stream is already bf16).
fn linear_t(cdt: DType, x: &Tensor, w: &Tensor) -> Tensor {
match cdt {
DType::F32 => x.matmul(w),
DType::BF16 => x.matmul(&w.to_dtype(DType::BF16)),
_ => unreachable!("compute dtype must be F32/BF16"),
}
}
/// A norm/QK-norm gamma in the compute dtype — mirrors `model::norm_gamma`.
fn gamma_t(cdt: DType, g: &Tensor) -> Tensor {
match cdt {
DType::F32 => g.clone(),
DType::BF16 => g.to_dtype(DType::BF16),
_ => unreachable!("compute dtype must be F32/BF16"),
}
}
/// Greedy KV-cache decode: continue `prompt` by `max_new` tokens, argmax each step.
/// Returns the full token sequence (prompt + generated), matching the naive
/// `sample::generate` interface for `temperature == 0`. Token-identical to the
/// naive full-recompute greedy (gated by `tests/decode_kv.rs`).
pub fn generate_greedy_cached(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
) -> Vec<i32> {
let mut rng = 0u64;
generate_cached(model, device, prompt, max_new, 0.0, &mut rng)
}
/// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax,
/// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`).
/// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row
/// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far
/// lighter on memory + the allocator — the naive sampler fragments the caching
/// allocator over a long rollout, which is the M4 "rollout is the long pole" wall.
pub fn generate_cached(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<i32> {
assert!(!prompt.is_empty(), "prompt must be non-empty");
let cfg = model.config();
let cdt = model.compute_dtype();
let n_layers = cfg.n_layers;
// params() is a stable, documented order (see TinyTransformer::params):
// [0] = embed [vocab, dim]
// [1 + li*11 .. +11] = layer li's 11 leaves, in block_params order:
// attn_norm, wq, wk, wv, q_norm, k_norm, wo, ffn_norm, w_gate, w_up, w_down
// [1 + n_layers*11] = final_norm [dim]
// [1 + n_layers*11 + 1] = lm_head [dim, vocab]
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
assert_eq!(
params.len(),
1 + n_layers * 11 + 2,
"unexpected param layout for decode"
);
let embed = &params[0];
let final_norm = &params[1 + n_layers * 11];
let lm_head = &params[1 + n_layers * 11 + 1];
let mut cache = KVCache::new(n_layers);
let mut tokens = prompt.to_vec();
// Prefill: feed each prompt token in order; the last step's logits are the
// distribution for the first generated token.
let mut logits = Vec::new();
for (pos, &tok) in prompt.iter().enumerate() {
logits = decode_step(&params, cfg, cdt, device, &mut cache, tok, pos, embed, final_norm, lm_head);
}
for _ in 0..max_new {
let next = if temperature <= 0.0 {
argmax(&logits) as i32
} else {
sample_temperature(&logits, temperature, rng_state) as i32
};
tokens.push(next);
let pos = tokens.len() - 1; // absolute position of the token just appended
logits = decode_step(&params, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
}
tokens
}
/// Sample a token from `softmax(logits / temperature)` (numerically stable). Same
/// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`.
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row.iter().map(|&x| ((x - max) / temperature).exp()).collect();
let sum: f32 = exps.iter().sum();
*rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum;
let mut acc = 0.0;
for (i, &e) in exps.iter().enumerate() {
acc += e;
if acc >= r {
return i;
}
}
exps.len() - 1
}
/// One incremental decode step for token `tok` at absolute position `pos`: append
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
#[allow(clippy::too_many_arguments)]
fn decode_step(
params: &[Tensor],
cfg: &crate::Config,
cdt: DType,
device: Device,
cache: &mut KVCache,
tok: i32,
pos: usize,
embed: &Tensor,
final_norm: &Tensor,
lm_head: &Tensor,
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
// Embedding (fp32 table) → activation stream in the compute dtype.
let ids = Tensor::from_slice(&[tok], &[1]).to_device(device);
let mut h = embed.embedding(&ids); // [1, dim] f32
if cdt == DType::BF16 {
h = h.to_dtype(DType::BF16);
}
for li in 0..n_layers {
let base = 1 + li * 11;
let (attn_norm, wq, wk, wv) =
(&params[base], &params[base + 1], &params[base + 2], &params[base + 3]);
let (q_norm, k_norm, wo) = (&params[base + 4], &params[base + 5], &params[base + 6]);
let (ffn_norm, w_gate, w_up, w_down) =
(&params[base + 7], &params[base + 8], &params[base + 9], &params[base + 10]);
// --- Attention sub-block (pre-norm + cached-KV attention + residual) ---
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [1, dim]
// Q: project → per-head QK-norm → RoPE at absolute position `pos`.
let q = linear_t(cdt, &normed, wq).reshape(&[1, nh, hd]); // [1, nh, hd]
let q = q.reshape(&[nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
// K: same as Q (QK-norm + RoPE). V: project only. Append each as [num_kv,1,hd]
// (bh-major) into the device cache; no host round-trip, no transpose (M2c).
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_bh = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos).reshape(&[num_kv, 1, hd]);
let v_bh = linear_t(cdt, &normed, wv).reshape(&[num_kv, 1, hd]);
cache.append(li, k_bh, v_bh);
// repeat_kv the cached [num_kv,T,hd] to [nh,T,hd] for the SDPA.
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, 1) };
let k_full = expand(cache.k[li].as_ref().unwrap());
let v_full = expand(cache.v[li].as_ref().unwrap());
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
let attn_out = linear_t(cdt, &attn, wo); // [1, dim]
h = h.add(&attn_out);
// --- MLP sub-block (pre-norm + SwiGLU + residual) ---
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
let gate = linear_t(cdt, &normed, w_gate);
let up = linear_t(cdt, &normed, w_up);
let act = gate.silu().mul(&up); // swiglu = silu(gate) ∘ up
let down = linear_t(cdt, &act, w_down);
h = h.add(&down);
}
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
let logits = linear_t(cdt, &h, lm_head); // [1, vocab]
logits
.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
fn argmax(row: &[f32]) -> usize {
row.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0
}
// ===================================================================
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
// ===================================================================
/// Batched K/V cache: `G` sequences advancing together. Per layer, a device-resident
/// `[G·num_kv, T, head_dim]` grown one token per step via `cat_seq` (M2c — no host
/// round-trip). Same as M2a's device cache with a G dimension in `bh`.
struct BatchKVCache {
k: Vec<Option<Tensor>>,
v: Vec<Option<Tensor>>,
}
impl BatchKVCache {
fn new(n_layers: usize) -> Self {
Self {
k: (0..n_layers).map(|_| None).collect(),
v: (0..n_layers).map(|_| None).collect(),
}
}
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
self.k[li] = Some(match self.k[li].take() {
Some(c) => c.cat_seq(&k_bh),
None => k_bh,
});
self.v[li] = Some(match self.v[li].take() {
Some(c) => c.cat_seq(&v_bh),
None => v_bh,
});
}
}
/// Batched KV-cache decode: roll out `n_samples` (G) completions of the SAME
/// `prompt` in lockstep — all G share the prompt, so they advance at one common
/// decode position each step (uniform RoPE via `rope_pos`). Returns G full token
/// sequences (prompt + sampled continuation). The G-way batching amortises the
/// per-step kernel launches across G (the rollout long-pole). Token-identical per
/// row to G independent single-sequence decodes (gated by `tests/decode_batch.rs`).
///
/// `temperature == 0` ⇒ greedy (all G identical); `> 0` ⇒ independent samples
/// (per-row draw from one shared `rng_state`). No finished-mask: all G generate
/// `max_new` tokens; the caller cuts each at `<|endoftext|>` (a perf-only early
/// stop is the M2b+ follow-up). Ragged (different-length prompts) is also deferred.
pub fn generate_cached_batch(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
n_samples: usize,
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<Vec<i32>> {
assert!(!prompt.is_empty(), "prompt must be non-empty");
assert!(n_samples > 0, "n_samples must be > 0");
let cfg = model.config();
let cdt = model.compute_dtype();
let n_layers = cfg.n_layers;
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
let embed = &params[0];
let final_norm = &params[1 + n_layers * 11];
let lm_head = &params[1 + n_layers * 11 + 1];
let g = n_samples;
let mut cache = BatchKVCache::new(n_layers);
let mut seqs: Vec<Vec<i32>> = vec![prompt.to_vec(); g];
// Prefill: feed each prompt token (identical across G) at its position.
let mut logits = Vec::new(); // [G, vocab] flattened
for (pos, &tok) in prompt.iter().enumerate() {
let toks = vec![tok; g];
logits = decode_step_batch(&params, cfg, cdt, device, &mut cache, &toks, pos, embed, final_norm, lm_head);
}
let vocab = cfg.vocab;
for _ in 0..max_new {
let mut next = Vec::with_capacity(g);
for row in 0..g {
let lg = &logits[row * vocab..(row + 1) * vocab];
let t = if temperature <= 0.0 {
argmax(lg) as i32
} else {
sample_temperature(lg, temperature, rng_state) as i32
};
next.push(t);
seqs[row].push(t);
}
let pos = seqs[0].len() - 1; // all G are at the same position
logits = decode_step_batch(&params, cfg, cdt, device, &mut cache, &next, pos, embed, final_norm, lm_head);
}
seqs
}
/// One batched decode step: `toks` is one current token per sequence (`[G]`), all at
/// absolute position `pos`. Appends each sequence's K/V and returns logits `[G·vocab]`.
#[allow(clippy::too_many_arguments)]
fn decode_step_batch(
params: &[Tensor],
cfg: &crate::Config,
cdt: DType,
device: Device,
cache: &mut BatchKVCache,
toks: &[i32],
pos: usize,
embed: &Tensor,
final_norm: &Tensor,
lm_head: &Tensor,
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let g = toks.len();
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
// Uniform per-row position (all G at the same decode step).
let positions = Tensor::from_slice(&vec![pos as i32; g], &[g]).to_device(device);
let ids = Tensor::from_slice(toks, &[g]).to_device(device);
let mut h = embed.embedding(&ids); // [G, dim] f32
if cdt == DType::BF16 {
h = h.to_dtype(DType::BF16);
}
for li in 0..n_layers {
let base = 1 + li * 11;
let (attn_norm, wq, wk, wv) =
(&params[base], &params[base + 1], &params[base + 2], &params[base + 3]);
let (q_norm, k_norm, wo) = (&params[base + 4], &params[base + 5], &params[base + 6]);
let (ffn_norm, w_gate, w_up, w_down) =
(&params[base + 7], &params[base + 8], &params[base + 9], &params[base + 10]);
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [G, dim]
// Q: project → per-head QK-norm → RoPE at `pos` for every row.
let q = linear_t(cdt, &normed, wq).reshape(&[g, nh, hd]);
let q = q.reshape(&[g * nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
// K/V appended as [G·num_kv,1,hd] (bh-major) into the device cache (M2c).
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_bh = k
.reshape(&[g, num_kv, hd])
.rope_pos(&positions, theta)
.reshape(&[g * num_kv, 1, hd]);
let v_bh = linear_t(cdt, &normed, wv).reshape(&[g * num_kv, 1, hd]);
cache.append(li, k_bh, v_bh);
// repeat_kv the cached [G·num_kv,T,hd] to [G·nh,T,hd] for the SDPA.
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, g) };
let k_full = expand(cache.k[li].as_ref().unwrap());
let v_full = expand(cache.v[li].as_ref().unwrap());
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
let attn = attn.reshape(&[g, dim]); // concat heads per sequence
let attn_out = linear_t(cdt, &attn, wo);
h = h.add(&attn_out);
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
let gate = linear_t(cdt, &normed, w_gate);
let up = linear_t(cdt, &normed, w_up);
let act = gate.silu().mul(&up);
let down = linear_t(cdt, &act, w_down);
h = h.add(&down);
}
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
linear_t(cdt, &h, lm_head)
.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}