Compare commits
3 Commits
1574e21d89
...
b39e6e7110
| Author | SHA1 | Date | |
|---|---|---|---|
| b39e6e7110 | |||
| eff26a0898 | |||
| c88e2ab88c |
@@ -139,6 +139,19 @@ unsafe extern "C" {
|
||||
period: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// RoPE at an absolute position offset (KV-cache decode, forward only): row
|
||||
// `tok`'s position is `pos0 + tok` (no modulo). For a single decode token
|
||||
// (tokens == 1) the one row sits at absolute position `pos0`.
|
||||
pub fn launch_rope_at_f32(
|
||||
x: *const f32,
|
||||
y: *mut f32,
|
||||
tokens: i32,
|
||||
heads: i32,
|
||||
head_dim: i32,
|
||||
theta: f32,
|
||||
pos0: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_rope_dx_f32(
|
||||
dy: *const f32,
|
||||
dx: *mut f32,
|
||||
|
||||
225
crates/xtrain-model/src/decode.rs
Normal file
225
crates/xtrain-model/src/decode.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
//! KV-cache incremental-decode engine (post-training M2a, single sequence).
|
||||
//!
|
||||
//! The naive sampler ([`crate::TinyTransformer`] via `train::sample::generate`)
|
||||
//! re-runs the full forward over the whole growing prefix every step — O(t²) and
|
||||
//! a fresh autograd graph per token. This is the inference engine that replaces it:
|
||||
//! a per-layer **K/V cache** + a **single-token incremental forward** that processes
|
||||
//! one new token at a time, attending to the cached keys/values.
|
||||
//!
|
||||
//! Built on three primitives, all gated by their own correctness tests:
|
||||
//! - [`Tensor::rope_at`](xtrain_tensor::Tensor::rope_at): RoPE at the token's
|
||||
//! absolute position (not row-in-tile), so cached post-RoPE K matches the full
|
||||
//! forward (bit-identical, `integration::rope_at_matches_full_rope_row`).
|
||||
//! - [`Tensor::decode_attention`](xtrain_tensor::Tensor::decode_attention): the
|
||||
//! single-query × cached-K/V SDPA, equal to the full causal attention's last row
|
||||
//! (`integration::decode_attention_matches_full_attention_last_row`).
|
||||
//! - this module's per-token block forward, mirroring `model::block_forward` at the
|
||||
//! raw-Tensor level (no autograd tape — inference needs no gradients).
|
||||
//!
|
||||
//! Correctness gate (the M2 centerpiece): KV-cache greedy decode is **token-
|
||||
//! identical** to the naive full-recompute greedy (`tests/decode_kv.rs`).
|
||||
//!
|
||||
//! Prefill is just the first `prompt.len()` decode steps (one token at a time) —
|
||||
//! one code path, at the cost of a non-batched prefill (M2b adds batched prefill +
|
||||
//! ragged batch decode). The cache is host-accumulated (token-major f32) and the
|
||||
//! K/V tensor is rebuilt per step; the host round-trip is small (`num_kv·head_dim`
|
||||
//! floats/token/layer) and is the honest M2a baseline — M2b moves it device-side.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use crate::TinyTransformer;
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
|
||||
/// Per-layer K/V cache: token-major host accumulation. For each layer, `k[li]` and
|
||||
/// `v[li]` hold `[T, num_kv, head_dim]` (f32, flattened), grown by one token's
|
||||
/// `num_kv·head_dim` values per decode step. Stored f32 (an exact upcast of the
|
||||
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
|
||||
/// tensor, so bf16 values round-trip bit-for-bit.
|
||||
struct KVCache {
|
||||
k: Vec<Vec<f32>>,
|
||||
v: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
fn new(n_layers: usize) -> Self {
|
||||
Self {
|
||||
k: vec![Vec::new(); n_layers],
|
||||
v: vec![Vec::new(); n_layers],
|
||||
}
|
||||
}
|
||||
|
||||
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
|
||||
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
|
||||
self.k[li].extend_from_slice(k_tok);
|
||||
self.v[li].extend_from_slice(v_tok);
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear `x @ W` in the compute dtype — mirrors `model::linear` (bf16 casts the
|
||||
/// fp32-master weight to bf16 on the fly; the activation stream is already bf16).
|
||||
fn linear_t(cdt: DType, x: &Tensor, w: &Tensor) -> Tensor {
|
||||
match cdt {
|
||||
DType::F32 => x.matmul(w),
|
||||
DType::BF16 => x.matmul(&w.to_dtype(DType::BF16)),
|
||||
_ => unreachable!("compute dtype must be F32/BF16"),
|
||||
}
|
||||
}
|
||||
|
||||
/// A norm/QK-norm gamma in the compute dtype — mirrors `model::norm_gamma`.
|
||||
fn gamma_t(cdt: DType, g: &Tensor) -> Tensor {
|
||||
match cdt {
|
||||
DType::F32 => g.clone(),
|
||||
DType::BF16 => g.to_dtype(DType::BF16),
|
||||
_ => unreachable!("compute dtype must be F32/BF16"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Greedy KV-cache decode: continue `prompt` by `max_new` tokens, argmax each step.
|
||||
/// Returns the full token sequence (prompt + generated), matching the naive
|
||||
/// `sample::generate` interface for `temperature == 0`. Token-identical to the
|
||||
/// naive full-recompute greedy (gated by `tests/decode_kv.rs`).
|
||||
pub fn generate_greedy_cached(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
prompt: &[i32],
|
||||
max_new: usize,
|
||||
) -> Vec<i32> {
|
||||
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
||||
let cfg = model.config();
|
||||
let cdt = model.compute_dtype();
|
||||
let n_layers = cfg.n_layers;
|
||||
|
||||
// params() is a stable, documented order (see TinyTransformer::params):
|
||||
// [0] = embed [vocab, dim]
|
||||
// [1 + li*11 .. +11] = layer li's 11 leaves, in block_params order:
|
||||
// attn_norm, wq, wk, wv, q_norm, k_norm, wo, ffn_norm, w_gate, w_up, w_down
|
||||
// [1 + n_layers*11] = final_norm [dim]
|
||||
// [1 + n_layers*11 + 1] = lm_head [dim, vocab]
|
||||
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
|
||||
assert_eq!(
|
||||
params.len(),
|
||||
1 + n_layers * 11 + 2,
|
||||
"unexpected param layout for decode"
|
||||
);
|
||||
let embed = ¶ms[0];
|
||||
let final_norm = ¶ms[1 + n_layers * 11];
|
||||
let lm_head = ¶ms[1 + n_layers * 11 + 1];
|
||||
|
||||
let mut cache = KVCache::new(n_layers);
|
||||
let mut tokens = prompt.to_vec();
|
||||
|
||||
// Prefill: feed each prompt token in order; the last step's logits are the
|
||||
// distribution for the first generated token.
|
||||
let mut logits = Vec::new();
|
||||
for (pos, &tok) in prompt.iter().enumerate() {
|
||||
logits = decode_step(¶ms, cfg, cdt, device, &mut cache, tok, pos, embed, final_norm, lm_head);
|
||||
}
|
||||
|
||||
for _ in 0..max_new {
|
||||
let next = argmax(&logits) as i32;
|
||||
tokens.push(next);
|
||||
let pos = tokens.len() - 1; // absolute position of the token just appended
|
||||
logits = decode_step(¶ms, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
|
||||
}
|
||||
tokens
|
||||
}
|
||||
|
||||
/// One incremental decode step for token `tok` at absolute position `pos`: append
|
||||
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn decode_step(
|
||||
params: &[Tensor],
|
||||
cfg: &crate::Config,
|
||||
cdt: DType,
|
||||
device: Device,
|
||||
cache: &mut KVCache,
|
||||
tok: i32,
|
||||
pos: usize,
|
||||
embed: &Tensor,
|
||||
final_norm: &Tensor,
|
||||
lm_head: &Tensor,
|
||||
) -> Vec<f32> {
|
||||
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||||
let dim = cfg.dim;
|
||||
let kv_dim = num_kv * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||||
let n_layers = cfg.n_layers;
|
||||
|
||||
// Embedding (fp32 table) → activation stream in the compute dtype.
|
||||
let ids = Tensor::from_slice(&[tok], &[1]).to_device(device);
|
||||
let mut h = embed.embedding(&ids); // [1, dim] f32
|
||||
if cdt == DType::BF16 {
|
||||
h = h.to_dtype(DType::BF16);
|
||||
}
|
||||
|
||||
for li in 0..n_layers {
|
||||
let base = 1 + li * 11;
|
||||
let (attn_norm, wq, wk, wv) =
|
||||
(¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]);
|
||||
let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]);
|
||||
let (ffn_norm, w_gate, w_up, w_down) =
|
||||
(¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]);
|
||||
|
||||
// --- Attention sub-block (pre-norm + cached-KV attention + residual) ---
|
||||
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [1, dim]
|
||||
|
||||
// Q: project → per-head QK-norm → RoPE at absolute position `pos`.
|
||||
let q = linear_t(cdt, &normed, wq).reshape(&[1, nh, hd]); // [1, nh, hd]
|
||||
let q = q.reshape(&[nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
|
||||
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
|
||||
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
|
||||
|
||||
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
|
||||
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
|
||||
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||||
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
|
||||
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
|
||||
|
||||
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
cache.append(li, &k_host, &v_host);
|
||||
|
||||
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
|
||||
// → repeat_kv to [nh,T,hd].
|
||||
let t_len = cache.k[li].len() / kv_dim;
|
||||
let build = |flat: &[f32]| -> Tensor {
|
||||
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
|
||||
.to_device(device)
|
||||
.transpose_3d01(); // [num_kv, T, hd], f32
|
||||
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
|
||||
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
|
||||
};
|
||||
let k_full = build(&cache.k[li]);
|
||||
let v_full = build(&cache.v[li]);
|
||||
|
||||
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
|
||||
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
|
||||
let attn_out = linear_t(cdt, &attn, wo); // [1, dim]
|
||||
h = h.add(&attn_out);
|
||||
|
||||
// --- MLP sub-block (pre-norm + SwiGLU + residual) ---
|
||||
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
|
||||
let gate = linear_t(cdt, &normed, w_gate);
|
||||
let up = linear_t(cdt, &normed, w_up);
|
||||
let act = gate.silu().mul(&up); // swiglu = silu(gate) ∘ up
|
||||
let down = linear_t(cdt, &act, w_down);
|
||||
h = h.add(&down);
|
||||
}
|
||||
|
||||
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
|
||||
let logits = linear_t(cdt, &h, lm_head); // [1, vocab]
|
||||
logits
|
||||
.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
fn argmax(row: &[f32]) -> usize {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.unwrap()
|
||||
.0
|
||||
}
|
||||
@@ -25,3 +25,8 @@ pub use config::Config;
|
||||
mod model;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod decode;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use decode::generate_greedy_cached;
|
||||
|
||||
@@ -790,6 +790,38 @@ impl Tensor {
|
||||
out
|
||||
}
|
||||
|
||||
/// RoPE at an absolute position offset (KV-cache decode, forward only).
|
||||
/// `self`:[tokens,heads,head_dim]; row `r`'s position is `pos0 + r` (no
|
||||
/// modulo). For a single new decode token pass `tokens == 1` → the one row is
|
||||
/// rotated at absolute position `pos0`. Mirrors [`rope`](Self::rope)'s dtype
|
||||
/// handling (bf16 → f32 → bf16); no backward (inference path).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rope_at(&self, theta: f32, pos0: usize) -> Self {
|
||||
assert_eq!(self.ndim(), 3, "rope_at requires [tokens,heads,head_dim]");
|
||||
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
assert_eq!(head_dim % 2, 0, "head_dim must be even");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.rope_at(theta, pos0)
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rope_at_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
tokens as i32,
|
||||
heads as i32,
|
||||
head_dim as i32,
|
||||
theta,
|
||||
pos0 as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
||||
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
||||
#[cfg(not(no_cuda))]
|
||||
@@ -1076,6 +1108,76 @@ impl Tensor {
|
||||
(out, probs)
|
||||
}
|
||||
|
||||
/// Decode-time (incremental) attention: a SINGLE query position against a
|
||||
/// cached K/V of length `t` (KV-cache decode, forward only). `self` = Q
|
||||
/// `[bh,1,head_dim]`; `k`,`v` = `[bh,t,head_dim]`, already repeat_kv-expanded
|
||||
/// to `bh` heads. Returns out `[bh,head_dim]` (= `[bh,1,head_dim]` flattened).
|
||||
///
|
||||
/// No causal mask is needed — the one query sits at the end, so every cached
|
||||
/// key (positions `0..t`) is visible. This is exactly the LAST query row of the
|
||||
/// full causal [`attention`](Self::attention), so KV-cache greedy decode is
|
||||
/// token-identical to full recompute. Softmax is computed in f32 (matching the
|
||||
/// causal path) with `scale` folded in before the exponentials.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn decode_attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> Self {
|
||||
assert_eq!(self.ndim(), 3, "decode_attention Q must be [bh,1,head_dim]");
|
||||
assert_eq!(self.shape[1], 1, "decode_attention Q seq must be 1");
|
||||
assert_eq!(k.ndim(), 3, "decode_attention K must be [bh,t,head_dim]");
|
||||
assert_eq!(k.shape(), v.shape(), "K/V shape mismatch");
|
||||
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
|
||||
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
|
||||
let (bh, hd) = (self.shape[0], self.shape[2]);
|
||||
assert_eq!(k.shape[0], bh, "Q/K batch-head mismatch");
|
||||
assert_eq!(k.shape[2], hd, "Q/K head_dim mismatch");
|
||||
let t = k.shape[1]; // cached length
|
||||
let dt = self.dtype;
|
||||
let dev = self.device();
|
||||
|
||||
// scores[bh,1,t] = Q[bh,1,hd] · Kᵀ[bh,hd,t] (per-head batched GEMM).
|
||||
// [bh,1,t] is stored identically to [bh,t]; allocate 2D so the rowwise
|
||||
// softmax can run without a reshape.
|
||||
let scores = Tensor::zeros(&[bh, t], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
false,
|
||||
true,
|
||||
1,
|
||||
t,
|
||||
hd,
|
||||
self.data_ptr(),
|
||||
hd,
|
||||
k.data_ptr(),
|
||||
t * hd,
|
||||
scores.data_ptr(),
|
||||
t,
|
||||
bh,
|
||||
);
|
||||
// probs = softmax(scale · scores) over the t keys (f32, like the causal path).
|
||||
let probs = scores
|
||||
.to_dtype(DType::F32)
|
||||
.scale(scale)
|
||||
.softmax()
|
||||
.to_dtype(dt);
|
||||
// out[bh,1,hd] = probs[bh,1,t] · V[bh,t,hd].
|
||||
let out = Tensor::zeros(&[bh, hd], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
false,
|
||||
false,
|
||||
1,
|
||||
hd,
|
||||
t,
|
||||
probs.data_ptr(),
|
||||
t,
|
||||
v.data_ptr(),
|
||||
t * hd,
|
||||
out.data_ptr(),
|
||||
hd,
|
||||
bh,
|
||||
);
|
||||
out
|
||||
}
|
||||
|
||||
/// Backward of [`attention`](Self::attention). Inputs: forward `q`,`k`,`v`,
|
||||
/// the cached `probs`, the upstream `dout` (all batched `[bh,seq,*]`), and the
|
||||
/// same `scale`. Returns `(dq, dk, dv)`.
|
||||
|
||||
@@ -56,3 +56,106 @@ fn elementwise_scale_kernel() {
|
||||
r.len()
|
||||
);
|
||||
}
|
||||
|
||||
/// (c) `rope_at` (KV-cache decode RoPE at an absolute position) is bit-identical
|
||||
/// to the full-sequence `rope`'s corresponding row. This is the invariant the
|
||||
/// decode KV-cache relies on: a single new token RoPE'd at position `t` must equal
|
||||
/// what the full-sequence forward would have produced at row `t` (so cached
|
||||
/// post-RoPE K matches the full-recompute path → token-identical decode).
|
||||
#[test]
|
||||
fn rope_at_matches_full_rope_row() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
|
||||
let (n, heads, hd) = (7usize, 3usize, 8usize);
|
||||
let theta = 10000.0f32;
|
||||
// Deterministic pseudo-random fill in [-1, 1).
|
||||
let host: Vec<f32> = (0..n * heads * hd)
|
||||
.map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0)
|
||||
.collect();
|
||||
|
||||
// Full-sequence rope (period = n → row r gets position r).
|
||||
let full = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0));
|
||||
let roped_full = full
|
||||
.rope(theta, n)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec();
|
||||
|
||||
let row_len = heads * hd;
|
||||
for t in 0..n {
|
||||
let row = &host[t * row_len..(t + 1) * row_len];
|
||||
let roped_row = Tensor::from_slice(row, &[1, heads, hd])
|
||||
.to_device(Device::Cuda(0))
|
||||
.rope_at(theta, t)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec();
|
||||
let expect = &roped_full[t * row_len..(t + 1) * row_len];
|
||||
assert_eq!(
|
||||
roped_row.as_slice(),
|
||||
expect,
|
||||
"rope_at(pos0={t}) != full rope row {t}"
|
||||
);
|
||||
}
|
||||
println!("rope_at OK: bit-identical to full rope across {n} positions");
|
||||
}
|
||||
|
||||
/// (d) `decode_attention` (single query vs cached K/V, no mask) equals the LAST
|
||||
/// query row of the full causal `attention`. This is the core decode-engine
|
||||
/// invariant: the incremental path must reproduce what the full-recompute forward
|
||||
/// computes for the final position, so KV-cache greedy decode is token-identical.
|
||||
/// Tolerance is fp rounding (different softmax kernel + reduction order), not bits.
|
||||
#[test]
|
||||
fn decode_attention_matches_full_attention_last_row() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
|
||||
let (bh, t, hd) = (6usize, 5usize, 8usize);
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let n = bh * t * hd;
|
||||
let qh: Vec<f32> = (0..n).map(|i| ((i * 31 % 97) as f32 / 48.0) - 1.0).collect();
|
||||
let kh: Vec<f32> = (0..n).map(|i| ((i * 53 % 89) as f32 / 44.0) - 1.0).collect();
|
||||
let vh: Vec<f32> = (0..n).map(|i| ((i * 17 % 83) as f32 / 41.0) - 1.0).collect();
|
||||
let q = Tensor::from_slice(&qh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&kh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&vh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
||||
|
||||
// Reference: full causal attention, take each head's last query row.
|
||||
let (full, _) = q.attention(&k, &v, scale);
|
||||
let full_h = full.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
|
||||
// Decode: build Q_last [bh,1,hd] from each head's last row, attend to all K/V.
|
||||
let mut ql = vec![0f32; bh * hd];
|
||||
for b in 0..bh {
|
||||
let src = (b * t + (t - 1)) * hd;
|
||||
ql[b * hd..(b + 1) * hd].copy_from_slice(&qh[src..src + hd]);
|
||||
}
|
||||
let q_last = Tensor::from_slice(&ql, &[bh, 1, hd]).to_device(Device::Cuda(0));
|
||||
let dec = q_last
|
||||
.decode_attention(&k, &v, scale)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec();
|
||||
assert_eq!(dec.len(), bh * hd, "decode out shape");
|
||||
|
||||
let mut max_abs = 0f32;
|
||||
for b in 0..bh {
|
||||
for d in 0..hd {
|
||||
let got = dec[b * hd + d];
|
||||
let exp = full_h[(b * t + (t - 1)) * hd + d];
|
||||
max_abs = max_abs.max((got - exp).abs());
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
max_abs < 1e-4,
|
||||
"decode_attention vs full last-row max abs diff {max_abs} exceeds 1e-4"
|
||||
);
|
||||
println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})");
|
||||
}
|
||||
|
||||
@@ -103,6 +103,10 @@ fn main() {
|
||||
let n_show = flag(&args, "--show", 8usize);
|
||||
let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required");
|
||||
let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required");
|
||||
// M2: decode through the KV-cache incremental engine instead of the naive
|
||||
// full-recompute sampler. Token-identical to the naive path (gated by
|
||||
// tests/decode_kv.rs); this flag also lets us A/B the two for the speedup.
|
||||
let use_cached = args.iter().any(|a| a == "--cached");
|
||||
|
||||
// Prompts: skip the `#` header / blank lines and decode escaped newlines so the
|
||||
// count and order line up with the gold file.
|
||||
@@ -148,18 +152,26 @@ fn main() {
|
||||
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
|
||||
|
||||
println!(
|
||||
"eval_arith: ckpt {} | {} prompts | max_new {}",
|
||||
"eval_arith: ckpt {} | {} prompts | max_new {} | decode={}",
|
||||
ckpt.display(),
|
||||
prompts.len(),
|
||||
max_new
|
||||
max_new,
|
||||
if use_cached { "kv-cache" } else { "naive" }
|
||||
);
|
||||
|
||||
let (mut n_boxed, mut n_correct) = (0usize, 0usize);
|
||||
let mut shown = 0usize;
|
||||
let mut gen_tokens = 0usize;
|
||||
let t0 = std::time::Instant::now();
|
||||
for (prompt, &gold) in prompts.iter().zip(&golds) {
|
||||
let ids: Vec<i32> = tok.encode(prompt).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 7u64;
|
||||
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
|
||||
let out = if use_cached {
|
||||
xtrain_model::generate_greedy_cached(&model, device, &ids, max_new)
|
||||
} else {
|
||||
let mut rng = 7u64;
|
||||
generate(&model, device, &ids, max_new, 0.0, &mut rng)
|
||||
};
|
||||
gen_tokens += out.len() - ids.len();
|
||||
let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
let seg = first_answer_segment(&cont);
|
||||
if parse_boxed_answer(seg).is_some() {
|
||||
@@ -176,6 +188,7 @@ fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = t0.elapsed().as_secs_f64();
|
||||
let n = prompts.len() as f64;
|
||||
println!(
|
||||
"RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)",
|
||||
@@ -186,4 +199,11 @@ fn main() {
|
||||
prompts.len(),
|
||||
100.0 * n_correct as f64 / n,
|
||||
);
|
||||
println!(
|
||||
"TIMING decode={} | {:.2}s | {} gen tokens | {:.1} tok/s",
|
||||
if use_cached { "kv-cache" } else { "naive" },
|
||||
elapsed,
|
||||
gen_tokens,
|
||||
gen_tokens as f64 / elapsed,
|
||||
);
|
||||
}
|
||||
|
||||
94
crates/xtrain-train/tests/decode_kv.rs
Normal file
94
crates/xtrain-train/tests/decode_kv.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
// M2a KV-cache decode engine — the token-identical correctness gate.
|
||||
//
|
||||
// The centerpiece M2 invariant: greedy decode through the KV-cache incremental
|
||||
// engine (`xtrain_model::generate_greedy_cached`) must be TOKEN-IDENTICAL to the
|
||||
// naive full-recompute greedy (`xtrain_train::sample::generate` at temperature 0),
|
||||
// which re-runs the whole forward over the growing prefix each step. Same tokens ⇒
|
||||
// the cache + decode-time attention + RoPE-at-position reproduce the full forward.
|
||||
//
|
||||
// Numerics note: a randomly-initialised model has near-uniform logits, so argmax
|
||||
// can be fragile to ~1e-6 differences. This unit gate therefore runs in F32 (the
|
||||
// tightest path, and the dtype the eval harness actually uses) on a small model.
|
||||
// The headline gate on the trained v12 checkpoint (peaked logits → robust argmax)
|
||||
// is run on the GPU box and recorded in docs/18.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, generate_greedy_cached};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device, dtype: DType) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
m.with_compute_dtype(dtype)
|
||||
}
|
||||
|
||||
// A real GQA config (8 query / 2 kv heads → group 4) to exercise repeat_kv in the
|
||||
// decode path; head_dim 16, dim 128, 4 layers.
|
||||
fn gqa_cfg() -> Config {
|
||||
Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kv_cache_decode_is_token_identical_to_naive_f32() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let model = build(gqa_cfg(), device, DType::F32);
|
||||
let prompt: Vec<i32> = vec![1, 5, 9, 13, 2, 7];
|
||||
let max_new = 24usize;
|
||||
|
||||
let mut rng = 7u64;
|
||||
let naive = xtrain_train::sample::generate(&model, device, &prompt, max_new, 0.0, &mut rng);
|
||||
let cached = generate_greedy_cached(&model, device, &prompt, max_new);
|
||||
|
||||
assert_eq!(
|
||||
naive.len(),
|
||||
cached.len(),
|
||||
"length mismatch: naive {} vs cached {}",
|
||||
naive.len(),
|
||||
cached.len()
|
||||
);
|
||||
if naive != cached {
|
||||
// Report the first divergence for debugging.
|
||||
let first = naive
|
||||
.iter()
|
||||
.zip(&cached)
|
||||
.position(|(a, b)| a != b)
|
||||
.unwrap();
|
||||
panic!(
|
||||
"token divergence at index {first}: naive={:?} cached={:?}\nnaive ={naive:?}\ncached ={cached:?}",
|
||||
naive[first], cached[first]
|
||||
);
|
||||
}
|
||||
println!(
|
||||
"KV-cache decode token-identical to naive over {} generated tokens (F32, GQA 8/2)",
|
||||
max_new
|
||||
);
|
||||
}
|
||||
@@ -242,6 +242,33 @@ void launch_rope_f32(const float* x, float* y, int tokens, int heads,
|
||||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
|
||||
}
|
||||
|
||||
// RoPE at an absolute position offset (KV-cache decode-time, forward only). Same
|
||||
// rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) —
|
||||
// a single new decode token sits at absolute position pos0. The training rope_k
|
||||
// (position = tok % period) is left untouched, so this adds no training-path risk.
|
||||
__global__ void rope_at_k(const float* x, float* y, int heads, int head_dim,
|
||||
float theta, int pos0) {
|
||||
int tok = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int half = head_dim / 2;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half) return;
|
||||
int pos = pos0 + tok;
|
||||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||
float angle = (float)pos * freq;
|
||||
float c = cosf(angle), sn = sinf(angle);
|
||||
int base = (tok * heads + head) * head_dim;
|
||||
float x0 = x[base + i], x1 = x[base + i + half];
|
||||
y[base + i] = x0 * c - x1 * sn;
|
||||
y[base + i + half] = x1 * c + x0 * sn;
|
||||
}
|
||||
void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
||||
int head_dim, float theta, int pos0, void* s) {
|
||||
dim3 grid(tokens, heads);
|
||||
int blk = head_dim / 2;
|
||||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||
}
|
||||
|
||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||
float theta, int period) {
|
||||
int tok = blockIdx.x;
|
||||
|
||||
@@ -360,3 +360,53 @@ gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to cl
|
||||
held-out correct > 0 confirms the checker + eval harness score real matches (not just format).
|
||||
M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic
|
||||
skill, which is downstream by design.
|
||||
|
||||
### M2a — KV-cache incremental-decode engine (single sequence, landed)
|
||||
|
||||
The decode engine (D3, built up front) that replaces the naive sampler — which re-runs the
|
||||
full forward over the growing prefix every step (O(t²), a fresh autograd graph per token). Two
|
||||
forward-only primitives + a raw-Tensor per-token block forward, each gated in isolation.
|
||||
|
||||
**Primitives (`xtrain-tensor`, both forward-only):**
|
||||
- `Tensor::rope_at(theta, pos0)` — RoPE at a token's *absolute* position (`pos = pos0 + row`,
|
||||
no modulo), vs the training `rope` (`pos = row % period`) which is left untouched (new CUDA
|
||||
kernel `rope_at_k` → no training-path risk). Cached K is stored post-RoPE, so it must match
|
||||
what the full forward produced at that position. **Gate:** bit-identical to the full-sequence
|
||||
rope's row `t` (`integration::rope_at_matches_full_rope_row`).
|
||||
- `Tensor::decode_attention(k, v, scale)` — single-query × cached-K/V SDPA (`[bh,1,hd]` vs
|
||||
`[bh,t,hd]`, no causal mask: the one query sees all cached keys). Composed from the existing
|
||||
strided batched GEMM + plain softmax — **no new kernel**. **Gate:** equals the full causal
|
||||
attention's last query row, max |Δ| 6e-8 (`integration::decode_attention_matches_…`).
|
||||
|
||||
**Engine (`xtrain-model/src/decode.rs`, `generate_greedy_cached`):** per-layer K/V cache +
|
||||
single-token incremental forward. Prefill = the first `prompt.len()` decode steps (one code
|
||||
path). Mirrors `model::block_forward` at the raw-Tensor level (no autograd tape — inference
|
||||
needs no grads), pulling weights via the public `params()` stable order (no model-internal
|
||||
visibility changes). The cache is host-accumulated token-major f32, rebuilt per step — the
|
||||
honest M2a baseline; M2b moves it device-side + adds batched ragged decode.
|
||||
|
||||
**Gate (the M2 centerpiece — token-identical):** KV-cache greedy decode is byte-for-byte the
|
||||
same token sequence as the naive full-recompute greedy. Verified two ways:
|
||||
- `xtrain-train/tests/decode_kv.rs` — small GQA model (8 query / 2 kv heads), F32, 24 generated
|
||||
tokens, exact token-equality. (Unit gate runs F32: a random model's near-uniform logits make
|
||||
argmax fragile to ~1e-6, so the tightest path is used; the trained model below has peaked
|
||||
logits → robust.)
|
||||
- v12 1.05B SFT checkpoint: `eval_arith --cached` produces the **identical** eval outcome to the
|
||||
naive run (format 100/100, correct 8/100) and byte-identical completions.
|
||||
|
||||
**Throughput baseline (v12 1.05B, batch 1, F32, profile-first — measured, not assumed):** the
|
||||
cache win is **sequence-length-dependent**, which is the honest systems finding here:
|
||||
|
||||
| max_new | naive | kv-cache | note |
|
||||
|---------|-------|----------|------|
|
||||
| 32 | 108 tok/s | 111 tok/s | ~1.0× — both **launch/overhead-bound** at short seq |
|
||||
| 128 | 69 tok/s | **133 tok/s** | **~1.9×** — naive's O(t²) recompute starts to bite |
|
||||
| 256 | **OOM** | 129 tok/s | naive rebuilds the O(seq²) graph every step → OOM |
|
||||
|
||||
Cached throughput stays ~constant (O(1)/token compute + constant memory); naive **decays**
|
||||
(108→69 tok/s, O(t)/token) and eventually **OOMs** (the full autograd graph per step). So at the
|
||||
short arithmetic-eval lengths the cache is overhead-bound and gives ~nothing — it matters for
|
||||
**long rollouts** (DPO pair-generation, GRPO completions), exactly where M3/M4 use it. (M2a's
|
||||
per-layer host round-trip is part of why short-seq is overhead-bound; M2b's device-side cache
|
||||
targets it.) This is the same measure-first lesson as T17 (process-per-GPU throughput-neutral):
|
||||
the win is real but only in the regime that actually stresses the bottleneck.
|
||||
|
||||
@@ -97,6 +97,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
||||
|
||||
**M1(SFT task baseline,已落地)**:可验证算术任务 + 数据生成器 + 评分器一套,host-side 9/9 单测过(masking、SFT-target 自洽 2000 样、parser 边界、种子确定性)。dash5 单卡从 v12 基座 SFT(loss 4.68→~0.34,best val 0.386)。**100 留出题 eval:格式 `\boxed{}` 习得率 base 0% → SFT 100%;算术正确率 8%。**——SFT 只买**格式**(0%→100% 干净落地),算术正确性是 base 模型本身弱项(如 `46*80` 框成 3380),正是 M3/M4 的可验证 reward 要去补的残差。一条诚实账:M1 用的是**朴素无 KV-cache 采样器**(每 token 全量 forward),100 题已经很慢——这正是 M2 解码引擎前置的动机。
|
||||
|
||||
**M2a(KV-cache 增量解码引擎,单序列,已落地)**:两个 forward-only 原语 + 裸 Tensor 逐 token block forward,各自隔离闸门。`rope_at`(绝对位置 RoPE,新 kernel,不动训练 `rope` → 训练路径零风险)逐位等于全序列 rope 的对应行;`decode_attention`(单 query × cached-K/V,由现成 strided-gemm + 普通 softmax 组合,**零新 kernel**)等于全 causal attention 末行(max|Δ| 6e-8)。引擎 `generate_greedy_cached` 镜像 `block_forward` 在 Tensor 层(无 autograd tape,推理不需梯度),靠**公开 `params()` 稳定顺序**拿权重(零 model 可见性改动)。**核心闸门 = token-identical**:与朴素全重算贪心逐 token 一致(小 GQA 单测 + v12 1.05B 上 cached eval 与 naive **逐字节相同**:format 100/100, correct 8/100)。**吞吐 baseline(v12, batch1, F32,profile-first 实测)= cache 收益随序列长度而定**:max_new 32 ≈ 持平(108 vs 111,短序列 launch 开销 bound)、128 **~1.9×**(69 vs 133)、256 naive **OOM** vs cached 129 tok/s。cached 吞吐**近恒定**(O(1)/token + 恒定显存),naive **衰减**(O(t)/token,O(seq²) 图 → OOM)。⇒ 短 eval prompt overhead-bound、cache 几乎无收益,真正受益的是**长 rollout**(DPO 造对 / GRPO completion)——与 T17(process-per-GPU 吞吐中性)同一条 measure-first 教训:收益真实,但只在真正压到瓶颈的 regime 里。M2a 的 per-layer 主机往返是短序列 overhead-bound 的一部分原因,M2b(device 端 cache + 批量 ragged)针对它。
|
||||
|
||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||
|
||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||
|
||||
Reference in New Issue
Block a user