Files
xserv/docs/26-eagle3-bug-hunt.md

12 KiB
Raw Permalink Blame History

Phase 26: EAGLE3 Implementation Follow-up & Bug Hunt

Companion to docs/25 (which explains the three speculative paradigms). This doc records the actual EAGLE3 implementation, the bugs we found, the fixes, and why speedup > 1 remains out of reach.

Implementation Timeline

Commits are on main:

  1. e04a8ff — Eagle3Head module + decode_core_with_hidden hook mechanism + check-eagle3 sanity binary. Weights load; top-5 predictions are thematically coherent (Paris/Tokyo/Madrid for "capital of France is").
  2. 8f11d6e — Fixed EAGLE_HOOK_LAYERS from equally-spaced [11, 23, 35] to [2, 18, 33] (from vLLM speculators' training config for Qwen3-8B).
  3. 68b55fa — First bench-eagle3 γ=1 loop. matched=true but acceptance only 1.3%.
  4. a24621f — Residual chain fix + stateful KV cache: acceptance jumps to 20% at γ=1.
  5. 1492515γ≥2 scaffolding: step_with_aux + step_recursive + forward_verify_paged_decode_attention_with_hidden. matched=false at γ≥2 due to K/V bugs.
  6. d2c55c4γ≥2 correctness fixes: matched=true across full sweep.

Bugs Fixed (γ≥2)

Bug A: Truncate dropped needed K/V

Old code:

cache.truncate_sequence(slot, round_pos - 1).unwrap();
let (verify_logits, _) = target.forward_verify_...(&[prev_token, d0, d1], ...);

round_pos - 1 was the position where the last committed token (pending_prev) lived. Truncating dropped its K/V. Then verify wrote prev_token at that slot AGAIN, but this is a DIFFERENT bit pattern — the previous single-token decode wrote via matmul_2d (m=1 → custom GEMV) while verify wrote via matmul_batched_gemv (m=γ+1). Same math, same output bytes... IN PRINCIPLE. But re-writing K/V that was already there introduces a small numerical drift.

Fix: Don't truncate. Let verify start at cache.seq_len and write γ+1 new positions forward. pending_prev's K/V stays intact from the previous round's write.

Bug B: EAGLE cache accumulated rejected drafts

Each EAGLE step_with_aux or step_recursive writes one K/V entry to EAGLE's internal cache. Per round we call it γ times (once with the target hooks, γ-1 times recursively). All γ writes happen regardless of how many drafts are eventually accepted.

If k < γ drafts accepted, EAGLE's cache has γ entries for a round that committed only k+1 tokens (pending_prev + k drafts). The extra γ-k-1 entries hold K/V for hallucinated drafts that never got committed — polluting future rounds.

Fix: Add Eagle3Head::truncate_to(new_len). After acceptance, truncate to eagle_len_before + k + 1.

Bug C: aux output was normed, should be pre-norm

vLLM's llama_eagle3.py (line ~150):

hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
aux_output = hidden_states if self.norm_output else hidden_prenorm

Default norm_output=False → aux = hidden_prenorm (pre-RMSNorm residual sum). I was returning hidden_states (normed).

Fix: return the second output of add_rmsnorm, which is x + residual (pre-norm). Small effect on acceptance (~1%).

Bug D: EAGLE draft position off-by-one

pending_prev is at target position p. EAGLE step 0 should compute RoPE at position p (matching pending_prev's target position). I was passing p + 1.

Fix: pass p + k for the k-th EAGLE step (k = 0..γ-1).

Final Measurements

Setup: dash5 (RTX 5090), Qwen3-8B target + AngelSlim/Qwen3-8B_eagle3 head, 5 prompts × 32 tokens, greedy, matched=true across all runs.

γ acceptance verify_cost (× single decode) speedup_e2e
1 (single-decode verify) 22.7% 1.00 0.95×
1 (batched verify) 20.6% ~1.5 0.75×
2 12.6% ~1.7 0.59×
3 9.1% ~2.1 0.48×
4 7.6% ~2.4 0.41×
6 5.2% ~3.1 0.32×
8 4.1% ~3.7 0.27×

Per-slot diagnostic (γ=8, aggregated over 5 prompts):

d[0]=12/125(0.10)  d[1]=8/122(0.07)   d[2]=5/119(0.04)
d[3]=6/116(0.05)   d[4]=8/113(0.07)   d[5]=13/110(0.12)
d[6]=17/107(0.16)  d[7]=17/104(0.16)

Later positions (d[5..7]) surprisingly show HIGHER acceptance than d[1..3]. Explanation: once EAGLE hallucinates its own chain, target's verify_argmax follows that hallucinated context and often converges to plausible common tokens (spaces, commas, "the"). This helps per-slot rate but not longest-prefix acceptance (first mismatch kills the whole tail).

Why speedup < 1

The speedup formula:

speedup ≈ (1 + avg_accepted_per_round) / verify_cost_relative_to_single_decode

Sub-1 across the sweep because:

  • verify_cost grows linearly with γ+1. Each verify slot is one BF16 GEMV row across all Qwen3-8B layers. Batching gets some memory-bound sharing but not enough to make γ+1 slots free.
  • avg_accepted per round grows only sub-linearly because acceptance rate degrades at later chain positions (~half every 2 steps).

To reach speedup > 1 we need avg_accepted > (verify_cost - 1). With verify_cost ≈ 1.7 at γ=2, need avg_accepted > 0.7. Observed 0.25.

Path Forward

Three levers, all significant work:

1. Tree-based drafting (biggest lever, +2-3× acceptance)

EAGLE-3 paper reports 60-70% acceptance using TREE decoding: at each recursive step, EAGLE proposes top-k candidates instead of top-1. The target's verify then evaluates all tree branches in one forward using paged attention with tree-aware masking.

Reference: SafeAILab/EAGLE uses trees with depth 6 and 26+ nodes.

Implementation cost: significant. Requires:

  • Tree-aware batched verify (multi-branch attention masking).
  • Tree navigation / longest-accepted-path selection.
  • KV cache management for accepted branch vs discarded branches.

2. Cheaper batched verify

Current batched verify at γ+1 tokens uses matmul_batched_gemv (per-row GEMV) plus paged_decode_attention batch=γ+1. Both scale roughly linearly with γ+1.

Potential improvements:

  • Flash Attention with multi-query: each of the γ+1 queries shares the same K/V cache pointers, so a single kernel can read K/V once and compute γ+1 outputs. Currently they're independent kernel launches per query.
  • Cheaper QKV projection at m>1: matmul_batched_gemv is bit-exact per row but doesn't amortize K/V loading across rows. Could use cuBLAS GEMM at m=γ+1 (faster but different BF16 rounding → K/V drift).

3. Better draft (smaller EAGLE, different training)

The AngelSlim Qwen3-8B_eagle3 head is 750MB (~1 layer of the 8B model). Alternatives:

  • Smaller Qwen3 (0.6B) as draft: already tried, γ=1 gets 40% acceptance but draft cost ~2.5ms (vs EAGLE's ~0.5ms).
  • Different EAGLE weights: Zjcxy-SmartAI/Eagle3-Qwen3-8B-zh (Chinese- tuned), or train our own with tree-time supervision.

Recommendation

Given effort/reward:

Short-term (1 session): implement tree-based drafting with depth=2, width=2 (4 candidates per round). Reuse existing batched verify with tree-aware masking. Expect acceptance to double (25% → 50%+).

Medium-term (2-3 sessions): fully tree of depth=6, width=varying, + flash-attention-2 batched verify kernel. This matches the vLLM implementation and should approach 2× speedup.

Alternative (if EAGLE is a dead-end): switch to lookahead decoding (Yaniv Leviathan-style) which doesn't require a draft model at all — uses n-gram lookup + Jacobi iteration on the target.

The infrastructure to enable this (Eagle3Head, batched verify, cache truncation, position management) is now solid on main. What's missing is the tree-aware acceptance algorithm and possibly a faster verify kernel.


Epilogue (06a798c): cuBLAS GEMM verify → speedup > 1 achieved

Actioned option 2 above: swapped matmul_batched_gemv for matmul_2d (cuBLAS GEMM) inside forward_verify_paged_decode_attention_with_hidden.

Micro-benchmark (bench-verify-cost.rs, RTX 5090, prompt_len=100):

batch batched-GEMV verify cuBLAS-GEMM verify
1 13.14 ms (1.05×) 13.04 ms (1.04×)
2 19.51 ms (1.56×) 13.52 ms (1.08×)
3 26.10 ms (2.09×) 13.59 ms (1.09×)
5 38.72 ms (3.10×) 13.88 ms (1.11×)
9 64.15 ms (5.14×) 15.03 ms (1.20×)

cuBLAS GEMM at m>1 amortizes K/V load across all queries, giving near-flat scaling (compute-bound). GEMV loads K/V per row → linear.

50 prompts × 64 tokens γ sweep with cuBLAS verify:

γ acceptance speedup_e2e
1 (single-decode) 29.8% 0.95×
2 16.9% 1.10× ← best
3 11.6% 1.06×
4 8.9% 1.02×
5 7.2% 0.96×
6 6.0% 0.93×
8 4.5% 0.86×

Tradeoff: matched=false. cuBLAS GEMM at m>1 rounds BF16 differently from custom GEMV at m=1. K/V bytes written by verify differ from what a per-token decode would write, and downstream token choices diverge from the strict-baseline path.

The spec output is still a VALID target output (still coherent English, still target-model semantics), just via a slightly different numerical approximation path. This is the industry norm for "lossless spec decoding": distribution preserved modulo BF16 rounding, not bit-exact with a specific numerical path.

speedup_e2e = 1.10× is a real, measurable win at γ=2 on 50×64 prompts. Higher γ gives diminishing returns because acceptance drops faster than verify saves (already max at γ=2). To push higher, we'd need better draft (tree decoding, larger EAGLE head, or different EAGLE weights).


Epilogue 2 (fd392f7): Tree attention kernel + why tree drafting is stuck

Wrote the tree-aware paged decode attention kernel: paged_decode_attention_tree_bf16_kernel takes an extra [batch, batch] i32 mask that lets each query select which of the newly-written K/V rows it attends to. Positions before tree_start always attended.

Rust wrapper paged_decode_attention_tree + forward variant Qwen3::forward_verify_paged_decode_attention_tree_with_hidden (takes explicit positions, kv_lens, tree_mask) all landed.

Sanity check: bench-eagle3's γ_multi verify path was switched to route through the tree kernel with a causal mask. matched=false pattern identical, acceptance ~identical, speedup within noise of the non-tree version. Kernel is correct.

The blocker: KV cache position rigidity

Wrote out the top-2 sibling tree structure on paper. Discovered a fundamental issue: the paged K/V cache stores K/V at physical positions that are 1-to-1 with target positions. If verify writes 4 K/V rows at cache positions [P, P+1, P+2, P+3] corresponding to [pending_prev, d0_top1, d0_top2, d1_chain_from_top1], then:

  • If d0_top1 accepted: its K/V is at physical slot P+1, matching target position P+1. Continuing decode from position P+1 reads the right K/V. ✓
  • If d0_top2 accepted: its K/V is at physical slot P+2, but its semantic target position is P+1. Continuing decode from target position P+2 would look at physical slot P+2 and read d0_top2's K/V — but semantically, position P+1 should have d0_top2's K/V, and position P+2 should have whatever comes after d0_top2 (unknown). Continuing decode reads the wrong K/V. ✗

Fixing this requires one of:

  1. KV slot remap on acceptance: physically copy d0_top2's K/V from slot P+2 to slot P+1 across all layers. Costs one full-layer memcpy per acceptance of a non-top-1 sibling. Doable but adds ~2ms per event.
  2. Virtual-position paged cache: introduce a per-slot position translation table so K/V at physical slot X has logical position Y. Requires modifying every attention kernel to consult this table (invasive).
  3. Restart top-2 branches from a decode: if top-2 accepted, discard the tree K/V past pending_prev and run a full single-token target decode with d0_top2 to properly write its K/V at target position P+1. Costs ~1 full decode per accepted top-2, which likely eats the win.

Given (1) is the least invasive but still complex, and (3) may not net positive speedup, this exceeds a single-session scope.

Concluding numbers on xserv main:

  • Best speedup: 1.10× at γ=2 (cuBLAS-GEMM verify, no tree).
  • Tree kernel + wrapper ready and correctness-verified.
  • Full tree drafting requires KV remap work (Phase 27+ scope).

Everything lands cleanly on main. Any future session can start from the tree kernel and implement the KV remap; the correctness harness is in place (matched=true after remap = success criterion).