bench-gpt-oss: teacher-forced diagnostics + --prompt flag
Add --prompt to override the fixed prompt, and two teacher-forced diagnostics: --forced runs prefill over prompt+oracle ids and reports per-position top-1 agreement; --forced-decode walks the oracle trajectory through the decode path with per-position agreement bucketed by position, to localize long-context decode divergence from the reference. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -68,7 +68,8 @@ fn main() {
|
|||||||
eprintln!("[rank 0] Ready.");
|
eprintln!("[rank 0] Ready.");
|
||||||
|
|
||||||
// Prompt
|
// Prompt
|
||||||
let prompt = "What is the meaning of life?";
|
let prompt_arg = get_arg::<String>(&args, "--prompt");
|
||||||
|
let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?");
|
||||||
let token_ids = tokenizer.encode(prompt);
|
let token_ids = tokenizer.encode(prompt);
|
||||||
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
|
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
|
||||||
|
|
||||||
@@ -77,6 +78,84 @@ fn main() {
|
|||||||
cache.register_sequence(slot).unwrap();
|
cache.register_sequence(slot).unwrap();
|
||||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Register(slot));
|
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Register(slot));
|
||||||
|
|
||||||
|
// Teacher-forced diagnostic: prefill (prompt + forced ids) in one shot and
|
||||||
|
// report, for each forced position, whether xserv's argmax == the forced
|
||||||
|
// (oracle) next token. Removes free-running compounding so it isolates
|
||||||
|
// whether per-position logits agree with the llama.cpp trajectory.
|
||||||
|
if let Some(forced) = get_arg::<String>(&args, "--forced") {
|
||||||
|
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||||
|
let mut seq = token_ids.clone();
|
||||||
|
seq.extend_from_slice(&forced_ids);
|
||||||
|
// Workers must run the same prefill in lockstep (TP AllReduces match up).
|
||||||
|
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot });
|
||||||
|
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
|
||||||
|
wait_workers(&worker_handles);
|
||||||
|
let logits_cpu = logits.to_device(Device::Cpu);
|
||||||
|
let vocab = logits.shape()[1];
|
||||||
|
let data = logits_cpu.as_slice::<half::bf16>();
|
||||||
|
let plen = token_ids.len();
|
||||||
|
let mut matches = 0usize;
|
||||||
|
let mut total = 0usize;
|
||||||
|
// position i predicts seq[i+1]; we check the forced region
|
||||||
|
for i in (plen - 1)..(seq.len() - 1) {
|
||||||
|
let row = &data[i * vocab..(i + 1) * vocab];
|
||||||
|
let argmax = row.iter().enumerate()
|
||||||
|
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||||
|
.map(|(j, _)| j as u32).unwrap();
|
||||||
|
let expected = seq[i + 1];
|
||||||
|
let ok = argmax == expected;
|
||||||
|
if ok { matches += 1; }
|
||||||
|
total += 1;
|
||||||
|
eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"});
|
||||||
|
}
|
||||||
|
eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
|
||||||
|
100.0 * matches as f64 / total as f64);
|
||||||
|
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||||
|
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Teacher-forced DECODE diagnostic: prefill the prompt, then walk the oracle
|
||||||
|
// trajectory through the autoregressive decode path (NOT prefill), recording
|
||||||
|
// per-position top-1 agreement bucketed by position. Localizes long-context
|
||||||
|
// decode degradation (which prefill teacher-forcing cannot see).
|
||||||
|
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
|
||||||
|
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||||
|
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot });
|
||||||
|
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||||
|
wait_workers(&worker_handles);
|
||||||
|
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
|
||||||
|
let bucket = 50usize;
|
||||||
|
let mut buckets: Vec<(usize, usize)> = Vec::new();
|
||||||
|
let (mut matches, mut total) = (0usize, 0usize);
|
||||||
|
for (i, &f) in forced_ids.iter().enumerate() {
|
||||||
|
let ok = pred == f;
|
||||||
|
matches += ok as usize;
|
||||||
|
total += 1;
|
||||||
|
let b = i / bucket;
|
||||||
|
if buckets.len() <= b { buckets.push((0, 0)); }
|
||||||
|
buckets[b].0 += ok as usize;
|
||||||
|
buckets[b].1 += 1;
|
||||||
|
// Teacher-force: feed the oracle token through the decode path.
|
||||||
|
let pos = cache.seq_len(slot);
|
||||||
|
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
|
||||||
|
tokens: vec![f], positions: vec![pos], slots: vec![slot],
|
||||||
|
});
|
||||||
|
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
|
||||||
|
wait_workers(&worker_handles);
|
||||||
|
pred = sample_greedy_last(&logits);
|
||||||
|
}
|
||||||
|
eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
|
||||||
|
100.0 * matches as f64 / total as f64);
|
||||||
|
for (b, (m, t)) in buckets.iter().enumerate() {
|
||||||
|
eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
|
||||||
|
b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64));
|
||||||
|
}
|
||||||
|
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||||
|
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Prefill
|
// Prefill
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {
|
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {
|
||||||
|
|||||||
Reference in New Issue
Block a user