Compare commits
2 Commits
2fe903ecea
...
6309dc1181
| Author | SHA1 | Date | |
|---|---|---|---|
| 6309dc1181 | |||
| 264c004662 |
@@ -83,18 +83,35 @@ fn main() {
|
||||
if args.len() < 3 {
|
||||
eprintln!(
|
||||
"Usage: bench-eagle3 <target-dir> <eagle3-dir> \
|
||||
[--gen-tokens N] [--prompts N] [--max-seq-len N] [--device N]"
|
||||
[--gen-tokens N] [--prompts N] [--max-seq-len N] [--device N] \
|
||||
[--tree] [--gamma N] [--gsm8k PATH]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let target_dir = PathBuf::from(&args[1]);
|
||||
let eagle_dir = PathBuf::from(&args[2]);
|
||||
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
|
||||
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
let gamma = arg_usize(&args, "--gamma", 2).max(1);
|
||||
let use_tree = args.iter().any(|a| a == "--tree");
|
||||
let gsm8k_path = arg_str(&args, "--gsm8k");
|
||||
let (prompts_source, default_gen): (Vec<GsmProblem>, usize) = if let Some(p) = &gsm8k_path {
|
||||
(load_gsm8k(p), 512)
|
||||
} else {
|
||||
(
|
||||
PROMPTS
|
||||
.iter()
|
||||
.map(|s| GsmProblem {
|
||||
id: String::new(),
|
||||
problem: s.to_string(),
|
||||
answer: String::new(),
|
||||
})
|
||||
.collect(),
|
||||
DEFAULT_GEN_TOKENS,
|
||||
)
|
||||
};
|
||||
let gen_tokens = arg_usize(&args, "--gen-tokens", default_gen);
|
||||
let prompt_count = arg_usize(&args, "--prompts", prompts_source.len()).min(prompts_source.len());
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
@@ -137,8 +154,21 @@ fn main() {
|
||||
let mut spec_target_steps = 0usize;
|
||||
let mut mismatches = 0usize;
|
||||
|
||||
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
|
||||
let ids = tokenizer.encode(prompt);
|
||||
let is_gsm = gsm8k_path.is_some();
|
||||
let mut baseline_correct = 0usize;
|
||||
let mut spec_correct = 0usize;
|
||||
let mut answer_agree = 0usize;
|
||||
let mut both_scored = 0usize;
|
||||
let im_end_id = tokenizer.special_token_id("<|im_end|>");
|
||||
|
||||
for (i, item) in prompts_source.iter().take(prompt_count).enumerate() {
|
||||
let prompt = &item.problem;
|
||||
let ids = if is_gsm {
|
||||
let templated = build_chat_prompt(prompt);
|
||||
tokenizer.encode(&templated)
|
||||
} else {
|
||||
tokenizer.encode(prompt)
|
||||
};
|
||||
if ids.len() + gen_tokens >= max_seq_len {
|
||||
eprintln!("prompt {i} too long, skipping");
|
||||
continue;
|
||||
@@ -195,29 +225,66 @@ fn main() {
|
||||
let ok = baseline.ids == spec.ids;
|
||||
if !ok {
|
||||
mismatches += 1;
|
||||
let common = baseline
|
||||
.ids
|
||||
.iter()
|
||||
.zip(spec.ids.iter())
|
||||
.position(|(a, b)| a != b)
|
||||
.unwrap_or(0);
|
||||
eprintln!(
|
||||
"MISMATCH prompt {i} (diverge at {}): {prompt}\n baseline: {:?}\n spec: {:?}",
|
||||
common, baseline.ids, spec.ids
|
||||
);
|
||||
if !is_gsm {
|
||||
let common = baseline
|
||||
.ids
|
||||
.iter()
|
||||
.zip(spec.ids.iter())
|
||||
.position(|(a, b)| a != b)
|
||||
.unwrap_or(0);
|
||||
eprintln!(
|
||||
"MISMATCH prompt {i} (diverge at {}): {prompt}\n baseline: {:?}\n spec: {:?}",
|
||||
common, baseline.ids, spec.ids
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"prompt={:02} match={} gen={} accept={}/{} target_steps={} baseline_tpot_ms={:.3} spec_tpot_ms={:.3}",
|
||||
i,
|
||||
ok,
|
||||
spec.ids.len(),
|
||||
spec.accepted,
|
||||
spec.proposed,
|
||||
spec.target_steps,
|
||||
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
|
||||
spec.total_s * 1000.0 / spec.ids.len() as f64,
|
||||
);
|
||||
if is_gsm {
|
||||
let gold = normalize_num(&item.answer);
|
||||
let baseline_text = decode_until_im_end(&tokenizer, &baseline.ids, im_end_id);
|
||||
let spec_text = decode_until_im_end(&tokenizer, &spec.ids, im_end_id);
|
||||
let baseline_pred = extract_answer(&baseline_text);
|
||||
let spec_pred = extract_answer(&spec_text);
|
||||
let b_ok = gold.is_some() && baseline_pred == gold;
|
||||
let s_ok = gold.is_some() && spec_pred == gold;
|
||||
let agree = baseline_pred.is_some() && baseline_pred == spec_pred;
|
||||
if b_ok {
|
||||
baseline_correct += 1;
|
||||
}
|
||||
if s_ok {
|
||||
spec_correct += 1;
|
||||
}
|
||||
if agree {
|
||||
answer_agree += 1;
|
||||
}
|
||||
both_scored += 1;
|
||||
println!(
|
||||
"q={:03} id={} tok_match={} gold={} base={} spec={} b_ok={} s_ok={} agree={} base_tpot={:.2} spec_tpot={:.2}",
|
||||
i,
|
||||
item.id,
|
||||
ok,
|
||||
gold.as_deref().unwrap_or("?"),
|
||||
baseline_pred.as_deref().unwrap_or("?"),
|
||||
spec_pred.as_deref().unwrap_or("?"),
|
||||
b_ok,
|
||||
s_ok,
|
||||
agree,
|
||||
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
|
||||
spec.total_s * 1000.0 / spec.ids.len() as f64,
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"prompt={:02} match={} gen={} accept={}/{} target_steps={} baseline_tpot_ms={:.3} spec_tpot_ms={:.3}",
|
||||
i,
|
||||
ok,
|
||||
spec.ids.len(),
|
||||
spec.accepted,
|
||||
spec.proposed,
|
||||
spec.target_steps,
|
||||
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
|
||||
spec.total_s * 1000.0 / spec.ids.len() as f64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let baseline_tpot = baseline_total_s * 1000.0 / baseline_tokens as f64;
|
||||
@@ -240,6 +307,20 @@ fn main() {
|
||||
1000.0 / spec_tpot,
|
||||
baseline_tpot / spec_tpot
|
||||
);
|
||||
if is_gsm && both_scored > 0 {
|
||||
println!(
|
||||
"gsm8k: baseline_acc={:.4} ({}/{}) spec_acc={:.4} ({}/{}) agreement={:.4} ({}/{})",
|
||||
baseline_correct as f64 / both_scored as f64,
|
||||
baseline_correct,
|
||||
both_scored,
|
||||
spec_correct as f64 / both_scored as f64,
|
||||
spec_correct,
|
||||
both_scored,
|
||||
answer_agree as f64 / both_scored as f64,
|
||||
answer_agree,
|
||||
both_scored,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -924,7 +1005,122 @@ fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn arg_str(args: &[String], flag: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
||||
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
|
||||
PagedKVCache::new(config, num_blocks, 0, 16, num_blocks, DType::BF16, device)
|
||||
}
|
||||
|
||||
struct GsmProblem {
|
||||
id: String,
|
||||
problem: String,
|
||||
answer: String,
|
||||
}
|
||||
|
||||
fn load_gsm8k(path: &str) -> Vec<GsmProblem> {
|
||||
let text = std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {path}: {e}"));
|
||||
let v: serde_json::Value =
|
||||
serde_json::from_str(&text).unwrap_or_else(|e| panic!("failed to parse {path}: {e}"));
|
||||
let arr = v.as_array().expect("gsm8k json must be a top-level array");
|
||||
arr.iter()
|
||||
.map(|it| GsmProblem {
|
||||
id: it.get("id").and_then(|x| x.as_str()).unwrap_or("").to_string(),
|
||||
problem: it
|
||||
.get("problem")
|
||||
.and_then(|x| x.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
answer: it
|
||||
.get("answer")
|
||||
.and_then(|x| x.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
const GSM_SYSTEM: &str = "You are a careful math problem solver. Solve the problem step by step. \
|
||||
Put your final numeric answer inside \\boxed{}.";
|
||||
|
||||
fn build_chat_prompt(user: &str) -> String {
|
||||
let mut s = String::new();
|
||||
s.push_str("<|im_start|>system\n");
|
||||
s.push_str(GSM_SYSTEM);
|
||||
s.push_str("<|im_end|>\n");
|
||||
s.push_str("<|im_start|>user\n");
|
||||
s.push_str(user);
|
||||
s.push_str("<|im_end|>\n");
|
||||
s.push_str("<|im_start|>assistant\n");
|
||||
// enable_thinking=false: skip the model's chain-of-thought preamble to
|
||||
// save tokens and get straight to the answer.
|
||||
s.push_str("<think>\n\n</think>\n\n");
|
||||
s
|
||||
}
|
||||
|
||||
fn decode_until_im_end(tokenizer: &Tokenizer, ids: &[u32], im_end: Option<u32>) -> String {
|
||||
let cut = if let Some(e) = im_end {
|
||||
ids.iter().position(|&t| t == e).unwrap_or(ids.len())
|
||||
} else {
|
||||
ids.len()
|
||||
};
|
||||
tokenizer.decode(&ids[..cut])
|
||||
}
|
||||
|
||||
fn normalize_num(s: &str) -> Option<String> {
|
||||
let cleaned: String = s.trim().replace(',', "");
|
||||
let f: f64 = cleaned.parse().ok()?;
|
||||
if f.fract() == 0.0 && f.abs() < 1e15 {
|
||||
Some(format!("{}", f as i64))
|
||||
} else {
|
||||
Some(format!("{}", f))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_answer(text: &str) -> Option<String> {
|
||||
// Prefer \boxed{...}
|
||||
if let Some(idx) = text.rfind("\\boxed{") {
|
||||
let start = idx + "\\boxed{".len();
|
||||
let rest = &text[start..];
|
||||
if let Some(end) = rest.find('}') {
|
||||
let inner = &rest[..end];
|
||||
if let Some(n) = last_number_in(inner) {
|
||||
return normalize_num(&n);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: last number anywhere in the text.
|
||||
last_number_in(text).and_then(|n| normalize_num(&n))
|
||||
}
|
||||
|
||||
fn last_number_in(text: &str) -> Option<String> {
|
||||
let bytes = text.as_bytes();
|
||||
let mut end: Option<usize> = None;
|
||||
let mut start: Option<usize> = None;
|
||||
let mut i = bytes.len();
|
||||
while i > 0 {
|
||||
i -= 1;
|
||||
let c = bytes[i] as char;
|
||||
let is_num = c.is_ascii_digit() || c == '.' || c == ',' || c == '-';
|
||||
if end.is_none() {
|
||||
if c.is_ascii_digit() {
|
||||
end = Some(i + 1);
|
||||
start = Some(i);
|
||||
}
|
||||
} else if is_num {
|
||||
start = Some(i);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
match (start, end) {
|
||||
(Some(s), Some(e)) => Some(text[s..e].to_string()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
177
docs/27-speculative-quality-gsm8k.md
Normal file
177
docs/27-speculative-quality-gsm8k.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# Phase 27 — Speculative Decoding Quality: Task-Level Correctness at Scale
|
||||
|
||||
**Goal**: prove tree-drafting speculative decoding preserves output quality
|
||||
**despite** batched-verify BF16 rounding differences (`matched=false` on
|
||||
token-by-token comparison).
|
||||
|
||||
## TL;DR
|
||||
|
||||
| Suite | N | baseline_acc | spec_acc | agreement | tpot base→spec | **speedup** |
|
||||
|-------|---|:-----------:|:--------:|:---------:|:--------------:|:-----------:|
|
||||
| GSM8K | 1000 | 93.50% | 93.30% | 97.50% | 13.33 → 8.97 ms | **1.486×** |
|
||||
| AIME2025 | 30 | 16.67% | 13.33% | 23.33% | 17.18 → 11.64 ms | **1.475×** |
|
||||
|
||||
- **Speedup is model+workload driven, not accuracy-driven** — the same
|
||||
1.47-1.49× shows up on high-accuracy chat math (GSM8K) and on saturated
|
||||
long-reasoning math the model can't actually solve (AIME).
|
||||
- **GSM8K**: on 1000 problems, spec accuracy is within 0.2 pp of baseline
|
||||
(933 vs 935 correct). Where the two disagree (25 of 1000): baseline wins
|
||||
9 times, spec wins 7 times, they're both wrong 9 times. Net effect on
|
||||
aggregate accuracy is a wash.
|
||||
- **AIME**: at 8B params Qwen3 is far below the accuracy floor (16.67% =
|
||||
5/30). Divergences here reflect the fact that both trajectories are
|
||||
wandering through low-probability sequences; agreement drops to 23% but
|
||||
spec is only 1 problem behind baseline.
|
||||
|
||||
## Why AIME agreement is low but speedup unchanged
|
||||
|
||||
AIME2025 pushes Qwen3-8B way outside its competence. Both baseline and spec
|
||||
generate long, meandering, often-wrong reasoning; small BF16 rounding
|
||||
differences in tree-verify snowball across ~2000 gen-tokens into completely
|
||||
different (still-wrong) answers. This is expected: when the target
|
||||
distribution has no dominant mode, top-1 argmax is dictated by noise,
|
||||
and any batched-verify rounding will flip it.
|
||||
|
||||
Crucially, `speedup_e2e = 1.475×` on AIME matches `1.486×` on GSM8K to
|
||||
within ~1%. The wall-clock benefit does not depend on the task being
|
||||
solvable — it depends on EAGLE3 draft quality (which stays ~21% on both
|
||||
suites) and the batched-verify cost model.
|
||||
|
||||
## How the test was run
|
||||
|
||||
Extended `bench-eagle3` (from Phase 27) accepts any JSON file with the
|
||||
`{id, problem, answer}` schema. Same binary → same code paths.
|
||||
|
||||
```bash
|
||||
# GSM8K — 1000 problems, gen_tokens=512, max_seq_len=1024
|
||||
./target/release/bench-eagle3 \
|
||||
/opt/wjh/models/qwen3-8b \
|
||||
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||
--gsm8k tools/bench/data/gsm8k.json \
|
||||
--tree --prompts 1000 --gen-tokens 512 --max-seq-len 1024
|
||||
|
||||
# AIME2025 — 30 problems, gen_tokens=2048, max_seq_len=4096
|
||||
./target/release/bench-eagle3 \
|
||||
/opt/wjh/models/qwen3-8b \
|
||||
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||
--gsm8k tools/bench/data/aime2025.json \
|
||||
--tree --prompts 30 --gen-tokens 2048 --max-seq-len 4096
|
||||
```
|
||||
|
||||
Chat template used (`build_chat_prompt`, math-solver system prompt):
|
||||
```
|
||||
<|im_start|>system
|
||||
You are a careful math problem solver. Solve the problem step by step. Put your final numeric answer inside \boxed{}.
|
||||
<|im_end|>
|
||||
<|im_start|>user
|
||||
{problem}
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<think>
|
||||
|
||||
</think>
|
||||
|
||||
```
|
||||
|
||||
## GSM8K result (1000 problems)
|
||||
|
||||
```
|
||||
--- SUMMARY ---
|
||||
prompts=1000 matched=false
|
||||
acceptance_rate=0.2120 accepted=125326 proposed=591156 target_steps=149789
|
||||
baseline_tpot_ms=13.331 baseline_tok_s=75.013
|
||||
spec_tpot_ms=8.971 spec_tok_s=111.474 speedup_e2e=1.4861
|
||||
gsm8k: baseline_acc=0.9350 (935/1000) spec_acc=0.9330 (933/1000) agreement=0.9750 (975/1000)
|
||||
```
|
||||
|
||||
Disagreement analysis (25/1000 questions where extracted answers differ):
|
||||
- baseline correct, spec wrong: **9**
|
||||
- spec correct, baseline wrong: **7**
|
||||
- both wrong (different wrong answers): **9**
|
||||
|
||||
The counts are essentially symmetric — spec is not systematically worse.
|
||||
|
||||
## AIME2025 result (30 problems, 2048 gen-tokens)
|
||||
|
||||
```
|
||||
--- SUMMARY ---
|
||||
prompts=30 matched=false
|
||||
acceptance_rate=0.2034 accepted=23511 proposed=115596 target_steps=28959
|
||||
baseline_tpot_ms=17.177 baseline_tok_s=58.219
|
||||
spec_tpot_ms=11.642 spec_tok_s=85.896 speedup_e2e=1.4754
|
||||
gsm8k: baseline_acc=0.1667 (5/30) spec_acc=0.1333 (4/30) agreement=0.2333 (7/30)
|
||||
```
|
||||
|
||||
Note: the label `gsm8k` in the summary line is a hardcoded label — the
|
||||
data is AIME2025, wrapped in the same chat template.
|
||||
|
||||
Disagreement analysis (23/30 questions differ):
|
||||
- baseline correct, spec wrong: 1
|
||||
- spec correct, baseline wrong: 0
|
||||
- both wrong (different wrong answers): 22
|
||||
|
||||
## Absolute performance
|
||||
|
||||
| metric | baseline | tree-spec |
|
||||
|--------|----------|-----------|
|
||||
| GSM8K tpot | 13.33 ms | 8.97 ms |
|
||||
| GSM8K tok/s | 75.0 | 111.5 |
|
||||
| AIME tpot | 17.18 ms | 11.64 ms |
|
||||
| AIME tok/s | 58.2 | 85.9 |
|
||||
|
||||
AIME's absolute tpot is higher than GSM8K because average KV length is
|
||||
larger (avg completion ~1500 tokens vs ~350 for GSM8K), which slows the
|
||||
paged attention kernel roughly linearly. **Both suites see the same relative
|
||||
speedup**, confirming EAGLE3 tree-drafting benefits scale with context
|
||||
length rather than depending on it.
|
||||
|
||||
## Interpretation
|
||||
|
||||
The Phase 26 `matched=false` flag has been fully characterized on 1030
|
||||
real problems:
|
||||
|
||||
1. **On solvable tasks (GSM8K)**: spec accuracy is within noise (Δacc =
|
||||
-0.2 pp on 1000 samples, 95% CI easily includes zero). This is what
|
||||
vLLM and SGLang call "lossless" speculative decoding.
|
||||
|
||||
2. **On hard tasks (AIME)**: both baseline and spec meander through wrong
|
||||
answers; agreement collapses because the argmax distribution is nearly
|
||||
flat. Speedup is preserved.
|
||||
|
||||
3. **Draft acceptance is the invariant**: acceptance_rate = 21.2% (GSM8K)
|
||||
vs 20.3% (AIME) — nearly identical, because EAGLE3's draft quality
|
||||
depends on target distribution predictability, which is similar for
|
||||
both math-formatted chat prompts.
|
||||
|
||||
Speculative decoding is **correctness-preserving in expectation**, not
|
||||
bit-exact. This is the same guarantee production systems ship.
|
||||
|
||||
## What was NOT changed
|
||||
|
||||
- No changes to kernels, attention, KV cache, EAGLE3 head, or the tree
|
||||
drafting policy (still γ=2 top-3 as in commit `2fe903e`).
|
||||
- Bench binary already supported `--gsm8k <path>` from commit `264c004`;
|
||||
we simply pointed it at both `gsm8k.json` and `aime2025.json`.
|
||||
|
||||
## Files touched
|
||||
|
||||
- `docs/27-speculative-quality-gsm8k.md` — rewritten with 1000-scale
|
||||
GSM8K and 30-problem AIME2025 results.
|
||||
|
||||
## Reproduction
|
||||
|
||||
```bash
|
||||
# on dash5 (5090)
|
||||
cd /opt/wjh/projects/xserv
|
||||
./target/release/bench-eagle3 /opt/wjh/models/qwen3-8b \
|
||||
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||
--gsm8k tools/bench/data/gsm8k.json \
|
||||
--tree --prompts 1000 --gen-tokens 512 --max-seq-len 1024
|
||||
# ~90 minutes wall-clock on 5090
|
||||
|
||||
./target/release/bench-eagle3 /opt/wjh/models/qwen3-8b \
|
||||
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||
--gsm8k tools/bench/data/aime2025.json \
|
||||
--tree --prompts 30 --gen-tokens 2048 --max-seq-len 4096
|
||||
# ~11 minutes wall-clock on 5090
|
||||
```
|
||||
Reference in New Issue
Block a user