eagle3: cuBLAS-GEMM verify path — speedup_e2e > 1 achieved 🎉

Swap forward_verify_paged_decode_attention_with_hidden's projections
from matmul_batched_gemv (per-row bit-exact GEMV) to matmul_2d (cuBLAS
GEMM at m>1). This trades bit-exact parity with baseline for a much
cheaper batched verify.

Micro-benchmark (bench-verify-cost.rs) reveals the huge cost gap:
  batched-GEMV verify: 1.05× → 5.14× single decode (linear in batch)
  cuBLAS-GEMM verify:  1.04× → 1.20× single decode (nearly flat)

At batch=9 the difference is 4.3× — cuBLAS amortizes K/V load across
all queries while GEMV loads K/V for each row independently.

50 prompts × 64 tokens γ sweep on dash5 (Qwen3-8B + Qwen3-8B_eagle3):
  γ=2: acceptance=16.9%, speedup_e2e = 1.10× ← best
  γ=3: acceptance=11.6%, speedup_e2e = 1.06×
  γ=4: acceptance=8.9%,  speedup_e2e = 1.02×
  γ>4: speedup drops as acceptance falls faster than verify saves.

Tradeoff: matched=false — spec output diverges from baseline single-
decode by a few tokens per prompt because cuBLAS GEMM at m>1 rounds
BF16 differently from custom GEMV at m=1, so the K/V bytes written by
verify aren't bit-exact with what a single-token decode would write.
Downstream this compounds into slightly different token choices.

The spec output is still a VALID target model output — it's just via
a different numerical path. Semantically the outputs are indistinguishable
(both coherent English continuations of the prompt). This is the
industry-standard interpretation of "lossless spec decoding": target
distribution preserved modulo BF16 rounding, not bit-exact with a
specific numerical path.

New: crates/xserv-model/src/bin/bench-verify-cost.rs — micro-benchmark
that measures verify cost at various batch sizes, isolating the impact
of the GEMV vs GEMM choice.
This commit is contained in:
2026-07-01 19:58:23 +08:00
parent 9a1af0adee
commit 06a798cab9
2 changed files with 139 additions and 5 deletions

View File

@@ -0,0 +1,134 @@
//! Micro-benchmark: measure the cost of forward_verify_paged_decode_attention
//! at different batch sizes (γ+1 values), to understand where speedup comes
//! from (or doesn't).
use std::path::PathBuf;
use std::time::Instant;
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!(
"Usage: bench-verify-cost <target-dir> [--prompt-len N] [--iters N] [--device N]"
);
std::process::exit(1);
}
let target_dir = PathBuf::from(&args[1]);
let prompt_len = arg_usize(&args, "--prompt-len", 100);
let iters = arg_usize(&args, "--iters", 30);
let device = arg_usize(&args, "--device", 0) as u32;
xserv_cuda::device::set_device(device).unwrap();
let cfg = ModelConfig::from_file(&target_dir.join("config.json"));
eprintln!("Loading target...");
let weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
let target = Qwen3::from_weights(cfg.clone(), weights);
xserv_cuda::allocator::cached_trim();
let tok = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
let ids = tok.encode(&"the ".repeat(prompt_len))[..prompt_len].to_vec();
let max_seq_len = 2048;
let num_blocks = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE + 4;
let mut cache = PagedKVCache::new(&cfg, num_blocks, 0, 16, num_blocks, DType::BF16, device);
cache.register_sequence(0).unwrap();
// Prefill
let _ = target.forward_prefill_paged(&ids, 0, &mut cache);
sync();
// Warmup one of each
for &n in &[1, 2, 3, 5, 9] {
let toks: Vec<u32> = (0..n).map(|_| ids[0]).collect();
let _ = target.forward_decode_paged(
&toks,
&(0..n).map(|i| ids.len() + i).collect::<Vec<_>>(),
&vec![0; n],
&mut cache,
);
cache.truncate_sequence(0, ids.len()).unwrap();
}
sync();
// Benchmark single-token decode
let mut t = 0.0f64;
for i in 0..iters {
cache.truncate_sequence(0, ids.len()).unwrap();
let t0 = Instant::now();
let _ = target.forward_decode_paged(&[ids[0]], &[ids.len()], &[0], &mut cache);
sync();
t += t0.elapsed().as_secs_f64();
let _ = i;
}
let single = t * 1000.0 / iters as f64;
println!(
"single-token decode: {:.3} ms (mean of {} iters)",
single, iters
);
// Benchmark forward_verify_paged_decode_attention at various batch sizes
// (batched-GEMV path).
for &n in &[1usize, 2, 3, 5, 9] {
let toks: Vec<u32> = (0..n).map(|_| ids[0]).collect();
let mut t = 0.0f64;
for _ in 0..iters {
cache.truncate_sequence(0, ids.len()).unwrap();
let t0 = Instant::now();
let _ = target.forward_verify_paged_decode_attention(&toks, 0, &mut cache);
sync();
t += t0.elapsed().as_secs_f64();
}
let ms = t * 1000.0 / iters as f64;
println!(
"verify (batched-GEMV) batch={}: {:.3} ms ({:.2}× single)",
n,
ms,
ms / single
);
}
// Benchmark _with_hidden variant which uses cuBLAS GEMM after Phase 26 fast-verify.
let hooks_layers = [2usize, 18, 33];
for &n in &[1usize, 2, 3, 5, 9] {
let toks: Vec<u32> = (0..n).map(|_| ids[0]).collect();
let mut t = 0.0f64;
for _ in 0..iters {
cache.truncate_sequence(0, ids.len()).unwrap();
let t0 = Instant::now();
let _ = target.forward_verify_paged_decode_attention_with_hidden(
&toks,
0,
&mut cache,
&hooks_layers,
);
sync();
t += t0.elapsed().as_secs_f64();
}
let ms = t * 1000.0 / iters as f64;
println!(
"verify (cuBLAS GEMM) batch={}: {:.3} ms ({:.2}× single)",
n,
ms,
ms / single
);
}
cache.free_sequence(0);
}
fn sync() {
xserv_cuda::device::synchronize().unwrap();
}
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
args.iter()
.position(|a| a == flag)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}

View File

@@ -1154,7 +1154,7 @@ impl Qwen3 {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let qkv = matmul_batched_gemv(&normed, &layer.qkv_proj_wt);
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim);
@@ -1197,19 +1197,19 @@ impl Qwen3 {
);
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
let attn_proj = matmul_batched_gemv(&attn_merged, &layer.o_proj_wt);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_batched_gemv(&normed, &layer.gate_up_proj_wt);
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_batched_gemv(&hidden_states, &layer.down_proj_wt);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down);
x = add_any(&residual, &down);
@@ -1221,7 +1221,7 @@ impl Qwen3 {
}
let x = rmsnorm(&x, &self.norm, eps);
let logits = matmul_batched_gemv(&x, &self.lm_head_t);
let logits = matmul_2d(&x, &self.lm_head_t);
let hidden_arr = [
hooks[0].take().expect("hook layer 0 not reached"),
hooks[1].take().expect("hook layer 1 not reached"),