eagle3: tree drafting with top-2 siblings — speedup_e2e = 1.17× 🎉

Implements the full tree speculative drafting loop using the
copy_kv_position primitive from the previous commit.

Tree structure per round (4 verify tokens):
  [pending_prev, d0_top1, d0_top2, d1_chain_from_top1]
  positions:    [P,       P+1,    P+1,    P+2]
  tree_mask:    row0=[1000] row1=[1100] row2=[1010] row3=[1101]

Acceptance logic:
- d0_top1 matches target → check d1 chain → commit 2 or 3 tokens.
- d0_top2 matches target → copy_kv_position(P+2→P+1) + commit 2.
- Neither → commit pending_prev only.

50 prompts × 64 tokens on dash5 (Qwen3-8B + AngelSlim EAGLE3):
  acceptance_rate = 14.1% (vs 11.3% non-tree γ=2)
  target_steps = 2231 (vs 2432 non-tree)
  baseline_tpot_ms = 12.51, spec_tpot_ms = 10.68
  speedup_e2e = 1.17× (vs 1.10× non-tree)

The top-2 sibling adds ~3% absolute acceptance, which translates to
~7% additional speedup. The copy_kv_position cost is negligible (<6μs).

CLI: bench-eagle3 --tree enables the tree path.
This commit is contained in:
2026-07-02 00:09:30 +08:00
parent 6da0972740
commit aac9ace144

View File

@@ -151,13 +151,19 @@ fn main() {
baseline_tokens += baseline.ids.len();
drop(baseline_cache);
// Speculative with EAGLE, γ from CLI. Verify uses the tree kernel with
// a causal mask (equivalent to non-tree behavior); a real tree
// (siblings sharing target positions) would require KV cache slot
// remap after acceptance, which is out of scope for this iteration.
let _ = use_tree; // reserved for future tree drafting
// Speculative with EAGLE.
let mut target_cache = new_cache(&target_config, max_seq_len, device);
let spec = if gamma == 1 {
let spec = if use_tree {
run_eagle_tree(
&target,
&mut eagle,
&embed_tokens,
&mut target_cache,
&tokenizer,
&ids,
gen_tokens,
)
} else if gamma == 1 {
run_eagle_gamma1(
&target,
&mut eagle,
@@ -609,6 +615,200 @@ fn run_eagle_gamma_multi(
)
}
/// Tree drafting with top-2 siblings at slot 0 + chain from top-1.
///
/// Tree structure per round (4 verify tokens):
/// [pending_prev, d0_top1, d0_top2, d1_chain_from_top1]
/// positions: [P, P+1, P+1, P+2]
/// tree_mask: [[1,0,0,0],[1,1,0,0],[1,0,1,0],[1,1,0,1]]
///
/// Acceptance:
/// - d0_top1 matches target → check d1 → commit 2 or 3 tokens.
/// - d0_top2 matches target → copy_kv_position(P+2 → P+1) + commit 2 tokens.
/// - Neither → commit pending_prev only (1 token).
#[allow(clippy::too_many_arguments)]
fn run_eagle_tree(
target: &Qwen3,
eagle: &mut Eagle3Head,
embed_tokens: &Tensor,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_ids: &[u32],
gen_tokens: usize,
) -> RunStats {
use xserv_model::eagle3::EAGLE_HOOK_LAYERS;
let slot = 0;
cache.register_sequence(slot).unwrap();
eagle.reset();
let t0 = Instant::now();
// Prefill + seed decode for initial hooks.
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
let first_token = last_argmax(&prefill_logits);
let (seed_logits, seed_hooks_v) = target_decode_with_hidden(target, first_token, cache, slot);
let mut generated = vec![first_token];
let mut pending_prev = last_argmax(&seed_logits);
let mut seed_hooks: [Tensor; 3] = seed_hooks_v;
let mut target_steps = 2usize;
let mut accepted_total = 0usize;
let mut proposed_total = 0usize;
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
let p = cache.seq_len(slot);
if generated.len() + 4 > gen_tokens {
// Not enough room for tree; fall back to committing pending_prev only.
generated.push(pending_prev);
break;
}
let eagle_len_before = eagle.current_len();
// Draft: EAGLE step 0 → top-2 candidates for position P+1.
let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
let top2 = top_k_target_ids(&l0, 2, eagle);
let d0_top2 = if top2[0] == d0_top1 { top2[1] } else { top2[0] };
// Draft: EAGLE step 1 → chain from d0_top1 for position P+2.
let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1);
proposed_total += 3; // 3 candidates proposed (d0_top1, d0_top2, d1)
// Verify with tree: 4 tokens.
let verify_input = vec![pending_prev, d0_top1, d0_top2, d1];
let positions_v: Vec<u32> = vec![p as u32, (p + 1) as u32, (p + 1) as u32, (p + 2) as u32];
let kv_lens_v: Vec<i32> = vec![
(p + 1) as i32,
(p + 2) as i32,
(p + 3) as i32,
(p + 4) as i32,
];
#[rustfmt::skip]
let tree_mask: Vec<i32> = vec![
1, 0, 0, 0,
1, 1, 0, 0,
1, 0, 1, 0,
1, 1, 0, 1,
];
let (verify_logits, verify_hooks) = target
.forward_verify_paged_decode_attention_tree_with_hidden(
&verify_input,
&positions_v,
&kv_lens_v,
&tree_mask,
slot,
cache,
&EAGLE_HOOK_LAYERS,
);
target_steps += 1;
// cache.seq_len is now p + 4.
let va = argmax_rows(&verify_logits);
// va[0] = target at P+1; va[1] = target at P+2 given ..d0_top1;
// va[2] = target at P+2 given ..d0_top2; va[3] = target at P+3 given ..d1.
if d0_top1 == va[0] {
// Top-1 path accepted at slot 0.
accepted_total += 1;
if d1 == va[1] {
// Chain also accepted.
accepted_total += 1;
// Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3].
// But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1.
// We need d1 at cache position P+2 (not P+3). Remap:
cache.copy_kv_position(slot, p + 3, p + 2);
cache.truncate_sequence(slot, p + 3).unwrap();
generated.push(pending_prev);
generated.push(d0_top1);
generated.push(d1);
pending_prev = va[3];
// seed_hooks = verify_hooks[3] (state that produced va[3]).
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
for h in verify_hooks.iter() {
new_hooks.push(h.narrow(0, 3, 1).contiguous());
}
seed_hooks = [
new_hooks.remove(0),
new_hooks.remove(0),
new_hooks.remove(0),
];
eagle.truncate_to(eagle_len_before + 2); // keep pending_prev + d0_top1
} else {
// Only d0_top1 accepted, d1 rejected.
// K/V [P]=pp, [P+1]=d0_top1 — correct. Drop [P+2], [P+3].
cache.truncate_sequence(slot, p + 2).unwrap();
generated.push(pending_prev);
generated.push(d0_top1);
pending_prev = va[1];
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
for h in verify_hooks.iter() {
new_hooks.push(h.narrow(0, 1, 1).contiguous());
}
seed_hooks = [
new_hooks.remove(0),
new_hooks.remove(0),
new_hooks.remove(0),
];
eagle.truncate_to(eagle_len_before + 2);
}
} else if d0_top2 == va[0] {
// Top-2 sibling accepted! K/V for d0_top2 is at cache position P+2.
// Copy it to P+1 (canonical position for the accepted token).
accepted_total += 1;
cache.copy_kv_position(slot, p + 2, p + 1);
cache.truncate_sequence(slot, p + 2).unwrap();
generated.push(pending_prev);
generated.push(d0_top2);
pending_prev = va[2];
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
for h in verify_hooks.iter() {
new_hooks.push(h.narrow(0, 2, 1).contiguous());
}
seed_hooks = [
new_hooks.remove(0),
new_hooks.remove(0),
new_hooks.remove(0),
];
// EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1.
eagle.truncate_to(eagle_len_before + 1);
} else {
// Both rejected. Commit only pending_prev.
cache.truncate_sequence(slot, p + 1).unwrap();
generated.push(pending_prev);
pending_prev = va[0];
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
for h in verify_hooks.iter() {
new_hooks.push(h.narrow(0, 0, 1).contiguous());
}
seed_hooks = [
new_hooks.remove(0),
new_hooks.remove(0),
new_hooks.remove(0),
];
eagle.truncate_to(eagle_len_before + 1);
}
if generated.len() >= gen_tokens || tokenizer.is_eos(*generated.last().unwrap()) {
break;
}
}
// Commit remaining pending_prev if needed.
if generated.len() < gen_tokens {
generated.push(pending_prev);
}
finalize(
generated,
t0,
target_steps,
accepted_total,
proposed_total,
cache,
slot,
)
}
fn finalize(
generated: Vec<u32>,
t0: Instant,