eagle3: proper residual chain + stateful KV cache
Two fixes to bring EAGLE3 forward in line with vllm's llama_eagle3.py
reference:
1. Residual chain: previously the residual added into post_attention_layernorm
was the token embedding (wrong). Reference uses _norm_after_residual:
residual = fused_h (pre-norm)
hidden_states = hidden_norm(fused_h)
Then post_attention_layernorm is a fused add_rmsnorm(attn_out, residual),
and the final norm is another add_rmsnorm(mlp_out, residual_after_attn).
Neither residual carries the embedding — both carry fused_h forward.
2. KV cache: previously the attention was approximated as "output = V"
because seq_len=1 (no cache), effectively giving EAGLE no history.
Add a real per-Eagle3Head KV cache (1 layer × [1, num_kv_heads,
max_seq_len, head_dim] BF16) that grows as we call step(). Use the
existing decode_attention kernel with a fresh contiguous slice of the
cache each step. reset() clears current_len for a new sequence.
Result on 10 prompts × 32 tokens (γ=1, no batched verify yet):
matched=true across all prompts
acceptance_rate = 20.0% (was 4.7% before residual fix, 1.3% originally)
- Prompt 00 "The capital of France is": 60% (18/30) — best case
- Other prompts: 10-25% — matches EAGLE paper's observation that
structured/factual prompts get higher acceptance
Sanity check (check-eagle3) on Paris prompt now shows:
EAGLE top-5 pairing A: "." / " is" / "," / " Paris" / ".\n"
MATCH: EAGLE agrees with target on next token.
speedup_e2e still 0.95x because γ=1 does 1 target decode per token
regardless of acceptance. Real speedup requires γ≥2 with a single
batched target-verify covering all γ draft tokens; that's the next step.
This commit is contained in:
@@ -109,7 +109,7 @@ fn main() {
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
eprintln!("Loading EAGLE3 head...");
|
||||
let eagle = Eagle3Head::load(&eagle_dir, device);
|
||||
let mut eagle = Eagle3Head::load(&eagle_dir, device);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
@@ -151,7 +151,7 @@ fn main() {
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let spec = run_eagle_gamma1(
|
||||
&target,
|
||||
&eagle,
|
||||
&mut eagle,
|
||||
&embed_tokens,
|
||||
&mut target_cache,
|
||||
&tokenizer,
|
||||
@@ -279,7 +279,7 @@ fn run_baseline(
|
||||
/// For γ=1 we simply track acceptance rate for informational purposes.
|
||||
fn run_eagle_gamma1(
|
||||
target: &Qwen3,
|
||||
eagle: &Eagle3Head,
|
||||
eagle: &mut Eagle3Head,
|
||||
embed_tokens: &Tensor,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
@@ -288,6 +288,7 @@ fn run_eagle_gamma1(
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).unwrap();
|
||||
eagle.reset();
|
||||
let t0 = Instant::now();
|
||||
|
||||
// Prefill target — we don't have hidden state hooks from prefill in this
|
||||
|
||||
@@ -38,7 +38,7 @@ fn main() {
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
eprintln!("Loading EAGLE3 head from {}", eagle_dir.display());
|
||||
let eagle = Eagle3Head::load(&eagle_dir, device);
|
||||
let mut eagle = Eagle3Head::load(&eagle_dir, device);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
@@ -103,6 +103,7 @@ fn main() {
|
||||
// Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token
|
||||
// and the hidden states from the position where target_first lives).
|
||||
// EAGLE should predict target_next (or close to it) to be useful.
|
||||
eagle.reset();
|
||||
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
|
||||
let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
|
||||
println!(
|
||||
@@ -129,6 +130,7 @@ fn main() {
|
||||
|
||||
// Alternative pairing B: pair hooks with target_next (the token those hooks produced
|
||||
// via lm_head), predict token after target_next. Position advances by 1.
|
||||
eagle.reset();
|
||||
let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1);
|
||||
let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]);
|
||||
println!(
|
||||
|
||||
@@ -47,7 +47,14 @@ pub struct Eagle3Head {
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_seq_len: usize,
|
||||
rope_cache: RopeCache,
|
||||
// Stateful 1-layer KV cache: [1, num_kv_heads, max_seq_len, head_dim] BF16.
|
||||
// We slice `..current_len` for attention. The head is tiny (~64 KB per
|
||||
// 1000 tokens) so pre-allocating max_seq_len wastes negligible memory.
|
||||
k_cache: Tensor,
|
||||
v_cache: Tensor,
|
||||
current_len: usize,
|
||||
}
|
||||
|
||||
impl Eagle3Head {
|
||||
@@ -100,6 +107,17 @@ impl Eagle3Head {
|
||||
|
||||
let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta);
|
||||
|
||||
let k_cache = Tensor::zeros(
|
||||
&[1, num_kv_heads, max_seq_len, head_dim],
|
||||
DType::BF16,
|
||||
Device::Cuda(device),
|
||||
);
|
||||
let v_cache = Tensor::zeros(
|
||||
&[1, num_kv_heads, max_seq_len, head_dim],
|
||||
DType::BF16,
|
||||
Device::Cuda(device),
|
||||
);
|
||||
|
||||
Self {
|
||||
fc_wt,
|
||||
hidden_norm,
|
||||
@@ -119,10 +137,19 @@ impl Eagle3Head {
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
rope_cache,
|
||||
k_cache,
|
||||
v_cache,
|
||||
current_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the internal KV cache for a fresh sequence.
|
||||
pub fn reset(&mut self) {
|
||||
self.current_len = 0;
|
||||
}
|
||||
|
||||
/// One draft step: produce a token in target vocabulary space.
|
||||
///
|
||||
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
|
||||
@@ -132,13 +159,19 @@ impl Eagle3Head {
|
||||
///
|
||||
/// Returns (draft_token_in_target_vocab, draft_logits_tensor).
|
||||
pub fn step(
|
||||
&self,
|
||||
&mut self,
|
||||
target_hidden: &[Tensor; 3],
|
||||
embed_table: &Tensor,
|
||||
prev_token: u32,
|
||||
position: usize,
|
||||
) -> (u32, Tensor) {
|
||||
let eps = 1e-6f32;
|
||||
assert!(
|
||||
self.current_len < self.max_seq_len,
|
||||
"EAGLE KV cache overflow: {} >= {}",
|
||||
self.current_len,
|
||||
self.max_seq_len
|
||||
);
|
||||
|
||||
// 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc
|
||||
let h_cat = concat_hidden(target_hidden);
|
||||
@@ -147,15 +180,16 @@ impl Eagle3Head {
|
||||
// 2. Embed previous token (shared with target)
|
||||
let emb = embedding(embed_table, &[prev_token]); // [1, hidden]
|
||||
|
||||
// 3. Concat normalized: [norm(emb), norm(fused_h)] → [1, 2*hidden]
|
||||
// 3. Norm both, concat, remember residual = fused_h (pre-norm).
|
||||
let residual = fused_h.clone();
|
||||
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
|
||||
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
|
||||
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 8192]
|
||||
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 2*hidden]
|
||||
|
||||
// 4. Self-attention (no KV cache for simplicity in v0 — single query)
|
||||
let q = matmul_2d(&attn_in, &self.q_proj_wt); // [1, num_heads*head_dim]
|
||||
let k = matmul_2d(&attn_in, &self.k_proj_wt); // [1, num_kv*head_dim]
|
||||
let v = matmul_2d(&attn_in, &self.v_proj_wt); // [1, num_kv*head_dim]
|
||||
// 4. Q/K/V projection then RoPE (position from caller).
|
||||
let q = matmul_2d(&attn_in, &self.q_proj_wt);
|
||||
let k = matmul_2d(&attn_in, &self.k_proj_wt);
|
||||
let v = matmul_2d(&attn_in, &self.v_proj_wt);
|
||||
|
||||
let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]);
|
||||
let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||
@@ -163,39 +197,58 @@ impl Eagle3Head {
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||
|
||||
// Single-token attention: Q·K^T / sqrt(d) → softmax → V
|
||||
// With seq_len=1, attention is trivial: output = V (weight=1.0)
|
||||
let attn_out = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||
let attn_out = if self.num_heads != self.num_kv_heads {
|
||||
repeat_kv_for_single_token(&attn_out, self.num_heads / self.num_kv_heads)
|
||||
} else {
|
||||
attn_out
|
||||
};
|
||||
// 5. Append new K/V to the internal cache at slot `current_len`, then
|
||||
// build a contiguous view [1, num_kv_heads, current_len+1, head_dim].
|
||||
let v_3d = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||
self.append_to_kv_cache(&k_3d, &v_3d);
|
||||
self.current_len += 1;
|
||||
let kv_len = self.current_len;
|
||||
let k_view = self.k_cache.narrow(2, 0, kv_len).contiguous();
|
||||
let v_view = self.v_cache.narrow(2, 0, kv_len).contiguous();
|
||||
|
||||
// 6. Attention: q [1, num_q_heads, 1, head_dim] × k/v [1, num_kv_heads, kv_len, head_dim]
|
||||
let q_4d = q_3d.reshape(&[1, self.num_heads, 1, self.head_dim]);
|
||||
let attn_out = decode_attention(&q_4d, &k_view, &v_view);
|
||||
|
||||
// 7. Merge heads and o_proj.
|
||||
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); // [1, hidden]
|
||||
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt);
|
||||
|
||||
// Residual from embedding
|
||||
let x = add(&attn_proj, &emb);
|
||||
// 8. Post-attn fused add_rmsnorm.
|
||||
let (mlp_in, residual) =
|
||||
add_rmsnorm(&attn_proj, &residual, &self.post_attention_layernorm, eps);
|
||||
|
||||
// 5. MLP
|
||||
let normed = rmsnorm(&x, &self.post_attention_layernorm, eps);
|
||||
let gate = matmul_2d(&normed, &self.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &self.up_proj_wt);
|
||||
let mlp_out = silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&mlp_out, &self.down_proj_wt);
|
||||
let x = add(&x, &down);
|
||||
// 9. MLP.
|
||||
let gate = matmul_2d(&mlp_in, &self.gate_proj_wt);
|
||||
let up = matmul_2d(&mlp_in, &self.up_proj_wt);
|
||||
let hidden = silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden, &self.down_proj_wt);
|
||||
|
||||
// 6. Final norm + lm_head
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
let logits = matmul_2d(&x, &self.lm_head_wt); // [1, 32000]
|
||||
// 10. Final fused add_rmsnorm → lm_head.
|
||||
let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps);
|
||||
let logits = matmul_2d(&x, &self.lm_head_wt);
|
||||
|
||||
// 7. Argmax in draft vocab → map to target vocab
|
||||
let draft_id = argmax_bf16_single(&logits);
|
||||
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
||||
|
||||
(target_id, logits)
|
||||
}
|
||||
|
||||
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position
|
||||
/// `current_len` inside the [1, num_kv_heads, max_seq_len, head_dim] cache.
|
||||
fn append_to_kv_cache(&mut self, new_k: &Tensor, new_v: &Tensor) {
|
||||
let head_bytes = self.head_dim * self.k_cache.dtype().size_bytes();
|
||||
for h in 0..self.num_kv_heads {
|
||||
for (cache, src) in [(&self.k_cache, new_k), (&self.v_cache, new_v)] {
|
||||
let dst = unsafe {
|
||||
(cache.data_ptr() as *mut u8)
|
||||
.add(((h * self.max_seq_len) + self.current_len) * head_bytes)
|
||||
};
|
||||
let s = unsafe { (src.data_ptr() as *const u8).add(h * head_bytes) };
|
||||
d2d(dst, s, head_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Map a draft-vocab token id to the full target-vocab id via d2t.
|
||||
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
|
||||
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
|
||||
|
||||
Reference in New Issue
Block a user