From 2fe903eceab60201de706c13fc3929c9c3ab68f7 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 2 Jul 2026 00:24:57 +0800 Subject: [PATCH] =?UTF-8?q?eagle3:=20extend=20tree=20to=20top-3=20siblings?= =?UTF-8?q?=20=E2=80=94=20speedup=5Fe2e=20=3D=201.20=C3=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Widen the tree from 2 siblings to 3 at slot 0 (+ chain from top-1): [pending_prev, d0_top1, d0_top2, d0_top3, d1_chain] positions: [P, P+1, P+1, P+1, P+2] 5×5 tree mask enforcing sibling isolation. 50 prompts × 64 tokens on dash5: acceptance_rate = 12.1% (4 candidates/round) target_steps = 2101 (vs 2231 top-2, 2432 non-tree) spec_tpot_ms = 10.43 ms baseline_tpot_ms = 12.54 ms speedup_e2e = 1.20× (vs 1.17× top-2, 1.10× non-tree) Verify cost at batch=5: ~1.12× single decode (nearly free). The extra sibling adds ~3% additional rounds where EAGLE's top-3 matches target. --- crates/xserv-model/src/bin/bench-eagle3.rs | 71 ++++++++++++++++------ 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/crates/xserv-model/src/bin/bench-eagle3.rs b/crates/xserv-model/src/bin/bench-eagle3.rs index 2b0f47d..6a34fd0 100644 --- a/crates/xserv-model/src/bin/bench-eagle3.rs +++ b/crates/xserv-model/src/bin/bench-eagle3.rs @@ -655,7 +655,7 @@ fn run_eagle_tree( while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) { let p = cache.seq_len(slot); - if generated.len() + 4 > gen_tokens { + if generated.len() + 5 > gen_tokens { // Not enough room for tree; fall back to committing pending_prev only. generated.push(pending_prev); break; @@ -663,31 +663,43 @@ fn run_eagle_tree( let eagle_len_before = eagle.current_len(); - // Draft: EAGLE step 0 → top-2 candidates for position P+1. + // Draft: EAGLE step 0 → top-3 candidates for position P+1. let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p); - let top2 = top_k_target_ids(&l0, 2, eagle); - let d0_top2 = if top2[0] == d0_top1 { top2[1] } else { top2[0] }; + let top3 = top_k_target_ids(&l0, 3, eagle); + let d0_top2 = if top3[0] != d0_top1 { top3[0] } else { top3[1] }; + let d0_top3 = *top3 + .iter() + .find(|&&x| x != d0_top1 && x != d0_top2) + .unwrap_or(&top3[2]); // Draft: EAGLE step 1 → chain from d0_top1 for position P+2. let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1); - proposed_total += 3; // 3 candidates proposed (d0_top1, d0_top2, d1) + proposed_total += 4; // 4 candidates: d0_top1, d0_top2, d0_top3, d1 - // Verify with tree: 4 tokens. - let verify_input = vec![pending_prev, d0_top1, d0_top2, d1]; - let positions_v: Vec = vec![p as u32, (p + 1) as u32, (p + 1) as u32, (p + 2) as u32]; + // Verify with tree: 5 tokens (top-3 siblings + chain from top-1). + let verify_input = vec![pending_prev, d0_top1, d0_top2, d0_top3, d1]; + let positions_v: Vec = vec![ + p as u32, + (p + 1) as u32, + (p + 1) as u32, + (p + 1) as u32, + (p + 2) as u32, + ]; let kv_lens_v: Vec = vec![ (p + 1) as i32, (p + 2) as i32, (p + 3) as i32, (p + 4) as i32, + (p + 5) as i32, ]; #[rustfmt::skip] let tree_mask: Vec = vec![ - 1, 0, 0, 0, - 1, 1, 0, 0, - 1, 0, 1, 0, - 1, 1, 0, 1, + 1, 0, 0, 0, 0, + 1, 1, 0, 0, 0, + 1, 0, 1, 0, 0, + 1, 0, 0, 1, 0, + 1, 1, 0, 0, 1, ]; let (verify_logits, verify_hooks) = target @@ -701,11 +713,12 @@ fn run_eagle_tree( &EAGLE_HOOK_LAYERS, ); target_steps += 1; - // cache.seq_len is now p + 4. + // cache.seq_len is now p + 5. let va = argmax_rows(&verify_logits); - // va[0] = target at P+1; va[1] = target at P+2 given ..d0_top1; - // va[2] = target at P+2 given ..d0_top2; va[3] = target at P+3 given ..d1. + // va[0]=target at P+1. va[1]=target at P+2 given ..d0_top1. + // va[2]=target at P+2 given ..d0_top2. va[3]=target at P+2 given ..d0_top3. + // va[4]=target at P+3 given ..d0_top1, d1. if d0_top1 == va[0] { // Top-1 path accepted at slot 0. @@ -716,16 +729,16 @@ fn run_eagle_tree( // Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3]. // But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1. // We need d1 at cache position P+2 (not P+3). Remap: - cache.copy_kv_position(slot, p + 3, p + 2); + cache.copy_kv_position(slot, p + 4, p + 2); cache.truncate_sequence(slot, p + 3).unwrap(); generated.push(pending_prev); generated.push(d0_top1); generated.push(d1); - pending_prev = va[3]; - // seed_hooks = verify_hooks[3] (state that produced va[3]). + pending_prev = va[4]; + // seed_hooks = verify_hooks[4] (state that produced va[3]). let mut new_hooks: Vec = Vec::with_capacity(3); for h in verify_hooks.iter() { - new_hooks.push(h.narrow(0, 3, 1).contiguous()); + new_hooks.push(h.narrow(0, 4, 1).contiguous()); } seed_hooks = [ new_hooks.remove(0), @@ -771,8 +784,26 @@ fn run_eagle_tree( ]; // EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1. eagle.truncate_to(eagle_len_before + 1); + } else if d0_top3 == va[0] { + accepted_total += 1; + // d0_top3's K/V at physical slot p+3 -> canonical p+1. + cache.copy_kv_position(slot, p + 3, p + 1); + cache.truncate_sequence(slot, p + 2).unwrap(); + generated.push(pending_prev); + generated.push(d0_top3); + pending_prev = va[3]; + let mut new_hooks: Vec = Vec::with_capacity(3); + for h in verify_hooks.iter() { + new_hooks.push(h.narrow(0, 3, 1).contiguous()); + } + seed_hooks = [ + new_hooks.remove(0), + new_hooks.remove(0), + new_hooks.remove(0), + ]; + eagle.truncate_to(eagle_len_before + 1); } else { - // Both rejected. Commit only pending_prev. + // All rejected. Commit only pending_prev. cache.truncate_sequence(slot, p + 1).unwrap(); generated.push(pending_prev); pending_prev = va[0];