eagle3: proper residual chain + stateful KV cache

Two fixes to bring EAGLE3 forward in line with vllm's llama_eagle3.py
reference:

1. Residual chain: previously the residual added into post_attention_layernorm
   was the token embedding (wrong). Reference uses _norm_after_residual:
     residual = fused_h (pre-norm)
     hidden_states = hidden_norm(fused_h)
   Then post_attention_layernorm is a fused add_rmsnorm(attn_out, residual),
   and the final norm is another add_rmsnorm(mlp_out, residual_after_attn).
   Neither residual carries the embedding — both carry fused_h forward.

2. KV cache: previously the attention was approximated as "output = V"
   because seq_len=1 (no cache), effectively giving EAGLE no history.
   Add a real per-Eagle3Head KV cache (1 layer × [1, num_kv_heads,
   max_seq_len, head_dim] BF16) that grows as we call step(). Use the
   existing decode_attention kernel with a fresh contiguous slice of the
   cache each step. reset() clears current_len for a new sequence.

Result on 10 prompts × 32 tokens (γ=1, no batched verify yet):
  matched=true across all prompts
  acceptance_rate = 20.0% (was 4.7% before residual fix, 1.3% originally)
    - Prompt 00 "The capital of France is": 60% (18/30) — best case
    - Other prompts: 10-25% — matches EAGLE paper's observation that
      structured/factual prompts get higher acceptance

Sanity check (check-eagle3) on Paris prompt now shows:
  EAGLE top-5 pairing A: "." / " is" / "," / " Paris" / ".\n"
  MATCH: EAGLE agrees with target on next token.

speedup_e2e still 0.95x because γ=1 does 1 target decode per token
regardless of acceptance. Real speedup requires γ≥2 with a single
batched target-verify covering all γ draft tokens; that's the next step.
This commit is contained in:
2026-07-01 17:50:49 +08:00
parent 68b55fa1e6
commit a24621fa6a
3 changed files with 90 additions and 34 deletions

View File

@@ -109,7 +109,7 @@ fn main() {
xserv_cuda::allocator::cached_trim(); xserv_cuda::allocator::cached_trim();
eprintln!("Loading EAGLE3 head..."); eprintln!("Loading EAGLE3 head...");
let eagle = Eagle3Head::load(&eagle_dir, device); let mut eagle = Eagle3Head::load(&eagle_dir, device);
xserv_cuda::allocator::cached_trim(); xserv_cuda::allocator::cached_trim();
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json")); let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
@@ -151,7 +151,7 @@ fn main() {
let mut target_cache = new_cache(&target_config, max_seq_len, device); let mut target_cache = new_cache(&target_config, max_seq_len, device);
let spec = run_eagle_gamma1( let spec = run_eagle_gamma1(
&target, &target,
&eagle, &mut eagle,
&embed_tokens, &embed_tokens,
&mut target_cache, &mut target_cache,
&tokenizer, &tokenizer,
@@ -279,7 +279,7 @@ fn run_baseline(
/// For γ=1 we simply track acceptance rate for informational purposes. /// For γ=1 we simply track acceptance rate for informational purposes.
fn run_eagle_gamma1( fn run_eagle_gamma1(
target: &Qwen3, target: &Qwen3,
eagle: &Eagle3Head, eagle: &mut Eagle3Head,
embed_tokens: &Tensor, embed_tokens: &Tensor,
cache: &mut PagedKVCache, cache: &mut PagedKVCache,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
@@ -288,6 +288,7 @@ fn run_eagle_gamma1(
) -> RunStats { ) -> RunStats {
let slot = 0; let slot = 0;
cache.register_sequence(slot).unwrap(); cache.register_sequence(slot).unwrap();
eagle.reset();
let t0 = Instant::now(); let t0 = Instant::now();
// Prefill target — we don't have hidden state hooks from prefill in this // Prefill target — we don't have hidden state hooks from prefill in this

View File

@@ -38,7 +38,7 @@ fn main() {
xserv_cuda::allocator::cached_trim(); xserv_cuda::allocator::cached_trim();
eprintln!("Loading EAGLE3 head from {}", eagle_dir.display()); eprintln!("Loading EAGLE3 head from {}", eagle_dir.display());
let eagle = Eagle3Head::load(&eagle_dir, device); let mut eagle = Eagle3Head::load(&eagle_dir, device);
xserv_cuda::allocator::cached_trim(); xserv_cuda::allocator::cached_trim();
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json")); let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
@@ -103,6 +103,7 @@ fn main() {
// Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token // Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token
// and the hidden states from the position where target_first lives). // and the hidden states from the position where target_first lives).
// EAGLE should predict target_next (or close to it) to be useful. // EAGLE should predict target_next (or close to it) to be useful.
eagle.reset();
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos); let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
let eagle_pred_text = tokenizer.decode(&[eagle_pred]); let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
println!( println!(
@@ -129,6 +130,7 @@ fn main() {
// Alternative pairing B: pair hooks with target_next (the token those hooks produced // Alternative pairing B: pair hooks with target_next (the token those hooks produced
// via lm_head), predict token after target_next. Position advances by 1. // via lm_head), predict token after target_next. Position advances by 1.
eagle.reset();
let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1); let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1);
let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]); let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]);
println!( println!(

View File

@@ -47,7 +47,14 @@ pub struct Eagle3Head {
num_heads: usize, num_heads: usize,
num_kv_heads: usize, num_kv_heads: usize,
head_dim: usize, head_dim: usize,
max_seq_len: usize,
rope_cache: RopeCache, rope_cache: RopeCache,
// Stateful 1-layer KV cache: [1, num_kv_heads, max_seq_len, head_dim] BF16.
// We slice `..current_len` for attention. The head is tiny (~64 KB per
// 1000 tokens) so pre-allocating max_seq_len wastes negligible memory.
k_cache: Tensor,
v_cache: Tensor,
current_len: usize,
} }
impl Eagle3Head { impl Eagle3Head {
@@ -100,6 +107,17 @@ impl Eagle3Head {
let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta); let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta);
let k_cache = Tensor::zeros(
&[1, num_kv_heads, max_seq_len, head_dim],
DType::BF16,
Device::Cuda(device),
);
let v_cache = Tensor::zeros(
&[1, num_kv_heads, max_seq_len, head_dim],
DType::BF16,
Device::Cuda(device),
);
Self { Self {
fc_wt, fc_wt,
hidden_norm, hidden_norm,
@@ -119,10 +137,19 @@ impl Eagle3Head {
num_heads, num_heads,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
max_seq_len,
rope_cache, rope_cache,
k_cache,
v_cache,
current_len: 0,
} }
} }
/// Reset the internal KV cache for a fresh sequence.
pub fn reset(&mut self) {
self.current_len = 0;
}
/// One draft step: produce a token in target vocabulary space. /// One draft step: produce a token in target vocabulary space.
/// ///
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers /// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
@@ -132,13 +159,19 @@ impl Eagle3Head {
/// ///
/// Returns (draft_token_in_target_vocab, draft_logits_tensor). /// Returns (draft_token_in_target_vocab, draft_logits_tensor).
pub fn step( pub fn step(
&self, &mut self,
target_hidden: &[Tensor; 3], target_hidden: &[Tensor; 3],
embed_table: &Tensor, embed_table: &Tensor,
prev_token: u32, prev_token: u32,
position: usize, position: usize,
) -> (u32, Tensor) { ) -> (u32, Tensor) {
let eps = 1e-6f32; let eps = 1e-6f32;
assert!(
self.current_len < self.max_seq_len,
"EAGLE KV cache overflow: {} >= {}",
self.current_len,
self.max_seq_len
);
// 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc // 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc
let h_cat = concat_hidden(target_hidden); let h_cat = concat_hidden(target_hidden);
@@ -147,15 +180,16 @@ impl Eagle3Head {
// 2. Embed previous token (shared with target) // 2. Embed previous token (shared with target)
let emb = embedding(embed_table, &[prev_token]); // [1, hidden] let emb = embedding(embed_table, &[prev_token]); // [1, hidden]
// 3. Concat normalized: [norm(emb), norm(fused_h)] → [1, 2*hidden] // 3. Norm both, concat, remember residual = fused_h (pre-norm).
let residual = fused_h.clone();
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps); let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps); let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 8192] let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 2*hidden]
// 4. Self-attention (no KV cache for simplicity in v0 — single query) // 4. Q/K/V projection then RoPE (position from caller).
let q = matmul_2d(&attn_in, &self.q_proj_wt); // [1, num_heads*head_dim] let q = matmul_2d(&attn_in, &self.q_proj_wt);
let k = matmul_2d(&attn_in, &self.k_proj_wt); // [1, num_kv*head_dim] let k = matmul_2d(&attn_in, &self.k_proj_wt);
let v = matmul_2d(&attn_in, &self.v_proj_wt); // [1, num_kv*head_dim] let v = matmul_2d(&attn_in, &self.v_proj_wt);
let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]); let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]);
let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]); let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]);
@@ -163,39 +197,58 @@ impl Eagle3Head {
rope_inplace(&q_3d, &self.rope_cache, &positions); rope_inplace(&q_3d, &self.rope_cache, &positions);
rope_inplace(&k_3d, &self.rope_cache, &positions); rope_inplace(&k_3d, &self.rope_cache, &positions);
// Single-token attention: Q·K^T / sqrt(d) → softmax → V // 5. Append new K/V to the internal cache at slot `current_len`, then
// With seq_len=1, attention is trivial: output = V (weight=1.0) // build a contiguous view [1, num_kv_heads, current_len+1, head_dim].
let attn_out = v.reshape(&[1, self.num_kv_heads, self.head_dim]); let v_3d = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
let attn_out = if self.num_heads != self.num_kv_heads { self.append_to_kv_cache(&k_3d, &v_3d);
repeat_kv_for_single_token(&attn_out, self.num_heads / self.num_kv_heads) self.current_len += 1;
} else { let kv_len = self.current_len;
attn_out let k_view = self.k_cache.narrow(2, 0, kv_len).contiguous();
}; let v_view = self.v_cache.narrow(2, 0, kv_len).contiguous();
// 6. Attention: q [1, num_q_heads, 1, head_dim] × k/v [1, num_kv_heads, kv_len, head_dim]
let q_4d = q_3d.reshape(&[1, self.num_heads, 1, self.head_dim]);
let attn_out = decode_attention(&q_4d, &k_view, &v_view);
// 7. Merge heads and o_proj.
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]); let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); // [1, hidden] let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt);
// Residual from embedding // 8. Post-attn fused add_rmsnorm.
let x = add(&attn_proj, &emb); let (mlp_in, residual) =
add_rmsnorm(&attn_proj, &residual, &self.post_attention_layernorm, eps);
// 5. MLP // 9. MLP.
let normed = rmsnorm(&x, &self.post_attention_layernorm, eps); let gate = matmul_2d(&mlp_in, &self.gate_proj_wt);
let gate = matmul_2d(&normed, &self.gate_proj_wt); let up = matmul_2d(&mlp_in, &self.up_proj_wt);
let up = matmul_2d(&normed, &self.up_proj_wt); let hidden = silu_mul(&gate, &up);
let mlp_out = silu_mul(&gate, &up); let down = matmul_2d(&hidden, &self.down_proj_wt);
let down = matmul_2d(&mlp_out, &self.down_proj_wt);
let x = add(&x, &down);
// 6. Final norm + lm_head // 10. Final fused add_rmsnorm lm_head.
let x = rmsnorm(&x, &self.norm, eps); let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps);
let logits = matmul_2d(&x, &self.lm_head_wt); // [1, 32000] let logits = matmul_2d(&x, &self.lm_head_wt);
// 7. Argmax in draft vocab → map to target vocab
let draft_id = argmax_bf16_single(&logits); let draft_id = argmax_bf16_single(&logits);
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32; let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
(target_id, logits) (target_id, logits)
} }
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position
/// `current_len` inside the [1, num_kv_heads, max_seq_len, head_dim] cache.
fn append_to_kv_cache(&mut self, new_k: &Tensor, new_v: &Tensor) {
let head_bytes = self.head_dim * self.k_cache.dtype().size_bytes();
for h in 0..self.num_kv_heads {
for (cache, src) in [(&self.k_cache, new_k), (&self.v_cache, new_v)] {
let dst = unsafe {
(cache.data_ptr() as *mut u8)
.add(((h * self.max_seq_len) + self.current_len) * head_bytes)
};
let s = unsafe { (src.data_ptr() as *const u8).add(h * head_bytes) };
d2d(dst, s, head_bytes);
}
}
}
/// Map a draft-vocab token id to the full target-vocab id via d2t. /// Map a draft-vocab token id to the full target-vocab id via d2t.
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 { pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
(draft_id as i64 + self.d2t[draft_id as usize]) as u32 (draft_id as i64 + self.d2t[draft_id as usize]) as u32