bench-gpt-oss: teacher-forced diagnostics + --prompt flag

Add --prompt to override the fixed prompt, and two teacher-forced
diagnostics: --forced runs prefill over prompt+oracle ids and reports
per-position top-1 agreement; --forced-decode walks the oracle trajectory
through the decode path with per-position agreement bucketed by position,
to localize long-context decode divergence from the reference.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-05-31 00:56:46 +08:00
parent ffd90ce7fb
commit 0c6135aea3

View File

@@ -68,7 +68,8 @@ fn main() {
eprintln!("[rank 0] Ready.");
// Prompt
let prompt = "What is the meaning of life?";
let prompt_arg = get_arg::<String>(&args, "--prompt");
let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?");
let token_ids = tokenizer.encode(prompt);
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
@@ -77,6 +78,84 @@ fn main() {
cache.register_sequence(slot).unwrap();
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Register(slot));
// Teacher-forced diagnostic: prefill (prompt + forced ids) in one shot and
// report, for each forced position, whether xserv's argmax == the forced
// (oracle) next token. Removes free-running compounding so it isolates
// whether per-position logits agree with the llama.cpp trajectory.
if let Some(forced) = get_arg::<String>(&args, "--forced") {
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
let mut seq = token_ids.clone();
seq.extend_from_slice(&forced_ids);
// Workers must run the same prefill in lockstep (TP AllReduces match up).
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot });
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
wait_workers(&worker_handles);
let logits_cpu = logits.to_device(Device::Cpu);
let vocab = logits.shape()[1];
let data = logits_cpu.as_slice::<half::bf16>();
let plen = token_ids.len();
let mut matches = 0usize;
let mut total = 0usize;
// position i predicts seq[i+1]; we check the forced region
for i in (plen - 1)..(seq.len() - 1) {
let row = &data[i * vocab..(i + 1) * vocab];
let argmax = row.iter().enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(j, _)| j as u32).unwrap();
let expected = seq[i + 1];
let ok = argmax == expected;
if ok { matches += 1; }
total += 1;
eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"});
}
eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
for (h, _) in worker_handles { h.join().unwrap(); }
return;
}
// Teacher-forced DECODE diagnostic: prefill the prompt, then walk the oracle
// trajectory through the autoregressive decode path (NOT prefill), recording
// per-position top-1 agreement bucketed by position. Localizes long-context
// decode degradation (which prefill teacher-forcing cannot see).
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot });
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
wait_workers(&worker_handles);
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
let bucket = 50usize;
let mut buckets: Vec<(usize, usize)> = Vec::new();
let (mut matches, mut total) = (0usize, 0usize);
for (i, &f) in forced_ids.iter().enumerate() {
let ok = pred == f;
matches += ok as usize;
total += 1;
let b = i / bucket;
if buckets.len() <= b { buckets.push((0, 0)); }
buckets[b].0 += ok as usize;
buckets[b].1 += 1;
// Teacher-force: feed the oracle token through the decode path.
let pos = cache.seq_len(slot);
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
tokens: vec![f], positions: vec![pos], slots: vec![slot],
});
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
wait_workers(&worker_handles);
pred = sample_greedy_last(&logits);
}
eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
100.0 * matches as f64 / total as f64);
for (b, (m, t)) in buckets.iter().enumerate() {
eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64));
}
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
for (h, _) in worker_handles { h.join().unwrap(); }
return;
}
// Prefill
let t0 = Instant::now();
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {