post-train: M2c — device-side KV cache (cat_seq), profile-first bottleneck shift
Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01. Both single-seq and batched decode refactored to it; cache is Option<Tensor> per layer (cleaner than the host Vec + rebuild). Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch G-way both still TOKEN-IDENTICAL; GQA training path unaffected. Honest measurement (the point): removing the host round-trip buys ~10% on pure single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step (~8.5 s/step unchanged) — because after M2b batching the rollout is no longer the step's bottleneck; the per-sample per_token_logp captures + the PG-update forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson (cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next decode lever (ragged batched prefill) targets those, not the cache. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -164,6 +164,17 @@ unsafe extern "C" {
|
||||
theta: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Concatenate along the sequence dim: a:[bh,ta,hd], b:[bh,tb,hd] →
|
||||
// out:[bh,ta+tb,hd] (device-side KV-cache append, M2c).
|
||||
pub fn launch_cat_seq_f32(
|
||||
a: *const f32,
|
||||
b: *const f32,
|
||||
out: *mut f32,
|
||||
bh: i32,
|
||||
ta_hd: i32,
|
||||
tb_hd: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
|
||||
pub fn launch_scale_rows_f32(
|
||||
x: *const f32,
|
||||
|
||||
@@ -36,22 +36,29 @@ use xtrain_tensor::{DType, Device, Tensor};
|
||||
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
|
||||
/// tensor, so bf16 values round-trip bit-for-bit.
|
||||
struct KVCache {
|
||||
k: Vec<Vec<f32>>,
|
||||
v: Vec<Vec<f32>>,
|
||||
k: Vec<Option<Tensor>>,
|
||||
v: Vec<Option<Tensor>>,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
fn new(n_layers: usize) -> Self {
|
||||
Self {
|
||||
k: vec![Vec::new(); n_layers],
|
||||
v: vec![Vec::new(); n_layers],
|
||||
k: (0..n_layers).map(|_| None).collect(),
|
||||
v: (0..n_layers).map(|_| None).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
|
||||
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
|
||||
self.k[li].extend_from_slice(k_tok);
|
||||
self.v[li].extend_from_slice(v_tok);
|
||||
/// Append one token's K/V (`[bh,1,hd]`, compute dtype) to layer `li`, growing the
|
||||
/// device-resident `[bh,T,hd]` cache via `cat_seq` (no host round-trip, M2c).
|
||||
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
|
||||
self.k[li] = Some(match self.k[li].take() {
|
||||
Some(c) => c.cat_seq(&k_bh),
|
||||
None => k_bh,
|
||||
});
|
||||
self.v[li] = Some(match self.v[li].take() {
|
||||
Some(c) => c.cat_seq(&v_bh),
|
||||
None => v_bh,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +190,6 @@ fn decode_step(
|
||||
) -> Vec<f32> {
|
||||
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||||
let dim = cfg.dim;
|
||||
let kv_dim = num_kv * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||||
let n_layers = cfg.n_layers;
|
||||
@@ -212,28 +218,18 @@ fn decode_step(
|
||||
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
|
||||
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
|
||||
|
||||
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
|
||||
// K: same as Q (QK-norm + RoPE). V: project only. Append each as [num_kv,1,hd]
|
||||
// (bh-major) into the device cache; no host round-trip, no transpose (M2c).
|
||||
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
|
||||
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||||
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
|
||||
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
|
||||
let k_bh = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos).reshape(&[num_kv, 1, hd]);
|
||||
let v_bh = linear_t(cdt, &normed, wv).reshape(&[num_kv, 1, hd]);
|
||||
cache.append(li, k_bh, v_bh);
|
||||
|
||||
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
cache.append(li, &k_host, &v_host);
|
||||
|
||||
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
|
||||
// → repeat_kv to [nh,T,hd].
|
||||
let t_len = cache.k[li].len() / kv_dim;
|
||||
let build = |flat: &[f32]| -> Tensor {
|
||||
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
|
||||
.to_device(device)
|
||||
.transpose_3d01(); // [num_kv, T, hd], f32
|
||||
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
|
||||
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
|
||||
};
|
||||
let k_full = build(&cache.k[li]);
|
||||
let v_full = build(&cache.v[li]);
|
||||
// repeat_kv the cached [num_kv,T,hd] to [nh,T,hd] for the SDPA.
|
||||
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, 1) };
|
||||
let k_full = expand(cache.k[li].as_ref().unwrap());
|
||||
let v_full = expand(cache.v[li].as_ref().unwrap());
|
||||
|
||||
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
|
||||
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
|
||||
@@ -270,24 +266,30 @@ fn argmax(row: &[f32]) -> usize {
|
||||
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
|
||||
// ===================================================================
|
||||
|
||||
/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates
|
||||
/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt
|
||||
/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension.
|
||||
/// Batched K/V cache: `G` sequences advancing together. Per layer, a device-resident
|
||||
/// `[G·num_kv, T, head_dim]` grown one token per step via `cat_seq` (M2c — no host
|
||||
/// round-trip). Same as M2a's device cache with a G dimension in `bh`.
|
||||
struct BatchKVCache {
|
||||
k: Vec<Vec<f32>>,
|
||||
v: Vec<Vec<f32>>,
|
||||
k: Vec<Option<Tensor>>,
|
||||
v: Vec<Option<Tensor>>,
|
||||
}
|
||||
|
||||
impl BatchKVCache {
|
||||
fn new(n_layers: usize) -> Self {
|
||||
Self {
|
||||
k: vec![Vec::new(); n_layers],
|
||||
v: vec![Vec::new(); n_layers],
|
||||
k: (0..n_layers).map(|_| None).collect(),
|
||||
v: (0..n_layers).map(|_| None).collect(),
|
||||
}
|
||||
}
|
||||
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
|
||||
self.k[li].extend_from_slice(k_tok);
|
||||
self.v[li].extend_from_slice(v_tok);
|
||||
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
|
||||
self.k[li] = Some(match self.k[li].take() {
|
||||
Some(c) => c.cat_seq(&k_bh),
|
||||
None => k_bh,
|
||||
});
|
||||
self.v[li] = Some(match self.v[li].take() {
|
||||
Some(c) => c.cat_seq(&v_bh),
|
||||
None => v_bh,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,7 +371,6 @@ fn decode_step_batch(
|
||||
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||||
let dim = cfg.dim;
|
||||
let g = toks.len();
|
||||
let g_kv = g * num_kv; // batch·num_kv heads in the cache
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||||
let n_layers = cfg.n_layers;
|
||||
@@ -398,26 +399,20 @@ fn decode_step_batch(
|
||||
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
|
||||
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
|
||||
|
||||
// K/V appended as [G·num_kv,1,hd] (bh-major) into the device cache (M2c).
|
||||
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
|
||||
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||||
let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta);
|
||||
let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]);
|
||||
let k_bh = k
|
||||
.reshape(&[g, num_kv, hd])
|
||||
.rope_pos(&positions, theta)
|
||||
.reshape(&[g * num_kv, 1, hd]);
|
||||
let v_bh = linear_t(cdt, &normed, wv).reshape(&[g * num_kv, 1, hd]);
|
||||
cache.append(li, k_bh, v_bh);
|
||||
|
||||
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
cache.append(li, &k_host, &v_host);
|
||||
|
||||
// Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd].
|
||||
let t_len = cache.k[li].len() / (g_kv * hd);
|
||||
let build = |flat: &[f32]| -> Tensor {
|
||||
let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd])
|
||||
.to_device(device)
|
||||
.transpose_3d01(); // [G·num_kv, T, hd], f32
|
||||
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
|
||||
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd]
|
||||
};
|
||||
let k_full = build(&cache.k[li]);
|
||||
let v_full = build(&cache.v[li]);
|
||||
// repeat_kv the cached [G·num_kv,T,hd] to [G·nh,T,hd] for the SDPA.
|
||||
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, g) };
|
||||
let k_full = expand(cache.k[li].as_ref().unwrap());
|
||||
let v_full = expand(cache.v[li].as_ref().unwrap());
|
||||
|
||||
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
|
||||
let attn = attn.reshape(&[g, dim]); // concat heads per sequence
|
||||
|
||||
@@ -856,6 +856,41 @@ impl Tensor {
|
||||
out
|
||||
}
|
||||
|
||||
/// Concatenate along the sequence (middle) dim: `self`:[bh,ta,hd] ++
|
||||
/// `other`:[bh,tb,hd] → `[bh,ta+tb,hd]`. The device-side KV-cache append (M2c):
|
||||
/// the cache stays on the GPU and grows by one token per decode step, removing
|
||||
/// the M2a/M2b host round-trip. Mirrors the bf16 cast handling of the other
|
||||
/// structural kernels.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn cat_seq(&self, other: &Tensor) -> Self {
|
||||
assert_eq!(self.ndim(), 3, "cat_seq requires [bh,t,hd]");
|
||||
assert_eq!(other.ndim(), 3, "cat_seq requires [bh,t,hd]");
|
||||
assert_eq!(self.dtype, other.dtype, "cat_seq dtype mismatch");
|
||||
let (bh, ta, hd) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let (bh2, tb, hd2) = (other.shape[0], other.shape[1], other.shape[2]);
|
||||
assert_eq!(bh, bh2, "cat_seq bh mismatch");
|
||||
assert_eq!(hd, hd2, "cat_seq head_dim mismatch");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.cat_seq(&other.to_dtype(DType::F32))
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
let out = Tensor::zeros(&[bh, ta + tb, hd], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_cat_seq_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
bh as i32,
|
||||
(ta * hd) as i32,
|
||||
(tb * hd) as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
||||
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
||||
#[cfg(not(no_cuda))]
|
||||
|
||||
@@ -197,3 +197,29 @@ fn rope_pos_matches_rope_and_rope_at() {
|
||||
}
|
||||
println!("rope_pos OK: == full rope for [0..n] and == rope_at(P) per row for uniform P");
|
||||
}
|
||||
|
||||
/// (f) `cat_seq` (device-side KV-cache append, M2c): concatenating [bh,ta,hd] ++
|
||||
/// [bh,tb,hd] along the seq dim equals the host-side interleaved concat (per bh row,
|
||||
/// a's block then b's block). This is the device append that removes the M2a/M2b
|
||||
/// host round-trip.
|
||||
#[test]
|
||||
fn cat_seq_matches_host_concat() {
|
||||
assert!(device::device_count().expect("device count") > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let (bh, ta, tb, hd) = (4usize, 3usize, 2usize, 5usize);
|
||||
let ah: Vec<f32> = (0..bh * ta * hd).map(|i| i as f32 * 0.1).collect();
|
||||
let bhost: Vec<f32> = (0..bh * tb * hd).map(|i| -(i as f32) - 1.0).collect();
|
||||
let a = Tensor::from_slice(&ah, &[bh, ta, hd]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&bhost, &[bh, tb, hd]).to_device(Device::Cuda(0));
|
||||
|
||||
let got = a.cat_seq(&b).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
// Host reference: per bh row, a's ta*hd then b's tb*hd.
|
||||
let mut want = vec![0f32; bh * (ta + tb) * hd];
|
||||
for r in 0..bh {
|
||||
let (oa, ob, oo) = (r * ta * hd, r * tb * hd, r * (ta + tb) * hd);
|
||||
want[oo..oo + ta * hd].copy_from_slice(&ah[oa..oa + ta * hd]);
|
||||
want[oo + ta * hd..oo + (ta + tb) * hd].copy_from_slice(&bhost[ob..ob + tb * hd]);
|
||||
}
|
||||
assert_eq!(got, want, "cat_seq != host interleaved concat");
|
||||
println!("cat_seq OK: [bh={bh},{ta}+{tb},{hd}] == host concat");
|
||||
}
|
||||
|
||||
@@ -296,6 +296,25 @@ void launch_rope_pos_f32(const float* x, const int* positions, float* y,
|
||||
rope_pos_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, positions, y, heads, head_dim, theta);
|
||||
}
|
||||
|
||||
// Concatenate along the sequence (middle) dim: a:[bh,ta,hd], b:[bh,tb,hd] →
|
||||
// out:[bh,ta+tb,hd] with out[:, :ta]=a, out[:, ta:]=b. The device-side KV-cache
|
||||
// append (M2c): keeps K/V on the GPU and grows by one token per step, removing the
|
||||
// host round-trip the M2a/M2b host cache paid. One block per bh row.
|
||||
__global__ void cat_seq_k(const float* a, const float* b, float* out,
|
||||
int ta_hd, int tb_hd) {
|
||||
int i = blockIdx.x; // bh row
|
||||
int o_hd = ta_hd + tb_hd;
|
||||
const float* ar = a + (long)i * ta_hd;
|
||||
const float* br = b + (long)i * tb_hd;
|
||||
float* outr = out + (long)i * o_hd;
|
||||
for (int j = threadIdx.x; j < ta_hd; j += blockDim.x) outr[j] = ar[j];
|
||||
for (int j = threadIdx.x; j < tb_hd; j += blockDim.x) outr[ta_hd + j] = br[j];
|
||||
}
|
||||
void launch_cat_seq_f32(const float* a, const float* b, float* out,
|
||||
int bh, int ta_hd, int tb_hd, void* s) {
|
||||
cat_seq_k<<<bh, 256, 0, (cudaStream_t)s>>>(a, b, out, ta_hd, tb_hd);
|
||||
}
|
||||
|
||||
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
||||
// (M4) policy-gradient backward, where each completion token's row of
|
||||
// (probs − onehot) is scaled by its own per-token coefficient.
|
||||
|
||||
Reference in New Issue
Block a user