4 Commits

Author SHA1 Message Date
a67753f516 softmax: cap block size at 512 threads
launch_softmax_{f32,bf16} clamped block to 1024 threads when cols was
larger. Halving the ceiling to 512 keeps two blocks per SM resident on
the large vocab kernels that dominate speculative verify workloads
without changing rows/block indexing, and never exceeds cols.
2026-07-01 14:16:32 +08:00
f5ec10c2c3 xserv-cli: expose sampling params and greedy repetition penalty
Interactive REPL used to always call sample_greedy_last on both the
paged and legacy KV paths, so temperature/top-k/top-p and the repetition
penalty added in the sampling module were unreachable from the CLI.

- flag() helper parses --max-tokens / --temperature / --top-k / --top-p
  / --rep-penalty / --rep-window (defaults preserve prior behavior:
  temperature 0, top-p 1, penalty 1, window 512).
- pick_next() dispatches to sample_greedy_penalized only when
  temperature==0 and rep_penalty>1, otherwise to sample().
- Both Qwen3/GPT-2 paths and the GptOss paged path share the same
  sampler and both feed the rolling history window used for the penalty.
- Prompt input now unescapes literal "\n" so multi-turn prompts can be
  typed on one line.
2026-07-01 14:16:31 +08:00
ce7229f4fe speculative: Qwen3 draft-model v0 with paged verify parity
Phase 22 lands a correctness-only speculative decoding loop for Qwen3
target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns
verify logits into the authoritative acceptance signal so mirror-decode
per accepted token is no longer needed.

- paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered
  sequence, freeing whole physical blocks no longer reachable and
  leaving the slot registered. Covered by a CUDA-gated unit test.
- qwen3: forward_verify_paged_decode_attention writes the draft window
  into the target cache, runs the same paged decode attention kernel per
  draft token, and uses matmul_rows_gemv so linear layers follow the
  single-token decode BF16 rounding path.
- bench-speculative: new bench binary drives the state machine with
  --gamma / --gen-tokens / --prompts / --use-verify-logits /
  --verify-path flash|paged-decode / --dump-verify-mismatches, and
  compares baseline vs spec token sequences plus TPOT / tok/s / speedup.
- docs/22 records the decode-authoritative v0 result and dash5 numbers
  (matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under
  --use-verify-logits).
- docs/23 records the paged-decode verify path (matched=true,
  verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the
  next-step performance TODO.
2026-07-01 14:16:30 +08:00
5b350ee5f0 cuda: deterministic BF16 gemv + paged attention reductions
BF16 greedy decode was sensitive to inter-block scheduling when logits
were close, which broke speculative-decoding verify-vs-decode parity.

- gemv.cu: write per-K-block partials, then reduce in fixed block order
  in a second kernel instead of atomicAdd across K-blocks. Scratch
  buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems()
  exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/
  intermediate/vocab from it.
- paged_attention.cu: replace atomicAdd merge of warp outputs with
  per-warp shared partials reduced in warp-id order for both the base
  and sinks kernels.
2026-07-01 14:16:28 +08:00
11 changed files with 1562 additions and 85 deletions

View File

@@ -5,6 +5,7 @@ use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
const GEMV_TILE_K: usize = 256;
// GEMV: single-kernel, no FP32 temp buffer needed
unsafe extern "C" {
@@ -26,6 +27,10 @@ pub enum GemmBackend {
CuBlas,
}
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
n * k.div_ceil(GEMV_TILE_K)
}
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(
@@ -274,7 +279,8 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
},
GemmBackend::CuBlas => {
if m == 1 && dtype == DType::BF16 && n >= 256 {
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
let mut fp32_buf =
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
unsafe {
launch_gemv_bf16(
a_ptr,

View File

@@ -0,0 +1,962 @@
//! Draft-model speculative decoding benchmark for Qwen3.
//!
//! v0 scope:
//! - target + draft are Qwen3-family models with the same tokenizer/vocab;
//! - batch=1;
//! - greedy exact-match acceptance;
//! - no probabilistic rejection sampling.
use half::bf16;
use std::path::{Path, PathBuf};
use std::time::Instant;
use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader};
use xserv_tensor::{DType, Device, Tensor};
use xserv_tokenizer::Tokenizer;
const DEFAULT_GAMMA: usize = 4;
const DEFAULT_GEN_TOKENS: usize = 64;
const DEFAULT_MAX_SEQ_LEN: usize = 2048;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum VerifyPath {
Flash,
PagedDecode,
}
impl VerifyPath {
fn as_str(self) -> &'static str {
match self {
VerifyPath::Flash => "flash",
VerifyPath::PagedDecode => "paged-decode",
}
}
}
const PROMPTS: [&str; 50] = [
"The capital of France is",
"Once upon a time in a land far away",
"Hello, how are you doing today",
"In a shocking finding, scientists discovered a",
"The weather today is sunny, so I decided to",
"Alan Turing was a British mathematician who",
"The best way to learn programming is",
"Artificial intelligence will change the world because",
"The history of the internet began in the",
"A good morning routine starts with",
"The stock market crashed because investors",
"Deep learning is a subset of machine learning that",
"The president of the United States announced",
"In the year 2050, humans will",
"The secret to happiness is",
"When I was a child, I used to",
"The most important scientific discovery of the century",
"Climate change is caused by",
"The recipe for chocolate cake requires",
"In conclusion, the evidence suggests that",
"The cat sat on the mat and",
"According to recent studies, exercise can",
"The first step in solving any problem is",
"Technology has transformed the way we",
"The novel begins with the protagonist",
"Education is the most powerful weapon",
"The ocean covers more than seventy percent of",
"Last night I had a dream about",
"The company announced its quarterly earnings",
"Music has the power to",
"The difference between success and failure is",
"In the beginning, there was nothing but",
"The doctor told me that I should",
"Python is a popular programming language because",
"The ancient Romans built roads that",
"A balanced diet should include",
"The movie received mixed reviews from critics",
"Space exploration has led to many",
"The teacher asked the students to",
"Global warming is one of the most",
"The bridge collapsed due to structural",
"Quantum computing promises to revolutionize",
"The new policy will affect millions of",
"During the winter months, it is important to",
"The human brain contains approximately",
"Democracy depends on the active participation of",
"The train arrived at the station exactly",
"Researchers at MIT have developed a new",
"The smartphone has become an essential part of",
"After careful consideration, the committee decided to",
];
#[derive(Default)]
struct RunStats {
ids: Vec<u32>,
total_s: f64,
prefill_s: f64,
decode_s: f64,
target_steps: usize,
accepted: usize,
proposed: usize,
verify_steps: usize,
mirror_steps: usize,
commit_steps: usize,
correction_steps: usize,
verify_decode_mismatches: usize,
}
#[derive(Default)]
struct Totals {
prompts: usize,
baseline_generated: usize,
spec_generated: usize,
baseline_total_s: f64,
baseline_prefill_s: f64,
baseline_decode_s: f64,
spec_total_s: f64,
spec_prefill_s: f64,
spec_decode_s: f64,
spec_target_steps: usize,
spec_accepted: usize,
spec_proposed: usize,
spec_verify_steps: usize,
spec_mirror_steps: usize,
spec_commit_steps: usize,
spec_correction_steps: usize,
spec_verify_decode_mismatches: usize,
mismatches: usize,
}
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 3 {
eprintln!(
"Usage: bench-speculative <target-model-dir> <draft-model-dir> \
[--gen-tokens N] [--gamma N] [--prompts N] [--max-seq-len N] [--device N] \
[--use-verify-logits] [--verify-path flash|paged-decode] [--dump-verify-mismatches]"
);
std::process::exit(1);
}
let target_dir = PathBuf::from(&args[1]);
let draft_dir = PathBuf::from(&args[2]);
let gen_tokens = arg_usize(&args, "--gen-tokens", DEFAULT_GEN_TOKENS);
let gamma = arg_usize(&args, "--gamma", DEFAULT_GAMMA);
let prompt_count = arg_usize(&args, "--prompts", PROMPTS.len()).min(PROMPTS.len());
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
let device = arg_usize(&args, "--device", 0) as u32;
let use_verify_logits = args.iter().any(|a| a == "--use-verify-logits");
let verify_path = parse_verify_path(&args, use_verify_logits);
let dump_verify_mismatches = args.iter().any(|a| a == "--dump-verify-mismatches");
assert!(gen_tokens > 0, "--gen-tokens must be > 0");
assert!(gamma > 0, "--gamma must be > 0");
xserv_cuda::device::set_device(device).unwrap();
let info = xserv_cuda::device::device_info(device).unwrap();
eprintln!(
"GPU {device}: {} ({} MB free)",
info.name,
info.free_memory / 1024 / 1024
);
let target_config = ModelConfig::from_file(&target_dir.join("config.json"));
let draft_config = ModelConfig::from_file(&draft_dir.join("config.json"));
assert_qwen3(&target_config, "target");
assert_qwen3(&draft_config, "draft");
assert_eq!(
target_config.vocab_size, draft_config.vocab_size,
"target and draft vocab_size must match"
);
warn_if_tokenizers_differ(&target_dir, &draft_dir);
let tokenizer = Tokenizer::from_file(&target_dir.join("tokenizer.json"));
if tokenizer.vocab_size() != target_config.vocab_size {
eprintln!(
"WARNING: tokenizer decoder len {} differs from config vocab_size {}; continuing because token ids come from the shared tokenizer.json",
tokenizer.vocab_size(),
target_config.vocab_size
);
}
eprintln!(
"Loading target Qwen3: layers={} hidden={} heads={}/{} vocab={}",
target_config.num_layers(),
target_config.hidden(),
target_config.num_heads(),
target_config.num_kv_heads(),
target_config.vocab_size
);
let target_weights = loader::load_model_dir(&target_dir, Device::Cuda(device));
let target = Qwen3::from_weights(target_config.clone(), target_weights);
xserv_cuda::allocator::cached_trim();
eprintln!(
"Loading draft Qwen3: layers={} hidden={} heads={}/{} vocab={}",
draft_config.num_layers(),
draft_config.hidden(),
draft_config.num_heads(),
draft_config.num_kv_heads(),
draft_config.vocab_size
);
let draft_weights = loader::load_model_dir(&draft_dir, Device::Cuda(device));
let draft = Qwen3::from_weights(draft_config.clone(), draft_weights);
xserv_cuda::allocator::cached_trim();
let warm_ids = tokenizer.encode("warmup");
let warm_tokens = gen_tokens.min(4);
{
let mut target_cache = new_cache(&target_config, max_seq_len, device);
let _ = run_baseline(
&target,
&mut target_cache,
&tokenizer,
&warm_ids,
warm_tokens,
);
}
{
let mut target_cache = new_cache_with_rows(
&target_config,
max_seq_len,
device,
if use_verify_logits { gamma } else { 1 },
);
let mut target_verify_cache =
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
let _ = run_speculative(
&target,
&draft,
&mut target_cache,
&mut target_verify_cache,
&mut draft_cache,
&tokenizer,
&warm_ids,
warm_tokens,
gamma,
use_verify_logits,
verify_path,
dump_verify_mismatches,
);
}
eprintln!(
"Warmup done. Running {prompt_count} prompts, gen_tokens={gen_tokens}, gamma={gamma}, acceptance_mode={}, verify_path={}",
if use_verify_logits {
"verify_logits"
} else {
"decode"
},
verify_path.as_str()
);
let mut totals = Totals::default();
for (i, prompt) in PROMPTS.iter().take(prompt_count).enumerate() {
let ids = tokenizer.encode(prompt);
validate_length_budget(&ids, gen_tokens, max_seq_len, prompt);
let mut baseline_cache = new_cache(&target_config, max_seq_len, device);
let baseline = run_baseline(&target, &mut baseline_cache, &tokenizer, &ids, gen_tokens);
drop(baseline_cache);
let mut target_cache = new_cache_with_rows(
&target_config,
max_seq_len,
device,
if use_verify_logits { gamma } else { 1 },
);
let mut target_verify_cache =
new_cache_with_rows(&target_config, max_seq_len, device, gamma);
let mut draft_cache = new_cache(&draft_config, max_seq_len, device);
let spec = run_speculative(
&target,
&draft,
&mut target_cache,
&mut target_verify_cache,
&mut draft_cache,
&tokenizer,
&ids,
gen_tokens,
gamma,
use_verify_logits,
verify_path,
dump_verify_mismatches,
);
let matched = baseline.ids == spec.ids;
if !matched {
totals.mismatches += 1;
eprintln!("MISMATCH prompt {i}: {prompt}");
eprintln!(" baseline: {:?}", baseline.ids);
eprintln!(" spec: {:?}", spec.ids);
}
println!(
"prompt={:02} match={} gen={} accept={}/{} target_steps={} \
baseline_e2e_tpot_ms={:.3} spec_e2e_tpot_ms={:.3}",
i,
matched,
spec.ids.len(),
spec.accepted,
spec.proposed,
spec.target_steps,
per_token_ms(baseline.total_s, baseline.ids.len()),
per_token_ms(spec.total_s, spec.ids.len()),
);
totals.prompts += 1;
totals.baseline_generated += baseline.ids.len();
totals.spec_generated += spec.ids.len();
totals.baseline_total_s += baseline.total_s;
totals.baseline_prefill_s += baseline.prefill_s;
totals.baseline_decode_s += baseline.decode_s;
totals.spec_total_s += spec.total_s;
totals.spec_prefill_s += spec.prefill_s;
totals.spec_decode_s += spec.decode_s;
totals.spec_target_steps += spec.target_steps;
totals.spec_accepted += spec.accepted;
totals.spec_proposed += spec.proposed;
totals.spec_verify_steps += spec.verify_steps;
totals.spec_mirror_steps += spec.mirror_steps;
totals.spec_commit_steps += spec.commit_steps;
totals.spec_correction_steps += spec.correction_steps;
totals.spec_verify_decode_mismatches += spec.verify_decode_mismatches;
}
let baseline_decode_tokens = totals.baseline_generated;
let spec_decode_tokens = totals.spec_generated;
let acceptance = ratio(totals.spec_accepted, totals.spec_proposed);
let tokens_per_target_step = ratio(totals.spec_generated, totals.spec_target_steps);
let matched =
totals.mismatches == 0 && (!use_verify_logits || totals.spec_verify_decode_mismatches == 0);
println!("--- SUMMARY ---");
println!("prompts={} matched={matched}", totals.prompts);
println!(
"acceptance_mode={}",
if use_verify_logits {
"verify_logits"
} else {
"decode"
}
);
println!("verify_path={}", verify_path.as_str());
println!(
"acceptance_rate={:.4} accepted={} proposed={}",
acceptance, totals.spec_accepted, totals.spec_proposed
);
println!(
"tokens_per_target_step={:.4} target_steps={} verify_steps={} mirror_decode_steps={} commit_decode_steps={} correction_steps={}",
tokens_per_target_step,
totals.spec_target_steps,
totals.spec_verify_steps,
totals.spec_mirror_steps,
totals.spec_commit_steps,
totals.spec_correction_steps
);
println!(
"verify_decode_mismatches={}",
totals.spec_verify_decode_mismatches
);
println!(
"baseline_e2e_tpot_ms={:.3} baseline_e2e_tok_s={:.3}",
per_token_ms(totals.baseline_total_s, totals.baseline_generated),
tok_s(totals.baseline_generated, totals.baseline_total_s)
);
println!(
"spec_e2e_tpot_ms={:.3} spec_e2e_tok_s={:.3} speedup_e2e={:.4}",
per_token_ms(totals.spec_total_s, totals.spec_generated),
tok_s(totals.spec_generated, totals.spec_total_s),
speedup(totals.baseline_total_s, totals.spec_total_s)
);
println!(
"baseline_decode_tpot_ms={:.3} baseline_decode_tok_s={:.3}",
per_token_ms(totals.baseline_decode_s, baseline_decode_tokens),
tok_s(baseline_decode_tokens, totals.baseline_decode_s)
);
println!(
"spec_decode_tpot_ms={:.3} spec_decode_tok_s={:.3} speedup_decode={:.4}",
per_token_ms(totals.spec_decode_s, spec_decode_tokens),
tok_s(spec_decode_tokens, totals.spec_decode_s),
speedup(totals.baseline_decode_s, totals.spec_decode_s)
);
println!(
"decode_token_counts baseline={} spec={}",
baseline_decode_tokens, spec_decode_tokens
);
if !matched {
std::process::exit(2);
}
}
fn run_baseline(
model: &Qwen3,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_ids: &[u32],
gen_tokens: usize,
) -> RunStats {
let slot = 0;
cache.register_sequence(slot).expect("register target slot");
let t0 = Instant::now();
let prefill_start = Instant::now();
let logits = model.forward_prefill_paged(prompt_ids, slot, cache);
sync_device();
let prefill_s = prefill_start.elapsed().as_secs_f64();
let mut generated = Vec::with_capacity(gen_tokens);
let mut next = last_argmax(&logits);
generated.push(next);
let decode_start = Instant::now();
let mut target_steps = 0usize;
while generated.len() < gen_tokens && !tokenizer.is_eos(next) {
let pos = cache.seq_len(slot);
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], cache);
target_steps += 1;
next = last_argmax(&logits);
generated.push(next);
}
sync_device();
let decode_s = decode_start.elapsed().as_secs_f64();
sync_device();
let total_s = t0.elapsed().as_secs_f64();
cache.free_sequence(slot);
RunStats {
ids: generated,
total_s,
prefill_s,
decode_s,
target_steps,
..Default::default()
}
}
#[allow(clippy::too_many_arguments)]
fn run_speculative(
target: &Qwen3,
draft: &Qwen3,
target_cache: &mut PagedKVCache,
target_verify_cache: &mut PagedKVCache,
draft_cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_ids: &[u32],
gen_tokens: usize,
gamma: usize,
use_verify_logits: bool,
verify_path: VerifyPath,
dump_verify_mismatches: bool,
) -> RunStats {
let slot = 0;
target_cache
.register_sequence(slot)
.expect("register target slot");
target_verify_cache
.register_sequence(slot)
.expect("register target verify slot");
draft_cache
.register_sequence(slot)
.expect("register draft slot");
let t0 = Instant::now();
let prefill_start = Instant::now();
let target_logits = target.forward_prefill_paged(prompt_ids, slot, target_cache);
if !use_verify_logits {
let _ = target.forward_prefill_paged(prompt_ids, slot, target_verify_cache);
}
let draft_logits = draft.forward_prefill_paged(prompt_ids, slot, draft_cache);
sync_device();
let prefill_s = prefill_start.elapsed().as_secs_f64();
let mut target_next = last_argmax(&target_logits);
let mut draft_next = last_argmax(&draft_logits);
let mut generated = Vec::with_capacity(gen_tokens);
let mut accepted_total = 0usize;
let mut proposed_total = 0usize;
let mut verify_steps = 0usize;
let mut mirror_steps = 0usize;
let mut commit_steps = 0usize;
let mut correction_steps = 0usize;
let mut verify_decode_mismatches = 0usize;
let decode_start = Instant::now();
while generated.len() < gen_tokens {
let remaining = gen_tokens - generated.len();
let round_gamma = gamma.min(remaining);
let round_start_len = target_cache.seq_len(slot);
assert_eq!(
round_start_len,
draft_cache.seq_len(slot),
"target and draft cache lengths diverged"
);
if !use_verify_logits {
assert_eq!(
round_start_len,
target_verify_cache.seq_len(slot),
"target verify cache length diverged"
);
}
let mut draft_tokens = Vec::with_capacity(round_gamma);
for _ in 0..round_gamma {
let token = draft_next;
draft_tokens.push(token);
if tokenizer.is_eos(token) {
break;
}
let pos = draft_cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], draft_cache);
draft_next = last_argmax(&logits);
}
proposed_total += draft_tokens.len();
if use_verify_logits {
verify_steps += 1;
let verify_logits =
target.forward_verify_paged_decode_attention(&draft_tokens, slot, target_cache);
let verify_argmax = argmax_rows(&verify_logits);
assert_eq!(
verify_argmax.len(),
draft_tokens.len(),
"verify logits rows must match draft token count"
);
let mut accepted = 0usize;
let mut done = false;
while accepted < draft_tokens.len() {
let expected = if accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
if draft_tokens[accepted] != expected {
break;
}
let token = draft_tokens[accepted];
generated.push(token);
accepted_total += 1;
accepted += 1;
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
done = true;
break;
}
}
if accepted > 0 {
target_next = verify_argmax[accepted - 1];
}
target_cache
.truncate_sequence(slot, round_start_len + accepted)
.unwrap();
if done {
draft_cache
.truncate_sequence(slot, target_cache.seq_len(slot))
.unwrap();
break;
}
if accepted == draft_tokens.len() {
continue;
}
let correction = if accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
generated.push(correction);
draft_cache
.truncate_sequence(slot, round_start_len)
.unwrap();
replay_draft_tokens(
draft,
draft_cache,
slot,
&draft_tokens[..accepted],
&mut draft_next,
);
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
break;
}
let pos = target_cache.seq_len(slot);
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
target_next = last_argmax(&logits);
commit_steps += 1;
let pos = draft_cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
draft_next = last_argmax(&logits);
correction_steps += 1;
continue;
}
verify_steps += 1;
let verify_logits = match verify_path {
VerifyPath::Flash => {
target.forward_prefill_paged(&draft_tokens, slot, target_verify_cache)
}
VerifyPath::PagedDecode => target.forward_verify_paged_decode_attention(
&draft_tokens,
slot,
target_verify_cache,
),
};
let verify_argmax = argmax_rows(&verify_logits);
assert_eq!(
verify_argmax.len(),
draft_tokens.len(),
"verify logits rows must match draft token count"
);
target_verify_cache
.truncate_sequence(slot, round_start_len)
.unwrap();
let mut accepted = 0usize;
let mut done = false;
while accepted < draft_tokens.len() {
let expected = if use_verify_logits && accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
if draft_tokens[accepted] != expected {
break;
}
let token_idx = accepted;
let token = draft_tokens[token_idx];
generated.push(token);
accepted_total += 1;
accepted += 1;
if generated.len() >= gen_tokens || tokenizer.is_eos(token) {
done = true;
break;
}
let pos = target_cache.seq_len(slot);
let logits = target.forward_decode_paged(&[token], &[pos], &[slot], target_cache);
let decode_next = last_argmax(&logits);
if verify_argmax[token_idx] != decode_next {
verify_decode_mismatches += 1;
eprintln!(
"VERIFY/DECODE MISMATCH at cache_len={} accepted_idx={}: verify={} decode={}",
target_cache.seq_len(slot),
token_idx,
verify_argmax[token_idx],
decode_next
);
if dump_verify_mismatches {
eprintln!(
" verify_top5={} decode_top5={}",
format_topk(&verify_logits, token_idx, 5),
format_topk(&logits, 0, 5)
);
}
}
target_next = decode_next;
commit_steps += 1;
advance_target_cache(target, target_verify_cache, slot, token);
mirror_steps += 1;
}
if done {
draft_cache
.truncate_sequence(slot, target_cache.seq_len(slot))
.unwrap();
target_verify_cache
.truncate_sequence(slot, target_cache.seq_len(slot))
.unwrap();
break;
}
if accepted == draft_tokens.len() {
continue;
}
let correction = if use_verify_logits && accepted > 0 {
verify_argmax[accepted - 1]
} else {
target_next
};
generated.push(correction);
draft_cache
.truncate_sequence(slot, round_start_len)
.unwrap();
replay_draft_tokens(
draft,
draft_cache,
slot,
&draft_tokens[..accepted],
&mut draft_next,
);
if generated.len() >= gen_tokens || tokenizer.is_eos(correction) {
break;
}
let pos = target_cache.seq_len(slot);
let logits = target.forward_decode_paged(&[correction], &[pos], &[slot], target_cache);
target_next = last_argmax(&logits);
commit_steps += 1;
advance_target_cache(target, target_verify_cache, slot, correction);
mirror_steps += 1;
let pos = draft_cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[correction], &[pos], &[slot], draft_cache);
draft_next = last_argmax(&logits);
correction_steps += 1;
}
sync_device();
let decode_s = decode_start.elapsed().as_secs_f64();
sync_device();
let total_s = t0.elapsed().as_secs_f64();
target_cache.free_sequence(slot);
target_verify_cache.free_sequence(slot);
draft_cache.free_sequence(slot);
RunStats {
ids: generated,
total_s,
prefill_s,
decode_s,
target_steps: verify_steps + mirror_steps + commit_steps + correction_steps,
accepted: accepted_total,
proposed: proposed_total,
verify_steps,
mirror_steps,
commit_steps,
correction_steps,
verify_decode_mismatches,
}
}
fn advance_target_cache(target: &Qwen3, cache: &mut PagedKVCache, slot: usize, token: u32) {
let pos = cache.seq_len(slot);
let _ = target.forward_decode_paged(&[token], &[pos], &[slot], cache);
}
fn replay_draft_tokens(
draft: &Qwen3,
cache: &mut PagedKVCache,
slot: usize,
tokens: &[u32],
next: &mut u32,
) {
for &token in tokens {
let pos = cache.seq_len(slot);
let logits = draft.forward_decode_paged(&[token], &[pos], &[slot], cache);
*next = last_argmax(&logits);
}
}
fn new_cache(config: &ModelConfig, max_seq_len: usize, device: u32) -> PagedKVCache {
new_cache_with_rows(config, max_seq_len, device, 1)
}
fn new_cache_with_rows(
config: &ModelConfig,
max_seq_len: usize,
device: u32,
max_rows: usize,
) -> PagedKVCache {
let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE);
let total_blocks = max_blocks_per_seq + 8;
PagedKVCache::new(
config,
total_blocks,
0,
max_rows.max(1),
max_blocks_per_seq,
DType::BF16,
device,
)
}
fn argmax_rows(logits: &Tensor) -> Vec<u32> {
assert_eq!(logits.ndim(), 2);
if logits.dtype() == DType::BF16
&& matches!(logits.device(), Device::Cuda(_))
&& logits.is_contiguous()
{
return xserv_kernels::argmax_bf16_to_host(logits);
}
let vocab_size = logits.shape()[1];
let rows = logits.shape()[0];
let logits_cpu = logits.to_device(Device::Cpu);
match logits.dtype() {
DType::F32 => logits_cpu
.as_slice::<f32>()
.chunks_exact(vocab_size)
.take(rows)
.map(argmax_f32)
.collect(),
DType::BF16 => logits_cpu
.as_slice::<bf16>()
.chunks_exact(vocab_size)
.take(rows)
.map(|row| {
row.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32)
.unwrap()
})
.collect(),
_ => panic!("unsupported dtype for argmax: {:?}", logits.dtype()),
}
}
fn last_argmax(logits: &Tensor) -> u32 {
*argmax_rows(logits).last().unwrap()
}
fn argmax_f32(row: &[f32]) -> u32 {
row.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i as u32)
.unwrap()
}
fn format_topk(logits: &Tensor, row: usize, k: usize) -> String {
let vals = topk_row(logits, row, k);
vals.iter()
.map(|(id, val)| format!("{id}:{val:.3}"))
.collect::<Vec<_>>()
.join(",")
}
fn topk_row(logits: &Tensor, row: usize, k: usize) -> Vec<(u32, f32)> {
assert_eq!(logits.ndim(), 2);
let vocab_size = logits.shape()[1];
assert!(row < logits.shape()[0], "topk row out of bounds");
let logits_cpu = logits.to_device(Device::Cpu);
let mut vals: Vec<(u32, f32)> = match logits.dtype() {
DType::F32 => logits_cpu.as_slice::<f32>()[row * vocab_size..(row + 1) * vocab_size]
.iter()
.enumerate()
.map(|(i, &v)| (i as u32, v))
.collect(),
DType::BF16 => logits_cpu.as_slice::<bf16>()[row * vocab_size..(row + 1) * vocab_size]
.iter()
.enumerate()
.map(|(i, &v)| (i as u32, v.to_f32()))
.collect(),
_ => panic!("unsupported dtype for topk: {:?}", logits.dtype()),
};
vals.select_nth_unstable_by(k, |a, b| b.1.partial_cmp(&a.1).unwrap());
vals.truncate(k);
vals.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
vals
}
fn assert_qwen3(config: &ModelConfig, name: &str) {
let model_type = config.model_type.as_deref().unwrap_or("unknown");
assert!(
model_type.contains("qwen"),
"{name} model_type must be qwen-like, got {model_type}"
);
}
fn warn_if_tokenizers_differ(target_dir: &Path, draft_dir: &Path) {
let target = std::fs::read(target_dir.join("tokenizer.json"));
let draft = std::fs::read(draft_dir.join("tokenizer.json"));
if let (Ok(target), Ok(draft)) = (target, draft) {
if target != draft {
eprintln!(
"WARNING: target and draft tokenizer.json differ; v0 assumes identical token ids"
);
}
}
}
fn arg_usize(args: &[String], flag: &str, default: usize) -> usize {
args.iter()
.position(|a| a == flag)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn parse_verify_path(args: &[String], use_verify_logits: bool) -> VerifyPath {
let default = if use_verify_logits {
VerifyPath::PagedDecode
} else {
VerifyPath::Flash
};
let Some(value) = args
.iter()
.position(|a| a == "--verify-path")
.and_then(|i| args.get(i + 1))
else {
return default;
};
match value.as_str() {
"flash" => VerifyPath::Flash,
"paged-decode" => VerifyPath::PagedDecode,
_ => {
eprintln!("unknown --verify-path {value:?}; expected flash or paged-decode");
std::process::exit(1);
}
}
}
fn validate_length_budget(prompt_ids: &[u32], gen_tokens: usize, max_seq_len: usize, prompt: &str) {
let required = prompt_ids.len() + gen_tokens;
if required > max_seq_len {
eprintln!(
"prompt requires prompt_len({}) + gen_tokens({}) = {} tokens, exceeding --max-seq-len {}: {:?}",
prompt_ids.len(),
gen_tokens,
required,
max_seq_len,
prompt
);
std::process::exit(2);
}
}
fn sync_device() {
xserv_cuda::device::synchronize().expect("cuda device synchronize");
}
fn ratio(num: usize, den: usize) -> f64 {
if den == 0 {
0.0
} else {
num as f64 / den as f64
}
}
fn speedup(baseline_s: f64, spec_s: f64) -> f64 {
if spec_s == 0.0 {
0.0
} else {
baseline_s / spec_s
}
}
fn tok_s(tokens: usize, seconds: f64) -> f64 {
if seconds == 0.0 {
0.0
} else {
tokens as f64 / seconds
}
}
fn per_token_ms(seconds: f64, tokens: usize) -> f64 {
if tokens == 0 {
0.0
} else {
seconds * 1000.0 / tokens as f64
}
}

View File

@@ -1,23 +1,51 @@
use std::io::{self, Write};
use std::path::PathBuf;
use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader};
use xserv_model::{
BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, SamplingParams, loader, sample,
sample_greedy_penalized,
};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn pick_next(
logits: &xserv_tensor::Tensor,
sampling: &SamplingParams,
history: &[u32],
rep_penalty: f32,
) -> u32 {
if rep_penalty > 1.0 && sampling.temperature == 0.0 {
sample_greedy_penalized(logits, history, rep_penalty)
} else {
sample(logits, sampling)
}
}
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: xserv-cli <model-dir> [--max-tokens N]");
eprintln!(
"Usage: xserv-cli <model-dir> [--max-tokens N] [--temperature F] [--top-k N] [--top-p F] [--rep-penalty F] [--rep-window N]"
);
std::process::exit(1);
}
let model_dir = PathBuf::from(&args[1]);
let max_tokens: usize = args
.iter()
.position(|a| a == "--max-tokens")
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(100);
let max_tokens = flag(&args, "--max-tokens", 100usize);
let sampling = SamplingParams {
temperature: flag(&args, "--temperature", 0.0f32),
top_k: flag(&args, "--top-k", 0usize),
top_p: flag(&args, "--top-p", 1.0f32),
};
let rep_penalty = flag(&args, "--rep-penalty", 1.0f32);
let rep_window = flag(&args, "--rep-window", 512usize);
xserv_cuda::device::set_device(0).unwrap();
let info = xserv_cuda::device::device_info(0).unwrap();
@@ -65,7 +93,10 @@ fn main() {
};
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
eprintln!("Ready (KV cache, dtype={dtype}).\n");
eprintln!(
"Ready (KV cache, dtype={dtype}, temperature={}, top_k={}, top_p={}, rep_penalty={}, rep_window={}).\n",
sampling.temperature, sampling.top_k, sampling.top_p, rep_penalty, rep_window
);
loop {
print!("xserv> ");
@@ -74,15 +105,16 @@ fn main() {
if io::stdin().read_line(&mut input).unwrap() == 0 {
break;
}
let input = input.trim();
if input.is_empty() {
let raw_input = input.trim();
if raw_input.is_empty() {
continue;
}
if input == "quit" || input == "exit" {
if raw_input == "quit" || raw_input == "exit" {
break;
}
let input = raw_input.replace("\\n", "\n");
let token_ids = tokenizer.encode(input);
let token_ids = tokenizer.encode(&input);
if is_gpt_oss {
// GptOss uses paged KV cache
@@ -106,7 +138,9 @@ fn main() {
_ => unreachable!(),
};
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
let mut next = sample_greedy_last(&logits);
let mut history = token_ids.clone();
let start = history.len().saturating_sub(rep_window);
let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
print!("{input}");
io::stdout().flush().unwrap();
@@ -115,6 +149,7 @@ fn main() {
let text = tokenizer.decode(&[next]);
print!("{text}");
io::stdout().flush().unwrap();
history.push(next);
if tokenizer.eos_token_id() == Some(next) {
break;
@@ -122,7 +157,8 @@ fn main() {
let pos = paged_cache.seq_len(slot);
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
next = sample_greedy_last(&logits);
let start = history.len().saturating_sub(rep_window);
next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
}
println!();
paged_cache.free_sequence(slot);
@@ -145,11 +181,9 @@ fn main() {
Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache),
Model::GptOss(_) => unreachable!(),
};
let mut next = match &model {
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
Model::GptOss(_) => unreachable!(),
};
let mut history = token_ids.clone();
let start = history.len().saturating_sub(rep_window);
let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
print!("{input}");
io::stdout().flush().unwrap();
@@ -158,6 +192,7 @@ fn main() {
let text = tokenizer.decode(&[next]);
print!("{text}");
io::stdout().flush().unwrap();
history.push(next);
if tokenizer.eos_token_id() == Some(next) {
break;
@@ -168,28 +203,10 @@ fn main() {
Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache),
Model::GptOss(_) => unreachable!(),
};
next = match &model {
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
Model::GptOss(_) => unreachable!(),
};
let start = history.len().saturating_sub(rep_window);
next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
}
println!();
}
}
}
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
use half::bf16;
assert_eq!(logits.ndim(), 2);
let logits_cpu = logits.to_device(Device::Cpu);
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32)
.unwrap()
}

View File

@@ -9,7 +9,7 @@
use std::ffi::c_void;
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
use xserv_kernels::dispatch;
use xserv_kernels::gemm::cublas_handle;
use xserv_kernels::gemm::{cublas_handle, gemv_scratch_elems};
use crate::config::ModelConfig;
use crate::kv_cache::GpuKVCache;
@@ -54,7 +54,7 @@ struct DecodeBuffers {
up: GpuBuffer, // [1, intermediate]
silu_out: GpuBuffer, // [1, intermediate]
// GEMV fp32 accumulators (separate per output dimension)
// GEMV fp32 scratch for deterministic K-block partials.
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
fp32_q: GpuBuffer, // for Q projection
fp32_kv: GpuBuffer, // for K/V projection
@@ -140,11 +140,14 @@ impl DecodeGraphState {
up: alloc(intermediate * es),
silu_out: alloc(intermediate * es),
fp32_hidden: alloc(hidden * 4),
fp32_q: alloc(num_heads * head_dim * 4),
fp32_kv: alloc(num_kv_heads * head_dim * 4),
fp32_intermediate: alloc(intermediate * 4),
fp32_vocab: alloc(vocab_size * 4),
fp32_hidden: alloc(
gemv_scratch_elems(hidden, hidden).max(gemv_scratch_elems(intermediate, hidden))
* 4,
),
fp32_q: alloc(gemv_scratch_elems(hidden, num_heads * head_dim) * 4),
fp32_kv: alloc(gemv_scratch_elems(hidden, num_kv_heads * head_dim) * 4),
fp32_intermediate: alloc(gemv_scratch_elems(hidden, intermediate) * 4),
fp32_vocab: alloc(gemv_scratch_elems(hidden, vocab_size) * 4),
token_id_gpu: alloc(4),
position_gpu: alloc(4),

View File

@@ -486,6 +486,35 @@ impl PagedKVCache {
state.seq_len += num_tokens;
}
/// Roll a registered sequence back to `new_len` tokens.
///
/// This only changes cache metadata and frees whole physical blocks that are
/// no longer reachable. Bytes inside retained blocks are left untouched; the
/// logical `seq_len` prevents attention from reading them, and later writes
/// to the same positions overwrite them.
pub fn truncate_sequence(&mut self, slot: usize, new_len: usize) -> Result<(), &'static str> {
if slot >= self.max_seqs {
return Err("truncate_sequence: slot out of range");
}
let state = self.seq_states[slot]
.as_mut()
.ok_or("truncate_sequence: empty slot")?;
if new_len > state.seq_len {
return Err("truncate_sequence: cannot extend");
}
let needed_blocks = ((new_len + BLOCK_SIZE - 1) / BLOCK_SIZE).max(1);
while state.block_ids.len() > needed_blocks {
let block = state.block_ids.pop().expect("checked len");
match state.location {
Location::Gpu => self.allocator.free(block),
Location::Cpu => self.cpu_allocator.free(block),
}
}
state.seq_len = new_len;
Ok(())
}
/// Refresh the host-side block table + context lens from `seq_states`,
/// then upload to GPU. Call once per decode step before the paged kernel.
pub fn sync_to_gpu(&mut self) {
@@ -748,6 +777,71 @@ impl PagedKVCache {
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_config() -> ModelConfig {
serde_json::from_value(serde_json::json!({
"model_type": "qwen3",
"hidden_size": 8,
"intermediate_size": 16,
"num_attention_heads": 1,
"num_key_value_heads": 1,
"num_hidden_layers": 1,
"vocab_size": 32,
"max_position_embeddings": 64
}))
.unwrap()
}
#[test]
fn truncate_sequence_frees_whole_blocks_and_keeps_slot_registered() {
if xserv_cuda::device::set_device(0).is_err() {
eprintln!("skipping CUDA-backed PagedKVCache test: device 0 unavailable");
return;
}
let config = tiny_config();
let mut cache = PagedKVCache::new(&config, 5, 0, 1, 4, DType::BF16, 0);
assert_eq!(
cache.truncate_sequence(1, 0),
Err("truncate_sequence: slot out of range")
);
assert_eq!(
cache.truncate_sequence(0, 0),
Err("truncate_sequence: empty slot")
);
cache.register_sequence(0).unwrap();
cache.ensure_capacity(0, BLOCK_SIZE * 3 + 1);
cache.advance_seq_len(0, BLOCK_SIZE * 3 + 1);
assert_eq!(cache.seq_len(0), BLOCK_SIZE * 3 + 1);
assert_eq!(cache.block_count(0), 4);
assert_eq!(cache.free_blocks(), 0);
cache.truncate_sequence(0, BLOCK_SIZE + 1).unwrap();
assert_eq!(cache.seq_len(0), BLOCK_SIZE + 1);
assert_eq!(cache.block_count(0), 2);
assert_eq!(cache.free_blocks(), 2);
cache.truncate_sequence(0, BLOCK_SIZE).unwrap();
assert_eq!(cache.seq_len(0), BLOCK_SIZE);
assert_eq!(cache.block_count(0), 1);
assert_eq!(cache.free_blocks(), 3);
cache.truncate_sequence(0, 0).unwrap();
assert_eq!(cache.seq_len(0), 0);
assert_eq!(cache.block_count(0), 1);
assert_eq!(cache.free_blocks(), 3);
assert_eq!(
cache.truncate_sequence(0, 1),
Err("truncate_sequence: cannot extend")
);
}
}
unsafe fn tensor_from_owned_buf(
buf: GpuBuffer,
shape: &[usize],

View File

@@ -884,6 +884,109 @@ impl Qwen3 {
matmul_2d(&x, &self.lm_head_t)
}
/// Paged multi-token verify path: write `token_ids` into the paged cache,
/// then verify them with the same paged decode attention kernel used by
/// single-token decode. This keeps greedy top-1 behavior aligned with
/// `forward_decode_paged` while still batching the dense projections/MLP
/// across the draft window.
pub fn forward_verify_paged_decode_attention(
&self,
token_ids: &[u32],
slot: usize,
paged_cache: &mut PagedKVCache,
) -> Tensor {
let new_tokens = token_ids.len();
let pos_offset = paged_cache.seq_len(slot);
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
paged_cache.advance_seq_len(slot, new_tokens);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens)
.map(|p| p as u32)
.collect();
let kv_lens: Vec<i32> = (0..new_tokens)
.map(|i| (pos_offset + i + 1) as i32)
.collect();
let slots = vec![slot; new_tokens];
paged_cache.sync_active_batch_with_lens(&slots, &kv_lens);
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
let mut x = embedding(&self.embed_tokens, token_ids);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt);
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim);
let k_all = qkv.narrow(1, q_dim, kv_dim);
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
let q_flat = q_all
.contiguous()
.reshape(&[new_tokens * num_heads, head_dim]);
let k_flat = k_all
.contiguous()
.reshape(&[new_tokens * num_kv_heads, head_dim]);
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
rope_inplace(&q_3d, &self.rope_cache, &positions);
rope_inplace(&k_3d, &self.rope_cache, &positions);
let v_3d = v_all
.contiguous()
.reshape(&[new_tokens, num_kv_heads, head_dim]);
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_decode,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
new_tokens,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
);
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down);
x = add_any(&residual, &down);
}
let x = rmsnorm(&x, &self.norm, eps);
matmul_rows_gemv(&x, &self.lm_head_t)
}
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
let new_tokens = token_ids.len();
@@ -1158,6 +1261,20 @@ fn row_view(t: &Tensor, row: usize) -> Tensor {
)
}
/// Run a 2D matmul row by row so each row uses the same GEMV kernel as
/// single-token decode. Used by speculative verify parity, where near-tie
/// logits must follow decode's BF16 rounding path.
fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert!(a.is_contiguous());
let rows = a.shape()[0];
if rows == 1 {
return matmul_2d(a, b);
}
let out_rows: Vec<Tensor> = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect();
concat_rows(&out_rows)
}
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
fn concat_rows(rows: &[Tensor]) -> Tensor {
assert!(!rows.is_empty());

View File

@@ -118,7 +118,7 @@ __global__ void paged_decode_attention_bf16_kernel(
// ---- Block-level online softmax reduction ----
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
@@ -164,8 +164,12 @@ __global__ void paged_decode_attention_bf16_kernel(
__syncthreads();
global_sum = smem_sum[0];
// Step 4: reduce O across block, dim by dim
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
// Step 4: reduce O across block, dim by dim. Store one partial per warp
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
// when logits were close.
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
}
__syncthreads();
for (int d = 0; d < head_dim; d++) {
@@ -173,13 +177,15 @@ __global__ void paged_decode_attention_bf16_kernel(
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) atomicAdd(&smem_O[d], val);
if (lane == 0) smem_O_warp[warp_id][d] = val;
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
float out = 0.0f;
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
O_ptr[d] = __float2bfloat16(out * inv_sum);
}
}
@@ -289,7 +295,7 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
// ---- Block-level online softmax reduction (same as base kernel) ----
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
@@ -332,7 +338,9 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
__syncthreads();
global_sum = smem_sum[0];
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
}
__syncthreads();
for (int d = 0; d < head_dim; d++) {
@@ -340,13 +348,15 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) atomicAdd(&smem_O[d], val);
if (lane == 0) smem_O_warp[warp_id][d] = val;
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
float out = 0.0f;
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
O_ptr[d] = __float2bfloat16(out * inv_sum);
}
}

View File

@@ -6,22 +6,20 @@
//
// y[n] = sum_k x[k] * W[k * N + n]
//
// Grid: (N / TILE_N, K / TILE_K).
// All blocks atomicAdd their partial sums into a pre-zeroed FP32 buffer.
// A separate conversion kernel writes the final BF16 output.
// Launch sequence: cudaMemsetAsync(fp32) → accumulation kernel → convert kernel.
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
// fixed-order reduction over K blocks. The previous implementation used
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
// inter-block scheduling when logits were close.
#define GEMV_TILE_N 128
#define GEMV_TILE_K 256
#define GEMV_BLOCK 128
__global__ void gemv_bf16_fused_kernel(
__global__ void gemv_bf16_partial_kernel(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ W,
__nv_bfloat16* __restrict__ y_bf16,
float* __restrict__ y_fp32,
int K, int N,
int num_k_blocks
float* __restrict__ partials,
int K, int N
) {
const int block_n = blockIdx.x;
const int block_k = blockIdx.y;
@@ -52,18 +50,22 @@ __global__ void gemv_bf16_fused_kernel(
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
}
atomicAdd(&y_fp32[col], sum);
partials[(long long)block_k * N + col] = sum;
}
// Conversion kernel: FP32 accumulator -> BF16 output
__global__ void gemv_fp32_to_bf16_kernel(
const float* __restrict__ src,
__global__ void gemv_reduce_to_bf16_kernel(
const float* __restrict__ partials,
__nv_bfloat16* __restrict__ dst,
int n
int n,
int num_k_blocks
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] = __float2bfloat16(src[idx]);
float sum = 0.0f;
for (int kb = 0; kb < num_k_blocks; kb++) {
sum += partials[(long long)kb * n + idx];
}
dst[idx] = __float2bfloat16(sum);
}
}
@@ -79,30 +81,25 @@ void launch_gemv_bf16(
) {
cudaStream_t s = (cudaStream_t)stream;
// Zero the FP32 accumulator BEFORE the kernel — the kernel uses atomicAdd
// across K-blocks with no inter-block ordering, so the buffer must be
// pre-zeroed to avoid accumulating on stale data.
cudaMemsetAsync(y_fp32_buf, 0, (size_t)N * sizeof(float), s);
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
gemv_bf16_fused_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
(const __nv_bfloat16*)x,
(const __nv_bfloat16*)W,
(__nv_bfloat16*)y_bf16,
(float*)y_fp32_buf,
K, N, num_k_blocks
K, N
);
CUDA_CHECK_LAST_ERROR();
// FP32 → BF16 conversion (must wait for all K-blocks to finish)
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
int conv_block = 256;
int conv_grid = (N + conv_block - 1) / conv_block;
gemv_fp32_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
(const float*)y_fp32_buf,
(__nv_bfloat16*)y_bf16,
N
N,
num_k_blocks
);
CUDA_CHECK_LAST_ERROR();
}

View File

@@ -90,7 +90,7 @@ __global__ void softmax_bf16(
extern "C" {
void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stream) {
int block = (cols < 1024) ? cols : 1024;
int block = (cols < 512) ? cols : 512;
if (block < 32) block = 32;
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (float*)out, cols);
@@ -98,7 +98,7 @@ void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stre
}
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
int block = (cols < 1024) ? cols : 1024;
int block = (cols < 512) ? cols : 512;
if (block < 32) block = 32;
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);

View File

@@ -0,0 +1,186 @@
# Phase 22: Draft-Model Speculative Decoding v0
> 目标:实现一个可验证的 speculative decoding 最小闭环。先只覆盖
> Qwen3 target + 同 tokenizer 的小 Qwen3 draft、batch=1、greedy
> (`temperature=0`)。本阶段不做 gpt-oss,不做 sampling rejection,不接入
> continuous batching。
## 1. Scope
本阶段只解决一个窄问题:
- target:现有 Qwen3 paged KV 路径,优先 Qwen3-8B;
- draft:同 tokenizer 的小 Qwen3,例如 Qwen3-0.6B;
- batch size:1;
- decoding:greedy argmax;
- draft window:`gamma=4`;
- acceptance:exact-match,即 `target_argmax == draft_token`
HTTP flag 可以后续接入。v0 先提供独立 bench/CLI,因为它能直接输出 token
一致性、acceptance rate、tokens/target-step、TPOT/tok/s,也避免把尚未稳定的
rollback 行为放进服务端调度循环。
bench 为了让 baseline/spec 对比不受跨 prompt KV pool 复用影响,每个 prompt 的
baseline run 和 speculative run 都使用新建的 paged KV cache。cache 分配发生在
单次 run 的计时外,输出的 TPOT/tok/s 只覆盖模型 prefill/decode 工作。
## 2. Why Qwen3 First
Qwen3 是现有代码里最适合作为 speculative v0 的模型族:
1. target 已有稳定的 `forward_prefill_paged``forward_decode_paged`;
2. 小 Qwen3 与 Qwen3-8B 共享 tokenizer,可以直接比较 token id;
3. Qwen3 是 dense decoder-only,没有 gpt-oss 的 harmony 格式、MoE sparse 路径、
sliding-window 或 CUDA Graph 状态;
4. greedy 输出的正确性定义简单:只要 spec 生成的 token 序列与纯 target greedy
完全一致即可。
gpt-oss spec 需要先定义 harmony prompt、MoE draft 选择、graph replay 与 rollback
的交互,这些都不属于本阶段。
## 3. Algorithm
对每个 prompt 建两套模型、三套 KV 状态:
```text
target model + target commit PagedKVCache
target model + target verify PagedKVCache
draft model + draft PagedKVCache
```
先把 prompt 分别 prefill 到三套 cache。此时 cache 都包含 prompt,并各自持有
"下一个 token" 的 logits。
每个 speculative round:
1. draft 从当前 draft logits 取 argmax,连续生成 `gamma` 个 draft token;
2. draft 每生成一个 token 就用自己的 paged decode append 到 draft KV,所以 round
结束时 draft cache 暂时包含整个草稿序列;
3. target verify cache 对完整 draft token 序列调用一次 paged prefill,覆盖
"target 可一次验证草稿窗口" 这条执行路径;
4. target verify cache 立刻 rollback 到 round 起点,避免把 prefill 临时写入污染
commit cache;
5. 用 target decode 轨迹作为权威结果,从左到右比较
`target_next_argmax == draft_token`,只接受连续匹配前缀;
6. 对每个接受 token,用 target decode 重放一次来提交 target KV,并得到下一步
`target_next_argmax`;verify cache 也 mirror decode 同一个 token,保持长度与 prefix 对齐;
7. 若全部匹配,draft cache 已经包含完整草稿,三套 cache 长度重新对齐;
8. 若在第 `k` 个 token 拒绝,提交前 `k` 个 draft token,再提交 target 在该位置的
argmax 作为修正 token。draft cache rollback 到 round 起点后重放接受 token 和修正
token,target commit/verify cache 都由 decode 路径提交到同一 prefix。
v0 不使用完整 speculative sampling 的概率校正。它只利用小模型猜测 greedy 轨迹,
因此生成序列必须与纯 target greedy 完全一致。
当前实现选择 decode 轨迹作为提交路径,而不是直接保留 target prefill 写入的 KV。
原因是 v0 验收要求 token 序列与纯 target greedy 完全一致;如果 prefill 和 decode
路径在数值或 KV 写入顺序上存在细微差异,直接提交 prefill KV 会让后续 greedy 输出
漂移。这个保守实现仍会执行 target paged prefill 验证和 rollback,但 verify 写入放在
独立 cache,不会影响权威 commit cache。代价是额外 mirror decode,速度收益预期较差,
主要用于先验证 draft-model speculative 的状态机和一致性。
为保证 greedy exactness,decode 里两个原有非确定点也需要固定:
- BF16 GEMV 不再用跨 K-block `atomicAdd`;改为写 K-block partials,再按固定顺序
reduce;
- paged decode attention 不再用 `atomicAdd` 合并 warp 输出;改为 per-warp partials
后按 warp id 顺序 reduce。
## 4. KV Commit And Rollback
现有 `forward_prefill_paged` 会一次性把传入 token 写进 paged KV,并提前推进
`seq_len`。验证草稿时 target verify cache 因此会临时包含整个 draft window。
新增的 cache 操作只做逻辑截断:
```text
truncate_sequence(slot, new_len)
```
约束:
- 只允许 `new_len <= current_len`;
- 保留覆盖 `[0, new_len)` 所需的物理 block;
- 释放右侧多余 block;
- 不清零仍在保留 block 内的旧字节,因为后续逻辑长度会阻止 attention 读取它们,
同一位置再次写入时会覆盖旧值;
- slot 仍保持 registered,`new_len=0` 时也保留第一个 block。
这让 target 和 draft 都能在拒绝时安全丢弃多写 KV,并在修正 token decode 后重新
对齐。
## 5. Acceptance Criteria
本阶段验收:
- `cargo fmt`;
- `cargo check`;
- `cargo test`;
- `bench-speculative` 可加载 target+draft 两套 Qwen3;
- 50 prompts,greedy,baseline target 与 speculative token id 序列完全一致;
- 输出 acceptance rate、tokens/target-step、TPOT、tok/s 和 speedup;
- 若 draft 模型缺失或磁盘不足,明确报告阻塞条件,不盲目下载大模型。
## 6. Validation Results
dash5 环境:
- GPU:RTX 5090,device 0;
- target:`/opt/wjh/models/qwen3-8b`;
- draft:`/dashscope-tmp/wjh/models/qwen3-0.6b`;
- command:`bench-speculative ... --prompts 50 --gen-tokens 32 --gamma 4 --device 0`;
- log:`/dashscope-tmp/wjh/xserv-spec-default-50x32-final.log`
默认 `acceptance_mode=decode` 的结果:
```text
prompts=50 matched=true
acceptance_rate=0.3664 accepted=1020 proposed=2784
tokens_per_target_step=0.3639 target_steps=4397
verify_steps=729 mirror_decode_steps=1550 commit_decode_steps=1550 correction_steps=568
verify_decode_mismatches=10
baseline_e2e_tpot_ms=13.123 baseline_e2e_tok_s=76.204
spec_e2e_tpot_ms=44.867 spec_e2e_tok_s=22.288 speedup_e2e=0.2925
baseline_decode_tpot_ms=12.638 baseline_decode_tok_s=79.127
spec_decode_tpot_ms=43.731 spec_decode_tok_s=22.867 speedup_decode=0.2890
decode_token_counts baseline=1600 spec=1600
```
诊断 `--use-verify-logits` 的结果:
- command:`bench-speculative ... --prompts 10 --gen-tokens 32 --gamma 4 --device 0 --use-verify-logits`;
- log:`/dashscope-tmp/wjh/xserv-spec-verify-logits-10x32.log`;
- exit status:`2`;
- summary:`matched=false`, `verify_decode_mismatches=4`;
- prompt 0/2/7 出现 baseline/spec token 序列分叉。
结论:当前可以做 correctness-first 的 speculative decoding 状态机,但还不能把
target batched prefill verify logits 作为 greedy 接受依据。verify prefill 路径与
逐 token decode 路径存在 top-1 不一致;默认模式必须继续以 decode 轨迹为权威,
因此 v0 是正确性闭环,不是性能优化。
## 7. Known Limits
- 只支持 batch=1;
- 只支持 Qwen3-family dense models;
- 只支持 greedy exact-match acceptance;
- 未实现 probabilistic rejection sampling,所以 temperature/top-k/top-p 不支持;
- 未接 HTTP/continuous batching;
- 未与 CUDA Graph decode 结合;
- 当前 v0 为保证 greedy exactness,接受 token 也会用 target decode 重放提交,因此
即使 acceptance 高也可能变慢;
- draft prefill 和 target prefill 都会计入端到端耗时,短输出可能没有收益。
## 8. Next Phase TODO
如果继续 speculative decoding,下一阶段不要先接 HTTP,应先解决 verify 路径:
1. 做最小 prefill-vs-decode parity harness:固定 prompt、cache len、draft token,
dump 每层/最终 logits 的 top-k,定位 top-1 分叉来自 attention、GEMV 还是 KV 写入顺序;
2.`--use-verify-logits` 在至少 50 prompts x 64 tokens 下 `matched=true`
`verify_decode_mismatches=0`;
3. parity 过后再做真正 multi-token target commit:要么安全保留 verify prefill 写入的
KV,要么实现专用 paged multi-token verify/commit kernel,避免当前的 mirror+commit
decode 重放;
4. 只有 `speedup_e2e > 1` 后再考虑 HTTP flag、continuous batching、sampling 或
gpt-oss speculative decoding。

View File

@@ -0,0 +1,85 @@
# Phase 23: Speculative Verify Parity
> 目标:把 speculative decoding 从 v0 的 correctness-only 状态机推进到
> "verify logits 可作为权威接受依据"。本阶段仍只覆盖 Qwen3 target +
> Qwen3 small draft、batch=1、greedy。
## 1. Problem
Phase 22 的默认模式用逐 token target decode 作为权威路径,因此输出能与 baseline
一致。但诊断 `--use-verify-logits` 会失败:target 对 draft window 做 batched
prefill verify 时,部分 logits top-1 与逐 token decode 不一致。
实测 top-k 显示分叉不是大幅数值错误,而是 BF16 near-tie:
```text
verify_top5=17689:24.500,9856:24.375,...
decode_top5=9856:24.500,17689:24.500,...
```
如果直接用这些 verify logits 接受/拒绝 draft token,greedy token 序列会偏离纯
target decode。
## 2. Design
新增 `Qwen3::forward_verify_paged_decode_attention`:
1. 在 target commit cache 上一次写入 draft window 的 K/V;
2. attention 使用现有 paged decode attention,每个 draft token 对应一行 metadata,
context lens 分别为 `pos + 1`;
3. 线性层使用逐行 GEMV,与 `forward_decode_paged` 的 BF16 rounding path 对齐;
4. 若 token 全接受,直接保留 verify 写入的 KV;
5. 若在第 `k` 个 token 拒绝,把 target cache truncate 到 accepted prefix,再只
decode 一个 correction token。
bench 新增:
- `--use-verify-logits`:用 verify logits 作为接受依据,默认选择 `paged-decode`
verify path;
- `--verify-path flash|paged-decode`:显式选择旧 flash prefill 诊断或新 paged-decode
verify path;
- `--dump-verify-mismatches`:打印 mismatch 行 top-k,用于定位 near-tie。
## 3. Validation
dash5:
- GPU:RTX 5090,device 0;
- target:`/opt/wjh/models/qwen3-8b`;
- draft:`/dashscope-tmp/wjh/models/qwen3-0.6b`;
- command:`bench-speculative ... --prompts 50 --gen-tokens 64 --gamma 4 --device 0 --use-verify-logits`;
- log:`/dashscope-tmp/wjh/xserv-spec-inplace-verify-50x64.log`
结果:
```text
prompts=50 matched=true
acceptance_mode=verify_logits
verify_path=paged-decode
acceptance_rate=0.3927 accepted=2120 proposed=5398
tokens_per_target_step=0.9112 target_steps=3512
verify_steps=1376 mirror_decode_steps=0 commit_decode_steps=1068 correction_steps=1068
verify_decode_mismatches=0
baseline_e2e_tpot_ms=13.094 baseline_e2e_tok_s=76.372
spec_e2e_tpot_ms=30.069 spec_e2e_tok_s=33.257 speedup_e2e=0.4355
baseline_decode_tpot_ms=12.846 baseline_decode_tok_s=77.844
spec_decode_tpot_ms=29.731 spec_decode_tok_s=33.635 speedup_decode=0.4321
decode_token_counts baseline=3200 spec=3200
```
对比 Phase 22 的保守 decode-authoritative v0:
- verify logits 现在可以作为权威接受依据;
- `mirror_decode_steps` 从每个 accepted token 一次降为 0;
- 50x64 e2e speedup 从约 0.29x 提升到 0.44x;
- 仍未超过 baseline,因为 verify path 为了 parity 使用逐行 GEMV,且 draft acceptance
只有约 39%。
## 4. Next TODO
下一阶段要从 correctness parity 转向性能:
1. 逐层替换 row-GEMV 为 batched GEMM,同时保留 near-tie fallback 或 top-k audit;
2. 加一个 `--verify-audit-decode` 低频抽样审计,避免每轮都做 target decode;
3.`gamma` 与 draft 选择,记录 acceptance 与 TPOT 曲线;
4. `speedup_e2e > 1` 前不接 HTTP/continuous batching/gpt-oss spec。