From 3a3425960c7aeaafb93ec9d175e21010b99130df Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 30 Jun 2026 17:38:16 +0800 Subject: [PATCH] =?UTF-8?q?post-train:=20M2c=20=E2=80=94=20device-side=20K?= =?UTF-8?q?V=20cache=20(cat=5Fseq),=20profile-first=20bottleneck=20shift?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01. Both single-seq and batched decode refactored to it; cache is Option per layer (cleaner than the host Vec + rebuild). Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch G-way both still TOKEN-IDENTICAL; GQA training path unaffected. Honest measurement (the point): removing the host round-trip buys ~10% on pure single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step (~8.5 s/step unchanged) — because after M2b batching the rollout is no longer the step's bottleneck; the per-sample per_token_logp captures + the PG-update forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson (cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next decode lever (ragged batched prefill) targets those, not the cache. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-cuda/src/ffi.rs | 11 +++ crates/xtrain-model/src/decode.rs | 107 +++++++++++----------- crates/xtrain-tensor/src/tensor.rs | 35 +++++++ crates/xtrain-tensor/tests/integration.rs | 26 ++++++ csrc/ops/nn.cu | 19 ++++ 5 files changed, 142 insertions(+), 56 deletions(-) diff --git a/crates/xtrain-cuda/src/ffi.rs b/crates/xtrain-cuda/src/ffi.rs index 1726c91..82505ca 100644 --- a/crates/xtrain-cuda/src/ffi.rs +++ b/crates/xtrain-cuda/src/ffi.rs @@ -164,6 +164,17 @@ unsafe extern "C" { theta: f32, s: CudaStream, ); + // Concatenate along the sequence dim: a:[bh,ta,hd], b:[bh,tb,hd] → + // out:[bh,ta+tb,hd] (device-side KV-cache append, M2c). + pub fn launch_cat_seq_f32( + a: *const f32, + b: *const f32, + out: *mut f32, + bh: i32, + ta_hd: i32, + tb_hd: i32, + s: CudaStream, + ); // Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward). pub fn launch_scale_rows_f32( x: *const f32, diff --git a/crates/xtrain-model/src/decode.rs b/crates/xtrain-model/src/decode.rs index f719f53..4a975c1 100644 --- a/crates/xtrain-model/src/decode.rs +++ b/crates/xtrain-model/src/decode.rs @@ -36,22 +36,29 @@ use xtrain_tensor::{DType, Device, Tensor}; /// bf16 projection output); rebuilt to the compute dtype when forming the K/V /// tensor, so bf16 values round-trip bit-for-bit. struct KVCache { - k: Vec>, - v: Vec>, + k: Vec>, + v: Vec>, } impl KVCache { fn new(n_layers: usize) -> Self { Self { - k: vec![Vec::new(); n_layers], - v: vec![Vec::new(); n_layers], + k: (0..n_layers).map(|_| None).collect(), + v: (0..n_layers).map(|_| None).collect(), } } - /// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`. - fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) { - self.k[li].extend_from_slice(k_tok); - self.v[li].extend_from_slice(v_tok); + /// Append one token's K/V (`[bh,1,hd]`, compute dtype) to layer `li`, growing the + /// device-resident `[bh,T,hd]` cache via `cat_seq` (no host round-trip, M2c). + fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) { + self.k[li] = Some(match self.k[li].take() { + Some(c) => c.cat_seq(&k_bh), + None => k_bh, + }); + self.v[li] = Some(match self.v[li].take() { + Some(c) => c.cat_seq(&v_bh), + None => v_bh, + }); } } @@ -183,7 +190,6 @@ fn decode_step( ) -> Vec { let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads); let dim = cfg.dim; - let kv_dim = num_kv * hd; let scale = 1.0 / (hd as f32).sqrt(); let (theta, eps) = (cfg.rope_theta, cfg.eps); let n_layers = cfg.n_layers; @@ -212,28 +218,18 @@ fn decode_step( let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos); let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data - // K: same as Q (QK-norm + RoPE); cache token-major. V: project only. + // K: same as Q (QK-norm + RoPE). V: project only. Append each as [num_kv,1,hd] + // (bh-major) into the device cache; no host round-trip, no transpose (M2c). let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]); let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0; - let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd] - let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]); + let k_bh = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos).reshape(&[num_kv, 1, hd]); + let v_bh = linear_t(cdt, &normed, wv).reshape(&[num_kv, 1, hd]); + cache.append(li, k_bh, v_bh); - let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::().to_vec(); - let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::().to_vec(); - cache.append(li, &k_host, &v_host); - - // Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd] - // → repeat_kv to [nh,T,hd]. - let t_len = cache.k[li].len() / kv_dim; - let build = |flat: &[f32]| -> Tensor { - let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd]) - .to_device(device) - .transpose_3d01(); // [num_kv, T, hd], f32 - let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv }; - if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd] - }; - let k_full = build(&cache.k[li]); - let v_full = build(&cache.v[li]); + // repeat_kv the cached [num_kv,T,hd] to [nh,T,hd] for the SDPA. + let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, 1) }; + let k_full = expand(cache.k[li].as_ref().unwrap()); + let v_full = expand(cache.v[li].as_ref().unwrap()); let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd] let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim) @@ -270,24 +266,30 @@ fn argmax(row: &[f32]) -> usize { // M2b — batched KV-cache decode (G samples of one prompt, in lockstep) // =================================================================== -/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates -/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt -/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension. +/// Batched K/V cache: `G` sequences advancing together. Per layer, a device-resident +/// `[G·num_kv, T, head_dim]` grown one token per step via `cat_seq` (M2c — no host +/// round-trip). Same as M2a's device cache with a G dimension in `bh`. struct BatchKVCache { - k: Vec>, - v: Vec>, + k: Vec>, + v: Vec>, } impl BatchKVCache { fn new(n_layers: usize) -> Self { Self { - k: vec![Vec::new(); n_layers], - v: vec![Vec::new(); n_layers], + k: (0..n_layers).map(|_| None).collect(), + v: (0..n_layers).map(|_| None).collect(), } } - fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) { - self.k[li].extend_from_slice(k_tok); - self.v[li].extend_from_slice(v_tok); + fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) { + self.k[li] = Some(match self.k[li].take() { + Some(c) => c.cat_seq(&k_bh), + None => k_bh, + }); + self.v[li] = Some(match self.v[li].take() { + Some(c) => c.cat_seq(&v_bh), + None => v_bh, + }); } } @@ -369,7 +371,6 @@ fn decode_step_batch( let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads); let dim = cfg.dim; let g = toks.len(); - let g_kv = g * num_kv; // batch·num_kv heads in the cache let scale = 1.0 / (hd as f32).sqrt(); let (theta, eps) = (cfg.rope_theta, cfg.eps); let n_layers = cfg.n_layers; @@ -398,26 +399,20 @@ fn decode_step_batch( let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta); let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh + // K/V appended as [G·num_kv,1,hd] (bh-major) into the device cache (M2c). let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]); let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0; - let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta); - let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]); + let k_bh = k + .reshape(&[g, num_kv, hd]) + .rope_pos(&positions, theta) + .reshape(&[g * num_kv, 1, hd]); + let v_bh = linear_t(cdt, &normed, wv).reshape(&[g * num_kv, 1, hd]); + cache.append(li, k_bh, v_bh); - let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::().to_vec(); - let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::().to_vec(); - cache.append(li, &k_host, &v_host); - - // Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd]. - let t_len = cache.k[li].len() / (g_kv * hd); - let build = |flat: &[f32]| -> Tensor { - let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd]) - .to_device(device) - .transpose_3d01(); // [G·num_kv, T, hd], f32 - let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv }; - if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd] - }; - let k_full = build(&cache.k[li]); - let v_full = build(&cache.v[li]); + // repeat_kv the cached [G·num_kv,T,hd] to [G·nh,T,hd] for the SDPA. + let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, g) }; + let k_full = expand(cache.k[li].as_ref().unwrap()); + let v_full = expand(cache.v[li].as_ref().unwrap()); let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd] let attn = attn.reshape(&[g, dim]); // concat heads per sequence diff --git a/crates/xtrain-tensor/src/tensor.rs b/crates/xtrain-tensor/src/tensor.rs index a954da0..ed326ab 100644 --- a/crates/xtrain-tensor/src/tensor.rs +++ b/crates/xtrain-tensor/src/tensor.rs @@ -856,6 +856,41 @@ impl Tensor { out } + /// Concatenate along the sequence (middle) dim: `self`:[bh,ta,hd] ++ + /// `other`:[bh,tb,hd] → `[bh,ta+tb,hd]`. The device-side KV-cache append (M2c): + /// the cache stays on the GPU and grows by one token per decode step, removing + /// the M2a/M2b host round-trip. Mirrors the bf16 cast handling of the other + /// structural kernels. + #[cfg(not(no_cuda))] + pub fn cat_seq(&self, other: &Tensor) -> Self { + assert_eq!(self.ndim(), 3, "cat_seq requires [bh,t,hd]"); + assert_eq!(other.ndim(), 3, "cat_seq requires [bh,t,hd]"); + assert_eq!(self.dtype, other.dtype, "cat_seq dtype mismatch"); + let (bh, ta, hd) = (self.shape[0], self.shape[1], self.shape[2]); + let (bh2, tb, hd2) = (other.shape[0], other.shape[1], other.shape[2]); + assert_eq!(bh, bh2, "cat_seq bh mismatch"); + assert_eq!(hd, hd2, "cat_seq head_dim mismatch"); + if self.dtype == DType::BF16 { + return self + .to_dtype(DType::F32) + .cat_seq(&other.to_dtype(DType::F32)) + .to_dtype(DType::BF16); + } + let out = Tensor::zeros(&[bh, ta + tb, hd], DType::F32, self.device()); + unsafe { + xtrain_cuda::ffi::launch_cat_seq_f32( + self.data_ptr() as *const f32, + other.data_ptr() as *const f32, + out.data_ptr() as *mut f32, + bh as i32, + (ta * hd) as i32, + (tb * hd) as i32, + std::ptr::null_mut(), + ); + } + out + } + /// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an /// orthogonal map, so it needs no cached forward values, only `theta`/`period`. #[cfg(not(no_cuda))] diff --git a/crates/xtrain-tensor/tests/integration.rs b/crates/xtrain-tensor/tests/integration.rs index b43d0a9..78f4581 100644 --- a/crates/xtrain-tensor/tests/integration.rs +++ b/crates/xtrain-tensor/tests/integration.rs @@ -197,3 +197,29 @@ fn rope_pos_matches_rope_and_rope_at() { } println!("rope_pos OK: == full rope for [0..n] and == rope_at(P) per row for uniform P"); } + +/// (f) `cat_seq` (device-side KV-cache append, M2c): concatenating [bh,ta,hd] ++ +/// [bh,tb,hd] along the seq dim equals the host-side interleaved concat (per bh row, +/// a's block then b's block). This is the device append that removes the M2a/M2b +/// host round-trip. +#[test] +fn cat_seq_matches_host_concat() { + assert!(device::device_count().expect("device count") > 0, "no CUDA device"); + device::set_device(0).unwrap(); + let (bh, ta, tb, hd) = (4usize, 3usize, 2usize, 5usize); + let ah: Vec = (0..bh * ta * hd).map(|i| i as f32 * 0.1).collect(); + let bhost: Vec = (0..bh * tb * hd).map(|i| -(i as f32) - 1.0).collect(); + let a = Tensor::from_slice(&ah, &[bh, ta, hd]).to_device(Device::Cuda(0)); + let b = Tensor::from_slice(&bhost, &[bh, tb, hd]).to_device(Device::Cuda(0)); + + let got = a.cat_seq(&b).to_device(Device::Cpu).as_slice::().to_vec(); + // Host reference: per bh row, a's ta*hd then b's tb*hd. + let mut want = vec![0f32; bh * (ta + tb) * hd]; + for r in 0..bh { + let (oa, ob, oo) = (r * ta * hd, r * tb * hd, r * (ta + tb) * hd); + want[oo..oo + ta * hd].copy_from_slice(&ah[oa..oa + ta * hd]); + want[oo + ta * hd..oo + (ta + tb) * hd].copy_from_slice(&bhost[ob..ob + tb * hd]); + } + assert_eq!(got, want, "cat_seq != host interleaved concat"); + println!("cat_seq OK: [bh={bh},{ta}+{tb},{hd}] == host concat"); +} diff --git a/csrc/ops/nn.cu b/csrc/ops/nn.cu index b1c9bae..5ff23f7 100644 --- a/csrc/ops/nn.cu +++ b/csrc/ops/nn.cu @@ -296,6 +296,25 @@ void launch_rope_pos_f32(const float* x, const int* positions, float* y, rope_pos_k<<>>(x, positions, y, heads, head_dim, theta); } +// Concatenate along the sequence (middle) dim: a:[bh,ta,hd], b:[bh,tb,hd] → +// out:[bh,ta+tb,hd] with out[:, :ta]=a, out[:, ta:]=b. The device-side KV-cache +// append (M2c): keeps K/V on the GPU and grows by one token per step, removing the +// host round-trip the M2a/M2b host cache paid. One block per bh row. +__global__ void cat_seq_k(const float* a, const float* b, float* out, + int ta_hd, int tb_hd) { + int i = blockIdx.x; // bh row + int o_hd = ta_hd + tb_hd; + const float* ar = a + (long)i * ta_hd; + const float* br = b + (long)i * tb_hd; + float* outr = out + (long)i * o_hd; + for (int j = threadIdx.x; j < ta_hd; j += blockDim.x) outr[j] = ar[j]; + for (int j = threadIdx.x; j < tb_hd; j += blockDim.x) outr[ta_hd + j] = br[j]; +} +void launch_cat_seq_f32(const float* a, const float* b, float* out, + int bh, int ta_hd, int tb_hd, void* s) { + cat_seq_k<<>>(a, b, out, ta_hd, tb_hd); +} + // Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO // (M4) policy-gradient backward, where each completion token's row of // (probs − onehot) is scaled by its own per-token coefficient.