eagle3: γ=1 speculative bench + first end-to-end measurement

bench-eagle3.rs runs the full loop: prefill → for each output token, one
EAGLE draft + one target decode with hidden state hook. Measures
acceptance rate and speedup vs pure target decode.

First numbers on dash5 (10 prompts × 32 tokens, γ=1):
  matched=true (10/10)
  acceptance_rate=1.3% (4/300)  ← should be ~60-70% per EAGLE3 paper
  speedup_e2e=0.95×             ← below 1 because γ=1 does 1 target
                                  decode per output token regardless of
                                  acceptance
  target_steps=320 for 320 tokens

Positive: the plumbing is correct — target/EAGLE both run without error,
output sequences match baseline, all shapes/dtypes check out. The
sanity check earlier showed EAGLE top-5 contains thematically-plausible
tokens (Paris/Tokyo/Madrid for "capital of France is").

Negative: 1.3% acceptance means EAGLE is not currently learning to match
target's greedy top-1. Root causes to investigate:
1. Token/hook pairing convention. Paper uses (h_that_produced_t_i, t_i)
   → predicts t_{i+1}. My bench does the same but sanity check earlier
   suggested pairing might be one off.
2. Missing "training-time test" projection: EAGLE was trained to feed
   its own prev output as fused_h for the next step (γ>1 chaining).
   Currently we always use target hooks, which is what pairing A/B do
   for γ=1, but may not be aligned with training-time behavior.
3. Hook site: I capture x AFTER the residual+MLP. Paper may want x
   BEFORE, or the "hidden_states" as used by the final norm+lm_head.
   Currently the same tensor feeds into final norm during the target
   forward, so pre/post-residual is what I have — but confirming
   against reference Python impl is needed.
4. Weight loading: transposes assume [in,out] → [out,in]. Need to
   validate at least one output layer's shape against expected.

Next step (deferred to another session): download AngelSlim reference
inference code, run same prompt through it, compare intermediate
activations at each stage to isolate the discrepancy.
This commit is contained in:
2026-07-01 17:32:53 +08:00
parent 8f11d6e5cd
commit 68b55fa1e6
2 changed files with 406 additions and 2 deletions

View File

@@ -0,0 +1,384 @@
//! EAGLE3 speculative decoding benchmark (γ=1).
//!
//! Per round:
//! - EAGLE.step(prev_hooks, prev_token, pos) -> draft token d
//! - target.forward_decode_paged(d) -> logits + new hooks -> target argmax = a
//! - If d == target's argmax computed from prev-round hidden ⇒ accept d (already
//! in cache), also commit a. Otherwise reject d (roll back cache) and commit
//! only target's true answer.
//!
//! Speedup potential per round:
//! - Accept path: 1 draft (cheap) + 1 target decode → 2 tokens (2x speedup if
//! draft cost ≈ 0).
//! - Reject path: 1 draft (wasted) + 1 target decode → 1 token (1x, no speedup).
//! - Expected e2e: (1 + accept_rate) tokens per target decode; if accept_rate =
//! 0.7 and draft cost = 10% of target, speedup ≈ 1.7 / 1.1 ≈ 1.55×.
use std::path::PathBuf;
use std::time::Instant;
use xserv_model::eagle3::{EAGLE_HOOK_LAYERS, Eagle3Head};
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
const DEFAULT_MAX_SEQ_LEN: usize = 2048;
const DEFAULT_GEN_TOKENS: usize = 64;
const PROMPTS: [&str; 50] = [
"The capital of France is",
"Once upon a time in a land far away",
"Hello, how are you doing today",
"In a shocking finding, scientists discovered a",
"The weather today is sunny, so I decided to",
"Alan Turing was a British mathematician who",
"The best way to learn programming is",
"Artificial intelligence will change the world because",
"The history of the internet began in the",
"A good morning routine starts with",
"The stock market crashed because investors",
"Deep learning is a subset of machine learning that",
"The president of the United States announced",
"In the year 2050, humans will",
"The secret to happiness is",
"When I was a child, I used to",
"The most important scientific discovery of the century",
"Climate change is caused by",
"The recipe for chocolate cake requires",
"In conclusion, the evidence suggests that",
"The cat sat on the mat and",
"According to recent studies, exercise can",
"The first step in solving any problem is",
"Technology has transformed the way we",
"The novel begins with the protagonist",
"Education is the most powerful weapon",
"The ocean covers more than seventy percent of",
"Last night I had a dream about",
"The company announced its quarterly earnings",
"Music has the power to",
"The difference between success and failure is",
"In the beginning, there was nothing but",
"The doctor told me that I should",
"Python is a popular programming language because",
"The ancient Romans built roads that",
"A balanced diet should include",
"The movie received mixed reviews from critics",
"Space exploration has led to many",
"The teacher asked the students to",
"Global warming is one of the most",
"The bridge collapsed due to structural",
"Quantum computing promises to revolutionize",
"The new policy will affect millions of",
"During the winter months, it is important to",
"The chef prepared a delicious meal using",
"Renewable energy sources include",
"The scientist conducted an experiment to",
"In every generation there are people who",
"The smartphone has become an essential part of",
"After careful consideration, the committee decided to",
];
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 3 {
eprintln!(
"Usage: bench-eagle3 <target-dir> <eagle3-dir> \
[--gen-tokens N] [--prompts N] [--max-seq-len N] [--device N]"
);
std::process::exit(1);
}
let target_dir = PathBuf::from(&args[1]);
let eagle_dir = PathBuf::from(&args[2]);
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
let device = arg_usize(&args, "--device", 0) as u32;
xserv_cuda::device::set_device(device).unwrap();
let info = xserv_cuda::device::device_info(device).unwrap();
eprintln!(
"GPU {device}: {} ({} MB free)",
info.name,
info.free_memory / 1024 / 1024
);
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
eprintln!("Loading target Qwen3-8B...");
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
let target = Qwen3::from_weights(target_config.clone(), target_weights);
xserv_cuda::allocator::cached_trim();
eprintln!("Loading EAGLE3 head...");
let eagle = Eagle3Head::load(&eagle_dir, device);
xserv_cuda::allocator::cached_trim();
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
let embed_tokens = target.embed_tokens_tensor().clone();
// Warmup
{
let mut cache = new_cache(&target_config, max_seq_len, device);
let ids = tokenizer.encode("warmup");
let _ = run_baseline(&target, &mut cache, &tokenizer, &ids, 4);
drop(cache);
}
eprintln!("Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}\n");
let mut baseline_total_s = 0.0f64;
let mut baseline_tokens = 0usize;
let mut spec_total_s = 0.0f64;
let mut spec_tokens = 0usize;
let mut spec_accepted = 0usize;
let mut spec_proposed = 0usize;
let mut spec_target_steps = 0usize;
let mut mismatches = 0usize;
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
let ids = tokenizer.encode(prompt);
if ids.len() + gen_tokens >= max_seq_len {
eprintln!("prompt {i} too long, skipping");
continue;
}
// Baseline: pure target decode.
let mut baseline_cache = new_cache(&target_config, max_seq_len, device);
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
baseline_total_s += baseline.total_s;
baseline_tokens += baseline.ids.len();
drop(baseline_cache);
// Speculative with EAGLE γ=1.
let mut target_cache = new_cache(&target_config, max_seq_len, device);
let spec = run_eagle_gamma1(
&target,
&eagle,
&embed_tokens,
&mut target_cache,
&tokenizer,
&ids,
gen_tokens,
);
spec_total_s += spec.total_s;
spec_tokens += spec.ids.len();
spec_accepted += spec.accepted;
spec_proposed += spec.proposed;
spec_target_steps += spec.target_steps;
drop(target_cache);
let ok = baseline.ids == spec.ids;
if !ok {
mismatches += 1;
let common = baseline
.ids
.iter()
.zip(spec.ids.iter())
.position(|(a, b)| a != b)
.unwrap_or(0);
eprintln!(
"MISMATCH prompt {i} (diverge at {}): {prompt}\n baseline: {:?}\n spec: {:?}",
common, baseline.ids, spec.ids
);
}
println!(
"prompt={:02} match={} gen={} accept={}/{} target_steps={} baseline_tpot_ms={:.3} spec_tpot_ms={:.3}",
i,
ok,
spec.ids.len(),
spec.accepted,
spec.proposed,
spec.target_steps,
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
spec.total_s * 1000.0 / spec.ids.len() as f64,
);
}
let baseline_tpot = baseline_total_s * 1000.0 / baseline_tokens as f64;
let spec_tpot = spec_total_s * 1000.0 / spec_tokens as f64;
println!("\n--- SUMMARY ---");
println!("prompts={} matched={}", prompt_count, mismatches == 0);
let acceptance = spec_accepted as f64 / spec_proposed.max(1) as f64;
println!(
"acceptance_rate={:.4} accepted={} proposed={} target_steps={}",
acceptance, spec_accepted, spec_proposed, spec_target_steps
);
println!(
"baseline_tpot_ms={:.3} baseline_tok_s={:.3}",
baseline_tpot,
1000.0 / baseline_tpot
);
println!(
"spec_tpot_ms={:.3} spec_tok_s={:.3} speedup_e2e={:.4}",
spec_tpot,
1000.0 / spec_tpot,
baseline_tpot / spec_tpot
);
}
#[derive(Default)]
struct RunStats {
ids: Vec<u32>,
total_s: f64,
target_steps: usize,
accepted: usize,
proposed: usize,
}
fn run_baseline(
model: &Qwen3,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_ids: &[u32],
gen_tokens: usize,
) -> RunStats {
let slot = 0;
cache.register_sequence(slot).unwrap();
let t0 = Instant::now();
let logits = model.forward_prefill_paged(prompt_ids, slot, cache);
let mut next = last_argmax(&logits);
let mut generated = vec![next];
let mut steps = 0;
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
let pos = cache.seq_len(slot);
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache);
next = last_argmax(&logits);
generated.push(next);
steps += 1;
}
sync_device();
let total_s = t0.elapsed().as_secs_f64();
cache.free_sequence(slot);
RunStats {
ids: generated,
total_s,
target_steps: steps + 1, // +1 for prefill
..Default::default()
}
}
/// EAGLE γ=1 speculative decoding.
///
/// Invariant: at the start of each round, we have (prev_token, prev_hooks) where
/// prev_hooks are the target hidden states at the position OF prev_token (i.e.,
/// the state that lm_head applied to yields prev_token).
///
/// Round:
/// 1. Draft: eagle.step(prev_hooks, prev_token) → draft_token d.
/// (EAGLE's pairing: state that produced prev_token, plus prev_token itself.)
/// 2. Target verify: forward_decode_paged(prev_token, position=cache.seq_len).
/// This writes K/V and returns (logits, new_hooks). target_argmax(logits) = a.
/// a is what target REALLY says after prev_token.
/// 3. Accept if d == a: commit both d (as the next token in the sequence — via
/// another target decode) and prev_hooks_of_d becomes new_hooks.
/// Wait — if d==a, then a IS the next token. Commit a; the K/V is already
/// correct. Next round: prev_token=a, prev_hooks=new_hooks.
/// For γ=1 the "accept" doesn't skip target decode; it just means the DRAFT
/// matched. Speedup comes from γ≥2 where you avoid multiple target decodes.
/// So γ=1 gives NO speedup over baseline. But it validates correctness.
///
/// For γ=1 we simply track acceptance rate for informational purposes.
fn run_eagle_gamma1(
target: &Qwen3,
eagle: &Eagle3Head,
embed_tokens: &Tensor,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_ids: &[u32],
gen_tokens: usize,
) -> RunStats {
let slot = 0;
cache.register_sequence(slot).unwrap();
let t0 = Instant::now();
// Prefill target — we don't have hidden state hooks from prefill in this
// impl, so we run 1 decode step after prefill to seed the hooks.
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
let first_token = last_argmax(&prefill_logits);
let mut generated = vec![first_token];
// First target decode: input first_token, get new_hooks (which are the state
// that produced NEXT token, i.e., paired with the token this decode outputs).
let (logits, mut hooks) = target_decode_with_hidden(target, first_token, cache, slot);
let mut next = last_argmax(&logits);
generated.push(next);
let mut target_steps = 2; // 1 prefill + 1 decode
let mut accepted = 0usize;
let mut proposed = 0usize;
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
// Draft: EAGLE predicts token after `next`, using state that produced `next`.
let pos = cache.seq_len(slot);
let (draft, _) = eagle.step(&hooks, embed_tokens, next, pos);
proposed += 1;
// Verify: target decodes `next`, producing its true answer for position `pos`.
let (logits, new_hooks) = target_decode_with_hidden(target, next, cache, slot);
target_steps += 1;
let target_next = last_argmax(&logits);
if draft == target_next {
accepted += 1;
}
generated.push(target_next);
next = target_next;
hooks = new_hooks;
}
sync_device();
let total_s = t0.elapsed().as_secs_f64();
cache.free_sequence(slot);
RunStats {
ids: generated,
total_s,
target_steps,
accepted,
proposed,
}
}
fn target_decode_with_hidden(
target: &Qwen3,
token: u32,
cache: &mut PagedKVCache,
slot: usize,
) -> (Tensor, [Tensor; 3]) {
let position = cache.seq_len(slot);
target.decode_prepare(&[position], &[slot], cache);
let ids_gpu = upload_u32(&[token]);
let pos_gpu = upload_u32(&[position as u32]);
target.decode_core_with_hidden(
ids_gpu.as_ptr() as *const std::ffi::c_void,
pos_gpu.as_ptr() as *const std::ffi::c_void,
1,
&[slot],
cache,
&EAGLE_HOOK_LAYERS,
)
}
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).unwrap();
buf.copy_from_host(bytes).unwrap();
buf
}
fn sync_device() {
xserv_cuda::device::synchronize().unwrap();
}
fn last_argmax(logits: &Tensor) -> u32 {
*xserv_kernels::argmax_bf16_to_host(logits).last().unwrap()
}
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
args.iter()
.position(|a| a == flag)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
PagedKVCache::new(config, num_blocks, 0, 1, num_blocks, DType::BF16, device)
}

View File

@@ -106,7 +106,7 @@ fn main() {
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos); let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
let eagle_pred_text = tokenizer.decode(&[eagle_pred]); let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
println!( println!(
"EAGLE draft prediction: {} ({:?})", "EAGLE draft prediction (pairing A: prev=target_first): {} ({:?})",
eagle_pred, eagle_pred_text eagle_pred, eagle_pred_text
); );
@@ -120,7 +120,27 @@ fn main() {
} }
// Show top-5 from eagle logits (in draft vocab space, mapped to target). // Show top-5 from eagle logits (in draft vocab space, mapped to target).
print_top5(&eagle_logits, "EAGLE draft top-5", &eagle, &tokenizer); print_top5(
&eagle_logits,
"EAGLE draft top-5 (pairing A)",
&eagle,
&tokenizer,
);
// Alternative pairing B: pair hooks with target_next (the token those hooks produced
// via lm_head), predict token after target_next. Position advances by 1.
let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1);
let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]);
println!(
"\nEAGLE draft prediction (pairing B: prev=target_next): {} ({:?})",
eagle_pred_b, eagle_pred_b_text
);
print_top5(
&eagle_logits_b,
"EAGLE draft top-5 (pairing B)",
&eagle,
&tokenizer,
);
} }
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer { fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {