diff --git a/crates/xserv-model/src/bin/bench-eagle3.rs b/crates/xserv-model/src/bin/bench-eagle3.rs index 2ee2a3d..a07db72 100644 --- a/crates/xserv-model/src/bin/bench-eagle3.rs +++ b/crates/xserv-model/src/bin/bench-eagle3.rs @@ -93,6 +93,7 @@ fn main() { let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len()); let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN); let device = arg_usize(&args, "--device", 0) as u32; + let gamma = arg_usize(&args, "--gamma", 2).max(1); xserv_cuda::device::set_device(device).unwrap(); let info = xserv_cuda::device::device_info(device).unwrap(); @@ -122,7 +123,9 @@ fn main() { let _ = run_baseline(&target, &mut cache, &tokenizer, &ids, 4); drop(cache); } - eprintln!("Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}\n"); + eprintln!( + "Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}\n" + ); let mut baseline_total_s = 0.0f64; let mut baseline_tokens = 0usize; @@ -147,17 +150,30 @@ fn main() { baseline_tokens += baseline.ids.len(); drop(baseline_cache); - // Speculative with EAGLE γ=1. + // Speculative with EAGLE, γ from CLI. let mut target_cache = new_cache(&target_config, max_seq_len, device); - let spec = run_eagle_gamma1( - &target, - &mut eagle, - &embed_tokens, - &mut target_cache, - &tokenizer, - &ids, - gen_tokens, - ); + let spec = if gamma == 1 { + run_eagle_gamma1( + &target, + &mut eagle, + &embed_tokens, + &mut target_cache, + &tokenizer, + &ids, + gen_tokens, + ) + } else { + run_eagle_gamma_multi( + &target, + &mut eagle, + &embed_tokens, + &mut target_cache, + &tokenizer, + &ids, + gen_tokens, + gamma, + ) + }; spec_total_s += spec.total_s; spec_tokens += spec.ids.len(); spec_accepted += spec.accepted; @@ -336,6 +352,148 @@ fn run_eagle_gamma1( } } +/// γ≥2 speculative with recursive EAGLE drafting + batched target verify. +/// +/// Per round (state entering: prev_token committed at position pos-1, and +/// prev_seed_hooks = target hidden states at position pos-1 from previous +/// round's verify): +/// +/// 1. EAGLE γ recursive drafts using step_with_aux + step_recursive. +/// 2. Verify: target forward on [prev_token, d0..d_{γ-2}] to get γ logits. +/// verify_argmax[i] = target's answer at position pos+i. Longest matching +/// prefix accepts d[i] iff d[i] == verify_argmax[i]. +/// 3. Correction: verify_argmax[accepted] IS target's correction for the +/// first rejected position (or verify_argmax[γ-1] if all γ accepted). +/// Commit accepted tokens + correction. +/// 4. Seed next round with hooks at the position of the last committed token. +/// Take hooks from verify_hooks at index `accepted` (0-indexed). +#[allow(clippy::too_many_arguments)] +fn run_eagle_gamma_multi( + target: &Qwen3, + eagle: &mut Eagle3Head, + embed_tokens: &Tensor, + cache: &mut PagedKVCache, + tokenizer: &Tokenizer, + prompt_ids: &[u32], + gen_tokens: usize, + gamma: usize, +) -> RunStats { + use xserv_model::eagle3::EAGLE_HOOK_LAYERS; + let slot = 0; + cache.register_sequence(slot).unwrap(); + eagle.reset(); + let t0 = Instant::now(); + + // Prefill. Advance one target decode step to seed hidden hooks. + let prefill_logits = target.forward_prefill_paged(prompt_ids, slot, cache); + let first_token = last_argmax(&prefill_logits); + let mut generated = vec![first_token]; + + // Seed: run one target decode with first_token to get hooks at that position. + let (logits, mut seed_hooks) = target_decode_with_hidden(target, first_token, cache, slot); + let mut prev_token = last_argmax(&logits); + generated.push(prev_token); + let mut target_steps = 2; // 1 prefill + 1 decode seed + let mut accepted_total = 0usize; + let mut proposed_total = 0usize; + + while generated.len() < gen_tokens && !tokenizer.is_eos(prev_token) { + let round_pos = cache.seq_len(slot); // position of next token to predict + let remaining = gen_tokens - generated.len(); + let round_gamma = gamma.min(remaining); + + // Draft γ tokens recursively. + let mut drafts: Vec = Vec::with_capacity(round_gamma); + let (d0, _, aux0) = eagle.step_with_aux(&seed_hooks, embed_tokens, prev_token, round_pos); + drafts.push(d0); + let mut prev_aux = aux0; + let mut prev_draft = d0; + for k in 1..round_gamma { + let (dk, _, auxk) = + eagle.step_recursive(prev_aux, embed_tokens, prev_draft, round_pos + k); + drafts.push(dk); + prev_aux = auxk; + prev_draft = dk; + } + proposed_total += round_gamma; + + // Verify: target forward on [prev_token, d0..d_{γ-1}] (length γ+1) to + // get γ+1 logits. verify_argmax[i] predicts position round_pos+i. + // Accept d[i] iff d[i] == verify_argmax[i]. The extra verify_argmax[γ] + // provides a free correction for the all-accepted case. + let mut verify_input: Vec = Vec::with_capacity(round_gamma + 1); + verify_input.push(prev_token); + for &d in drafts.iter() { + verify_input.push(d); + } + // Roll back the seed decode (we wrote prev_token at round_pos-1). + cache + .truncate_sequence(slot, round_pos - 1) + .expect("truncate before verify"); + let (verify_logits, _verify_hooks) = target + .forward_verify_paged_decode_attention_with_hidden( + &verify_input, + slot, + cache, + &EAGLE_HOOK_LAYERS, + ); + target_steps += 1; + + // Longest matching prefix: d[i] accepted iff d[i] == argmax(verify_logits[i]). + let verify_argmax = argmax_rows(&verify_logits); + let mut accepted = 0usize; + while accepted < round_gamma && drafts[accepted] == verify_argmax[accepted] { + accepted += 1; + } + accepted_total += accepted; + + // Commit accepted tokens (already in cache at positions round_pos..round_pos+accepted-1). + // Truncate cache to that length. + cache + .truncate_sequence(slot, round_pos + accepted) + .expect("truncate to accepted"); + for i in 0..accepted { + generated.push(drafts[i]); + if generated.len() >= gen_tokens || tokenizer.is_eos(drafts[i]) { + break; + } + } + if generated.len() >= gen_tokens { + break; + } + if !generated.is_empty() && tokenizer.is_eos(*generated.last().unwrap()) { + break; + } + + // Correction: verify_argmax[accepted] is target's answer for position + // round_pos+accepted. Write its K/V via one target decode. This also + // gives us seed_hooks for the next round. + let correction = verify_argmax[accepted]; + let (_, new_hooks) = target_decode_with_hidden(target, correction, cache, slot); + target_steps += 1; + generated.push(correction); + prev_token = correction; + // Extract the position-`accepted` slice of verify_hooks as the seed for next round? + // Actually the correction just wrote K/V AT position round_pos+accepted, and + // its hidden state = new_hooks. We use new_hooks as the seed for the next round. + seed_hooks = new_hooks; + } + sync_device(); + let total_s = t0.elapsed().as_secs_f64(); + cache.free_sequence(slot); + RunStats { + ids: generated, + total_s, + target_steps, + accepted: accepted_total, + proposed: proposed_total, + } +} + +fn argmax_rows(logits: &Tensor) -> Vec { + xserv_kernels::argmax_bf16_to_host(logits) +} + fn target_decode_with_hidden( target: &Qwen3, token: u32, @@ -381,5 +539,5 @@ fn arg_usize(args: &[String], flag: &str, default: usize) -> usize { fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache { let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2; - PagedKVCache::new(config, num_blocks, 0, 1, num_blocks, DType::BF16, device) + PagedKVCache::new(config, num_blocks, 0, 16, num_blocks, DType::BF16, device) } diff --git a/crates/xserv-model/src/eagle3.rs b/crates/xserv-model/src/eagle3.rs index e00e57a..0addc71 100644 --- a/crates/xserv-model/src/eagle3.rs +++ b/crates/xserv-model/src/eagle3.rs @@ -165,6 +165,44 @@ impl Eagle3Head { prev_token: u32, position: usize, ) -> (u32, Tensor) { + let (id, logits, _) = self.step_with_aux(target_hidden, embed_table, prev_token, position); + (id, logits) + } + + /// Like `step`, but also returns the final hidden state (aux) usable as + /// the fused_h for a subsequent recursive draft step via `step_recursive`. + pub fn step_with_aux( + &mut self, + target_hidden: &[Tensor; 3], + embed_table: &Tensor, + prev_token: u32, + position: usize, + ) -> (u32, Tensor, Tensor) { + // Fuse 3 target hidden states into fused_h via fc. + let h_cat = concat_hidden(target_hidden); + let fused_h = matmul_2d(&h_cat, &self.fc_wt); + self.forward_from_fused(fused_h, embed_table, prev_token, position) + } + + /// Recursive draft step: reuses the previous EAGLE step's aux as fused_h, + /// bypassing the fc+3-hidden fusion. Used for γ≥2 chained drafts. + pub fn step_recursive( + &mut self, + fused_h: Tensor, + embed_table: &Tensor, + prev_token: u32, + position: usize, + ) -> (u32, Tensor, Tensor) { + self.forward_from_fused(fused_h, embed_table, prev_token, position) + } + + fn forward_from_fused( + &mut self, + fused_h: Tensor, + embed_table: &Tensor, + prev_token: u32, + position: usize, + ) -> (u32, Tensor, Tensor) { let eps = 1e-6f32; assert!( self.current_len < self.max_seq_len, @@ -173,20 +211,12 @@ impl Eagle3Head { self.max_seq_len ); - // 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc - let h_cat = concat_hidden(target_hidden); - let fused_h = matmul_2d(&h_cat, &self.fc_wt); // [1, hidden] - - // 2. Embed previous token (shared with target) - let emb = embedding(embed_table, &[prev_token]); // [1, hidden] - - // 3. Norm both, concat, remember residual = fused_h (pre-norm). + let emb = embedding(embed_table, &[prev_token]); let residual = fused_h.clone(); let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps); let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps); - let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 2*hidden] + let attn_in = concat_last_dim(&emb_normed, &h_normed); - // 4. Q/K/V projection then RoPE (position from caller). let q = matmul_2d(&attn_in, &self.q_proj_wt); let k = matmul_2d(&attn_in, &self.k_proj_wt); let v = matmul_2d(&attn_in, &self.v_proj_wt); @@ -197,8 +227,6 @@ impl Eagle3Head { rope_inplace(&q_3d, &self.rope_cache, &positions); rope_inplace(&k_3d, &self.rope_cache, &positions); - // 5. Append new K/V to the internal cache at slot `current_len`, then - // build a contiguous view [1, num_kv_heads, current_len+1, head_dim]. let v_3d = v.reshape(&[1, self.num_kv_heads, self.head_dim]); self.append_to_kv_cache(&k_3d, &v_3d); self.current_len += 1; @@ -206,31 +234,26 @@ impl Eagle3Head { let k_view = self.k_cache.narrow(2, 0, kv_len).contiguous(); let v_view = self.v_cache.narrow(2, 0, kv_len).contiguous(); - // 6. Attention: q [1, num_q_heads, 1, head_dim] × k/v [1, num_kv_heads, kv_len, head_dim] let q_4d = q_3d.reshape(&[1, self.num_heads, 1, self.head_dim]); let attn_out = decode_attention(&q_4d, &k_view, &v_view); - // 7. Merge heads and o_proj. let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]); let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); - // 8. Post-attn fused add_rmsnorm. let (mlp_in, residual) = add_rmsnorm(&attn_proj, &residual, &self.post_attention_layernorm, eps); - // 9. MLP. let gate = matmul_2d(&mlp_in, &self.gate_proj_wt); let up = matmul_2d(&mlp_in, &self.up_proj_wt); let hidden = silu_mul(&gate, &up); let down = matmul_2d(&hidden, &self.down_proj_wt); - // 10. Final fused add_rmsnorm → lm_head. let (x, _) = add_rmsnorm(&down, &residual, &self.norm, eps); let logits = matmul_2d(&x, &self.lm_head_wt); let draft_id = argmax_bf16_single(&logits); let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32; - (target_id, logits) + (target_id, logits, x) } /// Write new K/V rows (shape [1, num_kv_heads, head_dim]) at position diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index 1806cb0..4707801 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -1113,6 +1113,123 @@ impl Qwen3 { matmul_batched_gemv(&x, &self.lm_head_t) } + /// Like `forward_verify_paged_decode_attention`, but also captures hidden + /// states at 3 layer indices (per position). Returns + /// (logits [new_tokens, vocab], hooks [3][new_tokens, hidden]). Used by + /// EAGLE3 speculative γ≥2 verify path so we can seed the next round's + /// EAGLE draft with target's real hidden states at the accepted position. + pub fn forward_verify_paged_decode_attention_with_hidden( + &self, + token_ids: &[u32], + slot: usize, + paged_cache: &mut PagedKVCache, + hook_layers: &[usize; 3], + ) -> (Tensor, [Tensor; 3]) { + let new_tokens = token_ids.len(); + let pos_offset = paged_cache.seq_len(slot); + let num_heads = self.local_num_heads; + let num_kv_heads = self.local_num_kv_heads; + let head_dim = self.config.head_dim(); + let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; + + paged_cache.ensure_capacity(slot, pos_offset + new_tokens); + paged_cache.advance_seq_len(slot, new_tokens); + + let positions: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); + let kv_lens: Vec = (0..new_tokens) + .map(|i| (pos_offset + i + 1) as i32) + .collect(); + let slots = vec![slot; new_tokens]; + paged_cache.sync_active_batch_with_lens(&slots, &kv_lens); + let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32; + let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32; + let max_blocks = paged_cache.max_blocks_per_seq(); + + let mut x = embedding(&self.embed_tokens, token_ids); + let mut hooks: [Option; 3] = [None, None, None]; + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let residual = x.clone(); + let normed = rmsnorm(&x, &layer.input_norm, eps); + + let qkv = matmul_batched_gemv(&normed, &layer.qkv_proj_wt); + let q_dim = num_heads * head_dim; + let kv_dim = num_kv_heads * head_dim; + let q_all = qkv.narrow(1, 0, q_dim); + let k_all = qkv.narrow(1, q_dim, kv_dim); + let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim); + + let q_flat = q_all + .contiguous() + .reshape(&[new_tokens * num_heads, head_dim]); + let k_flat = k_all + .contiguous() + .reshape(&[new_tokens * num_kv_heads, head_dim]); + let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps); + let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps); + + let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]); + let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]); + rope_inplace(&q_3d, &self.rope_cache, &positions); + rope_inplace(&k_3d, &self.rope_cache, &positions); + + let v_3d = v_all + .contiguous() + .reshape(&[new_tokens, num_kv_heads, head_dim]); + paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens); + + let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]); + let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void; + let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void; + let attn_out = xserv_kernels::paged_decode_attention( + &q_decode, + k_pool_ptr, + v_pool_ptr, + bt_ptr, + cl_ptr, + new_tokens, + num_heads, + num_kv_heads, + head_dim, + max_blocks, + ); + + let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]); + let attn_proj = matmul_batched_gemv(&attn_merged, &layer.o_proj_wt); + self.all_reduce(&attn_proj); + + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let residual = x_new.clone(); + + let gate_up = matmul_batched_gemv(&normed, &layer.gate_up_proj_wt); + let ffn_dim = gate_up.shape()[1] / 2; + let gate = gate_up.narrow(1, 0, ffn_dim).contiguous(); + let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous(); + let hidden_states = xserv_kernels::silu_mul(&gate, &up); + let down = matmul_batched_gemv(&hidden_states, &layer.down_proj_wt); + self.all_reduce(&down); + x = add_any(&residual, &down); + + for (h_idx, &h_layer) in hook_layers.iter().enumerate() { + if layer_idx == h_layer { + hooks[h_idx] = Some(x.clone()); + } + } + } + + let x = rmsnorm(&x, &self.norm, eps); + let logits = matmul_batched_gemv(&x, &self.lm_head_t); + let hidden_arr = [ + hooks[0].take().expect("hook layer 0 not reached"), + hooks[1].take().expect("hook layer 1 not reached"), + hooks[2].take().expect("hook layer 2 not reached"), + ]; + (logits, hidden_arr) + } + /// Forward with GPU-resident KV cache and GPU transpose/reshape kernels. pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor { let new_tokens = token_ids.len();