eagle3: coverage + top-3 diagnostic; acceptance ceiling analysis

Add t2d bool tensor loading and per-slot top-3 rate tracking to
bench-eagle3 so we can distinguish three failure modes:
- Not covered: target's argmax not in EAGLE's 32k-vocab (upper bound).
- Not top-3: target's argmax not in EAGLE's top-3 (drafting quality).
- Not top-1: target's argmax not EAGLE's argmax (final acceptance rule).

Measured on 50 prompts × 64 tokens γ=2:
  d[0]: correct=27%, top-3=42%, covered=98% → EAGLE covers vocab well
                                              but often ranks target
                                              answer below top-1.
  d[1]: correct=9%,  top-3=17%, covered=100% → recursive draft even
                                               weaker.

Coverage is essentially not a bottleneck (98%+). The bottleneck is
that EAGLE ranks the true target answer only ~27% of the time at slot
0. Top-3 rate (~42%) shows the correct answer is often in EAGLE's
distribution but not the highest-scored candidate.

To exploit the top-3 headroom would require tree-based verify (multiple
candidates per position, tree-aware attention masking). Each candidate
attends only to its own branch, not siblings. Current paged_decode_
attention writes K/V at unique per-batch positions and does not
support tree causal masks.

Speedup formula analysis (from bench-verify-cost):
  γ=2: verify_cost=1.11×, round_yield=1.34 → theoretical speedup=1.21×,
       observed 1.10× (0.11× lost to EAGLE draft cost + bookkeeping).
  γ=4: verify_cost=1.12×, round_yield=1.36 → theoretical=1.21×,
       observed 1.02×.

Current numbers are near-optimal given measured acceptance. Further
gains require either tree drafting (unlocks top-K acceptance) or a
better-trained EAGLE head. Neither is a small change.
This commit is contained in:
2026-07-01 20:19:28 +08:00
parent cc3bc2188c
commit 10a98539d0
2 changed files with 74 additions and 13 deletions

View File

@@ -400,6 +400,8 @@ fn run_eagle_gamma_multi(
let mut accepted_total = 0usize;
let mut proposed_total = 0usize;
let mut per_slot_correct: Vec<usize> = vec![0; gamma];
let mut per_slot_covered: Vec<usize> = vec![0; gamma];
let mut per_slot_top3: Vec<usize> = vec![0; gamma];
let mut per_slot_total: Vec<usize> = vec![0; gamma];
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
@@ -416,13 +418,16 @@ fn run_eagle_gamma_multi(
// Snapshot EAGLE's cache len so we can roll back rejected drafts' K/V.
let eagle_len_before = eagle.current_len();
let mut drafts: Vec<u32> = Vec::with_capacity(round_gamma);
let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
let mut draft_logits: Vec<Tensor> = Vec::with_capacity(round_gamma);
let (d0, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
drafts.push(d0);
draft_logits.push(l0);
let mut prev_aux = aux0;
let mut prev_draft = d0;
for k in 1..round_gamma {
let (dk, _, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k);
let (dk, lk, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k);
drafts.push(dk);
draft_logits.push(lk);
prev_aux = auxk;
prev_draft = dk;
}
@@ -454,12 +459,22 @@ fn run_eagle_gamma_multi(
k += 1;
}
accepted_total += k;
// Per-slot diagnostic: independent acceptance rate per position.
// Per-slot diagnostic: independent acceptance rate per position AND
// whether verify's argmax is even representable in EAGLE's draft vocab
// (coverage upper-bound on acceptance).
for (i, &d) in drafts.iter().enumerate() {
if d == verify_argmax[i] {
per_slot_correct[i] += 1;
}
if eagle.target_id_in_draft_vocab(verify_argmax[i]) {
per_slot_covered[i] += 1;
}
per_slot_total[i] += 1;
// Top-3 diagnostic: is verify_argmax[i] in EAGLE's top-3 for slot i?
let top3 = top_k_target_ids(&draft_logits[i], 3, eagle);
if top3.contains(&verify_argmax[i]) {
per_slot_top3[i] += 1;
}
}
// EAGLE wrote γ K/V entries this round at slots
@@ -539,16 +554,27 @@ fn run_eagle_gamma_multi(
generated.push(pending_prev);
}
eprint!("[per-slot d[i] correct/total: ");
eprint!("[per-slot i=correct/top3/covered/total (c,t3,cov): ");
for i in 0..gamma {
let rate = if per_slot_total[i] > 0 {
per_slot_correct[i] as f64 / per_slot_total[i] as f64
let (c_rate, t3_rate, cov_rate) = if per_slot_total[i] > 0 {
(
per_slot_correct[i] as f64 / per_slot_total[i] as f64,
per_slot_top3[i] as f64 / per_slot_total[i] as f64,
per_slot_covered[i] as f64 / per_slot_total[i] as f64,
)
} else {
0.0
(0.0, 0.0, 0.0)
};
eprint!(
"{}={}/{}({:.2}) ",
i, per_slot_correct[i], per_slot_total[i], rate
"{}={}/{}/{}/{}({:.2},{:.2},{:.2}) ",
i,
per_slot_correct[i],
per_slot_top3[i],
per_slot_covered[i],
per_slot_total[i],
c_rate,
t3_rate,
cov_rate
);
}
eprintln!("]");
@@ -585,6 +611,22 @@ fn finalize(
}
}
/// Return top-k target-vocab ids from a [1, draft_vocab_size] BF16 logits
/// tensor. Uses partial sort on CPU (draft_vocab=32k is small enough).
fn top_k_target_ids(logits: &Tensor, k: usize, eagle: &Eagle3Head) -> Vec<u32> {
use half::bf16;
let cpu = logits.to_device(Device::Cpu);
let data = cpu.as_slice::<bf16>();
let mut idx: Vec<usize> = (0..data.len()).collect();
idx.select_nth_unstable_by(k, |&a, &b| {
data[b].to_f32().partial_cmp(&data[a].to_f32()).unwrap()
});
idx[..k]
.iter()
.map(|&i| eagle.map_draft_to_target(i as u32))
.collect()
}
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
xserv_kernels::argmax_bf16_to_host(logits)
}

View File

@@ -43,6 +43,10 @@ pub struct Eagle3Head {
norm: Tensor, // [hidden] final
lm_head_wt: Tensor, // [draft_vocab, hidden]
d2t: Vec<i64>, // [draft_vocab] offset mapping
/// t2d[target_id] = true iff target_id has a corresponding draft-vocab id
/// (i.e. can potentially be produced by EAGLE). Used to measure the
/// coverage cap on acceptance.
t2d: Vec<bool>,
hidden_size: usize,
num_heads: usize,
num_kv_heads: usize,
@@ -59,7 +63,7 @@ pub struct Eagle3Head {
impl Eagle3Head {
pub fn load(dir: &Path, device: u32) -> Self {
let (weights, d2t) = load_eagle3_weights(dir, device);
let (weights, d2t, t2d) = load_eagle3_weights(dir, device);
let hidden_size = 4096;
let num_heads = 32;
let num_kv_heads = 8;
@@ -133,6 +137,7 @@ impl Eagle3Head {
norm,
lm_head_wt,
d2t,
t2d,
hidden_size,
num_heads,
num_kv_heads,
@@ -290,6 +295,12 @@ impl Eagle3Head {
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
}
/// Returns true iff `target_id` is representable in the draft vocabulary
/// (i.e., EAGLE could in principle produce it).
pub fn target_id_in_draft_vocab(&self, target_id: u32) -> bool {
self.t2d.get(target_id as usize).copied().unwrap_or(false)
}
}
fn d2d(dst: *mut u8, src: *const u8, bytes: usize) {
@@ -349,8 +360,8 @@ fn repeat_kv_for_single_token(kv: &Tensor, repeats: usize) -> Tensor {
out
}
/// Load EAGLE3 weights from safetensors, handling int64 d2t specially.
fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec<i64>) {
/// Load EAGLE3 weights from safetensors, handling int64 d2t + bool t2d specially.
fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec<i64>, Vec<bool>) {
let st_path = dir.join("model.safetensors");
assert!(
st_path.exists(),
@@ -368,9 +379,13 @@ fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec
let mut tensors = HashMap::new();
let mut d2t_vec: Vec<i64> = Vec::new();
let mut t2d_vec: Vec<bool> = Vec::new();
for (name, view) in st.tensors() {
if name == "t2d" {
let raw = view.data();
assert_eq!(view.dtype(), safetensors::Dtype::BOOL);
t2d_vec = raw.iter().map(|&b| b != 0).collect();
continue;
}
if name == "d2t" {
@@ -402,5 +417,9 @@ fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec
!d2t_vec.is_empty(),
"d2t tensor not found in eagle3 weights"
);
(tensors, d2t_vec)
assert!(
!t2d_vec.is_empty(),
"t2d tensor not found in eagle3 weights"
);
(tensors, d2t_vec, t2d_vec)
}