Single-sequence KV-cache decode (xtrain-model/src/decode.rs): per-layer K/V cache + single-token incremental forward (prefill = first prompt.len() decode steps, one code path). Mirrors model::block_forward at the raw-Tensor level (no autograd tape — inference needs no grads), using rope_at + decode_attention. Cache is host-accumulated token-major f32, rebuilt per step (the honest M2a baseline; M2b moves it device-side + batched ragged). Gate (the M2 centerpiece): KV-cache greedy decode is TOKEN-IDENTICAL to the naive full-recompute greedy — tests/decode_kv.rs (small GQA model, F32, 24 tokens) and corroborated on the v12 1.05B SFT checkpoint (cached eval = naive eval byte-for-byte: format 100/100, correct 8/100). eval_arith --cached A/Bs the two paths + reports decode tok/s. Measured on v12 (1.05B, batch 1, F32): the cache win is sequence-length-dependent — max_new=32 naive 108 vs cached 111 tok/s (~1.0x; overhead-bound) max_new=128 naive 69 vs cached 133 tok/s (~1.9x) max_new=256 naive OOM vs cached 129 tok/s Cached throughput stays ~constant (O(1)/token) while naive decays (O(t)/token, O(seq^2) graph → OOM at length). Short eval prompts are overhead-bound, so the cache matters for long rollouts (DPO/GRPO), not the arithmetic eval itself. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
210 lines
7.5 KiB
Rust
210 lines
7.5 KiB
Rust
//! Verifiable-task eval (post-training, M1+). Load a checkpoint, greedily generate an
|
|
//! answer for each held-out arithmetic prompt, parse the `\boxed{}` answer, and report
|
|
//! the exact-match pass-rate against the gold file. Two signals are printed:
|
|
//! **format** (fraction that emitted any boxed integer) and **correctness** (fraction
|
|
//! whose boxed answer matches gold). This is the M1 format-baseline metric and the
|
|
//! reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO).
|
|
//!
|
|
//! eval_arith <ckpt> <tokenizer.json> --heads 52 --head-dim 32 --kv-heads 13 \
|
|
//! --layers 22 --ffn 6656 \
|
|
//! --prompts-file <dir>/arith_eval_prompts.txt \
|
|
//! --gold-file <dir>/arith_eval_gold.txt --max-tokens 48 --show 8
|
|
|
|
#[cfg(no_cuda)]
|
|
fn main() {
|
|
eprintln!("eval_arith: built without CUDA (no_cuda); run on a GPU host (dash5).");
|
|
}
|
|
|
|
#[cfg(not(no_cuda))]
|
|
use std::path::PathBuf;
|
|
#[cfg(not(no_cuda))]
|
|
use xtrain_cuda::device;
|
|
#[cfg(not(no_cuda))]
|
|
use xtrain_model::{Config, TinyTransformer};
|
|
#[cfg(not(no_cuda))]
|
|
use xtrain_tensor::Device;
|
|
#[cfg(not(no_cuda))]
|
|
use xtrain_train::sample::generate;
|
|
#[cfg(not(no_cuda))]
|
|
use xtrain_train::task::{check_answer, parse_boxed_answer};
|
|
|
|
// Same deterministic LCG init scheme as bin/train.rs / bin/greedy_sample.rs (the
|
|
// values are overwritten by the loaded checkpoint; init just shapes the tensors).
|
|
#[cfg(not(no_cuda))]
|
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
#[cfg(not(no_cuda))]
|
|
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
|
args.iter()
|
|
.position(|a| a == name)
|
|
.and_then(|i| args.get(i + 1))
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(default)
|
|
}
|
|
|
|
#[cfg(not(no_cuda))]
|
|
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
|
args.iter()
|
|
.position(|a| a == name)
|
|
.and_then(|i| args.get(i + 1))
|
|
.cloned()
|
|
}
|
|
|
|
#[cfg(not(no_cuda))]
|
|
fn decode_escapes(s: &str) -> String {
|
|
s.replace("\\n", "\n").replace("\\t", "\t")
|
|
}
|
|
|
|
/// The model keeps generating past the answer (no EOS stop in the sampler), so keep
|
|
/// only the first answer "turn": cut at the first `<|endoftext|>` and then at the
|
|
/// first newline. The arithmetic answer is a single line, so this isolates it.
|
|
#[cfg(not(no_cuda))]
|
|
fn first_answer_segment(continuation: &str) -> &str {
|
|
let s = continuation
|
|
.split("<|endoftext|>")
|
|
.next()
|
|
.unwrap_or(continuation);
|
|
s.split('\n').next().unwrap_or(s)
|
|
}
|
|
|
|
#[cfg(not(no_cuda))]
|
|
fn main() {
|
|
use xserv_tokenizer::Tokenizer;
|
|
|
|
let args: Vec<String> = std::env::args().collect();
|
|
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
|
let ckpt = positionals
|
|
.first()
|
|
.map(|s| PathBuf::from(s.as_str()))
|
|
.expect("usage: eval_arith <ckpt> <tokenizer.json> [flags]");
|
|
let tok_path = positionals
|
|
.get(1)
|
|
.map(|s| PathBuf::from(s.as_str()))
|
|
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
|
|
|
let n_heads = flag(&args, "--heads", 52usize);
|
|
let head_dim = flag(&args, "--head-dim", 32usize);
|
|
let n_layers = flag(&args, "--layers", 22usize);
|
|
let ffn = flag(&args, "--ffn", 6656usize);
|
|
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
|
let max_new = flag(&args, "--max-tokens", 48usize);
|
|
let n_show = flag(&args, "--show", 8usize);
|
|
let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required");
|
|
let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required");
|
|
// M2: decode through the KV-cache incremental engine instead of the naive
|
|
// full-recompute sampler. Token-identical to the naive path (gated by
|
|
// tests/decode_kv.rs); this flag also lets us A/B the two for the speedup.
|
|
let use_cached = args.iter().any(|a| a == "--cached");
|
|
|
|
// Prompts: skip the `#` header / blank lines and decode escaped newlines so the
|
|
// count and order line up with the gold file.
|
|
let prompts: Vec<String> = std::fs::read_to_string(&prompts_file)
|
|
.unwrap_or_else(|e| panic!("read prompts {prompts_file}: {e}"))
|
|
.lines()
|
|
.map(str::trim)
|
|
.filter(|l| !l.is_empty() && !l.starts_with('#'))
|
|
.map(decode_escapes)
|
|
.collect();
|
|
let golds: Vec<i64> = std::fs::read_to_string(&gold_file)
|
|
.unwrap_or_else(|e| panic!("read gold {gold_file}: {e}"))
|
|
.lines()
|
|
.map(str::trim)
|
|
.filter(|l| !l.is_empty())
|
|
.map(|l| l.parse::<i64>().expect("gold line not an integer"))
|
|
.collect();
|
|
assert_eq!(
|
|
prompts.len(),
|
|
golds.len(),
|
|
"prompt/gold count mismatch ({} vs {})",
|
|
prompts.len(),
|
|
golds.len()
|
|
);
|
|
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
|
|
let tok = Tokenizer::from_file(&tok_path);
|
|
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
|
|
.with_kv_heads(kv_heads);
|
|
let mut seed = 1u64;
|
|
let model = TinyTransformer::new(cfg, device, |shape| {
|
|
seed = seed.wrapping_add(1);
|
|
let n: usize = shape.iter().product();
|
|
if shape.len() == 1 {
|
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
|
} else {
|
|
fill(n, seed, 0.04)
|
|
}
|
|
});
|
|
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
|
|
|
|
println!(
|
|
"eval_arith: ckpt {} | {} prompts | max_new {} | decode={}",
|
|
ckpt.display(),
|
|
prompts.len(),
|
|
max_new,
|
|
if use_cached { "kv-cache" } else { "naive" }
|
|
);
|
|
|
|
let (mut n_boxed, mut n_correct) = (0usize, 0usize);
|
|
let mut shown = 0usize;
|
|
let mut gen_tokens = 0usize;
|
|
let t0 = std::time::Instant::now();
|
|
for (prompt, &gold) in prompts.iter().zip(&golds) {
|
|
let ids: Vec<i32> = tok.encode(prompt).into_iter().map(|t| t as i32).collect();
|
|
let out = if use_cached {
|
|
xtrain_model::generate_greedy_cached(&model, device, &ids, max_new)
|
|
} else {
|
|
let mut rng = 7u64;
|
|
generate(&model, device, &ids, max_new, 0.0, &mut rng)
|
|
};
|
|
gen_tokens += out.len() - ids.len();
|
|
let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
|
let seg = first_answer_segment(&cont);
|
|
if parse_boxed_answer(seg).is_some() {
|
|
n_boxed += 1;
|
|
}
|
|
let ok = check_answer(seg, gold);
|
|
if ok {
|
|
n_correct += 1;
|
|
}
|
|
if shown < n_show {
|
|
let q = prompt.replace('\n', " ");
|
|
println!(" [{}] gold={gold} got={seg:?} {}", q, if ok { "OK" } else { "x" });
|
|
shown += 1;
|
|
}
|
|
}
|
|
|
|
let elapsed = t0.elapsed().as_secs_f64();
|
|
let n = prompts.len() as f64;
|
|
println!(
|
|
"RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)",
|
|
n_boxed,
|
|
prompts.len(),
|
|
100.0 * n_boxed as f64 / n,
|
|
n_correct,
|
|
prompts.len(),
|
|
100.0 * n_correct as f64 / n,
|
|
);
|
|
println!(
|
|
"TIMING decode={} | {:.2}s | {} gen tokens | {:.1} tok/s",
|
|
if use_cached { "kv-cache" } else { "naive" },
|
|
elapsed,
|
|
gen_tokens,
|
|
gen_tokens as f64 / elapsed,
|
|
);
|
|
}
|