xserv-chat: harmony channel routing + repetition penalty for gpt-oss

- Let the model generate its own <|channel|> routing instead of forcing
  <|channel|>final<|message|> — matches the GGUF chat template behavior.
- State machine tracks harmony channels: analysis channel rendered gray,
  final channel printed normally, <|end|> stops on final channel only.
- Add repetition penalty (default 1.3 for MoE, 1.0 for Qwen) with 512
  token window to prevent greedy decode loops. Configurable via
  XSERV_REP_PENALTY and XSERV_REP_WINDOW env vars.
- Fix Length path: use <|end|> instead of <|im_end|> for gpt-oss to
  avoid poisoning the KV cache with garbage tokens on truncation.
- Server api.rs: append <|channel|>final<|message|> to the hardcoded
  gpt-oss prompt (server expects to post-process the JSON output).
- Add startswith filter to minijinja for harmony template compatibility.

Known issue: gpt-oss multi-turn NaN when total context exceeds ~256
tokens — likely a flash_attention_sinks kernel bug with sliding window
layers at large kv_len + small q_len. Single-turn and short multi-turn
conversations work correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-02 12:40:17 +08:00
parent 3ee8df2c0f
commit f2e60218b4
2 changed files with 92 additions and 7 deletions

View File

@@ -4,7 +4,7 @@ use std::path::PathBuf;
use std::sync::{mpsc, Arc};
use std::thread;
use xserv_model::{loader, sample, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use xserv_model::{loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE};
use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer;
@@ -404,7 +404,8 @@ fn main() {
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
}
Finish::Length => {
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
let end_text = if is_moe { "<|end|>\n" } else { "<|im_end|>\n" };
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, end_text, &tp_handle);
}
}
println!();
@@ -595,7 +596,7 @@ fn build_turn_prompt_gpt_oss(
prompt.push_str("<|start|>user<|message|>");
prompt.push_str(user_input);
prompt.push_str("<|end|>");
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
prompt.push_str("<|start|>assistant");
prompt
}
@@ -611,17 +612,42 @@ fn generate_with_paged_cache(
is_moe: bool,
) -> Finish {
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
let harmony_special: Vec<u32> = if is_moe {
["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"]
.iter().filter_map(|s| tokenizer.special_token_id(s)).collect()
} else {
Vec::new()
};
// Harmony channel state: "final" channel text is printed normally,
// "analysis" channel is rendered as thinking (gray). After <|channel|>
// we read the channel name tokens until <|message|>.
#[derive(PartialEq, Clone, Copy)]
enum HarmonyState { Normal, ReadingChannel, InAnalysis, InFinal }
let mut hstate = if is_moe { HarmonyState::InFinal } else { HarmonyState::Normal };
let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok()
.and_then(|s| s.parse().ok())
.unwrap_or(if is_moe { 1.3 } else { 1.0 });
let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok()
.and_then(|s| s.parse().ok())
.unwrap_or(512);
let mut history: Vec<u32> = Vec::new();
let pick = |logits: &xserv_tensor::Tensor, sp: &SamplingParams, hist: &[u32]| -> u32 {
if rep_penalty > 1.0 && sp.temperature == 0.0 {
let start = hist.len().saturating_sub(rep_window);
sample_greedy_penalized(logits, &hist[start..], rep_penalty)
} else {
sample(logits, sp)
}
};
if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); }
let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache);
if let Some(h) = tp { h.wait(); }
let mut next = sample(&logits, sampling);
let mut next = pick(&logits, sampling, &history);
let mut decode_buffer = Vec::new();
let mut in_thinking = false;
@@ -630,7 +656,7 @@ fn generate_with_paged_cache(
if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); }
let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache);
if let Some(h) = tp { h.wait(); }
if tokenizer.is_eos(next) || harmony_end_id == Some(next) {
if tokenizer.is_eos(next) {
print_stream_text(
&tokenizer.flush_decode_stream(&mut decode_buffer),
in_thinking,
@@ -639,9 +665,63 @@ fn generate_with_paged_cache(
io::stdout().flush().unwrap();
return Finish::Stop { token_id: next };
}
if harmony_end_id == Some(next) {
// <|end|> closes current segment; if in final channel, we're done
print_stream_text(
&tokenizer.flush_decode_stream(&mut decode_buffer),
in_thinking,
use_color,
);
if hstate == HarmonyState::InFinal {
io::stdout().flush().unwrap();
return Finish::Stop { token_id: next };
}
hstate = HarmonyState::Normal;
next = pick(&logits, sampling, &history);
continue;
}
history.push(next);
// Harmony channel routing state machine
if harmony_channel_id == Some(next) {
decode_buffer.clear();
hstate = HarmonyState::ReadingChannel;
next = pick(&logits, sampling, &history);
continue;
}
if harmony_message_id == Some(next) {
if hstate == HarmonyState::ReadingChannel {
// Channel name was accumulated but we don't need to parse it —
// we just check via the channel_name buffer below
}
decode_buffer.clear();
next = pick(&logits, sampling, &history);
continue;
}
if hstate == HarmonyState::ReadingChannel {
// Reading channel name tokens (e.g. "final", "analysis")
let tok_text = tokenizer.decode(&[next]);
if tok_text.contains("final") {
hstate = HarmonyState::InFinal;
in_thinking = false;
} else {
hstate = HarmonyState::InAnalysis;
in_thinking = use_color; // render analysis as gray
}
next = pick(&logits, sampling, &history);
continue;
}
if harmony_special.contains(&next) {
next = sample(&logits, sampling);
next = pick(&logits, sampling, &history);
continue;
}
if hstate == HarmonyState::InAnalysis {
// Analysis channel: render as thinking (gray) if color enabled, skip if not
if use_color {
print_generated_token(tokenizer, next, &mut decode_buffer, &mut in_thinking, use_color);
}
next = pick(&logits, sampling, &history);
continue;
}
@@ -653,7 +733,7 @@ fn generate_with_paged_cache(
use_color,
);
io::stdout().flush().unwrap();
next = sample(&logits, sampling);
next = pick(&logits, sampling, &history);
}
print_stream_text(

View File

@@ -114,6 +114,11 @@ impl ChatTemplate {
env.add_function("strftime_now", strftime_now);
env.add_function("raise_exception", raise_exception);
// Python str methods used by harmony/gpt-oss templates.
env.add_filter("startswith", |s: String, prefix: String| -> bool {
s.starts_with(&prefix)
});
env.add_template("chat", &self.source)?;
let tmpl = env.get_template("chat")?;