post-train: M2c — device-side KV cache (cat_seq), profile-first bottleneck shift

Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).

Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.

Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 17:38:16 +08:00
parent 0f76c0fdb0
commit 3a3425960c
5 changed files with 142 additions and 56 deletions

View File

@@ -164,6 +164,17 @@ unsafe extern "C" {
theta: f32,
s: CudaStream,
);
// Concatenate along the sequence dim: a:[bh,ta,hd], b:[bh,tb,hd] →
// out:[bh,ta+tb,hd] (device-side KV-cache append, M2c).
pub fn launch_cat_seq_f32(
a: *const f32,
b: *const f32,
out: *mut f32,
bh: i32,
ta_hd: i32,
tb_hd: i32,
s: CudaStream,
);
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
pub fn launch_scale_rows_f32(
x: *const f32,

View File

@@ -36,22 +36,29 @@ use xtrain_tensor::{DType, Device, Tensor};
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
/// tensor, so bf16 values round-trip bit-for-bit.
struct KVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
k: Vec<Option<Tensor>>,
v: Vec<Option<Tensor>>,
}
impl KVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
k: (0..n_layers).map(|_| None).collect(),
v: (0..n_layers).map(|_| None).collect(),
}
}
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
/// Append one token's K/V (`[bh,1,hd]`, compute dtype) to layer `li`, growing the
/// device-resident `[bh,T,hd]` cache via `cat_seq` (no host round-trip, M2c).
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
self.k[li] = Some(match self.k[li].take() {
Some(c) => c.cat_seq(&k_bh),
None => k_bh,
});
self.v[li] = Some(match self.v[li].take() {
Some(c) => c.cat_seq(&v_bh),
None => v_bh,
});
}
}
@@ -183,7 +190,6 @@ fn decode_step(
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let kv_dim = num_kv * hd;
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
@@ -212,28 +218,18 @@ fn decode_step(
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
// K: same as Q (QK-norm + RoPE). V: project only. Append each as [num_kv,1,hd]
// (bh-major) into the device cache; no host round-trip, no transpose (M2c).
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
let k_bh = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos).reshape(&[num_kv, 1, hd]);
let v_bh = linear_t(cdt, &normed, wv).reshape(&[num_kv, 1, hd]);
cache.append(li, k_bh, v_bh);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
// → repeat_kv to [nh,T,hd].
let t_len = cache.k[li].len() / kv_dim;
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
.to_device(device)
.transpose_3d01(); // [num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
// repeat_kv the cached [num_kv,T,hd] to [nh,T,hd] for the SDPA.
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, 1) };
let k_full = expand(cache.k[li].as_ref().unwrap());
let v_full = expand(cache.v[li].as_ref().unwrap());
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
@@ -270,24 +266,30 @@ fn argmax(row: &[f32]) -> usize {
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
// ===================================================================
/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates
/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt
/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension.
/// Batched K/V cache: `G` sequences advancing together. Per layer, a device-resident
/// `[G·num_kv, T, head_dim]` grown one token per step via `cat_seq` (M2c — no host
/// round-trip). Same as M2a's device cache with a G dimension in `bh`.
struct BatchKVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
k: Vec<Option<Tensor>>,
v: Vec<Option<Tensor>>,
}
impl BatchKVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
k: (0..n_layers).map(|_| None).collect(),
v: (0..n_layers).map(|_| None).collect(),
}
}
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
self.k[li] = Some(match self.k[li].take() {
Some(c) => c.cat_seq(&k_bh),
None => k_bh,
});
self.v[li] = Some(match self.v[li].take() {
Some(c) => c.cat_seq(&v_bh),
None => v_bh,
});
}
}
@@ -369,7 +371,6 @@ fn decode_step_batch(
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let g = toks.len();
let g_kv = g * num_kv; // batch·num_kv heads in the cache
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
@@ -398,26 +399,20 @@ fn decode_step_batch(
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
// K/V appended as [G·num_kv,1,hd] (bh-major) into the device cache (M2c).
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta);
let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]);
let k_bh = k
.reshape(&[g, num_kv, hd])
.rope_pos(&positions, theta)
.reshape(&[g * num_kv, 1, hd]);
let v_bh = linear_t(cdt, &normed, wv).reshape(&[g * num_kv, 1, hd]);
cache.append(li, k_bh, v_bh);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd].
let t_len = cache.k[li].len() / (g_kv * hd);
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd])
.to_device(device)
.transpose_3d01(); // [G·num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
// repeat_kv the cached [G·num_kv,T,hd] to [G·nh,T,hd] for the SDPA.
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, g) };
let k_full = expand(cache.k[li].as_ref().unwrap());
let v_full = expand(cache.v[li].as_ref().unwrap());
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
let attn = attn.reshape(&[g, dim]); // concat heads per sequence

View File

@@ -856,6 +856,41 @@ impl Tensor {
out
}
/// Concatenate along the sequence (middle) dim: `self`:[bh,ta,hd] ++
/// `other`:[bh,tb,hd] → `[bh,ta+tb,hd]`. The device-side KV-cache append (M2c):
/// the cache stays on the GPU and grows by one token per decode step, removing
/// the M2a/M2b host round-trip. Mirrors the bf16 cast handling of the other
/// structural kernels.
#[cfg(not(no_cuda))]
pub fn cat_seq(&self, other: &Tensor) -> Self {
assert_eq!(self.ndim(), 3, "cat_seq requires [bh,t,hd]");
assert_eq!(other.ndim(), 3, "cat_seq requires [bh,t,hd]");
assert_eq!(self.dtype, other.dtype, "cat_seq dtype mismatch");
let (bh, ta, hd) = (self.shape[0], self.shape[1], self.shape[2]);
let (bh2, tb, hd2) = (other.shape[0], other.shape[1], other.shape[2]);
assert_eq!(bh, bh2, "cat_seq bh mismatch");
assert_eq!(hd, hd2, "cat_seq head_dim mismatch");
if self.dtype == DType::BF16 {
return self
.to_dtype(DType::F32)
.cat_seq(&other.to_dtype(DType::F32))
.to_dtype(DType::BF16);
}
let out = Tensor::zeros(&[bh, ta + tb, hd], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_cat_seq_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
bh as i32,
(ta * hd) as i32,
(tb * hd) as i32,
std::ptr::null_mut(),
);
}
out
}
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
#[cfg(not(no_cuda))]

View File

@@ -197,3 +197,29 @@ fn rope_pos_matches_rope_and_rope_at() {
}
println!("rope_pos OK: == full rope for [0..n] and == rope_at(P) per row for uniform P");
}
/// (f) `cat_seq` (device-side KV-cache append, M2c): concatenating [bh,ta,hd] ++
/// [bh,tb,hd] along the seq dim equals the host-side interleaved concat (per bh row,
/// a's block then b's block). This is the device append that removes the M2a/M2b
/// host round-trip.
#[test]
fn cat_seq_matches_host_concat() {
assert!(device::device_count().expect("device count") > 0, "no CUDA device");
device::set_device(0).unwrap();
let (bh, ta, tb, hd) = (4usize, 3usize, 2usize, 5usize);
let ah: Vec<f32> = (0..bh * ta * hd).map(|i| i as f32 * 0.1).collect();
let bhost: Vec<f32> = (0..bh * tb * hd).map(|i| -(i as f32) - 1.0).collect();
let a = Tensor::from_slice(&ah, &[bh, ta, hd]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&bhost, &[bh, tb, hd]).to_device(Device::Cuda(0));
let got = a.cat_seq(&b).to_device(Device::Cpu).as_slice::<f32>().to_vec();
// Host reference: per bh row, a's ta*hd then b's tb*hd.
let mut want = vec![0f32; bh * (ta + tb) * hd];
for r in 0..bh {
let (oa, ob, oo) = (r * ta * hd, r * tb * hd, r * (ta + tb) * hd);
want[oo..oo + ta * hd].copy_from_slice(&ah[oa..oa + ta * hd]);
want[oo + ta * hd..oo + (ta + tb) * hd].copy_from_slice(&bhost[ob..ob + tb * hd]);
}
assert_eq!(got, want, "cat_seq != host interleaved concat");
println!("cat_seq OK: [bh={bh},{ta}+{tb},{hd}] == host concat");
}

View File

@@ -296,6 +296,25 @@ void launch_rope_pos_f32(const float* x, const int* positions, float* y,
rope_pos_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, positions, y, heads, head_dim, theta);
}
// Concatenate along the sequence (middle) dim: a:[bh,ta,hd], b:[bh,tb,hd] →
// out:[bh,ta+tb,hd] with out[:, :ta]=a, out[:, ta:]=b. The device-side KV-cache
// append (M2c): keeps K/V on the GPU and grows by one token per step, removing the
// host round-trip the M2a/M2b host cache paid. One block per bh row.
__global__ void cat_seq_k(const float* a, const float* b, float* out,
int ta_hd, int tb_hd) {
int i = blockIdx.x; // bh row
int o_hd = ta_hd + tb_hd;
const float* ar = a + (long)i * ta_hd;
const float* br = b + (long)i * tb_hd;
float* outr = out + (long)i * o_hd;
for (int j = threadIdx.x; j < ta_hd; j += blockDim.x) outr[j] = ar[j];
for (int j = threadIdx.x; j < tb_hd; j += blockDim.x) outr[ta_hd + j] = br[j];
}
void launch_cat_seq_f32(const float* a, const float* b, float* out,
int bh, int ta_hd, int tb_hd, void* s) {
cat_seq_k<<<bh, 256, 0, (cudaStream_t)s>>>(a, b, out, ta_hd, tb_hd);
}
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
// (M4) policy-gradient backward, where each completion token's row of
// (probs onehot) is scaled by its own per-token coefficient.