bench-gpt-oss: teacher-forced diagnostics + --prompt flag
Add --prompt to override the fixed prompt, and two teacher-forced diagnostics: --forced runs prefill over prompt+oracle ids and reports per-position top-1 agreement; --forced-decode walks the oracle trajectory through the decode path with per-position agreement bucketed by position, to localize long-context decode divergence from the reference. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -68,7 +68,8 @@ fn main() {
|
||||
eprintln!("[rank 0] Ready.");
|
||||
|
||||
// Prompt
|
||||
let prompt = "What is the meaning of life?";
|
||||
let prompt_arg = get_arg::<String>(&args, "--prompt");
|
||||
let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?");
|
||||
let token_ids = tokenizer.encode(prompt);
|
||||
eprintln!("Prompt ({} tokens): {prompt}", token_ids.len());
|
||||
|
||||
@@ -77,6 +78,84 @@ fn main() {
|
||||
cache.register_sequence(slot).unwrap();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Register(slot));
|
||||
|
||||
// Teacher-forced diagnostic: prefill (prompt + forced ids) in one shot and
|
||||
// report, for each forced position, whether xserv's argmax == the forced
|
||||
// (oracle) next token. Removes free-running compounding so it isolates
|
||||
// whether per-position logits agree with the llama.cpp trajectory.
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced") {
|
||||
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||
let mut seq = token_ids.clone();
|
||||
seq.extend_from_slice(&forced_ids);
|
||||
// Workers must run the same prefill in lockstep (TP AllReduces match up).
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot });
|
||||
let logits = model.forward_prefill_paged(&seq, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let vocab = logits.shape()[1];
|
||||
let data = logits_cpu.as_slice::<half::bf16>();
|
||||
let plen = token_ids.len();
|
||||
let mut matches = 0usize;
|
||||
let mut total = 0usize;
|
||||
// position i predicts seq[i+1]; we check the forced region
|
||||
for i in (plen - 1)..(seq.len() - 1) {
|
||||
let row = &data[i * vocab..(i + 1) * vocab];
|
||||
let argmax = row.iter().enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(j, _)| j as u32).unwrap();
|
||||
let expected = seq[i + 1];
|
||||
let ok = argmax == expected;
|
||||
if ok { matches += 1; }
|
||||
total += 1;
|
||||
eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"});
|
||||
}
|
||||
eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||
return;
|
||||
}
|
||||
|
||||
// Teacher-forced DECODE diagnostic: prefill the prompt, then walk the oracle
|
||||
// trajectory through the autoregressive decode path (NOT prefill), recording
|
||||
// per-position top-1 agreement bucketed by position. Localizes long-context
|
||||
// decode degradation (which prefill teacher-forcing cannot see).
|
||||
if let Some(forced) = get_arg::<String>(&args, "--forced-decode") {
|
||||
let forced_ids: Vec<u32> = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot });
|
||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
let mut pred = sample_greedy_last(&logits); // prediction for forced[0]
|
||||
let bucket = 50usize;
|
||||
let mut buckets: Vec<(usize, usize)> = Vec::new();
|
||||
let (mut matches, mut total) = (0usize, 0usize);
|
||||
for (i, &f) in forced_ids.iter().enumerate() {
|
||||
let ok = pred == f;
|
||||
matches += ok as usize;
|
||||
total += 1;
|
||||
let b = i / bucket;
|
||||
if buckets.len() <= b { buckets.push((0, 0)); }
|
||||
buckets[b].0 += ok as usize;
|
||||
buckets[b].1 += 1;
|
||||
// Teacher-force: feed the oracle token through the decode path.
|
||||
let pos = cache.seq_len(slot);
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode {
|
||||
tokens: vec![f], positions: vec![pos], slots: vec![slot],
|
||||
});
|
||||
let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache);
|
||||
wait_workers(&worker_handles);
|
||||
pred = sample_greedy_last(&logits);
|
||||
}
|
||||
eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%",
|
||||
100.0 * matches as f64 / total as f64);
|
||||
for (b, (m, t)) in buckets.iter().enumerate() {
|
||||
eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%",
|
||||
b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64));
|
||||
}
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown);
|
||||
for (h, _) in worker_handles { h.join().unwrap(); }
|
||||
return;
|
||||
}
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {
|
||||
|
||||
Reference in New Issue
Block a user