speculative: Qwen3 draft-model v0 with paged verify parity
Phase 22 lands a correctness-only speculative decoding loop for Qwen3 target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns verify logits into the authoritative acceptance signal so mirror-decode per accepted token is no longer needed. - paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered sequence, freeing whole physical blocks no longer reachable and leaving the slot registered. Covered by a CUDA-gated unit test. - qwen3: forward_verify_paged_decode_attention writes the draft window into the target cache, runs the same paged decode attention kernel per draft token, and uses matmul_rows_gemv so linear layers follow the single-token decode BF16 rounding path. - bench-speculative: new bench binary drives the state machine with --gamma / --gen-tokens / --prompts / --use-verify-logits / --verify-path flash|paged-decode / --dump-verify-mismatches, and compares baseline vs spec token sequences plus TPOT / tok/s / speedup. - docs/22 records the decode-authoritative v0 result and dash5 numbers (matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under --use-verify-logits). - docs/23 records the paged-decode verify path (matched=true, verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the next-step performance TODO.
This commit is contained in:
962
crates/xserv-model/src/bin/bench-speculative.rs
Normal file
962
crates/xserv-model/src/bin/bench-speculative.rs
Normal file
@@ -0,0 +1,962 @@
|
||||
//! Draft-model speculative decoding benchmark for Qwen3.
|
||||
//!
|
||||
//! v0 scope:
|
||||
//! - target + draft are Qwen3-family models with the same tokenizer/vocab;
|
||||
//! - batch=1;
|
||||
//! - greedy exact-match acceptance;
|
||||
//! - no probabilistic rejection sampling.
|
||||
|
||||
use half::bf16;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
const DEFAULT_GAMMA: usize = 4;
|
||||
const DEFAULT_GEN_TOKENS: usize = 64;
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 2048;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum VerifyPath {
|
||||
Flash,
|
||||
PagedDecode,
|
||||
}
|
||||
|
||||
impl VerifyPath {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
VerifyPath::Flash => "flash",
|
||||
VerifyPath::PagedDecode => "paged-decode",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const PROMPTS: [&str; 50] = [
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The human brain contains approximately",
|
||||
"Democracy depends on the active participation of",
|
||||
"The train arrived at the station exactly",
|
||||
"Researchers at MIT have developed a new",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
#[derive(Default)]
|
||||
struct RunStats {
|
||||
ids: Vec<u32>,
|
||||
total_s: f64,
|
||||
prefill_s: f64,
|
||||
decode_s: f64,
|
||||
target_steps: usize,
|
||||
accepted: usize,
|
||||
proposed: usize,
|
||||
verify_steps: usize,
|
||||
mirror_steps: usize,
|
||||
commit_steps: usize,
|
||||
correction_steps: usize,
|
||||
verify_decode_mismatches: usize,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Totals {
|
||||
prompts: usize,
|
||||
baseline_generated: usize,
|
||||
spec_generated: usize,
|
||||
baseline_total_s: f64,
|
||||
baseline_prefill_s: f64,
|
||||
baseline_decode_s: f64,
|
||||
spec_total_s: f64,
|
||||
spec_prefill_s: f64,
|
||||
spec_decode_s: f64,
|
||||
spec_target_steps: usize,
|
||||
spec_accepted: usize,
|
||||
spec_proposed: usize,
|
||||
spec_verify_steps: usize,
|
||||
spec_mirror_steps: usize,
|
||||
spec_commit_steps: usize,
|
||||
spec_correction_steps: usize,
|
||||
spec_verify_decode_mismatches: usize,
|
||||
mismatches: usize,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 3 {
|
||||
eprintln!(
|
||||
"Usage: bench-speculative <target-model-dir> <draft-model-dir> \
|
||||
[--gen-tokens N] [--gamma N] [--prompts N] [--max-seq-len N] [--device N] \
|
||||
[--use-verify-logits] [--verify-path flash|paged-decode] [--dump-verify-mismatches]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let target_dir = PathBuf::from(&args[1]);
|
||||
let draft_dir = PathBuf::from(&args[2]);
|
||||
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
|
||||
let gamma = arg_usize(&args, "--gamma", DEFAULT_GAMMA);
|
||||
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
let use_verify_logits = args.iter().any(|a| a == "--use-verify-logits");
|
||||
let verify_path = parse_verify_path(&args, use_verify_logits);
|
||||
let dump_verify_mismatches = args.iter().any(|a| a == "--dump-verify-mismatches");
|
||||
|
||||
assert!(gen_tokens > 0, "--gen-tokens must be > 0");
|
||||
assert!(gamma > 0, "--gamma must be > 0");
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
eprintln!(
|
||||
"GPU {device}: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
|
||||
let draft_config = ModelConfig::from_file(&draft_dir.join("config.json"));
|
||||
assert_qwen3(&target_config, "target");
|
||||
assert_qwen3(&draft_config, "draft");
|
||||
assert_eq!(
|
||||
target_config.vocab_size, draft_config.vocab_size,
|
||||
"target and draft vocab_size must match"
|
||||
);
|
||||
|
||||
warn_if_tokenizers_differ(&target_dir, &draft_dir);
|
||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
if tokenizer.vocab_size() != target_config.vocab_size {
|
||||
eprintln!(
|
||||
"WARNING: tokenizer decoder len {} differs from config vocab_size {}; continuing because token ids come from the shared tokenizer.json",
|
||||
tokenizer.vocab_size(),
|
||||
target_config.vocab_size
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"Loading target Qwen3: layers={} hidden={} heads={}/{} vocab={}",
|
||||
target_config.num_layers(),
|
||||
target_config.hidden(),
|
||||
target_config.num_heads(),
|
||||
target_config.num_kv_heads(),
|
||||
target_config.vocab_size
|
||||
);
|
||||
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
|
||||
let target = Qwen3::from_weights(target_config.clone(), target_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
eprintln!(
|
||||
"Loading draft Qwen3: layers={} hidden={} heads={}/{} vocab={}",
|
||||
draft_config.num_layers(),
|
||||
draft_config.hidden(),
|
||||
draft_config.num_heads(),
|
||||
draft_config.num_kv_heads(),
|
||||
draft_config.vocab_size
|
||||
);
|
||||
let draft_weights = loader::load_model_dir(&draft_dir, Device::Cuda(device));
|
||||
let draft = Qwen3::from_weights(draft_config.clone(), draft_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let warm_ids = tokenizer.encode("warmup");
|
||||
let warm_tokens = gen_tokens.min(4);
|
||||
{
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let _ = run_baseline(
|
||||
&target,
|
||||
&mut target_cache,
|
||||
&tokenizer,
|
||||
&warm_ids,
|
||||
warm_tokens,
|
||||
);
|
||||
}
|
||||
{
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache =
|
||||
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let _ = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&tokenizer,
|
||||
&warm_ids,
|
||||
warm_tokens,
|
||||
gamma,
|
||||
use_verify_logits,
|
||||
verify_path,
|
||||
dump_verify_mismatches,
|
||||
);
|
||||
}
|
||||
eprintln!(
|
||||
"Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}, acceptance_mode={}, verify_path={}",
|
||||
if use_verify_logits {
|
||||
"verify_logits"
|
||||
} else {
|
||||
"decode"
|
||||
},
|
||||
verify_path.as_str()
|
||||
);
|
||||
|
||||
let mut totals = Totals::default();
|
||||
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
|
||||
let ids = tokenizer.encode(prompt);
|
||||
validate_length_budget(&ids, gen_tokens, max_seq_len, prompt);
|
||||
let mut baseline_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
|
||||
drop(baseline_cache);
|
||||
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache =
|
||||
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let spec = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
gamma,
|
||||
use_verify_logits,
|
||||
verify_path,
|
||||
dump_verify_mismatches,
|
||||
);
|
||||
|
||||
let matched = baseline.ids == spec.ids;
|
||||
if !matched {
|
||||
totals.mismatches += 1;
|
||||
eprintln!("MISMATCH prompt {i}: {prompt}");
|
||||
eprintln!(" baseline: {:?}", baseline.ids);
|
||||
eprintln!(" spec: {:?}", spec.ids);
|
||||
}
|
||||
|
||||
println!(
|
||||
"prompt={:02} match={} gen={} accept={}/{} target_steps={} \
|
||||
baseline_e2e_tpot_ms={:.3} spec_e2e_tpot_ms={:.3}",
|
||||
i,
|
||||
matched,
|
||||
spec.ids.len(),
|
||||
spec.accepted,
|
||||
spec.proposed,
|
||||
spec.target_steps,
|
||||
per_token_ms(baseline.total_s, baseline.ids.len()),
|
||||
per_token_ms(spec.total_s, spec.ids.len()),
|
||||
);
|
||||
|
||||
totals.prompts += 1;
|
||||
totals.baseline_generated += baseline.ids.len();
|
||||
totals.spec_generated += spec.ids.len();
|
||||
totals.baseline_total_s += baseline.total_s;
|
||||
totals.baseline_prefill_s += baseline.prefill_s;
|
||||
totals.baseline_decode_s += baseline.decode_s;
|
||||
totals.spec_total_s += spec.total_s;
|
||||
totals.spec_prefill_s += spec.prefill_s;
|
||||
totals.spec_decode_s += spec.decode_s;
|
||||
totals.spec_target_steps += spec.target_steps;
|
||||
totals.spec_accepted += spec.accepted;
|
||||
totals.spec_proposed += spec.proposed;
|
||||
totals.spec_verify_steps += spec.verify_steps;
|
||||
totals.spec_mirror_steps += spec.mirror_steps;
|
||||
totals.spec_commit_steps += spec.commit_steps;
|
||||
totals.spec_correction_steps += spec.correction_steps;
|
||||
totals.spec_verify_decode_mismatches += spec.verify_decode_mismatches;
|
||||
}
|
||||
|
||||
let baseline_decode_tokens = totals.baseline_generated;
|
||||
let spec_decode_tokens = totals.spec_generated;
|
||||
let acceptance = ratio(totals.spec_accepted, totals.spec_proposed);
|
||||
let tokens_per_target_step = ratio(totals.spec_generated, totals.spec_target_steps);
|
||||
let matched =
|
||||
totals.mismatches == 0 && (!use_verify_logits || totals.spec_verify_decode_mismatches == 0);
|
||||
|
||||
println!("--- SUMMARY ---");
|
||||
println!("prompts={} matched={matched}", totals.prompts);
|
||||
println!(
|
||||
"acceptance_mode={}",
|
||||
if use_verify_logits {
|
||||
"verify_logits"
|
||||
} else {
|
||||
"decode"
|
||||
}
|
||||
);
|
||||
println!("verify_path={}", verify_path.as_str());
|
||||
println!(
|
||||
"acceptance_rate={:.4} accepted={} proposed={}",
|
||||
acceptance, totals.spec_accepted, totals.spec_proposed
|
||||
);
|
||||
println!(
|
||||
"tokens_per_target_step={:.4} target_steps={} verify_steps={} mirror_decode_steps={} commit_decode_steps={} correction_steps={}",
|
||||
tokens_per_target_step,
|
||||
totals.spec_target_steps,
|
||||
totals.spec_verify_steps,
|
||||
totals.spec_mirror_steps,
|
||||
totals.spec_commit_steps,
|
||||
totals.spec_correction_steps
|
||||
);
|
||||
println!(
|
||||
"verify_decode_mismatches={}",
|
||||
totals.spec_verify_decode_mismatches
|
||||
);
|
||||
println!(
|
||||
"baseline_e2e_tpot_ms={:.3} baseline_e2e_tok_s={:.3}",
|
||||
per_token_ms(totals.baseline_total_s, totals.baseline_generated),
|
||||
tok_s(totals.baseline_generated, totals.baseline_total_s)
|
||||
);
|
||||
println!(
|
||||
"spec_e2e_tpot_ms={:.3} spec_e2e_tok_s={:.3} speedup_e2e={:.4}",
|
||||
per_token_ms(totals.spec_total_s, totals.spec_generated),
|
||||
tok_s(totals.spec_generated, totals.spec_total_s),
|
||||
speedup(totals.baseline_total_s, totals.spec_total_s)
|
||||
);
|
||||
println!(
|
||||
"baseline_decode_tpot_ms={:.3} baseline_decode_tok_s={:.3}",
|
||||
per_token_ms(totals.baseline_decode_s, baseline_decode_tokens),
|
||||
tok_s(baseline_decode_tokens, totals.baseline_decode_s)
|
||||
);
|
||||
println!(
|
||||
"spec_decode_tpot_ms={:.3} spec_decode_tok_s={:.3} speedup_decode={:.4}",
|
||||
per_token_ms(totals.spec_decode_s, spec_decode_tokens),
|
||||
tok_s(spec_decode_tokens, totals.spec_decode_s),
|
||||
speedup(totals.baseline_decode_s, totals.spec_decode_s)
|
||||
);
|
||||
println!(
|
||||
"decode_token_counts baseline={} spec={}",
|
||||
baseline_decode_tokens, spec_decode_tokens
|
||||
);
|
||||
|
||||
if !matched {
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_baseline(
|
||||
model: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).expect("register target slot");
|
||||
|
||||
let t0 = Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
let logits = model.forward_prefill_paged(prompt_ids, slot, cache);
|
||||
sync_device();
|
||||
let prefill_s = prefill_start.elapsed().as_secs_f64();
|
||||
|
||||
let mut generated = Vec::with_capacity(gen_tokens);
|
||||
let mut next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
|
||||
let decode_start = Instant::now();
|
||||
let mut target_steps = 0usize;
|
||||
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache);
|
||||
target_steps += 1;
|
||||
next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
}
|
||||
sync_device();
|
||||
let decode_s = decode_start.elapsed().as_secs_f64();
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
|
||||
cache.free_sequence(slot);
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
prefill_s,
|
||||
decode_s,
|
||||
target_steps,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_speculative(
|
||||
target: &Qwen3,
|
||||
draft: &Qwen3,
|
||||
target_cache: &mut PagedKVCache,
|
||||
target_verify_cache: &mut PagedKVCache,
|
||||
draft_cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
gamma: usize,
|
||||
use_verify_logits: bool,
|
||||
verify_path: VerifyPath,
|
||||
dump_verify_mismatches: bool,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
target_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register target slot");
|
||||
target_verify_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register target verify slot");
|
||||
draft_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register draft slot");
|
||||
|
||||
let t0 = Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
let target_logits = target.forward_prefill_paged(prompt_ids, slot, target_cache);
|
||||
if !use_verify_logits {
|
||||
let _ = target.forward_prefill_paged(prompt_ids, slot, target_verify_cache);
|
||||
}
|
||||
let draft_logits = draft.forward_prefill_paged(prompt_ids, slot, draft_cache);
|
||||
sync_device();
|
||||
let prefill_s = prefill_start.elapsed().as_secs_f64();
|
||||
|
||||
let mut target_next = last_argmax(&target_logits);
|
||||
let mut draft_next = last_argmax(&draft_logits);
|
||||
let mut generated = Vec::with_capacity(gen_tokens);
|
||||
let mut accepted_total = 0usize;
|
||||
let mut proposed_total = 0usize;
|
||||
let mut verify_steps = 0usize;
|
||||
let mut mirror_steps = 0usize;
|
||||
let mut commit_steps = 0usize;
|
||||
let mut correction_steps = 0usize;
|
||||
let mut verify_decode_mismatches = 0usize;
|
||||
|
||||
let decode_start = Instant::now();
|
||||
while generated.len() < gen_tokens {
|
||||
let remaining = gen_tokens - generated.len();
|
||||
let round_gamma = gamma.min(remaining);
|
||||
let round_start_len = target_cache.seq_len(slot);
|
||||
assert_eq!(
|
||||
round_start_len,
|
||||
draft_cache.seq_len(slot),
|
||||
"target and draft cache lengths diverged"
|
||||
);
|
||||
if !use_verify_logits {
|
||||
assert_eq!(
|
||||
round_start_len,
|
||||
target_verify_cache.seq_len(slot),
|
||||
"target verify cache length diverged"
|
||||
);
|
||||
}
|
||||
|
||||
let mut draft_tokens = Vec::with_capacity(round_gamma);
|
||||
for _ in 0..round_gamma {
|
||||
let token = draft_next;
|
||||
draft_tokens.push(token);
|
||||
if tokenizer.is_eos(token) {
|
||||
break;
|
||||
}
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
}
|
||||
proposed_total += draft_tokens.len();
|
||||
|
||||
if use_verify_logits {
|
||||
verify_steps += 1;
|
||||
let verify_logits =
|
||||
target.forward_verify_paged_decode_attention(&draft_tokens, slot, target_cache);
|
||||
let verify_argmax = argmax_rows(&verify_logits);
|
||||
assert_eq!(
|
||||
verify_argmax.len(),
|
||||
draft_tokens.len(),
|
||||
"verify logits rows must match draft token count"
|
||||
);
|
||||
|
||||
let mut accepted = 0usize;
|
||||
let mut done = false;
|
||||
while accepted < draft_tokens.len() {
|
||||
let expected = if accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
if draft_tokens[accepted] != expected {
|
||||
break;
|
||||
}
|
||||
let token = draft_tokens[accepted];
|
||||
generated.push(token);
|
||||
accepted_total += 1;
|
||||
accepted += 1;
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if accepted > 0 {
|
||||
target_next = verify_argmax[accepted - 1];
|
||||
}
|
||||
target_cache
|
||||
.truncate_sequence(slot, round_start_len + accepted)
|
||||
.unwrap();
|
||||
|
||||
if done {
|
||||
draft_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
break;
|
||||
}
|
||||
|
||||
if accepted == draft_tokens.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let correction = if accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
generated.push(correction);
|
||||
|
||||
draft_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
&mut draft_next,
|
||||
);
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
|
||||
target_next = last_argmax(&logits);
|
||||
commit_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
verify_steps += 1;
|
||||
let verify_logits = match verify_path {
|
||||
VerifyPath::Flash => {
|
||||
target.forward_prefill_paged(&draft_tokens, slot, target_verify_cache)
|
||||
}
|
||||
VerifyPath::PagedDecode => target.forward_verify_paged_decode_attention(
|
||||
&draft_tokens,
|
||||
slot,
|
||||
target_verify_cache,
|
||||
),
|
||||
};
|
||||
let verify_argmax = argmax_rows(&verify_logits);
|
||||
assert_eq!(
|
||||
verify_argmax.len(),
|
||||
draft_tokens.len(),
|
||||
"verify logits rows must match draft token count"
|
||||
);
|
||||
|
||||
target_verify_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
|
||||
let mut accepted = 0usize;
|
||||
let mut done = false;
|
||||
while accepted < draft_tokens.len() {
|
||||
let expected = if use_verify_logits && accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
if draft_tokens[accepted] != expected {
|
||||
break;
|
||||
}
|
||||
let token_idx = accepted;
|
||||
let token = draft_tokens[token_idx];
|
||||
generated.push(token);
|
||||
accepted_total += 1;
|
||||
accepted += 1;
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[token], &[pos], &[slot], target_cache);
|
||||
let decode_next = last_argmax(&logits);
|
||||
if verify_argmax[token_idx] != decode_next {
|
||||
verify_decode_mismatches += 1;
|
||||
eprintln!(
|
||||
"VERIFY/DECODE MISMATCH at cache_len={} accepted_idx={}: verify={} decode={}",
|
||||
target_cache.seq_len(slot),
|
||||
token_idx,
|
||||
verify_argmax[token_idx],
|
||||
decode_next
|
||||
);
|
||||
if dump_verify_mismatches {
|
||||
eprintln!(
|
||||
" verify_top5={} decode_top5={}",
|
||||
format_topk(&verify_logits, token_idx, 5),
|
||||
format_topk(&logits, 0, 5)
|
||||
);
|
||||
}
|
||||
}
|
||||
target_next = decode_next;
|
||||
commit_steps += 1;
|
||||
|
||||
advance_target_cache(target, target_verify_cache, slot, token);
|
||||
mirror_steps += 1;
|
||||
}
|
||||
if done {
|
||||
draft_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
target_verify_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
break;
|
||||
}
|
||||
|
||||
if accepted == draft_tokens.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let correction = if use_verify_logits && accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
generated.push(correction);
|
||||
|
||||
draft_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
&mut draft_next,
|
||||
);
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
|
||||
target_next = last_argmax(&logits);
|
||||
commit_steps += 1;
|
||||
|
||||
advance_target_cache(target, target_verify_cache, slot, correction);
|
||||
mirror_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
}
|
||||
sync_device();
|
||||
let decode_s = decode_start.elapsed().as_secs_f64();
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
|
||||
target_cache.free_sequence(slot);
|
||||
target_verify_cache.free_sequence(slot);
|
||||
draft_cache.free_sequence(slot);
|
||||
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
prefill_s,
|
||||
decode_s,
|
||||
target_steps: verify_steps + mirror_steps + commit_steps + correction_steps,
|
||||
accepted: accepted_total,
|
||||
proposed: proposed_total,
|
||||
verify_steps,
|
||||
mirror_steps,
|
||||
commit_steps,
|
||||
correction_steps,
|
||||
verify_decode_mismatches,
|
||||
}
|
||||
}
|
||||
|
||||
fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, token: u32) {
|
||||
let pos = cache.seq_len(slot);
|
||||
let _ = target.forward_decode_paged(&[token], &[pos], &[slot], cache);
|
||||
}
|
||||
|
||||
fn replay_draft_tokens(
|
||||
draft: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
slot: usize,
|
||||
tokens: &[u32],
|
||||
next: &mut u32,
|
||||
) {
|
||||
for &token in tokens {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], cache);
|
||||
*next = last_argmax(&logits);
|
||||
}
|
||||
}
|
||||
|
||||
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
||||
new_cache_with_rows(config, max_seq_len, device, 1)
|
||||
}
|
||||
|
||||
fn new_cache_with_rows(
|
||||
config: &ModelConfig,
|
||||
max_seq_len: usize,
|
||||
device: u32,
|
||||
max_rows: usize,
|
||||
) -> PagedKVCache {
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
PagedKVCache::new(
|
||||
config,
|
||||
total_blocks,
|
||||
0,
|
||||
max_rows.max(1),
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
if logits.dtype() == DType::BF16
|
||||
&& matches!(logits.device(), Device::Cuda(_))
|
||||
&& logits.is_contiguous()
|
||||
{
|
||||
return xserv_kernels::argmax_bf16_to_host(logits);
|
||||
}
|
||||
|
||||
let vocab_size = logits.shape()[1];
|
||||
let rows = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
match logits.dtype() {
|
||||
DType::F32 => logits_cpu
|
||||
.as_slice::<f32>()
|
||||
.chunks_exact(vocab_size)
|
||||
.take(rows)
|
||||
.map(argmax_f32)
|
||||
.collect(),
|
||||
DType::BF16 => logits_cpu
|
||||
.as_slice::<bf16>()
|
||||
.chunks_exact(vocab_size)
|
||||
.take(rows)
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
})
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for argmax: {:?}", logits.dtype()),
|
||||
}
|
||||
}
|
||||
|
||||
fn last_argmax(logits: &Tensor) -> u32 {
|
||||
*argmax_rows(logits).last().unwrap()
|
||||
}
|
||||
|
||||
fn argmax_f32(row: &[f32]) -> u32 {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn format_topk(logits: &Tensor, row: usize, k: usize) -> String {
|
||||
let vals = topk_row(logits, row, k);
|
||||
vals.iter()
|
||||
.map(|(id, val)| format!("{id}:{val:.3}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
}
|
||||
|
||||
fn topk_row(logits: &Tensor, row: usize, k: usize) -> Vec<(u32, f32)> {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let vocab_size = logits.shape()[1];
|
||||
assert!(row < logits.shape()[0], "topk row out of bounds");
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let mut vals: Vec<(u32, f32)> = match logits.dtype() {
|
||||
DType::F32 => logits_cpu.as_slice::<f32>()[row * vocab_size..(row + 1) * vocab_size]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i as u32, v))
|
||||
.collect(),
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()[row * vocab_size..(row + 1) * vocab_size]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i as u32, v.to_f32()))
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for topk: {:?}", logits.dtype()),
|
||||
};
|
||||
vals.select_nth_unstable_by(k, |a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
vals.truncate(k);
|
||||
vals.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
vals
|
||||
}
|
||||
|
||||
fn assert_qwen3(config: &ModelConfig, name: &str) {
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
assert!(
|
||||
model_type.contains("qwen"),
|
||||
"{name} model_type must be qwen-like, got {model_type}"
|
||||
);
|
||||
}
|
||||
|
||||
fn warn_if_tokenizers_differ(target_dir: &Path, draft_dir: &Path) {
|
||||
let target = std::fs::read(target_dir.join("tokenizer.json"));
|
||||
let draft = std::fs::read(draft_dir.join("tokenizer.json"));
|
||||
if let (Ok(target), Ok(draft)) = (target, draft) {
|
||||
if target != draft {
|
||||
eprintln!(
|
||||
"WARNING: target and draft tokenizer.json differ; v0 assumes identical token ids"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn parse_verify_path(args: &[String], use_verify_logits: bool) -> VerifyPath {
|
||||
let default = if use_verify_logits {
|
||||
VerifyPath::PagedDecode
|
||||
} else {
|
||||
VerifyPath::Flash
|
||||
};
|
||||
let Some(value) = args
|
||||
.iter()
|
||||
.position(|a| a == "--verify-path")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
else {
|
||||
return default;
|
||||
};
|
||||
match value.as_str() {
|
||||
"flash" => VerifyPath::Flash,
|
||||
"paged-decode" => VerifyPath::PagedDecode,
|
||||
_ => {
|
||||
eprintln!("unknown --verify-path {value:?}; expected flash or paged-decode");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_length_budget(prompt_ids: &[u32], gen_tokens: usize, max_seq_len: usize, prompt: &str) {
|
||||
let required = prompt_ids.len() + gen_tokens;
|
||||
if required > max_seq_len {
|
||||
eprintln!(
|
||||
"prompt requires prompt_len({}) + gen_tokens({}) = {} tokens, exceeding --max-seq-len {}: {:?}",
|
||||
prompt_ids.len(),
|
||||
gen_tokens,
|
||||
required,
|
||||
max_seq_len,
|
||||
prompt
|
||||
);
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_device() {
|
||||
xserv_cuda::device::synchronize().expect("cuda device synchronize");
|
||||
}
|
||||
|
||||
fn ratio(num: usize, den: usize) -> f64 {
|
||||
if den == 0 {
|
||||
0.0
|
||||
} else {
|
||||
num as f64 / den as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn speedup(baseline_s: f64, spec_s: f64) -> f64 {
|
||||
if spec_s == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
baseline_s / spec_s
|
||||
}
|
||||
}
|
||||
|
||||
fn tok_s(tokens: usize, seconds: f64) -> f64 {
|
||||
if seconds == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
tokens as f64 / seconds
|
||||
}
|
||||
}
|
||||
|
||||
fn per_token_ms(seconds: f64, tokens: usize) -> f64 {
|
||||
if tokens == 0 {
|
||||
0.0
|
||||
} else {
|
||||
seconds * 1000.0 / tokens as f64
|
||||
}
|
||||
}
|
||||
@@ -486,6 +486,35 @@ impl PagedKVCache {
|
||||
state.seq_len += num_tokens;
|
||||
}
|
||||
|
||||
/// Roll a registered sequence back to `new_len` tokens.
|
||||
///
|
||||
/// This only changes cache metadata and frees whole physical blocks that are
|
||||
/// no longer reachable. Bytes inside retained blocks are left untouched; the
|
||||
/// logical `seq_len` prevents attention from reading them, and later writes
|
||||
/// to the same positions overwrite them.
|
||||
pub fn truncate_sequence(&mut self, slot: usize, new_len: usize) -> Result<(), &'static str> {
|
||||
if slot >= self.max_seqs {
|
||||
return Err("truncate_sequence: slot out of range");
|
||||
}
|
||||
let state = self.seq_states[slot]
|
||||
.as_mut()
|
||||
.ok_or("truncate_sequence: empty slot")?;
|
||||
if new_len > state.seq_len {
|
||||
return Err("truncate_sequence: cannot extend");
|
||||
}
|
||||
|
||||
let needed_blocks = ((new_len + BLOCK_SIZE - 1) / BLOCK_SIZE).max(1);
|
||||
while state.block_ids.len() > needed_blocks {
|
||||
let block = state.block_ids.pop().expect("checked len");
|
||||
match state.location {
|
||||
Location::Gpu => self.allocator.free(block),
|
||||
Location::Cpu => self.cpu_allocator.free(block),
|
||||
}
|
||||
}
|
||||
state.seq_len = new_len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Refresh the host-side block table + context lens from `seq_states`,
|
||||
/// then upload to GPU. Call once per decode step before the paged kernel.
|
||||
pub fn sync_to_gpu(&mut self) {
|
||||
@@ -748,6 +777,71 @@ impl PagedKVCache {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tiny_config() -> ModelConfig {
|
||||
serde_json::from_value(serde_json::json!({
|
||||
"model_type": "qwen3",
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 16,
|
||||
"num_attention_heads": 1,
|
||||
"num_key_value_heads": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"vocab_size": 32,
|
||||
"max_position_embeddings": 64
|
||||
}))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_sequence_frees_whole_blocks_and_keeps_slot_registered() {
|
||||
if xserv_cuda::device::set_device(0).is_err() {
|
||||
eprintln!("skipping CUDA-backed PagedKVCache test: device 0 unavailable");
|
||||
return;
|
||||
}
|
||||
|
||||
let config = tiny_config();
|
||||
let mut cache = PagedKVCache::new(&config, 5, 0, 1, 4, DType::BF16, 0);
|
||||
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(1, 0),
|
||||
Err("truncate_sequence: slot out of range")
|
||||
);
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(0, 0),
|
||||
Err("truncate_sequence: empty slot")
|
||||
);
|
||||
|
||||
cache.register_sequence(0).unwrap();
|
||||
cache.ensure_capacity(0, BLOCK_SIZE * 3 + 1);
|
||||
cache.advance_seq_len(0, BLOCK_SIZE * 3 + 1);
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE * 3 + 1);
|
||||
assert_eq!(cache.block_count(0), 4);
|
||||
assert_eq!(cache.free_blocks(), 0);
|
||||
|
||||
cache.truncate_sequence(0, BLOCK_SIZE + 1).unwrap();
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE + 1);
|
||||
assert_eq!(cache.block_count(0), 2);
|
||||
assert_eq!(cache.free_blocks(), 2);
|
||||
|
||||
cache.truncate_sequence(0, BLOCK_SIZE).unwrap();
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE);
|
||||
assert_eq!(cache.block_count(0), 1);
|
||||
assert_eq!(cache.free_blocks(), 3);
|
||||
|
||||
cache.truncate_sequence(0, 0).unwrap();
|
||||
assert_eq!(cache.seq_len(0), 0);
|
||||
assert_eq!(cache.block_count(0), 1);
|
||||
assert_eq!(cache.free_blocks(), 3);
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(0, 1),
|
||||
Err("truncate_sequence: cannot extend")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn tensor_from_owned_buf(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
|
||||
@@ -884,6 +884,109 @@ impl Qwen3 {
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Paged multi-token verify path: write `token_ids` into the paged cache,
|
||||
/// then verify them with the same paged decode attention kernel used by
|
||||
/// single-token decode. This keeps greedy top-1 behavior aligned with
|
||||
/// `forward_decode_paged` while still batching the dense projections/MLP
|
||||
/// across the draft window.
|
||||
pub fn forward_verify_paged_decode_attention(
|
||||
&self,
|
||||
token_ids: &[u32],
|
||||
slot: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
let kv_lens: Vec<i32> = (0..new_tokens)
|
||||
.map(|i| (pos_offset + i + 1) as i32)
|
||||
.collect();
|
||||
let slots = vec![slot; new_tokens];
|
||||
paged_cache.sync_active_batch_with_lens(&slots, &kv_lens);
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt);
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim);
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim);
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
let q_flat = q_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
|
||||
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||
|
||||
let v_3d = v_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
|
||||
|
||||
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_decode,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
||||
let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_rows_gemv(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
@@ -1158,6 +1261,20 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
|
||||
)
|
||||
}
|
||||
|
||||
/// Run a 2D matmul row by row so each row uses the same GEMV kernel as
|
||||
/// single-token decode. Used by speculative verify parity, where near-tie
|
||||
/// logits must follow decode's BF16 rounding path.
|
||||
fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert!(a.is_contiguous());
|
||||
let rows = a.shape()[0];
|
||||
if rows == 1 {
|
||||
return matmul_2d(a, b);
|
||||
}
|
||||
let out_rows: Vec<Tensor> = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect();
|
||||
concat_rows(&out_rows)
|
||||
}
|
||||
|
||||
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
|
||||
fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
assert!(!rows.is_empty());
|
||||
|
||||
186
docs/22-speculative-decoding.md
Normal file
186
docs/22-speculative-decoding.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Phase 22: Draft-Model Speculative Decoding v0
|
||||
|
||||
> 目标:实现一个可验证的 speculative decoding 最小闭环。先只覆盖
|
||||
> Qwen3 target + 同 tokenizer 的小 Qwen3 draft、batch=1、greedy
|
||||
> (`temperature=0`)。本阶段不做 gpt-oss,不做 sampling rejection,不接入
|
||||
> continuous batching。
|
||||
|
||||
## 1. Scope
|
||||
|
||||
本阶段只解决一个窄问题:
|
||||
|
||||
- target:现有 Qwen3 paged KV 路径,优先 Qwen3-8B;
|
||||
- draft:同 tokenizer 的小 Qwen3,例如 Qwen3-0.6B;
|
||||
- batch size:1;
|
||||
- decoding:greedy argmax;
|
||||
- draft window:`gamma=4`;
|
||||
- acceptance:exact-match,即 `target_argmax == draft_token`。
|
||||
|
||||
HTTP flag 可以后续接入。v0 先提供独立 bench/CLI,因为它能直接输出 token
|
||||
一致性、acceptance rate、tokens/target-step、TPOT/tok/s,也避免把尚未稳定的
|
||||
rollback 行为放进服务端调度循环。
|
||||
|
||||
bench 为了让 baseline/spec 对比不受跨 prompt KV pool 复用影响,每个 prompt 的
|
||||
baseline run 和 speculative run 都使用新建的 paged KV cache。cache 分配发生在
|
||||
单次 run 的计时外,输出的 TPOT/tok/s 只覆盖模型 prefill/decode 工作。
|
||||
|
||||
## 2. Why Qwen3 First
|
||||
|
||||
Qwen3 是现有代码里最适合作为 speculative v0 的模型族:
|
||||
|
||||
1. target 已有稳定的 `forward_prefill_paged` 和 `forward_decode_paged`;
|
||||
2. 小 Qwen3 与 Qwen3-8B 共享 tokenizer,可以直接比较 token id;
|
||||
3. Qwen3 是 dense decoder-only,没有 gpt-oss 的 harmony 格式、MoE sparse 路径、
|
||||
sliding-window 或 CUDA Graph 状态;
|
||||
4. greedy 输出的正确性定义简单:只要 spec 生成的 token 序列与纯 target greedy
|
||||
完全一致即可。
|
||||
|
||||
gpt-oss spec 需要先定义 harmony prompt、MoE draft 选择、graph replay 与 rollback
|
||||
的交互,这些都不属于本阶段。
|
||||
|
||||
## 3. Algorithm
|
||||
|
||||
对每个 prompt 建两套模型、三套 KV 状态:
|
||||
|
||||
```text
|
||||
target model + target commit PagedKVCache
|
||||
target model + target verify PagedKVCache
|
||||
draft model + draft PagedKVCache
|
||||
```
|
||||
|
||||
先把 prompt 分别 prefill 到三套 cache。此时 cache 都包含 prompt,并各自持有
|
||||
"下一个 token" 的 logits。
|
||||
|
||||
每个 speculative round:
|
||||
|
||||
1. draft 从当前 draft logits 取 argmax,连续生成 `gamma` 个 draft token;
|
||||
2. draft 每生成一个 token 就用自己的 paged decode append 到 draft KV,所以 round
|
||||
结束时 draft cache 暂时包含整个草稿序列;
|
||||
3. target verify cache 对完整 draft token 序列调用一次 paged prefill,覆盖
|
||||
"target 可一次验证草稿窗口" 这条执行路径;
|
||||
4. target verify cache 立刻 rollback 到 round 起点,避免把 prefill 临时写入污染
|
||||
commit cache;
|
||||
5. 用 target decode 轨迹作为权威结果,从左到右比较
|
||||
`target_next_argmax == draft_token`,只接受连续匹配前缀;
|
||||
6. 对每个接受 token,用 target decode 重放一次来提交 target KV,并得到下一步
|
||||
`target_next_argmax`;verify cache 也 mirror decode 同一个 token,保持长度与 prefix 对齐;
|
||||
7. 若全部匹配,draft cache 已经包含完整草稿,三套 cache 长度重新对齐;
|
||||
8. 若在第 `k` 个 token 拒绝,提交前 `k` 个 draft token,再提交 target 在该位置的
|
||||
argmax 作为修正 token。draft cache rollback 到 round 起点后重放接受 token 和修正
|
||||
token,target commit/verify cache 都由 decode 路径提交到同一 prefix。
|
||||
|
||||
v0 不使用完整 speculative sampling 的概率校正。它只利用小模型猜测 greedy 轨迹,
|
||||
因此生成序列必须与纯 target greedy 完全一致。
|
||||
|
||||
当前实现选择 decode 轨迹作为提交路径,而不是直接保留 target prefill 写入的 KV。
|
||||
原因是 v0 验收要求 token 序列与纯 target greedy 完全一致;如果 prefill 和 decode
|
||||
路径在数值或 KV 写入顺序上存在细微差异,直接提交 prefill KV 会让后续 greedy 输出
|
||||
漂移。这个保守实现仍会执行 target paged prefill 验证和 rollback,但 verify 写入放在
|
||||
独立 cache,不会影响权威 commit cache。代价是额外 mirror decode,速度收益预期较差,
|
||||
主要用于先验证 draft-model speculative 的状态机和一致性。
|
||||
|
||||
为保证 greedy exactness,decode 里两个原有非确定点也需要固定:
|
||||
|
||||
- BF16 GEMV 不再用跨 K-block `atomicAdd`;改为写 K-block partials,再按固定顺序
|
||||
reduce;
|
||||
- paged decode attention 不再用 `atomicAdd` 合并 warp 输出;改为 per-warp partials
|
||||
后按 warp id 顺序 reduce。
|
||||
|
||||
## 4. KV Commit And Rollback
|
||||
|
||||
现有 `forward_prefill_paged` 会一次性把传入 token 写进 paged KV,并提前推进
|
||||
`seq_len`。验证草稿时 target verify cache 因此会临时包含整个 draft window。
|
||||
|
||||
新增的 cache 操作只做逻辑截断:
|
||||
|
||||
```text
|
||||
truncate_sequence(slot, new_len)
|
||||
```
|
||||
|
||||
约束:
|
||||
|
||||
- 只允许 `new_len <= current_len`;
|
||||
- 保留覆盖 `[0, new_len)` 所需的物理 block;
|
||||
- 释放右侧多余 block;
|
||||
- 不清零仍在保留 block 内的旧字节,因为后续逻辑长度会阻止 attention 读取它们,
|
||||
同一位置再次写入时会覆盖旧值;
|
||||
- slot 仍保持 registered,`new_len=0` 时也保留第一个 block。
|
||||
|
||||
这让 target 和 draft 都能在拒绝时安全丢弃多写 KV,并在修正 token decode 后重新
|
||||
对齐。
|
||||
|
||||
## 5. Acceptance Criteria
|
||||
|
||||
本阶段验收:
|
||||
|
||||
- `cargo fmt`;
|
||||
- `cargo check`;
|
||||
- `cargo test`;
|
||||
- `bench-speculative` 可加载 target+draft 两套 Qwen3;
|
||||
- 50 prompts,greedy,baseline target 与 speculative token id 序列完全一致;
|
||||
- 输出 acceptance rate、tokens/target-step、TPOT、tok/s 和 speedup;
|
||||
- 若 draft 模型缺失或磁盘不足,明确报告阻塞条件,不盲目下载大模型。
|
||||
|
||||
## 6. Validation Results
|
||||
|
||||
dash5 环境:
|
||||
|
||||
- GPU:RTX 5090,device 0;
|
||||
- target:`/opt/wjh/models/qwen3-8b`;
|
||||
- draft:`/dashscope-tmp/wjh/models/qwen3-0.6b`;
|
||||
- command:`bench-speculative ... --prompts 50 --gen-tokens 32 --gamma 4 --device 0`;
|
||||
- log:`/dashscope-tmp/wjh/xserv-spec-default-50x32-final.log`。
|
||||
|
||||
默认 `acceptance_mode=decode` 的结果:
|
||||
|
||||
```text
|
||||
prompts=50 matched=true
|
||||
acceptance_rate=0.3664 accepted=1020 proposed=2784
|
||||
tokens_per_target_step=0.3639 target_steps=4397
|
||||
verify_steps=729 mirror_decode_steps=1550 commit_decode_steps=1550 correction_steps=568
|
||||
verify_decode_mismatches=10
|
||||
baseline_e2e_tpot_ms=13.123 baseline_e2e_tok_s=76.204
|
||||
spec_e2e_tpot_ms=44.867 spec_e2e_tok_s=22.288 speedup_e2e=0.2925
|
||||
baseline_decode_tpot_ms=12.638 baseline_decode_tok_s=79.127
|
||||
spec_decode_tpot_ms=43.731 spec_decode_tok_s=22.867 speedup_decode=0.2890
|
||||
decode_token_counts baseline=1600 spec=1600
|
||||
```
|
||||
|
||||
诊断 `--use-verify-logits` 的结果:
|
||||
|
||||
- command:`bench-speculative ... --prompts 10 --gen-tokens 32 --gamma 4 --device 0 --use-verify-logits`;
|
||||
- log:`/dashscope-tmp/wjh/xserv-spec-verify-logits-10x32.log`;
|
||||
- exit status:`2`;
|
||||
- summary:`matched=false`, `verify_decode_mismatches=4`;
|
||||
- prompt 0/2/7 出现 baseline/spec token 序列分叉。
|
||||
|
||||
结论:当前可以做 correctness-first 的 speculative decoding 状态机,但还不能把
|
||||
target batched prefill verify logits 作为 greedy 接受依据。verify prefill 路径与
|
||||
逐 token decode 路径存在 top-1 不一致;默认模式必须继续以 decode 轨迹为权威,
|
||||
因此 v0 是正确性闭环,不是性能优化。
|
||||
|
||||
## 7. Known Limits
|
||||
|
||||
- 只支持 batch=1;
|
||||
- 只支持 Qwen3-family dense models;
|
||||
- 只支持 greedy exact-match acceptance;
|
||||
- 未实现 probabilistic rejection sampling,所以 temperature/top-k/top-p 不支持;
|
||||
- 未接 HTTP/continuous batching;
|
||||
- 未与 CUDA Graph decode 结合;
|
||||
- 当前 v0 为保证 greedy exactness,接受 token 也会用 target decode 重放提交,因此
|
||||
即使 acceptance 高也可能变慢;
|
||||
- draft prefill 和 target prefill 都会计入端到端耗时,短输出可能没有收益。
|
||||
|
||||
## 8. Next Phase TODO
|
||||
|
||||
如果继续 speculative decoding,下一阶段不要先接 HTTP,应先解决 verify 路径:
|
||||
|
||||
1. 做最小 prefill-vs-decode parity harness:固定 prompt、cache len、draft token,
|
||||
dump 每层/最终 logits 的 top-k,定位 top-1 分叉来自 attention、GEMV 还是 KV 写入顺序;
|
||||
2. 让 `--use-verify-logits` 在至少 50 prompts x 64 tokens 下 `matched=true` 且
|
||||
`verify_decode_mismatches=0`;
|
||||
3. parity 过后再做真正 multi-token target commit:要么安全保留 verify prefill 写入的
|
||||
KV,要么实现专用 paged multi-token verify/commit kernel,避免当前的 mirror+commit
|
||||
decode 重放;
|
||||
4. 只有 `speedup_e2e > 1` 后再考虑 HTTP flag、continuous batching、sampling 或
|
||||
gpt-oss speculative decoding。
|
||||
85
docs/23-speculative-verify-parity.md
Normal file
85
docs/23-speculative-verify-parity.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Phase 23: Speculative Verify Parity
|
||||
|
||||
> 目标:把 speculative decoding 从 v0 的 correctness-only 状态机推进到
|
||||
> "verify logits 可作为权威接受依据"。本阶段仍只覆盖 Qwen3 target +
|
||||
> Qwen3 small draft、batch=1、greedy。
|
||||
|
||||
## 1. Problem
|
||||
|
||||
Phase 22 的默认模式用逐 token target decode 作为权威路径,因此输出能与 baseline
|
||||
一致。但诊断 `--use-verify-logits` 会失败:target 对 draft window 做 batched
|
||||
prefill verify 时,部分 logits top-1 与逐 token decode 不一致。
|
||||
|
||||
实测 top-k 显示分叉不是大幅数值错误,而是 BF16 near-tie:
|
||||
|
||||
```text
|
||||
verify_top5=17689:24.500,9856:24.375,...
|
||||
decode_top5=9856:24.500,17689:24.500,...
|
||||
```
|
||||
|
||||
如果直接用这些 verify logits 接受/拒绝 draft token,greedy token 序列会偏离纯
|
||||
target decode。
|
||||
|
||||
## 2. Design
|
||||
|
||||
新增 `Qwen3::forward_verify_paged_decode_attention`:
|
||||
|
||||
1. 在 target commit cache 上一次写入 draft window 的 K/V;
|
||||
2. attention 使用现有 paged decode attention,每个 draft token 对应一行 metadata,
|
||||
context lens 分别为 `pos + 1`;
|
||||
3. 线性层使用逐行 GEMV,与 `forward_decode_paged` 的 BF16 rounding path 对齐;
|
||||
4. 若 token 全接受,直接保留 verify 写入的 KV;
|
||||
5. 若在第 `k` 个 token 拒绝,把 target cache truncate 到 accepted prefix,再只
|
||||
decode 一个 correction token。
|
||||
|
||||
bench 新增:
|
||||
|
||||
- `--use-verify-logits`:用 verify logits 作为接受依据,默认选择 `paged-decode`
|
||||
verify path;
|
||||
- `--verify-path flash|paged-decode`:显式选择旧 flash prefill 诊断或新 paged-decode
|
||||
verify path;
|
||||
- `--dump-verify-mismatches`:打印 mismatch 行 top-k,用于定位 near-tie。
|
||||
|
||||
## 3. Validation
|
||||
|
||||
dash5:
|
||||
|
||||
- GPU:RTX 5090,device 0;
|
||||
- target:`/opt/wjh/models/qwen3-8b`;
|
||||
- draft:`/dashscope-tmp/wjh/models/qwen3-0.6b`;
|
||||
- command:`bench-speculative ... --prompts 50 --gen-tokens 64 --gamma 4 --device 0 --use-verify-logits`;
|
||||
- log:`/dashscope-tmp/wjh/xserv-spec-inplace-verify-50x64.log`。
|
||||
|
||||
结果:
|
||||
|
||||
```text
|
||||
prompts=50 matched=true
|
||||
acceptance_mode=verify_logits
|
||||
verify_path=paged-decode
|
||||
acceptance_rate=0.3927 accepted=2120 proposed=5398
|
||||
tokens_per_target_step=0.9112 target_steps=3512
|
||||
verify_steps=1376 mirror_decode_steps=0 commit_decode_steps=1068 correction_steps=1068
|
||||
verify_decode_mismatches=0
|
||||
baseline_e2e_tpot_ms=13.094 baseline_e2e_tok_s=76.372
|
||||
spec_e2e_tpot_ms=30.069 spec_e2e_tok_s=33.257 speedup_e2e=0.4355
|
||||
baseline_decode_tpot_ms=12.846 baseline_decode_tok_s=77.844
|
||||
spec_decode_tpot_ms=29.731 spec_decode_tok_s=33.635 speedup_decode=0.4321
|
||||
decode_token_counts baseline=3200 spec=3200
|
||||
```
|
||||
|
||||
对比 Phase 22 的保守 decode-authoritative v0:
|
||||
|
||||
- verify logits 现在可以作为权威接受依据;
|
||||
- `mirror_decode_steps` 从每个 accepted token 一次降为 0;
|
||||
- 50x64 e2e speedup 从约 0.29x 提升到 0.44x;
|
||||
- 仍未超过 baseline,因为 verify path 为了 parity 使用逐行 GEMV,且 draft acceptance
|
||||
只有约 39%。
|
||||
|
||||
## 4. Next TODO
|
||||
|
||||
下一阶段要从 correctness parity 转向性能:
|
||||
|
||||
1. 逐层替换 row-GEMV 为 batched GEMM,同时保留 near-tie fallback 或 top-k audit;
|
||||
2. 加一个 `--verify-audit-decode` 低频抽样审计,避免每轮都做 target decode;
|
||||
3. 扫 `gamma` 与 draft 选择,记录 acceptance 与 TPOT 曲线;
|
||||
4. `speedup_e2e > 1` 前不接 HTTP/continuous batching/gpt-oss spec。
|
||||
Reference in New Issue
Block a user