eagle3: proper residual chain + stateful KV cache
Two fixes to bring EAGLE3 forward in line with vllm's llama_eagle3.py
reference:
1. Residual chain: previously the residual added into post_attention_layernorm
was the token embedding (wrong). Reference uses _norm_after_residual:
residual = fused_h (pre-norm)
hidden_states = hidden_norm(fused_h)
Then post_attention_layernorm is a fused add_rmsnorm(attn_out, residual),
and the final norm is another add_rmsnorm(mlp_out, residual_after_attn).
Neither residual carries the embedding — both carry fused_h forward.
2. KV cache: previously the attention was approximated as "output = V"
because seq_len=1 (no cache), effectively giving EAGLE no history.
Add a real per-Eagle3Head KV cache (1 layer × [1, num_kv_heads,
max_seq_len, head_dim] BF16) that grows as we call step(). Use the
existing decode_attention kernel with a fresh contiguous slice of the
cache each step. reset() clears current_len for a new sequence.
Result on 10 prompts × 32 tokens (γ=1, no batched verify yet):
matched=true across all prompts
acceptance_rate = 20.0% (was 4.7% before residual fix, 1.3% originally)
- Prompt 00 "The capital of France is": 60% (18/30) — best case
- Other prompts: 10-25% — matches EAGLE paper's observation that
structured/factual prompts get higher acceptance
Sanity check (check-eagle3) on Paris prompt now shows:
EAGLE top-5 pairing A: "." / " is" / "," / " Paris" / ".\n"
MATCH: EAGLE agrees with target on next token.
speedup_e2e still 0.95x because γ=1 does 1 target decode per token
regardless of acceptance. Real speedup requires γ≥2 with a single
batched target-verify covering all γ draft tokens; that's the next step.
This commit is contained in:
@@ -109,7 +109,7 @@ fn main() {
|
|||||||
xserv_cuda::allocator::cached_trim();
|
xserv_cuda::allocator::cached_trim();
|
||||||
|
|
||||||
eprintln!("Loading EAGLE3 head...");
|
eprintln!("Loading EAGLE3 head...");
|
||||||
let eagle = Eagle3Head::load(&eagle_dir, device);
|
let mut eagle = Eagle3Head::load(&eagle_dir, device);
|
||||||
xserv_cuda::allocator::cached_trim();
|
xserv_cuda::allocator::cached_trim();
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||||
@@ -151,7 +151,7 @@ fn main() {
|
|||||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||||
let spec = run_eagle_gamma1(
|
let spec = run_eagle_gamma1(
|
||||||
&target,
|
&target,
|
||||||
&eagle,
|
&mut eagle,
|
||||||
&embed_tokens,
|
&embed_tokens,
|
||||||
&mut target_cache,
|
&mut target_cache,
|
||||||
&tokenizer,
|
&tokenizer,
|
||||||
@@ -279,7 +279,7 @@ fn run_baseline(
|
|||||||
/// For γ=1 we simply track acceptance rate for informational purposes.
|
/// For γ=1 we simply track acceptance rate for informational purposes.
|
||||||
fn run_eagle_gamma1(
|
fn run_eagle_gamma1(
|
||||||
target: &Qwen3,
|
target: &Qwen3,
|
||||||
eagle: &Eagle3Head,
|
eagle: &mut Eagle3Head,
|
||||||
embed_tokens: &Tensor,
|
embed_tokens: &Tensor,
|
||||||
cache: &mut PagedKVCache,
|
cache: &mut PagedKVCache,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
@@ -288,6 +288,7 @@ fn run_eagle_gamma1(
|
|||||||
) -> RunStats {
|
) -> RunStats {
|
||||||
let slot = 0;
|
let slot = 0;
|
||||||
cache.register_sequence(slot).unwrap();
|
cache.register_sequence(slot).unwrap();
|
||||||
|
eagle.reset();
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
|
|
||||||
// Prefill target — we don't have hidden state hooks from prefill in this
|
// Prefill target — we don't have hidden state hooks from prefill in this
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ fn main() {
|
|||||||
xserv_cuda::allocator::cached_trim();
|
xserv_cuda::allocator::cached_trim();
|
||||||
|
|
||||||
eprintln!("Loading EAGLE3 head from {}", eagle_dir.display());
|
eprintln!("Loading EAGLE3 head from {}", eagle_dir.display());
|
||||||
let eagle = Eagle3Head::load(&eagle_dir, device);
|
let mut eagle = Eagle3Head::load(&eagle_dir, device);
|
||||||
xserv_cuda::allocator::cached_trim();
|
xserv_cuda::allocator::cached_trim();
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||||
@@ -103,6 +103,7 @@ fn main() {
|
|||||||
// Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token
|
// Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token
|
||||||
// and the hidden states from the position where target_first lives).
|
// and the hidden states from the position where target_first lives).
|
||||||
// EAGLE should predict target_next (or close to it) to be useful.
|
// EAGLE should predict target_next (or close to it) to be useful.
|
||||||
|
eagle.reset();
|
||||||
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
|
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
|
||||||
let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
|
let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
|
||||||
println!(
|
println!(
|
||||||
@@ -129,6 +130,7 @@ fn main() {
|
|||||||
|
|
||||||
// Alternative pairing B: pair hooks with target_next (the token those hooks produced
|
// Alternative pairing B: pair hooks with target_next (the token those hooks produced
|
||||||
// via lm_head), predict token after target_next. Position advances by 1.
|
// via lm_head), predict token after target_next. Position advances by 1.
|
||||||
|
eagle.reset();
|
||||||
let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1);
|
let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1);
|
||||||
let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]);
|
let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]);
|
||||||
println!(
|
println!(
|
||||||
|
|||||||
@@ -47,7 +47,14 @@ pub struct Eagle3Head {
|
|||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
num_kv_heads: usize,
|
num_kv_heads: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
|
max_seq_len: usize,
|
||||||
rope_cache: RopeCache,
|
rope_cache: RopeCache,
|
||||||
|
// Stateful 1-layer KV cache: [1, num_kv_heads, max_seq_len, head_dim] BF16.
|
||||||
|
// We slice `..current_len` for attention. The head is tiny (~64 KB per
|
||||||
|
// 1000 tokens) so pre-allocating max_seq_len wastes negligible memory.
|
||||||
|
k_cache: Tensor,
|
||||||
|
v_cache: Tensor,
|
||||||
|
current_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Eagle3Head {
|
impl Eagle3Head {
|
||||||
@@ -100,6 +107,17 @@ impl Eagle3Head {
|
|||||||
|
|
||||||
let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta);
|
let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta);
|
||||||
|
|
||||||
|
let k_cache = Tensor::zeros(
|
||||||
|
&[1, num_kv_heads, max_seq_len, head_dim],
|
||||||
|
DType::BF16,
|
||||||
|
Device::Cuda(device),
|
||||||
|
);
|
||||||
|
let v_cache = Tensor::zeros(
|
||||||
|
&[1, num_kv_heads, max_seq_len, head_dim],
|
||||||
|
DType::BF16,
|
||||||
|
Device::Cuda(device),
|
||||||
|
);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
fc_wt,
|
fc_wt,
|
||||||
hidden_norm,
|
hidden_norm,
|
||||||
@@ -119,10 +137,19 @@ impl Eagle3Head {
|
|||||||
num_heads,
|
num_heads,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
|
max_seq_len,
|
||||||
rope_cache,
|
rope_cache,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
current_len: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reset the internal KV cache for a fresh sequence.
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.current_len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
/// One draft step: produce a token in target vocabulary space.
|
/// One draft step: produce a token in target vocabulary space.
|
||||||
///
|
///
|
||||||
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
|
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
|
||||||
@@ -132,13 +159,19 @@ impl Eagle3Head {
|
|||||||
///
|
///
|
||||||
/// Returns (draft_token_in_target_vocab, draft_logits_tensor).
|
/// Returns (draft_token_in_target_vocab, draft_logits_tensor).
|
||||||
pub fn step(
|
pub fn step(
|
||||||
&self,
|
&mut self,
|
||||||
target_hidden: &[Tensor; 3],
|
target_hidden: &[Tensor; 3],
|
||||||
embed_table: &Tensor,
|
embed_table: &Tensor,
|
||||||
prev_token: u32,
|
prev_token: u32,
|
||||||
position: usize,
|
position: usize,
|
||||||
) -> (u32, Tensor) {
|
) -> (u32, Tensor) {
|
||||||
let eps = 1e-6f32;
|
let eps = 1e-6f32;
|
||||||
|
assert!(
|
||||||
|
self.current_len < self.max_seq_len,
|
||||||
|
"EAGLE KV cache overflow: {} >= {}",
|
||||||
|
self.current_len,
|
||||||
|
self.max_seq_len
|
||||||
|
);
|
||||||
|
|
||||||
// 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc
|
// 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc
|
||||||
let h_cat = concat_hidden(target_hidden);
|
let h_cat = concat_hidden(target_hidden);
|
||||||
@@ -147,15 +180,16 @@ impl Eagle3Head {
|
|||||||
// 2. Embed previous token (shared with target)
|
// 2. Embed previous token (shared with target)
|
||||||
let emb = embedding(embed_table, &[prev_token]); // [1, hidden]
|
let emb = embedding(embed_table, &[prev_token]); // [1, hidden]
|
||||||
|
|
||||||
// 3. Concat normalized: [norm(emb), norm(fused_h)] → [1, 2*hidden]
|
// 3. Norm both, concat, remember residual = fused_h (pre-norm).
|
||||||
|
let residual = fused_h.clone();
|
||||||
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
|
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
|
||||||
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
|
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
|
||||||
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 8192]
|
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 2*hidden]
|
||||||
|
|
||||||
// 4. Self-attention (no KV cache for simplicity in v0 — single query)
|
// 4. Q/K/V projection then RoPE (position from caller).
|
||||||
let q = matmul_2d(&attn_in, &self.q_proj_wt); // [1, num_heads*head_dim]
|
let q = matmul_2d(&attn_in, &self.q_proj_wt);
|
||||||
let k = matmul_2d(&attn_in, &self.k_proj_wt); // [1, num_kv*head_dim]
|
let k = matmul_2d(&attn_in, &self.k_proj_wt);
|
||||||
let v = matmul_2d(&attn_in, &self.v_proj_wt); // [1, num_kv*head_dim]
|
let v = matmul_2d(&attn_in, &self.v_proj_wt);
|
||||||
|
|
||||||
let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]);
|
let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]);
|
||||||
let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||||
@@ -163,39 +197,58 @@ impl Eagle3Head {
|
|||||||
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||||
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||||
|
|
||||||
// Single-token attention: Q·K^T / sqrt(d) → softmax → V
|
// 5. Append new K/V to the internal cache at slot `current_len`, then
|
||||||
// With seq_len=1, attention is trivial: output = V (weight=1.0)
|
// build a contiguous view [1, num_kv_heads, current_len+1, head_dim].
|
||||||
let attn_out = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
let v_3d = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||||
let attn_out = if self.num_heads != self.num_kv_heads {
|
self.append_to_kv_cache(&k_3d, &v_3d);
|
||||||
repeat_kv_for_single_token(&attn_out, self.num_heads / self.num_kv_heads)
|
self.current_len += 1;
|
||||||
} else {
|
let kv_len = self.current_len;
|
||||||
attn_out
|
let k_view = self.k_cache.narrow(2, 0, kv_len).contiguous();
|
||||||
};
|
let v_view = self.v_cache.narrow(2, 0, kv_len).contiguous();
|
||||||
|
|
||||||
|
// 6. Attention: q [1, num_q_heads, 1, head_dim] × k/v [1, num_kv_heads, kv_len, head_dim]
|
||||||
|
let q_4d = q_3d.reshape(&[1, self.num_heads, 1, self.head_dim]);
|
||||||
|
let attn_out = decode_attention(&q_4d, &k_view, &v_view);
|
||||||
|
|
||||||
|
// 7. Merge heads and o_proj.
|
||||||
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
|
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
|
||||||
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); // [1, hidden]
|
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt);
|
||||||
|
|
||||||
// Residual from embedding
|
// 8. Post-attn fused add_rmsnorm.
|
||||||
let x = add(&attn_proj, &emb);
|
let (mlp_in, residual) =
|
||||||
|
add_rmsnorm(&attn_proj, &residual, &self.post_attention_layernorm, eps);
|
||||||
|
|
||||||
// 5. MLP
|
// 9. MLP.
|
||||||
let normed = rmsnorm(&x, &self.post_attention_layernorm, eps);
|
let gate = matmul_2d(&mlp_in, &self.gate_proj_wt);
|
||||||
let gate = matmul_2d(&normed, &self.gate_proj_wt);
|
let up = matmul_2d(&mlp_in, &self.up_proj_wt);
|
||||||
let up = matmul_2d(&normed, &self.up_proj_wt);
|
let hidden = silu_mul(&gate, &up);
|
||||||
let mlp_out = silu_mul(&gate, &up);
|
let down = matmul_2d(&hidden, &self.down_proj_wt);
|
||||||
let down = matmul_2d(&mlp_out, &self.down_proj_wt);
|
|
||||||
let x = add(&x, &down);
|
|
||||||
|
|
||||||
// 6. Final norm + lm_head
|
// 10. Final fused add_rmsnorm → lm_head.
|
||||||
let x = rmsnorm(&x, &self.norm, eps);
|
let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps);
|
||||||
let logits = matmul_2d(&x, &self.lm_head_wt); // [1, 32000]
|
let logits = matmul_2d(&x, &self.lm_head_wt);
|
||||||
|
|
||||||
// 7. Argmax in draft vocab → map to target vocab
|
|
||||||
let draft_id = argmax_bf16_single(&logits);
|
let draft_id = argmax_bf16_single(&logits);
|
||||||
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
||||||
|
|
||||||
(target_id, logits)
|
(target_id, logits)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position
|
||||||
|
/// `current_len` inside the [1, num_kv_heads, max_seq_len, head_dim] cache.
|
||||||
|
fn append_to_kv_cache(&mut self, new_k: &Tensor, new_v: &Tensor) {
|
||||||
|
let head_bytes = self.head_dim * self.k_cache.dtype().size_bytes();
|
||||||
|
for h in 0..self.num_kv_heads {
|
||||||
|
for (cache, src) in [(&self.k_cache, new_k), (&self.v_cache, new_v)] {
|
||||||
|
let dst = unsafe {
|
||||||
|
(cache.data_ptr() as *mut u8)
|
||||||
|
.add(((h * self.max_seq_len) + self.current_len) * head_bytes)
|
||||||
|
};
|
||||||
|
let s = unsafe { (src.data_ptr() as *const u8).add(h * head_bytes) };
|
||||||
|
d2d(dst, s, head_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Map a draft-vocab token id to the full target-vocab id via d2t.
|
/// Map a draft-vocab token id to the full target-vocab id via d2t.
|
||||||
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
|
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
|
||||||
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
|
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
|
||||||
|
|||||||
Reference in New Issue
Block a user