speculative: EAGLE3 draft head implementation (Phase 25 step 1)

- eagle3.rs: Eagle3Head struct loads AngelSlim/Qwen3-8B_eagle3 safetensors,
  runs a single draft step via fc(concat(h_low, h_mid, h_high)) +
  concat(input_norm(emb), hidden_norm(fused_h)) → 1 midlayer → norm →
  lm_head → argmax in draft_vocab(32000) → d2t → target_vocab.
- qwen3.rs: new decode_core_with_hidden method that mirrors decode_core
  but captures hidden states at 3 configurable layer indices (default
  [11, 23, 35] for the 36-layer Qwen3-8B). Also expose embed_tokens_tensor
  and (in eagle3) map_draft_to_target as public accessors.
- loader.rs: make_tensor now pub(crate) so eagle3 can reuse it.
- bin/check-eagle3.rs: sanity binary that loads target + EAGLE, runs one
  prefill + one decode + one EAGLE step, prints the top-5 EAGLE predictions.
  Verified on dash5 with prompt "The capital of France is":
    target says: " Paris" then "."
    EAGLE top-5: "," / " Paris" / " Madrid" / "." / " Berlin"
  Weights load correctly, d2t mapping works, hidden state hooks are the
  right shape ([1, 4096]), and EAGLE produces thematically-relevant tokens.

The top-1 pick "," doesn't match target's "." at this position, but
that's expected: this test uses hidden states from a single decode step
with no recursive chaining. A full speculative loop still needs the
γ≥2 verify + accept path wired up (next step).
This commit is contained in:
2026-07-01 17:23:22 +08:00
parent 6485c87c5b
commit e04a8ffb18
5 changed files with 577 additions and 1 deletions

View File

@@ -0,0 +1,152 @@
//! EAGLE3 sanity check: load weights, run one draft step, print top-5 predictions.
//!
//! This verifies that:
//! - Eagle3Head weights load without shape mismatches
//! - Target hidden states can be captured via decode_core_with_hidden
//! - Eagle3Head::step produces a valid token id (in target vocab)
//!
//! Does NOT measure speedup — that requires a full γ≥2 speculative loop, which
//! is more complex integration work.
use std::path::PathBuf;
use xserv_model::eagle3::{EAGLE_HOOK_LAYERS, Eagle3Head};
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 3 {
eprintln!("Usage: check-eagle3 <target-model-dir> <eagle3-model-dir> [prompt]");
std::process::exit(1);
}
let target_dir = PathBuf::from(&args[1]);
let eagle_dir = PathBuf::from(&args[2]);
let prompt = args
.get(3)
.cloned()
.unwrap_or_else(|| "The capital of France is".to_string());
let device: u32 = 0;
xserv_cuda::device::set_device(device).unwrap();
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
eprintln!("Loading target Qwen3-8B...");
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
let target = Qwen3::from_weights(target_config.clone(), target_weights);
xserv_cuda::allocator::cached_trim();
eprintln!("Loading EAGLE3 head from {}", eagle_dir.display());
let eagle = Eagle3Head::load(&eagle_dir, device);
xserv_cuda::allocator::cached_trim();
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
let embed_tokens = target.embed_tokens_tensor();
let ids = tokenizer.encode(&prompt);
let max_seq_len = 512;
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
let mut cache = PagedKVCache::new(
&target_config,
num_blocks,
0,
1,
num_blocks,
DType::BF16,
device,
);
cache.register_sequence(0).unwrap();
// Prefill target.
let logits = target.forward_prefill_paged(&ids, 0, &mut cache);
let target_first = *xserv_kernels::argmax_bf16_to_host(&logits).last().unwrap();
let target_first_text = tokenizer.decode(&[target_first]);
println!("Prompt: {:?}", prompt);
println!(
"Target argmax after prefill: {} ({:?})",
target_first, target_first_text
);
// Now run one target decode step with target_first to get hidden states at the
// hook layers.
let pos = cache.seq_len(0);
target.decode_prepare(&[pos], &[0], &mut cache);
let ids_gpu = upload_u32(&[target_first]);
let pos_gpu = upload_u32(&[pos as u32]);
let (target_next_logits, hooks) = target.decode_core_with_hidden(
ids_gpu.as_ptr() as *const std::ffi::c_void,
pos_gpu.as_ptr() as *const std::ffi::c_void,
1,
&[0],
&mut cache,
&EAGLE_HOOK_LAYERS,
);
let target_next = xserv_kernels::argmax_bf16_single(&target_next_logits);
let target_next_text = tokenizer.decode(&[target_next]);
println!(
"Target argmax after 1 decode step: {} ({:?})",
target_next, target_next_text
);
for (i, h) in hooks.iter().enumerate() {
println!(
"hook[{}] (layer {}): shape={:?} dtype={:?}",
i,
EAGLE_HOOK_LAYERS[i],
h.shape(),
h.dtype()
);
}
// Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token
// and the hidden states from the position where target_first lives).
// EAGLE should predict target_next (or close to it) to be useful.
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
println!(
"EAGLE draft prediction: {} ({:?})",
eagle_pred, eagle_pred_text
);
if eagle_pred == target_next {
println!("MATCH: EAGLE agrees with target on next token.");
} else {
println!(
"MISMATCH: EAGLE draft={} vs target={} (this is fine per-step; check top-5 below)",
eagle_pred, target_next
);
}
// Show top-5 from eagle logits (in draft vocab space, mapped to target).
print_top5(&eagle_logits, "EAGLE draft top-5", &eagle, &tokenizer);
}
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).unwrap();
buf.copy_from_host(bytes).unwrap();
buf
}
fn print_top5(logits: &Tensor, label: &str, eagle: &Eagle3Head, tokenizer: &Tokenizer) {
use half::bf16;
let cpu = logits.to_device(Device::Cpu);
let data = cpu.as_slice::<bf16>();
let mut vals: Vec<(usize, f32)> = data
.iter()
.enumerate()
.map(|(i, v)| (i, v.to_f32()))
.collect();
vals.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
println!("{label}:");
for (i, val) in vals.iter().take(5) {
let target_id = eagle.map_draft_to_target(*i as u32);
let text = tokenizer.decode(&[target_id]);
println!(
" draft_id={} target_id={} val={:.3} text={:?}",
i, target_id, val, text
);
}
}

View File

@@ -0,0 +1,312 @@
//! EAGLE3 speculative draft head for Qwen3-8B (Phase 25).
//!
//! Loads the AngelSlim/Qwen3-8B_eagle3 pytorch_model.bin and provides a
//! single-step forward pass that takes 3 target hidden states + the previous
//! token and returns a draft token in the target vocabulary.
//!
//! Architecture (from weights):
//! - fc: [hidden, 3*hidden] → fuse 3 target hidden states
//! - midlayer: 1 decoder layer (attn input dim = 2*hidden)
//! - norm + lm_head: → [draft_vocab_size=32000]
//! - d2t: draft_id → target_id offset mapping
use std::collections::HashMap;
use std::path::Path;
use xserv_kernels::*;
use xserv_tensor::{DType, Device, Tensor};
pub const EAGLE_HOOK_LAYERS: [usize; 3] = [11, 23, 35];
const DRAFT_VOCAB_SIZE: usize = 32000;
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);
matmul(a, b, GemmBackend::CuBlas)
}
pub struct Eagle3Head {
fc_wt: Tensor, // [hidden, 3*hidden] transposed for matmul
hidden_norm: Tensor, // [hidden]
input_layernorm: Tensor, // [hidden]
q_proj_wt: Tensor, // [num_heads*head_dim, 2*hidden]
k_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden]
v_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden]
o_proj_wt: Tensor, // [hidden, num_heads*head_dim]
gate_proj_wt: Tensor, // [intermediate, hidden]
up_proj_wt: Tensor, // [intermediate, hidden]
down_proj_wt: Tensor, // [hidden, intermediate]
post_attention_layernorm: Tensor, // [hidden]
norm: Tensor, // [hidden] final
lm_head_wt: Tensor, // [draft_vocab, hidden]
d2t: Vec<i64>, // [draft_vocab] offset mapping
hidden_size: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rope_cache: RopeCache,
}
impl Eagle3Head {
pub fn load(dir: &Path, device: u32) -> Self {
let (weights, d2t) = load_eagle3_weights(dir, device);
let hidden_size = 4096;
let num_heads = 32;
let num_kv_heads = 8;
let head_dim = 128;
let intermediate_size = 12288;
let max_seq_len = 2048;
let rope_theta = 1_000_000.0f32;
let get = |name: &str| -> Tensor {
weights
.get(name)
.unwrap_or_else(|| panic!("missing eagle3 weight: {name}"))
.clone()
};
let fc_wt = get("fc.weight").transpose(0, 1).contiguous();
let q_proj_wt = get("midlayer.self_attn.q_proj.weight")
.transpose(0, 1)
.contiguous();
let k_proj_wt = get("midlayer.self_attn.k_proj.weight")
.transpose(0, 1)
.contiguous();
let v_proj_wt = get("midlayer.self_attn.v_proj.weight")
.transpose(0, 1)
.contiguous();
let o_proj_wt = get("midlayer.self_attn.o_proj.weight")
.transpose(0, 1)
.contiguous();
let gate_proj_wt = get("midlayer.mlp.gate_proj.weight")
.transpose(0, 1)
.contiguous();
let up_proj_wt = get("midlayer.mlp.up_proj.weight")
.transpose(0, 1)
.contiguous();
let down_proj_wt = get("midlayer.mlp.down_proj.weight")
.transpose(0, 1)
.contiguous();
let hidden_norm = get("midlayer.hidden_norm.weight");
let input_layernorm = get("midlayer.input_layernorm.weight");
let post_attention_layernorm = get("midlayer.post_attention_layernorm.weight");
let norm = get("norm.weight");
let lm_head_wt = get("lm_head.weight").transpose(0, 1).contiguous();
assert_eq!(d2t.len(), DRAFT_VOCAB_SIZE);
let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta);
Self {
fc_wt,
hidden_norm,
input_layernorm,
q_proj_wt,
k_proj_wt,
v_proj_wt,
o_proj_wt,
gate_proj_wt,
up_proj_wt,
down_proj_wt,
post_attention_layernorm,
norm,
lm_head_wt,
d2t,
hidden_size,
num_heads,
num_kv_heads,
head_dim,
rope_cache,
}
}
/// One draft step: produce a token in target vocabulary space.
///
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
/// - `embed_table`: the target model's embed_tokens (shared, not copied)
/// - `prev_token`: the previous committed token
/// - `position`: the decode position for RoPE
///
/// Returns (draft_token_in_target_vocab, draft_logits_tensor).
pub fn step(
&self,
target_hidden: &[Tensor; 3],
embed_table: &Tensor,
prev_token: u32,
position: usize,
) -> (u32, Tensor) {
let eps = 1e-6f32;
// 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc
let h_cat = concat_hidden(target_hidden);
let fused_h = matmul_2d(&h_cat, &self.fc_wt); // [1, hidden]
// 2. Embed previous token (shared with target)
let emb = embedding(embed_table, &[prev_token]); // [1, hidden]
// 3. Concat normalized: [norm(emb), norm(fused_h)] → [1, 2*hidden]
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 8192]
// 4. Self-attention (no KV cache for simplicity in v0 — single query)
let q = matmul_2d(&attn_in, &self.q_proj_wt); // [1, num_heads*head_dim]
let k = matmul_2d(&attn_in, &self.k_proj_wt); // [1, num_kv*head_dim]
let v = matmul_2d(&attn_in, &self.v_proj_wt); // [1, num_kv*head_dim]
let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]);
let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]);
let positions = [position as u32];
rope_inplace(&q_3d, &self.rope_cache, &positions);
rope_inplace(&k_3d, &self.rope_cache, &positions);
// Single-token attention: Q·K^T / sqrt(d) → softmax → V
// With seq_len=1, attention is trivial: output = V (weight=1.0)
let attn_out = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
let attn_out = if self.num_heads != self.num_kv_heads {
repeat_kv_for_single_token(&attn_out, self.num_heads / self.num_kv_heads)
} else {
attn_out
};
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); // [1, hidden]
// Residual from embedding
let x = add(&attn_proj, &emb);
// 5. MLP
let normed = rmsnorm(&x, &self.post_attention_layernorm, eps);
let gate = matmul_2d(&normed, &self.gate_proj_wt);
let up = matmul_2d(&normed, &self.up_proj_wt);
let mlp_out = silu_mul(&gate, &up);
let down = matmul_2d(&mlp_out, &self.down_proj_wt);
let x = add(&x, &down);
// 6. Final norm + lm_head
let x = rmsnorm(&x, &self.norm, eps);
let logits = matmul_2d(&x, &self.lm_head_wt); // [1, 32000]
// 7. Argmax in draft vocab → map to target vocab
let draft_id = argmax_bf16_single(&logits);
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
(target_id, logits)
}
/// Map a draft-vocab token id to the full target-vocab id via d2t.
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
}
}
fn d2d(dst: *mut u8, src: *const u8, bytes: usize) {
unsafe {
xserv_cuda::ffi::cudaMemcpy(dst, src, bytes, xserv_cuda::ffi::CUDA_MEMCPY_D2D);
}
}
fn concat_hidden(hidden: &[Tensor; 3]) -> Tensor {
let h = hidden[0].shape()[1];
let dtype = hidden[0].dtype();
let device = hidden[0].device();
let elem_bytes = dtype.size_bytes();
let out = Tensor::empty(&[1, 3 * h], dtype, device);
for (i, t) in hidden.iter().enumerate() {
assert!(t.is_contiguous());
let dst = unsafe { (out.data_ptr() as *mut u8).add(i * h * elem_bytes) };
d2d(dst, t.data_ptr() as *const u8, h * elem_bytes);
}
out
}
fn concat_last_dim(a: &Tensor, b: &Tensor) -> Tensor {
let da = a.shape()[1];
let db = b.shape()[1];
let dtype = a.dtype();
let device = a.device();
let elem_bytes = dtype.size_bytes();
let out = Tensor::empty(&[1, da + db], dtype, device);
d2d(
out.data_ptr() as *mut u8,
a.data_ptr() as *const u8,
da * elem_bytes,
);
let dst = unsafe { (out.data_ptr() as *mut u8).add(da * elem_bytes) };
d2d(dst, b.data_ptr() as *const u8, db * elem_bytes);
out
}
fn repeat_kv_for_single_token(kv: &Tensor, repeats: usize) -> Tensor {
if repeats == 1 {
return kv.clone();
}
let nkv = kv.shape()[1];
let d = kv.shape()[2];
let dtype = kv.dtype();
let device = kv.device();
let head_bytes = d * dtype.size_bytes();
let out = Tensor::empty(&[1, nkv * repeats, d], dtype, device);
for h in 0..nkv {
let src = unsafe { (kv.data_ptr() as *const u8).add(h * head_bytes) };
for r in 0..repeats {
let dst = unsafe { (out.data_ptr() as *mut u8).add((h * repeats + r) * head_bytes) };
d2d(dst, src, head_bytes);
}
}
out
}
/// Load EAGLE3 weights from safetensors, handling int64 d2t specially.
fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec<i64>) {
let st_path = dir.join("model.safetensors");
assert!(
st_path.exists(),
"Eagle3 model.safetensors not found in {}. Convert with:\n\
python3 -c \"import torch; from safetensors.torch import save_file; \
sd=torch.load('pytorch_model.bin', map_location='cpu', weights_only=False); \
save_file(sd, 'model.safetensors')\"",
dir.display()
);
let data = std::fs::read(&st_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", st_path.display()));
let st = safetensors::SafeTensors::deserialize(&data)
.unwrap_or_else(|e| panic!("failed to parse {}: {e}", st_path.display()));
let mut tensors = HashMap::new();
let mut d2t_vec: Vec<i64> = Vec::new();
for (name, view) in st.tensors() {
if name == "t2d" {
continue;
}
if name == "d2t" {
let raw = view.data();
assert_eq!(view.dtype(), safetensors::Dtype::I64);
let n = raw.len() / 8;
d2t_vec = (0..n)
.map(|i| i64::from_le_bytes(raw[i * 8..(i + 1) * 8].try_into().unwrap()))
.collect();
continue;
}
let dtype = match view.dtype() {
safetensors::Dtype::BF16 => DType::BF16,
safetensors::Dtype::F32 => DType::F32,
safetensors::Dtype::F16 => DType::F16,
other => {
eprintln!("eagle3: skipping {name} with unsupported dtype {other:?}");
continue;
}
};
let shape: Vec<usize> = view.shape().to_vec();
let raw = view.data();
let t = crate::loader::make_tensor(raw, &shape, dtype);
let t = t.to_device(Device::Cuda(device));
tensors.insert(name.to_string(), t);
}
assert!(
!d2t_vec.is_empty(),
"d2t tensor not found in eagle3 weights"
);
(tensors, d2t_vec)
}

View File

@@ -1,5 +1,6 @@
pub mod config;
pub mod decode_graph;
pub mod eagle3;
pub mod gpt2;
pub mod gpt_oss;
pub mod gpt_oss_graph;

View File

@@ -68,7 +68,7 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
all_tensors
}
fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
pub(crate) fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
match dtype {
DType::F32 => {
let floats: &[f32] = unsafe {

View File

@@ -825,6 +825,111 @@ impl Qwen3 {
matmul_2d(&x, &self.lm_head_t)
}
/// Like `decode_core` but also captures hidden states at 3 specified layer
/// indices (after residual+MLP output). Used by EAGLE3 speculative drafting
/// to feed the draft head with low/mid/high target representations.
pub fn decode_core_with_hidden(
&self,
ids_gpu: *const std::ffi::c_void,
pos_gpu: *const std::ffi::c_void,
batch: usize,
seq_slots: &[usize],
paged_cache: &mut PagedKVCache,
hook_layers: &[usize; 3],
) -> (Tensor, [Tensor; 3]) {
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch);
let mut hooks: [Option<Tensor>; 3] = [None, None, None];
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim);
let k_all = qkv.narrow(1, q_dim, kv_dim);
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
let k_flat = k_all
.contiguous()
.reshape(&[batch * num_kv_heads, head_dim]);
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
let q_3d = q_normed.reshape(&[batch, num_heads, head_dim]);
let k_3d = k_normed.reshape(&[batch, num_kv_heads, head_dim]);
rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu);
rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu);
let v_3d = v_all.contiguous().reshape(&[batch, num_kv_heads, head_dim]);
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, batch);
let q_4d = q_3d.reshape(&[batch, num_heads, 1, head_dim]);
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
);
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down);
x = add_any(&residual, &down);
for (h_idx, &h_layer) in hook_layers.iter().enumerate() {
if layer_idx == h_layer {
hooks[h_idx] = Some(x.clone());
}
}
}
for &slot in seq_slots {
paged_cache.advance_seq_len(slot, 1);
}
let x = rmsnorm(&x, &self.norm, eps);
let logits = matmul_2d(&x, &self.lm_head_t);
let hidden_arr = [
hooks[0].take().expect("hook layer 0 not reached"),
hooks[1].take().expect("hook layer 1 not reached"),
hooks[2].take().expect("hook layer 2 not reached"),
];
(logits, hidden_arr)
}
/// Paged prefill: write a sequence of `new_tokens` K/V into the paged
/// cache for `slot`, run flash attention via gathered contiguous K/V.
/// Returns logits [new_tokens, vocab_size].
@@ -1074,6 +1179,12 @@ impl Qwen3 {
matmul_2d(&x, &self.lm_head_t)
}
/// Reference to the target's token embedding table. Shared (not copied)
/// with speculative draft heads like EAGLE3.
pub fn embed_tokens_tensor(&self) -> &Tensor {
&self.embed_tokens
}
/// Extract weight pointers for CUDA Graph capture.
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
self.layers