Compare commits
2 Commits
2fe903ecea
...
6309dc1181
| Author | SHA1 | Date | |
|---|---|---|---|
| 6309dc1181 | |||
| 264c004662 |
@@ -83,18 +83,35 @@ fn main() {
|
|||||||
if args.len() < 3 {
|
if args.len() < 3 {
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"Usage: bench-eagle3 <target-dir> <eagle3-dir> \
|
"Usage: bench-eagle3 <target-dir> <eagle3-dir> \
|
||||||
[--gen-tokens N] [--prompts N] [--max-seq-len N] [--device N]"
|
[--gen-tokens N] [--prompts N] [--max-seq-len N] [--device N] \
|
||||||
|
[--tree] [--gamma N] [--gsm8k PATH]"
|
||||||
);
|
);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
let target_dir = PathBuf::from(&args[1]);
|
let target_dir = PathBuf::from(&args[1]);
|
||||||
let eagle_dir = PathBuf::from(&args[2]);
|
let eagle_dir = PathBuf::from(&args[2]);
|
||||||
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
|
|
||||||
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
|
|
||||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||||
let device = arg_usize(&args, "--device", 0) as u32;
|
let device = arg_usize(&args, "--device", 0) as u32;
|
||||||
let gamma = arg_usize(&args, "--gamma", 2).max(1);
|
let gamma = arg_usize(&args, "--gamma", 2).max(1);
|
||||||
let use_tree = args.iter().any(|a| a == "--tree");
|
let use_tree = args.iter().any(|a| a == "--tree");
|
||||||
|
let gsm8k_path = arg_str(&args, "--gsm8k");
|
||||||
|
let (prompts_source, default_gen): (Vec<GsmProblem>, usize) = if let Some(p) = &gsm8k_path {
|
||||||
|
(load_gsm8k(p), 512)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
PROMPTS
|
||||||
|
.iter()
|
||||||
|
.map(|s| GsmProblem {
|
||||||
|
id: String::new(),
|
||||||
|
problem: s.to_string(),
|
||||||
|
answer: String::new(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
DEFAULT_GEN_TOKENS,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let gen_tokens = arg_usize(&args, "--gen-tokens", default_gen);
|
||||||
|
let prompt_count = arg_usize(&args, "--prompts", prompts_source.len()).min(prompts_source.len());
|
||||||
|
|
||||||
xserv_cuda::device::set_device(device).unwrap();
|
xserv_cuda::device::set_device(device).unwrap();
|
||||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||||
@@ -137,8 +154,21 @@ fn main() {
|
|||||||
let mut spec_target_steps = 0usize;
|
let mut spec_target_steps = 0usize;
|
||||||
let mut mismatches = 0usize;
|
let mut mismatches = 0usize;
|
||||||
|
|
||||||
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
|
let is_gsm = gsm8k_path.is_some();
|
||||||
let ids = tokenizer.encode(prompt);
|
let mut baseline_correct = 0usize;
|
||||||
|
let mut spec_correct = 0usize;
|
||||||
|
let mut answer_agree = 0usize;
|
||||||
|
let mut both_scored = 0usize;
|
||||||
|
let im_end_id = tokenizer.special_token_id("<|im_end|>");
|
||||||
|
|
||||||
|
for (i, item) in prompts_source.iter().take(prompt_count).enumerate() {
|
||||||
|
let prompt = &item.problem;
|
||||||
|
let ids = if is_gsm {
|
||||||
|
let templated = build_chat_prompt(prompt);
|
||||||
|
tokenizer.encode(&templated)
|
||||||
|
} else {
|
||||||
|
tokenizer.encode(prompt)
|
||||||
|
};
|
||||||
if ids.len() + gen_tokens >= max_seq_len {
|
if ids.len() + gen_tokens >= max_seq_len {
|
||||||
eprintln!("prompt {i} too long, skipping");
|
eprintln!("prompt {i} too long, skipping");
|
||||||
continue;
|
continue;
|
||||||
@@ -195,29 +225,66 @@ fn main() {
|
|||||||
let ok = baseline.ids == spec.ids;
|
let ok = baseline.ids == spec.ids;
|
||||||
if !ok {
|
if !ok {
|
||||||
mismatches += 1;
|
mismatches += 1;
|
||||||
let common = baseline
|
if !is_gsm {
|
||||||
.ids
|
let common = baseline
|
||||||
.iter()
|
.ids
|
||||||
.zip(spec.ids.iter())
|
.iter()
|
||||||
.position(|(a, b)| a != b)
|
.zip(spec.ids.iter())
|
||||||
.unwrap_or(0);
|
.position(|(a, b)| a != b)
|
||||||
eprintln!(
|
.unwrap_or(0);
|
||||||
"MISMATCH prompt {i} (diverge at {}): {prompt}\n baseline: {:?}\n spec: {:?}",
|
eprintln!(
|
||||||
common, baseline.ids, spec.ids
|
"MISMATCH prompt {i} (diverge at {}): {prompt}\n baseline: {:?}\n spec: {:?}",
|
||||||
);
|
common, baseline.ids, spec.ids
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!(
|
if is_gsm {
|
||||||
"prompt={:02} match={} gen={} accept={}/{} target_steps={} baseline_tpot_ms={:.3} spec_tpot_ms={:.3}",
|
let gold = normalize_num(&item.answer);
|
||||||
i,
|
let baseline_text = decode_until_im_end(&tokenizer, &baseline.ids, im_end_id);
|
||||||
ok,
|
let spec_text = decode_until_im_end(&tokenizer, &spec.ids, im_end_id);
|
||||||
spec.ids.len(),
|
let baseline_pred = extract_answer(&baseline_text);
|
||||||
spec.accepted,
|
let spec_pred = extract_answer(&spec_text);
|
||||||
spec.proposed,
|
let b_ok = gold.is_some() && baseline_pred == gold;
|
||||||
spec.target_steps,
|
let s_ok = gold.is_some() && spec_pred == gold;
|
||||||
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
|
let agree = baseline_pred.is_some() && baseline_pred == spec_pred;
|
||||||
spec.total_s * 1000.0 / spec.ids.len() as f64,
|
if b_ok {
|
||||||
);
|
baseline_correct += 1;
|
||||||
|
}
|
||||||
|
if s_ok {
|
||||||
|
spec_correct += 1;
|
||||||
|
}
|
||||||
|
if agree {
|
||||||
|
answer_agree += 1;
|
||||||
|
}
|
||||||
|
both_scored += 1;
|
||||||
|
println!(
|
||||||
|
"q={:03} id={} tok_match={} gold={} base={} spec={} b_ok={} s_ok={} agree={} base_tpot={:.2} spec_tpot={:.2}",
|
||||||
|
i,
|
||||||
|
item.id,
|
||||||
|
ok,
|
||||||
|
gold.as_deref().unwrap_or("?"),
|
||||||
|
baseline_pred.as_deref().unwrap_or("?"),
|
||||||
|
spec_pred.as_deref().unwrap_or("?"),
|
||||||
|
b_ok,
|
||||||
|
s_ok,
|
||||||
|
agree,
|
||||||
|
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
|
||||||
|
spec.total_s * 1000.0 / spec.ids.len() as f64,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
println!(
|
||||||
|
"prompt={:02} match={} gen={} accept={}/{} target_steps={} baseline_tpot_ms={:.3} spec_tpot_ms={:.3}",
|
||||||
|
i,
|
||||||
|
ok,
|
||||||
|
spec.ids.len(),
|
||||||
|
spec.accepted,
|
||||||
|
spec.proposed,
|
||||||
|
spec.target_steps,
|
||||||
|
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
|
||||||
|
spec.total_s * 1000.0 / spec.ids.len() as f64,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let baseline_tpot = baseline_total_s * 1000.0 / baseline_tokens as f64;
|
let baseline_tpot = baseline_total_s * 1000.0 / baseline_tokens as f64;
|
||||||
@@ -240,6 +307,20 @@ fn main() {
|
|||||||
1000.0 / spec_tpot,
|
1000.0 / spec_tpot,
|
||||||
baseline_tpot / spec_tpot
|
baseline_tpot / spec_tpot
|
||||||
);
|
);
|
||||||
|
if is_gsm && both_scored > 0 {
|
||||||
|
println!(
|
||||||
|
"gsm8k: baseline_acc={:.4} ({}/{}) spec_acc={:.4} ({}/{}) agreement={:.4} ({}/{})",
|
||||||
|
baseline_correct as f64 / both_scored as f64,
|
||||||
|
baseline_correct,
|
||||||
|
both_scored,
|
||||||
|
spec_correct as f64 / both_scored as f64,
|
||||||
|
spec_correct,
|
||||||
|
both_scored,
|
||||||
|
answer_agree as f64 / both_scored as f64,
|
||||||
|
answer_agree,
|
||||||
|
both_scored,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
@@ -924,7 +1005,122 @@ fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
|||||||
.unwrap_or(default)
|
.unwrap_or(default)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn arg_str(args: &[String], flag: &str) -> Option<String> {
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == flag)
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
||||||
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
|
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
|
||||||
PagedKVCache::new(config, num_blocks, 0, 16, num_blocks, DType::BF16, device)
|
PagedKVCache::new(config, num_blocks, 0, 16, num_blocks, DType::BF16, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct GsmProblem {
|
||||||
|
id: String,
|
||||||
|
problem: String,
|
||||||
|
answer: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_gsm8k(path: &str) -> Vec<GsmProblem> {
|
||||||
|
let text = std::fs::read_to_string(path)
|
||||||
|
.unwrap_or_else(|e| panic!("failed to read {path}: {e}"));
|
||||||
|
let v: serde_json::Value =
|
||||||
|
serde_json::from_str(&text).unwrap_or_else(|e| panic!("failed to parse {path}: {e}"));
|
||||||
|
let arr = v.as_array().expect("gsm8k json must be a top-level array");
|
||||||
|
arr.iter()
|
||||||
|
.map(|it| GsmProblem {
|
||||||
|
id: it.get("id").and_then(|x| x.as_str()).unwrap_or("").to_string(),
|
||||||
|
problem: it
|
||||||
|
.get("problem")
|
||||||
|
.and_then(|x| x.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string(),
|
||||||
|
answer: it
|
||||||
|
.get("answer")
|
||||||
|
.and_then(|x| x.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
const GSM_SYSTEM: &str = "You are a careful math problem solver. Solve the problem step by step. \
|
||||||
|
Put your final numeric answer inside \\boxed{}.";
|
||||||
|
|
||||||
|
fn build_chat_prompt(user: &str) -> String {
|
||||||
|
let mut s = String::new();
|
||||||
|
s.push_str("<|im_start|>system\n");
|
||||||
|
s.push_str(GSM_SYSTEM);
|
||||||
|
s.push_str("<|im_end|>\n");
|
||||||
|
s.push_str("<|im_start|>user\n");
|
||||||
|
s.push_str(user);
|
||||||
|
s.push_str("<|im_end|>\n");
|
||||||
|
s.push_str("<|im_start|>assistant\n");
|
||||||
|
// enable_thinking=false: skip the model's chain-of-thought preamble to
|
||||||
|
// save tokens and get straight to the answer.
|
||||||
|
s.push_str("<think>\n\n</think>\n\n");
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_until_im_end(tokenizer: &Tokenizer, ids: &[u32], im_end: Option<u32>) -> String {
|
||||||
|
let cut = if let Some(e) = im_end {
|
||||||
|
ids.iter().position(|&t| t == e).unwrap_or(ids.len())
|
||||||
|
} else {
|
||||||
|
ids.len()
|
||||||
|
};
|
||||||
|
tokenizer.decode(&ids[..cut])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_num(s: &str) -> Option<String> {
|
||||||
|
let cleaned: String = s.trim().replace(',', "");
|
||||||
|
let f: f64 = cleaned.parse().ok()?;
|
||||||
|
if f.fract() == 0.0 && f.abs() < 1e15 {
|
||||||
|
Some(format!("{}", f as i64))
|
||||||
|
} else {
|
||||||
|
Some(format!("{}", f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_answer(text: &str) -> Option<String> {
|
||||||
|
// Prefer \boxed{...}
|
||||||
|
if let Some(idx) = text.rfind("\\boxed{") {
|
||||||
|
let start = idx + "\\boxed{".len();
|
||||||
|
let rest = &text[start..];
|
||||||
|
if let Some(end) = rest.find('}') {
|
||||||
|
let inner = &rest[..end];
|
||||||
|
if let Some(n) = last_number_in(inner) {
|
||||||
|
return normalize_num(&n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: last number anywhere in the text.
|
||||||
|
last_number_in(text).and_then(|n| normalize_num(&n))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn last_number_in(text: &str) -> Option<String> {
|
||||||
|
let bytes = text.as_bytes();
|
||||||
|
let mut end: Option<usize> = None;
|
||||||
|
let mut start: Option<usize> = None;
|
||||||
|
let mut i = bytes.len();
|
||||||
|
while i > 0 {
|
||||||
|
i -= 1;
|
||||||
|
let c = bytes[i] as char;
|
||||||
|
let is_num = c.is_ascii_digit() || c == '.' || c == ',' || c == '-';
|
||||||
|
if end.is_none() {
|
||||||
|
if c.is_ascii_digit() {
|
||||||
|
end = Some(i + 1);
|
||||||
|
start = Some(i);
|
||||||
|
}
|
||||||
|
} else if is_num {
|
||||||
|
start = Some(i);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match (start, end) {
|
||||||
|
(Some(s), Some(e)) => Some(text[s..e].to_string()),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
177
docs/27-speculative-quality-gsm8k.md
Normal file
177
docs/27-speculative-quality-gsm8k.md
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
# Phase 27 — Speculative Decoding Quality: Task-Level Correctness at Scale
|
||||||
|
|
||||||
|
**Goal**: prove tree-drafting speculative decoding preserves output quality
|
||||||
|
**despite** batched-verify BF16 rounding differences (`matched=false` on
|
||||||
|
token-by-token comparison).
|
||||||
|
|
||||||
|
## TL;DR
|
||||||
|
|
||||||
|
| Suite | N | baseline_acc | spec_acc | agreement | tpot base→spec | **speedup** |
|
||||||
|
|-------|---|:-----------:|:--------:|:---------:|:--------------:|:-----------:|
|
||||||
|
| GSM8K | 1000 | 93.50% | 93.30% | 97.50% | 13.33 → 8.97 ms | **1.486×** |
|
||||||
|
| AIME2025 | 30 | 16.67% | 13.33% | 23.33% | 17.18 → 11.64 ms | **1.475×** |
|
||||||
|
|
||||||
|
- **Speedup is model+workload driven, not accuracy-driven** — the same
|
||||||
|
1.47-1.49× shows up on high-accuracy chat math (GSM8K) and on saturated
|
||||||
|
long-reasoning math the model can't actually solve (AIME).
|
||||||
|
- **GSM8K**: on 1000 problems, spec accuracy is within 0.2 pp of baseline
|
||||||
|
(933 vs 935 correct). Where the two disagree (25 of 1000): baseline wins
|
||||||
|
9 times, spec wins 7 times, they're both wrong 9 times. Net effect on
|
||||||
|
aggregate accuracy is a wash.
|
||||||
|
- **AIME**: at 8B params Qwen3 is far below the accuracy floor (16.67% =
|
||||||
|
5/30). Divergences here reflect the fact that both trajectories are
|
||||||
|
wandering through low-probability sequences; agreement drops to 23% but
|
||||||
|
spec is only 1 problem behind baseline.
|
||||||
|
|
||||||
|
## Why AIME agreement is low but speedup unchanged
|
||||||
|
|
||||||
|
AIME2025 pushes Qwen3-8B way outside its competence. Both baseline and spec
|
||||||
|
generate long, meandering, often-wrong reasoning; small BF16 rounding
|
||||||
|
differences in tree-verify snowball across ~2000 gen-tokens into completely
|
||||||
|
different (still-wrong) answers. This is expected: when the target
|
||||||
|
distribution has no dominant mode, top-1 argmax is dictated by noise,
|
||||||
|
and any batched-verify rounding will flip it.
|
||||||
|
|
||||||
|
Crucially, `speedup_e2e = 1.475×` on AIME matches `1.486×` on GSM8K to
|
||||||
|
within ~1%. The wall-clock benefit does not depend on the task being
|
||||||
|
solvable — it depends on EAGLE3 draft quality (which stays ~21% on both
|
||||||
|
suites) and the batched-verify cost model.
|
||||||
|
|
||||||
|
## How the test was run
|
||||||
|
|
||||||
|
Extended `bench-eagle3` (from Phase 27) accepts any JSON file with the
|
||||||
|
`{id, problem, answer}` schema. Same binary → same code paths.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GSM8K — 1000 problems, gen_tokens=512, max_seq_len=1024
|
||||||
|
./target/release/bench-eagle3 \
|
||||||
|
/opt/wjh/models/qwen3-8b \
|
||||||
|
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||||
|
--gsm8k tools/bench/data/gsm8k.json \
|
||||||
|
--tree --prompts 1000 --gen-tokens 512 --max-seq-len 1024
|
||||||
|
|
||||||
|
# AIME2025 — 30 problems, gen_tokens=2048, max_seq_len=4096
|
||||||
|
./target/release/bench-eagle3 \
|
||||||
|
/opt/wjh/models/qwen3-8b \
|
||||||
|
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||||
|
--gsm8k tools/bench/data/aime2025.json \
|
||||||
|
--tree --prompts 30 --gen-tokens 2048 --max-seq-len 4096
|
||||||
|
```
|
||||||
|
|
||||||
|
Chat template used (`build_chat_prompt`, math-solver system prompt):
|
||||||
|
```
|
||||||
|
<|im_start|>system
|
||||||
|
You are a careful math problem solver. Solve the problem step by step. Put your final numeric answer inside \boxed{}.
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
{problem}
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<think>
|
||||||
|
|
||||||
|
</think>
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## GSM8K result (1000 problems)
|
||||||
|
|
||||||
|
```
|
||||||
|
--- SUMMARY ---
|
||||||
|
prompts=1000 matched=false
|
||||||
|
acceptance_rate=0.2120 accepted=125326 proposed=591156 target_steps=149789
|
||||||
|
baseline_tpot_ms=13.331 baseline_tok_s=75.013
|
||||||
|
spec_tpot_ms=8.971 spec_tok_s=111.474 speedup_e2e=1.4861
|
||||||
|
gsm8k: baseline_acc=0.9350 (935/1000) spec_acc=0.9330 (933/1000) agreement=0.9750 (975/1000)
|
||||||
|
```
|
||||||
|
|
||||||
|
Disagreement analysis (25/1000 questions where extracted answers differ):
|
||||||
|
- baseline correct, spec wrong: **9**
|
||||||
|
- spec correct, baseline wrong: **7**
|
||||||
|
- both wrong (different wrong answers): **9**
|
||||||
|
|
||||||
|
The counts are essentially symmetric — spec is not systematically worse.
|
||||||
|
|
||||||
|
## AIME2025 result (30 problems, 2048 gen-tokens)
|
||||||
|
|
||||||
|
```
|
||||||
|
--- SUMMARY ---
|
||||||
|
prompts=30 matched=false
|
||||||
|
acceptance_rate=0.2034 accepted=23511 proposed=115596 target_steps=28959
|
||||||
|
baseline_tpot_ms=17.177 baseline_tok_s=58.219
|
||||||
|
spec_tpot_ms=11.642 spec_tok_s=85.896 speedup_e2e=1.4754
|
||||||
|
gsm8k: baseline_acc=0.1667 (5/30) spec_acc=0.1333 (4/30) agreement=0.2333 (7/30)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: the label `gsm8k` in the summary line is a hardcoded label — the
|
||||||
|
data is AIME2025, wrapped in the same chat template.
|
||||||
|
|
||||||
|
Disagreement analysis (23/30 questions differ):
|
||||||
|
- baseline correct, spec wrong: 1
|
||||||
|
- spec correct, baseline wrong: 0
|
||||||
|
- both wrong (different wrong answers): 22
|
||||||
|
|
||||||
|
## Absolute performance
|
||||||
|
|
||||||
|
| metric | baseline | tree-spec |
|
||||||
|
|--------|----------|-----------|
|
||||||
|
| GSM8K tpot | 13.33 ms | 8.97 ms |
|
||||||
|
| GSM8K tok/s | 75.0 | 111.5 |
|
||||||
|
| AIME tpot | 17.18 ms | 11.64 ms |
|
||||||
|
| AIME tok/s | 58.2 | 85.9 |
|
||||||
|
|
||||||
|
AIME's absolute tpot is higher than GSM8K because average KV length is
|
||||||
|
larger (avg completion ~1500 tokens vs ~350 for GSM8K), which slows the
|
||||||
|
paged attention kernel roughly linearly. **Both suites see the same relative
|
||||||
|
speedup**, confirming EAGLE3 tree-drafting benefits scale with context
|
||||||
|
length rather than depending on it.
|
||||||
|
|
||||||
|
## Interpretation
|
||||||
|
|
||||||
|
The Phase 26 `matched=false` flag has been fully characterized on 1030
|
||||||
|
real problems:
|
||||||
|
|
||||||
|
1. **On solvable tasks (GSM8K)**: spec accuracy is within noise (Δacc =
|
||||||
|
-0.2 pp on 1000 samples, 95% CI easily includes zero). This is what
|
||||||
|
vLLM and SGLang call "lossless" speculative decoding.
|
||||||
|
|
||||||
|
2. **On hard tasks (AIME)**: both baseline and spec meander through wrong
|
||||||
|
answers; agreement collapses because the argmax distribution is nearly
|
||||||
|
flat. Speedup is preserved.
|
||||||
|
|
||||||
|
3. **Draft acceptance is the invariant**: acceptance_rate = 21.2% (GSM8K)
|
||||||
|
vs 20.3% (AIME) — nearly identical, because EAGLE3's draft quality
|
||||||
|
depends on target distribution predictability, which is similar for
|
||||||
|
both math-formatted chat prompts.
|
||||||
|
|
||||||
|
Speculative decoding is **correctness-preserving in expectation**, not
|
||||||
|
bit-exact. This is the same guarantee production systems ship.
|
||||||
|
|
||||||
|
## What was NOT changed
|
||||||
|
|
||||||
|
- No changes to kernels, attention, KV cache, EAGLE3 head, or the tree
|
||||||
|
drafting policy (still γ=2 top-3 as in commit `2fe903e`).
|
||||||
|
- Bench binary already supported `--gsm8k <path>` from commit `264c004`;
|
||||||
|
we simply pointed it at both `gsm8k.json` and `aime2025.json`.
|
||||||
|
|
||||||
|
## Files touched
|
||||||
|
|
||||||
|
- `docs/27-speculative-quality-gsm8k.md` — rewritten with 1000-scale
|
||||||
|
GSM8K and 30-problem AIME2025 results.
|
||||||
|
|
||||||
|
## Reproduction
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# on dash5 (5090)
|
||||||
|
cd /opt/wjh/projects/xserv
|
||||||
|
./target/release/bench-eagle3 /opt/wjh/models/qwen3-8b \
|
||||||
|
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||||
|
--gsm8k tools/bench/data/gsm8k.json \
|
||||||
|
--tree --prompts 1000 --gen-tokens 512 --max-seq-len 1024
|
||||||
|
# ~90 minutes wall-clock on 5090
|
||||||
|
|
||||||
|
./target/release/bench-eagle3 /opt/wjh/models/qwen3-8b \
|
||||||
|
/dashscope-tmp/wjh/models/qwen3-8b-eagle3 \
|
||||||
|
--gsm8k tools/bench/data/aime2025.json \
|
||||||
|
--tree --prompts 30 --gen-tokens 2048 --max-seq-len 4096
|
||||||
|
# ~11 minutes wall-clock on 5090
|
||||||
|
```
|
||||||
Reference in New Issue
Block a user