eagle3: extend tree to top-3 siblings — speedup_e2e = 1.20×
Widen the tree from 2 siblings to 3 at slot 0 (+ chain from top-1): [pending_prev, d0_top1, d0_top2, d0_top3, d1_chain] positions: [P, P+1, P+1, P+1, P+2] 5×5 tree mask enforcing sibling isolation. 50 prompts × 64 tokens on dash5: acceptance_rate = 12.1% (4 candidates/round) target_steps = 2101 (vs 2231 top-2, 2432 non-tree) spec_tpot_ms = 10.43 ms baseline_tpot_ms = 12.54 ms speedup_e2e = 1.20× (vs 1.17× top-2, 1.10× non-tree) Verify cost at batch=5: ~1.12× single decode (nearly free). The extra sibling adds ~3% additional rounds where EAGLE's top-3 matches target.
This commit is contained in:
@@ -655,7 +655,7 @@ fn run_eagle_tree(
|
|||||||
|
|
||||||
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
|
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
|
||||||
let p = cache.seq_len(slot);
|
let p = cache.seq_len(slot);
|
||||||
if generated.len() + 4 > gen_tokens {
|
if generated.len() + 5 > gen_tokens {
|
||||||
// Not enough room for tree; fall back to committing pending_prev only.
|
// Not enough room for tree; fall back to committing pending_prev only.
|
||||||
generated.push(pending_prev);
|
generated.push(pending_prev);
|
||||||
break;
|
break;
|
||||||
@@ -663,31 +663,43 @@ fn run_eagle_tree(
|
|||||||
|
|
||||||
let eagle_len_before = eagle.current_len();
|
let eagle_len_before = eagle.current_len();
|
||||||
|
|
||||||
// Draft: EAGLE step 0 → top-2 candidates for position P+1.
|
// Draft: EAGLE step 0 → top-3 candidates for position P+1.
|
||||||
let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
|
let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
|
||||||
let top2 = top_k_target_ids(&l0, 2, eagle);
|
let top3 = top_k_target_ids(&l0, 3, eagle);
|
||||||
let d0_top2 = if top2[0] == d0_top1 { top2[1] } else { top2[0] };
|
let d0_top2 = if top3[0] != d0_top1 { top3[0] } else { top3[1] };
|
||||||
|
let d0_top3 = *top3
|
||||||
|
.iter()
|
||||||
|
.find(|&&x| x != d0_top1 && x != d0_top2)
|
||||||
|
.unwrap_or(&top3[2]);
|
||||||
|
|
||||||
// Draft: EAGLE step 1 → chain from d0_top1 for position P+2.
|
// Draft: EAGLE step 1 → chain from d0_top1 for position P+2.
|
||||||
let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1);
|
let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1);
|
||||||
|
|
||||||
proposed_total += 3; // 3 candidates proposed (d0_top1, d0_top2, d1)
|
proposed_total += 4; // 4 candidates: d0_top1, d0_top2, d0_top3, d1
|
||||||
|
|
||||||
// Verify with tree: 4 tokens.
|
// Verify with tree: 5 tokens (top-3 siblings + chain from top-1).
|
||||||
let verify_input = vec![pending_prev, d0_top1, d0_top2, d1];
|
let verify_input = vec![pending_prev, d0_top1, d0_top2, d0_top3, d1];
|
||||||
let positions_v: Vec<u32> = vec![p as u32, (p + 1) as u32, (p + 1) as u32, (p + 2) as u32];
|
let positions_v: Vec<u32> = vec![
|
||||||
|
p as u32,
|
||||||
|
(p + 1) as u32,
|
||||||
|
(p + 1) as u32,
|
||||||
|
(p + 1) as u32,
|
||||||
|
(p + 2) as u32,
|
||||||
|
];
|
||||||
let kv_lens_v: Vec<i32> = vec![
|
let kv_lens_v: Vec<i32> = vec![
|
||||||
(p + 1) as i32,
|
(p + 1) as i32,
|
||||||
(p + 2) as i32,
|
(p + 2) as i32,
|
||||||
(p + 3) as i32,
|
(p + 3) as i32,
|
||||||
(p + 4) as i32,
|
(p + 4) as i32,
|
||||||
|
(p + 5) as i32,
|
||||||
];
|
];
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
let tree_mask: Vec<i32> = vec![
|
let tree_mask: Vec<i32> = vec![
|
||||||
1, 0, 0, 0,
|
1, 0, 0, 0, 0,
|
||||||
1, 1, 0, 0,
|
1, 1, 0, 0, 0,
|
||||||
1, 0, 1, 0,
|
1, 0, 1, 0, 0,
|
||||||
1, 1, 0, 1,
|
1, 0, 0, 1, 0,
|
||||||
|
1, 1, 0, 0, 1,
|
||||||
];
|
];
|
||||||
|
|
||||||
let (verify_logits, verify_hooks) = target
|
let (verify_logits, verify_hooks) = target
|
||||||
@@ -701,11 +713,12 @@ fn run_eagle_tree(
|
|||||||
&EAGLE_HOOK_LAYERS,
|
&EAGLE_HOOK_LAYERS,
|
||||||
);
|
);
|
||||||
target_steps += 1;
|
target_steps += 1;
|
||||||
// cache.seq_len is now p + 4.
|
// cache.seq_len is now p + 5.
|
||||||
|
|
||||||
let va = argmax_rows(&verify_logits);
|
let va = argmax_rows(&verify_logits);
|
||||||
// va[0] = target at P+1; va[1] = target at P+2 given ..d0_top1;
|
// va[0]=target at P+1. va[1]=target at P+2 given ..d0_top1.
|
||||||
// va[2] = target at P+2 given ..d0_top2; va[3] = target at P+3 given ..d1.
|
// va[2]=target at P+2 given ..d0_top2. va[3]=target at P+2 given ..d0_top3.
|
||||||
|
// va[4]=target at P+3 given ..d0_top1, d1.
|
||||||
|
|
||||||
if d0_top1 == va[0] {
|
if d0_top1 == va[0] {
|
||||||
// Top-1 path accepted at slot 0.
|
// Top-1 path accepted at slot 0.
|
||||||
@@ -716,16 +729,16 @@ fn run_eagle_tree(
|
|||||||
// Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3].
|
// Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3].
|
||||||
// But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1.
|
// But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1.
|
||||||
// We need d1 at cache position P+2 (not P+3). Remap:
|
// We need d1 at cache position P+2 (not P+3). Remap:
|
||||||
cache.copy_kv_position(slot, p + 3, p + 2);
|
cache.copy_kv_position(slot, p + 4, p + 2);
|
||||||
cache.truncate_sequence(slot, p + 3).unwrap();
|
cache.truncate_sequence(slot, p + 3).unwrap();
|
||||||
generated.push(pending_prev);
|
generated.push(pending_prev);
|
||||||
generated.push(d0_top1);
|
generated.push(d0_top1);
|
||||||
generated.push(d1);
|
generated.push(d1);
|
||||||
pending_prev = va[3];
|
pending_prev = va[4];
|
||||||
// seed_hooks = verify_hooks[3] (state that produced va[3]).
|
// seed_hooks = verify_hooks[4] (state that produced va[3]).
|
||||||
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
||||||
for h in verify_hooks.iter() {
|
for h in verify_hooks.iter() {
|
||||||
new_hooks.push(h.narrow(0, 3, 1).contiguous());
|
new_hooks.push(h.narrow(0, 4, 1).contiguous());
|
||||||
}
|
}
|
||||||
seed_hooks = [
|
seed_hooks = [
|
||||||
new_hooks.remove(0),
|
new_hooks.remove(0),
|
||||||
@@ -771,8 +784,26 @@ fn run_eagle_tree(
|
|||||||
];
|
];
|
||||||
// EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1.
|
// EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1.
|
||||||
eagle.truncate_to(eagle_len_before + 1);
|
eagle.truncate_to(eagle_len_before + 1);
|
||||||
|
} else if d0_top3 == va[0] {
|
||||||
|
accepted_total += 1;
|
||||||
|
// d0_top3's K/V at physical slot p+3 -> canonical p+1.
|
||||||
|
cache.copy_kv_position(slot, p + 3, p + 1);
|
||||||
|
cache.truncate_sequence(slot, p + 2).unwrap();
|
||||||
|
generated.push(pending_prev);
|
||||||
|
generated.push(d0_top3);
|
||||||
|
pending_prev = va[3];
|
||||||
|
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
||||||
|
for h in verify_hooks.iter() {
|
||||||
|
new_hooks.push(h.narrow(0, 3, 1).contiguous());
|
||||||
|
}
|
||||||
|
seed_hooks = [
|
||||||
|
new_hooks.remove(0),
|
||||||
|
new_hooks.remove(0),
|
||||||
|
new_hooks.remove(0),
|
||||||
|
];
|
||||||
|
eagle.truncate_to(eagle_len_before + 1);
|
||||||
} else {
|
} else {
|
||||||
// Both rejected. Commit only pending_prev.
|
// All rejected. Commit only pending_prev.
|
||||||
cache.truncate_sequence(slot, p + 1).unwrap();
|
cache.truncate_sequence(slot, p + 1).unwrap();
|
||||||
generated.push(pending_prev);
|
generated.push(pending_prev);
|
||||||
pending_prev = va[0];
|
pending_prev = va[0];
|
||||||
|
|||||||
Reference in New Issue
Block a user