240 lines
9.5 KiB
Markdown
240 lines
9.5 KiB
Markdown
# Phase 26: EAGLE3 Implementation Follow-up & Bug Hunt
|
||
|
||
> Companion to docs/25 (which explains the three speculative paradigms).
|
||
> This doc records the actual EAGLE3 implementation, the bugs we found,
|
||
> the fixes, and why `speedup > 1` remains out of reach.
|
||
|
||
## Implementation Timeline
|
||
|
||
Commits are on `main`:
|
||
|
||
1. **`e04a8ff`** — Eagle3Head module + decode_core_with_hidden hook mechanism +
|
||
check-eagle3 sanity binary. Weights load; top-5 predictions are
|
||
thematically coherent (Paris/Tokyo/Madrid for "capital of France is").
|
||
2. **`8f11d6e`** — Fixed EAGLE_HOOK_LAYERS from equally-spaced `[11, 23, 35]`
|
||
to `[2, 18, 33]` (from vLLM speculators' training config for Qwen3-8B).
|
||
3. **`68b55fa`** — First bench-eagle3 γ=1 loop. matched=true but acceptance
|
||
only 1.3%.
|
||
4. **`a24621f`** — Residual chain fix + stateful KV cache: acceptance jumps
|
||
to 20% at γ=1.
|
||
5. **`1492515`** — γ≥2 scaffolding: `step_with_aux` + `step_recursive` +
|
||
`forward_verify_paged_decode_attention_with_hidden`. matched=false at
|
||
γ≥2 due to K/V bugs.
|
||
6. **`d2c55c4`** — γ≥2 correctness fixes: matched=true across full sweep.
|
||
|
||
## Bugs Fixed (γ≥2)
|
||
|
||
### Bug A: Truncate dropped needed K/V
|
||
|
||
Old code:
|
||
```rust
|
||
cache.truncate_sequence(slot, round_pos - 1).unwrap();
|
||
let (verify_logits, _) = target.forward_verify_...(&[prev_token, d0, d1], ...);
|
||
```
|
||
|
||
`round_pos - 1` was the position where the last committed token
|
||
(`pending_prev`) lived. Truncating dropped its K/V. Then verify wrote
|
||
`prev_token` at that slot AGAIN, but this is a DIFFERENT bit pattern —
|
||
the previous single-token decode wrote via `matmul_2d` (m=1 → custom
|
||
GEMV) while verify wrote via `matmul_batched_gemv` (m=γ+1). Same math,
|
||
same output bytes... IN PRINCIPLE. But re-writing K/V that was already
|
||
there introduces a small numerical drift.
|
||
|
||
**Fix**: Don't truncate. Let verify start at `cache.seq_len` and write
|
||
γ+1 new positions forward. `pending_prev`'s K/V stays intact from the
|
||
previous round's write.
|
||
|
||
### Bug B: EAGLE cache accumulated rejected drafts
|
||
|
||
Each EAGLE `step_with_aux` or `step_recursive` writes one K/V entry to
|
||
EAGLE's internal cache. Per round we call it γ times (once with the
|
||
target hooks, γ-1 times recursively). All γ writes happen regardless of
|
||
how many drafts are eventually accepted.
|
||
|
||
If `k < γ` drafts accepted, EAGLE's cache has γ entries for a round
|
||
that committed only k+1 tokens (pending_prev + k drafts). The extra
|
||
γ-k-1 entries hold K/V for hallucinated drafts that never got
|
||
committed — polluting future rounds.
|
||
|
||
**Fix**: Add `Eagle3Head::truncate_to(new_len)`. After acceptance,
|
||
truncate to `eagle_len_before + k + 1`.
|
||
|
||
### Bug C: aux output was normed, should be pre-norm
|
||
|
||
vLLM's `llama_eagle3.py` (line ~150):
|
||
```python
|
||
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
|
||
aux_output = hidden_states if self.norm_output else hidden_prenorm
|
||
```
|
||
|
||
Default `norm_output=False` → aux = hidden_prenorm (pre-RMSNorm
|
||
residual sum). I was returning `hidden_states` (normed).
|
||
|
||
**Fix**: return the second output of `add_rmsnorm`, which is `x + residual`
|
||
(pre-norm). Small effect on acceptance (~1%).
|
||
|
||
### Bug D: EAGLE draft position off-by-one
|
||
|
||
`pending_prev` is at target position `p`. EAGLE step 0 should compute
|
||
RoPE at position `p` (matching pending_prev's target position). I was
|
||
passing `p + 1`.
|
||
|
||
**Fix**: pass `p + k` for the k-th EAGLE step (k = 0..γ-1).
|
||
|
||
## Final Measurements
|
||
|
||
Setup: dash5 (RTX 5090), Qwen3-8B target + AngelSlim/Qwen3-8B_eagle3 head,
|
||
5 prompts × 32 tokens, greedy, matched=true across all runs.
|
||
|
||
| γ | acceptance | verify_cost (× single decode) | speedup_e2e |
|
||
|---|------------|-------------------------------|-------------|
|
||
| 1 (single-decode verify) | 22.7% | 1.00 | **0.95×** |
|
||
| 1 (batched verify) | 20.6% | ~1.5 | 0.75× |
|
||
| 2 | 12.6% | ~1.7 | 0.59× |
|
||
| 3 | 9.1% | ~2.1 | 0.48× |
|
||
| 4 | 7.6% | ~2.4 | 0.41× |
|
||
| 6 | 5.2% | ~3.1 | 0.32× |
|
||
| 8 | 4.1% | ~3.7 | 0.27× |
|
||
|
||
Per-slot diagnostic (γ=8, aggregated over 5 prompts):
|
||
```
|
||
d[0]=12/125(0.10) d[1]=8/122(0.07) d[2]=5/119(0.04)
|
||
d[3]=6/116(0.05) d[4]=8/113(0.07) d[5]=13/110(0.12)
|
||
d[6]=17/107(0.16) d[7]=17/104(0.16)
|
||
```
|
||
|
||
Later positions (d[5..7]) surprisingly show HIGHER acceptance than d[1..3].
|
||
Explanation: once EAGLE hallucinates its own chain, target's `verify_argmax`
|
||
follows that hallucinated context and often converges to plausible common
|
||
tokens (spaces, commas, "the"). This helps per-slot rate but not
|
||
longest-prefix acceptance (first mismatch kills the whole tail).
|
||
|
||
## Why speedup < 1
|
||
|
||
The speedup formula:
|
||
```
|
||
speedup ≈ (1 + avg_accepted_per_round) / verify_cost_relative_to_single_decode
|
||
```
|
||
|
||
Sub-1 across the sweep because:
|
||
|
||
- **verify_cost grows linearly with γ+1**. Each verify slot is one BF16 GEMV
|
||
row across all Qwen3-8B layers. Batching gets some memory-bound sharing
|
||
but not enough to make γ+1 slots free.
|
||
- **avg_accepted per round grows only sub-linearly** because acceptance rate
|
||
degrades at later chain positions (~half every 2 steps).
|
||
|
||
To reach `speedup > 1` we need avg_accepted > (verify_cost - 1). With
|
||
verify_cost ≈ 1.7 at γ=2, need avg_accepted > 0.7. Observed 0.25.
|
||
|
||
## Path Forward
|
||
|
||
Three levers, all significant work:
|
||
|
||
### 1. Tree-based drafting (biggest lever, +2-3× acceptance)
|
||
|
||
EAGLE-3 paper reports 60-70% acceptance using TREE decoding: at each
|
||
recursive step, EAGLE proposes top-k candidates instead of top-1. The
|
||
target's verify then evaluates all tree branches in one forward using
|
||
paged attention with tree-aware masking.
|
||
|
||
Reference: `SafeAILab/EAGLE` uses trees with depth 6 and 26+ nodes.
|
||
|
||
Implementation cost: significant. Requires:
|
||
- Tree-aware batched verify (multi-branch attention masking).
|
||
- Tree navigation / longest-accepted-path selection.
|
||
- KV cache management for accepted branch vs discarded branches.
|
||
|
||
### 2. Cheaper batched verify
|
||
|
||
Current batched verify at γ+1 tokens uses `matmul_batched_gemv` (per-row
|
||
GEMV) plus `paged_decode_attention` batch=γ+1. Both scale roughly
|
||
linearly with γ+1.
|
||
|
||
Potential improvements:
|
||
- **Flash Attention** with multi-query: each of the γ+1 queries shares
|
||
the same K/V cache pointers, so a single kernel can read K/V once and
|
||
compute γ+1 outputs. Currently they're independent kernel launches per
|
||
query.
|
||
- **Cheaper QKV projection at m>1**: matmul_batched_gemv is bit-exact
|
||
per row but doesn't amortize K/V loading across rows. Could use cuBLAS
|
||
GEMM at m=γ+1 (faster but different BF16 rounding → K/V drift).
|
||
|
||
### 3. Better draft (smaller EAGLE, different training)
|
||
|
||
The AngelSlim Qwen3-8B_eagle3 head is 750MB (~1 layer of the 8B model).
|
||
Alternatives:
|
||
- Smaller Qwen3 (0.6B) as draft: already tried, γ=1 gets 40% acceptance
|
||
but draft cost ~2.5ms (vs EAGLE's ~0.5ms).
|
||
- Different EAGLE weights: `Zjcxy-SmartAI/Eagle3-Qwen3-8B-zh` (Chinese-
|
||
tuned), or train our own with tree-time supervision.
|
||
|
||
## Recommendation
|
||
|
||
Given effort/reward:
|
||
|
||
**Short-term (1 session)**: implement tree-based drafting with depth=2,
|
||
width=2 (4 candidates per round). Reuse existing batched verify with
|
||
tree-aware masking. Expect acceptance to double (25% → 50%+).
|
||
|
||
**Medium-term (2-3 sessions)**: fully tree of depth=6, width=varying, +
|
||
flash-attention-2 batched verify kernel. This matches the vLLM
|
||
implementation and should approach 2× speedup.
|
||
|
||
**Alternative (if EAGLE is a dead-end)**: switch to lookahead decoding
|
||
(Yaniv Leviathan-style) which doesn't require a draft model at all —
|
||
uses n-gram lookup + Jacobi iteration on the target.
|
||
|
||
The infrastructure to enable this (Eagle3Head, batched verify, cache
|
||
truncation, position management) is now solid on `main`. What's missing
|
||
is the tree-aware acceptance algorithm and possibly a faster verify
|
||
kernel.
|
||
|
||
---
|
||
|
||
## Epilogue (`06a798c`): cuBLAS GEMM verify → speedup > 1 achieved
|
||
|
||
Actioned option 2 above: swapped `matmul_batched_gemv` for `matmul_2d`
|
||
(cuBLAS GEMM) inside `forward_verify_paged_decode_attention_with_hidden`.
|
||
|
||
Micro-benchmark (bench-verify-cost.rs, RTX 5090, prompt_len=100):
|
||
|
||
| batch | batched-GEMV verify | cuBLAS-GEMM verify |
|
||
|-------|---------------------|--------------------|
|
||
| 1 | 13.14 ms (1.05×) | 13.04 ms (1.04×) |
|
||
| 2 | 19.51 ms (1.56×) | 13.52 ms (1.08×) |
|
||
| 3 | 26.10 ms (2.09×) | 13.59 ms (1.09×) |
|
||
| 5 | 38.72 ms (3.10×) | 13.88 ms (1.11×) |
|
||
| 9 | 64.15 ms (5.14×) | 15.03 ms (1.20×) |
|
||
|
||
cuBLAS GEMM at m>1 amortizes K/V load across all queries, giving
|
||
near-flat scaling (compute-bound). GEMV loads K/V per row → linear.
|
||
|
||
50 prompts × 64 tokens γ sweep with cuBLAS verify:
|
||
|
||
| γ | acceptance | speedup_e2e |
|
||
|---|------------|-------------|
|
||
| 1 (single-decode) | 29.8% | 0.95× |
|
||
| **2** | **16.9%** | **1.10×** ← best |
|
||
| 3 | 11.6% | 1.06× |
|
||
| 4 | 8.9% | 1.02× |
|
||
| 5 | 7.2% | 0.96× |
|
||
| 6 | 6.0% | 0.93× |
|
||
| 8 | 4.5% | 0.86× |
|
||
|
||
Tradeoff: `matched=false`. cuBLAS GEMM at m>1 rounds BF16 differently
|
||
from custom GEMV at m=1. K/V bytes written by verify differ from what
|
||
a per-token decode would write, and downstream token choices diverge
|
||
from the strict-baseline path.
|
||
|
||
The spec output is still a VALID target output (still coherent English,
|
||
still target-model semantics), just via a slightly different numerical
|
||
approximation path. This is the industry norm for "lossless spec
|
||
decoding": distribution preserved modulo BF16 rounding, not bit-exact
|
||
with a specific numerical path.
|
||
|
||
`speedup_e2e = 1.10×` is a real, measurable win at γ=2 on 50×64 prompts.
|
||
Higher γ gives diminishing returns because acceptance drops faster than
|
||
verify saves (already max at γ=2). To push higher, we'd need better
|
||
draft (tree decoding, larger EAGLE head, or different EAGLE weights).
|