From ce7229f4fe557a8b5d74b5e97b8c900ac6ecd3a3 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 1 Jul 2026 14:15:39 +0800 Subject: [PATCH] speculative: Qwen3 draft-model v0 with paged verify parity Phase 22 lands a correctness-only speculative decoding loop for Qwen3 target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns verify logits into the authoritative acceptance signal so mirror-decode per accepted token is no longer needed. - paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered sequence, freeing whole physical blocks no longer reachable and leaving the slot registered. Covered by a CUDA-gated unit test. - qwen3: forward_verify_paged_decode_attention writes the draft window into the target cache, runs the same paged decode attention kernel per draft token, and uses matmul_rows_gemv so linear layers follow the single-token decode BF16 rounding path. - bench-speculative: new bench binary drives the state machine with --gamma / --gen-tokens / --prompts / --use-verify-logits / --verify-path flash|paged-decode / --dump-verify-mismatches, and compares baseline vs spec token sequences plus TPOT / tok/s / speedup. - docs/22 records the decode-authoritative v0 result and dash5 numbers (matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under --use-verify-logits). - docs/23 records the paged-decode verify path (matched=true, verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the next-step performance TODO. --- .../xserv-model/src/bin/bench-speculative.rs | 962 ++++++++++++++++++ crates/xserv-model/src/paged_kv_cache.rs | 94 ++ crates/xserv-model/src/qwen3.rs | 117 +++ docs/22-speculative-decoding.md | 186 ++++ docs/23-speculative-verify-parity.md | 85 ++ 5 files changed, 1444 insertions(+) create mode 100644 crates/xserv-model/src/bin/bench-speculative.rs create mode 100644 docs/22-speculative-decoding.md create mode 100644 docs/23-speculative-verify-parity.md diff --git a/crates/xserv-model/src/bin/bench-speculative.rs b/crates/xserv-model/src/bin/bench-speculative.rs new file mode 100644 index 0000000..a8fef1a --- /dev/null +++ b/crates/xserv-model/src/bin/bench-speculative.rs @@ -0,0 +1,962 @@ +//! Draft-model speculative decoding benchmark for Qwen3. +//! +//! v0 scope: +//! - target + draft are Qwen3-family models with the same tokenizer/vocab; +//! - batch=1; +//! - greedy exact-match acceptance; +//! - no probabilistic rejection sampling. + +use half::bf16; +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader}; +use xserv_tensor::{DType, Device, Tensor}; +use xserv_tokenizer::Tokenizer; + +const DEFAULT_GAMMA: usize = 4; +const DEFAULT_GEN_TOKENS: usize = 64; +const DEFAULT_MAX_SEQ_LEN: usize = 2048; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum VerifyPath { + Flash, + PagedDecode, +} + +impl VerifyPath { + fn as_str(self) -> &'static str { + match self { + VerifyPath::Flash => "flash", + VerifyPath::PagedDecode => "paged-decode", + } + } +} + +const PROMPTS: [&str; 50] = [ + "The capital of France is", + "Once upon a time in a land far away", + "Hello, how are you doing today", + "In a shocking finding, scientists discovered a", + "The weather today is sunny, so I decided to", + "Alan Turing was a British mathematician who", + "The best way to learn programming is", + "Artificial intelligence will change the world because", + "The history of the internet began in the", + "A good morning routine starts with", + "The stock market crashed because investors", + "Deep learning is a subset of machine learning that", + "The president of the United States announced", + "In the year 2050, humans will", + "The secret to happiness is", + "When I was a child, I used to", + "The most important scientific discovery of the century", + "Climate change is caused by", + "The recipe for chocolate cake requires", + "In conclusion, the evidence suggests that", + "The cat sat on the mat and", + "According to recent studies, exercise can", + "The first step in solving any problem is", + "Technology has transformed the way we", + "The novel begins with the protagonist", + "Education is the most powerful weapon", + "The ocean covers more than seventy percent of", + "Last night I had a dream about", + "The company announced its quarterly earnings", + "Music has the power to", + "The difference between success and failure is", + "In the beginning, there was nothing but", + "The doctor told me that I should", + "Python is a popular programming language because", + "The ancient Romans built roads that", + "A balanced diet should include", + "The movie received mixed reviews from critics", + "Space exploration has led to many", + "The teacher asked the students to", + "Global warming is one of the most", + "The bridge collapsed due to structural", + "Quantum computing promises to revolutionize", + "The new policy will affect millions of", + "During the winter months, it is important to", + "The human brain contains approximately", + "Democracy depends on the active participation of", + "The train arrived at the station exactly", + "Researchers at MIT have developed a new", + "The smartphone has become an essential part of", + "After careful consideration, the committee decided to", +]; + +#[derive(Default)] +struct RunStats { + ids: Vec, + total_s: f64, + prefill_s: f64, + decode_s: f64, + target_steps: usize, + accepted: usize, + proposed: usize, + verify_steps: usize, + mirror_steps: usize, + commit_steps: usize, + correction_steps: usize, + verify_decode_mismatches: usize, +} + +#[derive(Default)] +struct Totals { + prompts: usize, + baseline_generated: usize, + spec_generated: usize, + baseline_total_s: f64, + baseline_prefill_s: f64, + baseline_decode_s: f64, + spec_total_s: f64, + spec_prefill_s: f64, + spec_decode_s: f64, + spec_target_steps: usize, + spec_accepted: usize, + spec_proposed: usize, + spec_verify_steps: usize, + spec_mirror_steps: usize, + spec_commit_steps: usize, + spec_correction_steps: usize, + spec_verify_decode_mismatches: usize, + mismatches: usize, +} + +fn main() { + let args: Vec = std::env::args().collect(); + if args.len() < 3 { + eprintln!( + "Usage: bench-speculative \ + [--gen-tokens N] [--gamma N] [--prompts N] [--max-seq-len N] [--device N] \ + [--use-verify-logits] [--verify-path flash|paged-decode] [--dump-verify-mismatches]" + ); + std::process::exit(1); + } + + let target_dir = PathBuf::from(&args[1]); + let draft_dir = PathBuf::from(&args[2]); + let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS); + let gamma = arg_usize(&args, "--gamma", DEFAULT_GAMMA); + let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len()); + let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN); + let device = arg_usize(&args, "--device", 0) as u32; + let use_verify_logits = args.iter().any(|a| a == "--use-verify-logits"); + let verify_path = parse_verify_path(&args, use_verify_logits); + let dump_verify_mismatches = args.iter().any(|a| a == "--dump-verify-mismatches"); + + assert!(gen_tokens > 0, "--gen-tokens must be > 0"); + assert!(gamma > 0, "--gamma must be > 0"); + + xserv_cuda::device::set_device(device).unwrap(); + let info = xserv_cuda::device::device_info(device).unwrap(); + eprintln!( + "GPU {device}: {} ({} MB free)", + info.name, + info.free_memory / 1024 / 1024 + ); + + let target_config = ModelConfig::from_file(&target_dir.join("config.json")); + let draft_config = ModelConfig::from_file(&draft_dir.join("config.json")); + assert_qwen3(&target_config, "target"); + assert_qwen3(&draft_config, "draft"); + assert_eq!( + target_config.vocab_size, draft_config.vocab_size, + "target and draft vocab_size must match" + ); + + warn_if_tokenizers_differ(&target_dir, &draft_dir); + let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json")); + if tokenizer.vocab_size() != target_config.vocab_size { + eprintln!( + "WARNING: tokenizer decoder len {} differs from config vocab_size {}; continuing because token ids come from the shared tokenizer.json", + tokenizer.vocab_size(), + target_config.vocab_size + ); + } + + eprintln!( + "Loading target Qwen3: layers={} hidden={} heads={}/{} vocab={}", + target_config.num_layers(), + target_config.hidden(), + target_config.num_heads(), + target_config.num_kv_heads(), + target_config.vocab_size + ); + let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device)); + let target = Qwen3::from_weights(target_config.clone(), target_weights); + xserv_cuda::allocator::cached_trim(); + + eprintln!( + "Loading draft Qwen3: layers={} hidden={} heads={}/{} vocab={}", + draft_config.num_layers(), + draft_config.hidden(), + draft_config.num_heads(), + draft_config.num_kv_heads(), + draft_config.vocab_size + ); + let draft_weights = loader::load_model_dir(&draft_dir, Device::Cuda(device)); + let draft = Qwen3::from_weights(draft_config.clone(), draft_weights); + xserv_cuda::allocator::cached_trim(); + + let warm_ids = tokenizer.encode("warmup"); + let warm_tokens = gen_tokens.min(4); + { + let mut target_cache = new_cache(&target_config, max_seq_len, device); + let _ = run_baseline( + &target, + &mut target_cache, + &tokenizer, + &warm_ids, + warm_tokens, + ); + } + { + let mut target_cache = new_cache_with_rows( + &target_config, + max_seq_len, + device, + if use_verify_logits { gamma } else { 1 }, + ); + let mut target_verify_cache = + new_cache_with_rows(&target_config, max_seq_len, device, gamma); + let mut draft_cache = new_cache(&draft_config, max_seq_len, device); + let _ = run_speculative( + &target, + &draft, + &mut target_cache, + &mut target_verify_cache, + &mut draft_cache, + &tokenizer, + &warm_ids, + warm_tokens, + gamma, + use_verify_logits, + verify_path, + dump_verify_mismatches, + ); + } + eprintln!( + "Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}, acceptance_mode={}, verify_path={}", + if use_verify_logits { + "verify_logits" + } else { + "decode" + }, + verify_path.as_str() + ); + + let mut totals = Totals::default(); + for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() { + let ids = tokenizer.encode(prompt); + validate_length_budget(&ids, gen_tokens, max_seq_len, prompt); + let mut baseline_cache = new_cache(&target_config, max_seq_len, device); + let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens); + drop(baseline_cache); + + let mut target_cache = new_cache_with_rows( + &target_config, + max_seq_len, + device, + if use_verify_logits { gamma } else { 1 }, + ); + let mut target_verify_cache = + new_cache_with_rows(&target_config, max_seq_len, device, gamma); + let mut draft_cache = new_cache(&draft_config, max_seq_len, device); + let spec = run_speculative( + &target, + &draft, + &mut target_cache, + &mut target_verify_cache, + &mut draft_cache, + &tokenizer, + &ids, + gen_tokens, + gamma, + use_verify_logits, + verify_path, + dump_verify_mismatches, + ); + + let matched = baseline.ids == spec.ids; + if !matched { + totals.mismatches += 1; + eprintln!("MISMATCH prompt {i}: {prompt}"); + eprintln!(" baseline: {:?}", baseline.ids); + eprintln!(" spec: {:?}", spec.ids); + } + + println!( + "prompt={:02} match={} gen={} accept={}/{} target_steps={} \ + baseline_e2e_tpot_ms={:.3} spec_e2e_tpot_ms={:.3}", + i, + matched, + spec.ids.len(), + spec.accepted, + spec.proposed, + spec.target_steps, + per_token_ms(baseline.total_s, baseline.ids.len()), + per_token_ms(spec.total_s, spec.ids.len()), + ); + + totals.prompts += 1; + totals.baseline_generated += baseline.ids.len(); + totals.spec_generated += spec.ids.len(); + totals.baseline_total_s += baseline.total_s; + totals.baseline_prefill_s += baseline.prefill_s; + totals.baseline_decode_s += baseline.decode_s; + totals.spec_total_s += spec.total_s; + totals.spec_prefill_s += spec.prefill_s; + totals.spec_decode_s += spec.decode_s; + totals.spec_target_steps += spec.target_steps; + totals.spec_accepted += spec.accepted; + totals.spec_proposed += spec.proposed; + totals.spec_verify_steps += spec.verify_steps; + totals.spec_mirror_steps += spec.mirror_steps; + totals.spec_commit_steps += spec.commit_steps; + totals.spec_correction_steps += spec.correction_steps; + totals.spec_verify_decode_mismatches += spec.verify_decode_mismatches; + } + + let baseline_decode_tokens = totals.baseline_generated; + let spec_decode_tokens = totals.spec_generated; + let acceptance = ratio(totals.spec_accepted, totals.spec_proposed); + let tokens_per_target_step = ratio(totals.spec_generated, totals.spec_target_steps); + let matched = + totals.mismatches == 0 && (!use_verify_logits || totals.spec_verify_decode_mismatches == 0); + + println!("--- SUMMARY ---"); + println!("prompts={} matched={matched}", totals.prompts); + println!( + "acceptance_mode={}", + if use_verify_logits { + "verify_logits" + } else { + "decode" + } + ); + println!("verify_path={}", verify_path.as_str()); + println!( + "acceptance_rate={:.4} accepted={} proposed={}", + acceptance, totals.spec_accepted, totals.spec_proposed + ); + println!( + "tokens_per_target_step={:.4} target_steps={} verify_steps={} mirror_decode_steps={} commit_decode_steps={} correction_steps={}", + tokens_per_target_step, + totals.spec_target_steps, + totals.spec_verify_steps, + totals.spec_mirror_steps, + totals.spec_commit_steps, + totals.spec_correction_steps + ); + println!( + "verify_decode_mismatches={}", + totals.spec_verify_decode_mismatches + ); + println!( + "baseline_e2e_tpot_ms={:.3} baseline_e2e_tok_s={:.3}", + per_token_ms(totals.baseline_total_s, totals.baseline_generated), + tok_s(totals.baseline_generated, totals.baseline_total_s) + ); + println!( + "spec_e2e_tpot_ms={:.3} spec_e2e_tok_s={:.3} speedup_e2e={:.4}", + per_token_ms(totals.spec_total_s, totals.spec_generated), + tok_s(totals.spec_generated, totals.spec_total_s), + speedup(totals.baseline_total_s, totals.spec_total_s) + ); + println!( + "baseline_decode_tpot_ms={:.3} baseline_decode_tok_s={:.3}", + per_token_ms(totals.baseline_decode_s, baseline_decode_tokens), + tok_s(baseline_decode_tokens, totals.baseline_decode_s) + ); + println!( + "spec_decode_tpot_ms={:.3} spec_decode_tok_s={:.3} speedup_decode={:.4}", + per_token_ms(totals.spec_decode_s, spec_decode_tokens), + tok_s(spec_decode_tokens, totals.spec_decode_s), + speedup(totals.baseline_decode_s, totals.spec_decode_s) + ); + println!( + "decode_token_counts baseline={} spec={}", + baseline_decode_tokens, spec_decode_tokens + ); + + if !matched { + std::process::exit(2); + } +} + +fn run_baseline( + model: &Qwen3, + cache: &mut PagedKVCache, + tokenizer: &Tokenizer, + prompt_ids: &[u32], + gen_tokens: usize, +) -> RunStats { + let slot = 0; + cache.register_sequence(slot).expect("register target slot"); + + let t0 = Instant::now(); + let prefill_start = Instant::now(); + let logits = model.forward_prefill_paged(prompt_ids, slot, cache); + sync_device(); + let prefill_s = prefill_start.elapsed().as_secs_f64(); + + let mut generated = Vec::with_capacity(gen_tokens); + let mut next = last_argmax(&logits); + generated.push(next); + + let decode_start = Instant::now(); + let mut target_steps = 0usize; + while generated.len() < gen_tokens && !tokenizer.is_eos(next) { + let pos = cache.seq_len(slot); + let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache); + target_steps += 1; + next = last_argmax(&logits); + generated.push(next); + } + sync_device(); + let decode_s = decode_start.elapsed().as_secs_f64(); + sync_device(); + let total_s = t0.elapsed().as_secs_f64(); + + cache.free_sequence(slot); + RunStats { + ids: generated, + total_s, + prefill_s, + decode_s, + target_steps, + ..Default::default() + } +} + +#[allow(clippy::too_many_arguments)] +fn run_speculative( + target: &Qwen3, + draft: &Qwen3, + target_cache: &mut PagedKVCache, + target_verify_cache: &mut PagedKVCache, + draft_cache: &mut PagedKVCache, + tokenizer: &Tokenizer, + prompt_ids: &[u32], + gen_tokens: usize, + gamma: usize, + use_verify_logits: bool, + verify_path: VerifyPath, + dump_verify_mismatches: bool, +) -> RunStats { + let slot = 0; + target_cache + .register_sequence(slot) + .expect("register target slot"); + target_verify_cache + .register_sequence(slot) + .expect("register target verify slot"); + draft_cache + .register_sequence(slot) + .expect("register draft slot"); + + let t0 = Instant::now(); + let prefill_start = Instant::now(); + let target_logits = target.forward_prefill_paged(prompt_ids, slot, target_cache); + if !use_verify_logits { + let _ = target.forward_prefill_paged(prompt_ids, slot, target_verify_cache); + } + let draft_logits = draft.forward_prefill_paged(prompt_ids, slot, draft_cache); + sync_device(); + let prefill_s = prefill_start.elapsed().as_secs_f64(); + + let mut target_next = last_argmax(&target_logits); + let mut draft_next = last_argmax(&draft_logits); + let mut generated = Vec::with_capacity(gen_tokens); + let mut accepted_total = 0usize; + let mut proposed_total = 0usize; + let mut verify_steps = 0usize; + let mut mirror_steps = 0usize; + let mut commit_steps = 0usize; + let mut correction_steps = 0usize; + let mut verify_decode_mismatches = 0usize; + + let decode_start = Instant::now(); + while generated.len() < gen_tokens { + let remaining = gen_tokens - generated.len(); + let round_gamma = gamma.min(remaining); + let round_start_len = target_cache.seq_len(slot); + assert_eq!( + round_start_len, + draft_cache.seq_len(slot), + "target and draft cache lengths diverged" + ); + if !use_verify_logits { + assert_eq!( + round_start_len, + target_verify_cache.seq_len(slot), + "target verify cache length diverged" + ); + } + + let mut draft_tokens = Vec::with_capacity(round_gamma); + for _ in 0..round_gamma { + let token = draft_next; + draft_tokens.push(token); + if tokenizer.is_eos(token) { + break; + } + let pos = draft_cache.seq_len(slot); + let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], draft_cache); + draft_next = last_argmax(&logits); + } + proposed_total += draft_tokens.len(); + + if use_verify_logits { + verify_steps += 1; + let verify_logits = + target.forward_verify_paged_decode_attention(&draft_tokens, slot, target_cache); + let verify_argmax = argmax_rows(&verify_logits); + assert_eq!( + verify_argmax.len(), + draft_tokens.len(), + "verify logits rows must match draft token count" + ); + + let mut accepted = 0usize; + let mut done = false; + while accepted < draft_tokens.len() { + let expected = if accepted > 0 { + verify_argmax[accepted - 1] + } else { + target_next + }; + if draft_tokens[accepted] != expected { + break; + } + let token = draft_tokens[accepted]; + generated.push(token); + accepted_total += 1; + accepted += 1; + + if generated.len() >= gen_tokens || tokenizer.is_eos(token) { + done = true; + break; + } + } + + if accepted > 0 { + target_next = verify_argmax[accepted - 1]; + } + target_cache + .truncate_sequence(slot, round_start_len + accepted) + .unwrap(); + + if done { + draft_cache + .truncate_sequence(slot, target_cache.seq_len(slot)) + .unwrap(); + break; + } + + if accepted == draft_tokens.len() { + continue; + } + + let correction = if accepted > 0 { + verify_argmax[accepted - 1] + } else { + target_next + }; + generated.push(correction); + + draft_cache + .truncate_sequence(slot, round_start_len) + .unwrap(); + replay_draft_tokens( + draft, + draft_cache, + slot, + &draft_tokens[..accepted], + &mut draft_next, + ); + + if generated.len() >= gen_tokens || tokenizer.is_eos(correction) { + break; + } + + let pos = target_cache.seq_len(slot); + let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache); + target_next = last_argmax(&logits); + commit_steps += 1; + + let pos = draft_cache.seq_len(slot); + let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache); + draft_next = last_argmax(&logits); + correction_steps += 1; + continue; + } + + verify_steps += 1; + let verify_logits = match verify_path { + VerifyPath::Flash => { + target.forward_prefill_paged(&draft_tokens, slot, target_verify_cache) + } + VerifyPath::PagedDecode => target.forward_verify_paged_decode_attention( + &draft_tokens, + slot, + target_verify_cache, + ), + }; + let verify_argmax = argmax_rows(&verify_logits); + assert_eq!( + verify_argmax.len(), + draft_tokens.len(), + "verify logits rows must match draft token count" + ); + + target_verify_cache + .truncate_sequence(slot, round_start_len) + .unwrap(); + + let mut accepted = 0usize; + let mut done = false; + while accepted < draft_tokens.len() { + let expected = if use_verify_logits && accepted > 0 { + verify_argmax[accepted - 1] + } else { + target_next + }; + if draft_tokens[accepted] != expected { + break; + } + let token_idx = accepted; + let token = draft_tokens[token_idx]; + generated.push(token); + accepted_total += 1; + accepted += 1; + + if generated.len() >= gen_tokens || tokenizer.is_eos(token) { + done = true; + break; + } + + let pos = target_cache.seq_len(slot); + let logits = target.forward_decode_paged(&[token], &[pos], &[slot], target_cache); + let decode_next = last_argmax(&logits); + if verify_argmax[token_idx] != decode_next { + verify_decode_mismatches += 1; + eprintln!( + "VERIFY/DECODE MISMATCH at cache_len={} accepted_idx={}: verify={} decode={}", + target_cache.seq_len(slot), + token_idx, + verify_argmax[token_idx], + decode_next + ); + if dump_verify_mismatches { + eprintln!( + " verify_top5={} decode_top5={}", + format_topk(&verify_logits, token_idx, 5), + format_topk(&logits, 0, 5) + ); + } + } + target_next = decode_next; + commit_steps += 1; + + advance_target_cache(target, target_verify_cache, slot, token); + mirror_steps += 1; + } + if done { + draft_cache + .truncate_sequence(slot, target_cache.seq_len(slot)) + .unwrap(); + target_verify_cache + .truncate_sequence(slot, target_cache.seq_len(slot)) + .unwrap(); + break; + } + + if accepted == draft_tokens.len() { + continue; + } + + let correction = if use_verify_logits && accepted > 0 { + verify_argmax[accepted - 1] + } else { + target_next + }; + generated.push(correction); + + draft_cache + .truncate_sequence(slot, round_start_len) + .unwrap(); + replay_draft_tokens( + draft, + draft_cache, + slot, + &draft_tokens[..accepted], + &mut draft_next, + ); + + if generated.len() >= gen_tokens || tokenizer.is_eos(correction) { + break; + } + + let pos = target_cache.seq_len(slot); + let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache); + target_next = last_argmax(&logits); + commit_steps += 1; + + advance_target_cache(target, target_verify_cache, slot, correction); + mirror_steps += 1; + + let pos = draft_cache.seq_len(slot); + let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache); + draft_next = last_argmax(&logits); + correction_steps += 1; + } + sync_device(); + let decode_s = decode_start.elapsed().as_secs_f64(); + sync_device(); + let total_s = t0.elapsed().as_secs_f64(); + + target_cache.free_sequence(slot); + target_verify_cache.free_sequence(slot); + draft_cache.free_sequence(slot); + + RunStats { + ids: generated, + total_s, + prefill_s, + decode_s, + target_steps: verify_steps + mirror_steps + commit_steps + correction_steps, + accepted: accepted_total, + proposed: proposed_total, + verify_steps, + mirror_steps, + commit_steps, + correction_steps, + verify_decode_mismatches, + } +} + +fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, token: u32) { + let pos = cache.seq_len(slot); + let _ = target.forward_decode_paged(&[token], &[pos], &[slot], cache); +} + +fn replay_draft_tokens( + draft: &Qwen3, + cache: &mut PagedKVCache, + slot: usize, + tokens: &[u32], + next: &mut u32, +) { + for &token in tokens { + let pos = cache.seq_len(slot); + let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], cache); + *next = last_argmax(&logits); + } +} + +fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache { + new_cache_with_rows(config, max_seq_len, device, 1) +} + +fn new_cache_with_rows( + config: &ModelConfig, + max_seq_len: usize, + device: u32, + max_rows: usize, +) -> PagedKVCache { + let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE); + let total_blocks = max_blocks_per_seq + 8; + PagedKVCache::new( + config, + total_blocks, + 0, + max_rows.max(1), + max_blocks_per_seq, + DType::BF16, + device, + ) +} + +fn argmax_rows(logits: &Tensor) -> Vec { + assert_eq!(logits.ndim(), 2); + if logits.dtype() == DType::BF16 + && matches!(logits.device(), Device::Cuda(_)) + && logits.is_contiguous() + { + return xserv_kernels::argmax_bf16_to_host(logits); + } + + let vocab_size = logits.shape()[1]; + let rows = logits.shape()[0]; + let logits_cpu = logits.to_device(Device::Cpu); + match logits.dtype() { + DType::F32 => logits_cpu + .as_slice::() + .chunks_exact(vocab_size) + .take(rows) + .map(argmax_f32) + .collect(), + DType::BF16 => logits_cpu + .as_slice::() + .chunks_exact(vocab_size) + .take(rows) + .map(|row| { + row.iter() + .enumerate() + .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) + .map(|(i, _)| i as u32) + .unwrap() + }) + .collect(), + _ => panic!("unsupported dtype for argmax: {:?}", logits.dtype()), + } +} + +fn last_argmax(logits: &Tensor) -> u32 { + *argmax_rows(logits).last().unwrap() +} + +fn argmax_f32(row: &[f32]) -> u32 { + row.iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .map(|(i, _)| i as u32) + .unwrap() +} + +fn format_topk(logits: &Tensor, row: usize, k: usize) -> String { + let vals = topk_row(logits, row, k); + vals.iter() + .map(|(id, val)| format!("{id}:{val:.3}")) + .collect::>() + .join(",") +} + +fn topk_row(logits: &Tensor, row: usize, k: usize) -> Vec<(u32, f32)> { + assert_eq!(logits.ndim(), 2); + let vocab_size = logits.shape()[1]; + assert!(row < logits.shape()[0], "topk row out of bounds"); + let logits_cpu = logits.to_device(Device::Cpu); + let mut vals: Vec<(u32, f32)> = match logits.dtype() { + DType::F32 => logits_cpu.as_slice::()[row * vocab_size..(row + 1) * vocab_size] + .iter() + .enumerate() + .map(|(i, &v)| (i as u32, v)) + .collect(), + DType::BF16 => logits_cpu.as_slice::()[row * vocab_size..(row + 1) * vocab_size] + .iter() + .enumerate() + .map(|(i, &v)| (i as u32, v.to_f32())) + .collect(), + _ => panic!("unsupported dtype for topk: {:?}", logits.dtype()), + }; + vals.select_nth_unstable_by(k, |a, b| b.1.partial_cmp(&a.1).unwrap()); + vals.truncate(k); + vals.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + vals +} + +fn assert_qwen3(config: &ModelConfig, name: &str) { + let model_type = config.model_type.as_deref().unwrap_or("unknown"); + assert!( + model_type.contains("qwen"), + "{name} model_type must be qwen-like, got {model_type}" + ); +} + +fn warn_if_tokenizers_differ(target_dir: &Path, draft_dir: &Path) { + let target = std::fs::read(target_dir.join("tokenizer.json")); + let draft = std::fs::read(draft_dir.join("tokenizer.json")); + if let (Ok(target), Ok(draft)) = (target, draft) { + if target != draft { + eprintln!( + "WARNING: target and draft tokenizer.json differ; v0 assumes identical token ids" + ); + } + } +} + +fn arg_usize(args: &[String], flag: &str, default: usize) -> usize { + args.iter() + .position(|a| a == flag) + .and_then(|i| args.get(i + 1)) + .and_then(|s| s.parse().ok()) + .unwrap_or(default) +} + +fn parse_verify_path(args: &[String], use_verify_logits: bool) -> VerifyPath { + let default = if use_verify_logits { + VerifyPath::PagedDecode + } else { + VerifyPath::Flash + }; + let Some(value) = args + .iter() + .position(|a| a == "--verify-path") + .and_then(|i| args.get(i + 1)) + else { + return default; + }; + match value.as_str() { + "flash" => VerifyPath::Flash, + "paged-decode" => VerifyPath::PagedDecode, + _ => { + eprintln!("unknown --verify-path {value:?}; expected flash or paged-decode"); + std::process::exit(1); + } + } +} + +fn validate_length_budget(prompt_ids: &[u32], gen_tokens: usize, max_seq_len: usize, prompt: &str) { + let required = prompt_ids.len() + gen_tokens; + if required > max_seq_len { + eprintln!( + "prompt requires prompt_len({}) + gen_tokens({}) = {} tokens, exceeding --max-seq-len {}: {:?}", + prompt_ids.len(), + gen_tokens, + required, + max_seq_len, + prompt + ); + std::process::exit(2); + } +} + +fn sync_device() { + xserv_cuda::device::synchronize().expect("cuda device synchronize"); +} + +fn ratio(num: usize, den: usize) -> f64 { + if den == 0 { + 0.0 + } else { + num as f64 / den as f64 + } +} + +fn speedup(baseline_s: f64, spec_s: f64) -> f64 { + if spec_s == 0.0 { + 0.0 + } else { + baseline_s / spec_s + } +} + +fn tok_s(tokens: usize, seconds: f64) -> f64 { + if seconds == 0.0 { + 0.0 + } else { + tokens as f64 / seconds + } +} + +fn per_token_ms(seconds: f64, tokens: usize) -> f64 { + if tokens == 0 { + 0.0 + } else { + seconds * 1000.0 / tokens as f64 + } +} diff --git a/crates/xserv-model/src/paged_kv_cache.rs b/crates/xserv-model/src/paged_kv_cache.rs index 50d6256..34f9bac 100644 --- a/crates/xserv-model/src/paged_kv_cache.rs +++ b/crates/xserv-model/src/paged_kv_cache.rs @@ -486,6 +486,35 @@ impl PagedKVCache { state.seq_len += num_tokens; } + /// Roll a registered sequence back to `new_len` tokens. + /// + /// This only changes cache metadata and frees whole physical blocks that are + /// no longer reachable. Bytes inside retained blocks are left untouched; the + /// logical `seq_len` prevents attention from reading them, and later writes + /// to the same positions overwrite them. + pub fn truncate_sequence(&mut self, slot: usize, new_len: usize) -> Result<(), &'static str> { + if slot >= self.max_seqs { + return Err("truncate_sequence: slot out of range"); + } + let state = self.seq_states[slot] + .as_mut() + .ok_or("truncate_sequence: empty slot")?; + if new_len > state.seq_len { + return Err("truncate_sequence: cannot extend"); + } + + let needed_blocks = ((new_len + BLOCK_SIZE - 1) / BLOCK_SIZE).max(1); + while state.block_ids.len() > needed_blocks { + let block = state.block_ids.pop().expect("checked len"); + match state.location { + Location::Gpu => self.allocator.free(block), + Location::Cpu => self.cpu_allocator.free(block), + } + } + state.seq_len = new_len; + Ok(()) + } + /// Refresh the host-side block table + context lens from `seq_states`, /// then upload to GPU. Call once per decode step before the paged kernel. pub fn sync_to_gpu(&mut self) { @@ -748,6 +777,71 @@ impl PagedKVCache { } } +#[cfg(test)] +mod tests { + use super::*; + + fn tiny_config() -> ModelConfig { + serde_json::from_value(serde_json::json!({ + "model_type": "qwen3", + "hidden_size": 8, + "intermediate_size": 16, + "num_attention_heads": 1, + "num_key_value_heads": 1, + "num_hidden_layers": 1, + "vocab_size": 32, + "max_position_embeddings": 64 + })) + .unwrap() + } + + #[test] + fn truncate_sequence_frees_whole_blocks_and_keeps_slot_registered() { + if xserv_cuda::device::set_device(0).is_err() { + eprintln!("skipping CUDA-backed PagedKVCache test: device 0 unavailable"); + return; + } + + let config = tiny_config(); + let mut cache = PagedKVCache::new(&config, 5, 0, 1, 4, DType::BF16, 0); + + assert_eq!( + cache.truncate_sequence(1, 0), + Err("truncate_sequence: slot out of range") + ); + assert_eq!( + cache.truncate_sequence(0, 0), + Err("truncate_sequence: empty slot") + ); + + cache.register_sequence(0).unwrap(); + cache.ensure_capacity(0, BLOCK_SIZE * 3 + 1); + cache.advance_seq_len(0, BLOCK_SIZE * 3 + 1); + assert_eq!(cache.seq_len(0), BLOCK_SIZE * 3 + 1); + assert_eq!(cache.block_count(0), 4); + assert_eq!(cache.free_blocks(), 0); + + cache.truncate_sequence(0, BLOCK_SIZE + 1).unwrap(); + assert_eq!(cache.seq_len(0), BLOCK_SIZE + 1); + assert_eq!(cache.block_count(0), 2); + assert_eq!(cache.free_blocks(), 2); + + cache.truncate_sequence(0, BLOCK_SIZE).unwrap(); + assert_eq!(cache.seq_len(0), BLOCK_SIZE); + assert_eq!(cache.block_count(0), 1); + assert_eq!(cache.free_blocks(), 3); + + cache.truncate_sequence(0, 0).unwrap(); + assert_eq!(cache.seq_len(0), 0); + assert_eq!(cache.block_count(0), 1); + assert_eq!(cache.free_blocks(), 3); + assert_eq!( + cache.truncate_sequence(0, 1), + Err("truncate_sequence: cannot extend") + ); + } +} + unsafe fn tensor_from_owned_buf( buf: GpuBuffer, shape: &[usize], diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index 2eef2a2..376d89c 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -884,6 +884,109 @@ impl Qwen3 { matmul_2d(&x, &self.lm_head_t) } + /// Paged multi-token verify path: write `token_ids` into the paged cache, + /// then verify them with the same paged decode attention kernel used by + /// single-token decode. This keeps greedy top-1 behavior aligned with + /// `forward_decode_paged` while still batching the dense projections/MLP + /// across the draft window. + pub fn forward_verify_paged_decode_attention( + &self, + token_ids: &[u32], + slot: usize, + paged_cache: &mut PagedKVCache, + ) -> Tensor { + let new_tokens = token_ids.len(); + let pos_offset = paged_cache.seq_len(slot); + let num_heads = self.local_num_heads; + let num_kv_heads = self.local_num_kv_heads; + let head_dim = self.config.head_dim(); + let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; + + paged_cache.ensure_capacity(slot, pos_offset + new_tokens); + paged_cache.advance_seq_len(slot, new_tokens); + + let positions: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); + let kv_lens: Vec = (0..new_tokens) + .map(|i| (pos_offset + i + 1) as i32) + .collect(); + let slots = vec![slot; new_tokens]; + paged_cache.sync_active_batch_with_lens(&slots, &kv_lens); + let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32; + let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32; + let max_blocks = paged_cache.max_blocks_per_seq(); + + let mut x = embedding(&self.embed_tokens, token_ids); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let residual = x.clone(); + let normed = rmsnorm(&x, &layer.input_norm, eps); + + let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt); + let q_dim = num_heads * head_dim; + let kv_dim = num_kv_heads * head_dim; + let q_all = qkv.narrow(1, 0, q_dim); + let k_all = qkv.narrow(1, q_dim, kv_dim); + let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim); + + let q_flat = q_all + .contiguous() + .reshape(&[new_tokens * num_heads, head_dim]); + let k_flat = k_all + .contiguous() + .reshape(&[new_tokens * num_kv_heads, head_dim]); + let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps); + let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps); + + let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]); + let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]); + rope_inplace(&q_3d, &self.rope_cache, &positions); + rope_inplace(&k_3d, &self.rope_cache, &positions); + + let v_3d = v_all + .contiguous() + .reshape(&[new_tokens, num_kv_heads, head_dim]); + paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens); + + let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]); + let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void; + let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void; + let attn_out = xserv_kernels::paged_decode_attention( + &q_decode, + k_pool_ptr, + v_pool_ptr, + bt_ptr, + cl_ptr, + new_tokens, + num_heads, + num_kv_heads, + head_dim, + max_blocks, + ); + + let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]); + let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt); + self.all_reduce(&attn_proj); + + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let residual = x_new.clone(); + + let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt); + let ffn_dim = gate_up.shape()[1] / 2; + let gate = gate_up.narrow(1, 0, ffn_dim).contiguous(); + let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous(); + let hidden_states = xserv_kernels::silu_mul(&gate, &up); + let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt); + self.all_reduce(&down); + x = add_any(&residual, &down); + } + + let x = rmsnorm(&x, &self.norm, eps); + matmul_rows_gemv(&x, &self.lm_head_t) + } + /// Forward with GPU-resident KV cache and GPU transpose/reshape kernels. pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor { let new_tokens = token_ids.len(); @@ -1158,6 +1261,20 @@ fn row_view(t: &Tensor, row: usize) -> Tensor { ) } +/// Run a 2D matmul row by row so each row uses the same GEMV kernel as +/// single-token decode. Used by speculative verify parity, where near-tie +/// logits must follow decode's BF16 rounding path. +fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor { + assert_eq!(a.ndim(), 2); + assert!(a.is_contiguous()); + let rows = a.shape()[0]; + if rows == 1 { + return matmul_2d(a, b); + } + let out_rows: Vec = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect(); + concat_rows(&out_rows) +} + /// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy. fn concat_rows(rows: &[Tensor]) -> Tensor { assert!(!rows.is_empty()); diff --git a/docs/22-speculative-decoding.md b/docs/22-speculative-decoding.md new file mode 100644 index 0000000..ec7186a --- /dev/null +++ b/docs/22-speculative-decoding.md @@ -0,0 +1,186 @@ +# Phase 22: Draft-Model Speculative Decoding v0 + +> 目标:实现一个可验证的 speculative decoding 最小闭环。先只覆盖 +> Qwen3 target + 同 tokenizer 的小 Qwen3 draft、batch=1、greedy +> (`temperature=0`)。本阶段不做 gpt-oss,不做 sampling rejection,不接入 +> continuous batching。 + +## 1. Scope + +本阶段只解决一个窄问题: + +- target:现有 Qwen3 paged KV 路径,优先 Qwen3-8B; +- draft:同 tokenizer 的小 Qwen3,例如 Qwen3-0.6B; +- batch size:1; +- decoding:greedy argmax; +- draft window:`gamma=4`; +- acceptance:exact-match,即 `target_argmax == draft_token`。 + +HTTP flag 可以后续接入。v0 先提供独立 bench/CLI,因为它能直接输出 token +一致性、acceptance rate、tokens/target-step、TPOT/tok/s,也避免把尚未稳定的 +rollback 行为放进服务端调度循环。 + +bench 为了让 baseline/spec 对比不受跨 prompt KV pool 复用影响,每个 prompt 的 +baseline run 和 speculative run 都使用新建的 paged KV cache。cache 分配发生在 +单次 run 的计时外,输出的 TPOT/tok/s 只覆盖模型 prefill/decode 工作。 + +## 2. Why Qwen3 First + +Qwen3 是现有代码里最适合作为 speculative v0 的模型族: + +1. target 已有稳定的 `forward_prefill_paged` 和 `forward_decode_paged`; +2. 小 Qwen3 与 Qwen3-8B 共享 tokenizer,可以直接比较 token id; +3. Qwen3 是 dense decoder-only,没有 gpt-oss 的 harmony 格式、MoE sparse 路径、 + sliding-window 或 CUDA Graph 状态; +4. greedy 输出的正确性定义简单:只要 spec 生成的 token 序列与纯 target greedy + 完全一致即可。 + +gpt-oss spec 需要先定义 harmony prompt、MoE draft 选择、graph replay 与 rollback +的交互,这些都不属于本阶段。 + +## 3. Algorithm + +对每个 prompt 建两套模型、三套 KV 状态: + +```text +target model + target commit PagedKVCache +target model + target verify PagedKVCache +draft model + draft PagedKVCache +``` + +先把 prompt 分别 prefill 到三套 cache。此时 cache 都包含 prompt,并各自持有 +"下一个 token" 的 logits。 + +每个 speculative round: + +1. draft 从当前 draft logits 取 argmax,连续生成 `gamma` 个 draft token; +2. draft 每生成一个 token 就用自己的 paged decode append 到 draft KV,所以 round + 结束时 draft cache 暂时包含整个草稿序列; +3. target verify cache 对完整 draft token 序列调用一次 paged prefill,覆盖 + "target 可一次验证草稿窗口" 这条执行路径; +4. target verify cache 立刻 rollback 到 round 起点,避免把 prefill 临时写入污染 + commit cache; +5. 用 target decode 轨迹作为权威结果,从左到右比较 + `target_next_argmax == draft_token`,只接受连续匹配前缀; +6. 对每个接受 token,用 target decode 重放一次来提交 target KV,并得到下一步 + `target_next_argmax`;verify cache 也 mirror decode 同一个 token,保持长度与 prefix 对齐; +7. 若全部匹配,draft cache 已经包含完整草稿,三套 cache 长度重新对齐; +8. 若在第 `k` 个 token 拒绝,提交前 `k` 个 draft token,再提交 target 在该位置的 + argmax 作为修正 token。draft cache rollback 到 round 起点后重放接受 token 和修正 + token,target commit/verify cache 都由 decode 路径提交到同一 prefix。 + +v0 不使用完整 speculative sampling 的概率校正。它只利用小模型猜测 greedy 轨迹, +因此生成序列必须与纯 target greedy 完全一致。 + +当前实现选择 decode 轨迹作为提交路径,而不是直接保留 target prefill 写入的 KV。 +原因是 v0 验收要求 token 序列与纯 target greedy 完全一致;如果 prefill 和 decode +路径在数值或 KV 写入顺序上存在细微差异,直接提交 prefill KV 会让后续 greedy 输出 +漂移。这个保守实现仍会执行 target paged prefill 验证和 rollback,但 verify 写入放在 +独立 cache,不会影响权威 commit cache。代价是额外 mirror decode,速度收益预期较差, +主要用于先验证 draft-model speculative 的状态机和一致性。 + +为保证 greedy exactness,decode 里两个原有非确定点也需要固定: + +- BF16 GEMV 不再用跨 K-block `atomicAdd`;改为写 K-block partials,再按固定顺序 + reduce; +- paged decode attention 不再用 `atomicAdd` 合并 warp 输出;改为 per-warp partials + 后按 warp id 顺序 reduce。 + +## 4. KV Commit And Rollback + +现有 `forward_prefill_paged` 会一次性把传入 token 写进 paged KV,并提前推进 +`seq_len`。验证草稿时 target verify cache 因此会临时包含整个 draft window。 + +新增的 cache 操作只做逻辑截断: + +```text +truncate_sequence(slot, new_len) +``` + +约束: + +- 只允许 `new_len <= current_len`; +- 保留覆盖 `[0, new_len)` 所需的物理 block; +- 释放右侧多余 block; +- 不清零仍在保留 block 内的旧字节,因为后续逻辑长度会阻止 attention 读取它们, + 同一位置再次写入时会覆盖旧值; +- slot 仍保持 registered,`new_len=0` 时也保留第一个 block。 + +这让 target 和 draft 都能在拒绝时安全丢弃多写 KV,并在修正 token decode 后重新 +对齐。 + +## 5. Acceptance Criteria + +本阶段验收: + +- `cargo fmt`; +- `cargo check`; +- `cargo test`; +- `bench-speculative` 可加载 target+draft 两套 Qwen3; +- 50 prompts,greedy,baseline target 与 speculative token id 序列完全一致; +- 输出 acceptance rate、tokens/target-step、TPOT、tok/s 和 speedup; +- 若 draft 模型缺失或磁盘不足,明确报告阻塞条件,不盲目下载大模型。 + +## 6. Validation Results + +dash5 环境: + +- GPU:RTX 5090,device 0; +- target:`/opt/wjh/models/qwen3-8b`; +- draft:`/dashscope-tmp/wjh/models/qwen3-0.6b`; +- command:`bench-speculative ... --prompts 50 --gen-tokens 32 --gamma 4 --device 0`; +- log:`/dashscope-tmp/wjh/xserv-spec-default-50x32-final.log`。 + +默认 `acceptance_mode=decode` 的结果: + +```text +prompts=50 matched=true +acceptance_rate=0.3664 accepted=1020 proposed=2784 +tokens_per_target_step=0.3639 target_steps=4397 +verify_steps=729 mirror_decode_steps=1550 commit_decode_steps=1550 correction_steps=568 +verify_decode_mismatches=10 +baseline_e2e_tpot_ms=13.123 baseline_e2e_tok_s=76.204 +spec_e2e_tpot_ms=44.867 spec_e2e_tok_s=22.288 speedup_e2e=0.2925 +baseline_decode_tpot_ms=12.638 baseline_decode_tok_s=79.127 +spec_decode_tpot_ms=43.731 spec_decode_tok_s=22.867 speedup_decode=0.2890 +decode_token_counts baseline=1600 spec=1600 +``` + +诊断 `--use-verify-logits` 的结果: + +- command:`bench-speculative ... --prompts 10 --gen-tokens 32 --gamma 4 --device 0 --use-verify-logits`; +- log:`/dashscope-tmp/wjh/xserv-spec-verify-logits-10x32.log`; +- exit status:`2`; +- summary:`matched=false`, `verify_decode_mismatches=4`; +- prompt 0/2/7 出现 baseline/spec token 序列分叉。 + +结论:当前可以做 correctness-first 的 speculative decoding 状态机,但还不能把 +target batched prefill verify logits 作为 greedy 接受依据。verify prefill 路径与 +逐 token decode 路径存在 top-1 不一致;默认模式必须继续以 decode 轨迹为权威, +因此 v0 是正确性闭环,不是性能优化。 + +## 7. Known Limits + +- 只支持 batch=1; +- 只支持 Qwen3-family dense models; +- 只支持 greedy exact-match acceptance; +- 未实现 probabilistic rejection sampling,所以 temperature/top-k/top-p 不支持; +- 未接 HTTP/continuous batching; +- 未与 CUDA Graph decode 结合; +- 当前 v0 为保证 greedy exactness,接受 token 也会用 target decode 重放提交,因此 + 即使 acceptance 高也可能变慢; +- draft prefill 和 target prefill 都会计入端到端耗时,短输出可能没有收益。 + +## 8. Next Phase TODO + +如果继续 speculative decoding,下一阶段不要先接 HTTP,应先解决 verify 路径: + +1. 做最小 prefill-vs-decode parity harness:固定 prompt、cache len、draft token, + dump 每层/最终 logits 的 top-k,定位 top-1 分叉来自 attention、GEMV 还是 KV 写入顺序; +2. 让 `--use-verify-logits` 在至少 50 prompts x 64 tokens 下 `matched=true` 且 + `verify_decode_mismatches=0`; +3. parity 过后再做真正 multi-token target commit:要么安全保留 verify prefill 写入的 + KV,要么实现专用 paged multi-token verify/commit kernel,避免当前的 mirror+commit + decode 重放; +4. 只有 `speedup_e2e > 1` 后再考虑 HTTP flag、continuous batching、sampling 或 + gpt-oss speculative decoding。 diff --git a/docs/23-speculative-verify-parity.md b/docs/23-speculative-verify-parity.md new file mode 100644 index 0000000..e004f98 --- /dev/null +++ b/docs/23-speculative-verify-parity.md @@ -0,0 +1,85 @@ +# Phase 23: Speculative Verify Parity + +> 目标:把 speculative decoding 从 v0 的 correctness-only 状态机推进到 +> "verify logits 可作为权威接受依据"。本阶段仍只覆盖 Qwen3 target + +> Qwen3 small draft、batch=1、greedy。 + +## 1. Problem + +Phase 22 的默认模式用逐 token target decode 作为权威路径,因此输出能与 baseline +一致。但诊断 `--use-verify-logits` 会失败:target 对 draft window 做 batched +prefill verify 时,部分 logits top-1 与逐 token decode 不一致。 + +实测 top-k 显示分叉不是大幅数值错误,而是 BF16 near-tie: + +```text +verify_top5=17689:24.500,9856:24.375,... +decode_top5=9856:24.500,17689:24.500,... +``` + +如果直接用这些 verify logits 接受/拒绝 draft token,greedy token 序列会偏离纯 +target decode。 + +## 2. Design + +新增 `Qwen3::forward_verify_paged_decode_attention`: + +1. 在 target commit cache 上一次写入 draft window 的 K/V; +2. attention 使用现有 paged decode attention,每个 draft token 对应一行 metadata, + context lens 分别为 `pos + 1`; +3. 线性层使用逐行 GEMV,与 `forward_decode_paged` 的 BF16 rounding path 对齐; +4. 若 token 全接受,直接保留 verify 写入的 KV; +5. 若在第 `k` 个 token 拒绝,把 target cache truncate 到 accepted prefix,再只 + decode 一个 correction token。 + +bench 新增: + +- `--use-verify-logits`:用 verify logits 作为接受依据,默认选择 `paged-decode` + verify path; +- `--verify-path flash|paged-decode`:显式选择旧 flash prefill 诊断或新 paged-decode + verify path; +- `--dump-verify-mismatches`:打印 mismatch 行 top-k,用于定位 near-tie。 + +## 3. Validation + +dash5: + +- GPU:RTX 5090,device 0; +- target:`/opt/wjh/models/qwen3-8b`; +- draft:`/dashscope-tmp/wjh/models/qwen3-0.6b`; +- command:`bench-speculative ... --prompts 50 --gen-tokens 64 --gamma 4 --device 0 --use-verify-logits`; +- log:`/dashscope-tmp/wjh/xserv-spec-inplace-verify-50x64.log`。 + +结果: + +```text +prompts=50 matched=true +acceptance_mode=verify_logits +verify_path=paged-decode +acceptance_rate=0.3927 accepted=2120 proposed=5398 +tokens_per_target_step=0.9112 target_steps=3512 +verify_steps=1376 mirror_decode_steps=0 commit_decode_steps=1068 correction_steps=1068 +verify_decode_mismatches=0 +baseline_e2e_tpot_ms=13.094 baseline_e2e_tok_s=76.372 +spec_e2e_tpot_ms=30.069 spec_e2e_tok_s=33.257 speedup_e2e=0.4355 +baseline_decode_tpot_ms=12.846 baseline_decode_tok_s=77.844 +spec_decode_tpot_ms=29.731 spec_decode_tok_s=33.635 speedup_decode=0.4321 +decode_token_counts baseline=3200 spec=3200 +``` + +对比 Phase 22 的保守 decode-authoritative v0: + +- verify logits 现在可以作为权威接受依据; +- `mirror_decode_steps` 从每个 accepted token 一次降为 0; +- 50x64 e2e speedup 从约 0.29x 提升到 0.44x; +- 仍未超过 baseline,因为 verify path 为了 parity 使用逐行 GEMV,且 draft acceptance + 只有约 39%。 + +## 4. Next TODO + +下一阶段要从 correctness parity 转向性能: + +1. 逐层替换 row-GEMV 为 batched GEMM,同时保留 near-tie fallback 或 top-k audit; +2. 加一个 `--verify-audit-decode` 低频抽样审计,避免每轮都做 target decode; +3. 扫 `gamma` 与 draft 选择,记录 acceptance 与 TPOT 曲线; +4. `speedup_e2e > 1` 前不接 HTTP/continuous batching/gpt-oss spec。