test: M2d — ragged-forward + batched-op equivalence gates + throughput bench
Two exact correctness gates (composed = the end-to-end batched GRPO step == looped): - xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded ragged sequences == per-sequence single-seq forward on the real rows. fp32 max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is free under causal". - xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32). bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards (loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms (9.2×), training forwards 2526→280 ms (9.0×). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -1177,3 +1177,94 @@ fn clipped_pg_loss_bwd_and_degenerate() {
|
||||
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
|
||||
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
|
||||
}
|
||||
|
||||
// clipped_pg_loss_batched (M2d): N ragged completions packed + right-padded into ONE
|
||||
// forward must equal the looped per-sample path Σ_s (1/N)·clipped_pg_loss_s. The
|
||||
// per-row CE backward is row-local, so folding weight = 1/(N·n_s) into the batched
|
||||
// op reproduces the looped gradient and weighted-sum loss bit-for-bit (f32 path).
|
||||
#[test]
|
||||
fn clipped_pg_loss_batched_matches_looped() {
|
||||
require_gpu();
|
||||
let (n, lmax, cols) = (3usize, 5usize, 10usize);
|
||||
let rows = n * lmax;
|
||||
let x_h = fill(rows * cols, 909);
|
||||
// Per sample: row 0 = prompt (-100); rows 1..real_len = completion; rest = pad
|
||||
// (-100). Different real_len ⇒ n_s = {2, 3, 1} completion rows.
|
||||
let real_len = [3usize, 4, 2];
|
||||
let adv_s = [0.7f32, -0.5, 0.3];
|
||||
let mut targets = vec![-100i32; rows];
|
||||
for s in 0..n {
|
||||
for r in 1..real_len[s] {
|
||||
let t = s * lmax + r;
|
||||
targets[t] = ((t * 3) % cols) as i32;
|
||||
}
|
||||
}
|
||||
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||
|
||||
// logp_old ≈ logπθ at base logits (ρ≈1), logp_ref offset to exercise the KL term.
|
||||
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
|
||||
let logp_old: Vec<f32> = per_row0
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect();
|
||||
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect();
|
||||
let (eps, beta) = (0.2f32, 0.1f32);
|
||||
|
||||
// Per-row advantage (sample's A) + per-row weight 1/(N·n_s) (full normaliser).
|
||||
let n_of = |s: usize| (0..lmax).filter(|&r| targets[s * lmax + r] >= 0).count() as f32;
|
||||
let mut advantage = vec![0f32; rows];
|
||||
let mut weight = vec![0f32; rows];
|
||||
for s in 0..n {
|
||||
let w = (1.0 / n as f32) * (1.0 / n_of(s));
|
||||
for r in 0..lmax {
|
||||
advantage[s * lmax + r] = adv_s[s];
|
||||
weight[s * lmax + r] = w;
|
||||
}
|
||||
}
|
||||
|
||||
// Batched: one packed [R, vocab] forward + one backward.
|
||||
let xb = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||
let lb = ops::clipped_pg_loss_batched(
|
||||
&xb, &mk_target(), &logp_old, &logp_ref, &advantage, &weight, eps, beta,
|
||||
);
|
||||
lb.backward();
|
||||
let gb = xb.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let lb_val = lb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
|
||||
// Looped reference: per-sample slice → clipped_pg_loss → scale(1/N) → backward.
|
||||
let mut g_ref = vec![0f32; rows * cols];
|
||||
let mut loss_ref = 0f32;
|
||||
for s in 0..n {
|
||||
let r0 = s * lmax;
|
||||
let xs_h = x_h[r0 * cols..(r0 + lmax) * cols].to_vec();
|
||||
let tgt_s: Vec<i32> = targets[r0..r0 + lmax].to_vec();
|
||||
let lo_s = logp_old[r0..r0 + lmax].to_vec();
|
||||
let lr_s = logp_ref[r0..r0 + lmax].to_vec();
|
||||
let xs = Var::leaf(cuda(&xs_h, &[lmax, cols]));
|
||||
let tgt = Tensor::from_slice(&tgt_s, &[lmax]).to_device(Device::Cuda(0));
|
||||
let ls = ops::clipped_pg_loss(&xs, &tgt, &lo_s, &lr_s, adv_s[s], eps, beta);
|
||||
let scaled = ops::scale(&ls, 1.0 / n as f32);
|
||||
scaled.backward();
|
||||
let gs = xs.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
g_ref[r0 * cols..(r0 + lmax) * cols].copy_from_slice(&gs);
|
||||
loss_ref += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
}
|
||||
|
||||
let max_g = gb
|
||||
.iter()
|
||||
.zip(&g_ref)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
assert!(
|
||||
(lb_val - loss_ref).abs() < 1e-5,
|
||||
"batched loss {lb_val} vs looped {loss_ref}"
|
||||
);
|
||||
assert!(max_g < 1e-5, "batched grad vs looped: max|Δ| = {max_g}");
|
||||
println!(
|
||||
"clipped_pg_loss_batched OK: loss Δ={:.2e}, grad max|Δ|={:.2e} (== looped Σ_s 1/N·pg_s)",
|
||||
(lb_val - loss_ref).abs(),
|
||||
max_g
|
||||
);
|
||||
}
|
||||
|
||||
97
crates/xtrain-model/tests/ragged_batch.rs
Normal file
97
crates/xtrain-model/tests/ragged_batch.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
// M2d gate: does forward_batched on RIGHT-PADDED ragged sequences reproduce the
|
||||
// per-sequence single-seq forward on the real (non-pad) rows? The batched GRPO
|
||||
// training-side forwards depend on this "right-pad is free under causal attention"
|
||||
// property — a real completion row is at an earlier position than the trailing pad,
|
||||
// and causal masking forbids attending forward, so its logits should be unchanged.
|
||||
//
|
||||
// Tested in fp32 (exact) over both SDPA cores (composed + fused flash), since the
|
||||
// bench uses flash and a kernel could in principle leak the pad keys into the online
|
||||
// softmax.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor};
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
m.with_compute_dtype(dtype).with_flash(flash)
|
||||
}
|
||||
|
||||
fn host(t: &Tensor) -> Vec<f32> {
|
||||
t.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_batched_ragged_matches_looped() {
|
||||
if device::device_count().unwrap_or(0) == 0 {
|
||||
eprintln!("no CUDA device; skipping");
|
||||
return;
|
||||
}
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 32;
|
||||
cfg.n_layers = 2;
|
||||
let vocab = cfg.vocab;
|
||||
|
||||
// Ragged lengths incl. one crossing the flash tile (>32) and short ones.
|
||||
let lens = [6usize, 40, 9, 4];
|
||||
let lmax = *lens.iter().max().unwrap();
|
||||
let n = lens.len();
|
||||
let seqs: Vec<Vec<i32>> = lens
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(b, &l)| (0..l).map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32).collect())
|
||||
.collect();
|
||||
|
||||
for (dtype, tol) in [(DType::F32, 2e-3f32), (DType::BF16, 3e-1f32)] {
|
||||
for flash in [false, true] {
|
||||
let m = build(cfg, device, dtype, flash);
|
||||
// Looped: each sequence on its own (the ground truth).
|
||||
let looped: Vec<Vec<f32>> = seqs.iter().map(|s| host(&m.forward(&ids_tensor(s, device)).value())).collect();
|
||||
|
||||
// Batched: right-pad each to lmax (pad id 0), one forward_batched(batch = n).
|
||||
let mut flat = vec![0i32; n * lmax];
|
||||
for (i, s) in seqs.iter().enumerate() {
|
||||
flat[i * lmax..i * lmax + s.len()].copy_from_slice(s);
|
||||
}
|
||||
let ids = Tensor::from_slice(&flat, &[n * lmax]).to_device(device);
|
||||
let batched = host(&m.forward_batched(&ids, n).value()); // [n*lmax, vocab]
|
||||
|
||||
let mut dmax = 0f32;
|
||||
for (i, s) in seqs.iter().enumerate() {
|
||||
for r in 0..s.len() {
|
||||
for c in 0..vocab {
|
||||
let a = looped[i][r * vocab + c];
|
||||
let b = batched[(i * lmax + r) * vocab + c];
|
||||
dmax = dmax.max((a - b).abs());
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("dtype={dtype:?} flash={flash}: ragged right-pad vs looped, max|Δlogit| (real rows) = {dmax:.3e}");
|
||||
assert!(dmax < tol, "dtype={dtype:?} flash={flash}: right-pad NOT free under causal — max|Δ| = {dmax}");
|
||||
}
|
||||
}
|
||||
println!("forward_batched_ragged_matches_looped OK: right-pad is free under causal (fp32+bf16, composed + flash)");
|
||||
}
|
||||
268
crates/xtrain-train/src/bin/bench_grpo_batch.rs
Normal file
268
crates/xtrain-train/src/bin/bench_grpo_batch.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
//! Micro-benchmark + closeness gate for the M2d batched GRPO training-side forwards.
|
||||
//!
|
||||
//! After M2b/M2c the GRPO *step* is no longer rollout-bound — it is the `N = B·G`
|
||||
//! per-sample full-sequence forwards (the `per_token_logp` captures + the inner
|
||||
//! clipped-PG forward/backwards). This bin isolates exactly that, weight-independently
|
||||
//! (step wall-clock depends on shapes + launch counts, not on what the weights are), by
|
||||
//! synthesising `N` realistic ragged samples and A/B-timing the looped vs batched path
|
||||
//! for BOTH phases — plus asserting they agree numerically (the looped-vs-batched
|
||||
//! closeness gate; per-row bit-equivalence of the loss op is pinned by the autograd
|
||||
//! test `clipped_pg_loss_batched_matches_looped`).
|
||||
//!
|
||||
//! bench_grpo_batch <tokenizer.json> --init-ckpt <base.ckpt> <arch flags> \
|
||||
//! --n 48 --plen 12 --clen 24 --micro 16 --reps 3
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("bench_grpo_batch: built without CUDA (no_cuda); run on a GPU host.");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::grpo_batch::{PgSample, inner_pg_step_batched, inner_pg_step_looped, per_token_logp, per_token_logp_batched};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter().position(|a| a == name).and_then(|i| args.get(i + 1)).and_then(|s| s.parse().ok()).unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter().position(|a| a == name).and_then(|i| args.get(i + 1)).cloned()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn load_model(cfg: Config, device: Device, ckpt: &str) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
})
|
||||
.with_compute_dtype(DType::BF16)
|
||||
.with_flash(true);
|
||||
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
|
||||
m.eval();
|
||||
m
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn elapsed_ms<F: FnMut()>(reps: usize, mut f: F) -> f32 {
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..reps {
|
||||
f();
|
||||
}
|
||||
start.elapsed().as_secs_f32() * 1e3 / reps as f32
|
||||
}
|
||||
|
||||
/// Per-position argmax of the model over each ragged `input` (one `forward_batched`
|
||||
/// per `micro`-chunk). Used to teacher-force WELL-CONDITIONED targets (the top-1 token,
|
||||
/// high prob) so the closeness gate's logp isn't the ~−20 of a random token — where
|
||||
/// `−log p` amplifies bf16 noise. This matches real GRPO (targets are model samples).
|
||||
#[cfg(not(no_cuda))]
|
||||
fn model_argmax(model: &TinyTransformer, device: Device, inputs: &[Vec<i32>], vocab: usize, micro: usize) -> Vec<Vec<i32>> {
|
||||
let mut out = Vec::with_capacity(inputs.len());
|
||||
for chunk in inputs.chunks(micro.max(1)) {
|
||||
let m = chunk.len();
|
||||
let lmax = chunk.iter().map(|s| s.len()).max().unwrap();
|
||||
let mut flat = vec![0i32; m * lmax];
|
||||
for (i, s) in chunk.iter().enumerate() {
|
||||
flat[i * lmax..i * lmax + s.len()].copy_from_slice(s);
|
||||
}
|
||||
let ids = Tensor::from_slice(&flat, &[m * lmax]).to_device(device);
|
||||
let logits = model.forward_batched(&ids, m).value().to_dtype(DType::F32).to_device(Device::Cpu);
|
||||
let v = logits.as_slice::<f32>();
|
||||
for (i, s) in chunk.iter().enumerate() {
|
||||
let mut row = Vec::with_capacity(s.len());
|
||||
for r in 0..s.len() {
|
||||
let base = (i * lmax + r) * vocab;
|
||||
let mut best = 0usize;
|
||||
for c in 1..vocab {
|
||||
if v[base + c] > v[base + best] {
|
||||
best = c;
|
||||
}
|
||||
}
|
||||
row.push(best as i32);
|
||||
}
|
||||
out.push(row);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let tok_path = positionals.first().expect("usage: bench_grpo_batch <tokenizer.json> [flags]");
|
||||
|
||||
let n_heads = flag(&args, "--heads", 52usize);
|
||||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||
let n_layers = flag(&args, "--layers", 22usize);
|
||||
let ffn = flag(&args, "--ffn", 6656usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let n: usize = flag(&args, "--n", 48); // B·G samples per step
|
||||
let plen: usize = flag(&args, "--plen", 12); // prompt tokens
|
||||
let clen: usize = flag(&args, "--clen", 24); // max completion tokens
|
||||
let micro: usize = flag(&args, "--micro", 16);
|
||||
let reps: usize = flag(&args, "--reps", 3);
|
||||
let (eps, beta) = (flag(&args, "--eps", 0.2f32), flag(&args, "--beta", 0.0f32));
|
||||
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <base.ckpt> required");
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
|
||||
let vocab = tok.vocab_size();
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
let policy = load_model(cfg, device, &init_ckpt);
|
||||
let params = policy.params();
|
||||
|
||||
// --- Synthesise N ragged samples (frame-shaped: prompt masked, ragged completion).
|
||||
// Token IDs are random-but-valid; only the SHAPES drive the forward cost.
|
||||
let mut rng = 0xC0FFEEu64;
|
||||
let mut next = || {
|
||||
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
(rng >> 33) as usize
|
||||
};
|
||||
let mut io: Vec<(Vec<i32>, Vec<i32>)> = Vec::with_capacity(n);
|
||||
let mut advs: Vec<f32> = Vec::with_capacity(n);
|
||||
for _ in 0..n {
|
||||
let pl = plen.saturating_sub(2) + next() % 5; // jitter prompt length a little
|
||||
let cl = 4 + next() % clen.max(1); // completion 4..=clen
|
||||
let total = pl + cl;
|
||||
let toks: Vec<i32> = (0..total).map(|_| (next() % vocab) as i32).collect();
|
||||
let mut labels = vec![-100i32; pl]; // prompt masked
|
||||
labels.extend_from_slice(&toks[pl..]);
|
||||
let l = toks.len();
|
||||
io.push((toks[..l - 1].to_vec(), labels[1..l].to_vec())); // target masked at [..pl-1]
|
||||
advs.push(if next() % 2 == 0 { 0.7 } else { -0.7 });
|
||||
}
|
||||
let toklens: Vec<usize> = io.iter().map(|(i, _)| i.len()).collect();
|
||||
let (lmin, lmax) = (*toklens.iter().min().unwrap(), *toklens.iter().max().unwrap());
|
||||
println!("samples N={n}, seq len {lmin}..{lmax} (ragged), micro={micro}, β={beta}\n");
|
||||
|
||||
// Replace random completion targets with the model's own argmax (teacher forcing):
|
||||
// well-conditioned logp (top-1, not the ~−20 of a random token where bf16 noise
|
||||
// blows up via −log p). The completion target positions are where the skeleton is
|
||||
// ≥0; prompt positions stay masked (−100).
|
||||
let inputs: Vec<Vec<i32>> = io.iter().map(|(i, _)| i.clone()).collect();
|
||||
let preds = model_argmax(&policy, device, &inputs, vocab, micro);
|
||||
for (s, (_, target)) in io.iter_mut().enumerate() {
|
||||
for j in 0..target.len() {
|
||||
if target[j] >= 0 {
|
||||
target[j] = preds[s][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------- Phase 1: capture (per_token_logp) ----------------
|
||||
let logp_loop: Vec<Vec<f32>> = io.iter().map(|(i, t)| per_token_logp(&policy, device, i, t)).collect();
|
||||
let logp_batch = per_token_logp_batched(&policy, device, &io, micro);
|
||||
let cap_dmax = logp_loop
|
||||
.iter()
|
||||
.zip(&logp_batch)
|
||||
.flat_map(|(a, b)| a.iter().zip(b).map(|(x, y)| (x - y).abs()))
|
||||
.fold(0.0f32, f32::max);
|
||||
let t_cap_loop = elapsed_ms(reps, || {
|
||||
let _: Vec<Vec<f32>> = io.iter().map(|(i, t)| per_token_logp(&policy, device, i, t)).collect();
|
||||
});
|
||||
let t_cap_batch = elapsed_ms(reps, || {
|
||||
let _ = per_token_logp_batched(&policy, device, &io, micro);
|
||||
});
|
||||
|
||||
// Build PgSamples from the (matching) capture; ref = old − 0.3 to exercise KL.
|
||||
let batch: Vec<PgSample> = io
|
||||
.iter()
|
||||
.zip(&advs)
|
||||
.zip(&logp_batch)
|
||||
.map(|(((input, target), &adv), lp)| PgSample {
|
||||
input: input.clone(),
|
||||
target: target.clone(),
|
||||
adv,
|
||||
logp_old: lp.clone(),
|
||||
logp_ref: lp.iter().map(|v| v - 0.3).collect(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// ---------------- Phase 2: inner clipped-PG (forward + backward) ----------------
|
||||
// Representative grad snapshots: layer-0 wq (params[2]) + final_norm.
|
||||
let wq0 = ¶ms[2];
|
||||
let fnorm = ¶ms[1 + n_layers * 11];
|
||||
let snap = |v: &xtrain_autodiff::Var| -> Vec<f32> {
|
||||
v.grad().map(|g| g.to_device(Device::Cpu).as_slice::<f32>().to_vec()).unwrap_or_default()
|
||||
};
|
||||
let zero = |ps: &[xtrain_autodiff::Var]| ps.iter().for_each(|p| p.zero_grad());
|
||||
|
||||
zero(¶ms);
|
||||
inner_pg_step_looped(&policy, device, &batch, eps, beta);
|
||||
let (gq_loop, gn_loop) = (snap(wq0), snap(fnorm));
|
||||
zero(¶ms);
|
||||
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
|
||||
let (gq_batch, gn_batch) = (snap(wq0), snap(fnorm));
|
||||
zero(¶ms);
|
||||
|
||||
let reldiff = |a: &[f32], b: &[f32]| -> f32 {
|
||||
let num = a.iter().zip(b).map(|(x, y)| (x - y).abs()).fold(0.0f32, f32::max);
|
||||
let den = a.iter().map(|x| x.abs()).fold(0.0f32, f32::max).max(1e-12);
|
||||
num / den
|
||||
};
|
||||
let gq_rel = reldiff(&gq_loop, &gq_batch);
|
||||
let gn_rel = reldiff(&gn_loop, &gn_batch);
|
||||
|
||||
// Time only forward+backward — the lever. opt.step + grad-clip are identical in
|
||||
// both paths (one call over `params` after the per-sample loop), so they would
|
||||
// only add a constant; excluding them also dodges the unrelated 1B-Adam-state
|
||||
// memory wall (the M4 finding) that this diagnostic doesn't need to reproduce.
|
||||
let t_inner_loop = elapsed_ms(reps, || {
|
||||
inner_pg_step_looped(&policy, device, &batch, eps, beta);
|
||||
zero(¶ms);
|
||||
});
|
||||
let t_inner_batch = elapsed_ms(reps, || {
|
||||
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
|
||||
zero(¶ms);
|
||||
});
|
||||
|
||||
// ---------------- Report ----------------
|
||||
let spd = |a: f32, b: f32| if b > 0.0 { a / b } else { 0.0 };
|
||||
println!("=== closeness gate (looped vs batched) ===");
|
||||
println!(" capture per_token_logp : max|Δ| = {cap_dmax:.3e}");
|
||||
println!(" inner grad wq[0] : rel|Δ| = {gq_rel:.3e}");
|
||||
println!(" inner grad final_norm : rel|Δ| = {gn_rel:.3e}");
|
||||
println!("\n=== timing (mean of {reps} reps, ms/phase) ===");
|
||||
println!(" capture : looped {t_cap_loop:8.1} batched {t_cap_batch:8.1} ({:.2}× )", spd(t_cap_loop, t_cap_batch));
|
||||
println!(" inner : looped {t_inner_loop:8.1} batched {t_inner_batch:8.1} ({:.2}× )", spd(t_inner_loop, t_inner_batch));
|
||||
let (step_loop, step_batch) = (t_cap_loop + t_inner_loop, t_cap_batch + t_inner_batch);
|
||||
println!(" STEP : looped {step_loop:8.1} batched {step_batch:8.1} ({:.2}× )", spd(step_loop, step_batch));
|
||||
|
||||
// The RIGOROUS correctness gates live in the test suite (exact, not bf16-noisy):
|
||||
// - xtrain-model forward_batched_ragged_matches_looped (forward+pad == looped)
|
||||
// - xtrain-autodiff clipped_pg_loss_batched_matches_looped (op == looped, f32)
|
||||
// This is a smoke check at the 1B/bf16 scale: single-seq vs batched GEMM differ in
|
||||
// batch-reduction order, so a loose band, with well-conditioned (argmax) targets.
|
||||
assert!(cap_dmax < 0.2, "capture closeness smoke FAILED: max|Δlogp| = {cap_dmax}");
|
||||
assert!(gq_rel < 0.2 && gn_rel < 0.2, "inner grad closeness smoke FAILED: wq {gq_rel}, fn {gn_rel}");
|
||||
println!("\nSMOKE PASS (bf16 band): batched ≈ looped; rigorous gates are the two tests above.");
|
||||
}
|
||||
Reference in New Issue
Block a user