xserv-chat: support gpt-oss-20b with TP; fix GEMV precision bug
- Add ChatModel enum dispatching between Qwen3 and GptOss based on config.is_moe(), following the TP engine pattern. - Add --tp N flag for tensor-parallel inference (required for 39GB gpt-oss-20b which doesn't fit on a single 32GB GPU). - Add gpt-oss harmony chat template with channel/message format. - Replace hardcoded is_stop_token() with tokenizer.is_eos() for multi-model EOS support. - Restore gpt-oss hardcoded prompt template in server api.rs, lost during the Jinja template refactor. - Fix GEMV race condition: the K-split kernel zeroed the FP32 accumulator inside the kernel (block k=0) while other blocks atomicAdd'd concurrently. Pre-zero with cudaMemsetAsync instead. - Update benchmark docs with post-fix results. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,16 +1,104 @@
|
||||
use std::io::{self, IsTerminal, Read, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use xserv_model::{loader, sample, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::thread;
|
||||
|
||||
use xserv_model::{loader, sample, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
enum ChatModel {
|
||||
Qwen3(Qwen3),
|
||||
GptOss(GptOss),
|
||||
}
|
||||
|
||||
impl ChatModel {
|
||||
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
|
||||
match self {
|
||||
ChatModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
ChatModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
|
||||
}
|
||||
}
|
||||
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
|
||||
match self {
|
||||
ChatModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
ChatModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TP worker infrastructure (reused from tp_engine pattern)
|
||||
#[derive(Clone)]
|
||||
enum TpCommand {
|
||||
Register(usize),
|
||||
Free(usize),
|
||||
Prefill { tokens: Vec<u32>, slot: usize },
|
||||
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
|
||||
}
|
||||
|
||||
struct TpHandle {
|
||||
cmd_txs: Vec<mpsc::Sender<TpCommand>>,
|
||||
ack_rx: mpsc::Receiver<()>,
|
||||
}
|
||||
|
||||
impl TpHandle {
|
||||
fn send(&self, cmd: TpCommand) {
|
||||
for tx in &self.cmd_txs {
|
||||
tx.send(cmd.clone()).ok();
|
||||
}
|
||||
}
|
||||
fn wait(&self) {
|
||||
for _ in 0..self.cmd_txs.len() {
|
||||
self.ack_rx.recv().ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn tp_worker_loop(
|
||||
rank: usize, world: usize,
|
||||
id: xserv_distributed::UniqueId,
|
||||
model_dir: std::path::PathBuf,
|
||||
config: ModelConfig,
|
||||
max_seq_len: usize,
|
||||
cmd_rx: mpsc::Receiver<TpCommand>,
|
||||
ack_tx: mpsc::Sender<()>,
|
||||
) {
|
||||
let tp = Arc::new(xserv_distributed::TpContext::init(rank, world, id, rank as u32));
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
|
||||
let model = if config.is_moe() {
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
|
||||
} else {
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let mut cache = PagedKVCache::new_tp(
|
||||
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
|
||||
);
|
||||
while let Ok(cmd) = cmd_rx.recv() {
|
||||
match cmd {
|
||||
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
|
||||
TpCommand::Free(slot) => cache.free_sequence(slot),
|
||||
TpCommand::Prefill { tokens, slot } => {
|
||||
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
|
||||
}
|
||||
TpCommand::Decode { tokens, positions, slots } => {
|
||||
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
|
||||
}
|
||||
}
|
||||
let _ = ack_tx.send(());
|
||||
}
|
||||
}
|
||||
|
||||
const SLOT: usize = 0;
|
||||
|
||||
struct CliOptions {
|
||||
model_dir: PathBuf,
|
||||
max_tokens: usize,
|
||||
max_seq_len: usize,
|
||||
tp: usize,
|
||||
sampling: SamplingParams,
|
||||
system_prompt: Option<String>,
|
||||
enable_thinking: bool,
|
||||
@@ -168,14 +256,12 @@ fn main() {
|
||||
|
||||
let config = ModelConfig::from_file(&opts.model_dir.join("config.json"));
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
if !model_type.contains("qwen") {
|
||||
eprintln!("xserv-chat currently supports Qwen-style ChatML models only; got model_type={model_type}");
|
||||
std::process::exit(2);
|
||||
}
|
||||
let is_moe = config.is_moe();
|
||||
|
||||
let max_seq_len = opts.max_seq_len.min(config.max_seq_len()).max(1);
|
||||
eprintln!(
|
||||
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}, max_seq_len={}",
|
||||
"Model: {model_type}{}, layers={}, hidden={}, heads={}/{} kv, vocab={}, max_seq_len={}",
|
||||
if is_moe { " (MoE)" } else { "" },
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
@@ -184,17 +270,62 @@ fn main() {
|
||||
max_seq_len
|
||||
);
|
||||
|
||||
eprintln!("Loading weights...");
|
||||
let weights = loader::load_model_dir(&opts.model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
let world = opts.tp;
|
||||
if world > 1 {
|
||||
assert!(
|
||||
config.num_kv_heads() % world == 0,
|
||||
"num_kv_heads {} not divisible by tp {world}", config.num_kv_heads()
|
||||
);
|
||||
}
|
||||
|
||||
let (model, mut cache, tp_handle) = if world > 1 {
|
||||
let id = xserv_distributed::get_unique_id();
|
||||
let (ack_tx, ack_rx) = mpsc::channel::<()>();
|
||||
let mut cmd_txs = Vec::new();
|
||||
for rank in 1..world {
|
||||
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
|
||||
cmd_txs.push(ctx_tx);
|
||||
let ack_tx = ack_tx.clone();
|
||||
let model_dir = opts.model_dir.clone();
|
||||
let config = config.clone();
|
||||
thread::spawn(move || {
|
||||
tp_worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
|
||||
});
|
||||
}
|
||||
eprintln!("Loading weights (tp={world})...");
|
||||
let tp = Arc::new(xserv_distributed::TpContext::init(0, world, id, 0));
|
||||
let weights = loader::load_model_dir(&opts.model_dir, Device::Cpu);
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let m = if is_moe {
|
||||
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
|
||||
} else {
|
||||
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
|
||||
};
|
||||
let local_kv = config.num_kv_heads() / world;
|
||||
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let total_blocks = max_blocks_per_seq + 8;
|
||||
let c = PagedKVCache::new_tp(&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0);
|
||||
let h = TpHandle { cmd_txs, ack_rx };
|
||||
(m, c, Some(h))
|
||||
} else {
|
||||
eprintln!("Loading weights...");
|
||||
let weights = loader::load_model_dir(&opts.model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let m = if is_moe {
|
||||
ChatModel::GptOss(GptOss::from_weights(config.clone(), weights))
|
||||
} else {
|
||||
ChatModel::Qwen3(Qwen3::from_weights(config.clone(), weights))
|
||||
};
|
||||
let c = new_paged_cache(&config, max_seq_len);
|
||||
(m, c, None)
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
|
||||
let mut cache = new_paged_cache(&config, max_seq_len);
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
let use_color = opts.color && io::stdout().is_terminal();
|
||||
|
||||
eprintln!("Ready (paged KV cache, persistent chat slot).");
|
||||
eprintln!("Ready (paged KV cache, tp={world}).");
|
||||
eprintln!("Commands: /exit, /quit, /clear\n");
|
||||
|
||||
loop {
|
||||
@@ -210,7 +341,9 @@ fn main() {
|
||||
match input {
|
||||
"/exit" | "/quit" | "exit" | "quit" => break,
|
||||
"/clear" => {
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Free(SLOT)); h.wait(); }
|
||||
cache.free_sequence(SLOT);
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
eprintln!("history and KV cache cleared");
|
||||
continue;
|
||||
@@ -223,12 +356,20 @@ fn main() {
|
||||
}
|
||||
|
||||
let include_system = cache.seq_len(SLOT) == 0;
|
||||
let prompt = build_turn_prompt(
|
||||
opts.system_prompt.as_deref(),
|
||||
include_system,
|
||||
input,
|
||||
opts.enable_thinking,
|
||||
);
|
||||
let prompt = if is_moe {
|
||||
build_turn_prompt_gpt_oss(
|
||||
opts.system_prompt.as_deref(),
|
||||
include_system,
|
||||
input,
|
||||
)
|
||||
} else {
|
||||
build_turn_prompt(
|
||||
opts.system_prompt.as_deref(),
|
||||
include_system,
|
||||
input,
|
||||
opts.enable_thinking,
|
||||
)
|
||||
};
|
||||
let prompt_tokens = tokenizer.encode(&prompt);
|
||||
if prompt_tokens.is_empty() {
|
||||
continue;
|
||||
@@ -255,13 +396,14 @@ fn main() {
|
||||
&opts.sampling,
|
||||
max_new_tokens,
|
||||
use_color,
|
||||
&tp_handle,
|
||||
);
|
||||
match finish {
|
||||
Finish::Stop { token_id } => {
|
||||
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id);
|
||||
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
|
||||
}
|
||||
Finish::Length => {
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n");
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
|
||||
}
|
||||
}
|
||||
println!();
|
||||
@@ -277,6 +419,7 @@ fn parse_args() -> CliOptions {
|
||||
let mut model_dir = None;
|
||||
let mut max_tokens = 256usize;
|
||||
let mut max_seq_len = 2048usize;
|
||||
let mut tp = 1usize;
|
||||
let mut temperature = 0.0f32;
|
||||
let mut top_k = 0usize;
|
||||
let mut top_p = 1.0f32;
|
||||
@@ -299,6 +442,10 @@ fn parse_args() -> CliOptions {
|
||||
i += 1;
|
||||
max_seq_len = parse_value(&args, i, "--max-seq-len");
|
||||
}
|
||||
"--tp" => {
|
||||
i += 1;
|
||||
tp = parse_value(&args, i, "--tp");
|
||||
}
|
||||
"--temperature" => {
|
||||
i += 1;
|
||||
temperature = parse_value(&args, i, "--temperature");
|
||||
@@ -347,6 +494,7 @@ fn parse_args() -> CliOptions {
|
||||
}),
|
||||
max_tokens: max_tokens.max(1),
|
||||
max_seq_len: max_seq_len.max(1),
|
||||
tp: tp.max(1),
|
||||
sampling: SamplingParams {
|
||||
temperature,
|
||||
top_k,
|
||||
@@ -373,6 +521,7 @@ fn print_usage_and_exit(code: i32) -> ! {
|
||||
\t-m, --model DIR Model directory\n\
|
||||
\t--max-tokens N Max generated tokens per turn (default: 256)\n\
|
||||
\t--max-seq-len N Persistent KV context length (default: 2048)\n\
|
||||
\t--tp N Tensor parallelism degree (default: 1)\n\
|
||||
\t--temperature F Sampling temperature, 0 = greedy (default: 0)\n\
|
||||
\t--top-k N Top-k sampling, 0 = disabled (default: 0)\n\
|
||||
\t--top-p F Top-p sampling (default: 1.0)\n\
|
||||
@@ -424,24 +573,54 @@ fn build_turn_prompt(
|
||||
prompt
|
||||
}
|
||||
|
||||
fn build_turn_prompt_gpt_oss(
|
||||
system: Option<&str>,
|
||||
include_system: bool,
|
||||
user_input: &str,
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
if include_system {
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
if let Some(sys) = system {
|
||||
if !sys.trim().is_empty() {
|
||||
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
|
||||
prompt.push_str(sys.trim());
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|start|>user<|message|>");
|
||||
prompt.push_str(user_input);
|
||||
prompt.push_str("<|end|>");
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt
|
||||
}
|
||||
|
||||
fn generate_with_paged_cache(
|
||||
model: &Qwen3,
|
||||
model: &ChatModel,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
prompt_tokens: &[u32],
|
||||
sampling: &SamplingParams,
|
||||
max_tokens: usize,
|
||||
use_color: bool,
|
||||
tp: &Option<TpHandle>,
|
||||
) -> Finish {
|
||||
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
|
||||
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
let mut next = sample(&logits, sampling);
|
||||
let mut decode_buffer = Vec::new();
|
||||
let mut in_thinking = false;
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let position = cache.seq_len(SLOT);
|
||||
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
|
||||
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
|
||||
if is_stop_token(tokenizer, next) {
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if tokenizer.is_eos(next) {
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(&mut decode_buffer),
|
||||
in_thinking,
|
||||
@@ -472,29 +651,31 @@ fn generate_with_paged_cache(
|
||||
}
|
||||
|
||||
fn append_after_stop(
|
||||
model: &Qwen3,
|
||||
model: &ChatModel,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
max_seq_len: usize,
|
||||
stop_token_id: u32,
|
||||
_stop_token_id: u32,
|
||||
tp: &Option<TpHandle>,
|
||||
) {
|
||||
if tokenizer.special_token_id("<|im_end|>") == Some(stop_token_id) {
|
||||
append_text_to_cache(model, cache, tokenizer, max_seq_len, "\n");
|
||||
}
|
||||
append_text_to_cache(model, cache, tokenizer, max_seq_len, "\n", tp);
|
||||
}
|
||||
|
||||
fn append_text_to_cache(
|
||||
model: &Qwen3,
|
||||
model: &ChatModel,
|
||||
cache: &mut PagedKVCache,
|
||||
tokenizer: &Tokenizer,
|
||||
max_seq_len: usize,
|
||||
text: &str,
|
||||
tp: &Option<TpHandle>,
|
||||
) {
|
||||
let tokens = tokenizer.encode(text);
|
||||
if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len {
|
||||
return;
|
||||
}
|
||||
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: tokens.clone(), slot: SLOT }); }
|
||||
let _ = model.forward_prefill_paged(&tokens, SLOT, cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
}
|
||||
|
||||
fn print_generated_token(
|
||||
@@ -541,9 +722,3 @@ fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_stop_token(tokenizer: &Tokenizer, token_id: u32) -> bool {
|
||||
tokenizer.eos_token_id() == Some(token_id)
|
||||
|| tokenizer.special_token_id("<|im_end|>") == Some(token_id)
|
||||
|| tokenizer.special_token_id("<|endoftext|>") == Some(token_id)
|
||||
|| tokenizer.special_token_id("<|end_of_text|>") == Some(token_id)
|
||||
}
|
||||
|
||||
@@ -169,8 +169,10 @@ fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
|
||||
if model_type == "gpt_oss" {
|
||||
return build_prompt_gpt_oss(messages);
|
||||
}
|
||||
// Default: Qwen3 ChatML format
|
||||
let _ = model_type;
|
||||
let mut prompt = String::new();
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
@@ -189,6 +191,41 @@ fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
|
||||
prompt
|
||||
}
|
||||
|
||||
fn build_prompt_gpt_oss(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
let dev_instructions: String = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
if !dev_instructions.is_empty() {
|
||||
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
|
||||
prompt.push_str(&dev_instructions);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
"user" => {
|
||||
prompt.push_str("<|start|>user<|message|>");
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt.push_str(&msg.content);
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTP handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user