model: paged KV cache with CPU swap pool, decode graph, qwen3 updates
- paged_kv_cache: new block-paged KV cache; adds a pinned-host swap pool with
a second BlockAllocator, per-sequence Location {Gpu,Cpu}, and lossless
swap_out/swap_in (block-granular D2H/H2D) for vLLM-style preemption.
bytes_per_block helper exposes per-block cost for VRAM-based sizing.
- decode_graph: CUDA-graph decode path.
- qwen3/gpt2/kv_cache: paged prefill/decode forward + related updates.
- tokenizer/bins: BPE updates, new xserv-chat CLI, bench-qwen3 tweaks.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{loader, GpuKVCache, ModelConfig, Qwen3};
|
||||
use xserv_model::{loader, DecodeGraphState, GpuKVCache, ModelConfig, Qwen3};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-qwen3 <model-dir> [--gen-tokens N]");
|
||||
eprintln!("Usage: bench-qwen3 <model-dir> [--gen-tokens N] [--cuda-graph]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
@@ -18,6 +18,7 @@ fn main() {
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(20);
|
||||
let use_cuda_graph = args.iter().any(|a| a == "--cuda-graph");
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
@@ -34,6 +35,18 @@ fn main() {
|
||||
let mut cache = GpuKVCache::new(&config, 256, DType::BF16, 0);
|
||||
let _ = model.forward_gpu_cache(&ids, &mut cache);
|
||||
}
|
||||
|
||||
// CUDA Graph setup
|
||||
let layer_ptrs = model.layer_weight_ptrs();
|
||||
let (norm_w, lm_head, embed, cos, sin) = model.graph_capture_ptrs();
|
||||
let mut decode_graph = if use_cuda_graph {
|
||||
eprintln!("CUDA Graph mode enabled");
|
||||
Some(DecodeGraphState::new(&config))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut graph_captured = false;
|
||||
|
||||
eprintln!("Warmup done. Running benchmark...");
|
||||
|
||||
let prompts: Vec<&str> = vec![
|
||||
@@ -96,6 +109,12 @@ fn main() {
|
||||
|
||||
let mut cache = GpuKVCache::new(&config, 256, DType::BF16, 0);
|
||||
|
||||
// Reset graph state for new prompt
|
||||
graph_captured = false;
|
||||
if let Some(ref mut g) = decode_graph {
|
||||
g.invalidate();
|
||||
}
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_gpu_cache(&input_ids, &mut cache);
|
||||
@@ -109,8 +128,35 @@ fn main() {
|
||||
for _ in 1..gen_tokens {
|
||||
let last = *generated.last().unwrap();
|
||||
let t_start = Instant::now();
|
||||
let logits = model.forward_gpu_cache(&[last], &mut cache);
|
||||
let next = sample_greedy(&logits);
|
||||
|
||||
let next = if let Some(ref mut graph) = decode_graph {
|
||||
if !graph_captured {
|
||||
// First decode token: run ungraphed, then capture
|
||||
let logits = model.forward_gpu_cache(&[last], &mut cache);
|
||||
graph_captured = true;
|
||||
graph.capture(&layer_ptrs, norm_w, lm_head, embed, cos, sin);
|
||||
sample_greedy(&logits)
|
||||
} else {
|
||||
// Replay captured graphs
|
||||
let pos = cache.seq_len() as u32;
|
||||
graph.execute(last, pos, &mut cache, &layer_ptrs, embed, config.vocab_size as i32, config.hidden() as i32);
|
||||
cache.advance_seq_len(1);
|
||||
// Read logits from graph buffer
|
||||
let vocab_size = config.vocab_size;
|
||||
let mut logits_bytes = vec![0u8; vocab_size * 2];
|
||||
graph.logits_buffer().copy_to_host(&mut logits_bytes).unwrap();
|
||||
let logits_data: &[half::bf16] = unsafe {
|
||||
std::slice::from_raw_parts(logits_bytes.as_ptr() as *const half::bf16, vocab_size)
|
||||
};
|
||||
logits_data.iter().enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
}
|
||||
} else {
|
||||
let logits = model.forward_gpu_cache(&[last], &mut cache);
|
||||
sample_greedy(&logits)
|
||||
};
|
||||
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
|
||||
419
crates/xserv-model/src/bin/xserv-chat.rs
Normal file
419
crates/xserv-model/src/bin/xserv-chat.rs
Normal file
@@ -0,0 +1,419 @@
|
||||
use std::io::{self, IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use xserv_model::{loader, sample, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
const SLOT: usize = 0;
|
||||
|
||||
struct CliOptions {
|
||||
model_dir: PathBuf,
|
||||
max_tokens: usize,
|
||||
max_seq_len: usize,
|
||||
sampling: SamplingParams,
|
||||
system_prompt: Option<String>,
|
||||
enable_thinking: bool,
|
||||
color: bool,
|
||||
}
|
||||
|
||||
enum Finish {
|
||||
Stop { token_id: u32 },
|
||||
Length,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let opts = parse_args();
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let info = xserv_cuda::device::device_info(0).unwrap();
|
||||
eprintln!(
|
||||
"GPU: {} ({} MB free)",
|
||||
info.name,
|
||||
info.free_memory / 1024 / 1024
|
||||
);
|
||||
|
||||
let config = ModelConfig::from_file(&opts.model_dir.join("config.json"));
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
if !model_type.contains("qwen") {
|
||||
eprintln!("xserv-chat currently supports Qwen-style ChatML models only; got model_type={model_type}");
|
||||
std::process::exit(2);
|
||||
}
|
||||
|
||||
let max_seq_len = opts.max_seq_len.min(config.max_seq_len()).max(1);
|
||||
eprintln!(
|
||||
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}, max_seq_len={}",
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.num_kv_heads(),
|
||||
config.vocab_size,
|
||||
max_seq_len
|
||||
);
|
||||
|
||||
eprintln!("Loading weights...");
|
||||
let weights = loader::load_model_dir(&opts.model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
|
||||
let mut cache = new_paged_cache(&config, max_seq_len);
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
let use_color = opts.color && io::stdout().is_terminal();
|
||||
|
||||
eprintln!("Ready (paged KV cache, persistent chat slot).");
|
||||
eprintln!("Commands: /exit, /quit, /clear\n");
|
||||
|
||||
loop {
|
||||
print!("user> ");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
||||
break;
|
||||
}
|
||||
let input = input.trim();
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match input {
|
||||
"/exit" | "/quit" | "exit" | "quit" => break,
|
||||
"/clear" => {
|
||||
cache.free_sequence(SLOT);
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
eprintln!("history and KV cache cleared");
|
||||
continue;
|
||||
}
|
||||
"/help" => {
|
||||
print_help();
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let include_system = cache.seq_len(SLOT) == 0;
|
||||
let prompt = build_turn_prompt(
|
||||
opts.system_prompt.as_deref(),
|
||||
include_system,
|
||||
input,
|
||||
opts.enable_thinking,
|
||||
);
|
||||
let prompt_tokens = tokenizer.encode(&prompt);
|
||||
if prompt_tokens.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let used = cache.seq_len(SLOT);
|
||||
let remaining = max_seq_len.saturating_sub(used);
|
||||
if prompt_tokens.len() >= remaining {
|
||||
eprintln!(
|
||||
"context full: {used}/{max_seq_len} tokens used, new turn needs {} tokens; use /clear",
|
||||
prompt_tokens.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let max_new_tokens = opts.max_tokens.min(remaining - prompt_tokens.len());
|
||||
|
||||
print!("assistant> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let finish = generate_with_paged_cache(
|
||||
&model,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
&prompt_tokens,
|
||||
&opts.sampling,
|
||||
max_new_tokens,
|
||||
use_color,
|
||||
);
|
||||
match finish {
|
||||
Finish::Stop { token_id } => {
|
||||
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id);
|
||||
}
|
||||
Finish::Length => {
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args() -> CliOptions {
|
||||
let args: Vec<String> = std::env::args().skip(1).collect();
|
||||
if args.is_empty() || args.iter().any(|a| a == "--help" || a == "-h") {
|
||||
print_usage_and_exit(0);
|
||||
}
|
||||
|
||||
let mut model_dir = None;
|
||||
let mut max_tokens = 256usize;
|
||||
let mut max_seq_len = 2048usize;
|
||||
let mut temperature = 0.0f32;
|
||||
let mut top_k = 0usize;
|
||||
let mut top_p = 1.0f32;
|
||||
let mut system_prompt = None;
|
||||
let mut enable_thinking = false;
|
||||
let mut color = true;
|
||||
|
||||
let mut i = 0;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"-m" | "--model" => {
|
||||
i += 1;
|
||||
model_dir = args.get(i).map(PathBuf::from);
|
||||
}
|
||||
"--max-tokens" => {
|
||||
i += 1;
|
||||
max_tokens = parse_value(&args, i, "--max-tokens");
|
||||
}
|
||||
"--max-seq-len" => {
|
||||
i += 1;
|
||||
max_seq_len = parse_value(&args, i, "--max-seq-len");
|
||||
}
|
||||
"--temperature" => {
|
||||
i += 1;
|
||||
temperature = parse_value(&args, i, "--temperature");
|
||||
}
|
||||
"--top-k" => {
|
||||
i += 1;
|
||||
top_k = parse_value(&args, i, "--top-k");
|
||||
}
|
||||
"--top-p" => {
|
||||
i += 1;
|
||||
top_p = parse_value(&args, i, "--top-p");
|
||||
}
|
||||
"--system" => {
|
||||
i += 1;
|
||||
system_prompt = args.get(i).cloned();
|
||||
if system_prompt.is_none() {
|
||||
eprintln!("missing value for --system");
|
||||
std::process::exit(2);
|
||||
}
|
||||
}
|
||||
"--think" => {
|
||||
enable_thinking = true;
|
||||
}
|
||||
"--no-color" => {
|
||||
color = false;
|
||||
}
|
||||
arg if arg.starts_with('-') => {
|
||||
eprintln!("unknown option: {arg}");
|
||||
print_usage_and_exit(2);
|
||||
}
|
||||
arg => {
|
||||
if model_dir.is_some() {
|
||||
eprintln!("unexpected extra argument: {arg}");
|
||||
print_usage_and_exit(2);
|
||||
}
|
||||
model_dir = Some(PathBuf::from(arg));
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
CliOptions {
|
||||
model_dir: model_dir.unwrap_or_else(|| {
|
||||
eprintln!("missing model directory");
|
||||
print_usage_and_exit(2);
|
||||
}),
|
||||
max_tokens: max_tokens.max(1),
|
||||
max_seq_len: max_seq_len.max(1),
|
||||
sampling: SamplingParams {
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
},
|
||||
system_prompt,
|
||||
enable_thinking,
|
||||
color,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_value<T: std::str::FromStr>(args: &[String], i: usize, name: &str) -> T {
|
||||
args.get(i).and_then(|s| s.parse().ok()).unwrap_or_else(|| {
|
||||
eprintln!("invalid or missing value for {name}");
|
||||
std::process::exit(2);
|
||||
})
|
||||
}
|
||||
|
||||
fn print_usage_and_exit(code: i32) -> ! {
|
||||
eprintln!(
|
||||
"Usage: xserv-chat <model-dir> [options]\n\
|
||||
\n\
|
||||
Options:\n\
|
||||
\t-m, --model DIR Model directory\n\
|
||||
\t--max-tokens N Max generated tokens per turn (default: 256)\n\
|
||||
\t--max-seq-len N Persistent KV context length (default: 2048)\n\
|
||||
\t--temperature F Sampling temperature, 0 = greedy (default: 0)\n\
|
||||
\t--top-k N Top-k sampling, 0 = disabled (default: 0)\n\
|
||||
\t--top-p F Top-p sampling (default: 1.0)\n\
|
||||
\t--system TEXT System prompt for the first turn after start or /clear\n\
|
||||
\t--think Let Qwen3 emit thinking; rendered gray on terminals\n\
|
||||
\t--no-color Disable ANSI color for thinking output\n\
|
||||
\t-h, --help Show this help"
|
||||
);
|
||||
std::process::exit(code);
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
eprintln!("Commands:");
|
||||
eprintln!(" /clear clear chat history and free/recreate the paged KV slot");
|
||||
eprintln!(" /exit quit");
|
||||
eprintln!(" /quit quit");
|
||||
}
|
||||
|
||||
fn new_paged_cache(config: &ModelConfig, max_seq_len: usize) -> PagedKVCache {
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = (max_blocks_per_seq + 1).max(2);
|
||||
// Single-slot interactive CLI: no swap pool (cpu_total_blocks = 0).
|
||||
PagedKVCache::new(config, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0)
|
||||
}
|
||||
|
||||
fn build_turn_prompt(
|
||||
system: Option<&str>,
|
||||
include_system: bool,
|
||||
user_input: &str,
|
||||
enable_thinking: bool,
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
if include_system {
|
||||
if let Some(system) = system {
|
||||
if !system.trim().is_empty() {
|
||||
prompt.push_str("<|im_start|>system\n");
|
||||
prompt.push_str(system.trim());
|
||||
prompt.push_str("<|im_end|>\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|im_start|>user\n");
|
||||
prompt.push_str(user_input);
|
||||
prompt.push_str("<|im_end|>\n");
|
||||
prompt.push_str("<|im_start|>assistant\n");
|
||||
if !enable_thinking {
|
||||
prompt.push_str("<think>\n\n</think>\n\n");
|
||||
}
|
||||
prompt
|
||||
}
|
||||
|
||||
fn generate_with_paged_cache(
|
||||
model: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_tokens: &[u32],
|
||||
sampling: &SamplingParams,
|
||||
max_tokens: usize,
|
||||
use_color: bool,
|
||||
) -> Finish {
|
||||
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
|
||||
let mut next = sample(&logits, sampling);
|
||||
let mut decode_buffer = Vec::new();
|
||||
let mut in_thinking = false;
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let position = cache.seq_len(SLOT);
|
||||
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
|
||||
if is_stop_token(tokenizer, next) {
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(&mut decode_buffer),
|
||||
in_thinking,
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
return Finish::Stop { token_id: next };
|
||||
}
|
||||
|
||||
print_generated_token(
|
||||
tokenizer,
|
||||
next,
|
||||
&mut decode_buffer,
|
||||
&mut in_thinking,
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
next = sample(&logits, sampling);
|
||||
}
|
||||
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(&mut decode_buffer),
|
||||
in_thinking,
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
Finish::Length
|
||||
}
|
||||
|
||||
fn append_after_stop(
|
||||
model: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
max_seq_len: usize,
|
||||
stop_token_id: u32,
|
||||
) {
|
||||
if tokenizer.special_token_id("<|im_end|>") == Some(stop_token_id) {
|
||||
append_text_to_cache(model, cache, tokenizer, max_seq_len, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
fn append_text_to_cache(
|
||||
model: &Qwen3,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
max_seq_len: usize,
|
||||
text: &str,
|
||||
) {
|
||||
let tokens = tokenizer.encode(text);
|
||||
if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len {
|
||||
return;
|
||||
}
|
||||
let _ = model.forward_prefill_paged(&tokens, SLOT, cache);
|
||||
}
|
||||
|
||||
fn print_generated_token(
|
||||
tokenizer: &Tokenizer,
|
||||
token_id: u32,
|
||||
decode_buffer: &mut Vec<u8>,
|
||||
in_thinking: &mut bool,
|
||||
use_color: bool,
|
||||
) {
|
||||
if tokenizer.special_token_id("<think>") == Some(token_id) {
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(decode_buffer),
|
||||
*in_thinking,
|
||||
use_color,
|
||||
);
|
||||
*in_thinking = true;
|
||||
print_stream_text("<think>", true, use_color);
|
||||
return;
|
||||
}
|
||||
|
||||
if tokenizer.special_token_id("</think>") == Some(token_id) {
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(decode_buffer),
|
||||
*in_thinking,
|
||||
use_color,
|
||||
);
|
||||
print_stream_text("</think>", true, use_color);
|
||||
*in_thinking = false;
|
||||
return;
|
||||
}
|
||||
|
||||
let text = tokenizer.decode_token_stream(token_id, decode_buffer);
|
||||
print_stream_text(&text, *in_thinking, use_color);
|
||||
}
|
||||
|
||||
fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) {
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
if in_thinking && use_color {
|
||||
print!("\x1b[90m{text}\x1b[0m");
|
||||
} else {
|
||||
print!("{text}");
|
||||
}
|
||||
}
|
||||
|
||||
fn is_stop_token(tokenizer: &Tokenizer, token_id: u32) -> bool {
|
||||
tokenizer.eos_token_id() == Some(token_id)
|
||||
|| tokenizer.special_token_id("<|im_end|>") == Some(token_id)
|
||||
|| tokenizer.special_token_id("<|endoftext|>") == Some(token_id)
|
||||
|| tokenizer.special_token_id("<|end_of_text|>") == Some(token_id)
|
||||
}
|
||||
458
crates/xserv-model/src/decode_graph.rs
Normal file
458
crates/xserv-model/src/decode_graph.rs
Normal file
@@ -0,0 +1,458 @@
|
||||
//! CUDA Graph integration for batch=1 single-sequence decode.
|
||||
//!
|
||||
//! Uses a per-layer split graph approach:
|
||||
//! - Pre-attention graph: RMSNorm + QKV projections + reshape + QK-norm + RoPE
|
||||
//! - Ungraphed: KV cache append + decode attention (variable kv_len)
|
||||
//! - Post-attention graph: merge_heads + O-proj + add_rmsnorm + FFN + residual
|
||||
//! - Final graph: last RMSNorm + lm_head GEMV
|
||||
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_kernels::dispatch;
|
||||
use xserv_kernels::gemm::cublas_handle;
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use crate::kv_cache::GpuKVCache;
|
||||
|
||||
/// Pre-allocated intermediate buffers for decode (batch=1).
|
||||
/// All buffers have stable GPU addresses for CUDA Graph replay.
|
||||
struct DecodeBuffers {
|
||||
// Hidden-size buffers: [1, hidden]
|
||||
x: GpuBuffer, // running hidden state
|
||||
normed: GpuBuffer, // rmsnorm output
|
||||
attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim]
|
||||
attn_merged: GpuBuffer, // merge_heads output [1, hidden]
|
||||
o_proj: GpuBuffer, // O projection output [1, hidden]
|
||||
normed2: GpuBuffer, // post-attn norm output [1, hidden]
|
||||
sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden]
|
||||
down: GpuBuffer, // down projection output [1, hidden]
|
||||
|
||||
// QKV projection outputs
|
||||
q_proj: GpuBuffer, // [1, num_heads * head_dim]
|
||||
k_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
v_proj: GpuBuffer, // [1, num_kv_heads * head_dim]
|
||||
|
||||
// Reshaped: [1, H, 1, D]
|
||||
q_reshaped: GpuBuffer,
|
||||
k_reshaped: GpuBuffer,
|
||||
v_reshaped: GpuBuffer,
|
||||
|
||||
// After QK-norm (same shape as reshaped)
|
||||
q_normed: GpuBuffer,
|
||||
k_normed: GpuBuffer,
|
||||
|
||||
// RoPE transposed: [1, H, D]
|
||||
q_rope: GpuBuffer,
|
||||
k_rope: GpuBuffer,
|
||||
|
||||
// After RoPE transpose back: [1, H, 1, D]
|
||||
q_final: GpuBuffer,
|
||||
k_final: GpuBuffer,
|
||||
|
||||
// FFN intermediates
|
||||
gate: GpuBuffer, // [1, intermediate]
|
||||
up: GpuBuffer, // [1, intermediate]
|
||||
silu_out: GpuBuffer, // [1, intermediate]
|
||||
|
||||
// GEMV fp32 accumulators (separate per output dimension)
|
||||
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
|
||||
fp32_q: GpuBuffer, // for Q projection
|
||||
fp32_kv: GpuBuffer, // for K/V projection
|
||||
fp32_intermediate: GpuBuffer,// for gate/up projections
|
||||
fp32_vocab: GpuBuffer, // for lm_head
|
||||
|
||||
// Token ID and position (GPU-resident, updated before replay)
|
||||
token_id_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
position_gpu: GpuBuffer, // 4 bytes (u32)
|
||||
|
||||
// Final output
|
||||
logits: GpuBuffer, // [1, vocab_size]
|
||||
}
|
||||
|
||||
pub struct DecodeGraphState {
|
||||
stream: CudaStream,
|
||||
buffers: DecodeBuffers,
|
||||
|
||||
// Per-layer graph pairs
|
||||
pre_attn_graphs: Vec<CudaGraph>,
|
||||
post_attn_graphs: Vec<CudaGraph>,
|
||||
final_graph: CudaGraph,
|
||||
|
||||
captured: bool,
|
||||
|
||||
// Model dimensions
|
||||
hidden: usize,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
intermediate: usize,
|
||||
vocab_size: usize,
|
||||
num_layers: usize,
|
||||
eps: f32,
|
||||
}
|
||||
|
||||
impl DecodeGraphState {
|
||||
pub fn new(config: &ModelConfig) -> Self {
|
||||
let hidden = config.hidden();
|
||||
let num_heads = config.num_heads();
|
||||
let num_kv_heads = config.num_kv_heads();
|
||||
let head_dim = config.head_dim();
|
||||
let intermediate = config.ffn_hidden();
|
||||
let vocab_size = config.vocab_size;
|
||||
let num_layers = config.num_layers();
|
||||
let eps = config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
let es = 2usize; // BF16 = 2 bytes
|
||||
|
||||
let stream = CudaStream::new().expect("create CUDA stream for graph");
|
||||
|
||||
let alloc = |size: usize| -> GpuBuffer {
|
||||
GpuBuffer::alloc(size).expect("alloc decode graph buffer")
|
||||
};
|
||||
|
||||
let buffers = DecodeBuffers {
|
||||
x: alloc(hidden * es),
|
||||
normed: alloc(hidden * es),
|
||||
attn_out: alloc(num_heads * head_dim * es),
|
||||
attn_merged: alloc(hidden * es),
|
||||
o_proj: alloc(hidden * es),
|
||||
normed2: alloc(hidden * es),
|
||||
sum_out: alloc(hidden * es),
|
||||
down: alloc(hidden * es),
|
||||
|
||||
q_proj: alloc(num_heads * head_dim * es),
|
||||
k_proj: alloc(num_kv_heads * head_dim * es),
|
||||
v_proj: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_reshaped: alloc(num_heads * head_dim * es),
|
||||
k_reshaped: alloc(num_kv_heads * head_dim * es),
|
||||
v_reshaped: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_normed: alloc(num_heads * head_dim * es),
|
||||
k_normed: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_rope: alloc(num_heads * head_dim * es),
|
||||
k_rope: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
q_final: alloc(num_heads * head_dim * es),
|
||||
k_final: alloc(num_kv_heads * head_dim * es),
|
||||
|
||||
gate: alloc(intermediate * es),
|
||||
up: alloc(intermediate * es),
|
||||
silu_out: alloc(intermediate * es),
|
||||
|
||||
fp32_hidden: alloc(hidden * 4),
|
||||
fp32_q: alloc(num_heads * head_dim * 4),
|
||||
fp32_kv: alloc(num_kv_heads * head_dim * 4),
|
||||
fp32_intermediate: alloc(intermediate * 4),
|
||||
fp32_vocab: alloc(vocab_size * 4),
|
||||
|
||||
token_id_gpu: alloc(4),
|
||||
position_gpu: alloc(4),
|
||||
|
||||
logits: alloc(vocab_size * es),
|
||||
};
|
||||
|
||||
let pre_attn_graphs = (0..num_layers).map(|_| CudaGraph::new()).collect();
|
||||
let post_attn_graphs = (0..num_layers).map(|_| CudaGraph::new()).collect();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
buffers,
|
||||
pre_attn_graphs,
|
||||
post_attn_graphs,
|
||||
final_graph: CudaGraph::new(),
|
||||
captured: false,
|
||||
hidden,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
intermediate,
|
||||
vocab_size,
|
||||
num_layers,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_captured(&self) -> bool {
|
||||
self.captured
|
||||
}
|
||||
|
||||
/// Capture all per-layer graphs. Called once after the first decode step.
|
||||
pub fn capture(
|
||||
&mut self,
|
||||
layers: &[LayerWeightPtrs],
|
||||
norm_weight: *const c_void,
|
||||
lm_head_wt: *const c_void,
|
||||
_embed_table: *const c_void,
|
||||
rope_cos: *const c_void,
|
||||
rope_sin: *const c_void,
|
||||
) {
|
||||
let s = self.stream.as_raw();
|
||||
let h = self.hidden as i32;
|
||||
let nh = self.num_heads as i32;
|
||||
let nkv = self.num_kv_heads as i32;
|
||||
let hd = self.head_dim as i32;
|
||||
let inter = self.intermediate as i32;
|
||||
let vocab = self.vocab_size as i32;
|
||||
let eps = self.eps;
|
||||
|
||||
let cublas = cublas_handle();
|
||||
|
||||
// Set cuBLAS to use our stream
|
||||
unsafe { dispatch::set_cublas_stream(cublas, s); }
|
||||
|
||||
for (l, lw) in layers.iter().enumerate() {
|
||||
// === Pre-attention graph ===
|
||||
self.pre_attn_graphs[l].begin_capture(&self.stream).expect("begin pre-attn capture");
|
||||
unsafe {
|
||||
// RMSNorm
|
||||
dispatch::rmsnorm_bf16(
|
||||
self.buffers.x.as_ptr() as _, lw.input_norm, self.buffers.normed.as_mut_ptr() as _,
|
||||
1, h, eps, s,
|
||||
);
|
||||
|
||||
// Q projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.q_proj_wt, self.buffers.q_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_q.as_mut_ptr() as _,
|
||||
h, nh * hd, s,
|
||||
);
|
||||
|
||||
// K projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.k_proj_wt, self.buffers.k_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h, nkv * hd, s,
|
||||
);
|
||||
|
||||
// V projection (GEMV)
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lw.v_proj_wt, self.buffers.v_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_kv.as_mut_ptr() as _,
|
||||
h, nkv * hd, s,
|
||||
);
|
||||
|
||||
// Reshape heads: [1, H*D] -> [1, H, 1, D]
|
||||
dispatch::reshape_heads_bf16(self.buffers.q_proj.as_ptr() as _, self.buffers.q_reshaped.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::reshape_heads_bf16(self.buffers.k_proj.as_ptr() as _, self.buffers.k_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
dispatch::reshape_heads_bf16(self.buffers.v_proj.as_ptr() as _, self.buffers.v_reshaped.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
|
||||
// QK norm (head-level rmsnorm: treat [1,H,1,D] as [H, D])
|
||||
dispatch::rmsnorm_bf16(self.buffers.q_reshaped.as_ptr() as _, lw.q_norm, self.buffers.q_normed.as_mut_ptr() as _, nh, hd, eps, s);
|
||||
dispatch::rmsnorm_bf16(self.buffers.k_reshaped.as_ptr() as _, lw.k_norm, self.buffers.k_normed.as_mut_ptr() as _, nkv, hd, eps, s);
|
||||
|
||||
// Transpose for RoPE: [1,H,1,D] -> [1,H,D]
|
||||
dispatch::transpose_hsd_to_shd_bf16(self.buffers.q_normed.as_ptr() as _, self.buffers.q_rope.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::transpose_hsd_to_shd_bf16(self.buffers.k_normed.as_ptr() as _, self.buffers.k_rope.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
|
||||
// RoPE (in-place, reads position_gpu)
|
||||
dispatch::rope_bf16(self.buffers.q_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::rope_bf16(self.buffers.k_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nkv, hd, s);
|
||||
|
||||
// Transpose back: [1,H,D] -> [1,H,1,D]
|
||||
dispatch::transpose_shd_to_hsd_bf16(self.buffers.q_rope.as_ptr() as _, self.buffers.q_final.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
dispatch::transpose_shd_to_hsd_bf16(self.buffers.k_rope.as_ptr() as _, self.buffers.k_final.as_mut_ptr() as _, 1, nkv, hd, s);
|
||||
}
|
||||
self.pre_attn_graphs[l].end_capture(&self.stream).expect("end pre-attn capture");
|
||||
|
||||
// === Post-attention graph ===
|
||||
self.post_attn_graphs[l].begin_capture(&self.stream).expect("begin post-attn capture");
|
||||
unsafe {
|
||||
// Merge heads: [1,H,1,D] -> [1, hidden]
|
||||
// attn_out is written by ungraphed attention
|
||||
dispatch::merge_heads_bf16(self.buffers.attn_out.as_ptr() as _, self.buffers.attn_merged.as_mut_ptr() as _, 1, nh, hd, s);
|
||||
|
||||
// O projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.attn_merged.as_ptr() as _, lw.o_proj_wt, self.buffers.o_proj.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
nh * hd, h, s,
|
||||
);
|
||||
|
||||
// Fused Add+RMSNorm: normed2 = rmsnorm(o_proj + x), sum_out = o_proj + x
|
||||
dispatch::add_rmsnorm_bf16(
|
||||
self.buffers.o_proj.as_ptr() as _, self.buffers.x.as_ptr() as _, lw.post_norm,
|
||||
self.buffers.normed2.as_mut_ptr() as _, self.buffers.sum_out.as_mut_ptr() as _,
|
||||
1, h, eps, s,
|
||||
);
|
||||
|
||||
// Gate projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _, lw.gate_proj_wt, self.buffers.gate.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h, inter, s,
|
||||
);
|
||||
|
||||
// Up projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed2.as_ptr() as _, lw.up_proj_wt, self.buffers.up.as_mut_ptr() as _,
|
||||
self.buffers.fp32_intermediate.as_mut_ptr() as _,
|
||||
h, inter, s,
|
||||
);
|
||||
|
||||
// Fused SiLU x Mul
|
||||
dispatch::silu_mul_bf16(self.buffers.gate.as_ptr() as _, self.buffers.up.as_ptr() as _, self.buffers.silu_out.as_mut_ptr() as _, inter, s);
|
||||
|
||||
// Down projection
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.silu_out.as_ptr() as _, lw.down_proj_wt, self.buffers.down.as_mut_ptr() as _,
|
||||
self.buffers.fp32_hidden.as_mut_ptr() as _,
|
||||
inter, h, s,
|
||||
);
|
||||
|
||||
// x = sum_out + down (residual connection for next layer)
|
||||
dispatch::add_bf16(self.buffers.sum_out.as_ptr() as _, self.buffers.down.as_ptr() as _, self.buffers.x.as_mut_ptr() as _, h, s);
|
||||
}
|
||||
self.post_attn_graphs[l].end_capture(&self.stream).expect("end post-attn capture");
|
||||
}
|
||||
|
||||
// === Final graph: norm + lm_head ===
|
||||
self.final_graph.begin_capture(&self.stream).expect("begin final capture");
|
||||
unsafe {
|
||||
dispatch::rmsnorm_bf16(self.buffers.x.as_ptr() as _, norm_weight, self.buffers.normed.as_mut_ptr() as _, 1, h, eps, s);
|
||||
dispatch::gemv_bf16(
|
||||
self.buffers.normed.as_ptr() as _, lm_head_wt, self.buffers.logits.as_mut_ptr() as _,
|
||||
self.buffers.fp32_vocab.as_mut_ptr() as _,
|
||||
h, vocab, s,
|
||||
);
|
||||
}
|
||||
self.final_graph.end_capture(&self.stream).expect("end final capture");
|
||||
|
||||
// Reset cuBLAS back to null stream
|
||||
unsafe { dispatch::set_cublas_stream(cublas, std::ptr::null_mut()); }
|
||||
|
||||
self.captured = true;
|
||||
}
|
||||
|
||||
/// Execute a single decode step using captured graphs.
|
||||
pub fn execute(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
position: u32,
|
||||
cache: &mut GpuKVCache,
|
||||
_layers: &[LayerWeightPtrs],
|
||||
embed_table: *const c_void,
|
||||
vocab_size: i32,
|
||||
hidden_size: i32,
|
||||
) {
|
||||
assert!(self.captured, "must call capture() before execute()");
|
||||
let s = self.stream.as_raw();
|
||||
let nkv = self.num_kv_heads;
|
||||
let nh = self.num_heads;
|
||||
let hd = self.head_dim;
|
||||
let es = 2usize; // BF16
|
||||
|
||||
// Upload token ID and position to fixed GPU buffers
|
||||
self.buffers.token_id_gpu.copy_from_host(&token_id.to_le_bytes()).unwrap();
|
||||
self.buffers.position_gpu.copy_from_host(&position.to_le_bytes()).unwrap();
|
||||
|
||||
// Embedding (outside graph since token_id changes each step)
|
||||
unsafe {
|
||||
dispatch::embedding_bf16(
|
||||
embed_table,
|
||||
self.buffers.token_id_gpu.as_ptr() as _,
|
||||
self.buffers.x.as_mut_ptr() as _,
|
||||
1, hidden_size, vocab_size, s,
|
||||
);
|
||||
}
|
||||
|
||||
for l in 0..self.num_layers {
|
||||
// Pre-attention graph (norm + QKV + reshape + QK-norm + RoPE)
|
||||
self.pre_attn_graphs[l].launch(&self.stream).expect("launch pre-attn graph");
|
||||
|
||||
// Ungraphed: KV cache append
|
||||
// k_final shape: [1, num_kv_heads, 1, head_dim] (after RoPE pipeline)
|
||||
// v_reshaped shape: [1, num_kv_heads, 1, head_dim] (V skips RoPE)
|
||||
let pos = position as usize;
|
||||
|
||||
let k_buf_size = nkv * hd * es;
|
||||
let v_buf_size = nkv * hd * es;
|
||||
let shape = [1usize, nkv, 1, hd];
|
||||
|
||||
// Synchronize before accessing buffers for KV cache append
|
||||
self.stream.synchronize().expect("sync before kv cache");
|
||||
|
||||
let k_view = unsafe {
|
||||
crate::kv_cache::tensor_from_gpu_buffer_pub(
|
||||
GpuBuffer::borrow_raw(self.buffers.k_final.as_mut_ptr(), k_buf_size),
|
||||
&shape,
|
||||
xserv_tensor::DType::BF16,
|
||||
0,
|
||||
)
|
||||
};
|
||||
let v_view = unsafe {
|
||||
crate::kv_cache::tensor_from_gpu_buffer_pub(
|
||||
GpuBuffer::borrow_raw(self.buffers.v_reshaped.as_mut_ptr(), v_buf_size),
|
||||
&shape,
|
||||
xserv_tensor::DType::BF16,
|
||||
0,
|
||||
)
|
||||
};
|
||||
cache.append(l, &k_view, &v_view, 1, pos);
|
||||
|
||||
// Ungraphed: get full KV cache and run decode attention
|
||||
let (k_full, v_full) = cache.get_kv_len(l, pos + 1);
|
||||
let kv_len = (pos + 1) as i32;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
|
||||
// Attention output written to attn_out (separate from q_final)
|
||||
unsafe {
|
||||
dispatch::decode_attention_bf16(
|
||||
self.buffers.q_final.as_ptr() as _,
|
||||
k_full.data_ptr() as _,
|
||||
v_full.data_ptr() as _,
|
||||
self.buffers.attn_out.as_mut_ptr() as _,
|
||||
1, nh as i32, nkv as i32,
|
||||
kv_len, hd as i32,
|
||||
scale, s,
|
||||
);
|
||||
}
|
||||
|
||||
// Synchronize before post-attention graph reads attn_out
|
||||
self.stream.synchronize().expect("sync before post-attn");
|
||||
|
||||
// Post-attention graph (merge + O-proj + add_rmsnorm + FFN + residual)
|
||||
self.post_attn_graphs[l].launch(&self.stream).expect("launch post-attn graph");
|
||||
}
|
||||
|
||||
// Final graph (norm + lm_head)
|
||||
self.final_graph.launch(&self.stream).expect("launch final graph");
|
||||
|
||||
// Sync to ensure logits are ready
|
||||
self.stream.synchronize().expect("sync after decode");
|
||||
}
|
||||
|
||||
/// Get the logits buffer (for reading results after execute).
|
||||
pub fn logits_buffer(&self) -> &GpuBuffer {
|
||||
&self.buffers.logits
|
||||
}
|
||||
|
||||
/// Invalidate captured graphs (e.g. when switching sequences).
|
||||
pub fn invalidate(&mut self) {
|
||||
self.captured = false;
|
||||
self.pre_attn_graphs = (0..self.num_layers).map(|_| CudaGraph::new()).collect();
|
||||
self.post_attn_graphs = (0..self.num_layers).map(|_| CudaGraph::new()).collect();
|
||||
self.final_graph = CudaGraph::new();
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for DecodeGraphState {}
|
||||
|
||||
/// Lightweight struct holding raw pointers to a layer's weight tensors.
|
||||
/// Used to avoid passing the full model struct into the graph capture code.
|
||||
pub struct LayerWeightPtrs {
|
||||
pub input_norm: *const c_void,
|
||||
pub q_proj_wt: *const c_void,
|
||||
pub k_proj_wt: *const c_void,
|
||||
pub v_proj_wt: *const c_void,
|
||||
pub o_proj_wt: *const c_void,
|
||||
pub q_norm: *const c_void,
|
||||
pub k_norm: *const c_void,
|
||||
pub post_norm: *const c_void,
|
||||
pub gate_proj_wt: *const c_void,
|
||||
pub up_proj_wt: *const c_void,
|
||||
pub down_proj_wt: *const c_void,
|
||||
}
|
||||
|
||||
unsafe impl Send for LayerWeightPtrs {}
|
||||
unsafe impl Sync for LayerWeightPtrs {}
|
||||
@@ -280,45 +280,88 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) {
|
||||
let hidden = num_heads * head_dim;
|
||||
let qkv_cpu = qkv.to_device(Device::Cpu);
|
||||
let data = qkv_cpu.as_slice::<f32>();
|
||||
|
||||
let mut q_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut k_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut v_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
|
||||
for s in 0..seq_len {
|
||||
let row = &data[s * 3 * hidden..(s + 1) * 3 * hidden];
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
|
||||
let device = qkv.device();
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
let dtype = qkv.dtype();
|
||||
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let data = qkv_cpu.as_slice::<f32>();
|
||||
let mut q_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut k_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut v_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
let row = &data[s * 3 * hidden..(s + 1) * 3 * hidden];
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = qkv_cpu.as_slice::<half::bf16>();
|
||||
let mut q_data = vec![half::bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
let mut k_data = vec![half::bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
let mut v_data = vec![half::bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
let row = &data[s * 3 * hidden..(s + 1) * 3 * hidden];
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
_ => panic!("unsupported dtype {:?} in split_qkv", dtype),
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<f32>();
|
||||
let device = x.device();
|
||||
let dtype = x.dtype();
|
||||
|
||||
let mut out = vec![0.0f32; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let src = x_cpu.as_slice::<f32>();
|
||||
let mut out = vec![0.0f32; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let src = x_cpu.as_slice::<half::bf16>();
|
||||
let mut out = vec![half::bf16::ZERO; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device)
|
||||
}
|
||||
_ => panic!("unsupported dtype {:?} in merge_heads", dtype),
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// Greedy sampling: return the argmax token ID from the last position's logits.
|
||||
|
||||
@@ -76,6 +76,7 @@ impl GpuKVCache {
|
||||
|
||||
pub fn advance_seq_len(&mut self, new_tokens: usize) {
|
||||
self.seq_len += new_tokens;
|
||||
assert!(self.seq_len <= self.max_seq_len, "KV cache seq_len ({}) exceeds max_seq_len ({})", self.seq_len, self.max_seq_len);
|
||||
}
|
||||
|
||||
/// Get K/V cache tensors for a layer up to `seq_len` tokens: [1, num_kv_heads, seq_len, head_dim]
|
||||
@@ -85,6 +86,7 @@ impl GpuKVCache {
|
||||
}
|
||||
|
||||
pub fn get_kv_len(&mut self, layer: usize, sl: usize) -> (Tensor, Tensor) {
|
||||
assert!(sl <= self.max_seq_len, "get_kv_len: sl ({sl}) exceeds max_seq_len ({})", self.max_seq_len);
|
||||
let hd = self.head_dim;
|
||||
let nh = self.num_kv_heads;
|
||||
let es = self.elem_size;
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
pub mod config;
|
||||
pub mod decode_graph;
|
||||
pub mod gpt2;
|
||||
pub mod kv_cache;
|
||||
pub mod loader;
|
||||
pub mod paged_kv_cache;
|
||||
pub mod qwen3;
|
||||
pub mod sampling;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
pub use decode_graph::{DecodeGraphState, LayerWeightPtrs};
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use kv_cache::GpuKVCache;
|
||||
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
|
||||
pub use qwen3::Qwen3;
|
||||
pub use sampling::{SamplingParams, sample};
|
||||
|
||||
|
||||
569
crates/xserv-model/src/paged_kv_cache.rs
Normal file
569
crates/xserv-model/src/paged_kv_cache.rs
Normal file
@@ -0,0 +1,569 @@
|
||||
//! Paged KV cache: vLLM-style block-based KV cache with O(1) allocation
|
||||
//! and indirection via per-sequence block tables.
|
||||
//!
|
||||
//! Physical layout per layer:
|
||||
//! K pool: [total_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16
|
||||
//! V pool: same
|
||||
//!
|
||||
//! Logical view per sequence: a list of physical block ids. Token at logical
|
||||
//! position p lives in block_ids[p / BLOCK_SIZE] at slot (p % BLOCK_SIZE).
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use xserv_cuda::{GpuBuffer, PinnedBuffer};
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
pub const BLOCK_SIZE: usize = 16;
|
||||
|
||||
/// Stack-based block allocator: O(1) alloc/free.
|
||||
pub struct BlockAllocator {
|
||||
free_stack: Vec<u32>,
|
||||
total: usize,
|
||||
}
|
||||
|
||||
impl BlockAllocator {
|
||||
pub fn new(total_blocks: usize) -> Self {
|
||||
// Reserve block 0 as a sentinel "null" block (never allocated).
|
||||
// Free list contains [total-1, total-2, ..., 1] so pop returns 1 first.
|
||||
// total_blocks==0 means "disabled" (e.g. swap off): empty free list.
|
||||
let mut free_stack = Vec::with_capacity(total_blocks.saturating_sub(1));
|
||||
for b in (1..total_blocks).rev() {
|
||||
free_stack.push(b as u32);
|
||||
}
|
||||
Self { free_stack, total: total_blocks }
|
||||
}
|
||||
|
||||
pub fn alloc(&mut self) -> Option<u32> {
|
||||
self.free_stack.pop()
|
||||
}
|
||||
|
||||
pub fn free(&mut self, block: u32) {
|
||||
debug_assert!((block as usize) < self.total && block != 0);
|
||||
self.free_stack.push(block);
|
||||
}
|
||||
|
||||
pub fn free_count(&self) -> usize {
|
||||
self.free_stack.len()
|
||||
}
|
||||
|
||||
pub fn total(&self) -> usize {
|
||||
self.total
|
||||
}
|
||||
|
||||
pub fn can_alloc(&self, n: usize) -> bool {
|
||||
self.free_stack.len() >= n
|
||||
}
|
||||
}
|
||||
|
||||
/// Where a sequence's KV blocks currently live.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum Location {
|
||||
Gpu,
|
||||
Cpu,
|
||||
}
|
||||
|
||||
/// Per-sequence state held in the cache.
|
||||
#[derive(Clone)]
|
||||
pub struct SeqState {
|
||||
/// Block ids into the GPU pool when `location == Gpu`, or into the CPU
|
||||
/// (pinned host) pool when `location == Cpu`.
|
||||
pub block_ids: Vec<u32>,
|
||||
pub seq_len: usize,
|
||||
pub location: Location,
|
||||
}
|
||||
|
||||
pub struct PagedKVCache {
|
||||
// [layer]: GpuBuffer of size total_blocks * nkv * BLOCK_SIZE * hd * elem_size
|
||||
k_pools: Vec<GpuBuffer>,
|
||||
v_pools: Vec<GpuBuffer>,
|
||||
|
||||
// CPU (pinned host) swap pools, same per-layer layout as the GPU pools but
|
||||
// sized for `cpu_total_blocks`. Empty when swap is disabled.
|
||||
cpu_k_pools: Vec<PinnedBuffer>,
|
||||
cpu_v_pools: Vec<PinnedBuffer>,
|
||||
cpu_allocator: BlockAllocator,
|
||||
|
||||
// Bytes occupied by one block within a single layer pool:
|
||||
// num_kv_heads * BLOCK_SIZE * head_dim * elem_size.
|
||||
block_bytes: usize,
|
||||
|
||||
allocator: BlockAllocator,
|
||||
seq_states: Vec<Option<SeqState>>,
|
||||
|
||||
// GPU-resident per-sequence metadata. Uploaded each step via sync_to_gpu().
|
||||
// block_table_gpu: i32 [max_seqs, max_blocks_per_seq]
|
||||
// context_lens_gpu: i32 [max_seqs]
|
||||
block_table_gpu: GpuBuffer,
|
||||
context_lens_gpu: GpuBuffer,
|
||||
// Host-side staging mirroring the GPU buffers above.
|
||||
block_table_host: Vec<i32>,
|
||||
context_lens_host: Vec<i32>,
|
||||
|
||||
// Config
|
||||
num_layers: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
elem_size: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
max_seqs: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
}
|
||||
|
||||
impl PagedKVCache {
|
||||
/// Bytes occupied by all KV blocks for ONE physical block across the whole
|
||||
/// model (both K and V, all layers). Use this to size pools against VRAM.
|
||||
pub fn bytes_per_block(config: &ModelConfig, dtype: DType) -> usize {
|
||||
2 * config.num_layers()
|
||||
* config.num_kv_heads()
|
||||
* BLOCK_SIZE
|
||||
* config.head_dim()
|
||||
* dtype.size_bytes()
|
||||
}
|
||||
|
||||
/// Create a new paged cache.
|
||||
/// - `total_blocks`: total number of physical GPU blocks across all sequences.
|
||||
/// - `cpu_total_blocks`: physical blocks in the pinned-host swap pool (0 = swap off).
|
||||
/// - `max_seqs`: max number of concurrent sequences (slots), incl. swapped.
|
||||
/// - `max_blocks_per_seq`: capacity of the block table per slot
|
||||
/// (must be >= ceil(max_seq_len / BLOCK_SIZE)).
|
||||
pub fn new(
|
||||
config: &ModelConfig,
|
||||
total_blocks: usize,
|
||||
cpu_total_blocks: usize,
|
||||
max_seqs: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
dtype: DType,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
assert!(total_blocks >= 2, "need at least 2 blocks (one is sentinel)");
|
||||
let num_layers = config.num_layers();
|
||||
let num_kv_heads = config.num_kv_heads();
|
||||
let head_dim = config.head_dim();
|
||||
let elem_size = dtype.size_bytes();
|
||||
let block_bytes = num_kv_heads * BLOCK_SIZE * head_dim * elem_size;
|
||||
let pool_bytes = total_blocks * block_bytes;
|
||||
|
||||
let mut k_pools = Vec::with_capacity(num_layers);
|
||||
let mut v_pools = Vec::with_capacity(num_layers);
|
||||
for _ in 0..num_layers {
|
||||
let mut k = GpuBuffer::alloc(pool_bytes).expect("alloc paged K pool");
|
||||
let mut v = GpuBuffer::alloc(pool_bytes).expect("alloc paged V pool");
|
||||
k.zero().unwrap();
|
||||
v.zero().unwrap();
|
||||
k_pools.push(k);
|
||||
v_pools.push(v);
|
||||
}
|
||||
|
||||
// Pinned-host swap pools (one per layer, mirroring the GPU layout).
|
||||
let mut cpu_k_pools = Vec::new();
|
||||
let mut cpu_v_pools = Vec::new();
|
||||
if cpu_total_blocks >= 2 {
|
||||
let cpu_pool_bytes = cpu_total_blocks * block_bytes;
|
||||
for _ in 0..num_layers {
|
||||
cpu_k_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool"));
|
||||
cpu_v_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool"));
|
||||
}
|
||||
}
|
||||
let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 { cpu_total_blocks } else { 0 });
|
||||
|
||||
let block_table_gpu =
|
||||
GpuBuffer::alloc(max_seqs * max_blocks_per_seq * std::mem::size_of::<i32>())
|
||||
.expect("alloc block table");
|
||||
let context_lens_gpu =
|
||||
GpuBuffer::alloc(max_seqs * std::mem::size_of::<i32>()).expect("alloc context lens");
|
||||
|
||||
let block_table_host = vec![0i32; max_seqs * max_blocks_per_seq];
|
||||
let context_lens_host = vec![0i32; max_seqs];
|
||||
|
||||
let seq_states = (0..max_seqs).map(|_| None).collect();
|
||||
|
||||
Self {
|
||||
k_pools,
|
||||
v_pools,
|
||||
cpu_k_pools,
|
||||
cpu_v_pools,
|
||||
cpu_allocator,
|
||||
block_bytes,
|
||||
allocator: BlockAllocator::new(total_blocks),
|
||||
seq_states,
|
||||
block_table_gpu,
|
||||
context_lens_gpu,
|
||||
block_table_host,
|
||||
context_lens_host,
|
||||
num_layers,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
elem_size,
|
||||
dtype,
|
||||
device,
|
||||
max_seqs,
|
||||
max_blocks_per_seq,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize { self.num_layers }
|
||||
pub fn num_kv_heads(&self) -> usize { self.num_kv_heads }
|
||||
pub fn head_dim(&self) -> usize { self.head_dim }
|
||||
pub fn dtype(&self) -> DType { self.dtype }
|
||||
pub fn max_seqs(&self) -> usize { self.max_seqs }
|
||||
pub fn max_blocks_per_seq(&self) -> usize { self.max_blocks_per_seq }
|
||||
pub fn free_blocks(&self) -> usize { self.allocator.free_count() }
|
||||
pub fn total_blocks(&self) -> usize { self.allocator.total() }
|
||||
|
||||
pub fn k_pool(&self, layer: usize) -> &GpuBuffer { &self.k_pools[layer] }
|
||||
pub fn v_pool(&self, layer: usize) -> &GpuBuffer { &self.v_pools[layer] }
|
||||
pub fn block_table_gpu(&self) -> &GpuBuffer { &self.block_table_gpu }
|
||||
pub fn context_lens_gpu(&self) -> &GpuBuffer { &self.context_lens_gpu }
|
||||
|
||||
pub fn seq_len(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot].as_ref().map(|s| s.seq_len).unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn is_slot_free(&self, slot: usize) -> bool {
|
||||
self.seq_states[slot].is_none()
|
||||
}
|
||||
|
||||
/// Register a new sequence at `slot`. Allocates the first block.
|
||||
/// Returns Err(()) if no slot or no blocks are available.
|
||||
pub fn register_sequence(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
if slot >= self.max_seqs {
|
||||
return Err("slot out of range");
|
||||
}
|
||||
if self.seq_states[slot].is_some() {
|
||||
return Err("slot already in use");
|
||||
}
|
||||
let block = self.allocator.alloc().ok_or("out of blocks")?;
|
||||
self.seq_states[slot] = Some(SeqState {
|
||||
block_ids: vec![block],
|
||||
seq_len: 0,
|
||||
location: Location::Gpu,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Free all blocks for `slot` and clear the slot. Frees from whichever pool
|
||||
/// (GPU or CPU) the sequence currently lives in.
|
||||
pub fn free_sequence(&mut self, slot: usize) {
|
||||
if let Some(state) = self.seq_states[slot].take() {
|
||||
let alloc = match state.location {
|
||||
Location::Gpu => &mut self.allocator,
|
||||
Location::Cpu => &mut self.cpu_allocator,
|
||||
};
|
||||
for b in state.block_ids {
|
||||
alloc.free(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of blocks needed to hold `seq_len + new_tokens` tokens, beyond
|
||||
/// what is currently allocated for `slot`.
|
||||
pub fn additional_blocks_needed(&self, slot: usize, new_tokens: usize) -> usize {
|
||||
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
|
||||
let cur = state.block_ids.len();
|
||||
let needed_total = (state.seq_len + new_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if needed_total > cur { needed_total - cur } else { 0 }
|
||||
}
|
||||
|
||||
/// Pre-allocate enough physical blocks in `slot` to cover positions
|
||||
/// `[0, end_pos)`. Call once before the per-layer append loop so that
|
||||
/// every layer's append uses the same block table.
|
||||
pub fn ensure_capacity(&mut self, slot: usize, end_pos: usize) {
|
||||
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
|
||||
let needed_total = (end_pos + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
while state.block_ids.len() < needed_total {
|
||||
let b = self.allocator.alloc().expect("out of blocks (caller must check)");
|
||||
assert!(state.block_ids.len() < self.max_blocks_per_seq, "block table overflow");
|
||||
state.block_ids.push(b);
|
||||
}
|
||||
}
|
||||
|
||||
/// Append `num_tokens` of K/V into the paged pool for `slot` at logical
|
||||
/// position `start_pos`. Caller must have called `ensure_capacity(slot, start_pos + num_tokens)`
|
||||
/// first (or accept that this method may also extend block list).
|
||||
/// Does NOT touch `seq_len`. Call `advance_seq_len(slot, num_tokens)` after
|
||||
/// every layer has been written.
|
||||
///
|
||||
/// `k_new`, `v_new`: GPU tensors with logical shape
|
||||
/// [1, num_kv_heads, num_tokens, head_dim]
|
||||
/// stored contiguously (head-major, then tokens, then dim).
|
||||
pub fn append_tokens(
|
||||
&mut self,
|
||||
slot: usize,
|
||||
layer: usize,
|
||||
k_new: &Tensor,
|
||||
v_new: &Tensor,
|
||||
num_tokens: usize,
|
||||
start_pos: usize,
|
||||
) {
|
||||
if num_tokens == 0 { return; }
|
||||
// Make sure blocks exist for the target range.
|
||||
self.ensure_capacity(slot, start_pos + num_tokens);
|
||||
|
||||
let block_ids = self.seq_states[slot].as_ref().unwrap().block_ids.clone();
|
||||
|
||||
let nkv = self.num_kv_heads;
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let bs = BLOCK_SIZE;
|
||||
|
||||
let k_src = k_new.storage().gpu_buffer();
|
||||
let v_src = v_new.storage().gpu_buffer();
|
||||
|
||||
let k_pool = &mut self.k_pools[layer];
|
||||
let v_pool = &mut self.v_pools[layer];
|
||||
|
||||
let mut t = 0usize;
|
||||
while t < num_tokens {
|
||||
let p = start_pos + t;
|
||||
let logical_blk = p / bs;
|
||||
let slot_in_blk = p % bs;
|
||||
let chunk = (bs - slot_in_blk).min(num_tokens - t);
|
||||
let phys = block_ids[logical_blk] as usize;
|
||||
|
||||
for h in 0..nkv {
|
||||
let src_off = (h * num_tokens + t) * hd * es;
|
||||
let dst_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es;
|
||||
let count = chunk * hd * es;
|
||||
k_pool.copy_from_device_at(k_src, src_off, dst_off, count).unwrap();
|
||||
v_pool.copy_from_device_at(v_src, src_off, dst_off, count).unwrap();
|
||||
}
|
||||
|
||||
t += chunk;
|
||||
}
|
||||
}
|
||||
|
||||
/// Advance the logical seq_len after append_tokens for ALL layers has completed.
|
||||
pub fn advance_seq_len(&mut self, slot: usize, num_tokens: usize) {
|
||||
let state = self.seq_states[slot].as_mut().expect("unregistered slot");
|
||||
state.seq_len += num_tokens;
|
||||
}
|
||||
|
||||
/// Refresh the host-side block table + context lens from `seq_states`,
|
||||
/// then upload to GPU. Call once per decode step before the paged kernel.
|
||||
pub fn sync_to_gpu(&mut self) {
|
||||
let stride = self.max_blocks_per_seq;
|
||||
for slot in 0..self.max_seqs {
|
||||
let row = &mut self.block_table_host[slot * stride..(slot + 1) * stride];
|
||||
row.fill(0);
|
||||
let len = match &self.seq_states[slot] {
|
||||
Some(s) => {
|
||||
for (i, b) in s.block_ids.iter().enumerate() {
|
||||
row[i] = *b as i32;
|
||||
}
|
||||
s.seq_len as i32
|
||||
}
|
||||
None => 0,
|
||||
};
|
||||
self.context_lens_host[slot] = len;
|
||||
}
|
||||
|
||||
self.upload_metadata();
|
||||
}
|
||||
|
||||
/// Pack the given active slots into rows 0..slots.len() of block_table_gpu
|
||||
/// and context_lens_gpu, then upload. Used by paged decode where the kernel
|
||||
/// iterates over `batch` active sequences in order.
|
||||
pub fn sync_active_batch_to_gpu(&mut self, slots: &[usize]) {
|
||||
let lens: Vec<i32> = slots
|
||||
.iter()
|
||||
.map(|&s| self.seq_states[s].as_ref().unwrap().seq_len as i32)
|
||||
.collect();
|
||||
self.sync_active_batch_with_lens(slots, &lens);
|
||||
}
|
||||
|
||||
/// Like sync_active_batch_to_gpu but uses caller-supplied kv_lens (number
|
||||
/// of valid K/V tokens to attend over per active row). Useful when the
|
||||
/// kv_len for the current step differs from the cached seq_len (e.g.
|
||||
/// before advance_seq_len has run).
|
||||
pub fn sync_active_batch_with_lens(&mut self, slots: &[usize], kv_lens: &[i32]) {
|
||||
assert_eq!(slots.len(), kv_lens.len());
|
||||
assert!(slots.len() <= self.max_seqs, "active batch exceeds max_seqs");
|
||||
let stride = self.max_blocks_per_seq;
|
||||
for row in &mut self.block_table_host {
|
||||
*row = 0;
|
||||
}
|
||||
for cl in &mut self.context_lens_host {
|
||||
*cl = 0;
|
||||
}
|
||||
for (i, &slot) in slots.iter().enumerate() {
|
||||
let s = self.seq_states[slot].as_ref().expect("unregistered slot in active batch");
|
||||
let row = &mut self.block_table_host[i * stride..(i + 1) * stride];
|
||||
for (j, b) in s.block_ids.iter().enumerate() {
|
||||
row[j] = *b as i32;
|
||||
}
|
||||
self.context_lens_host[i] = kv_lens[i];
|
||||
}
|
||||
self.upload_metadata();
|
||||
}
|
||||
|
||||
fn upload_metadata(&mut self) {
|
||||
let bt_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
self.block_table_host.as_ptr() as *const u8,
|
||||
self.block_table_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
};
|
||||
self.block_table_gpu.copy_from_host(bt_bytes).unwrap();
|
||||
|
||||
let cl_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
self.context_lens_host.as_ptr() as *const u8,
|
||||
self.context_lens_host.len() * std::mem::size_of::<i32>(),
|
||||
)
|
||||
};
|
||||
self.context_lens_gpu.copy_from_host(cl_bytes).unwrap();
|
||||
}
|
||||
|
||||
/// Materialize a contiguous K/V tensor for a sequence at `layer`, shaped
|
||||
/// [1, num_kv_heads, seq_len, head_dim]. Used for prefill, where Flash
|
||||
/// Attention 2 expects contiguous K/V.
|
||||
///
|
||||
/// Allocates from the cached allocator; the returned Tensors own their storage.
|
||||
pub fn gather_kv_contiguous(&self, slot: usize, layer: usize) -> (Tensor, Tensor) {
|
||||
let state = self.seq_states[slot].as_ref().expect("unregistered slot");
|
||||
let sl = state.seq_len;
|
||||
let nkv = self.num_kv_heads;
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let bs = BLOCK_SIZE;
|
||||
|
||||
let out_bytes = nkv * sl * hd * es;
|
||||
let mut k_dst = xserv_cuda::allocator::cached_alloc(out_bytes).expect("alloc gather K");
|
||||
let mut v_dst = xserv_cuda::allocator::cached_alloc(out_bytes).expect("alloc gather V");
|
||||
|
||||
let k_pool = &self.k_pools[layer];
|
||||
let v_pool = &self.v_pools[layer];
|
||||
|
||||
let mut p = 0usize;
|
||||
while p < sl {
|
||||
let logical_blk = p / bs;
|
||||
let slot_in_blk = p % bs;
|
||||
let chunk = (bs - slot_in_blk).min(sl - p);
|
||||
let phys = state.block_ids[logical_blk] as usize;
|
||||
|
||||
for h in 0..nkv {
|
||||
let src_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es;
|
||||
let dst_off = (h * sl + p) * hd * es;
|
||||
let count = chunk * hd * es;
|
||||
k_dst.copy_from_device_at(k_pool, src_off, dst_off, count).unwrap();
|
||||
v_dst.copy_from_device_at(v_pool, src_off, dst_off, count).unwrap();
|
||||
}
|
||||
p += chunk;
|
||||
}
|
||||
|
||||
let shape = &[1usize, nkv, sl, hd];
|
||||
let k = unsafe { tensor_from_owned_buf(k_dst, shape, self.dtype, self.device) };
|
||||
let v = unsafe { tensor_from_owned_buf(v_dst, shape, self.dtype, self.device) };
|
||||
(k, v)
|
||||
}
|
||||
|
||||
// ----- Swapping (vLLM-style preemption to pinned host memory) -----
|
||||
|
||||
pub fn free_cpu_blocks(&self) -> usize { self.cpu_allocator.free_count() }
|
||||
pub fn swap_enabled(&self) -> bool { !self.cpu_k_pools.is_empty() }
|
||||
|
||||
pub fn is_swapped(&self, slot: usize) -> bool {
|
||||
matches!(self.seq_states[slot].as_ref().map(|s| s.location), Some(Location::Cpu))
|
||||
}
|
||||
|
||||
/// Number of physical blocks currently held by `slot` (in either pool).
|
||||
pub fn block_count(&self, slot: usize) -> usize {
|
||||
self.seq_states[slot].as_ref().map(|s| s.block_ids.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Whether a swapped sequence at `slot` can be brought back (enough free GPU blocks).
|
||||
pub fn can_swap_in(&self, slot: usize) -> bool {
|
||||
self.allocator.can_alloc(self.block_count(slot))
|
||||
}
|
||||
|
||||
/// Whether the GPU sequence at `slot` can be evicted (enough free CPU blocks).
|
||||
pub fn can_swap_out(&self, slot: usize) -> bool {
|
||||
self.cpu_allocator.can_alloc(self.block_count(slot))
|
||||
}
|
||||
|
||||
/// Evict `slot`'s KV from GPU to pinned host memory and free its GPU blocks.
|
||||
/// The slot stays registered (location = Cpu); the sequence is paused.
|
||||
pub fn swap_out(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot].as_ref().ok_or("swap_out: empty slot")?;
|
||||
if state.location == Location::Cpu { return Ok(()); }
|
||||
let gpu_ids = state.block_ids.clone();
|
||||
let n = gpu_ids.len();
|
||||
if !self.cpu_allocator.can_alloc(n) { return Err("swap_out: CPU pool full"); }
|
||||
|
||||
let cpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.cpu_allocator.alloc().expect("checked can_alloc"))
|
||||
.collect();
|
||||
|
||||
let bb = self.block_bytes;
|
||||
for layer in 0..self.num_layers {
|
||||
for i in 0..n {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_to_host_at(&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_to_host_at(&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
for b in gpu_ids {
|
||||
self.allocator.free(b);
|
||||
}
|
||||
let state = self.seq_states[slot].as_mut().unwrap();
|
||||
state.block_ids = cpu_ids;
|
||||
state.location = Location::Cpu;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bring `slot`'s KV back from host to GPU and free its CPU blocks.
|
||||
pub fn swap_in(&mut self, slot: usize) -> Result<(), &'static str> {
|
||||
let state = self.seq_states[slot].as_ref().ok_or("swap_in: empty slot")?;
|
||||
if state.location == Location::Gpu { return Ok(()); }
|
||||
let cpu_ids = state.block_ids.clone();
|
||||
let n = cpu_ids.len();
|
||||
if !self.allocator.can_alloc(n) { return Err("swap_in: GPU pool full"); }
|
||||
|
||||
let gpu_ids: Vec<u32> = (0..n)
|
||||
.map(|_| self.allocator.alloc().expect("checked can_alloc"))
|
||||
.collect();
|
||||
|
||||
let bb = self.block_bytes;
|
||||
for layer in 0..self.num_layers {
|
||||
for i in 0..n {
|
||||
let g_off = gpu_ids[i] as usize * bb;
|
||||
let c_off = cpu_ids[i] as usize * bb;
|
||||
self.k_pools[layer]
|
||||
.copy_from_host_at(&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.unwrap();
|
||||
self.v_pools[layer]
|
||||
.copy_from_host_at(&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
for b in cpu_ids {
|
||||
self.cpu_allocator.free(b);
|
||||
}
|
||||
let state = self.seq_states[slot].as_mut().unwrap();
|
||||
state.block_ids = gpu_ids;
|
||||
state.location = Location::Gpu;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn tensor_from_owned_buf(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
use smallvec::SmallVec;
|
||||
use xserv_tensor::shape::contiguous_strides;
|
||||
use xserv_tensor::storage::Storage;
|
||||
|
||||
let storage = Storage::cuda(buf, device);
|
||||
Tensor::from_storage(
|
||||
storage,
|
||||
SmallVec::from_slice(shape),
|
||||
contiguous_strides(shape),
|
||||
0,
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ use xserv_tensor::{DType, Device, Tensor};
|
||||
use crate::config::ModelConfig;
|
||||
use crate::gpt2::KVCache;
|
||||
use crate::kv_cache::GpuKVCache;
|
||||
use crate::paged_kv_cache::PagedKVCache;
|
||||
|
||||
pub struct Qwen3 {
|
||||
pub config: ModelConfig,
|
||||
@@ -255,6 +256,196 @@ impl Qwen3 {
|
||||
matmul_2d(&x, &self.lm_head_t) // [B, vocab_size]
|
||||
}
|
||||
|
||||
/// Paged decode: process one token per sequence using a shared paged KV cache.
|
||||
///
|
||||
/// tokens: [B] one token per sequence
|
||||
/// positions: [B] current logical position (BEFORE this step) per sequence
|
||||
/// seq_slots: [B] slot ids in `paged_cache`
|
||||
pub fn forward_decode_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
seq_slots: &[usize],
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let batch = tokens.len();
|
||||
assert_eq!(positions.len(), batch);
|
||||
assert_eq!(seq_slots.len(), batch);
|
||||
assert!(batch > 0);
|
||||
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
// Ensure all slots have enough physical blocks for this token, then
|
||||
// upload block tables + context_lens once for the whole forward (the
|
||||
// tables are identical across layers; only the layer's K/V pool changes).
|
||||
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
|
||||
for (b, &slot) in seq_slots.iter().enumerate() {
|
||||
paged_cache.ensure_capacity(slot, positions[b] + 1);
|
||||
}
|
||||
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
|
||||
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
// Batched embedding: [B, hidden]
|
||||
let mut x = embedding(&self.embed_tokens, tokens);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let q_all = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k_all = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v_all = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
|
||||
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
for b in 0..batch {
|
||||
let q_row = row_view(&q_all, b);
|
||||
let k_row = row_view(&k_all, b);
|
||||
let v_row = row_view(&v_all, b);
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim);
|
||||
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
let pos = [positions[b] as u32];
|
||||
rope_inplace(&q, &self.rope_cache, &pos);
|
||||
rope_inplace(&k, &self.rope_cache, &pos);
|
||||
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
paged_cache.append_tokens(seq_slots[b], layer_idx, &k, &v, 1, positions[b]);
|
||||
|
||||
let q_flat = xserv_kernels::merge_heads_gpu(&q, 1, num_heads, head_dim);
|
||||
q_rows.push(q_flat);
|
||||
}
|
||||
|
||||
let q_batched_2d = concat_rows(&q_rows);
|
||||
// q_batched_2d: [B, num_heads * head_dim]. Memory is [B, H, D] —
|
||||
// a plain reshape view to [B, H, 1, D] is what the paged kernel expects.
|
||||
let q_4d = q_batched_2d.reshape(&[batch, num_heads, 1, head_dim]);
|
||||
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
);
|
||||
|
||||
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
|
||||
// Plain reshape is a view; merge_heads_gpu would incorrectly swap B<->H.
|
||||
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
// Advance logical seq_len now that all layers have been written.
|
||||
for &slot in seq_slots {
|
||||
paged_cache.advance_seq_len(slot, 1);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Paged prefill: write a sequence of `new_tokens` K/V into the paged
|
||||
/// cache for `slot`, run flash attention via gathered contiguous K/V.
|
||||
/// Returns logits [new_tokens, vocab_size].
|
||||
pub fn forward_prefill_paged(
|
||||
&self,
|
||||
token_ids: &[u32],
|
||||
slot: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
// Pre-allocate enough blocks and bump seq_len up-front so per-layer
|
||||
// gather_kv_contiguous returns the freshly written K/V range.
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let q = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
rope_inplace(&q, &self.rope_cache, &positions);
|
||||
rope_inplace(&k, &self.rope_cache, &positions);
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// Write into paged pool at the original (pre-advance) position.
|
||||
paged_cache.append_tokens(slot, layer_idx, &k, &v, new_tokens, pos_offset);
|
||||
|
||||
// Gather contiguous K/V for the full sequence (seq_len already includes new_tokens).
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
@@ -320,6 +511,40 @@ impl Qwen3 {
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Extract weight pointers for CUDA Graph capture.
|
||||
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
|
||||
self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs {
|
||||
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
|
||||
q_proj_wt: l.q_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
k_proj_wt: l.k_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
v_proj_wt: l.v_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
|
||||
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
|
||||
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
|
||||
gate_proj_wt: l.gate_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
up_proj_wt: l.up_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
}).collect()
|
||||
}
|
||||
|
||||
/// Get pointers needed for CUDA Graph capture.
|
||||
pub fn graph_capture_ptrs(&self) -> (
|
||||
*const std::ffi::c_void, // norm weight
|
||||
*const std::ffi::c_void, // lm_head_t
|
||||
*const std::ffi::c_void, // embed_tokens
|
||||
*const std::ffi::c_void, // rope cos
|
||||
*const std::ffi::c_void, // rope sin
|
||||
) {
|
||||
(
|
||||
self.norm.data_ptr() as *const std::ffi::c_void,
|
||||
self.lm_head_t.data_ptr() as *const std::ffi::c_void,
|
||||
self.embed_tokens.data_ptr() as *const std::ffi::c_void,
|
||||
self.rope_cache.cos.as_ptr() as *const std::ffi::c_void,
|
||||
self.rope_cache.sin.as_ptr() as *const std::ffi::c_void,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
@@ -41,6 +41,7 @@ enum MergeEntry {
|
||||
struct AddedToken {
|
||||
id: u32,
|
||||
content: String,
|
||||
#[allow(dead_code)]
|
||||
special: bool,
|
||||
}
|
||||
|
||||
@@ -90,21 +91,22 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
// Added tokens are matched as indivisible tokens by HF tokenizers,
|
||||
// even when their `special` flag is false (for example Qwen3's
|
||||
// <think> and </think> tokens).
|
||||
let mut special_tokens = HashMap::new();
|
||||
let mut special_token_ids = HashMap::new();
|
||||
let mut eos_token_id = None;
|
||||
for at in &tj.added_tokens {
|
||||
if at.special {
|
||||
special_tokens.insert(at.content.clone(), at.id);
|
||||
special_token_ids.insert(at.id, at.content.clone());
|
||||
decoder.resize(decoder.len().max(at.id as usize + 1), vec![]);
|
||||
decoder[at.id as usize] = at.content.as_bytes().to_vec();
|
||||
if at.content == "<|endoftext|>" || at.content == "<|end_of_text|>" {
|
||||
eos_token_id = Some(at.id);
|
||||
}
|
||||
}
|
||||
special_tokens.insert(at.content.clone(), at.id);
|
||||
special_token_ids.insert(at.id, at.content.clone());
|
||||
decoder.resize(decoder.len().max(at.id as usize + 1), vec![]);
|
||||
decoder[at.id as usize] = at.content.as_bytes().to_vec();
|
||||
}
|
||||
let eos_token_id = special_tokens
|
||||
.get("<|im_end|>")
|
||||
.or_else(|| special_tokens.get("<|end_of_text|>"))
|
||||
.or_else(|| special_tokens.get("<|endoftext|>"))
|
||||
.copied();
|
||||
|
||||
// Pre-tokenization regex
|
||||
let pre_tokenize_re = if byte_fallback {
|
||||
@@ -230,6 +232,19 @@ impl Tokenizer {
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
}
|
||||
|
||||
pub fn decode_token_stream(&self, token_id: u32, pending: &mut Vec<u8>) -> String {
|
||||
if let Some(bytes) = self.decoder.get(token_id as usize) {
|
||||
pending.extend_from_slice(bytes);
|
||||
}
|
||||
take_valid_utf8(pending)
|
||||
}
|
||||
|
||||
pub fn flush_decode_stream(&self, pending: &mut Vec<u8>) -> String {
|
||||
let text = String::from_utf8_lossy(pending).into_owned();
|
||||
pending.clear();
|
||||
text
|
||||
}
|
||||
|
||||
pub fn eos_token_id(&self) -> Option<u32> {
|
||||
self.eos_token_id
|
||||
}
|
||||
@@ -250,6 +265,31 @@ fn token_str_to_bytes(s: &str) -> Vec<u8> {
|
||||
s.chars().map(|c| unicode_to_byte(c)).collect()
|
||||
}
|
||||
|
||||
fn take_valid_utf8(pending: &mut Vec<u8>) -> String {
|
||||
match std::str::from_utf8(pending) {
|
||||
Ok(text) => {
|
||||
let text = text.to_string();
|
||||
pending.clear();
|
||||
text
|
||||
}
|
||||
Err(err) => {
|
||||
let valid_up_to = err.valid_up_to();
|
||||
if valid_up_to == 0 {
|
||||
if let Some(error_len) = err.error_len() {
|
||||
let invalid_len = error_len.min(pending.len());
|
||||
let text = String::from_utf8_lossy(&pending[..invalid_len]).into_owned();
|
||||
pending.drain(..invalid_len);
|
||||
return text;
|
||||
}
|
||||
return String::new();
|
||||
}
|
||||
let text = String::from_utf8_lossy(&pending[..valid_up_to]).into_owned();
|
||||
pending.drain(..valid_up_to);
|
||||
text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a Unicode char back to the byte it represents in GPT-2 encoding.
|
||||
fn unicode_to_byte(c: char) -> u8 {
|
||||
// Build the inverse map on first use
|
||||
@@ -279,3 +319,49 @@ fn unicode_to_byte(c: char) -> u8 {
|
||||
panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{take_valid_utf8, Tokenizer};
|
||||
|
||||
#[test]
|
||||
fn qwen_added_tokens_are_indivisible_and_im_end_is_eos() {
|
||||
let path =
|
||||
std::env::temp_dir().join(format!("xserv-tokenizer-test-{}.json", std::process::id()));
|
||||
std::fs::write(
|
||||
&path,
|
||||
r#"{
|
||||
"model": {
|
||||
"vocab": {},
|
||||
"merges": [],
|
||||
"byte_fallback": false
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id":151643,"content":"<|endoftext|>","special":true},
|
||||
{"id":151644,"content":"<|im_start|>","special":true},
|
||||
{"id":151645,"content":"<|im_end|>","special":true},
|
||||
{"id":151667,"content":"<think>","special":false},
|
||||
{"id":151668,"content":"</think>","special":false}
|
||||
]
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&path);
|
||||
let _ = std::fs::remove_file(&path);
|
||||
|
||||
assert_eq!(tokenizer.eos_token_id(), Some(151645));
|
||||
assert_eq!(tokenizer.encode("<think>"), vec![151667]);
|
||||
assert_eq!(tokenizer.encode("</think>"), vec![151668]);
|
||||
assert_eq!(tokenizer.decode(&[151645]), "<|im_end|>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_decode_buffers_incomplete_utf8() {
|
||||
let mut pending = vec![0xF0, 0x9F];
|
||||
assert_eq!(take_valid_utf8(&mut pending), "");
|
||||
pending.extend_from_slice(&[0x98, 0x8A, b'!']);
|
||||
assert_eq!(take_valid_utf8(&mut pending), "😊!");
|
||||
assert!(pending.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user