xserv-chat: harmony channel routing + repetition penalty for gpt-oss
- Let the model generate its own <|channel|> routing instead of forcing <|channel|>final<|message|> — matches the GGUF chat template behavior. - State machine tracks harmony channels: analysis channel rendered gray, final channel printed normally, <|end|> stops on final channel only. - Add repetition penalty (default 1.3 for MoE, 1.0 for Qwen) with 512 token window to prevent greedy decode loops. Configurable via XSERV_REP_PENALTY and XSERV_REP_WINDOW env vars. - Fix Length path: use <|end|> instead of <|im_end|> for gpt-oss to avoid poisoning the KV cache with garbage tokens on truncation. - Server api.rs: append <|channel|>final<|message|> to the hardcoded gpt-oss prompt (server expects to post-process the JSON output). - Add startswith filter to minijinja for harmony template compatibility. Known issue: gpt-oss multi-turn NaN when total context exceeds ~256 tokens — likely a flash_attention_sinks kernel bug with sliding window layers at large kv_len + small q_len. Single-turn and short multi-turn conversations work correctly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -4,7 +4,7 @@ use std::path::PathBuf;
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::thread;
|
||||
|
||||
use xserv_model::{loader, sample, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_model::{loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
@@ -404,7 +404,8 @@ fn main() {
|
||||
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
|
||||
}
|
||||
Finish::Length => {
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
|
||||
let end_text = if is_moe { "<|end|>\n" } else { "<|im_end|>\n" };
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, end_text, &tp_handle);
|
||||
}
|
||||
}
|
||||
println!();
|
||||
@@ -595,7 +596,7 @@ fn build_turn_prompt_gpt_oss(
|
||||
prompt.push_str("<|start|>user<|message|>");
|
||||
prompt.push_str(user_input);
|
||||
prompt.push_str("<|end|>");
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt.push_str("<|start|>assistant");
|
||||
prompt
|
||||
}
|
||||
|
||||
@@ -611,17 +612,42 @@ fn generate_with_paged_cache(
|
||||
is_moe: bool,
|
||||
) -> Finish {
|
||||
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
|
||||
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
|
||||
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
|
||||
let harmony_special: Vec<u32> = if is_moe {
|
||||
["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"]
|
||||
.iter().filter_map(|s| tokenizer.special_token_id(s)).collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
// Harmony channel state: "final" channel text is printed normally,
|
||||
// "analysis" channel is rendered as thinking (gray). After <|channel|>
|
||||
// we read the channel name tokens until <|message|>.
|
||||
#[derive(PartialEq, Clone, Copy)]
|
||||
enum HarmonyState { Normal, ReadingChannel, InAnalysis, InFinal }
|
||||
let mut hstate = if is_moe { HarmonyState::InFinal } else { HarmonyState::Normal };
|
||||
|
||||
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(if is_moe { 1.3 } else { 1.0 });
|
||||
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(512);
|
||||
let mut history: Vec<u32> = Vec::new();
|
||||
|
||||
let pick = |logits: &xserv_tensor::Tensor, sp: &SamplingParams, hist: &[u32]| -> u32 {
|
||||
if rep_penalty > 1.0 && sp.temperature == 0.0 {
|
||||
let start = hist.len().saturating_sub(rep_window);
|
||||
sample_greedy_penalized(logits, &hist[start..], rep_penalty)
|
||||
} else {
|
||||
sample(logits, sp)
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
|
||||
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
let mut next = sample(&logits, sampling);
|
||||
let mut next = pick(&logits, sampling, &history);
|
||||
let mut decode_buffer = Vec::new();
|
||||
let mut in_thinking = false;
|
||||
|
||||
@@ -630,7 +656,7 @@ fn generate_with_paged_cache(
|
||||
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
|
||||
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
|
||||
if let Some(h) = tp { h.wait(); }
|
||||
if tokenizer.is_eos(next) || harmony_end_id == Some(next) {
|
||||
if tokenizer.is_eos(next) {
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(&mut decode_buffer),
|
||||
in_thinking,
|
||||
@@ -639,9 +665,63 @@ fn generate_with_paged_cache(
|
||||
io::stdout().flush().unwrap();
|
||||
return Finish::Stop { token_id: next };
|
||||
}
|
||||
if harmony_end_id == Some(next) {
|
||||
// <|end|> closes current segment; if in final channel, we're done
|
||||
print_stream_text(
|
||||
&tokenizer.flush_decode_stream(&mut decode_buffer),
|
||||
in_thinking,
|
||||
use_color,
|
||||
);
|
||||
if hstate == HarmonyState::InFinal {
|
||||
io::stdout().flush().unwrap();
|
||||
return Finish::Stop { token_id: next };
|
||||
}
|
||||
hstate = HarmonyState::Normal;
|
||||
next = pick(&logits, sampling, &history);
|
||||
continue;
|
||||
}
|
||||
|
||||
history.push(next);
|
||||
|
||||
// Harmony channel routing state machine
|
||||
if harmony_channel_id == Some(next) {
|
||||
decode_buffer.clear();
|
||||
hstate = HarmonyState::ReadingChannel;
|
||||
next = pick(&logits, sampling, &history);
|
||||
continue;
|
||||
}
|
||||
if harmony_message_id == Some(next) {
|
||||
if hstate == HarmonyState::ReadingChannel {
|
||||
// Channel name was accumulated but we don't need to parse it —
|
||||
// we just check via the channel_name buffer below
|
||||
}
|
||||
decode_buffer.clear();
|
||||
next = pick(&logits, sampling, &history);
|
||||
continue;
|
||||
}
|
||||
if hstate == HarmonyState::ReadingChannel {
|
||||
// Reading channel name tokens (e.g. "final", "analysis")
|
||||
let tok_text = tokenizer.decode(&[next]);
|
||||
if tok_text.contains("final") {
|
||||
hstate = HarmonyState::InFinal;
|
||||
in_thinking = false;
|
||||
} else {
|
||||
hstate = HarmonyState::InAnalysis;
|
||||
in_thinking = use_color; // render analysis as gray
|
||||
}
|
||||
next = pick(&logits, sampling, &history);
|
||||
continue;
|
||||
}
|
||||
if harmony_special.contains(&next) {
|
||||
next = sample(&logits, sampling);
|
||||
next = pick(&logits, sampling, &history);
|
||||
continue;
|
||||
}
|
||||
if hstate == HarmonyState::InAnalysis {
|
||||
// Analysis channel: render as thinking (gray) if color enabled, skip if not
|
||||
if use_color {
|
||||
print_generated_token(tokenizer, next, &mut decode_buffer, &mut in_thinking, use_color);
|
||||
}
|
||||
next = pick(&logits, sampling, &history);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -653,7 +733,7 @@ fn generate_with_paged_cache(
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
next = sample(&logits, sampling);
|
||||
next = pick(&logits, sampling, &history);
|
||||
}
|
||||
|
||||
print_stream_text(
|
||||
|
||||
@@ -114,6 +114,11 @@ impl ChatTemplate {
|
||||
env.add_function("strftime_now", strftime_now);
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// Python str methods used by harmony/gpt-oss templates.
|
||||
env.add_filter("startswith", |s: String, prefix: String| -> bool {
|
||||
s.starts_with(&prefix)
|
||||
});
|
||||
|
||||
env.add_template("chat", &self.source)?;
|
||||
let tmpl = env.get_template("chat")?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user