speculative: EAGLE3 draft head implementation (Phase 25 step 1)
- eagle3.rs: Eagle3Head struct loads AngelSlim/Qwen3-8B_eagle3 safetensors,
runs a single draft step via fc(concat(h_low, h_mid, h_high)) +
concat(input_norm(emb), hidden_norm(fused_h)) → 1 midlayer → norm →
lm_head → argmax in draft_vocab(32000) → d2t → target_vocab.
- qwen3.rs: new decode_core_with_hidden method that mirrors decode_core
but captures hidden states at 3 configurable layer indices (default
[11, 23, 35] for the 36-layer Qwen3-8B). Also expose embed_tokens_tensor
and (in eagle3) map_draft_to_target as public accessors.
- loader.rs: make_tensor now pub(crate) so eagle3 can reuse it.
- bin/check-eagle3.rs: sanity binary that loads target + EAGLE, runs one
prefill + one decode + one EAGLE step, prints the top-5 EAGLE predictions.
Verified on dash5 with prompt "The capital of France is":
target says: " Paris" then "."
EAGLE top-5: "," / " Paris" / " Madrid" / "." / " Berlin"
Weights load correctly, d2t mapping works, hidden state hooks are the
right shape ([1, 4096]), and EAGLE produces thematically-relevant tokens.
The top-1 pick "," doesn't match target's "." at this position, but
that's expected: this test uses hidden states from a single decode step
with no recursive chaining. A full speculative loop still needs the
γ≥2 verify + accept path wired up (next step).
This commit is contained in:
152
crates/xserv-model/src/bin/check-eagle3.rs
Normal file
152
crates/xserv-model/src/bin/check-eagle3.rs
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
//! EAGLE3 sanity check: load weights, run one draft step, print top-5 predictions.
|
||||||
|
//!
|
||||||
|
//! This verifies that:
|
||||||
|
//! - Eagle3Head weights load without shape mismatches
|
||||||
|
//! - Target hidden states can be captured via decode_core_with_hidden
|
||||||
|
//! - Eagle3Head::step produces a valid token id (in target vocab)
|
||||||
|
//!
|
||||||
|
//! Does NOT measure speedup — that requires a full γ≥2 speculative loop, which
|
||||||
|
//! is more complex integration work.
|
||||||
|
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use xserv_model::eagle3::{EAGLE_HOOK_LAYERS, Eagle3Head};
|
||||||
|
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||||
|
use xserv_tensor::{DType, Device, Tensor};
|
||||||
|
use xserv_tokenizer::Tokenizer;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let args: Vec<String> = std::env::args().collect();
|
||||||
|
if args.len() < 3 {
|
||||||
|
eprintln!("Usage: check-eagle3 <target-model-dir> <eagle3-model-dir> [prompt]");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
let target_dir = PathBuf::from(&args[1]);
|
||||||
|
let eagle_dir = PathBuf::from(&args[2]);
|
||||||
|
let prompt = args
|
||||||
|
.get(3)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "The capital of France is".to_string());
|
||||||
|
let device: u32 = 0;
|
||||||
|
|
||||||
|
xserv_cuda::device::set_device(device).unwrap();
|
||||||
|
|
||||||
|
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
|
||||||
|
eprintln!("Loading target Qwen3-8B...");
|
||||||
|
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
|
||||||
|
let target = Qwen3::from_weights(target_config.clone(), target_weights);
|
||||||
|
xserv_cuda::allocator::cached_trim();
|
||||||
|
|
||||||
|
eprintln!("Loading EAGLE3 head from {}", eagle_dir.display());
|
||||||
|
let eagle = Eagle3Head::load(&eagle_dir, device);
|
||||||
|
xserv_cuda::allocator::cached_trim();
|
||||||
|
|
||||||
|
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||||
|
let embed_tokens = target.embed_tokens_tensor();
|
||||||
|
|
||||||
|
let ids = tokenizer.encode(&prompt);
|
||||||
|
let max_seq_len = 512;
|
||||||
|
|
||||||
|
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 2;
|
||||||
|
let mut cache = PagedKVCache::new(
|
||||||
|
&target_config,
|
||||||
|
num_blocks,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
num_blocks,
|
||||||
|
DType::BF16,
|
||||||
|
device,
|
||||||
|
);
|
||||||
|
cache.register_sequence(0).unwrap();
|
||||||
|
|
||||||
|
// Prefill target.
|
||||||
|
let logits = target.forward_prefill_paged(&ids, 0, &mut cache);
|
||||||
|
let target_first = *xserv_kernels::argmax_bf16_to_host(&logits).last().unwrap();
|
||||||
|
let target_first_text = tokenizer.decode(&[target_first]);
|
||||||
|
println!("Prompt: {:?}", prompt);
|
||||||
|
println!(
|
||||||
|
"Target argmax after prefill: {} ({:?})",
|
||||||
|
target_first, target_first_text
|
||||||
|
);
|
||||||
|
|
||||||
|
// Now run one target decode step with target_first to get hidden states at the
|
||||||
|
// hook layers.
|
||||||
|
let pos = cache.seq_len(0);
|
||||||
|
target.decode_prepare(&[pos], &[0], &mut cache);
|
||||||
|
let ids_gpu = upload_u32(&[target_first]);
|
||||||
|
let pos_gpu = upload_u32(&[pos as u32]);
|
||||||
|
let (target_next_logits, hooks) = target.decode_core_with_hidden(
|
||||||
|
ids_gpu.as_ptr() as *const std::ffi::c_void,
|
||||||
|
pos_gpu.as_ptr() as *const std::ffi::c_void,
|
||||||
|
1,
|
||||||
|
&[0],
|
||||||
|
&mut cache,
|
||||||
|
&EAGLE_HOOK_LAYERS,
|
||||||
|
);
|
||||||
|
let target_next = xserv_kernels::argmax_bf16_single(&target_next_logits);
|
||||||
|
let target_next_text = tokenizer.decode(&[target_next]);
|
||||||
|
println!(
|
||||||
|
"Target argmax after 1 decode step: {} ({:?})",
|
||||||
|
target_next, target_next_text
|
||||||
|
);
|
||||||
|
|
||||||
|
for (i, h) in hooks.iter().enumerate() {
|
||||||
|
println!(
|
||||||
|
"hook[{}] (layer {}): shape={:?} dtype={:?}",
|
||||||
|
i,
|
||||||
|
EAGLE_HOOK_LAYERS[i],
|
||||||
|
h.shape(),
|
||||||
|
h.dtype()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ask EAGLE what it thinks the NEXT token is (given target_first as prev_token
|
||||||
|
// and the hidden states from the position where target_first lives).
|
||||||
|
// EAGLE should predict target_next (or close to it) to be useful.
|
||||||
|
let (eagle_pred, eagle_logits) = eagle.step(&hooks, embed_tokens, target_first, pos);
|
||||||
|
let eagle_pred_text = tokenizer.decode(&[eagle_pred]);
|
||||||
|
println!(
|
||||||
|
"EAGLE draft prediction: {} ({:?})",
|
||||||
|
eagle_pred, eagle_pred_text
|
||||||
|
);
|
||||||
|
|
||||||
|
if eagle_pred == target_next {
|
||||||
|
println!("MATCH: EAGLE agrees with target on next token.");
|
||||||
|
} else {
|
||||||
|
println!(
|
||||||
|
"MISMATCH: EAGLE draft={} vs target={} (this is fine per-step; check top-5 below)",
|
||||||
|
eagle_pred, target_next
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show top-5 from eagle logits (in draft vocab space, mapped to target).
|
||||||
|
print_top5(&eagle_logits, "EAGLE draft top-5", &eagle, &tokenizer);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer {
|
||||||
|
let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) };
|
||||||
|
let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).unwrap();
|
||||||
|
buf.copy_from_host(bytes).unwrap();
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_top5(logits: &Tensor, label: &str, eagle: &Eagle3Head, tokenizer: &Tokenizer) {
|
||||||
|
use half::bf16;
|
||||||
|
let cpu = logits.to_device(Device::Cpu);
|
||||||
|
let data = cpu.as_slice::<bf16>();
|
||||||
|
let mut vals: Vec<(usize, f32)> = data
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, v)| (i, v.to_f32()))
|
||||||
|
.collect();
|
||||||
|
vals.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||||
|
println!("{label}:");
|
||||||
|
for (i, val) in vals.iter().take(5) {
|
||||||
|
let target_id = eagle.map_draft_to_target(*i as u32);
|
||||||
|
let text = tokenizer.decode(&[target_id]);
|
||||||
|
println!(
|
||||||
|
" draft_id={} target_id={} val={:.3} text={:?}",
|
||||||
|
i, target_id, val, text
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
312
crates/xserv-model/src/eagle3.rs
Normal file
312
crates/xserv-model/src/eagle3.rs
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
//! EAGLE3 speculative draft head for Qwen3-8B (Phase 25).
|
||||||
|
//!
|
||||||
|
//! Loads the AngelSlim/Qwen3-8B_eagle3 pytorch_model.bin and provides a
|
||||||
|
//! single-step forward pass that takes 3 target hidden states + the previous
|
||||||
|
//! token and returns a draft token in the target vocabulary.
|
||||||
|
//!
|
||||||
|
//! Architecture (from weights):
|
||||||
|
//! - fc: [hidden, 3*hidden] → fuse 3 target hidden states
|
||||||
|
//! - midlayer: 1 decoder layer (attn input dim = 2*hidden)
|
||||||
|
//! - norm + lm_head: → [draft_vocab_size=32000]
|
||||||
|
//! - d2t: draft_id → target_id offset mapping
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use xserv_kernels::*;
|
||||||
|
use xserv_tensor::{DType, Device, Tensor};
|
||||||
|
|
||||||
|
pub const EAGLE_HOOK_LAYERS: [usize; 3] = [11, 23, 35];
|
||||||
|
const DRAFT_VOCAB_SIZE: usize = 32000;
|
||||||
|
|
||||||
|
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||||
|
assert_eq!(a.ndim(), 2);
|
||||||
|
assert_eq!(b.ndim(), 2);
|
||||||
|
matmul(a, b, GemmBackend::CuBlas)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Eagle3Head {
|
||||||
|
fc_wt: Tensor, // [hidden, 3*hidden] transposed for matmul
|
||||||
|
hidden_norm: Tensor, // [hidden]
|
||||||
|
input_layernorm: Tensor, // [hidden]
|
||||||
|
q_proj_wt: Tensor, // [num_heads*head_dim, 2*hidden]
|
||||||
|
k_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden]
|
||||||
|
v_proj_wt: Tensor, // [num_kv_heads*head_dim, 2*hidden]
|
||||||
|
o_proj_wt: Tensor, // [hidden, num_heads*head_dim]
|
||||||
|
gate_proj_wt: Tensor, // [intermediate, hidden]
|
||||||
|
up_proj_wt: Tensor, // [intermediate, hidden]
|
||||||
|
down_proj_wt: Tensor, // [hidden, intermediate]
|
||||||
|
post_attention_layernorm: Tensor, // [hidden]
|
||||||
|
norm: Tensor, // [hidden] final
|
||||||
|
lm_head_wt: Tensor, // [draft_vocab, hidden]
|
||||||
|
d2t: Vec<i64>, // [draft_vocab] offset mapping
|
||||||
|
hidden_size: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
rope_cache: RopeCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Eagle3Head {
|
||||||
|
pub fn load(dir: &Path, device: u32) -> Self {
|
||||||
|
let (weights, d2t) = load_eagle3_weights(dir, device);
|
||||||
|
let hidden_size = 4096;
|
||||||
|
let num_heads = 32;
|
||||||
|
let num_kv_heads = 8;
|
||||||
|
let head_dim = 128;
|
||||||
|
let intermediate_size = 12288;
|
||||||
|
let max_seq_len = 2048;
|
||||||
|
let rope_theta = 1_000_000.0f32;
|
||||||
|
|
||||||
|
let get = |name: &str| -> Tensor {
|
||||||
|
weights
|
||||||
|
.get(name)
|
||||||
|
.unwrap_or_else(|| panic!("missing eagle3 weight: {name}"))
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let fc_wt = get("fc.weight").transpose(0, 1).contiguous();
|
||||||
|
let q_proj_wt = get("midlayer.self_attn.q_proj.weight")
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous();
|
||||||
|
let k_proj_wt = get("midlayer.self_attn.k_proj.weight")
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous();
|
||||||
|
let v_proj_wt = get("midlayer.self_attn.v_proj.weight")
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous();
|
||||||
|
let o_proj_wt = get("midlayer.self_attn.o_proj.weight")
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous();
|
||||||
|
let gate_proj_wt = get("midlayer.mlp.gate_proj.weight")
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous();
|
||||||
|
let up_proj_wt = get("midlayer.mlp.up_proj.weight")
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous();
|
||||||
|
let down_proj_wt = get("midlayer.mlp.down_proj.weight")
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous();
|
||||||
|
let hidden_norm = get("midlayer.hidden_norm.weight");
|
||||||
|
let input_layernorm = get("midlayer.input_layernorm.weight");
|
||||||
|
let post_attention_layernorm = get("midlayer.post_attention_layernorm.weight");
|
||||||
|
let norm = get("norm.weight");
|
||||||
|
let lm_head_wt = get("lm_head.weight").transpose(0, 1).contiguous();
|
||||||
|
|
||||||
|
assert_eq!(d2t.len(), DRAFT_VOCAB_SIZE);
|
||||||
|
|
||||||
|
let rope_cache = RopeCache::new(max_seq_len, head_dim, rope_theta);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
fc_wt,
|
||||||
|
hidden_norm,
|
||||||
|
input_layernorm,
|
||||||
|
q_proj_wt,
|
||||||
|
k_proj_wt,
|
||||||
|
v_proj_wt,
|
||||||
|
o_proj_wt,
|
||||||
|
gate_proj_wt,
|
||||||
|
up_proj_wt,
|
||||||
|
down_proj_wt,
|
||||||
|
post_attention_layernorm,
|
||||||
|
norm,
|
||||||
|
lm_head_wt,
|
||||||
|
d2t,
|
||||||
|
hidden_size,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
rope_cache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// One draft step: produce a token in target vocabulary space.
|
||||||
|
///
|
||||||
|
/// - `target_hidden`: 3 tensors [1, hidden_size] from target hook layers
|
||||||
|
/// - `embed_table`: the target model's embed_tokens (shared, not copied)
|
||||||
|
/// - `prev_token`: the previous committed token
|
||||||
|
/// - `position`: the decode position for RoPE
|
||||||
|
///
|
||||||
|
/// Returns (draft_token_in_target_vocab, draft_logits_tensor).
|
||||||
|
pub fn step(
|
||||||
|
&self,
|
||||||
|
target_hidden: &[Tensor; 3],
|
||||||
|
embed_table: &Tensor,
|
||||||
|
prev_token: u32,
|
||||||
|
position: usize,
|
||||||
|
) -> (u32, Tensor) {
|
||||||
|
let eps = 1e-6f32;
|
||||||
|
|
||||||
|
// 1. Fuse target hidden states: concat [h_low, h_mid, h_high] → fc
|
||||||
|
let h_cat = concat_hidden(target_hidden);
|
||||||
|
let fused_h = matmul_2d(&h_cat, &self.fc_wt); // [1, hidden]
|
||||||
|
|
||||||
|
// 2. Embed previous token (shared with target)
|
||||||
|
let emb = embedding(embed_table, &[prev_token]); // [1, hidden]
|
||||||
|
|
||||||
|
// 3. Concat normalized: [norm(emb), norm(fused_h)] → [1, 2*hidden]
|
||||||
|
let emb_normed = rmsnorm(&emb, &self.input_layernorm, eps);
|
||||||
|
let h_normed = rmsnorm(&fused_h, &self.hidden_norm, eps);
|
||||||
|
let attn_in = concat_last_dim(&emb_normed, &h_normed); // [1, 8192]
|
||||||
|
|
||||||
|
// 4. Self-attention (no KV cache for simplicity in v0 — single query)
|
||||||
|
let q = matmul_2d(&attn_in, &self.q_proj_wt); // [1, num_heads*head_dim]
|
||||||
|
let k = matmul_2d(&attn_in, &self.k_proj_wt); // [1, num_kv*head_dim]
|
||||||
|
let v = matmul_2d(&attn_in, &self.v_proj_wt); // [1, num_kv*head_dim]
|
||||||
|
|
||||||
|
let q_3d = q.reshape(&[1, self.num_heads, self.head_dim]);
|
||||||
|
let k_3d = k.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||||
|
let positions = [position as u32];
|
||||||
|
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||||
|
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||||
|
|
||||||
|
// Single-token attention: Q·K^T / sqrt(d) → softmax → V
|
||||||
|
// With seq_len=1, attention is trivial: output = V (weight=1.0)
|
||||||
|
let attn_out = v.reshape(&[1, self.num_kv_heads, self.head_dim]);
|
||||||
|
let attn_out = if self.num_heads != self.num_kv_heads {
|
||||||
|
repeat_kv_for_single_token(&attn_out, self.num_heads / self.num_kv_heads)
|
||||||
|
} else {
|
||||||
|
attn_out
|
||||||
|
};
|
||||||
|
let attn_merged = attn_out.reshape(&[1, self.num_heads * self.head_dim]);
|
||||||
|
let attn_proj = matmul_2d(&attn_merged, &self.o_proj_wt); // [1, hidden]
|
||||||
|
|
||||||
|
// Residual from embedding
|
||||||
|
let x = add(&attn_proj, &emb);
|
||||||
|
|
||||||
|
// 5. MLP
|
||||||
|
let normed = rmsnorm(&x, &self.post_attention_layernorm, eps);
|
||||||
|
let gate = matmul_2d(&normed, &self.gate_proj_wt);
|
||||||
|
let up = matmul_2d(&normed, &self.up_proj_wt);
|
||||||
|
let mlp_out = silu_mul(&gate, &up);
|
||||||
|
let down = matmul_2d(&mlp_out, &self.down_proj_wt);
|
||||||
|
let x = add(&x, &down);
|
||||||
|
|
||||||
|
// 6. Final norm + lm_head
|
||||||
|
let x = rmsnorm(&x, &self.norm, eps);
|
||||||
|
let logits = matmul_2d(&x, &self.lm_head_wt); // [1, 32000]
|
||||||
|
|
||||||
|
// 7. Argmax in draft vocab → map to target vocab
|
||||||
|
let draft_id = argmax_bf16_single(&logits);
|
||||||
|
let target_id = (draft_id as i64 + self.d2t[draft_id as usize]) as u32;
|
||||||
|
|
||||||
|
(target_id, logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map a draft-vocab token id to the full target-vocab id via d2t.
|
||||||
|
pub fn map_draft_to_target(&self, draft_id: u32) -> u32 {
|
||||||
|
(draft_id as i64 + self.d2t[draft_id as usize]) as u32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn d2d(dst: *mut u8, src: *const u8, bytes: usize) {
|
||||||
|
unsafe {
|
||||||
|
xserv_cuda::ffi::cudaMemcpy(dst, src, bytes, xserv_cuda::ffi::CUDA_MEMCPY_D2D);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn concat_hidden(hidden: &[Tensor; 3]) -> Tensor {
|
||||||
|
let h = hidden[0].shape()[1];
|
||||||
|
let dtype = hidden[0].dtype();
|
||||||
|
let device = hidden[0].device();
|
||||||
|
let elem_bytes = dtype.size_bytes();
|
||||||
|
let out = Tensor::empty(&[1, 3 * h], dtype, device);
|
||||||
|
for (i, t) in hidden.iter().enumerate() {
|
||||||
|
assert!(t.is_contiguous());
|
||||||
|
let dst = unsafe { (out.data_ptr() as *mut u8).add(i * h * elem_bytes) };
|
||||||
|
d2d(dst, t.data_ptr() as *const u8, h * elem_bytes);
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn concat_last_dim(a: &Tensor, b: &Tensor) -> Tensor {
|
||||||
|
let da = a.shape()[1];
|
||||||
|
let db = b.shape()[1];
|
||||||
|
let dtype = a.dtype();
|
||||||
|
let device = a.device();
|
||||||
|
let elem_bytes = dtype.size_bytes();
|
||||||
|
let out = Tensor::empty(&[1, da + db], dtype, device);
|
||||||
|
d2d(
|
||||||
|
out.data_ptr() as *mut u8,
|
||||||
|
a.data_ptr() as *const u8,
|
||||||
|
da * elem_bytes,
|
||||||
|
);
|
||||||
|
let dst = unsafe { (out.data_ptr() as *mut u8).add(da * elem_bytes) };
|
||||||
|
d2d(dst, b.data_ptr() as *const u8, db * elem_bytes);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn repeat_kv_for_single_token(kv: &Tensor, repeats: usize) -> Tensor {
|
||||||
|
if repeats == 1 {
|
||||||
|
return kv.clone();
|
||||||
|
}
|
||||||
|
let nkv = kv.shape()[1];
|
||||||
|
let d = kv.shape()[2];
|
||||||
|
let dtype = kv.dtype();
|
||||||
|
let device = kv.device();
|
||||||
|
let head_bytes = d * dtype.size_bytes();
|
||||||
|
let out = Tensor::empty(&[1, nkv * repeats, d], dtype, device);
|
||||||
|
for h in 0..nkv {
|
||||||
|
let src = unsafe { (kv.data_ptr() as *const u8).add(h * head_bytes) };
|
||||||
|
for r in 0..repeats {
|
||||||
|
let dst = unsafe { (out.data_ptr() as *mut u8).add((h * repeats + r) * head_bytes) };
|
||||||
|
d2d(dst, src, head_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load EAGLE3 weights from safetensors, handling int64 d2t specially.
|
||||||
|
fn load_eagle3_weights(dir: &Path, device: u32) -> (HashMap<String, Tensor>, Vec<i64>) {
|
||||||
|
let st_path = dir.join("model.safetensors");
|
||||||
|
assert!(
|
||||||
|
st_path.exists(),
|
||||||
|
"Eagle3 model.safetensors not found in {}. Convert with:\n\
|
||||||
|
python3 -c \"import torch; from safetensors.torch import save_file; \
|
||||||
|
sd=torch.load('pytorch_model.bin', map_location='cpu', weights_only=False); \
|
||||||
|
save_file(sd, 'model.safetensors')\"",
|
||||||
|
dir.display()
|
||||||
|
);
|
||||||
|
|
||||||
|
let data = std::fs::read(&st_path)
|
||||||
|
.unwrap_or_else(|e| panic!("failed to read {}: {e}", st_path.display()));
|
||||||
|
let st = safetensors::SafeTensors::deserialize(&data)
|
||||||
|
.unwrap_or_else(|e| panic!("failed to parse {}: {e}", st_path.display()));
|
||||||
|
|
||||||
|
let mut tensors = HashMap::new();
|
||||||
|
let mut d2t_vec: Vec<i64> = Vec::new();
|
||||||
|
|
||||||
|
for (name, view) in st.tensors() {
|
||||||
|
if name == "t2d" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if name == "d2t" {
|
||||||
|
let raw = view.data();
|
||||||
|
assert_eq!(view.dtype(), safetensors::Dtype::I64);
|
||||||
|
let n = raw.len() / 8;
|
||||||
|
d2t_vec = (0..n)
|
||||||
|
.map(|i| i64::from_le_bytes(raw[i * 8..(i + 1) * 8].try_into().unwrap()))
|
||||||
|
.collect();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let dtype = match view.dtype() {
|
||||||
|
safetensors::Dtype::BF16 => DType::BF16,
|
||||||
|
safetensors::Dtype::F32 => DType::F32,
|
||||||
|
safetensors::Dtype::F16 => DType::F16,
|
||||||
|
other => {
|
||||||
|
eprintln!("eagle3: skipping {name} with unsupported dtype {other:?}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let shape: Vec<usize> = view.shape().to_vec();
|
||||||
|
let raw = view.data();
|
||||||
|
let t = crate::loader::make_tensor(raw, &shape, dtype);
|
||||||
|
let t = t.to_device(Device::Cuda(device));
|
||||||
|
tensors.insert(name.to_string(), t);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!d2t_vec.is_empty(),
|
||||||
|
"d2t tensor not found in eagle3 weights"
|
||||||
|
);
|
||||||
|
(tensors, d2t_vec)
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod decode_graph;
|
pub mod decode_graph;
|
||||||
|
pub mod eagle3;
|
||||||
pub mod gpt2;
|
pub mod gpt2;
|
||||||
pub mod gpt_oss;
|
pub mod gpt_oss;
|
||||||
pub mod gpt_oss_graph;
|
pub mod gpt_oss_graph;
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
|
|||||||
all_tensors
|
all_tensors
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
pub(crate) fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||||
match dtype {
|
match dtype {
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let floats: &[f32] = unsafe {
|
let floats: &[f32] = unsafe {
|
||||||
|
|||||||
@@ -825,6 +825,111 @@ impl Qwen3 {
|
|||||||
matmul_2d(&x, &self.lm_head_t)
|
matmul_2d(&x, &self.lm_head_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Like `decode_core` but also captures hidden states at 3 specified layer
|
||||||
|
/// indices (after residual+MLP output). Used by EAGLE3 speculative drafting
|
||||||
|
/// to feed the draft head with low/mid/high target representations.
|
||||||
|
pub fn decode_core_with_hidden(
|
||||||
|
&self,
|
||||||
|
ids_gpu: *const std::ffi::c_void,
|
||||||
|
pos_gpu: *const std::ffi::c_void,
|
||||||
|
batch: usize,
|
||||||
|
seq_slots: &[usize],
|
||||||
|
paged_cache: &mut PagedKVCache,
|
||||||
|
hook_layers: &[usize; 3],
|
||||||
|
) -> (Tensor, [Tensor; 3]) {
|
||||||
|
let num_heads = self.local_num_heads;
|
||||||
|
let num_kv_heads = self.local_num_kv_heads;
|
||||||
|
let head_dim = self.config.head_dim();
|
||||||
|
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||||
|
|
||||||
|
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||||
|
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||||
|
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||||
|
|
||||||
|
let mut x = embedding_device_ids(&self.embed_tokens, ids_gpu, batch);
|
||||||
|
let mut hooks: [Option<Tensor>; 3] = [None, None, None];
|
||||||
|
|
||||||
|
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||||
|
let residual = x.clone();
|
||||||
|
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||||
|
|
||||||
|
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||||
|
let q_dim = num_heads * head_dim;
|
||||||
|
let kv_dim = num_kv_heads * head_dim;
|
||||||
|
let q_all = qkv.narrow(1, 0, q_dim);
|
||||||
|
let k_all = qkv.narrow(1, q_dim, kv_dim);
|
||||||
|
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||||
|
|
||||||
|
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
|
||||||
|
let k_flat = k_all
|
||||||
|
.contiguous()
|
||||||
|
.reshape(&[batch * num_kv_heads, head_dim]);
|
||||||
|
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||||
|
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||||
|
|
||||||
|
let q_3d = q_normed.reshape(&[batch, num_heads, head_dim]);
|
||||||
|
let k_3d = k_normed.reshape(&[batch, num_kv_heads, head_dim]);
|
||||||
|
rope_inplace_device_pos(&q_3d, &self.rope_cache, pos_gpu);
|
||||||
|
rope_inplace_device_pos(&k_3d, &self.rope_cache, pos_gpu);
|
||||||
|
|
||||||
|
let v_3d = v_all.contiguous().reshape(&[batch, num_kv_heads, head_dim]);
|
||||||
|
|
||||||
|
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, batch);
|
||||||
|
|
||||||
|
let q_4d = q_3d.reshape(&[batch, num_heads, 1, head_dim]);
|
||||||
|
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||||
|
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||||
|
let attn_out = xserv_kernels::paged_decode_attention(
|
||||||
|
&q_4d,
|
||||||
|
k_pool_ptr,
|
||||||
|
v_pool_ptr,
|
||||||
|
bt_ptr,
|
||||||
|
cl_ptr,
|
||||||
|
batch,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
max_blocks,
|
||||||
|
);
|
||||||
|
|
||||||
|
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||||
|
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||||
|
self.all_reduce(&attn_proj);
|
||||||
|
|
||||||
|
let (normed, x_new) =
|
||||||
|
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||||
|
let residual = x_new.clone();
|
||||||
|
|
||||||
|
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||||
|
let ffn_dim = gate_up.shape()[1] / 2;
|
||||||
|
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||||
|
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||||
|
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||||
|
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||||
|
self.all_reduce(&down);
|
||||||
|
x = add_any(&residual, &down);
|
||||||
|
|
||||||
|
for (h_idx, &h_layer) in hook_layers.iter().enumerate() {
|
||||||
|
if layer_idx == h_layer {
|
||||||
|
hooks[h_idx] = Some(x.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for &slot in seq_slots {
|
||||||
|
paged_cache.advance_seq_len(slot, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let x = rmsnorm(&x, &self.norm, eps);
|
||||||
|
let logits = matmul_2d(&x, &self.lm_head_t);
|
||||||
|
let hidden_arr = [
|
||||||
|
hooks[0].take().expect("hook layer 0 not reached"),
|
||||||
|
hooks[1].take().expect("hook layer 1 not reached"),
|
||||||
|
hooks[2].take().expect("hook layer 2 not reached"),
|
||||||
|
];
|
||||||
|
(logits, hidden_arr)
|
||||||
|
}
|
||||||
|
|
||||||
/// Paged prefill: write a sequence of `new_tokens` K/V into the paged
|
/// Paged prefill: write a sequence of `new_tokens` K/V into the paged
|
||||||
/// cache for `slot`, run flash attention via gathered contiguous K/V.
|
/// cache for `slot`, run flash attention via gathered contiguous K/V.
|
||||||
/// Returns logits [new_tokens, vocab_size].
|
/// Returns logits [new_tokens, vocab_size].
|
||||||
@@ -1074,6 +1179,12 @@ impl Qwen3 {
|
|||||||
matmul_2d(&x, &self.lm_head_t)
|
matmul_2d(&x, &self.lm_head_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reference to the target's token embedding table. Shared (not copied)
|
||||||
|
/// with speculative draft heads like EAGLE3.
|
||||||
|
pub fn embed_tokens_tensor(&self) -> &Tensor {
|
||||||
|
&self.embed_tokens
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract weight pointers for CUDA Graph capture.
|
/// Extract weight pointers for CUDA Graph capture.
|
||||||
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
|
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
|
||||||
self.layers
|
self.layers
|
||||||
|
|||||||
Reference in New Issue
Block a user