The rollout long-pole fix deferred from M2a: decode the G samples of one prompt in lockstep (one forward per step over the group → G× fewer kernel launches). - rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward- only kernel) — G rows share one decode position. Gate: == full rope for [0..n], == rope_at(P) per row for uniform P (bit-identical). - generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step. decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G) broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next). - Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single- sequence decode (8 query / 2 kv heads → exercises repeat_kv batching). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
84 lines
2.9 KiB
Rust
84 lines
2.9 KiB
Rust
// M2b batched KV-cache decode — the token-identical gate.
|
|
//
|
|
// Batched decode rolls out G samples of one prompt in lockstep (one common decode
|
|
// position each step, uniform RoPE via rope_pos, KV cache carrying a G dimension).
|
|
// Under GREEDY decoding all G rows are deterministic and must each equal the
|
|
// single-sequence greedy decode (generate_greedy_cached, itself gated token-
|
|
// identical to the naive sampler). This pins that the G-way batching indexes each
|
|
// sequence's K/V correctly (no cross-row contamination) and reproduces M2a exactly.
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use xtrain_cuda::device;
|
|
use xtrain_model::{generate_cached_batch, generate_greedy_cached, Config, TinyTransformer};
|
|
use xtrain_tensor::{DType, Device};
|
|
|
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
|
let mut seed = 1u64;
|
|
TinyTransformer::new(cfg, device, |shape| {
|
|
seed = seed.wrapping_add(1);
|
|
let n: usize = shape.iter().product();
|
|
if shape.len() == 1 {
|
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
|
} else {
|
|
fill(n, seed, 0.08)
|
|
}
|
|
})
|
|
.with_compute_dtype(DType::F32)
|
|
}
|
|
|
|
#[test]
|
|
fn batched_greedy_decode_matches_single_seq() {
|
|
assert!(
|
|
device::device_count().expect("device count") > 0,
|
|
"no CUDA device"
|
|
);
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
|
|
// Real GQA (8 query / 2 kv heads → group 4) so repeat_kv(nh, batch=G) is exercised.
|
|
let cfg = Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2);
|
|
let model = build(cfg, device);
|
|
let prompt: Vec<i32> = vec![3, 9, 1, 14, 5];
|
|
let max_new = 24usize;
|
|
let g = 5usize;
|
|
|
|
let single = generate_greedy_cached(&model, device, &prompt, max_new);
|
|
let mut rng = 0u64;
|
|
let batched = generate_cached_batch(&model, device, &prompt, g, max_new, 0.0, &mut rng);
|
|
|
|
assert_eq!(batched.len(), g, "expected {g} sample rows");
|
|
for (row, seq) in batched.iter().enumerate() {
|
|
assert_eq!(
|
|
seq.len(),
|
|
single.len(),
|
|
"row {row} length {} vs single {}",
|
|
seq.len(),
|
|
single.len()
|
|
);
|
|
if seq != &single {
|
|
let first = seq.iter().zip(&single).position(|(a, b)| a != b).unwrap();
|
|
panic!(
|
|
"batched row {row} diverges from single-seq at index {first}: {:?} vs {:?}",
|
|
seq[first], single[first]
|
|
);
|
|
}
|
|
}
|
|
println!(
|
|
"batched decode OK: all {g} greedy rows token-identical to single-seq over {max_new} tokens"
|
|
);
|
|
}
|