post-train: M2b — batched KV-cache decode (G-way, token-identical)

The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).

- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
  only kernel) — G rows share one decode position. Gate: == full rope for
  [0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
  decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
  broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
  sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 17:18:54 +08:00
parent 096e45b845
commit 2c9b58cb3b
7 changed files with 369 additions and 1 deletions

View File

@@ -152,6 +152,18 @@ unsafe extern "C" {
pos0: i32,
s: CudaStream,
);
// RoPE with a per-row absolute position (batched KV-cache decode, M2b): row
// `tok`'s position is `positions[tok]`. Forward only.
pub fn launch_rope_pos_f32(
x: *const f32,
positions: *const i32,
y: *mut f32,
tokens: i32,
heads: i32,
head_dim: i32,
theta: f32,
s: CudaStream,
);
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
pub fn launch_scale_rows_f32(
x: *const f32,

View File

@@ -265,3 +265,177 @@ fn argmax(row: &[f32]) -> usize {
.unwrap()
.0
}
// ===================================================================
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
// ===================================================================
/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates
/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt
/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension.
struct BatchKVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
}
impl BatchKVCache {
fn new(n_layers: usize) -> Self {
Self {
k: vec![Vec::new(); n_layers],
v: vec![Vec::new(); n_layers],
}
}
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
self.k[li].extend_from_slice(k_tok);
self.v[li].extend_from_slice(v_tok);
}
}
/// Batched KV-cache decode: roll out `n_samples` (G) completions of the SAME
/// `prompt` in lockstep — all G share the prompt, so they advance at one common
/// decode position each step (uniform RoPE via `rope_pos`). Returns G full token
/// sequences (prompt + sampled continuation). The G-way batching amortises the
/// per-step kernel launches across G (the rollout long-pole). Token-identical per
/// row to G independent single-sequence decodes (gated by `tests/decode_batch.rs`).
///
/// `temperature == 0` ⇒ greedy (all G identical); `> 0` ⇒ independent samples
/// (per-row draw from one shared `rng_state`). No finished-mask: all G generate
/// `max_new` tokens; the caller cuts each at `<|endoftext|>` (a perf-only early
/// stop is the M2b+ follow-up). Ragged (different-length prompts) is also deferred.
pub fn generate_cached_batch(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
n_samples: usize,
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<Vec<i32>> {
assert!(!prompt.is_empty(), "prompt must be non-empty");
assert!(n_samples > 0, "n_samples must be > 0");
let cfg = model.config();
let cdt = model.compute_dtype();
let n_layers = cfg.n_layers;
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
let embed = &params[0];
let final_norm = &params[1 + n_layers * 11];
let lm_head = &params[1 + n_layers * 11 + 1];
let g = n_samples;
let mut cache = BatchKVCache::new(n_layers);
let mut seqs: Vec<Vec<i32>> = vec![prompt.to_vec(); g];
// Prefill: feed each prompt token (identical across G) at its position.
let mut logits = Vec::new(); // [G, vocab] flattened
for (pos, &tok) in prompt.iter().enumerate() {
let toks = vec![tok; g];
logits = decode_step_batch(&params, cfg, cdt, device, &mut cache, &toks, pos, embed, final_norm, lm_head);
}
let vocab = cfg.vocab;
for _ in 0..max_new {
let mut next = Vec::with_capacity(g);
for row in 0..g {
let lg = &logits[row * vocab..(row + 1) * vocab];
let t = if temperature <= 0.0 {
argmax(lg) as i32
} else {
sample_temperature(lg, temperature, rng_state) as i32
};
next.push(t);
seqs[row].push(t);
}
let pos = seqs[0].len() - 1; // all G are at the same position
logits = decode_step_batch(&params, cfg, cdt, device, &mut cache, &next, pos, embed, final_norm, lm_head);
}
seqs
}
/// One batched decode step: `toks` is one current token per sequence (`[G]`), all at
/// absolute position `pos`. Appends each sequence's K/V and returns logits `[G·vocab]`.
#[allow(clippy::too_many_arguments)]
fn decode_step_batch(
params: &[Tensor],
cfg: &crate::Config,
cdt: DType,
device: Device,
cache: &mut BatchKVCache,
toks: &[i32],
pos: usize,
embed: &Tensor,
final_norm: &Tensor,
lm_head: &Tensor,
) -> Vec<f32> {
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
let dim = cfg.dim;
let g = toks.len();
let g_kv = g * num_kv; // batch·num_kv heads in the cache
let scale = 1.0 / (hd as f32).sqrt();
let (theta, eps) = (cfg.rope_theta, cfg.eps);
let n_layers = cfg.n_layers;
// Uniform per-row position (all G at the same decode step).
let positions = Tensor::from_slice(&vec![pos as i32; g], &[g]).to_device(device);
let ids = Tensor::from_slice(toks, &[g]).to_device(device);
let mut h = embed.embedding(&ids); // [G, dim] f32
if cdt == DType::BF16 {
h = h.to_dtype(DType::BF16);
}
for li in 0..n_layers {
let base = 1 + li * 11;
let (attn_norm, wq, wk, wv) =
(&params[base], &params[base + 1], &params[base + 2], &params[base + 3]);
let (q_norm, k_norm, wo) = (&params[base + 4], &params[base + 5], &params[base + 6]);
let (ffn_norm, w_gate, w_up, w_down) =
(&params[base + 7], &params[base + 8], &params[base + 9], &params[base + 10]);
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [G, dim]
// Q: project → per-head QK-norm → RoPE at `pos` for every row.
let q = linear_t(cdt, &normed, wq).reshape(&[g, nh, hd]);
let q = q.reshape(&[g * nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta);
let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]);
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
cache.append(li, &k_host, &v_host);
// Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd].
let t_len = cache.k[li].len() / (g_kv * hd);
let build = |flat: &[f32]| -> Tensor {
let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd])
.to_device(device)
.transpose_3d01(); // [G·num_kv, T, hd], f32
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd]
};
let k_full = build(&cache.k[li]);
let v_full = build(&cache.v[li]);
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
let attn = attn.reshape(&[g, dim]); // concat heads per sequence
let attn_out = linear_t(cdt, &attn, wo);
h = h.add(&attn_out);
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
let gate = linear_t(cdt, &normed, w_gate);
let up = linear_t(cdt, &normed, w_up);
let act = gate.silu().mul(&up);
let down = linear_t(cdt, &act, w_down);
h = h.add(&down);
}
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
linear_t(cdt, &h, lm_head)
.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}

View File

@@ -29,4 +29,4 @@ pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
#[cfg(not(no_cuda))]
pub mod decode;
#[cfg(not(no_cuda))]
pub use decode::{generate_cached, generate_greedy_cached};
pub use decode::{generate_cached, generate_cached_batch, generate_greedy_cached};

View File

@@ -822,6 +822,40 @@ impl Tensor {
out
}
/// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b).
/// `self`:[tokens,heads,head_dim]; row `t`'s position is `positions[t]` (an
/// I32 `[tokens]` tensor). For G-way batched decode all G rows share one decode
/// position; for ragged batches each row carries its own. Mirrors `rope_at`'s
/// dtype handling; forward only.
#[cfg(not(no_cuda))]
pub fn rope_pos(&self, positions: &Tensor, theta: f32) -> Self {
assert_eq!(self.ndim(), 3, "rope_pos requires [tokens,heads,head_dim]");
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
assert_eq!(head_dim % 2, 0, "head_dim must be even");
assert_eq!(positions.dtype, DType::I32, "positions must be I32");
assert_eq!(positions.numel(), tokens, "one position per token");
if self.dtype == DType::BF16 {
return self
.to_dtype(DType::F32)
.rope_pos(positions, theta)
.to_dtype(DType::BF16);
}
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_rope_pos_f32(
self.data_ptr() as *const f32,
positions.data_ptr() as *const i32,
out.data_ptr() as *mut f32,
tokens as i32,
heads as i32,
head_dim as i32,
theta,
std::ptr::null_mut(),
);
}
out
}
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
#[cfg(not(no_cuda))]

View File

@@ -159,3 +159,41 @@ fn decode_attention_matches_full_attention_last_row() {
);
println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})");
}
/// (e) `rope_pos` (per-row positions, M2b batched decode): with positions
/// [0,1,…,n-1] it is bit-identical to the full-sequence `rope` (period=n); with a
/// uniform position P every row matches `rope_at(·, P)` of that single row. This is
/// the primitive the batched decode uses (G rows sharing one decode position).
#[test]
fn rope_pos_matches_rope_and_rope_at() {
assert!(device::device_count().expect("device count") > 0, "no CUDA device");
device::set_device(0).unwrap();
let (n, heads, hd) = (7usize, 3usize, 8usize);
let theta = 10000.0f32;
let host: Vec<f32> = (0..n * heads * hd).map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0).collect();
let x = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0));
// positions [0,1,…,n-1] ⇒ identical to the full-sequence rope.
let seq_pos: Vec<i32> = (0..n as i32).collect();
let pos_t = Tensor::from_slice(&seq_pos, &[n]).to_device(Device::Cuda(0));
let got = x.rope_pos(&pos_t, theta).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let want = x.rope(theta, n).to_device(Device::Cpu).as_slice::<f32>().to_vec();
assert_eq!(got, want, "rope_pos [0..n] != full rope");
// uniform position P ⇒ each row matches rope_at(single row, P).
let p = 5i32;
let uni = Tensor::from_slice(&vec![p; n], &[n]).to_device(Device::Cuda(0));
let got_u = x.rope_pos(&uni, theta).to_device(Device::Cpu).as_slice::<f32>().to_vec();
let row_len = heads * hd;
for t in 0..n {
let row = &host[t * row_len..(t + 1) * row_len];
let want_row = Tensor::from_slice(row, &[1, heads, hd])
.to_device(Device::Cuda(0))
.rope_at(theta, p as usize)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec();
assert_eq!(&got_u[t * row_len..(t + 1) * row_len], want_row.as_slice(), "uniform pos row {t}");
}
println!("rope_pos OK: == full rope for [0..n] and == rope_at(P) per row for uniform P");
}

View File

@@ -0,0 +1,83 @@
// M2b batched KV-cache decode — the token-identical gate.
//
// Batched decode rolls out G samples of one prompt in lockstep (one common decode
// position each step, uniform RoPE via rope_pos, KV cache carrying a G dimension).
// Under GREEDY decoding all G rows are deterministic and must each equal the
// single-sequence greedy decode (generate_greedy_cached, itself gated token-
// identical to the naive sampler). This pins that the G-way batching indexes each
// sequence's K/V correctly (no cross-row contamination) and reproduces M2a exactly.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{generate_cached_batch, generate_greedy_cached, Config, TinyTransformer};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
.with_compute_dtype(DType::F32)
}
#[test]
fn batched_greedy_decode_matches_single_seq() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
let device = Device::Cuda(0);
// Real GQA (8 query / 2 kv heads → group 4) so repeat_kv(nh, batch=G) is exercised.
let cfg = Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2);
let model = build(cfg, device);
let prompt: Vec<i32> = vec![3, 9, 1, 14, 5];
let max_new = 24usize;
let g = 5usize;
let single = generate_greedy_cached(&model, device, &prompt, max_new);
let mut rng = 0u64;
let batched = generate_cached_batch(&model, device, &prompt, g, max_new, 0.0, &mut rng);
assert_eq!(batched.len(), g, "expected {g} sample rows");
for (row, seq) in batched.iter().enumerate() {
assert_eq!(
seq.len(),
single.len(),
"row {row} length {} vs single {}",
seq.len(),
single.len()
);
if seq != &single {
let first = seq.iter().zip(&single).position(|(a, b)| a != b).unwrap();
panic!(
"batched row {row} diverges from single-seq at index {first}: {:?} vs {:?}",
seq[first], single[first]
);
}
}
println!(
"batched decode OK: all {g} greedy rows token-identical to single-seq over {max_new} tokens"
);
}