diff --git a/crates/xserv-model/src/bin/bench-eagle3.rs b/crates/xserv-model/src/bin/bench-eagle3.rs index 9dbc46d..2ee2a3d 100644 --- a/crates/xserv-model/src/bin/bench-eagle3.rs +++ b/crates/xserv-model/src/bin/bench-eagle3.rs @@ -109,7 +109,7 @@ fn main() { xserv_cuda::allocator::cached_trim(); eprintln!("Loading EAGLE3 head..."); - let eagle = Eagle3Head::load(&eagle_dir, device); + let mut eagle = Eagle3Head::load(&eagle_dir, device); xserv_cuda::allocator::cached_trim(); let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json")); @@ -151,7 +151,7 @@ fn main() { let mut target_cache = new_cache(&target_config, max_seq_len, device); let spec = run_eagle_gamma1( &target, - &eagle, + &mut eagle, &embed_tokens, &mut target_cache, &tokenizer, @@ -279,7 +279,7 @@ fn run_baseline( /// For γ=1 we simply track acceptance rate for informational purposes. fn run_eagle_gamma1( target: &Qwen3, - eagle: &Eagle3Head, + eagle: &mut Eagle3Head, embed_tokens: &Tensor, cache: &mut PagedKVCache, tokenizer: &Tokenizer, @@ -288,6 +288,7 @@ fn run_eagle_gamma1( ) -> RunStats { let slot = 0; cache.register_sequence(slot).unwrap(); + eagle.reset(); let t0 = Instant::now(); // Prefill target — we don't have hidden state hooks from prefill in this diff --git a/crates/xserv-model/src/bin/check-eagle3.rs b/crates/xserv-model/src/bin/check-eagle3.rs index 8e48afe..3d33230 100644 --- a/crates/xserv-model/src/bin/check-eagle3.rs +++ b/crates/xserv-model/src/bin/check-eagle3.rs @@ -38,7 +38,7 @@ fn main() { xserv_cuda::allocator::cached_trim(); eprintln!("Loading EAGLE3 head from {}", eagle_dir.display()); - let eagle = Eagle3Head::load(&eagle_dir, device); + let mut eagle = Eagle3Head::load(&eagle_dir, device); xserv_cuda::allocator::cached_trim(); let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json")); @@ -103,6 +103,7 @@ fn main() { // Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token // and the hidden states from the position where target_first lives). // EAGLE should predict target_next (or close to it) to be useful. + eagle.reset(); let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos); let eagle_pred_text = tokenizer.decode(&[eagle_pred]); println!( @@ -129,6 +130,7 @@ fn main() { // Alternative pairing B: pair hooks with target_next (the token those hooks produced // via lm_head), predict token after target_next. Position advances by 1. + eagle.reset(); let (eagle_pred_b, eagle_logits_b) = eagle.step(&hooks, embed_tokens, target_next, pos + 1); let eagle_pred_b_text = tokenizer.decode(&[eagle_pred_b]); println!( diff --git a/crates/xserv-model/src/eagle3.rs b/crates/xserv-model/src/eagle3.rs index 83bfa57..e00e57a 100644 --- a/crates/xserv-model/src/eagle3.rs +++ b/crates/xserv-model/src/eagle3.rs @@ -47,7 +47,14 @@ pub struct Eagle3Head { num_heads: usize, num_kv_heads: usize, head_dim: usize, + max_seq_len: usize, rope_cache: RopeCache, + // Stateful 1-layer KV cache: [1, num_kv_heads, max_seq_len, head_dim] BF16. + // We slice `..current_len` for attention. The head is tiny (~64 KB per + // 1000 tokens) so pre-allocating max_seq_len wastes negligible memory. + k_cache: Tensor, + v_cache: Tensor, + current_len: usize, } impl Eagle3Head { @@ -100,6 +107,17 @@ impl Eagle3Head { let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta); + let k_cache = Tensor::zeros( + &[1, num_kv_heads, max_seq_len, head_dim], + DType::BF16, + Device::Cuda(device), + ); + let v_cache = Tensor::zeros( + &[1, num_kv_heads, max_seq_len, head_dim], + DType::BF16, + Device::Cuda(device), + ); + Self { fc_wt, hidden_norm, @@ -119,10 +137,19 @@ impl Eagle3Head { num_heads, num_kv_heads, head_dim, + max_seq_len, rope_cache, + k_cache, + v_cache, + current_len: 0, } } + /// Reset the internal KV cache for a fresh sequence. + pub fn reset(&mut self) { + self.current_len = 0; + } + /// One draft step: produce a token in target vocabulary space. /// /// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers @@ -132,13 +159,19 @@ impl Eagle3Head { /// /// Returns (draft_token_in_target_vocab, draft_logits_tensor). pub fn step( - &self, + &mut self, target_hidden: &[Tensor; 3], embed_table: &Tensor, prev_token: u32, position: usize, ) -> (u32, Tensor) { let eps = 1e-6f32; + assert!( + self.current_len < self.max_seq_len, + "EAGLE KV cache overflow: {} >= {}", + self.current_len, + self.max_seq_len + ); // 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc let h_cat = concat_hidden(target_hidden); @@ -147,15 +180,16 @@ impl Eagle3Head { // 2. Embed previous token (shared with target) let emb = embedding(embed_table, &[prev_token]); // [1, hidden] - // 3. Concat normalized: [norm(emb), norm(fused_h)] → [1, 2*hidden] + // 3. Norm both, concat, remember residual = fused_h (pre-norm). + let residual = fused_h.clone(); let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps); let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps); - let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 8192] + let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 2*hidden] - // 4. Self-attention (no KV cache for simplicity in v0 — single query) - let q = matmul_2d(&attn_in, &self.q_proj_wt); // [1, num_heads*head_dim] - let k = matmul_2d(&attn_in, &self.k_proj_wt); // [1, num_kv*head_dim] - let v = matmul_2d(&attn_in, &self.v_proj_wt); // [1, num_kv*head_dim] + // 4. Q/K/V projection then RoPE (position from caller). + let q = matmul_2d(&attn_in, &self.q_proj_wt); + let k = matmul_2d(&attn_in, &self.k_proj_wt); + let v = matmul_2d(&attn_in, &self.v_proj_wt); let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]); let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]); @@ -163,39 +197,58 @@ impl Eagle3Head { rope_inplace(&q_3d, &self.rope_cache, &positions); rope_inplace(&k_3d, &self.rope_cache, &positions); - // Single-token attention: Q·K^T / sqrt(d) → softmax → V - // With seq_len=1, attention is trivial: output = V (weight=1.0) - let attn_out = v.reshape(&[1, self.num_kv_heads, self.head_dim]); - let attn_out = if self.num_heads != self.num_kv_heads { - repeat_kv_for_single_token(&attn_out, self.num_heads / self.num_kv_heads) - } else { - attn_out - }; + // 5. Append new K/V to the internal cache at slot `current_len`, then + // build a contiguous view [1, num_kv_heads, current_len+1, head_dim]. + let v_3d = v.reshape(&[1, self.num_kv_heads, self.head_dim]); + self.append_to_kv_cache(&k_3d, &v_3d); + self.current_len += 1; + let kv_len = self.current_len; + let k_view = self.k_cache.narrow(2, 0, kv_len).contiguous(); + let v_view = self.v_cache.narrow(2, 0, kv_len).contiguous(); + + // 6. Attention: q [1, num_q_heads, 1, head_dim] × k/v [1, num_kv_heads, kv_len, head_dim] + let q_4d = q_3d.reshape(&[1, self.num_heads, 1, self.head_dim]); + let attn_out = decode_attention(&q_4d, &k_view, &v_view); + + // 7. Merge heads and o_proj. let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]); - let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); // [1, hidden] + let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); - // Residual from embedding - let x = add(&attn_proj, &emb); + // 8. Post-attn fused add_rmsnorm. + let (mlp_in, residual) = + add_rmsnorm(&attn_proj, &residual, &self.post_attention_layernorm, eps); - // 5. MLP - let normed = rmsnorm(&x, &self.post_attention_layernorm, eps); - let gate = matmul_2d(&normed, &self.gate_proj_wt); - let up = matmul_2d(&normed, &self.up_proj_wt); - let mlp_out = silu_mul(&gate, &up); - let down = matmul_2d(&mlp_out, &self.down_proj_wt); - let x = add(&x, &down); + // 9. MLP. + let gate = matmul_2d(&mlp_in, &self.gate_proj_wt); + let up = matmul_2d(&mlp_in, &self.up_proj_wt); + let hidden = silu_mul(&gate, &up); + let down = matmul_2d(&hidden, &self.down_proj_wt); - // 6. Final norm + lm_head - let x = rmsnorm(&x, &self.norm, eps); - let logits = matmul_2d(&x, &self.lm_head_wt); // [1, 32000] + // 10. Final fused add_rmsnorm → lm_head. + let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps); + let logits = matmul_2d(&x, &self.lm_head_wt); - // 7. Argmax in draft vocab → map to target vocab let draft_id = argmax_bf16_single(&logits); let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32; - (target_id, logits) } + /// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position + /// `current_len` inside the [1, num_kv_heads, max_seq_len, head_dim] cache. + fn append_to_kv_cache(&mut self, new_k: &Tensor, new_v: &Tensor) { + let head_bytes = self.head_dim * self.k_cache.dtype().size_bytes(); + for h in 0..self.num_kv_heads { + for (cache, src) in [(&self.k_cache, new_k), (&self.v_cache, new_v)] { + let dst = unsafe { + (cache.data_ptr() as *mut u8) + .add(((h * self.max_seq_len) + self.current_len) * head_bytes) + }; + let s = unsafe { (src.data_ptr() as *const u8).add(h * head_bytes) }; + d2d(dst, s, head_bytes); + } + } + } + /// Map a draft-vocab token id to the full target-vocab id via d2t. pub fn map_draft_to_target(&self, draft_id: u32) -> u32 { (draft_id as i64 + self.d2t[draft_id as usize]) as u32