speculative: Qwen3 draft-model v0 with paged verify parity

Phase 22 lands a correctness-only speculative decoding loop for Qwen3
target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns
verify logits into the authoritative acceptance signal so mirror-decode
per accepted token is no longer needed.

- paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered
  sequence, freeing whole physical blocks no longer reachable and
  leaving the slot registered. Covered by a CUDA-gated unit test.
- qwen3: forward_verify_paged_decode_attention writes the draft window
  into the target cache, runs the same paged decode attention kernel per
  draft token, and uses matmul_rows_gemv so linear layers follow the
  single-token decode BF16 rounding path.
- bench-speculative: new bench binary drives the state machine with
  --gamma / --gen-tokens / --prompts / --use-verify-logits /
  --verify-path flash|paged-decode / --dump-verify-mismatches, and
  compares baseline vs spec token sequences plus TPOT / tok/s / speedup.
- docs/22 records the decode-authoritative v0 result and dash5 numbers
  (matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under
  --use-verify-logits).
- docs/23 records the paged-decode verify path (matched=true,
  verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the
  next-step performance TODO.
This commit is contained in:
2026-07-01 14:15:39 +08:00
parent 5b350ee5f0
commit ce7229f4fe
5 changed files with 1444 additions and 0 deletions

View File

@@ -0,0 +1,962 @@
//! Draft-model speculative decoding benchmark for Qwen3.
//!
//! v0 scope:
//! - target + draft are Qwen3-family models with the same tokenizer/vocab;
//! - batch=1;
//! - greedy exact-match acceptance;
//! - no probabilistic rejection sampling.
use half::bf16;
use std::path::{Path, PathBuf};
use std::time::Instant;
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
const DEFAULT_GAMMA: usize = 4;
const DEFAULT_GEN_TOKENS: usize = 64;
const DEFAULT_MAX_SEQ_LEN: usize = 2048;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum VerifyPath {
Flash,
PagedDecode,
}
impl VerifyPath {
fn as_str(self) -> &'static str {
match self {
VerifyPath::Flash => "flash",
VerifyPath::PagedDecode => "paged-decode",
}
}
}
const PROMPTS: [&str; 50] = [
"The capital of France is",
"Once upon a time in a land far away",
"Hello, how are you doing today",
"In a shocking finding, scientists discovered a",
"The weather today is sunny, so I decided to",
"Alan Turing was a British mathematician who",
"The best way to learn programming is",
"Artificial intelligence will change the world because",
"The history of the internet began in the",
"A good morning routine starts with",
"The stock market crashed because investors",
"Deep learning is a subset of machine learning that",
"The president of the United States announced",
"In the year 2050, humans will",
"The secret to happiness is",
"When I was a child, I used to",
"The most important scientific discovery of the century",
"Climate change is caused by",
"The recipe for chocolate cake requires",
"In conclusion, the evidence suggests that",
"The cat sat on the mat and",
"According to recent studies, exercise can",
"The first step in solving any problem is",
"Technology has transformed the way we",
"The novel begins with the protagonist",
"Education is the most powerful weapon",
"The ocean covers more than seventy percent of",
"Last night I had a dream about",
"The company announced its quarterly earnings",
"Music has the power to",
"The difference between success and failure is",
"In the beginning, there was nothing but",
"The doctor told me that I should",
"Python is a popular programming language because",
"The ancient Romans built roads that",
"A balanced diet should include",
"The movie received mixed reviews from critics",
"Space exploration has led to many",
"The teacher asked the students to",
"Global warming is one of the most",
"The bridge collapsed due to structural",
"Quantum computing promises to revolutionize",
"The new policy will affect millions of",
"During the winter months, it is important to",
"The human brain contains approximately",
"Democracy depends on the active participation of",
"The train arrived at the station exactly",
"Researchers at MIT have developed a new",
"The smartphone has become an essential part of",
"After careful consideration, the committee decided to",
];
#[derive(Default)]
struct RunStats {
ids: Vec<u32>,
total_s: f64,
prefill_s: f64,
decode_s: f64,
target_steps: usize,
accepted: usize,
proposed: usize,
verify_steps: usize,
mirror_steps: usize,
commit_steps: usize,
correction_steps: usize,
verify_decode_mismatches: usize,
}
#[derive(Default)]
struct Totals {
prompts: usize,
baseline_generated: usize,
spec_generated: usize,
baseline_total_s: f64,
baseline_prefill_s: f64,
baseline_decode_s: f64,
spec_total_s: f64,
spec_prefill_s: f64,
spec_decode_s: f64,
spec_target_steps: usize,
spec_accepted: usize,
spec_proposed: usize,
spec_verify_steps: usize,
spec_mirror_steps: usize,
spec_commit_steps: usize,
spec_correction_steps: usize,
spec_verify_decode_mismatches: usize,
mismatches: usize,
}
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 3 {
eprintln!(
"Usage: bench-speculative <target-model-dir> <draft-model-dir> \
[--gen-tokens N] [--gamma N] [--prompts N] [--max-seq-len N] [--device N] \
[--use-verify-logits] [--verify-path flash|paged-decode] [--dump-verify-mismatches]"
);
std::process::exit(1);
}
let target_dir = PathBuf::from(&args[1]);
let draft_dir = PathBuf::from(&args[2]);
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
let gamma = arg_usize(&args, "--gamma", DEFAULT_GAMMA);
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
let device = arg_usize(&args, "--device", 0) as u32;
let use_verify_logits = args.iter().any(|a| a == "--use-verify-logits");
let verify_path = parse_verify_path(&args, use_verify_logits);
let dump_verify_mismatches = args.iter().any(|a| a == "--dump-verify-mismatches");
assert!(gen_tokens > 0, "--gen-tokens must be > 0");
assert!(gamma > 0, "--gamma must be > 0");
xserv_cuda::device::set_device(device).unwrap();
let info = xserv_cuda::device::device_info(device).unwrap();
eprintln!(
"GPU {device}: {} ({} MB free)",
info.name,
info.free_memory / 1024 / 1024
);
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
let draft_config = ModelConfig::from_file(&draft_dir.join("config.json"));
assert_qwen3(&target_config, "target");
assert_qwen3(&draft_config, "draft");
assert_eq!(
target_config.vocab_size, draft_config.vocab_size,
"target and draft vocab_size must match"
);
warn_if_tokenizers_differ(&target_dir, &draft_dir);
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
if tokenizer.vocab_size() != target_config.vocab_size {
eprintln!(
"WARNING: tokenizer decoder len {} differs from config vocab_size {}; continuing because token ids come from the shared tokenizer.json",
tokenizer.vocab_size(),
target_config.vocab_size
);
}
eprintln!(
"Loading target Qwen3: layers={} hidden={} heads={}/{} vocab={}",
target_config.num_layers(),
target_config.hidden(),
target_config.num_heads(),
target_config.num_kv_heads(),
target_config.vocab_size
);
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
let target = Qwen3::from_weights(target_config.clone(), target_weights);
xserv_cuda::allocator::cached_trim();
eprintln!(
"Loading draft Qwen3: layers={} hidden={} heads={}/{} vocab={}",
draft_config.num_layers(),
draft_config.hidden(),
draft_config.num_heads(),
draft_config.num_kv_heads(),
draft_config.vocab_size
);
let draft_weights = loader::load_model_dir(&draft_dir, Device::Cuda(device));
let draft = Qwen3::from_weights(draft_config.clone(), draft_weights);
xserv_cuda::allocator::cached_trim();
let warm_ids = tokenizer.encode("warmup");
let warm_tokens = gen_tokens.min(4);
{
let mut target_cache = new_cache(&target_config, max_seq_len, device);
let _ = run_baseline(
&target,
&mut target_cache,
&tokenizer,
&warm_ids,
warm_tokens,
);
}
{
let mut target_cache = new_cache_with_rows(
&target_config,
max_seq_len,
device,
if use_verify_logits { gamma } else { 1 },
);
let mut target_verify_cache =
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
let _ = run_speculative(
&target,
&draft,
&mut target_cache,
&mut target_verify_cache,
&mut draft_cache,
&tokenizer,
&warm_ids,
warm_tokens,
gamma,
use_verify_logits,
verify_path,
dump_verify_mismatches,
);
}
eprintln!(
"Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}, acceptance_mode={}, verify_path={}",
if use_verify_logits {
"verify_logits"
} else {
"decode"
},
verify_path.as_str()
);
let mut totals = Totals::default();
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
let ids = tokenizer.encode(prompt);
validate_length_budget(&ids, gen_tokens, max_seq_len, prompt);
let mut baseline_cache = new_cache(&target_config, max_seq_len, device);
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
drop(baseline_cache);
let mut target_cache = new_cache_with_rows(
&target_config,
max_seq_len,
device,
if use_verify_logits { gamma } else { 1 },
);
let mut target_verify_cache =
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
let spec = run_speculative(
&target,
&draft,
&mut target_cache,
&mut target_verify_cache,
&mut draft_cache,
&tokenizer,
&ids,
gen_tokens,
gamma,
use_verify_logits,
verify_path,
dump_verify_mismatches,
);
let matched = baseline.ids == spec.ids;
if !matched {
totals.mismatches += 1;
eprintln!("MISMATCH prompt {i}: {prompt}");
eprintln!(" baseline: {:?}", baseline.ids);
eprintln!(" spec: {:?}", spec.ids);
}
println!(
"prompt={:02} match={} gen={} accept={}/{} target_steps={} \
baseline_e2e_tpot_ms={:.3} spec_e2e_tpot_ms={:.3}",
i,
matched,
spec.ids.len(),
spec.accepted,
spec.proposed,
spec.target_steps,
per_token_ms(baseline.total_s, baseline.ids.len()),
per_token_ms(spec.total_s, spec.ids.len()),
);
totals.prompts += 1;
totals.baseline_generated += baseline.ids.len();
totals.spec_generated += spec.ids.len();
totals.baseline_total_s += baseline.total_s;
totals.baseline_prefill_s += baseline.prefill_s;
totals.baseline_decode_s += baseline.decode_s;
totals.spec_total_s += spec.total_s;
totals.spec_prefill_s += spec.prefill_s;
totals.spec_decode_s += spec.decode_s;
totals.spec_target_steps += spec.target_steps;
totals.spec_accepted += spec.accepted;
totals.spec_proposed += spec.proposed;
totals.spec_verify_steps += spec.verify_steps;
totals.spec_mirror_steps += spec.mirror_steps;
totals.spec_commit_steps += spec.commit_steps;
totals.spec_correction_steps += spec.correction_steps;
totals.spec_verify_decode_mismatches += spec.verify_decode_mismatches;
}
let baseline_decode_tokens = totals.baseline_generated;
let spec_decode_tokens = totals.spec_generated;
let acceptance = ratio(totals.spec_accepted, totals.spec_proposed);
let tokens_per_target_step = ratio(totals.spec_generated, totals.spec_target_steps);
let matched =
totals.mismatches == 0 && (!use_verify_logits || totals.spec_verify_decode_mismatches == 0);
println!("--- SUMMARY ---");
println!("prompts={} matched={matched}", totals.prompts);
println!(
"acceptance_mode={}",
if use_verify_logits {
"verify_logits"
} else {
"decode"
}
);
println!("verify_path={}", verify_path.as_str());
println!(
"acceptance_rate={:.4} accepted={} proposed={}",
acceptance, totals.spec_accepted, totals.spec_proposed
);
println!(
"tokens_per_target_step={:.4} target_steps={} verify_steps={} mirror_decode_steps={} commit_decode_steps={} correction_steps={}",
tokens_per_target_step,
totals.spec_target_steps,
totals.spec_verify_steps,
totals.spec_mirror_steps,
totals.spec_commit_steps,
totals.spec_correction_steps
);
println!(
"verify_decode_mismatches={}",
totals.spec_verify_decode_mismatches
);
println!(
"baseline_e2e_tpot_ms={:.3} baseline_e2e_tok_s={:.3}",
per_token_ms(totals.baseline_total_s, totals.baseline_generated),
tok_s(totals.baseline_generated, totals.baseline_total_s)
);
println!(
"spec_e2e_tpot_ms={:.3} spec_e2e_tok_s={:.3} speedup_e2e={:.4}",
per_token_ms(totals.spec_total_s, totals.spec_generated),
tok_s(totals.spec_generated, totals.spec_total_s),
speedup(totals.baseline_total_s, totals.spec_total_s)
);
println!(
"baseline_decode_tpot_ms={:.3} baseline_decode_tok_s={:.3}",
per_token_ms(totals.baseline_decode_s, baseline_decode_tokens),
tok_s(baseline_decode_tokens, totals.baseline_decode_s)
);
println!(
"spec_decode_tpot_ms={:.3} spec_decode_tok_s={:.3} speedup_decode={:.4}",
per_token_ms(totals.spec_decode_s, spec_decode_tokens),
tok_s(spec_decode_tokens, totals.spec_decode_s),
speedup(totals.baseline_decode_s, totals.spec_decode_s)
);
println!(
"decode_token_counts baseline={} spec={}",
baseline_decode_tokens, spec_decode_tokens
);
if !matched {
std::process::exit(2);
}
}
fn run_baseline(
model: &Qwen3,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_ids: &[u32],
gen_tokens: usize,
) -> RunStats {
let slot = 0;
cache.register_sequence(slot).expect("register target slot");
let t0 = Instant::now();
let prefill_start = Instant::now();
let logits = model.forward_prefill_paged(prompt_ids, slot, cache);
sync_device();
let prefill_s = prefill_start.elapsed().as_secs_f64();
let mut generated = Vec::with_capacity(gen_tokens);
let mut next = last_argmax(&logits);
generated.push(next);
let decode_start = Instant::now();
let mut target_steps = 0usize;
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
let pos = cache.seq_len(slot);
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache);
target_steps += 1;
next = last_argmax(&logits);
generated.push(next);
}
sync_device();
let decode_s = decode_start.elapsed().as_secs_f64();
sync_device();
let total_s = t0.elapsed().as_secs_f64();
cache.free_sequence(slot);
RunStats {
ids: generated,
total_s,
prefill_s,
decode_s,
target_steps,
..Default::default()
}
}
#[allow(clippy::too_many_arguments)]
fn run_speculative(
target: &Qwen3,
draft: &Qwen3,
target_cache: &mut PagedKVCache,
target_verify_cache: &mut PagedKVCache,
draft_cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_ids: &[u32],
gen_tokens: usize,
gamma: usize,
use_verify_logits: bool,
verify_path: VerifyPath,
dump_verify_mismatches: bool,
) -> RunStats {
let slot = 0;
target_cache
.register_sequence(slot)
.expect("register target slot");
target_verify_cache
.register_sequence(slot)
.expect("register target verify slot");
draft_cache
.register_sequence(slot)
.expect("register draft slot");
let t0 = Instant::now();
let prefill_start = Instant::now();
let target_logits = target.forward_prefill_paged(prompt_ids, slot, target_cache);
if !use_verify_logits {
let _ = target.forward_prefill_paged(prompt_ids, slot, target_verify_cache);
}
let draft_logits = draft.forward_prefill_paged(prompt_ids, slot, draft_cache);
sync_device();
let prefill_s = prefill_start.elapsed().as_secs_f64();
let mut target_next = last_argmax(&target_logits);
let mut draft_next = last_argmax(&draft_logits);
let mut generated = Vec::with_capacity(gen_tokens);
let mut accepted_total = 0usize;
let mut proposed_total = 0usize;
let mut verify_steps = 0usize;
let mut mirror_steps = 0usize;
let mut commit_steps = 0usize;
let mut correction_steps = 0usize;
let mut verify_decode_mismatches = 0usize;
let decode_start = Instant::now();
while generated.len() < gen_tokens {
let remaining = gen_tokens - generated.len();
let round_gamma = gamma.min(remaining);
let round_start_len = target_cache.seq_len(slot);
assert_eq!(
round_start_len,
draft_cache.seq_len(slot),
"target and draft cache lengths diverged"
);
if !use_verify_logits {
assert_eq!(
round_start_len,
target_verify_cache.seq_len(slot),
"target verify cache length diverged"
);
}
let mut draft_tokens = Vec::with_capacity(round_gamma);
for _ in 0..round_gamma {
let token = draft_next;
draft_tokens.push(token);
if tokenizer.is_eos(token) {
break;
}
let pos = draft_cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], draft_cache);
draft_next = last_argmax(&logits);
}
proposed_total += draft_tokens.len();
if use_verify_logits {
verify_steps += 1;
let verify_logits =
target.forward_verify_paged_decode_attention(&draft_tokens, slot, target_cache);
let verify_argmax = argmax_rows(&verify_logits);
assert_eq!(
verify_argmax.len(),
draft_tokens.len(),
"verify logits rows must match draft token count"
);
let mut accepted = 0usize;
let mut done = false;
while accepted < draft_tokens.len() {
let expected = if accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
if draft_tokens[accepted] != expected {
break;
}
let token = draft_tokens[accepted];
generated.push(token);
accepted_total += 1;
accepted += 1;
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
done = true;
break;
}
}
if accepted > 0 {
target_next = verify_argmax[accepted - 1];
}
target_cache
.truncate_sequence(slot, round_start_len + accepted)
.unwrap();
if done {
draft_cache
.truncate_sequence(slot, target_cache.seq_len(slot))
.unwrap();
break;
}
if accepted == draft_tokens.len() {
continue;
}
let correction = if accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
generated.push(correction);
draft_cache
.truncate_sequence(slot, round_start_len)
.unwrap();
replay_draft_tokens(
draft,
draft_cache,
slot,
&draft_tokens[..accepted],
&mut draft_next,
);
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
break;
}
let pos = target_cache.seq_len(slot);
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
target_next = last_argmax(&logits);
commit_steps += 1;
let pos = draft_cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
draft_next = last_argmax(&logits);
correction_steps += 1;
continue;
}
verify_steps += 1;
let verify_logits = match verify_path {
VerifyPath::Flash => {
target.forward_prefill_paged(&draft_tokens, slot, target_verify_cache)
}
VerifyPath::PagedDecode => target.forward_verify_paged_decode_attention(
&draft_tokens,
slot,
target_verify_cache,
),
};
let verify_argmax = argmax_rows(&verify_logits);
assert_eq!(
verify_argmax.len(),
draft_tokens.len(),
"verify logits rows must match draft token count"
);
target_verify_cache
.truncate_sequence(slot, round_start_len)
.unwrap();
let mut accepted = 0usize;
let mut done = false;
while accepted < draft_tokens.len() {
let expected = if use_verify_logits && accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
if draft_tokens[accepted] != expected {
break;
}
let token_idx = accepted;
let token = draft_tokens[token_idx];
generated.push(token);
accepted_total += 1;
accepted += 1;
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
done = true;
break;
}
let pos = target_cache.seq_len(slot);
let logits = target.forward_decode_paged(&[token], &[pos], &[slot], target_cache);
let decode_next = last_argmax(&logits);
if verify_argmax[token_idx] != decode_next {
verify_decode_mismatches += 1;
eprintln!(
"VERIFY/DECODE MISMATCH at cache_len={} accepted_idx={}: verify={} decode={}",
target_cache.seq_len(slot),
token_idx,
verify_argmax[token_idx],
decode_next
);
if dump_verify_mismatches {
eprintln!(
" verify_top5={} decode_top5={}",
format_topk(&verify_logits, token_idx, 5),
format_topk(&logits, 0, 5)
);
}
}
target_next = decode_next;
commit_steps += 1;
advance_target_cache(target, target_verify_cache, slot, token);
mirror_steps += 1;
}
if done {
draft_cache
.truncate_sequence(slot, target_cache.seq_len(slot))
.unwrap();
target_verify_cache
.truncate_sequence(slot, target_cache.seq_len(slot))
.unwrap();
break;
}
if accepted == draft_tokens.len() {
continue;
}
let correction = if use_verify_logits && accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
generated.push(correction);
draft_cache
.truncate_sequence(slot, round_start_len)
.unwrap();
replay_draft_tokens(
draft,
draft_cache,
slot,
&draft_tokens[..accepted],
&mut draft_next,
);
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
break;
}
let pos = target_cache.seq_len(slot);
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
target_next = last_argmax(&logits);
commit_steps += 1;
advance_target_cache(target, target_verify_cache, slot, correction);
mirror_steps += 1;
let pos = draft_cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
draft_next = last_argmax(&logits);
correction_steps += 1;
}
sync_device();
let decode_s = decode_start.elapsed().as_secs_f64();
sync_device();
let total_s = t0.elapsed().as_secs_f64();
target_cache.free_sequence(slot);
target_verify_cache.free_sequence(slot);
draft_cache.free_sequence(slot);
RunStats {
ids: generated,
total_s,
prefill_s,
decode_s,
target_steps: verify_steps + mirror_steps + commit_steps + correction_steps,
accepted: accepted_total,
proposed: proposed_total,
verify_steps,
mirror_steps,
commit_steps,
correction_steps,
verify_decode_mismatches,
}
}
fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, token: u32) {
let pos = cache.seq_len(slot);
let _ = target.forward_decode_paged(&[token], &[pos], &[slot], cache);
}
fn replay_draft_tokens(
draft: &Qwen3,
cache: &mut PagedKVCache,
slot: usize,
tokens: &[u32],
next: &mut u32,
) {
for &token in tokens {
let pos = cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], cache);
*next = last_argmax(&logits);
}
}
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
new_cache_with_rows(config, max_seq_len, device, 1)
}
fn new_cache_with_rows(
config: &ModelConfig,
max_seq_len: usize,
device: u32,
max_rows: usize,
) -> PagedKVCache {
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
let total_blocks = max_blocks_per_seq + 8;
PagedKVCache::new(
config,
total_blocks,
0,
max_rows.max(1),
max_blocks_per_seq,
DType::BF16,
device,
)
}
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
assert_eq!(logits.ndim(), 2);
if logits.dtype() == DType::BF16
&& matches!(logits.device(), Device::Cuda(_))
&& logits.is_contiguous()
{
return xserv_kernels::argmax_bf16_to_host(logits);
}
let vocab_size = logits.shape()[1];
let rows = logits.shape()[0];
let logits_cpu = logits.to_device(Device::Cpu);
match logits.dtype() {
DType::F32 => logits_cpu
.as_slice::<f32>()
.chunks_exact(vocab_size)
.take(rows)
.map(argmax_f32)
.collect(),
DType::BF16 => logits_cpu
.as_slice::<bf16>()
.chunks_exact(vocab_size)
.take(rows)
.map(|row| {
row.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32)
.unwrap()
})
.collect(),
_ => panic!("unsupported dtype for argmax: {:?}", logits.dtype()),
}
}
fn last_argmax(logits: &Tensor) -> u32 {
*argmax_rows(logits).last().unwrap()
}
fn argmax_f32(row: &[f32]) -> u32 {
row.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i as u32)
.unwrap()
}
fn format_topk(logits: &Tensor, row: usize, k: usize) -> String {
let vals = topk_row(logits, row, k);
vals.iter()
.map(|(id, val)| format!("{id}:{val:.3}"))
.collect::<Vec<_>>()
.join(",")
}
fn topk_row(logits: &Tensor, row: usize, k: usize) -> Vec<(u32, f32)> {
assert_eq!(logits.ndim(), 2);
let vocab_size = logits.shape()[1];
assert!(row < logits.shape()[0], "topk row out of bounds");
let logits_cpu = logits.to_device(Device::Cpu);
let mut vals: Vec<(u32, f32)> = match logits.dtype() {
DType::F32 => logits_cpu.as_slice::<f32>()[row * vocab_size..(row + 1) * vocab_size]
.iter()
.enumerate()
.map(|(i, &v)| (i as u32, v))
.collect(),
DType::BF16 => logits_cpu.as_slice::<bf16>()[row * vocab_size..(row + 1) * vocab_size]
.iter()
.enumerate()
.map(|(i, &v)| (i as u32, v.to_f32()))
.collect(),
_ => panic!("unsupported dtype for topk: {:?}", logits.dtype()),
};
vals.select_nth_unstable_by(k, |a, b| b.1.partial_cmp(&a.1).unwrap());
vals.truncate(k);
vals.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
vals
}
fn assert_qwen3(config: &ModelConfig, name: &str) {
let model_type = config.model_type.as_deref().unwrap_or("unknown");
assert!(
model_type.contains("qwen"),
"{name} model_type must be qwen-like, got {model_type}"
);
}
fn warn_if_tokenizers_differ(target_dir: &Path, draft_dir: &Path) {
let target = std::fs::read(target_dir.join("tokenizer.json"));
let draft = std::fs::read(draft_dir.join("tokenizer.json"));
if let (Ok(target), Ok(draft)) = (target, draft) {
if target != draft {
eprintln!(
"WARNING: target and draft tokenizer.json differ; v0 assumes identical token ids"
);
}
}
}
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
args.iter()
.position(|a| a == flag)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn parse_verify_path(args: &[String], use_verify_logits: bool) -> VerifyPath {
let default = if use_verify_logits {
VerifyPath::PagedDecode
} else {
VerifyPath::Flash
};
let Some(value) = args
.iter()
.position(|a| a == "--verify-path")
.and_then(|i| args.get(i + 1))
else {
return default;
};
match value.as_str() {
"flash" => VerifyPath::Flash,
"paged-decode" => VerifyPath::PagedDecode,
_ => {
eprintln!("unknown --verify-path {value:?}; expected flash or paged-decode");
std::process::exit(1);
}
}
}
fn validate_length_budget(prompt_ids: &[u32], gen_tokens: usize, max_seq_len: usize, prompt: &str) {
let required = prompt_ids.len() + gen_tokens;
if required > max_seq_len {
eprintln!(
"prompt requires prompt_len({}) + gen_tokens({}) = {} tokens, exceeding --max-seq-len {}: {:?}",
prompt_ids.len(),
gen_tokens,
required,
max_seq_len,
prompt
);
std::process::exit(2);
}
}
fn sync_device() {
xserv_cuda::device::synchronize().expect("cuda device synchronize");
}
fn ratio(num: usize, den: usize) -> f64 {
if den == 0 {
0.0
} else {
num as f64 / den as f64
}
}
fn speedup(baseline_s: f64, spec_s: f64) -> f64 {
if spec_s == 0.0 {
0.0
} else {
baseline_s / spec_s
}
}
fn tok_s(tokens: usize, seconds: f64) -> f64 {
if seconds == 0.0 {
0.0
} else {
tokens as f64 / seconds
}
}
fn per_token_ms(seconds: f64, tokens: usize) -> f64 {
if tokens == 0 {
0.0
} else {
seconds * 1000.0 / tokens as f64
}
}

View File

@@ -486,6 +486,35 @@ impl PagedKVCache {
state.seq_len += num_tokens;
}
/// Roll a registered sequence back to `new_len` tokens.
///
/// This only changes cache metadata and frees whole physical blocks that are
/// no longer reachable. Bytes inside retained blocks are left untouched; the
/// logical `seq_len` prevents attention from reading them, and later writes
/// to the same positions overwrite them.
pub fn truncate_sequence(&mut self, slot: usize, new_len: usize) -> Result<(), &'static str> {
if slot >= self.max_seqs {
return Err("truncate_sequence: slot out of range");
}
let state = self.seq_states[slot]
.as_mut()
.ok_or("truncate_sequence: empty slot")?;
if new_len > state.seq_len {
return Err("truncate_sequence: cannot extend");
}
let needed_blocks = ((new_len + BLOCK_SIZE - 1) / BLOCK_SIZE).max(1);
while state.block_ids.len() > needed_blocks {
let block = state.block_ids.pop().expect("checked len");
match state.location {
Location::Gpu => self.allocator.free(block),
Location::Cpu => self.cpu_allocator.free(block),
}
}
state.seq_len = new_len;
Ok(())
}
/// Refresh the host-side block table + context lens from `seq_states`,
/// then upload to GPU. Call once per decode step before the paged kernel.
pub fn sync_to_gpu(&mut self) {
@@ -748,6 +777,71 @@ impl PagedKVCache {
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_config() -> ModelConfig {
serde_json::from_value(serde_json::json!({
"model_type": "qwen3",
"hidden_size": 8,
"intermediate_size": 16,
"num_attention_heads": 1,
"num_key_value_heads": 1,
"num_hidden_layers": 1,
"vocab_size": 32,
"max_position_embeddings": 64
}))
.unwrap()
}
#[test]
fn truncate_sequence_frees_whole_blocks_and_keeps_slot_registered() {
if xserv_cuda::device::set_device(0).is_err() {
eprintln!("skipping CUDA-backed PagedKVCache test: device 0 unavailable");
return;
}
let config = tiny_config();
let mut cache = PagedKVCache::new(&config, 5, 0, 1, 4, DType::BF16, 0);
assert_eq!(
cache.truncate_sequence(1, 0),
Err("truncate_sequence: slot out of range")
);
assert_eq!(
cache.truncate_sequence(0, 0),
Err("truncate_sequence: empty slot")
);
cache.register_sequence(0).unwrap();
cache.ensure_capacity(0, BLOCK_SIZE * 3 + 1);
cache.advance_seq_len(0, BLOCK_SIZE * 3 + 1);
assert_eq!(cache.seq_len(0), BLOCK_SIZE * 3 + 1);
assert_eq!(cache.block_count(0), 4);
assert_eq!(cache.free_blocks(), 0);
cache.truncate_sequence(0, BLOCK_SIZE + 1).unwrap();
assert_eq!(cache.seq_len(0), BLOCK_SIZE + 1);
assert_eq!(cache.block_count(0), 2);
assert_eq!(cache.free_blocks(), 2);
cache.truncate_sequence(0, BLOCK_SIZE).unwrap();
assert_eq!(cache.seq_len(0), BLOCK_SIZE);
assert_eq!(cache.block_count(0), 1);
assert_eq!(cache.free_blocks(), 3);
cache.truncate_sequence(0, 0).unwrap();
assert_eq!(cache.seq_len(0), 0);
assert_eq!(cache.block_count(0), 1);
assert_eq!(cache.free_blocks(), 3);
assert_eq!(
cache.truncate_sequence(0, 1),
Err("truncate_sequence: cannot extend")
);
}
}
unsafe fn tensor_from_owned_buf(
buf: GpuBuffer,
shape: &[usize],

View File

@@ -884,6 +884,109 @@ impl Qwen3 {
matmul_2d(&x, &self.lm_head_t)
}
/// Paged multi-token verify path: write `token_ids` into the paged cache,
/// then verify them with the same paged decode attention kernel used by
/// single-token decode. This keeps greedy top-1 behavior aligned with
/// `forward_decode_paged` while still batching the dense projections/MLP
/// across the draft window.
pub fn forward_verify_paged_decode_attention(
&self,
token_ids: &[u32],
slot: usize,
paged_cache: &mut PagedKVCache,
) -> Tensor {
let new_tokens = token_ids.len();
let pos_offset = paged_cache.seq_len(slot);
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
paged_cache.advance_seq_len(slot, new_tokens);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
let kv_lens: Vec<i32> = (0..new_tokens)
.map(|i| (pos_offset + i + 1) as i32)
.collect();
let slots = vec![slot; new_tokens];
paged_cache.sync_active_batch_with_lens(&slots, &kv_lens);
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
let mut x = embedding(&self.embed_tokens, token_ids);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt);
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim);
let k_all = qkv.narrow(1, q_dim, kv_dim);
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
let q_flat = q_all
.contiguous()
.reshape(&[new_tokens * num_heads, head_dim]);
let k_flat = k_all
.contiguous()
.reshape(&[new_tokens * num_kv_heads, head_dim]);
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
rope_inplace(&q_3d, &self.rope_cache, &positions);
rope_inplace(&k_3d, &self.rope_cache, &positions);
let v_3d = v_all
.contiguous()
.reshape(&[new_tokens, num_kv_heads, head_dim]);
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_decode,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
new_tokens,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
);
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down);
x = add_any(&residual, &down);
}
let x = rmsnorm(&x, &self.norm, eps);
matmul_rows_gemv(&x, &self.lm_head_t)
}
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
let new_tokens = token_ids.len();
@@ -1158,6 +1261,20 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
)
}
/// Run a 2D matmul row by row so each row uses the same GEMV kernel as
/// single-token decode. Used by speculative verify parity, where near-tie
/// logits must follow decode's BF16 rounding path.
fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert!(a.is_contiguous());
let rows = a.shape()[0];
if rows == 1 {
return matmul_2d(a, b);
}
let out_rows: Vec<Tensor> = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect();
concat_rows(&out_rows)
}
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
fn concat_rows(rows: &[Tensor]) -> Tensor {
assert!(!rows.is_empty());