Files
xtrain/crates/xtrain-train/tests/decode_batch.rs
Gahow Wang 2c9b58cb3b post-train: M2b — batched KV-cache decode (G-way, token-identical)
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).

- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
  only kernel) — G rows share one decode position. Gate: == full rope for
  [0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
  decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
  broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
  sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:18:54 +08:00

84 lines
2.9 KiB
Rust

// M2b batched KV-cache decode — the token-identical gate.
//
// Batched decode rolls out G samples of one prompt in lockstep (one common decode
// position each step, uniform RoPE via rope_pos, KV cache carrying a G dimension).
// Under GREEDY decoding all G rows are deterministic and must each equal the
// single-sequence greedy decode (generate_greedy_cached, itself gated token-
// identical to the naive sampler). This pins that the G-way batching indexes each
// sequence's K/V correctly (no cross-row contamination) and reproduces M2a exactly.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{generate_cached_batch, generate_greedy_cached, Config, TinyTransformer};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
.with_compute_dtype(DType::F32)
}
#[test]
fn batched_greedy_decode_matches_single_seq() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
let device = Device::Cuda(0);
// Real GQA (8 query / 2 kv heads → group 4) so repeat_kv(nh, batch=G) is exercised.
let cfg = Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2);
let model = build(cfg, device);
let prompt: Vec<i32> = vec![3, 9, 1, 14, 5];
let max_new = 24usize;
let g = 5usize;
let single = generate_greedy_cached(&model, device, &prompt, max_new);
let mut rng = 0u64;
let batched = generate_cached_batch(&model, device, &prompt, g, max_new, 0.0, &mut rng);
assert_eq!(batched.len(), g, "expected {g} sample rows");
for (row, seq) in batched.iter().enumerate() {
assert_eq!(
seq.len(),
single.len(),
"row {row} length {} vs single {}",
seq.len(),
single.len()
);
if seq != &single {
let first = seq.iter().zip(&single).position(|(a, b)| a != b).unwrap();
panic!(
"batched row {row} diverges from single-seq at index {first}: {:?} vs {:?}",
seq[first], single[first]
);
}
}
println!(
"batched decode OK: all {g} greedy rows token-identical to single-seq over {max_new} tokens"
);
}