xserv-chat: support gpt-oss-20b with TP; fix GEMV precision bug

- Add ChatModel enum dispatching between Qwen3 and GptOss based on
  config.is_moe(), following the TP engine pattern.
- Add --tp N flag for tensor-parallel inference (required for 39GB
  gpt-oss-20b which doesn't fit on a single 32GB GPU).
- Add gpt-oss harmony chat template with channel/message format.
- Replace hardcoded is_stop_token() with tokenizer.is_eos() for
  multi-model EOS support.
- Restore gpt-oss hardcoded prompt template in server api.rs, lost
  during the Jinja template refactor.
- Fix GEMV race condition: the K-split kernel zeroed the FP32
  accumulator inside the kernel (block k=0) while other blocks
  atomicAdd'd concurrently. Pre-zero with cudaMemsetAsync instead.
- Update benchmark docs with post-fix results.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-02 00:58:10 +08:00
parent 1d0ec32e8d
commit ae08896f46
4 changed files with 278 additions and 64 deletions

View File

@@ -1,16 +1,104 @@
use std::io::{self, IsTerminal, Read, Write};
use std::path::PathBuf;
use xserv_model::{loader, sample, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use std::sync::{mpsc, Arc};
use std::thread;
use xserv_model::{loader, sample, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
enum ChatModel {
Qwen3(Qwen3),
GptOss(GptOss),
}
impl ChatModel {
fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
match self {
ChatModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache),
ChatModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache),
}
}
fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> xserv_tensor::Tensor {
match self {
ChatModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache),
ChatModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache),
}
}
}
// TP worker infrastructure (reused from tp_engine pattern)
#[derive(Clone)]
enum TpCommand {
Register(usize),
Free(usize),
Prefill { tokens: Vec<u32>, slot: usize },
Decode { tokens: Vec<u32>, positions: Vec<usize>, slots: Vec<usize> },
}
struct TpHandle {
cmd_txs: Vec<mpsc::Sender<TpCommand>>,
ack_rx: mpsc::Receiver<()>,
}
impl TpHandle {
fn send(&self, cmd: TpCommand) {
for tx in &self.cmd_txs {
tx.send(cmd.clone()).ok();
}
}
fn wait(&self) {
for _ in 0..self.cmd_txs.len() {
self.ack_rx.recv().ok();
}
}
}
fn tp_worker_loop(
rank: usize, world: usize,
id: xserv_distributed::UniqueId,
model_dir: std::path::PathBuf,
config: ModelConfig,
max_seq_len: usize,
cmd_rx: mpsc::Receiver<TpCommand>,
ack_tx: mpsc::Sender<()>,
) {
let tp = Arc::new(xserv_distributed::TpContext::init(rank, world, id, rank as u32));
let weights = loader::load_model_dir(&model_dir, Device::Cpu);
let model = if config.is_moe() {
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
} else {
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let mut cache = PagedKVCache::new_tp(
&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32,
);
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); }
TpCommand::Free(slot) => cache.free_sequence(slot),
TpCommand::Prefill { tokens, slot } => {
let _ = model.forward_prefill_paged(&tokens, slot, &mut cache);
}
TpCommand::Decode { tokens, positions, slots } => {
let _ = model.forward_decode_paged(&tokens, &positions, &slots, &mut cache);
}
}
let _ = ack_tx.send(());
}
}
const SLOT: usize = 0;
struct CliOptions {
model_dir: PathBuf,
max_tokens: usize,
max_seq_len: usize,
tp: usize,
sampling: SamplingParams,
system_prompt: Option<String>,
enable_thinking: bool,
@@ -168,14 +256,12 @@ fn main() {
let config = ModelConfig::from_file(&opts.model_dir.join("config.json"));
let model_type = config.model_type.as_deref().unwrap_or("unknown");
if !model_type.contains("qwen") {
eprintln!("xserv-chat currently supports Qwen-style ChatML models only; got model_type={model_type}");
std::process::exit(2);
}
let is_moe = config.is_moe();
let max_seq_len = opts.max_seq_len.min(config.max_seq_len()).max(1);
eprintln!(
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}, max_seq_len={}",
"Model: {model_type}{}, layers={}, hidden={}, heads={}/{} kv, vocab={}, max_seq_len={}",
if is_moe { " (MoE)" } else { "" },
config.num_layers(),
config.hidden(),
config.num_heads(),
@@ -184,17 +270,62 @@ fn main() {
max_seq_len
);
eprintln!("Loading weights...");
let weights = loader::load_model_dir(&opts.model_dir, Device::Cuda(0));
eprintln!("Loaded {} tensors", weights.len());
let model = Qwen3::from_weights(config.clone(), weights);
let world = opts.tp;
if world > 1 {
assert!(
config.num_kv_heads() % world == 0,
"num_kv_heads {} not divisible by tp {world}", config.num_kv_heads()
);
}
let (model, mut cache, tp_handle) = if world > 1 {
let id = xserv_distributed::get_unique_id();
let (ack_tx, ack_rx) = mpsc::channel::<()>();
let mut cmd_txs = Vec::new();
for rank in 1..world {
let (ctx_tx, ctx_rx) = mpsc::channel::<TpCommand>();
cmd_txs.push(ctx_tx);
let ack_tx = ack_tx.clone();
let model_dir = opts.model_dir.clone();
let config = config.clone();
thread::spawn(move || {
tp_worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx);
});
}
eprintln!("Loading weights (tp={world})...");
let tp = Arc::new(xserv_distributed::TpContext::init(0, world, id, 0));
let weights = loader::load_model_dir(&opts.model_dir, Device::Cpu);
eprintln!("Loaded {} tensors", weights.len());
let m = if is_moe {
ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
} else {
ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp)))
};
let local_kv = config.num_kv_heads() / world;
let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
let total_blocks = max_blocks_per_seq + 8;
let c = PagedKVCache::new_tp(&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0);
let h = TpHandle { cmd_txs, ack_rx };
(m, c, Some(h))
} else {
eprintln!("Loading weights...");
let weights = loader::load_model_dir(&opts.model_dir, Device::Cuda(0));
eprintln!("Loaded {} tensors", weights.len());
let m = if is_moe {
ChatModel::GptOss(GptOss::from_weights(config.clone(), weights))
} else {
ChatModel::Qwen3(Qwen3::from_weights(config.clone(), weights))
};
let c = new_paged_cache(&config, max_seq_len);
(m, c, None)
};
let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json"));
let mut cache = new_paged_cache(&config, max_seq_len);
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
cache.register_sequence(SLOT).expect("register chat slot");
let use_color = opts.color && io::stdout().is_terminal();
eprintln!("Ready (paged KV cache, persistent chat slot).");
eprintln!("Ready (paged KV cache, tp={world}).");
eprintln!("Commands: /exit, /quit, /clear\n");
loop {
@@ -210,7 +341,9 @@ fn main() {
match input {
"/exit" | "/quit" | "exit" | "quit" => break,
"/clear" => {
if let Some(h) = &tp_handle { h.send(TpCommand::Free(SLOT)); h.wait(); }
cache.free_sequence(SLOT);
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
cache.register_sequence(SLOT).expect("register chat slot");
eprintln!("history and KV cache cleared");
continue;
@@ -223,12 +356,20 @@ fn main() {
}
let include_system = cache.seq_len(SLOT) == 0;
let prompt = build_turn_prompt(
opts.system_prompt.as_deref(),
include_system,
input,
opts.enable_thinking,
);
let prompt = if is_moe {
build_turn_prompt_gpt_oss(
opts.system_prompt.as_deref(),
include_system,
input,
)
} else {
build_turn_prompt(
opts.system_prompt.as_deref(),
include_system,
input,
opts.enable_thinking,
)
};
let prompt_tokens = tokenizer.encode(&prompt);
if prompt_tokens.is_empty() {
continue;
@@ -255,13 +396,14 @@ fn main() {
&opts.sampling,
max_new_tokens,
use_color,
&tp_handle,
);
match finish {
Finish::Stop { token_id } => {
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id);
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
}
Finish::Length => {
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n");
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
}
}
println!();
@@ -277,6 +419,7 @@ fn parse_args() -> CliOptions {
let mut model_dir = None;
let mut max_tokens = 256usize;
let mut max_seq_len = 2048usize;
let mut tp = 1usize;
let mut temperature = 0.0f32;
let mut top_k = 0usize;
let mut top_p = 1.0f32;
@@ -299,6 +442,10 @@ fn parse_args() -> CliOptions {
i += 1;
max_seq_len = parse_value(&args, i, "--max-seq-len");
}
"--tp" => {
i += 1;
tp = parse_value(&args, i, "--tp");
}
"--temperature" => {
i += 1;
temperature = parse_value(&args, i, "--temperature");
@@ -347,6 +494,7 @@ fn parse_args() -> CliOptions {
}),
max_tokens: max_tokens.max(1),
max_seq_len: max_seq_len.max(1),
tp: tp.max(1),
sampling: SamplingParams {
temperature,
top_k,
@@ -373,6 +521,7 @@ fn print_usage_and_exit(code: i32) -> ! {
\t-m, --model DIR Model directory\n\
\t--max-tokens N Max generated tokens per turn (default: 256)\n\
\t--max-seq-len N Persistent KV context length (default: 2048)\n\
\t--tp N Tensor parallelism degree (default: 1)\n\
\t--temperature F Sampling temperature, 0 = greedy (default: 0)\n\
\t--top-k N Top-k sampling, 0 = disabled (default: 0)\n\
\t--top-p F Top-p sampling (default: 1.0)\n\
@@ -424,24 +573,54 @@ fn build_turn_prompt(
prompt
}
fn build_turn_prompt_gpt_oss(
system: Option<&str>,
include_system: bool,
user_input: &str,
) -> String {
let mut prompt = String::new();
if include_system {
prompt.push_str("<|start|>system<|message|>");
prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.");
prompt.push_str("<|end|>");
if let Some(sys) = system {
if !sys.trim().is_empty() {
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
prompt.push_str(sys.trim());
prompt.push_str("<|end|>");
}
}
}
prompt.push_str("<|start|>user<|message|>");
prompt.push_str(user_input);
prompt.push_str("<|end|>");
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
prompt
}
fn generate_with_paged_cache(
model: &Qwen3,
model: &ChatModel,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
prompt_tokens: &[u32],
sampling: &SamplingParams,
max_tokens: usize,
use_color: bool,
tp: &Option<TpHandle>,
) -> Finish {
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
if let Some(h) = tp { h.wait(); }
let mut next = sample(&logits, sampling);
let mut decode_buffer = Vec::new();
let mut in_thinking = false;
for _ in 0..max_tokens {
let position = cache.seq_len(SLOT);
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
if is_stop_token(tokenizer, next) {
if let Some(h) = tp { h.wait(); }
if tokenizer.is_eos(next) {
print_stream_text(
&tokenizer.flush_decode_stream(&mut decode_buffer),
in_thinking,
@@ -472,29 +651,31 @@ fn generate_with_paged_cache(
}
fn append_after_stop(
model: &Qwen3,
model: &ChatModel,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
max_seq_len: usize,
stop_token_id: u32,
_stop_token_id: u32,
tp: &Option<TpHandle>,
) {
if tokenizer.special_token_id("<|im_end|>") == Some(stop_token_id) {
append_text_to_cache(model, cache, tokenizer, max_seq_len, "\n");
}
append_text_to_cache(model, cache, tokenizer, max_seq_len, "\n", tp);
}
fn append_text_to_cache(
model: &Qwen3,
model: &ChatModel,
cache: &mut PagedKVCache,
tokenizer: &Tokenizer,
max_seq_len: usize,
text: &str,
tp: &Option<TpHandle>,
) {
let tokens = tokenizer.encode(text);
if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len {
return;
}
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: tokens.clone(), slot: SLOT }); }
let _ = model.forward_prefill_paged(&tokens, SLOT, cache);
if let Some(h) = tp { h.wait(); }
}
fn print_generated_token(
@@ -541,9 +722,3 @@ fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) {
}
}
fn is_stop_token(tokenizer: &Tokenizer, token_id: u32) -> bool {
tokenizer.eos_token_id() == Some(token_id)
|| tokenizer.special_token_id("<|im_end|>") == Some(token_id)
|| tokenizer.special_token_id("<|endoftext|>") == Some(token_id)
|| tokenizer.special_token_id("<|end_of_text|>") == Some(token_id)
}

View File

@@ -169,8 +169,10 @@ fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
// ---------------------------------------------------------------------------
fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
if model_type == "gpt_oss" {
return build_prompt_gpt_oss(messages);
}
// Default: Qwen3 ChatML format
let _ = model_type;
let mut prompt = String::new();
for msg in messages {
match msg.role.as_str() {
@@ -189,6 +191,41 @@ fn build_prompt_hardcoded(messages: &[Message], model_type: &str) -> String {
prompt
}
fn build_prompt_gpt_oss(messages: &[Message]) -> String {
let mut prompt = String::new();
prompt.push_str("<|start|>system<|message|>");
prompt.push_str("You are a helpful assistant.\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.");
prompt.push_str("<|end|>");
let dev_instructions: String = messages
.iter()
.filter(|m| m.role == "system")
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n\n");
if !dev_instructions.is_empty() {
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
prompt.push_str(&dev_instructions);
prompt.push_str("<|end|>");
}
for msg in messages {
match msg.role.as_str() {
"user" => {
prompt.push_str("<|start|>user<|message|>");
prompt.push_str(&msg.content);
prompt.push_str("<|end|>");
}
"assistant" => {
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
prompt.push_str(&msg.content);
prompt.push_str("<|end|>");
}
_ => {}
}
}
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
prompt
}
// ---------------------------------------------------------------------------
// HTTP handlers
// ---------------------------------------------------------------------------