post-train: M2c — device-side KV cache (cat_seq), profile-first bottleneck shift

Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).

Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.

Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 17:38:16 +08:00
parent 0f76c0fdb0
commit 3a3425960c
5 changed files with 142 additions and 56 deletions

View File

@@ -36,22 +36,29 @@ use xtrain_tensor::{DType, Device, Tensor};
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
/// tensor, so bf16 values round-trip bit-for-bit.
struct KVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
k: Vec<Option<Tensor>>,
v: Vec<Option<Tensor>>,
}
impl KVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
k: (0..n_layers).map(|_| None).collect(),
v: (0..n_layers).map(|_| None).collect(),
}
}
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
/// Append one token's K/V (`[bh,1,hd]`, compute dtype) to layer `li`, growing the
/// device-resident `[bh,T,hd]` cache via `cat_seq` (no host round-trip, M2c).
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
self.k[li] = Some(match self.k[li].take() {
Some(c) => c.cat_seq(&k_bh),
None => k_bh,
});
self.v[li] = Some(match self.v[li].take() {
Some(c) => c.cat_seq(&v_bh),
None => v_bh,
});
}
}
@@ -183,7 +190,6 @@ fn decode_step(
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let kv_dim = num_kv * hd;
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
@@ -212,28 +218,18 @@ fn decode_step(
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
// K: same as Q (QK-norm + RoPE). V: project only. Append each as [num_kv,1,hd]
// (bh-major) into the device cache; no host round-trip, no transpose (M2c).
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
let k_bh = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos).reshape(&[num_kv, 1, hd]);
let v_bh = linear_t(cdt, &normed, wv).reshape(&[num_kv, 1, hd]);
cache.append(li, k_bh, v_bh);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
// → repeat_kv to [nh,T,hd].
let t_len = cache.k[li].len() / kv_dim;
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
.to_device(device)
.transpose_3d01(); // [num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
// repeat_kv the cached [num_kv,T,hd] to [nh,T,hd] for the SDPA.
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, 1) };
let k_full = expand(cache.k[li].as_ref().unwrap());
let v_full = expand(cache.v[li].as_ref().unwrap());
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
@@ -270,24 +266,30 @@ fn argmax(row: &[f32]) -> usize {
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
// ===================================================================
/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates
/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt
/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension.
/// Batched K/V cache: `G` sequences advancing together. Per layer, a device-resident
/// `[G·num_kv, T, head_dim]` grown one token per step via `cat_seq` (M2c — no host
/// round-trip). Same as M2a's device cache with a G dimension in `bh`.
struct BatchKVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
k: Vec<Option<Tensor>>,
v: Vec<Option<Tensor>>,
}
impl BatchKVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
k: (0..n_layers).map(|_| None).collect(),
v: (0..n_layers).map(|_| None).collect(),
}
}
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
self.k[li] = Some(match self.k[li].take() {
Some(c) => c.cat_seq(&k_bh),
None => k_bh,
});
self.v[li] = Some(match self.v[li].take() {
Some(c) => c.cat_seq(&v_bh),
None => v_bh,
});
}
}
@@ -369,7 +371,6 @@ fn decode_step_batch(
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let g = toks.len();
let g_kv = g * num_kv; // batch·num_kv heads in the cache
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
@@ -398,26 +399,20 @@ fn decode_step_batch(
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
// K/V appended as [G·num_kv,1,hd] (bh-major) into the device cache (M2c).
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta);
let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]);
let k_bh = k
.reshape(&[g, num_kv, hd])
.rope_pos(&positions, theta)
.reshape(&[g * num_kv, 1, hd]);
let v_bh = linear_t(cdt, &normed, wv).reshape(&[g * num_kv, 1, hd]);
cache.append(li, k_bh, v_bh);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd].
let t_len = cache.k[li].len() / (g_kv * hd);
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd])
.to_device(device)
.transpose_3d01(); // [G·num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
// repeat_kv the cached [G·num_kv,T,hd] to [G·nh,T,hd] for the SDPA.
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, g) };
let k_full = expand(cache.k[li].as_ref().unwrap());
let v_full = expand(cache.v[li].as_ref().unwrap());
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
let attn = attn.reshape(&[g, dim]); // concat heads per sequence