eagle3: coverage + top-3 diagnostic; acceptance ceiling analysis
Add t2d bool tensor loading and per-slot top-3 rate tracking to
bench-eagle3 so we can distinguish three failure modes:
- Not covered: target's argmax not in EAGLE's 32k-vocab (upper bound).
- Not top-3: target's argmax not in EAGLE's top-3 (drafting quality).
- Not top-1: target's argmax not EAGLE's argmax (final acceptance rule).
Measured on 50 prompts × 64 tokens γ=2:
d[0]: correct=27%, top-3=42%, covered=98% → EAGLE covers vocab well
but often ranks target
answer below top-1.
d[1]: correct=9%, top-3=17%, covered=100% → recursive draft even
weaker.
Coverage is essentially not a bottleneck (98%+). The bottleneck is
that EAGLE ranks the true target answer only ~27% of the time at slot
0. Top-3 rate (~42%) shows the correct answer is often in EAGLE's
distribution but not the highest-scored candidate.
To exploit the top-3 headroom would require tree-based verify (multiple
candidates per position, tree-aware attention masking). Each candidate
attends only to its own branch, not siblings. Current paged_decode_
attention writes K/V at unique per-batch positions and does not
support tree causal masks.
Speedup formula analysis (from bench-verify-cost):
γ=2: verify_cost=1.11×, round_yield=1.34 → theoretical speedup=1.21×,
observed 1.10× (0.11× lost to EAGLE draft cost + bookkeeping).
γ=4: verify_cost=1.12×, round_yield=1.36 → theoretical=1.21×,
observed 1.02×.
Current numbers are near-optimal given measured acceptance. Further
gains require either tree drafting (unlocks top-K acceptance) or a
better-trained EAGLE head. Neither is a small change.
This commit is contained in:
@@ -400,6 +400,8 @@ fn run_eagle_gamma_multi(
|
||||
let mut accepted_total = 0usize;
|
||||
let mut proposed_total = 0usize;
|
||||
let mut per_slot_correct: Vec<usize> = vec![0; gamma];
|
||||
let mut per_slot_covered: Vec<usize> = vec![0; gamma];
|
||||
let mut per_slot_top3: Vec<usize> = vec![0; gamma];
|
||||
let mut per_slot_total: Vec<usize> = vec![0; gamma];
|
||||
|
||||
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
|
||||
@@ -416,13 +418,16 @@ fn run_eagle_gamma_multi(
|
||||
// Snapshot EAGLE's cache len so we can roll back rejected drafts' K/V.
|
||||
let eagle_len_before = eagle.current_len();
|
||||
let mut drafts: Vec<u32> = Vec::with_capacity(round_gamma);
|
||||
let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
|
||||
let mut draft_logits: Vec<Tensor> = Vec::with_capacity(round_gamma);
|
||||
let (d0, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
|
||||
drafts.push(d0);
|
||||
draft_logits.push(l0);
|
||||
let mut prev_aux = aux0;
|
||||
let mut prev_draft = d0;
|
||||
for k in 1..round_gamma {
|
||||
let (dk, _, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k);
|
||||
let (dk, lk, auxk) = eagle.step_recursive(prev_aux, embed_tokens, prev_draft, p + k);
|
||||
drafts.push(dk);
|
||||
draft_logits.push(lk);
|
||||
prev_aux = auxk;
|
||||
prev_draft = dk;
|
||||
}
|
||||
@@ -454,12 +459,22 @@ fn run_eagle_gamma_multi(
|
||||
k += 1;
|
||||
}
|
||||
accepted_total += k;
|
||||
// Per-slot diagnostic: independent acceptance rate per position.
|
||||
// Per-slot diagnostic: independent acceptance rate per position AND
|
||||
// whether verify's argmax is even representable in EAGLE's draft vocab
|
||||
// (coverage upper-bound on acceptance).
|
||||
for (i, &d) in drafts.iter().enumerate() {
|
||||
if d == verify_argmax[i] {
|
||||
per_slot_correct[i] += 1;
|
||||
}
|
||||
if eagle.target_id_in_draft_vocab(verify_argmax[i]) {
|
||||
per_slot_covered[i] += 1;
|
||||
}
|
||||
per_slot_total[i] += 1;
|
||||
// Top-3 diagnostic: is verify_argmax[i] in EAGLE's top-3 for slot i?
|
||||
let top3 = top_k_target_ids(&draft_logits[i], 3, eagle);
|
||||
if top3.contains(&verify_argmax[i]) {
|
||||
per_slot_top3[i] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// EAGLE wrote γ K/V entries this round at slots
|
||||
@@ -539,16 +554,27 @@ fn run_eagle_gamma_multi(
|
||||
generated.push(pending_prev);
|
||||
}
|
||||
|
||||
eprint!("[per-slot d[i] correct/total: ");
|
||||
eprint!("[per-slot i=correct/top3/covered/total (c,t3,cov): ");
|
||||
for i in 0..gamma {
|
||||
let rate = if per_slot_total[i] > 0 {
|
||||
per_slot_correct[i] as f64 / per_slot_total[i] as f64
|
||||
let (c_rate, t3_rate, cov_rate) = if per_slot_total[i] > 0 {
|
||||
(
|
||||
per_slot_correct[i] as f64 / per_slot_total[i] as f64,
|
||||
per_slot_top3[i] as f64 / per_slot_total[i] as f64,
|
||||
per_slot_covered[i] as f64 / per_slot_total[i] as f64,
|
||||
)
|
||||
} else {
|
||||
0.0
|
||||
(0.0, 0.0, 0.0)
|
||||
};
|
||||
eprint!(
|
||||
"{}={}/{}({:.2}) ",
|
||||
i, per_slot_correct[i], per_slot_total[i], rate
|
||||
"{}={}/{}/{}/{}({:.2},{:.2},{:.2}) ",
|
||||
i,
|
||||
per_slot_correct[i],
|
||||
per_slot_top3[i],
|
||||
per_slot_covered[i],
|
||||
per_slot_total[i],
|
||||
c_rate,
|
||||
t3_rate,
|
||||
cov_rate
|
||||
);
|
||||
}
|
||||
eprintln!("]");
|
||||
@@ -585,6 +611,22 @@ fn finalize(
|
||||
}
|
||||
}
|
||||
|
||||
/// Return top-k target-vocab ids from a [1, draft_vocab_size] BF16 logits
|
||||
/// tensor. Uses partial sort on CPU (draft_vocab=32k is small enough).
|
||||
fn top_k_target_ids(logits: &Tensor, k: usize, eagle: &Eagle3Head) -> Vec<u32> {
|
||||
use half::bf16;
|
||||
let cpu = logits.to_device(Device::Cpu);
|
||||
let data = cpu.as_slice::<bf16>();
|
||||
let mut idx: Vec<usize> = (0..data.len()).collect();
|
||||
idx.select_nth_unstable_by(k, |&a, &b| {
|
||||
data[b].to_f32().partial_cmp(&data[a].to_f32()).unwrap()
|
||||
});
|
||||
idx[..k]
|
||||
.iter()
|
||||
.map(|&i| eagle.map_draft_to_target(i as u32))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
|
||||
xserv_kernels::argmax_bf16_to_host(logits)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,10 @@ pub struct Eagle3Head {
|
||||
norm: Tensor, // [hidden] final
|
||||
lm_head_wt: Tensor, // [draft_vocab, hidden]
|
||||
d2t: Vec<i64>, // [draft_vocab] offset mapping
|
||||
/// t2d[target_id] = true iff target_id has a corresponding draft-vocab id
|
||||
/// (i.e. can potentially be produced by EAGLE). Used to measure the
|
||||
/// coverage cap on acceptance.
|
||||
t2d: Vec<bool>,
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
@@ -59,7 +63,7 @@ pub struct Eagle3Head {
|
||||
|
||||
impl Eagle3Head {
|
||||
pub fn load(dir: &Path, device: u32) -> Self {
|
||||
let (weights, d2t) = load_eagle3_weights(dir, device);
|
||||
let (weights, d2t, t2d) = load_eagle3_weights(dir, device);
|
||||
let hidden_size = 4096;
|
||||
let num_heads = 32;
|
||||
let num_kv_heads = 8;
|
||||
@@ -133,6 +137,7 @@ impl Eagle3Head {
|
||||
norm,
|
||||
lm_head_wt,
|
||||
d2t,
|
||||
t2d,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
@@ -290,6 +295,12 @@ impl Eagle3Head {
|
||||
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
|
||||
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
|
||||
}
|
||||
|
||||
/// Returns true iff `target_id` is representable in the draft vocabulary
|
||||
/// (i.e., EAGLE could in principle produce it).
|
||||
pub fn target_id_in_draft_vocab(&self, target_id: u32) -> bool {
|
||||
self.t2d.get(target_id as usize).copied().unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn d2d(dst: *mut u8, src: *const u8, bytes: usize) {
|
||||
@@ -349,8 +360,8 @@ fn repeat_kv_for_single_token(kv: &Tensor, repeats: usize) -> Tensor {
|
||||
out
|
||||
}
|
||||
|
||||
/// Load EAGLE3 weights from safetensors, handling int64 d2t specially.
|
||||
fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec<i64>) {
|
||||
/// Load EAGLE3 weights from safetensors, handling int64 d2t + bool t2d specially.
|
||||
fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec<i64>, Vec<bool>) {
|
||||
let st_path = dir.join("model.safetensors");
|
||||
assert!(
|
||||
st_path.exists(),
|
||||
@@ -368,9 +379,13 @@ fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec
|
||||
|
||||
let mut tensors = HashMap::new();
|
||||
let mut d2t_vec: Vec<i64> = Vec::new();
|
||||
let mut t2d_vec: Vec<bool> = Vec::new();
|
||||
|
||||
for (name, view) in st.tensors() {
|
||||
if name == "t2d" {
|
||||
let raw = view.data();
|
||||
assert_eq!(view.dtype(), safetensors::Dtype::BOOL);
|
||||
t2d_vec = raw.iter().map(|&b| b != 0).collect();
|
||||
continue;
|
||||
}
|
||||
if name == "d2t" {
|
||||
@@ -402,5 +417,9 @@ fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec
|
||||
!d2t_vec.is_empty(),
|
||||
"d2t tensor not found in eagle3 weights"
|
||||
);
|
||||
(tensors, d2t_vec)
|
||||
assert!(
|
||||
!t2d_vec.is_empty(),
|
||||
"t2d tensor not found in eagle3 weights"
|
||||
);
|
||||
(tensors, d2t_vec, t2d_vec)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user