eagle3: γ=1 speculative bench + first end-to-end measurement
bench-eagle3.rs runs the full loop: prefill → for each output token, one
EAGLE draft + one target decode with hidden state hook. Measures
acceptance rate and speedup vs pure target decode.
First numbers on dash5 (10 prompts × 32 tokens, γ=1):
matched=true (10/10)
acceptance_rate=1.3% (4/300) ← should be ~60-70% per EAGLE3 paper
speedup_e2e=0.95× ← below 1 because γ=1 does 1 target
decode per output token regardless of
acceptance
target_steps=320 for 320 tokens
Positive: the plumbing is correct — target/EAGLE both run without error,
output sequences match baseline, all shapes/dtypes check out. The
sanity check earlier showed EAGLE top-5 contains thematically-plausible
tokens (Paris/Tokyo/Madrid for "capital of France is").
Negative: 1.3% acceptance means EAGLE is not currently learning to match
target's greedy top-1. Root causes to investigate:
1. Token/hook pairing convention. Paper uses (h_that_produced_t_i, t_i)
→ predicts t_{i+1}. My bench does the same but sanity check earlier
suggested pairing might be one off.
2. Missing "training-time test" projection: EAGLE was trained to feed
its own prev output as fused_h for the next step (γ>1 chaining).
Currently we always use target hooks, which is what pairing A/B do
for γ=1, but may not be aligned with training-time behavior.
3. Hook site: I capture x AFTER the residual+MLP. Paper may want x
BEFORE, or the "hidden_states" as used by the final norm+lm_head.
Currently the same tensor feeds into final norm during the target
forward, so pre/post-residual is what I have — but confirming
against reference Python impl is needed.
4. Weight loading: transposes assume [in,out] → [out,in]. Need to
validate at least one output layer's shape against expected.
Next step (deferred to another session): download AngelSlim reference
inference code, run same prompt through it, compare intermediate
activations at each stage to isolate the discrepancy.
This commit is contained in:
384
crates/xserv-model/src/bin/bench-eagle3.rs
Normal file
384
crates/xserv-model/src/bin/bench-eagle3.rs
Normal file
@@ -0,0 +1,384 @@
|
||||
//! EAGLE3 speculative decoding benchmark (γ=1).
|
||||
//!
|
||||
//! Per round:
|
||||
//! - EAGLE.step(prev_hooks, prev_token, pos) -> draft token d
|
||||
//! - target.forward_decode_paged(d) -> logits + new hooks -> target argmax = a
|
||||
//! - If d == target's argmax computed from prev-round hidden ⇒ accept d (already
|
||||
//! in cache), also commit a. Otherwise reject d (roll back cache) and commit
|
||||
//! only target's true answer.
|
||||
//!
|
||||
//! Speedup potential per round:
|
||||
//! - Accept path: 1 draft (cheap) + 1 target decode → 2 tokens (2x speedup if
|
||||
//! draft cost ≈ 0).
|
||||
//! - Reject path: 1 draft (wasted) + 1 target decode → 1 token (1x, no speedup).
|
||||
//! - Expected e2e: (1 + accept_rate) tokens per target decode; if accept_rate =
|
||||
//! 0.7 and draft cost = 10% of target, speedup ≈ 1.7 / 1.1 ≈ 1.55×.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::eagle3::{EAGLE_HOOK_LAYERS, Eagle3Head};
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 2048;
|
||||
const DEFAULT_GEN_TOKENS: usize = 64;
|
||||
|
||||
const PROMPTS: [&str; 50] = [
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The chef prepared a delicious meal using",
|
||||
"Renewable energy sources include",
|
||||
"The scientist conducted an experiment to",
|
||||
"In every generation there are people who",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 3 {
|
||||
eprintln!(
|
||||
"Usage: bench-eagle3 <target-dir> <eagle3-dir> \
|
||||
[--gen-tokens N] [--prompts N] [--max-seq-len N] [--device N]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let target_dir = PathBuf::from(&args[1]);
|
||||
let eagle_dir = PathBuf::from(&args[2]);
|
||||
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
|
||||
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
eprintln!(
|
||||
"GPU {device}: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
|
||||
eprintln!("Loading target Qwen3-8B...");
|
||||
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
|
||||
let target = Qwen3::from_weights(target_config.clone(), target_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
eprintln!("Loading EAGLE3 head...");
|
||||
let eagle = Eagle3Head::load(&eagle_dir, device);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
let embed_tokens = target.embed_tokens_tensor().clone();
|
||||
|
||||
// Warmup
|
||||
{
|
||||
let mut cache = new_cache(&target_config, max_seq_len, device);
|
||||
let ids = tokenizer.encode("warmup");
|
||||
let _ = run_baseline(&target, &mut cache, &tokenizer, &ids, 4);
|
||||
drop(cache);
|
||||
}
|
||||
eprintln!("Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}\n");
|
||||
|
||||
let mut baseline_total_s = 0.0f64;
|
||||
let mut baseline_tokens = 0usize;
|
||||
let mut spec_total_s = 0.0f64;
|
||||
let mut spec_tokens = 0usize;
|
||||
let mut spec_accepted = 0usize;
|
||||
let mut spec_proposed = 0usize;
|
||||
let mut spec_target_steps = 0usize;
|
||||
let mut mismatches = 0usize;
|
||||
|
||||
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
|
||||
let ids = tokenizer.encode(prompt);
|
||||
if ids.len() + gen_tokens >= max_seq_len {
|
||||
eprintln!("prompt {i} too long, skipping");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Baseline: pure target decode.
|
||||
let mut baseline_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
|
||||
baseline_total_s += baseline.total_s;
|
||||
baseline_tokens += baseline.ids.len();
|
||||
drop(baseline_cache);
|
||||
|
||||
// Speculative with EAGLE γ=1.
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let spec = run_eagle_gamma1(
|
||||
&target,
|
||||
&eagle,
|
||||
&embed_tokens,
|
||||
&mut target_cache,
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
);
|
||||
spec_total_s += spec.total_s;
|
||||
spec_tokens += spec.ids.len();
|
||||
spec_accepted += spec.accepted;
|
||||
spec_proposed += spec.proposed;
|
||||
spec_target_steps += spec.target_steps;
|
||||
drop(target_cache);
|
||||
|
||||
let ok = baseline.ids == spec.ids;
|
||||
if !ok {
|
||||
mismatches += 1;
|
||||
let common = baseline
|
||||
.ids
|
||||
.iter()
|
||||
.zip(spec.ids.iter())
|
||||
.position(|(a, b)| a != b)
|
||||
.unwrap_or(0);
|
||||
eprintln!(
|
||||
"MISMATCH prompt {i} (diverge at {}): {prompt}\n baseline: {:?}\n spec: {:?}",
|
||||
common, baseline.ids, spec.ids
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
"prompt={:02} match={} gen={} accept={}/{} target_steps={} baseline_tpot_ms={:.3} spec_tpot_ms={:.3}",
|
||||
i,
|
||||
ok,
|
||||
spec.ids.len(),
|
||||
spec.accepted,
|
||||
spec.proposed,
|
||||
spec.target_steps,
|
||||
baseline.total_s * 1000.0 / baseline.ids.len() as f64,
|
||||
spec.total_s * 1000.0 / spec.ids.len() as f64,
|
||||
);
|
||||
}
|
||||
|
||||
let baseline_tpot = baseline_total_s * 1000.0 / baseline_tokens as f64;
|
||||
let spec_tpot = spec_total_s * 1000.0 / spec_tokens as f64;
|
||||
println!("\n--- SUMMARY ---");
|
||||
println!("prompts={} matched={}", prompt_count, mismatches == 0);
|
||||
let acceptance = spec_accepted as f64 / spec_proposed.max(1) as f64;
|
||||
println!(
|
||||
"acceptance_rate={:.4} accepted={} proposed={} target_steps={}",
|
||||
acceptance, spec_accepted, spec_proposed, spec_target_steps
|
||||
);
|
||||
println!(
|
||||
"baseline_tpot_ms={:.3} baseline_tok_s={:.3}",
|
||||
baseline_tpot,
|
||||
1000.0 / baseline_tpot
|
||||
);
|
||||
println!(
|
||||
"spec_tpot_ms={:.3} spec_tok_s={:.3} speedup_e2e={:.4}",
|
||||
spec_tpot,
|
||||
1000.0 / spec_tpot,
|
||||
baseline_tpot / spec_tpot
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RunStats {
|
||||
ids: Vec<u32>,
|
||||
total_s: f64,
|
||||
target_steps: usize,
|
||||
accepted: usize,
|
||||
proposed: usize,
|
||||
}
|
||||
|
||||
fn run_baseline(
|
||||
model: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).unwrap();
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_prefill_paged(prompt_ids, slot, cache);
|
||||
let mut next = last_argmax(&logits);
|
||||
let mut generated = vec![next];
|
||||
let mut steps = 0;
|
||||
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache);
|
||||
next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
steps += 1;
|
||||
}
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
cache.free_sequence(slot);
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
target_steps: steps + 1, // +1 for prefill
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// EAGLE γ=1 speculative decoding.
|
||||
///
|
||||
/// Invariant: at the start of each round, we have (prev_token, prev_hooks) where
|
||||
/// prev_hooks are the target hidden states at the position OF prev_token (i.e.,
|
||||
/// the state that lm_head applied to yields prev_token).
|
||||
///
|
||||
/// Round:
|
||||
/// 1. Draft: eagle.step(prev_hooks, prev_token) → draft_token d.
|
||||
/// (EAGLE's pairing: state that produced prev_token, plus prev_token itself.)
|
||||
/// 2. Target verify: forward_decode_paged(prev_token, position=cache.seq_len).
|
||||
/// This writes K/V and returns (logits, new_hooks). target_argmax(logits) = a.
|
||||
/// a is what target REALLY says after prev_token.
|
||||
/// 3. Accept if d == a: commit both d (as the next token in the sequence — via
|
||||
/// another target decode) and prev_hooks_of_d becomes new_hooks.
|
||||
/// Wait — if d==a, then a IS the next token. Commit a; the K/V is already
|
||||
/// correct. Next round: prev_token=a, prev_hooks=new_hooks.
|
||||
/// For γ=1 the "accept" doesn't skip target decode; it just means the DRAFT
|
||||
/// matched. Speedup comes from γ≥2 where you avoid multiple target decodes.
|
||||
/// So γ=1 gives NO speedup over baseline. But it validates correctness.
|
||||
///
|
||||
/// For γ=1 we simply track acceptance rate for informational purposes.
|
||||
fn run_eagle_gamma1(
|
||||
target: &Qwen3,
|
||||
eagle: &Eagle3Head,
|
||||
embed_tokens: &Tensor,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).unwrap();
|
||||
let t0 = Instant::now();
|
||||
|
||||
// Prefill target — we don't have hidden state hooks from prefill in this
|
||||
// impl, so we run 1 decode step after prefill to seed the hooks.
|
||||
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
|
||||
let first_token = last_argmax(&prefill_logits);
|
||||
let mut generated = vec![first_token];
|
||||
|
||||
// First target decode: input first_token, get new_hooks (which are the state
|
||||
// that produced NEXT token, i.e., paired with the token this decode outputs).
|
||||
let (logits, mut hooks) = target_decode_with_hidden(target, first_token, cache, slot);
|
||||
let mut next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
let mut target_steps = 2; // 1 prefill + 1 decode
|
||||
let mut accepted = 0usize;
|
||||
let mut proposed = 0usize;
|
||||
|
||||
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
|
||||
// Draft: EAGLE predicts token after `next`, using state that produced `next`.
|
||||
let pos = cache.seq_len(slot);
|
||||
let (draft, _) = eagle.step(&hooks, embed_tokens, next, pos);
|
||||
proposed += 1;
|
||||
|
||||
// Verify: target decodes `next`, producing its true answer for position `pos`.
|
||||
let (logits, new_hooks) = target_decode_with_hidden(target, next, cache, slot);
|
||||
target_steps += 1;
|
||||
let target_next = last_argmax(&logits);
|
||||
|
||||
if draft == target_next {
|
||||
accepted += 1;
|
||||
}
|
||||
generated.push(target_next);
|
||||
next = target_next;
|
||||
hooks = new_hooks;
|
||||
}
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
cache.free_sequence(slot);
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
target_steps,
|
||||
accepted,
|
||||
proposed,
|
||||
}
|
||||
}
|
||||
|
||||
fn target_decode_with_hidden(
|
||||
target: &Qwen3,
|
||||
token: u32,
|
||||
cache: &mut PagedKVCache,
|
||||
slot: usize,
|
||||
) -> (Tensor, [Tensor; 3]) {
|
||||
let position = cache.seq_len(slot);
|
||||
target.decode_prepare(&[position], &[slot], cache);
|
||||
let ids_gpu = upload_u32(&[token]);
|
||||
let pos_gpu = upload_u32(&[position as u32]);
|
||||
target.decode_core_with_hidden(
|
||||
ids_gpu.as_ptr() as *const std::ffi::c_void,
|
||||
pos_gpu.as_ptr() as *const std::ffi::c_void,
|
||||
1,
|
||||
&[slot],
|
||||
cache,
|
||||
&EAGLE_HOOK_LAYERS,
|
||||
)
|
||||
}
|
||||
|
||||
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
|
||||
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).unwrap();
|
||||
buf.copy_from_host(bytes).unwrap();
|
||||
buf
|
||||
}
|
||||
|
||||
fn sync_device() {
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
|
||||
fn last_argmax(logits: &Tensor) -> u32 {
|
||||
*xserv_kernels::argmax_bf16_to_host(logits).last().unwrap()
|
||||
}
|
||||
|
||||
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
||||
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
|
||||
PagedKVCache::new(config, num_blocks, 0, 1, num_blocks, DType::BF16, device)
|
||||
}
|
||||
@@ -106,7 +106,7 @@ fn main() {
|
||||
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
|
||||
let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
|
||||
println!(
|
||||
"EAGLE draft prediction: {} ({:?})",
|
||||
"EAGLE draft prediction (pairing A: prev=target_first): {} ({:?})",
|
||||
eagle_pred, eagle_pred_text
|
||||
);
|
||||
|
||||
@@ -120,7 +120,27 @@ fn main() {
|
||||
}
|
||||
|
||||
// Show top-5 from eagle logits (in draft vocab space, mapped to target).
|
||||
print_top5(&eagle_logits, "EAGLE draft top-5", &eagle, &tokenizer);
|
||||
print_top5(
|
||||
&eagle_logits,
|
||||
"EAGLE draft top-5 (pairing A)",
|
||||
&eagle,
|
||||
&tokenizer,
|
||||
);
|
||||
|
||||
// Alternative pairing B: pair hooks with target_next (the token those hooks produced
|
||||
// via lm_head), predict token after target_next. Position advances by 1.
|
||||
let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1);
|
||||
let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]);
|
||||
println!(
|
||||
"\nEAGLE draft prediction (pairing B: prev=target_next): {} ({:?})",
|
||||
eagle_pred_b, eagle_pred_b_text
|
||||
);
|
||||
print_top5(
|
||||
&eagle_logits_b,
|
||||
"EAGLE draft top-5 (pairing B)",
|
||||
&eagle,
|
||||
&tokenizer,
|
||||
);
|
||||
}
|
||||
|
||||
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||
|
||||
Reference in New Issue
Block a user