speculative: Qwen3 draft-model v0 with paged verify parity
Phase 22 lands a correctness-only speculative decoding loop for Qwen3 target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns verify logits into the authoritative acceptance signal so mirror-decode per accepted token is no longer needed. - paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered sequence, freeing whole physical blocks no longer reachable and leaving the slot registered. Covered by a CUDA-gated unit test. - qwen3: forward_verify_paged_decode_attention writes the draft window into the target cache, runs the same paged decode attention kernel per draft token, and uses matmul_rows_gemv so linear layers follow the single-token decode BF16 rounding path. - bench-speculative: new bench binary drives the state machine with --gamma / --gen-tokens / --prompts / --use-verify-logits / --verify-path flash|paged-decode / --dump-verify-mismatches, and compares baseline vs spec token sequences plus TPOT / tok/s / speedup. - docs/22 records the decode-authoritative v0 result and dash5 numbers (matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under --use-verify-logits). - docs/23 records the paged-decode verify path (matched=true, verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the next-step performance TODO.
This commit is contained in:
962
crates/xserv-model/src/bin/bench-speculative.rs
Normal file
962
crates/xserv-model/src/bin/bench-speculative.rs
Normal file
@@ -0,0 +1,962 @@
|
||||
//! Draft-model speculative decoding benchmark for Qwen3.
|
||||
//!
|
||||
//! v0 scope:
|
||||
//! - target + draft are Qwen3-family models with the same tokenizer/vocab;
|
||||
//! - batch=1;
|
||||
//! - greedy exact-match acceptance;
|
||||
//! - no probabilistic rejection sampling.
|
||||
|
||||
use half::bf16;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Instant;
|
||||
|
||||
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
const DEFAULT_GAMMA: usize = 4;
|
||||
const DEFAULT_GEN_TOKENS: usize = 64;
|
||||
const DEFAULT_MAX_SEQ_LEN: usize = 2048;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum VerifyPath {
|
||||
Flash,
|
||||
PagedDecode,
|
||||
}
|
||||
|
||||
impl VerifyPath {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
VerifyPath::Flash => "flash",
|
||||
VerifyPath::PagedDecode => "paged-decode",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const PROMPTS: [&str; 50] = [
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The human brain contains approximately",
|
||||
"Democracy depends on the active participation of",
|
||||
"The train arrived at the station exactly",
|
||||
"Researchers at MIT have developed a new",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
#[derive(Default)]
|
||||
struct RunStats {
|
||||
ids: Vec<u32>,
|
||||
total_s: f64,
|
||||
prefill_s: f64,
|
||||
decode_s: f64,
|
||||
target_steps: usize,
|
||||
accepted: usize,
|
||||
proposed: usize,
|
||||
verify_steps: usize,
|
||||
mirror_steps: usize,
|
||||
commit_steps: usize,
|
||||
correction_steps: usize,
|
||||
verify_decode_mismatches: usize,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Totals {
|
||||
prompts: usize,
|
||||
baseline_generated: usize,
|
||||
spec_generated: usize,
|
||||
baseline_total_s: f64,
|
||||
baseline_prefill_s: f64,
|
||||
baseline_decode_s: f64,
|
||||
spec_total_s: f64,
|
||||
spec_prefill_s: f64,
|
||||
spec_decode_s: f64,
|
||||
spec_target_steps: usize,
|
||||
spec_accepted: usize,
|
||||
spec_proposed: usize,
|
||||
spec_verify_steps: usize,
|
||||
spec_mirror_steps: usize,
|
||||
spec_commit_steps: usize,
|
||||
spec_correction_steps: usize,
|
||||
spec_verify_decode_mismatches: usize,
|
||||
mismatches: usize,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 3 {
|
||||
eprintln!(
|
||||
"Usage: bench-speculative <target-model-dir> <draft-model-dir> \
|
||||
[--gen-tokens N] [--gamma N] [--prompts N] [--max-seq-len N] [--device N] \
|
||||
[--use-verify-logits] [--verify-path flash|paged-decode] [--dump-verify-mismatches]"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let target_dir = PathBuf::from(&args[1]);
|
||||
let draft_dir = PathBuf::from(&args[2]);
|
||||
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
|
||||
let gamma = arg_usize(&args, "--gamma", DEFAULT_GAMMA);
|
||||
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
let use_verify_logits = args.iter().any(|a| a == "--use-verify-logits");
|
||||
let verify_path = parse_verify_path(&args, use_verify_logits);
|
||||
let dump_verify_mismatches = args.iter().any(|a| a == "--dump-verify-mismatches");
|
||||
|
||||
assert!(gen_tokens > 0, "--gen-tokens must be > 0");
|
||||
assert!(gamma > 0, "--gamma must be > 0");
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
eprintln!(
|
||||
"GPU {device}: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
|
||||
let draft_config = ModelConfig::from_file(&draft_dir.join("config.json"));
|
||||
assert_qwen3(&target_config, "target");
|
||||
assert_qwen3(&draft_config, "draft");
|
||||
assert_eq!(
|
||||
target_config.vocab_size, draft_config.vocab_size,
|
||||
"target and draft vocab_size must match"
|
||||
);
|
||||
|
||||
warn_if_tokenizers_differ(&target_dir, &draft_dir);
|
||||
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
|
||||
if tokenizer.vocab_size() != target_config.vocab_size {
|
||||
eprintln!(
|
||||
"WARNING: tokenizer decoder len {} differs from config vocab_size {}; continuing because token ids come from the shared tokenizer.json",
|
||||
tokenizer.vocab_size(),
|
||||
target_config.vocab_size
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"Loading target Qwen3: layers={} hidden={} heads={}/{} vocab={}",
|
||||
target_config.num_layers(),
|
||||
target_config.hidden(),
|
||||
target_config.num_heads(),
|
||||
target_config.num_kv_heads(),
|
||||
target_config.vocab_size
|
||||
);
|
||||
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
|
||||
let target = Qwen3::from_weights(target_config.clone(), target_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
eprintln!(
|
||||
"Loading draft Qwen3: layers={} hidden={} heads={}/{} vocab={}",
|
||||
draft_config.num_layers(),
|
||||
draft_config.hidden(),
|
||||
draft_config.num_heads(),
|
||||
draft_config.num_kv_heads(),
|
||||
draft_config.vocab_size
|
||||
);
|
||||
let draft_weights = loader::load_model_dir(&draft_dir, Device::Cuda(device));
|
||||
let draft = Qwen3::from_weights(draft_config.clone(), draft_weights);
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
|
||||
let warm_ids = tokenizer.encode("warmup");
|
||||
let warm_tokens = gen_tokens.min(4);
|
||||
{
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let _ = run_baseline(
|
||||
&target,
|
||||
&mut target_cache,
|
||||
&tokenizer,
|
||||
&warm_ids,
|
||||
warm_tokens,
|
||||
);
|
||||
}
|
||||
{
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache =
|
||||
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let _ = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&tokenizer,
|
||||
&warm_ids,
|
||||
warm_tokens,
|
||||
gamma,
|
||||
use_verify_logits,
|
||||
verify_path,
|
||||
dump_verify_mismatches,
|
||||
);
|
||||
}
|
||||
eprintln!(
|
||||
"Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}, acceptance_mode={}, verify_path={}",
|
||||
if use_verify_logits {
|
||||
"verify_logits"
|
||||
} else {
|
||||
"decode"
|
||||
},
|
||||
verify_path.as_str()
|
||||
);
|
||||
|
||||
let mut totals = Totals::default();
|
||||
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
|
||||
let ids = tokenizer.encode(prompt);
|
||||
validate_length_budget(&ids, gen_tokens, max_seq_len, prompt);
|
||||
let mut baseline_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
|
||||
drop(baseline_cache);
|
||||
|
||||
let mut target_cache = new_cache_with_rows(
|
||||
&target_config,
|
||||
max_seq_len,
|
||||
device,
|
||||
if use_verify_logits { gamma } else { 1 },
|
||||
);
|
||||
let mut target_verify_cache =
|
||||
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
|
||||
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
|
||||
let spec = run_speculative(
|
||||
&target,
|
||||
&draft,
|
||||
&mut target_cache,
|
||||
&mut target_verify_cache,
|
||||
&mut draft_cache,
|
||||
&tokenizer,
|
||||
&ids,
|
||||
gen_tokens,
|
||||
gamma,
|
||||
use_verify_logits,
|
||||
verify_path,
|
||||
dump_verify_mismatches,
|
||||
);
|
||||
|
||||
let matched = baseline.ids == spec.ids;
|
||||
if !matched {
|
||||
totals.mismatches += 1;
|
||||
eprintln!("MISMATCH prompt {i}: {prompt}");
|
||||
eprintln!(" baseline: {:?}", baseline.ids);
|
||||
eprintln!(" spec: {:?}", spec.ids);
|
||||
}
|
||||
|
||||
println!(
|
||||
"prompt={:02} match={} gen={} accept={}/{} target_steps={} \
|
||||
baseline_e2e_tpot_ms={:.3} spec_e2e_tpot_ms={:.3}",
|
||||
i,
|
||||
matched,
|
||||
spec.ids.len(),
|
||||
spec.accepted,
|
||||
spec.proposed,
|
||||
spec.target_steps,
|
||||
per_token_ms(baseline.total_s, baseline.ids.len()),
|
||||
per_token_ms(spec.total_s, spec.ids.len()),
|
||||
);
|
||||
|
||||
totals.prompts += 1;
|
||||
totals.baseline_generated += baseline.ids.len();
|
||||
totals.spec_generated += spec.ids.len();
|
||||
totals.baseline_total_s += baseline.total_s;
|
||||
totals.baseline_prefill_s += baseline.prefill_s;
|
||||
totals.baseline_decode_s += baseline.decode_s;
|
||||
totals.spec_total_s += spec.total_s;
|
||||
totals.spec_prefill_s += spec.prefill_s;
|
||||
totals.spec_decode_s += spec.decode_s;
|
||||
totals.spec_target_steps += spec.target_steps;
|
||||
totals.spec_accepted += spec.accepted;
|
||||
totals.spec_proposed += spec.proposed;
|
||||
totals.spec_verify_steps += spec.verify_steps;
|
||||
totals.spec_mirror_steps += spec.mirror_steps;
|
||||
totals.spec_commit_steps += spec.commit_steps;
|
||||
totals.spec_correction_steps += spec.correction_steps;
|
||||
totals.spec_verify_decode_mismatches += spec.verify_decode_mismatches;
|
||||
}
|
||||
|
||||
let baseline_decode_tokens = totals.baseline_generated;
|
||||
let spec_decode_tokens = totals.spec_generated;
|
||||
let acceptance = ratio(totals.spec_accepted, totals.spec_proposed);
|
||||
let tokens_per_target_step = ratio(totals.spec_generated, totals.spec_target_steps);
|
||||
let matched =
|
||||
totals.mismatches == 0 && (!use_verify_logits || totals.spec_verify_decode_mismatches == 0);
|
||||
|
||||
println!("--- SUMMARY ---");
|
||||
println!("prompts={} matched={matched}", totals.prompts);
|
||||
println!(
|
||||
"acceptance_mode={}",
|
||||
if use_verify_logits {
|
||||
"verify_logits"
|
||||
} else {
|
||||
"decode"
|
||||
}
|
||||
);
|
||||
println!("verify_path={}", verify_path.as_str());
|
||||
println!(
|
||||
"acceptance_rate={:.4} accepted={} proposed={}",
|
||||
acceptance, totals.spec_accepted, totals.spec_proposed
|
||||
);
|
||||
println!(
|
||||
"tokens_per_target_step={:.4} target_steps={} verify_steps={} mirror_decode_steps={} commit_decode_steps={} correction_steps={}",
|
||||
tokens_per_target_step,
|
||||
totals.spec_target_steps,
|
||||
totals.spec_verify_steps,
|
||||
totals.spec_mirror_steps,
|
||||
totals.spec_commit_steps,
|
||||
totals.spec_correction_steps
|
||||
);
|
||||
println!(
|
||||
"verify_decode_mismatches={}",
|
||||
totals.spec_verify_decode_mismatches
|
||||
);
|
||||
println!(
|
||||
"baseline_e2e_tpot_ms={:.3} baseline_e2e_tok_s={:.3}",
|
||||
per_token_ms(totals.baseline_total_s, totals.baseline_generated),
|
||||
tok_s(totals.baseline_generated, totals.baseline_total_s)
|
||||
);
|
||||
println!(
|
||||
"spec_e2e_tpot_ms={:.3} spec_e2e_tok_s={:.3} speedup_e2e={:.4}",
|
||||
per_token_ms(totals.spec_total_s, totals.spec_generated),
|
||||
tok_s(totals.spec_generated, totals.spec_total_s),
|
||||
speedup(totals.baseline_total_s, totals.spec_total_s)
|
||||
);
|
||||
println!(
|
||||
"baseline_decode_tpot_ms={:.3} baseline_decode_tok_s={:.3}",
|
||||
per_token_ms(totals.baseline_decode_s, baseline_decode_tokens),
|
||||
tok_s(baseline_decode_tokens, totals.baseline_decode_s)
|
||||
);
|
||||
println!(
|
||||
"spec_decode_tpot_ms={:.3} spec_decode_tok_s={:.3} speedup_decode={:.4}",
|
||||
per_token_ms(totals.spec_decode_s, spec_decode_tokens),
|
||||
tok_s(spec_decode_tokens, totals.spec_decode_s),
|
||||
speedup(totals.baseline_decode_s, totals.spec_decode_s)
|
||||
);
|
||||
println!(
|
||||
"decode_token_counts baseline={} spec={}",
|
||||
baseline_decode_tokens, spec_decode_tokens
|
||||
);
|
||||
|
||||
if !matched {
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_baseline(
|
||||
model: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
cache.register_sequence(slot).expect("register target slot");
|
||||
|
||||
let t0 = Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
let logits = model.forward_prefill_paged(prompt_ids, slot, cache);
|
||||
sync_device();
|
||||
let prefill_s = prefill_start.elapsed().as_secs_f64();
|
||||
|
||||
let mut generated = Vec::with_capacity(gen_tokens);
|
||||
let mut next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
|
||||
let decode_start = Instant::now();
|
||||
let mut target_steps = 0usize;
|
||||
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache);
|
||||
target_steps += 1;
|
||||
next = last_argmax(&logits);
|
||||
generated.push(next);
|
||||
}
|
||||
sync_device();
|
||||
let decode_s = decode_start.elapsed().as_secs_f64();
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
|
||||
cache.free_sequence(slot);
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
prefill_s,
|
||||
decode_s,
|
||||
target_steps,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_speculative(
|
||||
target: &Qwen3,
|
||||
draft: &Qwen3,
|
||||
target_cache: &mut PagedKVCache,
|
||||
target_verify_cache: &mut PagedKVCache,
|
||||
draft_cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_ids: &[u32],
|
||||
gen_tokens: usize,
|
||||
gamma: usize,
|
||||
use_verify_logits: bool,
|
||||
verify_path: VerifyPath,
|
||||
dump_verify_mismatches: bool,
|
||||
) -> RunStats {
|
||||
let slot = 0;
|
||||
target_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register target slot");
|
||||
target_verify_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register target verify slot");
|
||||
draft_cache
|
||||
.register_sequence(slot)
|
||||
.expect("register draft slot");
|
||||
|
||||
let t0 = Instant::now();
|
||||
let prefill_start = Instant::now();
|
||||
let target_logits = target.forward_prefill_paged(prompt_ids, slot, target_cache);
|
||||
if !use_verify_logits {
|
||||
let _ = target.forward_prefill_paged(prompt_ids, slot, target_verify_cache);
|
||||
}
|
||||
let draft_logits = draft.forward_prefill_paged(prompt_ids, slot, draft_cache);
|
||||
sync_device();
|
||||
let prefill_s = prefill_start.elapsed().as_secs_f64();
|
||||
|
||||
let mut target_next = last_argmax(&target_logits);
|
||||
let mut draft_next = last_argmax(&draft_logits);
|
||||
let mut generated = Vec::with_capacity(gen_tokens);
|
||||
let mut accepted_total = 0usize;
|
||||
let mut proposed_total = 0usize;
|
||||
let mut verify_steps = 0usize;
|
||||
let mut mirror_steps = 0usize;
|
||||
let mut commit_steps = 0usize;
|
||||
let mut correction_steps = 0usize;
|
||||
let mut verify_decode_mismatches = 0usize;
|
||||
|
||||
let decode_start = Instant::now();
|
||||
while generated.len() < gen_tokens {
|
||||
let remaining = gen_tokens - generated.len();
|
||||
let round_gamma = gamma.min(remaining);
|
||||
let round_start_len = target_cache.seq_len(slot);
|
||||
assert_eq!(
|
||||
round_start_len,
|
||||
draft_cache.seq_len(slot),
|
||||
"target and draft cache lengths diverged"
|
||||
);
|
||||
if !use_verify_logits {
|
||||
assert_eq!(
|
||||
round_start_len,
|
||||
target_verify_cache.seq_len(slot),
|
||||
"target verify cache length diverged"
|
||||
);
|
||||
}
|
||||
|
||||
let mut draft_tokens = Vec::with_capacity(round_gamma);
|
||||
for _ in 0..round_gamma {
|
||||
let token = draft_next;
|
||||
draft_tokens.push(token);
|
||||
if tokenizer.is_eos(token) {
|
||||
break;
|
||||
}
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
}
|
||||
proposed_total += draft_tokens.len();
|
||||
|
||||
if use_verify_logits {
|
||||
verify_steps += 1;
|
||||
let verify_logits =
|
||||
target.forward_verify_paged_decode_attention(&draft_tokens, slot, target_cache);
|
||||
let verify_argmax = argmax_rows(&verify_logits);
|
||||
assert_eq!(
|
||||
verify_argmax.len(),
|
||||
draft_tokens.len(),
|
||||
"verify logits rows must match draft token count"
|
||||
);
|
||||
|
||||
let mut accepted = 0usize;
|
||||
let mut done = false;
|
||||
while accepted < draft_tokens.len() {
|
||||
let expected = if accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
if draft_tokens[accepted] != expected {
|
||||
break;
|
||||
}
|
||||
let token = draft_tokens[accepted];
|
||||
generated.push(token);
|
||||
accepted_total += 1;
|
||||
accepted += 1;
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if accepted > 0 {
|
||||
target_next = verify_argmax[accepted - 1];
|
||||
}
|
||||
target_cache
|
||||
.truncate_sequence(slot, round_start_len + accepted)
|
||||
.unwrap();
|
||||
|
||||
if done {
|
||||
draft_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
break;
|
||||
}
|
||||
|
||||
if accepted == draft_tokens.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let correction = if accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
generated.push(correction);
|
||||
|
||||
draft_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
&mut draft_next,
|
||||
);
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
|
||||
target_next = last_argmax(&logits);
|
||||
commit_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
verify_steps += 1;
|
||||
let verify_logits = match verify_path {
|
||||
VerifyPath::Flash => {
|
||||
target.forward_prefill_paged(&draft_tokens, slot, target_verify_cache)
|
||||
}
|
||||
VerifyPath::PagedDecode => target.forward_verify_paged_decode_attention(
|
||||
&draft_tokens,
|
||||
slot,
|
||||
target_verify_cache,
|
||||
),
|
||||
};
|
||||
let verify_argmax = argmax_rows(&verify_logits);
|
||||
assert_eq!(
|
||||
verify_argmax.len(),
|
||||
draft_tokens.len(),
|
||||
"verify logits rows must match draft token count"
|
||||
);
|
||||
|
||||
target_verify_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
|
||||
let mut accepted = 0usize;
|
||||
let mut done = false;
|
||||
while accepted < draft_tokens.len() {
|
||||
let expected = if use_verify_logits && accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
if draft_tokens[accepted] != expected {
|
||||
break;
|
||||
}
|
||||
let token_idx = accepted;
|
||||
let token = draft_tokens[token_idx];
|
||||
generated.push(token);
|
||||
accepted_total += 1;
|
||||
accepted += 1;
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[token], &[pos], &[slot], target_cache);
|
||||
let decode_next = last_argmax(&logits);
|
||||
if verify_argmax[token_idx] != decode_next {
|
||||
verify_decode_mismatches += 1;
|
||||
eprintln!(
|
||||
"VERIFY/DECODE MISMATCH at cache_len={} accepted_idx={}: verify={} decode={}",
|
||||
target_cache.seq_len(slot),
|
||||
token_idx,
|
||||
verify_argmax[token_idx],
|
||||
decode_next
|
||||
);
|
||||
if dump_verify_mismatches {
|
||||
eprintln!(
|
||||
" verify_top5={} decode_top5={}",
|
||||
format_topk(&verify_logits, token_idx, 5),
|
||||
format_topk(&logits, 0, 5)
|
||||
);
|
||||
}
|
||||
}
|
||||
target_next = decode_next;
|
||||
commit_steps += 1;
|
||||
|
||||
advance_target_cache(target, target_verify_cache, slot, token);
|
||||
mirror_steps += 1;
|
||||
}
|
||||
if done {
|
||||
draft_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
target_verify_cache
|
||||
.truncate_sequence(slot, target_cache.seq_len(slot))
|
||||
.unwrap();
|
||||
break;
|
||||
}
|
||||
|
||||
if accepted == draft_tokens.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let correction = if use_verify_logits && accepted > 0 {
|
||||
verify_argmax[accepted - 1]
|
||||
} else {
|
||||
target_next
|
||||
};
|
||||
generated.push(correction);
|
||||
|
||||
draft_cache
|
||||
.truncate_sequence(slot, round_start_len)
|
||||
.unwrap();
|
||||
replay_draft_tokens(
|
||||
draft,
|
||||
draft_cache,
|
||||
slot,
|
||||
&draft_tokens[..accepted],
|
||||
&mut draft_next,
|
||||
);
|
||||
|
||||
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = target_cache.seq_len(slot);
|
||||
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
|
||||
target_next = last_argmax(&logits);
|
||||
commit_steps += 1;
|
||||
|
||||
advance_target_cache(target, target_verify_cache, slot, correction);
|
||||
mirror_steps += 1;
|
||||
|
||||
let pos = draft_cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
|
||||
draft_next = last_argmax(&logits);
|
||||
correction_steps += 1;
|
||||
}
|
||||
sync_device();
|
||||
let decode_s = decode_start.elapsed().as_secs_f64();
|
||||
sync_device();
|
||||
let total_s = t0.elapsed().as_secs_f64();
|
||||
|
||||
target_cache.free_sequence(slot);
|
||||
target_verify_cache.free_sequence(slot);
|
||||
draft_cache.free_sequence(slot);
|
||||
|
||||
RunStats {
|
||||
ids: generated,
|
||||
total_s,
|
||||
prefill_s,
|
||||
decode_s,
|
||||
target_steps: verify_steps + mirror_steps + commit_steps + correction_steps,
|
||||
accepted: accepted_total,
|
||||
proposed: proposed_total,
|
||||
verify_steps,
|
||||
mirror_steps,
|
||||
commit_steps,
|
||||
correction_steps,
|
||||
verify_decode_mismatches,
|
||||
}
|
||||
}
|
||||
|
||||
fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, token: u32) {
|
||||
let pos = cache.seq_len(slot);
|
||||
let _ = target.forward_decode_paged(&[token], &[pos], &[slot], cache);
|
||||
}
|
||||
|
||||
fn replay_draft_tokens(
|
||||
draft: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
slot: usize,
|
||||
tokens: &[u32],
|
||||
next: &mut u32,
|
||||
) {
|
||||
for &token in tokens {
|
||||
let pos = cache.seq_len(slot);
|
||||
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], cache);
|
||||
*next = last_argmax(&logits);
|
||||
}
|
||||
}
|
||||
|
||||
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
|
||||
new_cache_with_rows(config, max_seq_len, device, 1)
|
||||
}
|
||||
|
||||
fn new_cache_with_rows(
|
||||
config: &ModelConfig,
|
||||
max_seq_len: usize,
|
||||
device: u32,
|
||||
max_rows: usize,
|
||||
) -> PagedKVCache {
|
||||
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
PagedKVCache::new(
|
||||
config,
|
||||
total_blocks,
|
||||
0,
|
||||
max_rows.max(1),
|
||||
max_blocks_per_seq,
|
||||
DType::BF16,
|
||||
device,
|
||||
)
|
||||
}
|
||||
|
||||
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
if logits.dtype() == DType::BF16
|
||||
&& matches!(logits.device(), Device::Cuda(_))
|
||||
&& logits.is_contiguous()
|
||||
{
|
||||
return xserv_kernels::argmax_bf16_to_host(logits);
|
||||
}
|
||||
|
||||
let vocab_size = logits.shape()[1];
|
||||
let rows = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
match logits.dtype() {
|
||||
DType::F32 => logits_cpu
|
||||
.as_slice::<f32>()
|
||||
.chunks_exact(vocab_size)
|
||||
.take(rows)
|
||||
.map(argmax_f32)
|
||||
.collect(),
|
||||
DType::BF16 => logits_cpu
|
||||
.as_slice::<bf16>()
|
||||
.chunks_exact(vocab_size)
|
||||
.take(rows)
|
||||
.map(|row| {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
})
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for argmax: {:?}", logits.dtype()),
|
||||
}
|
||||
}
|
||||
|
||||
fn last_argmax(logits: &Tensor) -> u32 {
|
||||
*argmax_rows(logits).last().unwrap()
|
||||
}
|
||||
|
||||
fn argmax_f32(row: &[f32]) -> u32 {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn format_topk(logits: &Tensor, row: usize, k: usize) -> String {
|
||||
let vals = topk_row(logits, row, k);
|
||||
vals.iter()
|
||||
.map(|(id, val)| format!("{id}:{val:.3}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
}
|
||||
|
||||
fn topk_row(logits: &Tensor, row: usize, k: usize) -> Vec<(u32, f32)> {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let vocab_size = logits.shape()[1];
|
||||
assert!(row < logits.shape()[0], "topk row out of bounds");
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let mut vals: Vec<(u32, f32)> = match logits.dtype() {
|
||||
DType::F32 => logits_cpu.as_slice::<f32>()[row * vocab_size..(row + 1) * vocab_size]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i as u32, v))
|
||||
.collect(),
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()[row * vocab_size..(row + 1) * vocab_size]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i as u32, v.to_f32()))
|
||||
.collect(),
|
||||
_ => panic!("unsupported dtype for topk: {:?}", logits.dtype()),
|
||||
};
|
||||
vals.select_nth_unstable_by(k, |a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
vals.truncate(k);
|
||||
vals.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
vals
|
||||
}
|
||||
|
||||
fn assert_qwen3(config: &ModelConfig, name: &str) {
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
assert!(
|
||||
model_type.contains("qwen"),
|
||||
"{name} model_type must be qwen-like, got {model_type}"
|
||||
);
|
||||
}
|
||||
|
||||
fn warn_if_tokenizers_differ(target_dir: &Path, draft_dir: &Path) {
|
||||
let target = std::fs::read(target_dir.join("tokenizer.json"));
|
||||
let draft = std::fs::read(draft_dir.join("tokenizer.json"));
|
||||
if let (Ok(target), Ok(draft)) = (target, draft) {
|
||||
if target != draft {
|
||||
eprintln!(
|
||||
"WARNING: target and draft tokenizer.json differ; v0 assumes identical token ids"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
|
||||
args.iter()
|
||||
.position(|a| a == flag)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn parse_verify_path(args: &[String], use_verify_logits: bool) -> VerifyPath {
|
||||
let default = if use_verify_logits {
|
||||
VerifyPath::PagedDecode
|
||||
} else {
|
||||
VerifyPath::Flash
|
||||
};
|
||||
let Some(value) = args
|
||||
.iter()
|
||||
.position(|a| a == "--verify-path")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
else {
|
||||
return default;
|
||||
};
|
||||
match value.as_str() {
|
||||
"flash" => VerifyPath::Flash,
|
||||
"paged-decode" => VerifyPath::PagedDecode,
|
||||
_ => {
|
||||
eprintln!("unknown --verify-path {value:?}; expected flash or paged-decode");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_length_budget(prompt_ids: &[u32], gen_tokens: usize, max_seq_len: usize, prompt: &str) {
|
||||
let required = prompt_ids.len() + gen_tokens;
|
||||
if required > max_seq_len {
|
||||
eprintln!(
|
||||
"prompt requires prompt_len({}) + gen_tokens({}) = {} tokens, exceeding --max-seq-len {}: {:?}",
|
||||
prompt_ids.len(),
|
||||
gen_tokens,
|
||||
required,
|
||||
max_seq_len,
|
||||
prompt
|
||||
);
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_device() {
|
||||
xserv_cuda::device::synchronize().expect("cuda device synchronize");
|
||||
}
|
||||
|
||||
fn ratio(num: usize, den: usize) -> f64 {
|
||||
if den == 0 {
|
||||
0.0
|
||||
} else {
|
||||
num as f64 / den as f64
|
||||
}
|
||||
}
|
||||
|
||||
fn speedup(baseline_s: f64, spec_s: f64) -> f64 {
|
||||
if spec_s == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
baseline_s / spec_s
|
||||
}
|
||||
}
|
||||
|
||||
fn tok_s(tokens: usize, seconds: f64) -> f64 {
|
||||
if seconds == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
tokens as f64 / seconds
|
||||
}
|
||||
}
|
||||
|
||||
fn per_token_ms(seconds: f64, tokens: usize) -> f64 {
|
||||
if tokens == 0 {
|
||||
0.0
|
||||
} else {
|
||||
seconds * 1000.0 / tokens as f64
|
||||
}
|
||||
}
|
||||
@@ -486,6 +486,35 @@ impl PagedKVCache {
|
||||
state.seq_len += num_tokens;
|
||||
}
|
||||
|
||||
/// Roll a registered sequence back to `new_len` tokens.
|
||||
///
|
||||
/// This only changes cache metadata and frees whole physical blocks that are
|
||||
/// no longer reachable. Bytes inside retained blocks are left untouched; the
|
||||
/// logical `seq_len` prevents attention from reading them, and later writes
|
||||
/// to the same positions overwrite them.
|
||||
pub fn truncate_sequence(&mut self, slot: usize, new_len: usize) -> Result<(), &'static str> {
|
||||
if slot >= self.max_seqs {
|
||||
return Err("truncate_sequence: slot out of range");
|
||||
}
|
||||
let state = self.seq_states[slot]
|
||||
.as_mut()
|
||||
.ok_or("truncate_sequence: empty slot")?;
|
||||
if new_len > state.seq_len {
|
||||
return Err("truncate_sequence: cannot extend");
|
||||
}
|
||||
|
||||
let needed_blocks = ((new_len + BLOCK_SIZE - 1) / BLOCK_SIZE).max(1);
|
||||
while state.block_ids.len() > needed_blocks {
|
||||
let block = state.block_ids.pop().expect("checked len");
|
||||
match state.location {
|
||||
Location::Gpu => self.allocator.free(block),
|
||||
Location::Cpu => self.cpu_allocator.free(block),
|
||||
}
|
||||
}
|
||||
state.seq_len = new_len;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Refresh the host-side block table + context lens from `seq_states`,
|
||||
/// then upload to GPU. Call once per decode step before the paged kernel.
|
||||
pub fn sync_to_gpu(&mut self) {
|
||||
@@ -748,6 +777,71 @@ impl PagedKVCache {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn tiny_config() -> ModelConfig {
|
||||
serde_json::from_value(serde_json::json!({
|
||||
"model_type": "qwen3",
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 16,
|
||||
"num_attention_heads": 1,
|
||||
"num_key_value_heads": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"vocab_size": 32,
|
||||
"max_position_embeddings": 64
|
||||
}))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_sequence_frees_whole_blocks_and_keeps_slot_registered() {
|
||||
if xserv_cuda::device::set_device(0).is_err() {
|
||||
eprintln!("skipping CUDA-backed PagedKVCache test: device 0 unavailable");
|
||||
return;
|
||||
}
|
||||
|
||||
let config = tiny_config();
|
||||
let mut cache = PagedKVCache::new(&config, 5, 0, 1, 4, DType::BF16, 0);
|
||||
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(1, 0),
|
||||
Err("truncate_sequence: slot out of range")
|
||||
);
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(0, 0),
|
||||
Err("truncate_sequence: empty slot")
|
||||
);
|
||||
|
||||
cache.register_sequence(0).unwrap();
|
||||
cache.ensure_capacity(0, BLOCK_SIZE * 3 + 1);
|
||||
cache.advance_seq_len(0, BLOCK_SIZE * 3 + 1);
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE * 3 + 1);
|
||||
assert_eq!(cache.block_count(0), 4);
|
||||
assert_eq!(cache.free_blocks(), 0);
|
||||
|
||||
cache.truncate_sequence(0, BLOCK_SIZE + 1).unwrap();
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE + 1);
|
||||
assert_eq!(cache.block_count(0), 2);
|
||||
assert_eq!(cache.free_blocks(), 2);
|
||||
|
||||
cache.truncate_sequence(0, BLOCK_SIZE).unwrap();
|
||||
assert_eq!(cache.seq_len(0), BLOCK_SIZE);
|
||||
assert_eq!(cache.block_count(0), 1);
|
||||
assert_eq!(cache.free_blocks(), 3);
|
||||
|
||||
cache.truncate_sequence(0, 0).unwrap();
|
||||
assert_eq!(cache.seq_len(0), 0);
|
||||
assert_eq!(cache.block_count(0), 1);
|
||||
assert_eq!(cache.free_blocks(), 3);
|
||||
assert_eq!(
|
||||
cache.truncate_sequence(0, 1),
|
||||
Err("truncate_sequence: cannot extend")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn tensor_from_owned_buf(
|
||||
buf: GpuBuffer,
|
||||
shape: &[usize],
|
||||
|
||||
@@ -884,6 +884,109 @@ impl Qwen3 {
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Paged multi-token verify path: write `token_ids` into the paged cache,
|
||||
/// then verify them with the same paged decode attention kernel used by
|
||||
/// single-token decode. This keeps greedy top-1 behavior aligned with
|
||||
/// `forward_decode_paged` while still batching the dense projections/MLP
|
||||
/// across the draft window.
|
||||
pub fn forward_verify_paged_decode_attention(
|
||||
&self,
|
||||
token_ids: &[u32],
|
||||
slot: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
|
||||
.map(|p| p as u32)
|
||||
.collect();
|
||||
let kv_lens: Vec<i32> = (0..new_tokens)
|
||||
.map(|i| (pos_offset + i + 1) as i32)
|
||||
.collect();
|
||||
let slots = vec![slot; new_tokens];
|
||||
paged_cache.sync_active_batch_with_lens(&slots, &kv_lens);
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt);
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim);
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim);
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
let q_flat = q_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
|
||||
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions);
|
||||
|
||||
let v_3d = v_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
|
||||
|
||||
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_decode,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
||||
let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_rows_gemv(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
@@ -1158,6 +1261,20 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
|
||||
)
|
||||
}
|
||||
|
||||
/// Run a 2D matmul row by row so each row uses the same GEMV kernel as
|
||||
/// single-token decode. Used by speculative verify parity, where near-tie
|
||||
/// logits must follow decode's BF16 rounding path.
|
||||
fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert!(a.is_contiguous());
|
||||
let rows = a.shape()[0];
|
||||
if rows == 1 {
|
||||
return matmul_2d(a, b);
|
||||
}
|
||||
let out_rows: Vec<Tensor> = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect();
|
||||
concat_rows(&out_rows)
|
||||
}
|
||||
|
||||
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
|
||||
fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
assert!(!rows.is_empty());
|
||||
|
||||
Reference in New Issue
Block a user