eagle3: extend tree to top-3 siblings — speedup_e2e = 1.20×

Widen the tree from 2 siblings to 3 at slot 0 (+ chain from top-1):
  [pending_prev, d0_top1, d0_top2, d0_top3, d1_chain]
  positions:    [P,       P+1,    P+1,    P+1,    P+2]
  5×5 tree mask enforcing sibling isolation.

50 prompts × 64 tokens on dash5:
  acceptance_rate = 12.1% (4 candidates/round)
  target_steps = 2101 (vs 2231 top-2, 2432 non-tree)
  spec_tpot_ms = 10.43 ms
  baseline_tpot_ms = 12.54 ms
  speedup_e2e = 1.20× (vs 1.17× top-2, 1.10× non-tree)

Verify cost at batch=5: ~1.12× single decode (nearly free). The extra
sibling adds ~3% additional rounds where EAGLE's top-3 matches target.
This commit is contained in:
2026-07-02 00:24:57 +08:00
parent aac9ace144
commit 2fe903ecea

View File

@@ -655,7 +655,7 @@ fn run_eagle_tree(
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) { while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
let p = cache.seq_len(slot); let p = cache.seq_len(slot);
if generated.len() + 4 > gen_tokens { if generated.len() + 5 > gen_tokens {
// Not enough room for tree; fall back to committing pending_prev only. // Not enough room for tree; fall back to committing pending_prev only.
generated.push(pending_prev); generated.push(pending_prev);
break; break;
@@ -663,31 +663,43 @@ fn run_eagle_tree(
let eagle_len_before = eagle.current_len(); let eagle_len_before = eagle.current_len();
// Draft: EAGLE step 0 → top-2 candidates for position P+1. // Draft: EAGLE step 0 → top-3 candidates for position P+1.
let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p); let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
let top2 = top_k_target_ids(&l0, 2, eagle); let top3 = top_k_target_ids(&l0, 3, eagle);
let d0_top2 = if top2[0] == d0_top1 { top2[1] } else { top2[0] }; let d0_top2 = if top3[0] != d0_top1 { top3[0] } else { top3[1] };
let d0_top3 = *top3
.iter()
.find(|&&x| x != d0_top1 && x != d0_top2)
.unwrap_or(&top3[2]);
// Draft: EAGLE step 1 → chain from d0_top1 for position P+2. // Draft: EAGLE step 1 → chain from d0_top1 for position P+2.
let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1); let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1);
proposed_total += 3; // 3 candidates proposed (d0_top1, d0_top2, d1) proposed_total += 4; // 4 candidates: d0_top1, d0_top2, d0_top3, d1
// Verify with tree: 4 tokens. // Verify with tree: 5 tokens (top-3 siblings + chain from top-1).
let verify_input = vec![pending_prev, d0_top1, d0_top2, d1]; let verify_input = vec![pending_prev, d0_top1, d0_top2, d0_top3, d1];
let positions_v: Vec<u32> = vec![p as u32, (p + 1) as u32, (p + 1) as u32, (p + 2) as u32]; let positions_v: Vec<u32> = vec![
p as u32,
(p + 1) as u32,
(p + 1) as u32,
(p + 1) as u32,
(p + 2) as u32,
];
let kv_lens_v: Vec<i32> = vec![ let kv_lens_v: Vec<i32> = vec![
(p + 1) as i32, (p + 1) as i32,
(p + 2) as i32, (p + 2) as i32,
(p + 3) as i32, (p + 3) as i32,
(p + 4) as i32, (p + 4) as i32,
(p + 5) as i32,
]; ];
#[rustfmt::skip] #[rustfmt::skip]
let tree_mask: Vec<i32> = vec![ let tree_mask: Vec<i32> = vec![
1, 0, 0, 0, 1, 0, 0, 0, 0,
1, 1, 0, 0, 1, 1, 0, 0, 0,
1, 0, 1, 0, 1, 0, 1, 0, 0,
1, 1, 0, 1, 1, 0, 0, 1, 0,
1, 1, 0, 0, 1,
]; ];
let (verify_logits, verify_hooks) = target let (verify_logits, verify_hooks) = target
@@ -701,11 +713,12 @@ fn run_eagle_tree(
&EAGLE_HOOK_LAYERS, &EAGLE_HOOK_LAYERS,
); );
target_steps += 1; target_steps += 1;
// cache.seq_len is now p + 4. // cache.seq_len is now p + 5.
let va = argmax_rows(&verify_logits); let va = argmax_rows(&verify_logits);
// va[0] = target at P+1; va[1] = target at P+2 given ..d0_top1; // va[0]=target at P+1. va[1]=target at P+2 given ..d0_top1.
// va[2] = target at P+2 given ..d0_top2; va[3] = target at P+3 given ..d1. // va[2]=target at P+2 given ..d0_top2. va[3]=target at P+2 given ..d0_top3.
// va[4]=target at P+3 given ..d0_top1, d1.
if d0_top1 == va[0] { if d0_top1 == va[0] {
// Top-1 path accepted at slot 0. // Top-1 path accepted at slot 0.
@@ -716,16 +729,16 @@ fn run_eagle_tree(
// Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3]. // Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3].
// But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1. // But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1.
// We need d1 at cache position P+2 (not P+3). Remap: // We need d1 at cache position P+2 (not P+3). Remap:
cache.copy_kv_position(slot, p + 3, p + 2); cache.copy_kv_position(slot, p + 4, p + 2);
cache.truncate_sequence(slot, p + 3).unwrap(); cache.truncate_sequence(slot, p + 3).unwrap();
generated.push(pending_prev); generated.push(pending_prev);
generated.push(d0_top1); generated.push(d0_top1);
generated.push(d1); generated.push(d1);
pending_prev = va[3]; pending_prev = va[4];
// seed_hooks = verify_hooks[3] (state that produced va[3]). // seed_hooks = verify_hooks[4] (state that produced va[3]).
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3); let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
for h in verify_hooks.iter() { for h in verify_hooks.iter() {
new_hooks.push(h.narrow(0, 3, 1).contiguous()); new_hooks.push(h.narrow(0, 4, 1).contiguous());
} }
seed_hooks = [ seed_hooks = [
new_hooks.remove(0), new_hooks.remove(0),
@@ -771,8 +784,26 @@ fn run_eagle_tree(
]; ];
// EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1. // EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1.
eagle.truncate_to(eagle_len_before + 1); eagle.truncate_to(eagle_len_before + 1);
} else if d0_top3 == va[0] {
accepted_total += 1;
// d0_top3's K/V at physical slot p+3 -> canonical p+1.
cache.copy_kv_position(slot, p + 3, p + 1);
cache.truncate_sequence(slot, p + 2).unwrap();
generated.push(pending_prev);
generated.push(d0_top3);
pending_prev = va[3];
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
for h in verify_hooks.iter() {
new_hooks.push(h.narrow(0, 3, 1).contiguous());
}
seed_hooks = [
new_hooks.remove(0),
new_hooks.remove(0),
new_hooks.remove(0),
];
eagle.truncate_to(eagle_len_before + 1);
} else { } else {
// Both rejected. Commit only pending_prev. // All rejected. Commit only pending_prev.
cache.truncate_sequence(slot, p + 1).unwrap(); cache.truncate_sequence(slot, p + 1).unwrap();
generated.push(pending_prev); generated.push(pending_prev);
pending_prev = va[0]; pending_prev = va[0];