Compare commits
3 Commits
096e45b845
...
0f76c0fdb0
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f76c0fdb0 | |||
| 361c5290fa | |||
| 2c9b58cb3b |
@@ -152,6 +152,18 @@ unsafe extern "C" {
|
|||||||
pos0: i32,
|
pos0: i32,
|
||||||
s: CudaStream,
|
s: CudaStream,
|
||||||
);
|
);
|
||||||
|
// RoPE with a per-row absolute position (batched KV-cache decode, M2b): row
|
||||||
|
// `tok`'s position is `positions[tok]`. Forward only.
|
||||||
|
pub fn launch_rope_pos_f32(
|
||||||
|
x: *const f32,
|
||||||
|
positions: *const i32,
|
||||||
|
y: *mut f32,
|
||||||
|
tokens: i32,
|
||||||
|
heads: i32,
|
||||||
|
head_dim: i32,
|
||||||
|
theta: f32,
|
||||||
|
s: CudaStream,
|
||||||
|
);
|
||||||
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
|
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
|
||||||
pub fn launch_scale_rows_f32(
|
pub fn launch_scale_rows_f32(
|
||||||
x: *const f32,
|
x: *const f32,
|
||||||
|
|||||||
@@ -265,3 +265,177 @@ fn argmax(row: &[f32]) -> usize {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.0
|
.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ===================================================================
|
||||||
|
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
|
||||||
|
// ===================================================================
|
||||||
|
|
||||||
|
/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates
|
||||||
|
/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt
|
||||||
|
/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension.
|
||||||
|
struct BatchKVCache {
|
||||||
|
k: Vec<Vec<f32>>,
|
||||||
|
v: Vec<Vec<f32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BatchKVCache {
|
||||||
|
fn new(n_layers: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
k: vec![Vec::new(); n_layers],
|
||||||
|
v: vec![Vec::new(); n_layers],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
|
||||||
|
self.k[li].extend_from_slice(k_tok);
|
||||||
|
self.v[li].extend_from_slice(v_tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batched KV-cache decode: roll out `n_samples` (G) completions of the SAME
|
||||||
|
/// `prompt` in lockstep — all G share the prompt, so they advance at one common
|
||||||
|
/// decode position each step (uniform RoPE via `rope_pos`). Returns G full token
|
||||||
|
/// sequences (prompt + sampled continuation). The G-way batching amortises the
|
||||||
|
/// per-step kernel launches across G (the rollout long-pole). Token-identical per
|
||||||
|
/// row to G independent single-sequence decodes (gated by `tests/decode_batch.rs`).
|
||||||
|
///
|
||||||
|
/// `temperature == 0` ⇒ greedy (all G identical); `> 0` ⇒ independent samples
|
||||||
|
/// (per-row draw from one shared `rng_state`). No finished-mask: all G generate
|
||||||
|
/// `max_new` tokens; the caller cuts each at `<|endoftext|>` (a perf-only early
|
||||||
|
/// stop is the M2b+ follow-up). Ragged (different-length prompts) is also deferred.
|
||||||
|
pub fn generate_cached_batch(
|
||||||
|
model: &TinyTransformer,
|
||||||
|
device: Device,
|
||||||
|
prompt: &[i32],
|
||||||
|
n_samples: usize,
|
||||||
|
max_new: usize,
|
||||||
|
temperature: f32,
|
||||||
|
rng_state: &mut u64,
|
||||||
|
) -> Vec<Vec<i32>> {
|
||||||
|
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
||||||
|
assert!(n_samples > 0, "n_samples must be > 0");
|
||||||
|
let cfg = model.config();
|
||||||
|
let cdt = model.compute_dtype();
|
||||||
|
let n_layers = cfg.n_layers;
|
||||||
|
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
|
||||||
|
let embed = ¶ms[0];
|
||||||
|
let final_norm = ¶ms[1 + n_layers * 11];
|
||||||
|
let lm_head = ¶ms[1 + n_layers * 11 + 1];
|
||||||
|
|
||||||
|
let g = n_samples;
|
||||||
|
let mut cache = BatchKVCache::new(n_layers);
|
||||||
|
let mut seqs: Vec<Vec<i32>> = vec![prompt.to_vec(); g];
|
||||||
|
|
||||||
|
// Prefill: feed each prompt token (identical across G) at its position.
|
||||||
|
let mut logits = Vec::new(); // [G, vocab] flattened
|
||||||
|
for (pos, &tok) in prompt.iter().enumerate() {
|
||||||
|
let toks = vec![tok; g];
|
||||||
|
logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &toks, pos, embed, final_norm, lm_head);
|
||||||
|
}
|
||||||
|
|
||||||
|
let vocab = cfg.vocab;
|
||||||
|
for _ in 0..max_new {
|
||||||
|
let mut next = Vec::with_capacity(g);
|
||||||
|
for row in 0..g {
|
||||||
|
let lg = &logits[row * vocab..(row + 1) * vocab];
|
||||||
|
let t = if temperature <= 0.0 {
|
||||||
|
argmax(lg) as i32
|
||||||
|
} else {
|
||||||
|
sample_temperature(lg, temperature, rng_state) as i32
|
||||||
|
};
|
||||||
|
next.push(t);
|
||||||
|
seqs[row].push(t);
|
||||||
|
}
|
||||||
|
let pos = seqs[0].len() - 1; // all G are at the same position
|
||||||
|
logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &next, pos, embed, final_norm, lm_head);
|
||||||
|
}
|
||||||
|
seqs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One batched decode step: `toks` is one current token per sequence (`[G]`), all at
|
||||||
|
/// absolute position `pos`. Appends each sequence's K/V and returns logits `[G·vocab]`.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn decode_step_batch(
|
||||||
|
params: &[Tensor],
|
||||||
|
cfg: &crate::Config,
|
||||||
|
cdt: DType,
|
||||||
|
device: Device,
|
||||||
|
cache: &mut BatchKVCache,
|
||||||
|
toks: &[i32],
|
||||||
|
pos: usize,
|
||||||
|
embed: &Tensor,
|
||||||
|
final_norm: &Tensor,
|
||||||
|
lm_head: &Tensor,
|
||||||
|
) -> Vec<f32> {
|
||||||
|
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||||||
|
let dim = cfg.dim;
|
||||||
|
let g = toks.len();
|
||||||
|
let g_kv = g * num_kv; // batch·num_kv heads in the cache
|
||||||
|
let scale = 1.0 / (hd as f32).sqrt();
|
||||||
|
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||||||
|
let n_layers = cfg.n_layers;
|
||||||
|
// Uniform per-row position (all G at the same decode step).
|
||||||
|
let positions = Tensor::from_slice(&vec![pos as i32; g], &[g]).to_device(device);
|
||||||
|
|
||||||
|
let ids = Tensor::from_slice(toks, &[g]).to_device(device);
|
||||||
|
let mut h = embed.embedding(&ids); // [G, dim] f32
|
||||||
|
if cdt == DType::BF16 {
|
||||||
|
h = h.to_dtype(DType::BF16);
|
||||||
|
}
|
||||||
|
|
||||||
|
for li in 0..n_layers {
|
||||||
|
let base = 1 + li * 11;
|
||||||
|
let (attn_norm, wq, wk, wv) =
|
||||||
|
(¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]);
|
||||||
|
let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]);
|
||||||
|
let (ffn_norm, w_gate, w_up, w_down) =
|
||||||
|
(¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]);
|
||||||
|
|
||||||
|
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [G, dim]
|
||||||
|
|
||||||
|
// Q: project → per-head QK-norm → RoPE at `pos` for every row.
|
||||||
|
let q = linear_t(cdt, &normed, wq).reshape(&[g, nh, hd]);
|
||||||
|
let q = q.reshape(&[g * nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
|
||||||
|
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
|
||||||
|
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
|
||||||
|
|
||||||
|
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
|
||||||
|
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||||||
|
let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta);
|
||||||
|
let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]);
|
||||||
|
|
||||||
|
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
cache.append(li, &k_host, &v_host);
|
||||||
|
|
||||||
|
// Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd].
|
||||||
|
let t_len = cache.k[li].len() / (g_kv * hd);
|
||||||
|
let build = |flat: &[f32]| -> Tensor {
|
||||||
|
let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd])
|
||||||
|
.to_device(device)
|
||||||
|
.transpose_3d01(); // [G·num_kv, T, hd], f32
|
||||||
|
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
|
||||||
|
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd]
|
||||||
|
};
|
||||||
|
let k_full = build(&cache.k[li]);
|
||||||
|
let v_full = build(&cache.v[li]);
|
||||||
|
|
||||||
|
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
|
||||||
|
let attn = attn.reshape(&[g, dim]); // concat heads per sequence
|
||||||
|
let attn_out = linear_t(cdt, &attn, wo);
|
||||||
|
h = h.add(&attn_out);
|
||||||
|
|
||||||
|
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
|
||||||
|
let gate = linear_t(cdt, &normed, w_gate);
|
||||||
|
let up = linear_t(cdt, &normed, w_up);
|
||||||
|
let act = gate.silu().mul(&up);
|
||||||
|
let down = linear_t(cdt, &act, w_down);
|
||||||
|
h = h.add(&down);
|
||||||
|
}
|
||||||
|
|
||||||
|
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
|
||||||
|
linear_t(cdt, &h, lm_head)
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.to_vec()
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,4 +29,4 @@ pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
|
|||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
pub mod decode;
|
pub mod decode;
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
pub use decode::{generate_cached, generate_greedy_cached};
|
pub use decode::{generate_cached, generate_cached_batch, generate_greedy_cached};
|
||||||
|
|||||||
@@ -822,6 +822,40 @@ impl Tensor {
|
|||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b).
|
||||||
|
/// `self`:[tokens,heads,head_dim]; row `t`'s position is `positions[t]` (an
|
||||||
|
/// I32 `[tokens]` tensor). For G-way batched decode all G rows share one decode
|
||||||
|
/// position; for ragged batches each row carries its own. Mirrors `rope_at`'s
|
||||||
|
/// dtype handling; forward only.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
pub fn rope_pos(&self, positions: &Tensor, theta: f32) -> Self {
|
||||||
|
assert_eq!(self.ndim(), 3, "rope_pos requires [tokens,heads,head_dim]");
|
||||||
|
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||||
|
assert_eq!(head_dim % 2, 0, "head_dim must be even");
|
||||||
|
assert_eq!(positions.dtype, DType::I32, "positions must be I32");
|
||||||
|
assert_eq!(positions.numel(), tokens, "one position per token");
|
||||||
|
if self.dtype == DType::BF16 {
|
||||||
|
return self
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
.rope_pos(positions, theta)
|
||||||
|
.to_dtype(DType::BF16);
|
||||||
|
}
|
||||||
|
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||||
|
unsafe {
|
||||||
|
xtrain_cuda::ffi::launch_rope_pos_f32(
|
||||||
|
self.data_ptr() as *const f32,
|
||||||
|
positions.data_ptr() as *const i32,
|
||||||
|
out.data_ptr() as *mut f32,
|
||||||
|
tokens as i32,
|
||||||
|
heads as i32,
|
||||||
|
head_dim as i32,
|
||||||
|
theta,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
||||||
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
|
|||||||
@@ -159,3 +159,41 @@ fn decode_attention_matches_full_attention_last_row() {
|
|||||||
);
|
);
|
||||||
println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})");
|
println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// (e) `rope_pos` (per-row positions, M2b batched decode): with positions
|
||||||
|
/// [0,1,…,n-1] it is bit-identical to the full-sequence `rope` (period=n); with a
|
||||||
|
/// uniform position P every row matches `rope_at(·, P)` of that single row. This is
|
||||||
|
/// the primitive the batched decode uses (G rows sharing one decode position).
|
||||||
|
#[test]
|
||||||
|
fn rope_pos_matches_rope_and_rope_at() {
|
||||||
|
assert!(device::device_count().expect("device count") > 0, "no CUDA device");
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
let (n, heads, hd) = (7usize, 3usize, 8usize);
|
||||||
|
let theta = 10000.0f32;
|
||||||
|
let host: Vec<f32> = (0..n * heads * hd).map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0).collect();
|
||||||
|
let x = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0));
|
||||||
|
|
||||||
|
// positions [0,1,…,n-1] ⇒ identical to the full-sequence rope.
|
||||||
|
let seq_pos: Vec<i32> = (0..n as i32).collect();
|
||||||
|
let pos_t = Tensor::from_slice(&seq_pos, &[n]).to_device(Device::Cuda(0));
|
||||||
|
let got = x.rope_pos(&pos_t, theta).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
let want = x.rope(theta, n).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
assert_eq!(got, want, "rope_pos [0..n] != full rope");
|
||||||
|
|
||||||
|
// uniform position P ⇒ each row matches rope_at(single row, P).
|
||||||
|
let p = 5i32;
|
||||||
|
let uni = Tensor::from_slice(&vec![p; n], &[n]).to_device(Device::Cuda(0));
|
||||||
|
let got_u = x.rope_pos(&uni, theta).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
let row_len = heads * hd;
|
||||||
|
for t in 0..n {
|
||||||
|
let row = &host[t * row_len..(t + 1) * row_len];
|
||||||
|
let want_row = Tensor::from_slice(row, &[1, heads, hd])
|
||||||
|
.to_device(Device::Cuda(0))
|
||||||
|
.rope_at(theta, p as usize)
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.to_vec();
|
||||||
|
assert_eq!(&got_u[t * row_len..(t + 1) * row_len], want_row.as_slice(), "uniform pos row {t}");
|
||||||
|
}
|
||||||
|
println!("rope_pos OK: == full rope for [0..n] and == rope_at(P) per row for uniform P");
|
||||||
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ use xtrain_autodiff::ops;
|
|||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
use xtrain_cuda::device;
|
use xtrain_cuda::device;
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
|
use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor};
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
use xtrain_tensor::{DType, Device};
|
use xtrain_tensor::{DType, Device};
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
@@ -205,12 +205,12 @@ fn main() {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|t| t as i32)
|
.map(|t| t as i32)
|
||||||
.collect();
|
.collect();
|
||||||
|
// M2b batched rollout: the G samples of this prompt decode in lockstep
|
||||||
|
// (one forward per step over the whole group → G× fewer kernel launches
|
||||||
|
// than G sequential single-seq rollouts; the M4 rollout long-pole fix).
|
||||||
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
||||||
for _ in 0..group {
|
let outs = generate_cached_batch(&policy, device, &prompt_ids, group, max_new, temp, &mut rng);
|
||||||
// KV-cache temperature rollout (M2 engine): single-row logits per
|
for out in &outs {
|
||||||
// step → far lighter on the allocator than the naive sampler, which
|
|
||||||
// fragments it over a long rollout (the M4 rollout long-pole).
|
|
||||||
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
|
|
||||||
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||||
let seg = first_answer_segment(&cont).trim().to_string();
|
let seg = first_answer_segment(&cont).trim().to_string();
|
||||||
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
||||||
|
|||||||
83
crates/xtrain-train/tests/decode_batch.rs
Normal file
83
crates/xtrain-train/tests/decode_batch.rs
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
// M2b batched KV-cache decode — the token-identical gate.
|
||||||
|
//
|
||||||
|
// Batched decode rolls out G samples of one prompt in lockstep (one common decode
|
||||||
|
// position each step, uniform RoPE via rope_pos, KV cache carrying a G dimension).
|
||||||
|
// Under GREEDY decoding all G rows are deterministic and must each equal the
|
||||||
|
// single-sequence greedy decode (generate_greedy_cached, itself gated token-
|
||||||
|
// identical to the naive sampler). This pins that the G-way batching indexes each
|
||||||
|
// sequence's K/V correctly (no cross-row contamination) and reproduces M2a exactly.
|
||||||
|
#![cfg(not(no_cuda))]
|
||||||
|
|
||||||
|
use xtrain_cuda::device;
|
||||||
|
use xtrain_model::{generate_cached_batch, generate_greedy_cached, Config, TinyTransformer};
|
||||||
|
use xtrain_tensor::{DType, Device};
|
||||||
|
|
||||||
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||||
|
let mut state = seed
|
||||||
|
.wrapping_mul(2862933555777941757)
|
||||||
|
.wrapping_add(3037000493);
|
||||||
|
(0..n)
|
||||||
|
.map(|_| {
|
||||||
|
state = state
|
||||||
|
.wrapping_mul(6364136223846793005)
|
||||||
|
.wrapping_add(1442695040888963407);
|
||||||
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||||
|
let mut seed = 1u64;
|
||||||
|
TinyTransformer::new(cfg, device, |shape| {
|
||||||
|
seed = seed.wrapping_add(1);
|
||||||
|
let n: usize = shape.iter().product();
|
||||||
|
if shape.len() == 1 {
|
||||||
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||||
|
} else {
|
||||||
|
fill(n, seed, 0.08)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.with_compute_dtype(DType::F32)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn batched_greedy_decode_matches_single_seq() {
|
||||||
|
assert!(
|
||||||
|
device::device_count().expect("device count") > 0,
|
||||||
|
"no CUDA device"
|
||||||
|
);
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
let device = Device::Cuda(0);
|
||||||
|
|
||||||
|
// Real GQA (8 query / 2 kv heads → group 4) so repeat_kv(nh, batch=G) is exercised.
|
||||||
|
let cfg = Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2);
|
||||||
|
let model = build(cfg, device);
|
||||||
|
let prompt: Vec<i32> = vec![3, 9, 1, 14, 5];
|
||||||
|
let max_new = 24usize;
|
||||||
|
let g = 5usize;
|
||||||
|
|
||||||
|
let single = generate_greedy_cached(&model, device, &prompt, max_new);
|
||||||
|
let mut rng = 0u64;
|
||||||
|
let batched = generate_cached_batch(&model, device, &prompt, g, max_new, 0.0, &mut rng);
|
||||||
|
|
||||||
|
assert_eq!(batched.len(), g, "expected {g} sample rows");
|
||||||
|
for (row, seq) in batched.iter().enumerate() {
|
||||||
|
assert_eq!(
|
||||||
|
seq.len(),
|
||||||
|
single.len(),
|
||||||
|
"row {row} length {} vs single {}",
|
||||||
|
seq.len(),
|
||||||
|
single.len()
|
||||||
|
);
|
||||||
|
if seq != &single {
|
||||||
|
let first = seq.iter().zip(&single).position(|(a, b)| a != b).unwrap();
|
||||||
|
panic!(
|
||||||
|
"batched row {row} diverges from single-seq at index {first}: {:?} vs {:?}",
|
||||||
|
seq[first], single[first]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"batched decode OK: all {g} greedy rows token-identical to single-seq over {max_new} tokens"
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -269,6 +269,33 @@ void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
|||||||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b): row `tok`'s
|
||||||
|
// position is `positions[tok]` (an i32 per token). For G-way batched decode all G
|
||||||
|
// rows share one decode position; for ragged batches each row carries its own.
|
||||||
|
// Forward only; the training rope_k is untouched.
|
||||||
|
__global__ void rope_pos_k(const float* x, const int* positions, float* y,
|
||||||
|
int heads, int head_dim, float theta) {
|
||||||
|
int tok = blockIdx.x;
|
||||||
|
int head = blockIdx.y;
|
||||||
|
int half = head_dim / 2;
|
||||||
|
int i = threadIdx.x;
|
||||||
|
if (i >= half) return;
|
||||||
|
int pos = positions[tok];
|
||||||
|
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||||
|
float angle = (float)pos * freq;
|
||||||
|
float c = cosf(angle), sn = sinf(angle);
|
||||||
|
int base = (tok * heads + head) * head_dim;
|
||||||
|
float x0 = x[base + i], x1 = x[base + i + half];
|
||||||
|
y[base + i] = x0 * c - x1 * sn;
|
||||||
|
y[base + i + half] = x1 * c + x0 * sn;
|
||||||
|
}
|
||||||
|
void launch_rope_pos_f32(const float* x, const int* positions, float* y,
|
||||||
|
int tokens, int heads, int head_dim, float theta, void* s) {
|
||||||
|
dim3 grid(tokens, heads);
|
||||||
|
int blk = head_dim / 2;
|
||||||
|
rope_pos_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, positions, y, heads, head_dim, theta);
|
||||||
|
}
|
||||||
|
|
||||||
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
||||||
// (M4) policy-gradient backward, where each completion token's row of
|
// (M4) policy-gradient backward, where each completion token's row of
|
||||||
// (probs − onehot) is scaled by its own per-token coefficient.
|
// (probs − onehot) is scaled by its own per-token coefficient.
|
||||||
|
|||||||
@@ -522,3 +522,35 @@ leash wired, format held); the held-out flatness + the two memory/throughput wal
|
|||||||
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
|
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
|
||||||
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
|
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
|
||||||
limits on what alignment can do for a capability the base model lacks.
|
limits on what alignment can do for a capability the base model lacks.
|
||||||
|
|
||||||
|
### M2b — batched KV-cache decode (landed; completes the M2 engine, fixes the rollout long-pole)
|
||||||
|
|
||||||
|
Built after M4 (where the rollout long-pole bit hardest): decode the **G samples of one prompt in
|
||||||
|
lockstep** — one forward per step over the whole group → G× fewer kernel launches, the deferred
|
||||||
|
fix from M2a.
|
||||||
|
|
||||||
|
**One new primitive:** `rope_pos(x, positions[])` — RoPE with a *per-row* absolute position (new
|
||||||
|
forward-only kernel), since the G batched rows share one decode position (M2a's `rope_at` does
|
||||||
|
`pos0 + row`, wrong for a batch at a single position). **Gate:** bit-identical to the full rope
|
||||||
|
for positions `[0..n]`, and to `rope_at(P)` per row for a uniform `P`.
|
||||||
|
|
||||||
|
**Engine (`generate_cached_batch`):** `BatchKVCache` carries a G dimension (`[T, G·num_kv, hd]`
|
||||||
|
host-accumulated → `[G·num_kv, T, hd]`); the batched `decode_step` threads G through embed /
|
||||||
|
projections / QK-norm / `rope_pos` / cache. Two M2a pieces drop in unchanged: `decode_attention`
|
||||||
|
is already batch-agnostic (`bh = G·nh`), and `repeat_kv(nh, batch=G)` broadcasts per group. No
|
||||||
|
finished-mask (all G generate `max_new`; the caller cuts at EOS) and no ragged-length prompts yet
|
||||||
|
— both perf-only follow-ups.
|
||||||
|
|
||||||
|
**Gate (token-identical):** all G **greedy** rows are byte-identical to the single-sequence decode
|
||||||
|
(`tests/decode_batch.rs`, 8 query / 2 kv heads → exercises the `repeat_kv` batching) — pins that
|
||||||
|
G-way batching indexes each sequence's K/V with no cross-row contamination.
|
||||||
|
|
||||||
|
**Throughput (v12 1.05B, G=6·B=6, easy task, rollout wired into `train_grpo`):** ~8.5 s/step vs
|
||||||
|
~14–16 s/step for the single-seq cached rollout — **~1.7×**, rollout-inclusive. Short of the full
|
||||||
|
G× because (a) the per-token-logp forwards + the PG update also cost, and (b) the M2a per-layer
|
||||||
|
**host round-trip** is still there (now G× the data in one transfer, not removed). The full
|
||||||
|
device-side cache (no host round-trip) is the remaining decode-engine optimization. Batching also
|
||||||
|
**stabilises memory**: one batched forward per step vs G separate allocations that fragmented the
|
||||||
|
caching allocator (the M4 OOM). So M2b closes the decode-engine milestone (M2a single-seq + M2b
|
||||||
|
batched) and turns the rollout long-pole from "OOM/unbounded" into a bounded ~1.7× win — measured,
|
||||||
|
with the device-cache as the named next lever.
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
|||||||
|
|
||||||
**M4(GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**:新算子 `clipped_pg_loss`(per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。环 `train_grpo`:采 B prompt × rollout G → checker reward 0/1 → group-relative advantage `(r−mean)/(std+ε)`(无 critic,全对/全错组跳过)→ 存 πθ_old/πref per-token → K 内层 clipped-PG。rollout 用 **M2 引擎 + 新加的 temperature 采样**(单行 logits 比 naive `[seq,vocab]` 轻)。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ 在 easy(操作数≤20)上从 v12 base 重训 SFT → held-out **18.7%**;但 250/600 步同样 18.7% = 1B web-text 模型从 ~550 例**不泛化加减法、只记 train**。**两道系统墙(设计文档 Risks 预言)**:① 显存——KL-leash 要 policy+reference 两个 1B fp32-master+Adam≈21GB,加激活在 32GB 5090 上不稳定 OOM → 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81(被 train 重叠抬),**format 100/100 不崩**(温和 lr 下 β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT → KV-cache → DPO → GRPO 栈**,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界。
|
**M4(GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**:新算子 `clipped_pg_loss`(per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。环 `train_grpo`:采 B prompt × rollout G → checker reward 0/1 → group-relative advantage `(r−mean)/(std+ε)`(无 critic,全对/全错组跳过)→ 存 πθ_old/πref per-token → K 内层 clipped-PG。rollout 用 **M2 引擎 + 新加的 temperature 采样**(单行 logits 比 naive `[seq,vocab]` 轻)。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ 在 easy(操作数≤20)上从 v12 base 重训 SFT → held-out **18.7%**;但 250/600 步同样 18.7% = 1B web-text 模型从 ~550 例**不泛化加减法、只记 train**。**两道系统墙(设计文档 Risks 预言)**:① 显存——KL-leash 要 policy+reference 两个 1B fp32-master+Adam≈21GB,加激活在 32GB 5090 上不稳定 OOM → 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81(被 train 重叠抬),**format 100/100 不崩**(温和 lr 下 β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT → KV-cache → DPO → GRPO 栈**,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界。
|
||||||
|
|
||||||
|
**M2b(批量 KV-cache 解码,已落地,补全 M2 引擎 + 修 rollout 长杆)**:M4 后补的 rollout 长杆修复——一个 prompt 的 **G 个样本同步解码**(每步一次 forward 跑整组 → G× 更少 kernel 启动)。一个新原语 `rope_pos`(逐 row 绝对位置 kernel,G 行共享一个解码位置;闸门 = `[0..n]` 逐位等于全 rope、统一 P 逐行等于 `rope_at(P)`,bit-identical)。引擎 `generate_cached_batch`:`BatchKVCache` 带 G 维,批量 `decode_step` 把 G 贯穿 embed/proj/QK-norm/`rope_pos`/cache;**M2a 两件零改动复用**——`decode_attention` 本就 batch-agnostic(bh=G·nh)、`repeat_kv(nh,batch=G)` 按组广播。闸门 = G 个贪心行逐字节等于单序列(`tests/decode_batch.rs`,8q/2kv 头练 repeat_kv 批量)。**吞吐**(v12, G6·B6, 接进 train_grpo):**~8.5s/step vs 单序列 ~14-16s/step ≈ 1.7×**(rollout-inclusive;未到满 G× 因 per_token_logp + PG 更新也占时间、M2a 主机往返还在);且**显存更稳**(一次批量 forward vs G 次分配撑碎 allocator 的 M4 OOM)。⇒ M2 引擎闭环(M2a 单序列 + M2b 批量),rollout 长杆从"OOM/无界"变成有界 ~1.7× 收益,device 端 cache 是点名的下一杠杆。
|
||||||
|
|
||||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||||
|
|
||||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||||
|
|||||||
Reference in New Issue
Block a user