Compare commits
11 Commits
096e45b845
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 6465a2d5ce | |||
| 33a1aee9ec | |||
| 86de6bfb51 | |||
| 4379868f2d | |||
| 0e82b2438e | |||
| c2ebf62ae1 | |||
| 41d46208a6 | |||
| 3a3425960c | |||
| 0f76c0fdb0 | |||
| 361c5290fa | |||
| 2c9b58cb3b |
@@ -597,3 +597,87 @@ pub fn clipped_pg_loss(
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Batched GRPO clipped-PG loss over `N` ragged completions packed into ONE
|
||||
/// `forward_batched` (M2d): `logits` is `[R, vocab]` with `R = N·Lmax` rows in
|
||||
/// sequence-major order (sample 0's `Lmax` rows, then sample 1's, …), each ragged
|
||||
/// completion right-padded to the batch's `Lmax`. Prompt AND pad rows are masked
|
||||
/// (`target < 0`), so they contribute nothing and carry no gradient — the
|
||||
/// **right-pad-is-free-under-causal-attention** property (a real completion row
|
||||
/// never attends to the trailing pad rows, so its logits equal the unpadded
|
||||
/// single-sequence forward's).
|
||||
///
|
||||
/// Unlike the per-sample [`clipped_pg_loss`] (which folds a single scalar
|
||||
/// `advantage` and a global `1/N_tokens` normaliser), this op takes **per-row**
|
||||
/// `advantage[t]` (the owning sample's group-relative `A`) and **per-row**
|
||||
/// `weight[t]` (the full normaliser, e.g. `1/(N_samples · n_s)` for sample `s`'s
|
||||
/// completion rows, `0` at masked rows). It does NOT compute its own `inv_n`. With
|
||||
/// `weight[t] = 1/(N_samples·n_s)` and `advantage[t] = A_s` this is **bit-equivalent
|
||||
/// to the looped path** `Σ_s scale·(1/n_s)·clipped_pg_loss_s` (`scale = 1/N_samples`):
|
||||
/// the per-row backward is local (`cross_entropy_backward` is row-wise), so the
|
||||
/// batched row-`t` gradient equals the looped sample-`s` row-`t` gradient, and the
|
||||
/// scalar loss equals the looped weighted sum. (`tests/autograd.rs`:
|
||||
/// `clipped_pg_loss_batched_matches_looped`.) Degenerate points match
|
||||
/// [`clipped_pg_loss`] (`A=0` ⇒ KL only; `ε→∞` ⇒ vanilla PG; `β=0` ⇒ no KL).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn clipped_pg_loss_batched(
|
||||
logits: &Var,
|
||||
target: &Tensor,
|
||||
logp_old: &[f32],
|
||||
logp_ref: &[f32],
|
||||
advantage: &[f32],
|
||||
weight: &[f32],
|
||||
eps: f32,
|
||||
beta: f32,
|
||||
) -> Var {
|
||||
use xtrain_tensor::Device;
|
||||
let logit_dtype = logits.value().dtype();
|
||||
let (probs, per_row) = logits.value().cross_entropy(target);
|
||||
let rows = per_row.shape()[0];
|
||||
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
|
||||
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per row");
|
||||
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per row");
|
||||
assert_eq!(advantage.len(), rows, "advantage must have one entry per row");
|
||||
assert_eq!(weight.len(), rows, "weight must have one entry per row");
|
||||
|
||||
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
|
||||
let mut loss_val = 0f32;
|
||||
for t in 0..rows {
|
||||
if target_h[t] < 0 {
|
||||
continue; // masked (prompt or pad) row — no contribution, no gradient
|
||||
}
|
||||
let (a, w) = (advantage[t], weight[t]);
|
||||
let lp = -per_row_h[t]; // logπθ_t
|
||||
let ratio = (lp - logp_old[t]).exp();
|
||||
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
|
||||
let (unclipped_term, clipped_term) = (ratio * a, clipped * a);
|
||||
let pg_t = unclipped_term.min(clipped_term);
|
||||
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
|
||||
let d = logp_ref[t] - lp;
|
||||
let kl_t = d.exp() - d - 1.0;
|
||||
let pg_grad = if active { -a * ratio } else { 0.0 };
|
||||
let kl_grad = beta * (1.0 - d.exp());
|
||||
// The full per-row normaliser is folded into s (no global inv_n here).
|
||||
s[t] = -(pg_grad + kl_grad) * w;
|
||||
loss_val += (-pg_t + beta * kl_t) * w;
|
||||
}
|
||||
let dev = logits.value().device();
|
||||
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
|
||||
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
|
||||
|
||||
let target = target.clone();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![logits.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
|
||||
let mut dx = ce.scale_rows(&s_dev);
|
||||
if up != 1.0 {
|
||||
dx = dx.scale(up);
|
||||
}
|
||||
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1177,3 +1177,94 @@ fn clipped_pg_loss_bwd_and_degenerate() {
|
||||
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
|
||||
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
|
||||
}
|
||||
|
||||
// clipped_pg_loss_batched (M2d): N ragged completions packed + right-padded into ONE
|
||||
// forward must equal the looped per-sample path Σ_s (1/N)·clipped_pg_loss_s. The
|
||||
// per-row CE backward is row-local, so folding weight = 1/(N·n_s) into the batched
|
||||
// op reproduces the looped gradient and weighted-sum loss bit-for-bit (f32 path).
|
||||
#[test]
|
||||
fn clipped_pg_loss_batched_matches_looped() {
|
||||
require_gpu();
|
||||
let (n, lmax, cols) = (3usize, 5usize, 10usize);
|
||||
let rows = n * lmax;
|
||||
let x_h = fill(rows * cols, 909);
|
||||
// Per sample: row 0 = prompt (-100); rows 1..real_len = completion; rest = pad
|
||||
// (-100). Different real_len ⇒ n_s = {2, 3, 1} completion rows.
|
||||
let real_len = [3usize, 4, 2];
|
||||
let adv_s = [0.7f32, -0.5, 0.3];
|
||||
let mut targets = vec![-100i32; rows];
|
||||
for s in 0..n {
|
||||
for r in 1..real_len[s] {
|
||||
let t = s * lmax + r;
|
||||
targets[t] = ((t * 3) % cols) as i32;
|
||||
}
|
||||
}
|
||||
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||
|
||||
// logp_old ≈ logπθ at base logits (ρ≈1), logp_ref offset to exercise the KL term.
|
||||
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
|
||||
let logp_old: Vec<f32> = per_row0
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect();
|
||||
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect();
|
||||
let (eps, beta) = (0.2f32, 0.1f32);
|
||||
|
||||
// Per-row advantage (sample's A) + per-row weight 1/(N·n_s) (full normaliser).
|
||||
let n_of = |s: usize| (0..lmax).filter(|&r| targets[s * lmax + r] >= 0).count() as f32;
|
||||
let mut advantage = vec![0f32; rows];
|
||||
let mut weight = vec![0f32; rows];
|
||||
for s in 0..n {
|
||||
let w = (1.0 / n as f32) * (1.0 / n_of(s));
|
||||
for r in 0..lmax {
|
||||
advantage[s * lmax + r] = adv_s[s];
|
||||
weight[s * lmax + r] = w;
|
||||
}
|
||||
}
|
||||
|
||||
// Batched: one packed [R, vocab] forward + one backward.
|
||||
let xb = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||
let lb = ops::clipped_pg_loss_batched(
|
||||
&xb, &mk_target(), &logp_old, &logp_ref, &advantage, &weight, eps, beta,
|
||||
);
|
||||
lb.backward();
|
||||
let gb = xb.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let lb_val = lb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
|
||||
// Looped reference: per-sample slice → clipped_pg_loss → scale(1/N) → backward.
|
||||
let mut g_ref = vec![0f32; rows * cols];
|
||||
let mut loss_ref = 0f32;
|
||||
for s in 0..n {
|
||||
let r0 = s * lmax;
|
||||
let xs_h = x_h[r0 * cols..(r0 + lmax) * cols].to_vec();
|
||||
let tgt_s: Vec<i32> = targets[r0..r0 + lmax].to_vec();
|
||||
let lo_s = logp_old[r0..r0 + lmax].to_vec();
|
||||
let lr_s = logp_ref[r0..r0 + lmax].to_vec();
|
||||
let xs = Var::leaf(cuda(&xs_h, &[lmax, cols]));
|
||||
let tgt = Tensor::from_slice(&tgt_s, &[lmax]).to_device(Device::Cuda(0));
|
||||
let ls = ops::clipped_pg_loss(&xs, &tgt, &lo_s, &lr_s, adv_s[s], eps, beta);
|
||||
let scaled = ops::scale(&ls, 1.0 / n as f32);
|
||||
scaled.backward();
|
||||
let gs = xs.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
g_ref[r0 * cols..(r0 + lmax) * cols].copy_from_slice(&gs);
|
||||
loss_ref += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
}
|
||||
|
||||
let max_g = gb
|
||||
.iter()
|
||||
.zip(&g_ref)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
assert!(
|
||||
(lb_val - loss_ref).abs() < 1e-5,
|
||||
"batched loss {lb_val} vs looped {loss_ref}"
|
||||
);
|
||||
assert!(max_g < 1e-5, "batched grad vs looped: max|Δ| = {max_g}");
|
||||
println!(
|
||||
"clipped_pg_loss_batched OK: loss Δ={:.2e}, grad max|Δ|={:.2e} (== looped Σ_s 1/N·pg_s)",
|
||||
(lb_val - loss_ref).abs(),
|
||||
max_g
|
||||
);
|
||||
}
|
||||
|
||||
@@ -152,6 +152,29 @@ unsafe extern "C" {
|
||||
pos0: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// RoPE with a per-row absolute position (batched KV-cache decode, M2b): row
|
||||
// `tok`'s position is `positions[tok]`. Forward only.
|
||||
pub fn launch_rope_pos_f32(
|
||||
x: *const f32,
|
||||
positions: *const i32,
|
||||
y: *mut f32,
|
||||
tokens: i32,
|
||||
heads: i32,
|
||||
head_dim: i32,
|
||||
theta: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Concatenate along the sequence dim: a:[bh,ta,hd], b:[bh,tb,hd] →
|
||||
// out:[bh,ta+tb,hd] (device-side KV-cache append, M2c).
|
||||
pub fn launch_cat_seq_f32(
|
||||
a: *const f32,
|
||||
b: *const f32,
|
||||
out: *mut f32,
|
||||
bh: i32,
|
||||
ta_hd: i32,
|
||||
tb_hd: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
|
||||
pub fn launch_scale_rows_f32(
|
||||
x: *const f32,
|
||||
|
||||
@@ -10,7 +10,9 @@
|
||||
//!
|
||||
//! Versus `train_ddp` (thread-per-GPU, kept as the regression baseline) the ONLY
|
||||
//! difference is the launch model + cross-process UniqueId bootstrap. CLI flags
|
||||
//! are identical, so it doubles as the before→after throughput driver.
|
||||
//! mirror `train_ddp` (incl. `--dropout` — same T21 wiring: `cfg.dropout` set here
|
||||
//! and `train_rank` re-asserts `model.train()` each step), so it doubles as the
|
||||
//! before→after throughput driver.
|
||||
//!
|
||||
//! Run on dash5 (pick idle GPUs — dash5 is shared):
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
@@ -108,6 +110,11 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
// Dropout (Phase T18/T21): residual-path dropout prob, active at training time
|
||||
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
|
||||
// (bit-identical to the no-dropout path). Mirrors bin/train_ddp; propagates into
|
||||
// cfg.dropout (below) and relies on T21's per-step model.train() in train_rank.
|
||||
let dropout: f32 = flag(&args, "--dropout", 0.0f32);
|
||||
let opts = ModelOpts {
|
||||
bf16: args.iter().any(|a| a == "--bf16"),
|
||||
recompute: args.iter().any(|a| a == "--recompute"),
|
||||
@@ -136,7 +143,9 @@ fn main() {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
let mut cfg =
|
||||
Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
cfg.dropout = dropout;
|
||||
|
||||
if env.rank == 0 {
|
||||
println!(
|
||||
@@ -162,6 +171,9 @@ fn main() {
|
||||
if opts.flash {
|
||||
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
|
||||
}
|
||||
if dropout > 0.0 {
|
||||
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
|
||||
}
|
||||
}
|
||||
|
||||
let dcfg = DdpConfig {
|
||||
|
||||
@@ -10,6 +10,14 @@
|
||||
//! (a) multi-process loss matches single-GPU within `<1e-3`,
|
||||
//! (b) cross-rank params agree within `<1e-6` (KI-5 ULP tolerance),
|
||||
//! (c) multi-process loss matches the thread-per-GPU `launch` path within `<1e-3`.
|
||||
//!
|
||||
//! T21-for-proc regression `proc_per_gpu_dropout_is_live_and_p0_matches_no_dropout`
|
||||
//! (below) additionally proves that `--dropout` propagates through the process-per-
|
||||
//! GPU launcher — the analogue of the thread-per-GPU T21 fix. Pre-fix
|
||||
//! `train_ddp_mp` had no `--dropout` flag, so `cfg.dropout` stayed 0 regardless of
|
||||
//! what the user passed, silently disabling dropout under process-per-GPU. The
|
||||
//! GATE B loss-trace signal (>1e-3 gap between p=0 and p=0.2) sits orders of
|
||||
//! magnitude above the KI-5 cross-rank noise floor and catches that gap directly.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
@@ -74,8 +82,20 @@ fn dcfg(batch_size: usize) -> DdpConfig {
|
||||
// The dump dir is passed launcher→worker via this env key (separate from the
|
||||
// XTRAIN_* keys the launcher sets); workers write `rank{N}.dump` there.
|
||||
const ENV_DUMP_DIR: &str = "XTRAIN_TEST_DUMP_DIR";
|
||||
// Optional launcher→worker channel for `cfg.dropout`. Absent = 0.0 = the existing
|
||||
// correctness test's contract (no perturbation). The T21-for-proc regression test
|
||||
// below sets it before each `launch_processes` call to prove the process-per-GPU
|
||||
// path actually plumbs `--dropout` into every worker's model.
|
||||
const ENV_DROPOUT: &str = "XTRAIN_TEST_DROPOUT";
|
||||
const GLOBAL_BATCH: usize = 8;
|
||||
|
||||
fn worker_dropout() -> f32 {
|
||||
std::env::var(ENV_DROPOUT)
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0.0)
|
||||
}
|
||||
|
||||
// ── Worker entry: runs when this test binary is re-execed by launch_processes ─
|
||||
|
||||
fn run_as_worker_if_needed() {
|
||||
@@ -87,7 +107,13 @@ fn run_as_worker_if_needed() {
|
||||
// production `run_worker` wrapper is exercised by `bin/train_ddp_mp` on dash5.
|
||||
let ctx = DdpContext::init(env.rank, env.world, env.id, env.local_rank);
|
||||
let device = Device::Cuda(env.local_rank);
|
||||
let model = build_model(test_config(), device);
|
||||
// Mirrors bin/train_ddp_mp's `cfg.dropout = dropout` wiring — the T21-for-proc
|
||||
// regression: if this line were missing (the pre-fix launcher's exact gap),
|
||||
// `cfg.dropout` would stay 0 and the GATE B test below would find a bit-
|
||||
// identical p=0 / p=0.2 loss trace and FAIL.
|
||||
let mut cfg = test_config();
|
||||
cfg.dropout = worker_dropout();
|
||||
let model = build_model(cfg, device);
|
||||
let res = train_rank(
|
||||
&ctx,
|
||||
&model,
|
||||
@@ -203,8 +229,16 @@ fn proc_per_gpu_matches_single_gpu_and_thread_path() {
|
||||
let dump_dir = std::env::temp_dir().join(format!("xtrain_t17_{}", std::process::id()));
|
||||
std::fs::create_dir_all(&dump_dir).unwrap();
|
||||
// SAFETY: single-threaded test (forced by --test-threads=1) sets this env
|
||||
// before spawning workers; no concurrent env access.
|
||||
// before spawning workers; no concurrent env access. ENV_DROPOUT is cleared
|
||||
// defensively — libtest orders `--test-threads=1` runs alphabetically, so the
|
||||
// sibling `proc_per_gpu_dropout_is_live_...` test (starts with 'd') runs BEFORE
|
||||
// this one (starts with 'm'). If it happened to leak `ENV_DROPOUT=0.2` in this
|
||||
// process's env, the workers here would inherit it (Command inherits parent
|
||||
// env by default) and build with dropout=0.2 while the single-GPU baseline
|
||||
// (run_single_gpu → test_config → dropout=0) stays at 0 — GATE (a) would blow up.
|
||||
// Explicit remove here severs that ordering coupling.
|
||||
unsafe {
|
||||
std::env::remove_var(ENV_DROPOUT);
|
||||
std::env::set_var(ENV_DUMP_DIR, &dump_dir);
|
||||
}
|
||||
// Re-exec the test binary but run ONLY this test, single-threaded, so the
|
||||
@@ -273,6 +307,100 @@ fn proc_per_gpu_matches_single_gpu_and_thread_path() {
|
||||
let _ = std::fs::remove_dir_all(&dump_dir);
|
||||
}
|
||||
|
||||
/// T21-for-proc regression: prove that `--dropout` actually reaches the model
|
||||
/// under process-per-GPU. The pre-fix `bin/train_ddp_mp` had no `--dropout` flag
|
||||
/// and never set `cfg.dropout`, so the launcher's worker built its model with
|
||||
/// dropout stuck at 0 — silent identity, regardless of what the user passed. The
|
||||
/// thread-per-GPU T21 fix caught the analogous gap; this test caps the same gap
|
||||
/// on the proc-per-GPU path with the same GATE-B pattern (loss trajectory of a
|
||||
/// p=0.2 run differs from p=0 by a large margin, well above the NCCL noise floor).
|
||||
///
|
||||
/// Both runs share the corpus, the initial params (via `build_model`'s deterministic
|
||||
/// LCG), and every other config knob; the ONLY difference is `cfg.dropout`. If the
|
||||
/// worker didn't plumb the env-provided dropout into `cfg.dropout` (the exact pre-
|
||||
/// fix regression), both traces would be bit-identical and this test would FAIL.
|
||||
/// The `>1e-3` threshold sits orders of magnitude above the KI-5 cross-rank ULP
|
||||
/// noise floor (~1e-7 on this PCIe box), so it's a hard signal for "dropout is
|
||||
/// active" rather than a noise measurement. Mirrors
|
||||
/// `ddp_dropout_is_live_and_p0_bit_identical` in ddp_correctness.rs for T21's
|
||||
/// thread-per-GPU fix.
|
||||
#[test]
|
||||
fn proc_per_gpu_dropout_is_live_and_p0_matches_no_dropout() {
|
||||
run_as_worker_if_needed();
|
||||
|
||||
let world = 2usize;
|
||||
if device::device_count().unwrap_or(0) < world as i32 {
|
||||
eprintln!("skip: need >= {world} GPUs");
|
||||
return;
|
||||
}
|
||||
|
||||
let base_dump_dir = std::env::temp_dir().join(format!("xtrain_t21mp_{}", std::process::id()));
|
||||
std::fs::create_dir_all(&base_dump_dir).unwrap();
|
||||
let worker_args = [
|
||||
"--exact".to_string(),
|
||||
"proc_per_gpu_dropout_is_live_and_p0_matches_no_dropout".to_string(),
|
||||
"--test-threads=1".to_string(),
|
||||
"--nocapture".to_string(),
|
||||
];
|
||||
|
||||
// Helper: launch `world` workers with a specific dropout prob (via env), read
|
||||
// rank 0's loss trace, clean up. Uses a subdir per run so the two invocations
|
||||
// do not clobber each other's dumps.
|
||||
let mut launch_with_dropout = |p: f32, tag: &str| -> Vec<f32> {
|
||||
let dump_dir = base_dump_dir.join(tag);
|
||||
std::fs::create_dir_all(&dump_dir).unwrap();
|
||||
// SAFETY: single-threaded test (forced by --test-threads=1); no concurrent env access.
|
||||
unsafe {
|
||||
std::env::set_var(ENV_DUMP_DIR, &dump_dir);
|
||||
std::env::set_var(ENV_DROPOUT, format!("{p}"));
|
||||
}
|
||||
launch_processes(world, &worker_args).expect("worker processes failed");
|
||||
let (losses, _) = read_dump(dump_dir.to_str().unwrap(), 0);
|
||||
losses
|
||||
};
|
||||
|
||||
let loss_p0 = launch_with_dropout(0.0, "p0");
|
||||
let loss_p1 = launch_with_dropout(0.2, "p02");
|
||||
|
||||
// GATE B — dropout is LIVE under process-per-GPU with p>0. If the worker
|
||||
// didn't set `cfg.dropout` (the pre-fix gap), the two traces would match to
|
||||
// the ~1e-7 NCCL noise floor. Anything above ~1e-3 is unambiguous evidence
|
||||
// that dropout masks are actually applied in every worker's forward.
|
||||
let max_live_diff = loss_p0
|
||||
.iter()
|
||||
.zip(&loss_p1)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
println!(
|
||||
"T21-proc GATE B (dropout live under proc-per-GPU): p0[last]={:.6} p0.2[last]={:.6} max |loss diff| = {max_live_diff:.3e}",
|
||||
loss_p0.last().unwrap(),
|
||||
loss_p1.last().unwrap()
|
||||
);
|
||||
assert!(
|
||||
max_live_diff > 1e-3,
|
||||
"p=0.2 proc-per-GPU loss matches p=0 — dropout NOT plumbed through the \
|
||||
process-per-GPU launcher (cfg.dropout stayed 0 in the worker): max |loss diff| {max_live_diff:.3e}"
|
||||
);
|
||||
|
||||
// No NaN/Inf in the p>0 run.
|
||||
assert!(
|
||||
loss_p1.iter().all(|l| l.is_finite()),
|
||||
"p=0.2 proc-per-GPU loss has non-finite values"
|
||||
);
|
||||
|
||||
// Clear the launcher→worker env keys so we don't leak state to anything that
|
||||
// runs later in this process. `proc_per_gpu_matches_single_gpu_and_thread_path`
|
||||
// clears ENV_DROPOUT defensively too, but keeping the invariant "each test
|
||||
// leaves the env as it found it" costs nothing.
|
||||
// SAFETY: single-threaded test (forced by --test-threads=1); no concurrent env access.
|
||||
unsafe {
|
||||
std::env::remove_var(ENV_DROPOUT);
|
||||
std::env::remove_var(ENV_DUMP_DIR);
|
||||
}
|
||||
|
||||
let _ = std::fs::remove_dir_all(&base_dump_dir);
|
||||
}
|
||||
|
||||
fn max_rel(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b)
|
||||
|
||||
@@ -36,22 +36,29 @@ use xtrain_tensor::{DType, Device, Tensor};
|
||||
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
|
||||
/// tensor, so bf16 values round-trip bit-for-bit.
|
||||
struct KVCache {
|
||||
k: Vec<Vec<f32>>,
|
||||
v: Vec<Vec<f32>>,
|
||||
k: Vec<Option<Tensor>>,
|
||||
v: Vec<Option<Tensor>>,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
fn new(n_layers: usize) -> Self {
|
||||
Self {
|
||||
k: vec![Vec::new(); n_layers],
|
||||
v: vec![Vec::new(); n_layers],
|
||||
k: (0..n_layers).map(|_| None).collect(),
|
||||
v: (0..n_layers).map(|_| None).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
|
||||
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
|
||||
self.k[li].extend_from_slice(k_tok);
|
||||
self.v[li].extend_from_slice(v_tok);
|
||||
/// Append one token's K/V (`[bh,1,hd]`, compute dtype) to layer `li`, growing the
|
||||
/// device-resident `[bh,T,hd]` cache via `cat_seq` (no host round-trip, M2c).
|
||||
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
|
||||
self.k[li] = Some(match self.k[li].take() {
|
||||
Some(c) => c.cat_seq(&k_bh),
|
||||
None => k_bh,
|
||||
});
|
||||
self.v[li] = Some(match self.v[li].take() {
|
||||
Some(c) => c.cat_seq(&v_bh),
|
||||
None => v_bh,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +190,6 @@ fn decode_step(
|
||||
) -> Vec<f32> {
|
||||
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||||
let dim = cfg.dim;
|
||||
let kv_dim = num_kv * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||||
let n_layers = cfg.n_layers;
|
||||
@@ -212,28 +218,18 @@ fn decode_step(
|
||||
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
|
||||
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
|
||||
|
||||
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
|
||||
// K: same as Q (QK-norm + RoPE). V: project only. Append each as [num_kv,1,hd]
|
||||
// (bh-major) into the device cache; no host round-trip, no transpose (M2c).
|
||||
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
|
||||
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||||
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
|
||||
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
|
||||
let k_bh = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos).reshape(&[num_kv, 1, hd]);
|
||||
let v_bh = linear_t(cdt, &normed, wv).reshape(&[num_kv, 1, hd]);
|
||||
cache.append(li, k_bh, v_bh);
|
||||
|
||||
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
cache.append(li, &k_host, &v_host);
|
||||
|
||||
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
|
||||
// → repeat_kv to [nh,T,hd].
|
||||
let t_len = cache.k[li].len() / kv_dim;
|
||||
let build = |flat: &[f32]| -> Tensor {
|
||||
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
|
||||
.to_device(device)
|
||||
.transpose_3d01(); // [num_kv, T, hd], f32
|
||||
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
|
||||
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
|
||||
};
|
||||
let k_full = build(&cache.k[li]);
|
||||
let v_full = build(&cache.v[li]);
|
||||
// repeat_kv the cached [num_kv,T,hd] to [nh,T,hd] for the SDPA.
|
||||
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, 1) };
|
||||
let k_full = expand(cache.k[li].as_ref().unwrap());
|
||||
let v_full = expand(cache.v[li].as_ref().unwrap());
|
||||
|
||||
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
|
||||
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
|
||||
@@ -265,3 +261,176 @@ fn argmax(row: &[f32]) -> usize {
|
||||
.unwrap()
|
||||
.0
|
||||
}
|
||||
|
||||
// ===================================================================
|
||||
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
|
||||
// ===================================================================
|
||||
|
||||
/// Batched K/V cache: `G` sequences advancing together. Per layer, a device-resident
|
||||
/// `[G·num_kv, T, head_dim]` grown one token per step via `cat_seq` (M2c — no host
|
||||
/// round-trip). Same as M2a's device cache with a G dimension in `bh`.
|
||||
struct BatchKVCache {
|
||||
k: Vec<Option<Tensor>>,
|
||||
v: Vec<Option<Tensor>>,
|
||||
}
|
||||
|
||||
impl BatchKVCache {
|
||||
fn new(n_layers: usize) -> Self {
|
||||
Self {
|
||||
k: (0..n_layers).map(|_| None).collect(),
|
||||
v: (0..n_layers).map(|_| None).collect(),
|
||||
}
|
||||
}
|
||||
fn append(&mut self, li: usize, k_bh: Tensor, v_bh: Tensor) {
|
||||
self.k[li] = Some(match self.k[li].take() {
|
||||
Some(c) => c.cat_seq(&k_bh),
|
||||
None => k_bh,
|
||||
});
|
||||
self.v[li] = Some(match self.v[li].take() {
|
||||
Some(c) => c.cat_seq(&v_bh),
|
||||
None => v_bh,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Batched KV-cache decode: roll out `n_samples` (G) completions of the SAME
|
||||
/// `prompt` in lockstep — all G share the prompt, so they advance at one common
|
||||
/// decode position each step (uniform RoPE via `rope_pos`). Returns G full token
|
||||
/// sequences (prompt + sampled continuation). The G-way batching amortises the
|
||||
/// per-step kernel launches across G (the rollout long-pole). Token-identical per
|
||||
/// row to G independent single-sequence decodes (gated by `tests/decode_batch.rs`).
|
||||
///
|
||||
/// `temperature == 0` ⇒ greedy (all G identical); `> 0` ⇒ independent samples
|
||||
/// (per-row draw from one shared `rng_state`). No finished-mask: all G generate
|
||||
/// `max_new` tokens; the caller cuts each at `<|endoftext|>` (a perf-only early
|
||||
/// stop is the M2b+ follow-up). Ragged (different-length prompts) is also deferred.
|
||||
pub fn generate_cached_batch(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
prompt: &[i32],
|
||||
n_samples: usize,
|
||||
max_new: usize,
|
||||
temperature: f32,
|
||||
rng_state: &mut u64,
|
||||
) -> Vec<Vec<i32>> {
|
||||
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
||||
assert!(n_samples > 0, "n_samples must be > 0");
|
||||
let cfg = model.config();
|
||||
let cdt = model.compute_dtype();
|
||||
let n_layers = cfg.n_layers;
|
||||
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
|
||||
let embed = ¶ms[0];
|
||||
let final_norm = ¶ms[1 + n_layers * 11];
|
||||
let lm_head = ¶ms[1 + n_layers * 11 + 1];
|
||||
|
||||
let g = n_samples;
|
||||
let mut cache = BatchKVCache::new(n_layers);
|
||||
let mut seqs: Vec<Vec<i32>> = vec![prompt.to_vec(); g];
|
||||
|
||||
// Prefill: feed each prompt token (identical across G) at its position.
|
||||
let mut logits = Vec::new(); // [G, vocab] flattened
|
||||
for (pos, &tok) in prompt.iter().enumerate() {
|
||||
let toks = vec![tok; g];
|
||||
logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &toks, pos, embed, final_norm, lm_head);
|
||||
}
|
||||
|
||||
let vocab = cfg.vocab;
|
||||
for _ in 0..max_new {
|
||||
let mut next = Vec::with_capacity(g);
|
||||
for row in 0..g {
|
||||
let lg = &logits[row * vocab..(row + 1) * vocab];
|
||||
let t = if temperature <= 0.0 {
|
||||
argmax(lg) as i32
|
||||
} else {
|
||||
sample_temperature(lg, temperature, rng_state) as i32
|
||||
};
|
||||
next.push(t);
|
||||
seqs[row].push(t);
|
||||
}
|
||||
let pos = seqs[0].len() - 1; // all G are at the same position
|
||||
logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &next, pos, embed, final_norm, lm_head);
|
||||
}
|
||||
seqs
|
||||
}
|
||||
|
||||
/// One batched decode step: `toks` is one current token per sequence (`[G]`), all at
|
||||
/// absolute position `pos`. Appends each sequence's K/V and returns logits `[G·vocab]`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn decode_step_batch(
|
||||
params: &[Tensor],
|
||||
cfg: &crate::Config,
|
||||
cdt: DType,
|
||||
device: Device,
|
||||
cache: &mut BatchKVCache,
|
||||
toks: &[i32],
|
||||
pos: usize,
|
||||
embed: &Tensor,
|
||||
final_norm: &Tensor,
|
||||
lm_head: &Tensor,
|
||||
) -> Vec<f32> {
|
||||
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||||
let dim = cfg.dim;
|
||||
let g = toks.len();
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||||
let n_layers = cfg.n_layers;
|
||||
// Uniform per-row position (all G at the same decode step).
|
||||
let positions = Tensor::from_slice(&vec![pos as i32; g], &[g]).to_device(device);
|
||||
|
||||
let ids = Tensor::from_slice(toks, &[g]).to_device(device);
|
||||
let mut h = embed.embedding(&ids); // [G, dim] f32
|
||||
if cdt == DType::BF16 {
|
||||
h = h.to_dtype(DType::BF16);
|
||||
}
|
||||
|
||||
for li in 0..n_layers {
|
||||
let base = 1 + li * 11;
|
||||
let (attn_norm, wq, wk, wv) =
|
||||
(¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]);
|
||||
let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]);
|
||||
let (ffn_norm, w_gate, w_up, w_down) =
|
||||
(¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]);
|
||||
|
||||
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [G, dim]
|
||||
|
||||
// Q: project → per-head QK-norm → RoPE at `pos` for every row.
|
||||
let q = linear_t(cdt, &normed, wq).reshape(&[g, nh, hd]);
|
||||
let q = q.reshape(&[g * nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
|
||||
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
|
||||
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
|
||||
|
||||
// K/V appended as [G·num_kv,1,hd] (bh-major) into the device cache (M2c).
|
||||
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
|
||||
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||||
let k_bh = k
|
||||
.reshape(&[g, num_kv, hd])
|
||||
.rope_pos(&positions, theta)
|
||||
.reshape(&[g * num_kv, 1, hd]);
|
||||
let v_bh = linear_t(cdt, &normed, wv).reshape(&[g * num_kv, 1, hd]);
|
||||
cache.append(li, k_bh, v_bh);
|
||||
|
||||
// repeat_kv the cached [G·num_kv,T,hd] to [G·nh,T,hd] for the SDPA.
|
||||
let expand = |c: &Tensor| if num_kv == nh { c.clone() } else { c.repeat_kv(nh, g) };
|
||||
let k_full = expand(cache.k[li].as_ref().unwrap());
|
||||
let v_full = expand(cache.v[li].as_ref().unwrap());
|
||||
|
||||
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
|
||||
let attn = attn.reshape(&[g, dim]); // concat heads per sequence
|
||||
let attn_out = linear_t(cdt, &attn, wo);
|
||||
h = h.add(&attn_out);
|
||||
|
||||
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
|
||||
let gate = linear_t(cdt, &normed, w_gate);
|
||||
let up = linear_t(cdt, &normed, w_up);
|
||||
let act = gate.silu().mul(&up);
|
||||
let down = linear_t(cdt, &act, w_down);
|
||||
h = h.add(&down);
|
||||
}
|
||||
|
||||
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
|
||||
linear_t(cdt, &h, lm_head)
|
||||
.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
@@ -29,4 +29,4 @@ pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod decode;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use decode::{generate_cached, generate_greedy_cached};
|
||||
pub use decode::{generate_cached, generate_cached_batch, generate_greedy_cached};
|
||||
|
||||
97
crates/xtrain-model/tests/ragged_batch.rs
Normal file
97
crates/xtrain-model/tests/ragged_batch.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
// M2d gate: does forward_batched on RIGHT-PADDED ragged sequences reproduce the
|
||||
// per-sequence single-seq forward on the real (non-pad) rows? The batched GRPO
|
||||
// training-side forwards depend on this "right-pad is free under causal attention"
|
||||
// property — a real completion row is at an earlier position than the trailing pad,
|
||||
// and causal masking forbids attending forward, so its logits should be unchanged.
|
||||
//
|
||||
// Tested in fp32 (exact) over both SDPA cores (composed + fused flash), since the
|
||||
// bench uses flash and a kernel could in principle leak the pad keys into the online
|
||||
// softmax.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor};
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
m.with_compute_dtype(dtype).with_flash(flash)
|
||||
}
|
||||
|
||||
fn host(t: &Tensor) -> Vec<f32> {
|
||||
t.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_batched_ragged_matches_looped() {
|
||||
if device::device_count().unwrap_or(0) == 0 {
|
||||
eprintln!("no CUDA device; skipping");
|
||||
return;
|
||||
}
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 32;
|
||||
cfg.n_layers = 2;
|
||||
let vocab = cfg.vocab;
|
||||
|
||||
// Ragged lengths incl. one crossing the flash tile (>32) and short ones.
|
||||
let lens = [6usize, 40, 9, 4];
|
||||
let lmax = *lens.iter().max().unwrap();
|
||||
let n = lens.len();
|
||||
let seqs: Vec<Vec<i32>> = lens
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(b, &l)| (0..l).map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32).collect())
|
||||
.collect();
|
||||
|
||||
for (dtype, tol) in [(DType::F32, 2e-3f32), (DType::BF16, 3e-1f32)] {
|
||||
for flash in [false, true] {
|
||||
let m = build(cfg, device, dtype, flash);
|
||||
// Looped: each sequence on its own (the ground truth).
|
||||
let looped: Vec<Vec<f32>> = seqs.iter().map(|s| host(&m.forward(&ids_tensor(s, device)).value())).collect();
|
||||
|
||||
// Batched: right-pad each to lmax (pad id 0), one forward_batched(batch = n).
|
||||
let mut flat = vec![0i32; n * lmax];
|
||||
for (i, s) in seqs.iter().enumerate() {
|
||||
flat[i * lmax..i * lmax + s.len()].copy_from_slice(s);
|
||||
}
|
||||
let ids = Tensor::from_slice(&flat, &[n * lmax]).to_device(device);
|
||||
let batched = host(&m.forward_batched(&ids, n).value()); // [n*lmax, vocab]
|
||||
|
||||
let mut dmax = 0f32;
|
||||
for (i, s) in seqs.iter().enumerate() {
|
||||
for r in 0..s.len() {
|
||||
for c in 0..vocab {
|
||||
let a = looped[i][r * vocab + c];
|
||||
let b = batched[(i * lmax + r) * vocab + c];
|
||||
dmax = dmax.max((a - b).abs());
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("dtype={dtype:?} flash={flash}: ragged right-pad vs looped, max|Δlogit| (real rows) = {dmax:.3e}");
|
||||
assert!(dmax < tol, "dtype={dtype:?} flash={flash}: right-pad NOT free under causal — max|Δ| = {dmax}");
|
||||
}
|
||||
}
|
||||
println!("forward_batched_ragged_matches_looped OK: right-pad is free under causal (fp32+bf16, composed + flash)");
|
||||
}
|
||||
@@ -822,6 +822,75 @@ impl Tensor {
|
||||
out
|
||||
}
|
||||
|
||||
/// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b).
|
||||
/// `self`:[tokens,heads,head_dim]; row `t`'s position is `positions[t]` (an
|
||||
/// I32 `[tokens]` tensor). For G-way batched decode all G rows share one decode
|
||||
/// position; for ragged batches each row carries its own. Mirrors `rope_at`'s
|
||||
/// dtype handling; forward only.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rope_pos(&self, positions: &Tensor, theta: f32) -> Self {
|
||||
assert_eq!(self.ndim(), 3, "rope_pos requires [tokens,heads,head_dim]");
|
||||
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
assert_eq!(head_dim % 2, 0, "head_dim must be even");
|
||||
assert_eq!(positions.dtype, DType::I32, "positions must be I32");
|
||||
assert_eq!(positions.numel(), tokens, "one position per token");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.rope_pos(positions, theta)
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rope_pos_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
positions.data_ptr() as *const i32,
|
||||
out.data_ptr() as *mut f32,
|
||||
tokens as i32,
|
||||
heads as i32,
|
||||
head_dim as i32,
|
||||
theta,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Concatenate along the sequence (middle) dim: `self`:[bh,ta,hd] ++
|
||||
/// `other`:[bh,tb,hd] → `[bh,ta+tb,hd]`. The device-side KV-cache append (M2c):
|
||||
/// the cache stays on the GPU and grows by one token per decode step, removing
|
||||
/// the M2a/M2b host round-trip. Mirrors the bf16 cast handling of the other
|
||||
/// structural kernels.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn cat_seq(&self, other: &Tensor) -> Self {
|
||||
assert_eq!(self.ndim(), 3, "cat_seq requires [bh,t,hd]");
|
||||
assert_eq!(other.ndim(), 3, "cat_seq requires [bh,t,hd]");
|
||||
assert_eq!(self.dtype, other.dtype, "cat_seq dtype mismatch");
|
||||
let (bh, ta, hd) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let (bh2, tb, hd2) = (other.shape[0], other.shape[1], other.shape[2]);
|
||||
assert_eq!(bh, bh2, "cat_seq bh mismatch");
|
||||
assert_eq!(hd, hd2, "cat_seq head_dim mismatch");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.cat_seq(&other.to_dtype(DType::F32))
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
let out = Tensor::zeros(&[bh, ta + tb, hd], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_cat_seq_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
bh as i32,
|
||||
(ta * hd) as i32,
|
||||
(tb * hd) as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
||||
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
||||
#[cfg(not(no_cuda))]
|
||||
|
||||
@@ -159,3 +159,67 @@ fn decode_attention_matches_full_attention_last_row() {
|
||||
);
|
||||
println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})");
|
||||
}
|
||||
|
||||
/// (e) `rope_pos` (per-row positions, M2b batched decode): with positions
|
||||
/// [0,1,…,n-1] it is bit-identical to the full-sequence `rope` (period=n); with a
|
||||
/// uniform position P every row matches `rope_at(·, P)` of that single row. This is
|
||||
/// the primitive the batched decode uses (G rows sharing one decode position).
|
||||
#[test]
|
||||
fn rope_pos_matches_rope_and_rope_at() {
|
||||
assert!(device::device_count().expect("device count") > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let (n, heads, hd) = (7usize, 3usize, 8usize);
|
||||
let theta = 10000.0f32;
|
||||
let host: Vec<f32> = (0..n * heads * hd).map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0).collect();
|
||||
let x = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0));
|
||||
|
||||
// positions [0,1,…,n-1] ⇒ identical to the full-sequence rope.
|
||||
let seq_pos: Vec<i32> = (0..n as i32).collect();
|
||||
let pos_t = Tensor::from_slice(&seq_pos, &[n]).to_device(Device::Cuda(0));
|
||||
let got = x.rope_pos(&pos_t, theta).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let want = x.rope(theta, n).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
assert_eq!(got, want, "rope_pos [0..n] != full rope");
|
||||
|
||||
// uniform position P ⇒ each row matches rope_at(single row, P).
|
||||
let p = 5i32;
|
||||
let uni = Tensor::from_slice(&vec![p; n], &[n]).to_device(Device::Cuda(0));
|
||||
let got_u = x.rope_pos(&uni, theta).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let row_len = heads * hd;
|
||||
for t in 0..n {
|
||||
let row = &host[t * row_len..(t + 1) * row_len];
|
||||
let want_row = Tensor::from_slice(row, &[1, heads, hd])
|
||||
.to_device(Device::Cuda(0))
|
||||
.rope_at(theta, p as usize)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec();
|
||||
assert_eq!(&got_u[t * row_len..(t + 1) * row_len], want_row.as_slice(), "uniform pos row {t}");
|
||||
}
|
||||
println!("rope_pos OK: == full rope for [0..n] and == rope_at(P) per row for uniform P");
|
||||
}
|
||||
|
||||
/// (f) `cat_seq` (device-side KV-cache append, M2c): concatenating [bh,ta,hd] ++
|
||||
/// [bh,tb,hd] along the seq dim equals the host-side interleaved concat (per bh row,
|
||||
/// a's block then b's block). This is the device append that removes the M2a/M2b
|
||||
/// host round-trip.
|
||||
#[test]
|
||||
fn cat_seq_matches_host_concat() {
|
||||
assert!(device::device_count().expect("device count") > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let (bh, ta, tb, hd) = (4usize, 3usize, 2usize, 5usize);
|
||||
let ah: Vec<f32> = (0..bh * ta * hd).map(|i| i as f32 * 0.1).collect();
|
||||
let bhost: Vec<f32> = (0..bh * tb * hd).map(|i| -(i as f32) - 1.0).collect();
|
||||
let a = Tensor::from_slice(&ah, &[bh, ta, hd]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&bhost, &[bh, tb, hd]).to_device(Device::Cuda(0));
|
||||
|
||||
let got = a.cat_seq(&b).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
// Host reference: per bh row, a's ta*hd then b's tb*hd.
|
||||
let mut want = vec![0f32; bh * (ta + tb) * hd];
|
||||
for r in 0..bh {
|
||||
let (oa, ob, oo) = (r * ta * hd, r * tb * hd, r * (ta + tb) * hd);
|
||||
want[oo..oo + ta * hd].copy_from_slice(&ah[oa..oa + ta * hd]);
|
||||
want[oo + ta * hd..oo + (ta + tb) * hd].copy_from_slice(&bhost[ob..ob + tb * hd]);
|
||||
}
|
||||
assert_eq!(got, want, "cat_seq != host interleaved concat");
|
||||
println!("cat_seq OK: [bh={bh},{ta}+{tb},{hd}] == host concat");
|
||||
}
|
||||
|
||||
268
crates/xtrain-train/src/bin/bench_grpo_batch.rs
Normal file
268
crates/xtrain-train/src/bin/bench_grpo_batch.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
//! Micro-benchmark + closeness gate for the M2d batched GRPO training-side forwards.
|
||||
//!
|
||||
//! After M2b/M2c the GRPO *step* is no longer rollout-bound — it is the `N = B·G`
|
||||
//! per-sample full-sequence forwards (the `per_token_logp` captures + the inner
|
||||
//! clipped-PG forward/backwards). This bin isolates exactly that, weight-independently
|
||||
//! (step wall-clock depends on shapes + launch counts, not on what the weights are), by
|
||||
//! synthesising `N` realistic ragged samples and A/B-timing the looped vs batched path
|
||||
//! for BOTH phases — plus asserting they agree numerically (the looped-vs-batched
|
||||
//! closeness gate; per-row bit-equivalence of the loss op is pinned by the autograd
|
||||
//! test `clipped_pg_loss_batched_matches_looped`).
|
||||
//!
|
||||
//! bench_grpo_batch <tokenizer.json> --init-ckpt <base.ckpt> <arch flags> \
|
||||
//! --n 48 --plen 12 --clen 24 --micro 16 --reps 3
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("bench_grpo_batch: built without CUDA (no_cuda); run on a GPU host.");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::grpo_batch::{PgSample, inner_pg_step_batched, inner_pg_step_looped, per_token_logp, per_token_logp_batched};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter().position(|a| a == name).and_then(|i| args.get(i + 1)).and_then(|s| s.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter().position(|a| a == name).and_then(|i| args.get(i + 1)).cloned()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn load_model(cfg: Config, device: Device, ckpt: &str) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
})
|
||||
.with_compute_dtype(DType::BF16)
|
||||
.with_flash(true);
|
||||
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
|
||||
m.eval();
|
||||
m
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn elapsed_ms<F: FnMut()>(reps: usize, mut f: F) -> f32 {
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..reps {
|
||||
f();
|
||||
}
|
||||
start.elapsed().as_secs_f32() * 1e3 / reps as f32
|
||||
}
|
||||
|
||||
/// Per-position argmax of the model over each ragged `input` (one `forward_batched`
|
||||
/// per `micro`-chunk). Used to teacher-force WELL-CONDITIONED targets (the top-1 token,
|
||||
/// high prob) so the closeness gate's logp isn't the ~−20 of a random token — where
|
||||
/// `−log p` amplifies bf16 noise. This matches real GRPO (targets are model samples).
|
||||
#[cfg(not(no_cuda))]
|
||||
fn model_argmax(model: &TinyTransformer, device: Device, inputs: &[Vec<i32>], vocab: usize, micro: usize) -> Vec<Vec<i32>> {
|
||||
let mut out = Vec::with_capacity(inputs.len());
|
||||
for chunk in inputs.chunks(micro.max(1)) {
|
||||
let m = chunk.len();
|
||||
let lmax = chunk.iter().map(|s| s.len()).max().unwrap();
|
||||
let mut flat = vec![0i32; m * lmax];
|
||||
for (i, s) in chunk.iter().enumerate() {
|
||||
flat[i * lmax..i * lmax + s.len()].copy_from_slice(s);
|
||||
}
|
||||
let ids = Tensor::from_slice(&flat, &[m * lmax]).to_device(device);
|
||||
let logits = model.forward_batched(&ids, m).value().to_dtype(DType::F32).to_device(Device::Cpu);
|
||||
let v = logits.as_slice::<f32>();
|
||||
for (i, s) in chunk.iter().enumerate() {
|
||||
let mut row = Vec::with_capacity(s.len());
|
||||
for r in 0..s.len() {
|
||||
let base = (i * lmax + r) * vocab;
|
||||
let mut best = 0usize;
|
||||
for c in 1..vocab {
|
||||
if v[base + c] > v[base + best] {
|
||||
best = c;
|
||||
}
|
||||
}
|
||||
row.push(best as i32);
|
||||
}
|
||||
out.push(row);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let tok_path = positionals.first().expect("usage: bench_grpo_batch <tokenizer.json> [flags]");
|
||||
|
||||
let n_heads = flag(&args, "--heads", 52usize);
|
||||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||
let n_layers = flag(&args, "--layers", 22usize);
|
||||
let ffn = flag(&args, "--ffn", 6656usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let n: usize = flag(&args, "--n", 48); // B·G samples per step
|
||||
let plen: usize = flag(&args, "--plen", 12); // prompt tokens
|
||||
let clen: usize = flag(&args, "--clen", 24); // max completion tokens
|
||||
let micro: usize = flag(&args, "--micro", 16);
|
||||
let reps: usize = flag(&args, "--reps", 3);
|
||||
let (eps, beta) = (flag(&args, "--eps", 0.2f32), flag(&args, "--beta", 0.0f32));
|
||||
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <base.ckpt> required");
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
|
||||
let vocab = tok.vocab_size();
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
let policy = load_model(cfg, device, &init_ckpt);
|
||||
let params = policy.params();
|
||||
|
||||
// --- Synthesise N ragged samples (frame-shaped: prompt masked, ragged completion).
|
||||
// Token IDs are random-but-valid; only the SHAPES drive the forward cost.
|
||||
let mut rng = 0xC0FFEEu64;
|
||||
let mut next = || {
|
||||
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
(rng >> 33) as usize
|
||||
};
|
||||
let mut io: Vec<(Vec<i32>, Vec<i32>)> = Vec::with_capacity(n);
|
||||
let mut advs: Vec<f32> = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let pl = plen.saturating_sub(2) + next() % 5; // jitter prompt length a little
|
||||
let cl = 4 + next() % clen.max(1); // completion 4..=clen
|
||||
let total = pl + cl;
|
||||
let toks: Vec<i32> = (0..total).map(|_| (next() % vocab) as i32).collect();
|
||||
let mut labels = vec![-100i32; pl]; // prompt masked
|
||||
labels.extend_from_slice(&toks[pl..]);
|
||||
let l = toks.len();
|
||||
io.push((toks[..l - 1].to_vec(), labels[1..l].to_vec())); // target masked at [..pl-1]
|
||||
advs.push(if next() % 2 == 0 { 0.7 } else { -0.7 });
|
||||
}
|
||||
let toklens: Vec<usize> = io.iter().map(|(i, _)| i.len()).collect();
|
||||
let (lmin, lmax) = (*toklens.iter().min().unwrap(), *toklens.iter().max().unwrap());
|
||||
println!("samples N={n}, seq len {lmin}..{lmax} (ragged), micro={micro}, β={beta}\n");
|
||||
|
||||
// Replace random completion targets with the model's own argmax (teacher forcing):
|
||||
// well-conditioned logp (top-1, not the ~−20 of a random token where bf16 noise
|
||||
// blows up via −log p). The completion target positions are where the skeleton is
|
||||
// ≥0; prompt positions stay masked (−100).
|
||||
let inputs: Vec<Vec<i32>> = io.iter().map(|(i, _)| i.clone()).collect();
|
||||
let preds = model_argmax(&policy, device, &inputs, vocab, micro);
|
||||
for (s, (_, target)) in io.iter_mut().enumerate() {
|
||||
for j in 0..target.len() {
|
||||
if target[j] >= 0 {
|
||||
target[j] = preds[s][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------- Phase 1: capture (per_token_logp) ----------------
|
||||
let logp_loop: Vec<Vec<f32>> = io.iter().map(|(i, t)| per_token_logp(&policy, device, i, t)).collect();
|
||||
let logp_batch = per_token_logp_batched(&policy, device, &io, micro);
|
||||
let cap_dmax = logp_loop
|
||||
.iter()
|
||||
.zip(&logp_batch)
|
||||
.flat_map(|(a, b)| a.iter().zip(b).map(|(x, y)| (x - y).abs()))
|
||||
.fold(0.0f32, f32::max);
|
||||
let t_cap_loop = elapsed_ms(reps, || {
|
||||
let _: Vec<Vec<f32>> = io.iter().map(|(i, t)| per_token_logp(&policy, device, i, t)).collect();
|
||||
});
|
||||
let t_cap_batch = elapsed_ms(reps, || {
|
||||
let _ = per_token_logp_batched(&policy, device, &io, micro);
|
||||
});
|
||||
|
||||
// Build PgSamples from the (matching) capture; ref = old − 0.3 to exercise KL.
|
||||
let batch: Vec<PgSample> = io
|
||||
.iter()
|
||||
.zip(&advs)
|
||||
.zip(&logp_batch)
|
||||
.map(|(((input, target), &adv), lp)| PgSample {
|
||||
input: input.clone(),
|
||||
target: target.clone(),
|
||||
adv,
|
||||
logp_old: lp.clone(),
|
||||
logp_ref: lp.iter().map(|v| v - 0.3).collect(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// ---------------- Phase 2: inner clipped-PG (forward + backward) ----------------
|
||||
// Representative grad snapshots: layer-0 wq (params[2]) + final_norm.
|
||||
let wq0 = ¶ms[2];
|
||||
let fnorm = ¶ms[1 + n_layers * 11];
|
||||
let snap = |v: &xtrain_autodiff::Var| -> Vec<f32> {
|
||||
v.grad().map(|g| g.to_device(Device::Cpu).as_slice::<f32>().to_vec()).unwrap_or_default()
|
||||
};
|
||||
let zero = |ps: &[xtrain_autodiff::Var]| ps.iter().for_each(|p| p.zero_grad());
|
||||
|
||||
zero(¶ms);
|
||||
inner_pg_step_looped(&policy, device, &batch, eps, beta);
|
||||
let (gq_loop, gn_loop) = (snap(wq0), snap(fnorm));
|
||||
zero(¶ms);
|
||||
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
|
||||
let (gq_batch, gn_batch) = (snap(wq0), snap(fnorm));
|
||||
zero(¶ms);
|
||||
|
||||
let reldiff = |a: &[f32], b: &[f32]| -> f32 {
|
||||
let num = a.iter().zip(b).map(|(x, y)| (x - y).abs()).fold(0.0f32, f32::max);
|
||||
let den = a.iter().map(|x| x.abs()).fold(0.0f32, f32::max).max(1e-12);
|
||||
num / den
|
||||
};
|
||||
let gq_rel = reldiff(&gq_loop, &gq_batch);
|
||||
let gn_rel = reldiff(&gn_loop, &gn_batch);
|
||||
|
||||
// Time only forward+backward — the lever. opt.step + grad-clip are identical in
|
||||
// both paths (one call over `params` after the per-sample loop), so they would
|
||||
// only add a constant; excluding them also dodges the unrelated 1B-Adam-state
|
||||
// memory wall (the M4 finding) that this diagnostic doesn't need to reproduce.
|
||||
let t_inner_loop = elapsed_ms(reps, || {
|
||||
inner_pg_step_looped(&policy, device, &batch, eps, beta);
|
||||
zero(¶ms);
|
||||
});
|
||||
let t_inner_batch = elapsed_ms(reps, || {
|
||||
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
|
||||
zero(¶ms);
|
||||
});
|
||||
|
||||
// ---------------- Report ----------------
|
||||
let spd = |a: f32, b: f32| if b > 0.0 { a / b } else { 0.0 };
|
||||
println!("=== closeness gate (looped vs batched) ===");
|
||||
println!(" capture per_token_logp : max|Δ| = {cap_dmax:.3e}");
|
||||
println!(" inner grad wq[0] : rel|Δ| = {gq_rel:.3e}");
|
||||
println!(" inner grad final_norm : rel|Δ| = {gn_rel:.3e}");
|
||||
println!("\n=== timing (mean of {reps} reps, ms/phase) ===");
|
||||
println!(" capture : looped {t_cap_loop:8.1} batched {t_cap_batch:8.1} ({:.2}× )", spd(t_cap_loop, t_cap_batch));
|
||||
println!(" inner : looped {t_inner_loop:8.1} batched {t_inner_batch:8.1} ({:.2}× )", spd(t_inner_loop, t_inner_batch));
|
||||
let (step_loop, step_batch) = (t_cap_loop + t_inner_loop, t_cap_batch + t_inner_batch);
|
||||
println!(" STEP : looped {step_loop:8.1} batched {step_batch:8.1} ({:.2}× )", spd(step_loop, step_batch));
|
||||
|
||||
// The RIGOROUS correctness gates live in the test suite (exact, not bf16-noisy):
|
||||
// - xtrain-model forward_batched_ragged_matches_looped (forward+pad == looped)
|
||||
// - xtrain-autodiff clipped_pg_loss_batched_matches_looped (op == looped, f32)
|
||||
// This is a smoke check at the 1B/bf16 scale: single-seq vs batched GEMM differ in
|
||||
// batch-reduction order, so a loose band, with well-conditioned (argmax) targets.
|
||||
assert!(cap_dmax < 0.2, "capture closeness smoke FAILED: max|Δlogp| = {cap_dmax}");
|
||||
assert!(gq_rel < 0.2 && gn_rel < 0.2, "inner grad closeness smoke FAILED: wq {gq_rel}, fn {gn_rel}");
|
||||
println!("\nSMOKE PASS (bf16 band): batched ≈ looped; rigorous gates are the two tests above.");
|
||||
}
|
||||
@@ -23,15 +23,15 @@ fn main() {
|
||||
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_autodiff::ops;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
|
||||
use xtrain_model::{Config, TinyTransformer, generate_cached_batch};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::{DType, Device};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::grpo_batch::{PgSample, inner_pg_step_batched, per_token_logp_batched};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
@@ -117,20 +117,6 @@ fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) ->
|
||||
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
|
||||
}
|
||||
|
||||
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= −per_row
|
||||
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
|
||||
let logits = model.forward(&ids_tensor(input, device)).value();
|
||||
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
|
||||
per_row
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
@@ -149,6 +135,9 @@ fn main() {
|
||||
let group: usize = flag(&args, "--group", 6);
|
||||
let n_prompts: usize = flag(&args, "--prompts", 8);
|
||||
let inner: usize = flag(&args, "--inner", 1);
|
||||
// M2d: pack the step's N=B·G ragged samples into forward_batched chunks of this
|
||||
// many samples (bounds the [chunk·Lmax, vocab] logits memory). Default = whole batch.
|
||||
let micro: usize = flag(&args, "--micro", n_prompts * group.max(1));
|
||||
let temp: f32 = flag(&args, "--temp", 1.0);
|
||||
let beta: f32 = flag(&args, "--beta", 0.04);
|
||||
let eps: f32 = flag(&args, "--eps", 0.2);
|
||||
@@ -188,16 +177,17 @@ fn main() {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
|
||||
// Per-window phase timers (ms): rollout / capture / inner — to keep the step
|
||||
// decomposition honest (M2d cut the training-side forwards 9×, so the question is
|
||||
// what now dominates the step).
|
||||
let (mut t_roll, mut t_cap, mut t_inner) = (0f32, 0f32, 0f32);
|
||||
for step in 0..steps {
|
||||
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
|
||||
struct Sample {
|
||||
input: Vec<i32>,
|
||||
target: Vec<i32>,
|
||||
adv: f32,
|
||||
logp_old: Vec<f32>,
|
||||
logp_ref: Vec<f32>,
|
||||
}
|
||||
let mut batch: Vec<Sample> = Vec::new();
|
||||
// Collect ALL the step's framed samples first (input, target, adv), so the
|
||||
// training-side forwards can be batched across the whole step (M2d) instead of
|
||||
// run one ragged sequence at a time.
|
||||
let t0 = std::time::Instant::now();
|
||||
let mut raw: Vec<(Vec<i32>, Vec<i32>, f32)> = Vec::new();
|
||||
for _ in 0..n_prompts {
|
||||
let p = gen_problem(&mut rng, &gcfg);
|
||||
let prompt_ids: Vec<i32> = tok
|
||||
@@ -205,12 +195,12 @@ fn main() {
|
||||
.into_iter()
|
||||
.map(|t| t as i32)
|
||||
.collect();
|
||||
// M2b batched rollout: the G samples of this prompt decode in lockstep
|
||||
// (one forward per step over the whole group → G× fewer kernel launches
|
||||
// than G sequential single-seq rollouts; the M4 rollout long-pole fix).
|
||||
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
||||
for _ in 0..group {
|
||||
// KV-cache temperature rollout (M2 engine): single-row logits per
|
||||
// step → far lighter on the allocator than the naive sampler, which
|
||||
// fragments it over a long rollout (the M4 rollout long-pole).
|
||||
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
|
||||
let outs = generate_cached_batch(&policy, device, &prompt_ids, group, max_new, temp, &mut rng);
|
||||
for out in &outs {
|
||||
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
let seg = first_answer_segment(&cont).trim().to_string();
|
||||
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
||||
@@ -230,53 +220,69 @@ fn main() {
|
||||
for (seg, r) in &comps {
|
||||
let adv = (r - mean) / (std + 1e-4);
|
||||
let (input, target) = frame(&tok, &p.question(), seg);
|
||||
let logp_old = per_token_logp(&policy, device, &input, &target);
|
||||
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
|
||||
let logp_ref = match &reference {
|
||||
Some(r) => per_token_logp(r, device, &input, &target),
|
||||
None => vec![0.0; logp_old.len()],
|
||||
};
|
||||
batch.push(Sample { input, target, adv, logp_old, logp_ref });
|
||||
raw.push((input, target, adv));
|
||||
}
|
||||
}
|
||||
|
||||
// ---- K inner clipped-PG epochs over the captured batch ----
|
||||
if !batch.is_empty() {
|
||||
let scale = 1.0 / batch.len() as f32;
|
||||
t_roll += t0.elapsed().as_secs_f32() * 1e3;
|
||||
|
||||
// ---- Batched capture (M2d): logπ_old (policy) + logπ_ref (frozen) over ALL
|
||||
// samples in forward_batched chunks, instead of one forward per sample. ----
|
||||
if !raw.is_empty() {
|
||||
let t1 = std::time::Instant::now();
|
||||
let io: Vec<(Vec<i32>, Vec<i32>)> = raw.iter().map(|(i, t, _)| (i.clone(), t.clone())).collect();
|
||||
let logp_old = per_token_logp_batched(&policy, device, &io, micro);
|
||||
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
|
||||
let logp_ref = match &reference {
|
||||
Some(r) => per_token_logp_batched(r, device, &io, micro),
|
||||
None => raw.iter().map(|(i, _, _)| vec![0.0; i.len()]).collect(),
|
||||
};
|
||||
let batch: Vec<PgSample> = raw
|
||||
.iter()
|
||||
.zip(logp_old)
|
||||
.zip(logp_ref)
|
||||
.map(|(((input, target, adv), lo), lr)| PgSample {
|
||||
input: input.clone(),
|
||||
target: target.clone(),
|
||||
adv: *adv,
|
||||
logp_old: lo,
|
||||
logp_ref: lr,
|
||||
})
|
||||
.collect();
|
||||
t_cap += t1.elapsed().as_secs_f32() * 1e3;
|
||||
|
||||
// ---- K inner clipped-PG epochs, batched over the captured samples ----
|
||||
let t2 = std::time::Instant::now();
|
||||
for _ in 0..inner {
|
||||
for s in &batch {
|
||||
let logits = policy.forward(&ids_tensor(&s.input, device));
|
||||
let loss = ops::clipped_pg_loss(
|
||||
&logits,
|
||||
&ids_tensor(&s.target, device),
|
||||
&s.logp_old,
|
||||
&s.logp_ref,
|
||||
s.adv,
|
||||
eps,
|
||||
beta,
|
||||
);
|
||||
ops::scale(&loss, scale).backward();
|
||||
}
|
||||
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
|
||||
let _ = xtrain_train::clip::clip_grad_norm_gpu(¶ms, clip, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
}
|
||||
t_inner += t2.elapsed().as_secs_f32() * 1e3;
|
||||
}
|
||||
|
||||
if (step + 1) % log_every == 0 || step == steps - 1 {
|
||||
let w = log_every.min(step + 1) as f32; // steps in this window
|
||||
println!(
|
||||
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
|
||||
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s | ms/step roll {:.0} cap {:.0} inner {:.0}",
|
||||
step + 1,
|
||||
win_reward / win_n.max(1) as f32,
|
||||
win_solved,
|
||||
win_n,
|
||||
start.elapsed().as_secs_f32(),
|
||||
t_roll / w,
|
||||
t_cap / w,
|
||||
t_inner / w,
|
||||
);
|
||||
win_reward = 0.0;
|
||||
win_solved = 0;
|
||||
win_n = 0;
|
||||
t_roll = 0.0;
|
||||
t_cap = 0.0;
|
||||
t_inner = 0.0;
|
||||
// Periodic save so a later OOM (naive rollout fragments the allocator —
|
||||
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
|
||||
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save");
|
||||
|
||||
162
crates/xtrain-train/src/grpo_batch.rs
Normal file
162
crates/xtrain-train/src/grpo_batch.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Batched GRPO training-side forwards (post-training M2d). After M2b/M2c made the
|
||||
//! rollout cheap, the GRPO **step** is dominated by the per-sample full-sequence
|
||||
//! forwards: the `per_token_logp` captures (policy + reference) and the inner
|
||||
//! clipped-PG `forward`/`backward`s — each a single-sequence `forward` over a short
|
||||
//! ragged completion. This module packs the `N = B·G` ragged samples of a step into
|
||||
//! ONE `forward_batched`, amortising the per-launch overhead across N (the same win
|
||||
//! M2b gave the rollout).
|
||||
//!
|
||||
//! The enabling property: **right-padding is free under causal attention.** Pad each
|
||||
//! ragged completion on the RIGHT to the batch's `Lmax`; a real completion row is at
|
||||
//! an earlier position than the trailing pad, and causal masking forbids attending
|
||||
//! forward, so its logits are bit-identical to the unpadded single-sequence forward.
|
||||
//! The pad rows' own outputs are garbage but are masked out (`target = -100`).
|
||||
//!
|
||||
//! Both the looped (baseline) and batched paths live here so they share one source of
|
||||
//! truth — `bin/bench_grpo_batch` A/Bs them (timing + a closeness gate), and the
|
||||
//! per-row equivalence of the loss op is pinned by `clipped_pg_loss_batched_matches_looped`
|
||||
//! in `xtrain-autodiff/tests/autograd.rs`.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_model::{TinyTransformer, ids_tensor};
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
/// One framed completion of a GRPO step: the next-token `(input, target)` pair
|
||||
/// (prompt positions masked to `-100` in `target`), its group-relative `adv`, and the
|
||||
/// per-position rollout-time / reference logprobs the clipped-PG loss needs.
|
||||
pub struct PgSample {
|
||||
pub input: Vec<i32>,
|
||||
pub target: Vec<i32>,
|
||||
pub adv: f32,
|
||||
pub logp_old: Vec<f32>,
|
||||
pub logp_ref: Vec<f32>,
|
||||
}
|
||||
|
||||
// ------------------------------- looped (baseline) -------------------------------
|
||||
|
||||
/// Per-position `logπ(target_t)` of one framed `(input, target)` pair (= `−per_row`
|
||||
/// of cross_entropy; masked positions are 0). One single-sequence forward, no grad.
|
||||
pub fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
|
||||
let logits = model.forward(&ids_tensor(input, device)).value();
|
||||
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
|
||||
per_row
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// One inner clipped-PG epoch the looped way: per sample, a single-sequence forward +
|
||||
/// [`ops::clipped_pg_loss`] scaled by `1/N` + backward (grads accumulate on `model`'s
|
||||
/// params). Returns the summed scaled loss. Caller does clip + opt.step + zero_grad.
|
||||
pub fn inner_pg_step_looped(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
batch: &[PgSample],
|
||||
eps: f32,
|
||||
beta: f32,
|
||||
) -> f32 {
|
||||
let scale = 1.0 / batch.len() as f32;
|
||||
let mut total = 0f32;
|
||||
for s in batch {
|
||||
let logits = model.forward(&ids_tensor(&s.input, device));
|
||||
let loss = ops::clipped_pg_loss(&logits, &ids_tensor(&s.target, device), &s.logp_old, &s.logp_ref, s.adv, eps, beta);
|
||||
let scaled = ops::scale(&loss, scale);
|
||||
total += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
scaled.backward();
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
// ------------------------------- batched (M2d) -----------------------------------
|
||||
|
||||
/// Right-pad `m` ragged `i32` rows (each `< lmax` long) to `[m*lmax]` sequence-major,
|
||||
/// filling with `pad`. Used for both the id stream (pad = 0, arbitrary) and the target
|
||||
/// stream (pad = −100, ignored by cross_entropy).
|
||||
fn pack_i32(rows: &[&[i32]], lmax: usize, pad: i32) -> Vec<i32> {
|
||||
let mut flat = vec![pad; rows.len() * lmax];
|
||||
for (i, r) in rows.iter().enumerate() {
|
||||
flat[i * lmax..i * lmax + r.len()].copy_from_slice(r);
|
||||
}
|
||||
flat
|
||||
}
|
||||
|
||||
/// Batched [`per_token_logp`]: pack `samples` (each `(input, target)`) right-padded to
|
||||
/// `Lmax`, run ONE `forward_batched(batch = N)`, and slice each sample's `logπ` back to
|
||||
/// its real length. Equal to looping [`per_token_logp`] (right-pad is free under causal
|
||||
/// attention), to bf16 batch-reduction tolerance. `samples` are processed in chunks of
|
||||
/// `micro` (≥1) to bound the `[chunk*Lmax, vocab]` logits memory.
|
||||
pub fn per_token_logp_batched(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
samples: &[(Vec<i32>, Vec<i32>)],
|
||||
micro: usize,
|
||||
) -> Vec<Vec<f32>> {
|
||||
let mut out = Vec::with_capacity(samples.len());
|
||||
for chunk in samples.chunks(micro.max(1)) {
|
||||
let m = chunk.len();
|
||||
let lmax = chunk.iter().map(|(i, _)| i.len()).max().unwrap();
|
||||
let ins: Vec<&[i32]> = chunk.iter().map(|(i, _)| i.as_slice()).collect();
|
||||
let tgs: Vec<&[i32]> = chunk.iter().map(|(_, t)| t.as_slice()).collect();
|
||||
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
|
||||
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
|
||||
let logits = model.forward_batched(&ids, m).value();
|
||||
let (_, per_row) = logits.cross_entropy(&tgt);
|
||||
let pr = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
for (i, (inp, _)) in chunk.iter().enumerate() {
|
||||
let b = i * lmax;
|
||||
out.push((0..inp.len()).map(|r| -pr[b + r]).collect());
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// One inner clipped-PG epoch, batched: pack the batch (in `micro`-sized chunks) and run
|
||||
/// ONE `forward_batched` + [`ops::clipped_pg_loss_batched`] + backward per chunk. The
|
||||
/// per-row `weight = 1/(N·n_s)` uses the GLOBAL `N = batch.len()` (not the chunk size),
|
||||
/// so chunked grad-accumulation reproduces the looped `Σ_s (1/N)(1/n_s)…` exactly.
|
||||
/// Returns the summed loss. Caller does clip + opt.step + zero_grad.
|
||||
pub fn inner_pg_step_batched(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
batch: &[PgSample],
|
||||
eps: f32,
|
||||
beta: f32,
|
||||
micro: usize,
|
||||
) -> f32 {
|
||||
let inv_n = 1.0 / batch.len() as f32;
|
||||
let mut total = 0f32;
|
||||
for chunk in batch.chunks(micro.max(1)) {
|
||||
let m = chunk.len();
|
||||
let lmax = chunk.iter().map(|s| s.input.len()).max().unwrap();
|
||||
let ins: Vec<&[i32]> = chunk.iter().map(|s| s.input.as_slice()).collect();
|
||||
let tgs: Vec<&[i32]> = chunk.iter().map(|s| s.target.as_slice()).collect();
|
||||
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
|
||||
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
|
||||
|
||||
let mut logp_old = vec![0f32; m * lmax];
|
||||
let mut logp_ref = vec![0f32; m * lmax];
|
||||
let mut advantage = vec![0f32; m * lmax];
|
||||
let mut weight = vec![0f32; m * lmax];
|
||||
for (i, s) in chunk.iter().enumerate() {
|
||||
let b = i * lmax;
|
||||
let li = s.input.len();
|
||||
logp_old[b..b + li].copy_from_slice(&s.logp_old);
|
||||
logp_ref[b..b + li].copy_from_slice(&s.logp_ref);
|
||||
let n_s = s.target.iter().filter(|&&t| t >= 0).count().max(1) as f32;
|
||||
let w = inv_n / n_s; // = 1/(N · n_s)
|
||||
for r in 0..lmax {
|
||||
advantage[b + r] = s.adv;
|
||||
weight[b + r] = w;
|
||||
}
|
||||
}
|
||||
let logits = model.forward_batched(&ids, m);
|
||||
let loss = ops::clipped_pg_loss_batched(&logits, &tgt, &logp_old, &logp_ref, &advantage, &weight, eps, beta);
|
||||
total += loss.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
loss.backward();
|
||||
}
|
||||
total
|
||||
}
|
||||
@@ -15,6 +15,8 @@ pub mod task;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod checkpoint;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod grpo_batch;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod sample;
|
||||
#[cfg(not(no_cuda))]
|
||||
mod train_loop;
|
||||
|
||||
83
crates/xtrain-train/tests/decode_batch.rs
Normal file
83
crates/xtrain-train/tests/decode_batch.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
// M2b batched KV-cache decode — the token-identical gate.
|
||||
//
|
||||
// Batched decode rolls out G samples of one prompt in lockstep (one common decode
|
||||
// position each step, uniform RoPE via rope_pos, KV cache carrying a G dimension).
|
||||
// Under GREEDY decoding all G rows are deterministic and must each equal the
|
||||
// single-sequence greedy decode (generate_greedy_cached, itself gated token-
|
||||
// identical to the naive sampler). This pins that the G-way batching indexes each
|
||||
// sequence's K/V correctly (no cross-row contamination) and reproduces M2a exactly.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{generate_cached_batch, generate_greedy_cached, Config, TinyTransformer};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
.with_compute_dtype(DType::F32)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batched_greedy_decode_matches_single_seq() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
// Real GQA (8 query / 2 kv heads → group 4) so repeat_kv(nh, batch=G) is exercised.
|
||||
let cfg = Config::from_arch(48, 8, 16, 4, 256).with_kv_heads(2);
|
||||
let model = build(cfg, device);
|
||||
let prompt: Vec<i32> = vec![3, 9, 1, 14, 5];
|
||||
let max_new = 24usize;
|
||||
let g = 5usize;
|
||||
|
||||
let single = generate_greedy_cached(&model, device, &prompt, max_new);
|
||||
let mut rng = 0u64;
|
||||
let batched = generate_cached_batch(&model, device, &prompt, g, max_new, 0.0, &mut rng);
|
||||
|
||||
assert_eq!(batched.len(), g, "expected {g} sample rows");
|
||||
for (row, seq) in batched.iter().enumerate() {
|
||||
assert_eq!(
|
||||
seq.len(),
|
||||
single.len(),
|
||||
"row {row} length {} vs single {}",
|
||||
seq.len(),
|
||||
single.len()
|
||||
);
|
||||
if seq != &single {
|
||||
let first = seq.iter().zip(&single).position(|(a, b)| a != b).unwrap();
|
||||
panic!(
|
||||
"batched row {row} diverges from single-seq at index {first}: {:?} vs {:?}",
|
||||
seq[first], single[first]
|
||||
);
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"batched decode OK: all {g} greedy rows token-identical to single-seq over {max_new} tokens"
|
||||
);
|
||||
}
|
||||
@@ -269,6 +269,52 @@ void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
||||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||
}
|
||||
|
||||
// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b): row `tok`'s
|
||||
// position is `positions[tok]` (an i32 per token). For G-way batched decode all G
|
||||
// rows share one decode position; for ragged batches each row carries its own.
|
||||
// Forward only; the training rope_k is untouched.
|
||||
__global__ void rope_pos_k(const float* x, const int* positions, float* y,
|
||||
int heads, int head_dim, float theta) {
|
||||
int tok = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int half = head_dim / 2;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half) return;
|
||||
int pos = positions[tok];
|
||||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||
float angle = (float)pos * freq;
|
||||
float c = cosf(angle), sn = sinf(angle);
|
||||
int base = (tok * heads + head) * head_dim;
|
||||
float x0 = x[base + i], x1 = x[base + i + half];
|
||||
y[base + i] = x0 * c - x1 * sn;
|
||||
y[base + i + half] = x1 * c + x0 * sn;
|
||||
}
|
||||
void launch_rope_pos_f32(const float* x, const int* positions, float* y,
|
||||
int tokens, int heads, int head_dim, float theta, void* s) {
|
||||
dim3 grid(tokens, heads);
|
||||
int blk = head_dim / 2;
|
||||
rope_pos_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, positions, y, heads, head_dim, theta);
|
||||
}
|
||||
|
||||
// Concatenate along the sequence (middle) dim: a:[bh,ta,hd], b:[bh,tb,hd] →
|
||||
// out:[bh,ta+tb,hd] with out[:, :ta]=a, out[:, ta:]=b. The device-side KV-cache
|
||||
// append (M2c): keeps K/V on the GPU and grows by one token per step, removing the
|
||||
// host round-trip the M2a/M2b host cache paid. One block per bh row.
|
||||
__global__ void cat_seq_k(const float* a, const float* b, float* out,
|
||||
int ta_hd, int tb_hd) {
|
||||
int i = blockIdx.x; // bh row
|
||||
int o_hd = ta_hd + tb_hd;
|
||||
const float* ar = a + (long)i * ta_hd;
|
||||
const float* br = b + (long)i * tb_hd;
|
||||
float* outr = out + (long)i * o_hd;
|
||||
for (int j = threadIdx.x; j < ta_hd; j += blockDim.x) outr[j] = ar[j];
|
||||
for (int j = threadIdx.x; j < tb_hd; j += blockDim.x) outr[ta_hd + j] = br[j];
|
||||
}
|
||||
void launch_cat_seq_f32(const float* a, const float* b, float* out,
|
||||
int bh, int ta_hd, int tb_hd, void* s) {
|
||||
cat_seq_k<<<bh, 256, 0, (cudaStream_t)s>>>(a, b, out, ta_hd, tb_hd);
|
||||
}
|
||||
|
||||
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
||||
// (M4) policy-gradient backward, where each completion token's row of
|
||||
// (probs − onehot) is scaled by its own per-token coefficient.
|
||||
|
||||
@@ -522,3 +522,108 @@ leash wired, format held); the held-out flatness + the two memory/throughput wal
|
||||
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
|
||||
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
|
||||
limits on what alignment can do for a capability the base model lacks.
|
||||
|
||||
### M2b — batched KV-cache decode (landed; completes the M2 engine, fixes the rollout long-pole)
|
||||
|
||||
Built after M4 (where the rollout long-pole bit hardest): decode the **G samples of one prompt in
|
||||
lockstep** — one forward per step over the whole group → G× fewer kernel launches, the deferred
|
||||
fix from M2a.
|
||||
|
||||
**One new primitive:** `rope_pos(x, positions[])` — RoPE with a *per-row* absolute position (new
|
||||
forward-only kernel), since the G batched rows share one decode position (M2a's `rope_at` does
|
||||
`pos0 + row`, wrong for a batch at a single position). **Gate:** bit-identical to the full rope
|
||||
for positions `[0..n]`, and to `rope_at(P)` per row for a uniform `P`.
|
||||
|
||||
**Engine (`generate_cached_batch`):** `BatchKVCache` carries a G dimension (`[T, G·num_kv, hd]`
|
||||
host-accumulated → `[G·num_kv, T, hd]`); the batched `decode_step` threads G through embed /
|
||||
projections / QK-norm / `rope_pos` / cache. Two M2a pieces drop in unchanged: `decode_attention`
|
||||
is already batch-agnostic (`bh = G·nh`), and `repeat_kv(nh, batch=G)` broadcasts per group. No
|
||||
finished-mask (all G generate `max_new`; the caller cuts at EOS) and no ragged-length prompts yet
|
||||
— both perf-only follow-ups.
|
||||
|
||||
**Gate (token-identical):** all G **greedy** rows are byte-identical to the single-sequence decode
|
||||
(`tests/decode_batch.rs`, 8 query / 2 kv heads → exercises the `repeat_kv` batching) — pins that
|
||||
G-way batching indexes each sequence's K/V with no cross-row contamination.
|
||||
|
||||
**Throughput (v12 1.05B, G=6·B=6, easy task, rollout wired into `train_grpo`):** ~8.5 s/step vs
|
||||
~14–16 s/step for the single-seq cached rollout — **~1.7×**, rollout-inclusive. Short of the full
|
||||
G× because (a) the per-token-logp forwards + the PG update also cost, and (b) the M2a per-layer
|
||||
**host round-trip** is still there (now G× the data in one transfer, not removed). The full
|
||||
device-side cache (no host round-trip) is the remaining decode-engine optimization. Batching also
|
||||
**stabilises memory**: one batched forward per step vs G separate allocations that fragmented the
|
||||
caching allocator (the M4 OOM). So M2b closes the decode-engine milestone (M2a single-seq + M2b
|
||||
batched) and turns the rollout long-pole from "OOM/unbounded" into a bounded ~1.7× win — measured,
|
||||
with the device-cache as the named next lever.
|
||||
|
||||
### M2c — device-side KV cache (landed; the bottleneck moved, a profile-first finding)
|
||||
|
||||
The named M2b follow-up: keep K/V on the GPU (`[bh,T,hd]`, an `Option<Tensor>` per layer) and
|
||||
grow it by one token per step via a new `cat_seq` kernel (concat along the seq dim) — removing the
|
||||
M2a/M2b per-layer **host round-trip** (`to_cpu`/`from_slice`/re-upload) *and* the `transpose_3d01`.
|
||||
Both single-seq and batched decode refactored to it (cleaner than the host `Vec` + rebuild).
|
||||
|
||||
**Gates hold:** `cat_seq == host concat`; `decode_kv` single-seq + `decode_batch` G-way both still
|
||||
**token-identical**; GQA training path unaffected.
|
||||
|
||||
**The finding (why this is a measure-first lesson, not a speedup story):** removing the host
|
||||
round-trip buys **~10%** on *pure* single-seq decode (133 → 147 tok/s @128) but **does not move the
|
||||
GRPO step** (~8.5 s/step, unchanged). Because after M2b batching, the rollout is no longer the
|
||||
step's bottleneck — the per-sample **`per_token_logp` captures** (2 forwards/sample) and the
|
||||
**PG-update** forwards+backwards (`model.forward`, full-sequence, per sample) now dominate. So the
|
||||
long pole **shifted** from the rollout to the training-side forwards (cf. T11/T17/M2a: profile
|
||||
before optimizing — the bottleneck you fixed is not the one that remains). The device cache is
|
||||
still a real, correctness-gated improvement (cleaner code, less PCIe, ~10% decode); the honest
|
||||
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
|
||||
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
|
||||
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
|
||||
|
||||
### M2d — batch the GRPO training-side forwards (landed; the lever M2c named, + a decomposition correction)
|
||||
|
||||
M2c named the next lever: **ragged batched prefill of the per-sample training-side forwards**. Those
|
||||
forwards are the two phases that, per step, run one single-sequence `forward` per sample: the
|
||||
`per_token_logp` **captures** (logπ_old policy + logπ_ref reference) and the inner **clipped-PG**
|
||||
forward/backwards. M2d packs all `N = B·G` ragged samples of a step into ONE `forward_batched`.
|
||||
|
||||
**The enabling property — right-padding is free under causal attention.** Pad each ragged completion
|
||||
on the RIGHT to the batch's `Lmax`. A real completion row sits at an earlier position than the
|
||||
trailing pad, and causal masking forbids attending forward, so its logits are **bit-identical** to
|
||||
the unpadded single-sequence forward; the pad rows are garbage but masked out (`target = -100`). This
|
||||
is exactly why training engines pad-and-mask rather than run ragged. Two new pieces:
|
||||
- `per_token_logp_batched` (`crates/xtrain-train/src/grpo_batch.rs`): right-pad → one
|
||||
`forward_batched(batch = N)` → slice each sample's logπ back to its real length.
|
||||
- `ops::clipped_pg_loss_batched` (`crates/xtrain-autodiff/src/ops.rs`): like the per-sample
|
||||
`clipped_pg_loss`, but takes **per-row** `advantage[t]` (the owning sample's `A`) and **per-row**
|
||||
`weight[t]` (the full normaliser; the caller passes `1/(N·n_s)`). It does NOT compute its own
|
||||
`1/n_tokens`, so folding `weight = 1/(N·n_s)` reproduces the looped `Σ_s (1/N)(1/n_s)…`
|
||||
**bit-for-bit** (the per-row CE backward is row-local). A `--micro` knob packs in chunks to bound
|
||||
the `[chunk·Lmax, vocab]` logits memory; the weight uses the GLOBAL `N`, so chunked
|
||||
grad-accumulation is exact. Both `train_grpo` and the bench call these shared helpers.
|
||||
|
||||
**Correctness gates (exact, not bf16-noisy):**
|
||||
- `xtrain-model::forward_batched_ragged_matches_looped` — forward_batched on right-padded ragged
|
||||
sequences == per-sequence single-seq forward on the real rows, **max|Δlogit| = 3.7e-7 (fp32) and
|
||||
0.0 (bf16)**, both composed + flash. Pins "right-pad is free".
|
||||
- `xtrain-autodiff::clipped_pg_loss_batched_matches_looped` — batched op == looped
|
||||
`Σ_s (1/N)·clipped_pg_loss_s`, **loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32)**.
|
||||
Composed, these prove the batched GRPO step == the looped step. End-to-end: a short SFT (v12 base,
|
||||
150 steps, arith) → `train_grpo` 12 steps runs clean — **no OOM** (1B master + AdamW + batched
|
||||
activations fit with `micro=16`), mean-reward rises, the batched inner executes.
|
||||
|
||||
**Throughput (bench `bin/bench_grpo_batch`, v12 1.05B, N=48 ragged, micro=16, β=0, weight-independent):**
|
||||
|
||||
| phase (per step) | looped (single-seq) | batched (M2d) | speedup |
|
||||
|-------------------------|---------------------|---------------|---------|
|
||||
| capture `per_token_logp`| 622 ms | 71 ms | 8.7× |
|
||||
| inner clipped-PG fwd+bwd| 1907 ms | 208 ms | 9.2× |
|
||||
| **training forwards** | **2526 ms** | **280 ms** | **9.0×**|
|
||||
|
||||
**The decomposition correction (the honest finding).** M2c claimed "the per-sample training forwards
|
||||
now dominate the step." The clean per-component bench falsifies the strong form: the training
|
||||
forwards were **~2.5 s of the ~8.5 s step (~30%)** — substantial and worth the 9× win, but the
|
||||
**rollout (`generate_cached_batch`, ~6 s) was always the larger share.** After M2d cuts the training
|
||||
forwards to ~0.28 s, the step is **~95% rollout** — the long pole has swung back to the rollout. So
|
||||
M2d removes the training-forward overhang (a real, exactly-gated 9× on its component), and re-confirms
|
||||
the same measure-first lesson one more time: the next **step-level** lever is **full B×G rollout
|
||||
batching** — today only the `G` samples of each prompt decode in lockstep (M2b); the `B` prompts are
|
||||
still sequential. M2d closes the "ragged batched per-sample forwards" lever M2c named; the post-
|
||||
training stack stays complete, now with the step decomposition measured, not asserted.
|
||||
|
||||
@@ -103,6 +103,12 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
||||
|
||||
**M4(GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**:新算子 `clipped_pg_loss`(per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。环 `train_grpo`:采 B prompt × rollout G → checker reward 0/1 → group-relative advantage `(r−mean)/(std+ε)`(无 critic,全对/全错组跳过)→ 存 πθ_old/πref per-token → K 内层 clipped-PG。rollout 用 **M2 引擎 + 新加的 temperature 采样**(单行 logits 比 naive `[seq,vocab]` 轻)。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ 在 easy(操作数≤20)上从 v12 base 重训 SFT → held-out **18.7%**;但 250/600 步同样 18.7% = 1B web-text 模型从 ~550 例**不泛化加减法、只记 train**。**两道系统墙(设计文档 Risks 预言)**:① 显存——KL-leash 要 policy+reference 两个 1B fp32-master+Adam≈21GB,加激活在 32GB 5090 上不稳定 OOM → 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81(被 train 重叠抬),**format 100/100 不崩**(温和 lr 下 β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT → KV-cache → DPO → GRPO 栈**,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界。
|
||||
|
||||
**M2b(批量 KV-cache 解码,已落地,补全 M2 引擎 + 修 rollout 长杆)**:M4 后补的 rollout 长杆修复——一个 prompt 的 **G 个样本同步解码**(每步一次 forward 跑整组 → G× 更少 kernel 启动)。一个新原语 `rope_pos`(逐 row 绝对位置 kernel,G 行共享一个解码位置;闸门 = `[0..n]` 逐位等于全 rope、统一 P 逐行等于 `rope_at(P)`,bit-identical)。引擎 `generate_cached_batch`:`BatchKVCache` 带 G 维,批量 `decode_step` 把 G 贯穿 embed/proj/QK-norm/`rope_pos`/cache;**M2a 两件零改动复用**——`decode_attention` 本就 batch-agnostic(bh=G·nh)、`repeat_kv(nh,batch=G)` 按组广播。闸门 = G 个贪心行逐字节等于单序列(`tests/decode_batch.rs`,8q/2kv 头练 repeat_kv 批量)。**吞吐**(v12, G6·B6, 接进 train_grpo):**~8.5s/step vs 单序列 ~14-16s/step ≈ 1.7×**(rollout-inclusive;未到满 G× 因 per_token_logp + PG 更新也占时间、M2a 主机往返还在);且**显存更稳**(一次批量 forward vs G 次分配撑碎 allocator 的 M4 OOM)。⇒ M2 引擎闭环(M2a 单序列 + M2b 批量),rollout 长杆从"OOM/无界"变成有界 ~1.7× 收益,device 端 cache 是点名的下一杠杆。
|
||||
|
||||
**M2c(device 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**:K/V 留 device 为 `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它(比 host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concat、decode_kv 单序列 + decode_batch 批量仍 **token-identical**、GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133→147 tok/s@128),但 **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward(同 T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实、闸门齐全的改进(更干净、少 PCIe、解码 +10%),但下一杠杆是 **per-sample forward 的 ragged 批量**而非 cache。M2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache),全 token-identical-gated;后训练栈完整、瓶颈已测绘。
|
||||
|
||||
**M2d(批量 GRPO 训练侧 forward,已落地,M2c 点名的杠杆 + 一处 decomposition 纠正)**:M2c 点名的下一杠杆——把每步 `N=B·G` 个 ragged 样本的训练侧 forward(`per_token_logp` 捕获 + inner clipped-PG fwd/bwd)打包进**一次 `forward_batched`**。**使能性质 = causal 下右 padding 免费**:真 completion 行位置早于尾部 pad,causal 禁止前向 attend,故真行 logits 与单序列 forward **逐位相同**,pad 行垃圾被 `target=-100` 屏蔽——这正是训练引擎 pad-and-mask 而非跑 ragged 的原因。两件新东西:`per_token_logp_batched`(右 pad → 一次 `forward_batched(N)` → 按真长切片)、`ops::clipped_pg_loss_batched`(per-row `advantage[t]` + per-row `weight[t]`,caller 传 `1/(N·n_s)`,op 不再自算 `1/n_tokens` → 折进 weight 即与 looped `Σ_s (1/N)(1/n_s)…` **逐位等价**;`--micro` 分块界定 `[chunk·Lmax,vocab]` logits 显存,weight 用全局 N 故分块梯度累积精确)。**两道精确闸门**:`forward_batched_ragged_matches_looped`(右 pad 批量 forward == 单序列,fp32 max|Δ|=3.7e-7、bf16 **0.0**,composed+flash)+ `clipped_pg_loss_batched_matches_looped`(批量 op == looped,loss Δ=1.5e-8/grad 7.5e-9,f32),复合即证端到端等价;端到端短 SFT→`train_grpo` 12 步**不 OOM**(1B master+AdamW+批量激活 micro=16 容得下)、批量 inner 执行。**吞吐(bench,v12 1.05B,N=48,micro16,权重无关)**:capture 622→71ms(8.7×)、inner 1907→208ms(9.2×)、**训练侧 forward 合计 2526→280ms(9.0×)**。**Decomposition 纠正(诚实发现)**:M2c 说"训练侧 forward 主导 step",干净分量 bench 证伪强形式——训练侧 forward 是 **~8.5s step 里的 ~2.5s(~30%)**,可观、值这 9×,但 **rollout(`generate_cached_batch` ~6s)一直是更大头**;M2d 把训练侧砍到 ~0.28s 后,step **~95% 是 rollout**,长杆又摆回 rollout。⇒ M2d 拔掉训练侧 forward 这块 overhang(分量级精确 9×),再次印证 measure-first:**step 级下一杠杆 = 全 B×G rollout 批量**(今天只有每 prompt 的 G 同步、B 个 prompt 仍串行)。后训练栈保持完整,step decomposition 现为**实测**而非断言。
|
||||
|
||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||
|
||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||
|
||||
Reference in New Issue
Block a user