eagle3: tree drafting with top-2 siblings — speedup_e2e = 1.17× 🎉
Implements the full tree speculative drafting loop using the copy_kv_position primitive from the previous commit. Tree structure per round (4 verify tokens): [pending_prev, d0_top1, d0_top2, d1_chain_from_top1] positions: [P, P+1, P+1, P+2] tree_mask: row0=[1000] row1=[1100] row2=[1010] row3=[1101] Acceptance logic: - d0_top1 matches target → check d1 chain → commit 2 or 3 tokens. - d0_top2 matches target → copy_kv_position(P+2→P+1) + commit 2. - Neither → commit pending_prev only. 50 prompts × 64 tokens on dash5 (Qwen3-8B + AngelSlim EAGLE3): acceptance_rate = 14.1% (vs 11.3% non-tree γ=2) target_steps = 2231 (vs 2432 non-tree) baseline_tpot_ms = 12.51, spec_tpot_ms = 10.68 speedup_e2e = 1.17× (vs 1.10× non-tree) The top-2 sibling adds ~3% absolute acceptance, which translates to ~7% additional speedup. The copy_kv_position cost is negligible (<6μs). CLI: bench-eagle3 --tree enables the tree path.
This commit is contained in:
@@ -151,13 +151,19 @@ fn main() {
|
||||
baseline_tokens += baseline.ids.len();
|
||||
drop(baseline_cache);
|
||||
|
||||
// Speculative with EAGLE, γ from CLI. Verify uses the tree kernel with
|
||||
// a causal mask (equivalent to non-tree behavior); a real tree
|
||||
// (siblings sharing target positions) would require KV cache slot
|
||||
// remap after acceptance, which is out of scope for this iteration.
|
||||
let _ = use_tree; // reserved for future tree drafting
|
||||
// Speculative with EAGLE.
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let spec = if gamma == 1 {
|
||||
let spec = if use_tree {
|
||||
run_eagle_tree(
|
||||
&target,
|
||||
&mut eagle,
|
||||
&embed_tokens,
|
||||
&mut target_cache,
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
)
|
||||
} else if gamma == 1 {
|
||||
run_eagle_gamma1(
|
||||
&target,
|
||||
&mut eagle,
|
||||
@@ -609,6 +615,200 @@ fn run_eagle_gamma_multi(
|
||||
)
|
||||
}
|
||||
|
||||
/// Tree drafting with top-2 siblings at slot 0 + chain from top-1.
|
||||
///
|
||||
/// Tree structure per round (4 verify tokens):
|
||||
/// [pending_prev, d0_top1, d0_top2, d1_chain_from_top1]
|
||||
/// positions: [P, P+1, P+1, P+2]
|
||||
/// tree_mask: [[1,0,0,0],[1,1,0,0],[1,0,1,0],[1,1,0,1]]
|
||||
///
|
||||
/// Acceptance:
|
||||
/// - d0_top1 matches target → check d1 → commit 2 or 3 tokens.
|
||||
/// - d0_top2 matches target → copy_kv_position(P+2 → P+1) + commit 2 tokens.
|
||||
/// - Neither → commit pending_prev only (1 token).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_eagle_tree(
|
||||
target: &Qwen3,
|
||||
eagle: &mut Eagle3Head,
|
||||
embed_tokens: &Tensor,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> RunStats {
|
||||
use xserv_model::eagle3::EAGLE_HOOK_LAYERS;
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).unwrap();
|
||||
eagle.reset();
|
||||
let t0 = Instant::now();
|
||||
|
||||
// Prefill + seed decode for initial hooks.
|
||||
let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache);
|
||||
let first_token = last_argmax(&prefill_logits);
|
||||
let (seed_logits, seed_hooks_v) = target_decode_with_hidden(target, first_token, cache, slot);
|
||||
let mut generated = vec![first_token];
|
||||
let mut pending_prev = last_argmax(&seed_logits);
|
||||
let mut seed_hooks: [Tensor; 3] = seed_hooks_v;
|
||||
let mut target_steps = 2usize;
|
||||
let mut accepted_total = 0usize;
|
||||
let mut proposed_total = 0usize;
|
||||
|
||||
while generated.len() + 1 < gen_tokens && !tokenizer.is_eos(pending_prev) {
|
||||
let p = cache.seq_len(slot);
|
||||
if generated.len() + 4 > gen_tokens {
|
||||
// Not enough room for tree; fall back to committing pending_prev only.
|
||||
generated.push(pending_prev);
|
||||
break;
|
||||
}
|
||||
|
||||
let eagle_len_before = eagle.current_len();
|
||||
|
||||
// Draft: EAGLE step 0 → top-2 candidates for position P+1.
|
||||
let (d0_top1, l0, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, pending_prev, p);
|
||||
let top2 = top_k_target_ids(&l0, 2, eagle);
|
||||
let d0_top2 = if top2[0] == d0_top1 { top2[1] } else { top2[0] };
|
||||
|
||||
// Draft: EAGLE step 1 → chain from d0_top1 for position P+2.
|
||||
let (d1, _, _) = eagle.step_recursive(aux0, embed_tokens, d0_top1, p + 1);
|
||||
|
||||
proposed_total += 3; // 3 candidates proposed (d0_top1, d0_top2, d1)
|
||||
|
||||
// Verify with tree: 4 tokens.
|
||||
let verify_input = vec![pending_prev, d0_top1, d0_top2, d1];
|
||||
let positions_v: Vec<u32> = vec![p as u32, (p + 1) as u32, (p + 1) as u32, (p + 2) as u32];
|
||||
let kv_lens_v: Vec<i32> = vec![
|
||||
(p + 1) as i32,
|
||||
(p + 2) as i32,
|
||||
(p + 3) as i32,
|
||||
(p + 4) as i32,
|
||||
];
|
||||
#[rustfmt::skip]
|
||||
let tree_mask: Vec<i32> = vec![
|
||||
1, 0, 0, 0,
|
||||
1, 1, 0, 0,
|
||||
1, 0, 1, 0,
|
||||
1, 1, 0, 1,
|
||||
];
|
||||
|
||||
let (verify_logits, verify_hooks) = target
|
||||
.forward_verify_paged_decode_attention_tree_with_hidden(
|
||||
&verify_input,
|
||||
&positions_v,
|
||||
&kv_lens_v,
|
||||
&tree_mask,
|
||||
slot,
|
||||
cache,
|
||||
&EAGLE_HOOK_LAYERS,
|
||||
);
|
||||
target_steps += 1;
|
||||
// cache.seq_len is now p + 4.
|
||||
|
||||
let va = argmax_rows(&verify_logits);
|
||||
// va[0] = target at P+1; va[1] = target at P+2 given ..d0_top1;
|
||||
// va[2] = target at P+2 given ..d0_top2; va[3] = target at P+3 given ..d1.
|
||||
|
||||
if d0_top1 == va[0] {
|
||||
// Top-1 path accepted at slot 0.
|
||||
accepted_total += 1;
|
||||
if d1 == va[1] {
|
||||
// Chain also accepted.
|
||||
accepted_total += 1;
|
||||
// Commit: pending_prev + d0_top1 + d1 = 3 tokens at [P, P+1, P+3].
|
||||
// But K/V layout: [P]=pending_prev, [P+1]=d0_top1, [P+2]=d0_top2, [P+3]=d1.
|
||||
// We need d1 at cache position P+2 (not P+3). Remap:
|
||||
cache.copy_kv_position(slot, p + 3, p + 2);
|
||||
cache.truncate_sequence(slot, p + 3).unwrap();
|
||||
generated.push(pending_prev);
|
||||
generated.push(d0_top1);
|
||||
generated.push(d1);
|
||||
pending_prev = va[3];
|
||||
// seed_hooks = verify_hooks[3] (state that produced va[3]).
|
||||
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
||||
for h in verify_hooks.iter() {
|
||||
new_hooks.push(h.narrow(0, 3, 1).contiguous());
|
||||
}
|
||||
seed_hooks = [
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
];
|
||||
eagle.truncate_to(eagle_len_before + 2); // keep pending_prev + d0_top1
|
||||
} else {
|
||||
// Only d0_top1 accepted, d1 rejected.
|
||||
// K/V [P]=pp, [P+1]=d0_top1 — correct. Drop [P+2], [P+3].
|
||||
cache.truncate_sequence(slot, p + 2).unwrap();
|
||||
generated.push(pending_prev);
|
||||
generated.push(d0_top1);
|
||||
pending_prev = va[1];
|
||||
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
||||
for h in verify_hooks.iter() {
|
||||
new_hooks.push(h.narrow(0, 1, 1).contiguous());
|
||||
}
|
||||
seed_hooks = [
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
];
|
||||
eagle.truncate_to(eagle_len_before + 2);
|
||||
}
|
||||
} else if d0_top2 == va[0] {
|
||||
// Top-2 sibling accepted! K/V for d0_top2 is at cache position P+2.
|
||||
// Copy it to P+1 (canonical position for the accepted token).
|
||||
accepted_total += 1;
|
||||
cache.copy_kv_position(slot, p + 2, p + 1);
|
||||
cache.truncate_sequence(slot, p + 2).unwrap();
|
||||
generated.push(pending_prev);
|
||||
generated.push(d0_top2);
|
||||
pending_prev = va[2];
|
||||
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
||||
for h in verify_hooks.iter() {
|
||||
new_hooks.push(h.narrow(0, 2, 1).contiguous());
|
||||
}
|
||||
seed_hooks = [
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
];
|
||||
// EAGLE cache: only step 0 was relevant (pending_prev). Truncate to 1.
|
||||
eagle.truncate_to(eagle_len_before + 1);
|
||||
} else {
|
||||
// Both rejected. Commit only pending_prev.
|
||||
cache.truncate_sequence(slot, p + 1).unwrap();
|
||||
generated.push(pending_prev);
|
||||
pending_prev = va[0];
|
||||
let mut new_hooks: Vec<Tensor> = Vec::with_capacity(3);
|
||||
for h in verify_hooks.iter() {
|
||||
new_hooks.push(h.narrow(0, 0, 1).contiguous());
|
||||
}
|
||||
seed_hooks = [
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
new_hooks.remove(0),
|
||||
];
|
||||
eagle.truncate_to(eagle_len_before + 1);
|
||||
}
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(*generated.last().unwrap()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Commit remaining pending_prev if needed.
|
||||
if generated.len() < gen_tokens {
|
||||
generated.push(pending_prev);
|
||||
}
|
||||
|
||||
finalize(
|
||||
generated,
|
||||
t0,
|
||||
target_steps,
|
||||
accepted_total,
|
||||
proposed_total,
|
||||
cache,
|
||||
slot,
|
||||
)
|
||||
}
|
||||
|
||||
fn finalize(
|
||||
generated: Vec<u32>,
|
||||
t0: Instant,
|
||||
|
||||
Reference in New Issue
Block a user