From 0c6135aea31034c60a2bda897b9f567a10f66f59 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sun, 31 May 2026 00:56:46 +0800 Subject: [PATCH] bench-gpt-oss: teacher-forced diagnostics + --prompt flag Add --prompt to override the fixed prompt, and two teacher-forced diagnostics: --forced runs prefill over prompt+oracle ids and reports per-position top-1 agreement; --forced-decode walks the oracle trajectory through the decode path with per-position agreement bucketed by position, to localize long-context decode divergence from the reference. Co-Authored-By: Claude Opus 4.8 --- crates/xserv-model/src/bin/bench-gpt-oss.rs | 81 ++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/crates/xserv-model/src/bin/bench-gpt-oss.rs b/crates/xserv-model/src/bin/bench-gpt-oss.rs index 98fe9e2..f80f27c 100644 --- a/crates/xserv-model/src/bin/bench-gpt-oss.rs +++ b/crates/xserv-model/src/bin/bench-gpt-oss.rs @@ -68,7 +68,8 @@ fn main() { eprintln!("[rank 0] Ready."); // Prompt - let prompt = "What is the meaning of life?"; + let prompt_arg = get_arg::(&args, "--prompt"); + let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?"); let token_ids = tokenizer.encode(prompt); eprintln!("Prompt ({} tokens): {prompt}", token_ids.len()); @@ -77,6 +78,84 @@ fn main() { cache.register_sequence(slot).unwrap(); broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Register(slot)); + // Teacher-forced diagnostic: prefill (prompt + forced ids) in one shot and + // report, for each forced position, whether xserv's argmax == the forced + // (oracle) next token. Removes free-running compounding so it isolates + // whether per-position logits agree with the llama.cpp trajectory. + if let Some(forced) = get_arg::(&args, "--forced") { + let forced_ids: Vec = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect(); + let mut seq = token_ids.clone(); + seq.extend_from_slice(&forced_ids); + // Workers must run the same prefill in lockstep (TP AllReduces match up). + broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot }); + let logits = model.forward_prefill_paged(&seq, slot, &mut cache); + wait_workers(&worker_handles); + let logits_cpu = logits.to_device(Device::Cpu); + let vocab = logits.shape()[1]; + let data = logits_cpu.as_slice::(); + let plen = token_ids.len(); + let mut matches = 0usize; + let mut total = 0usize; + // position i predicts seq[i+1]; we check the forced region + for i in (plen - 1)..(seq.len() - 1) { + let row = &data[i * vocab..(i + 1) * vocab]; + let argmax = row.iter().enumerate() + .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) + .map(|(j, _)| j as u32).unwrap(); + let expected = seq[i + 1]; + let ok = argmax == expected; + if ok { matches += 1; } + total += 1; + eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"}); + } + eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%", + 100.0 * matches as f64 / total as f64); + broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown); + for (h, _) in worker_handles { h.join().unwrap(); } + return; + } + + // Teacher-forced DECODE diagnostic: prefill the prompt, then walk the oracle + // trajectory through the autoregressive decode path (NOT prefill), recording + // per-position top-1 agreement bucketed by position. Localizes long-context + // decode degradation (which prefill teacher-forcing cannot see). + if let Some(forced) = get_arg::(&args, "--forced-decode") { + let forced_ids: Vec = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect(); + broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot }); + let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache); + wait_workers(&worker_handles); + let mut pred = sample_greedy_last(&logits); // prediction for forced[0] + let bucket = 50usize; + let mut buckets: Vec<(usize, usize)> = Vec::new(); + let (mut matches, mut total) = (0usize, 0usize); + for (i, &f) in forced_ids.iter().enumerate() { + let ok = pred == f; + matches += ok as usize; + total += 1; + let b = i / bucket; + if buckets.len() <= b { buckets.push((0, 0)); } + buckets[b].0 += ok as usize; + buckets[b].1 += 1; + // Teacher-force: feed the oracle token through the decode path. + let pos = cache.seq_len(slot); + broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode { + tokens: vec![f], positions: vec![pos], slots: vec![slot], + }); + let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache); + wait_workers(&worker_handles); + pred = sample_greedy_last(&logits); + } + eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%", + 100.0 * matches as f64 / total as f64); + for (b, (m, t)) in buckets.iter().enumerate() { + eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%", + b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64)); + } + broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown); + for (h, _) in worker_handles { h.join().unwrap(); } + return; + } + // Prefill let t0 = Instant::now(); broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill {