xserv-chat: render gpt-oss multi-turn as canonical harmony (drop CoT)

Re-render the whole conversation each turn and re-prefill into a freshly
cleared slot, with past assistant messages rendered as completed `final`
channels (analysis dropped, terminated with <|end|> not the <|return|>
stop token) — matching the model's training format and the server's
builder. The previous incremental cache kept every turn's chain-of-thought
plus <|return|> in context, which is out of distribution for harmony
multi-turn. The generator now returns the final-channel text to feed back
as history. Qwen3 keeps the incremental cache (its ChatML format is
unaffected); reset_slot factors out the free+re-register.

NOTE: this corrects the multi-turn *format* but does NOT cure the
long-context collapse on some inputs. That is a forward-pass numerical bug
(NaN / degenerate logits) reproducible in clean bench-gpt-oss independent
of the chat layer — the collapse token is vocab_size-1 (201087), the
all-NaN argmax tie-break. Tracked separately.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-02 15:39:24 +08:00
parent c0a81c84e7
commit ea5d8ba7ea

View File

@@ -328,6 +328,13 @@ fn main() {
eprintln!("Ready (paged KV cache, tp={world}).");
eprintln!("Commands: /exit, /quit, /clear\n");
// gpt-oss multi-turn history of (user, assistant-final) text. Harmony
// requires re-rendering the conversation each turn with prior analysis
// dropped, so the moe path re-prefills from this rather than reusing an
// incremental KV cache (which would accumulate CoT + <|return|> and collapse
// at longer context). Qwen3 ignores this and keeps the incremental cache.
let mut moe_history: Vec<(String, String)> = Vec::new();
loop {
let line = match read_line_edited("user> ") {
Line::Eof => break,
@@ -341,10 +348,8 @@ fn main() {
match input {
"/exit" | "/quit" | "exit" | "quit" => break,
"/clear" => {
if let Some(h) = &tp_handle { h.send(TpCommand::Free(SLOT)); h.wait(); }
cache.free_sequence(SLOT);
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
cache.register_sequence(SLOT).expect("register chat slot");
reset_slot(&mut cache, &tp_handle);
moe_history.clear();
eprintln!("history and KV cache cleared");
continue;
}
@@ -355,21 +360,47 @@ fn main() {
_ => {}
}
if is_moe {
// Harmony multi-turn: re-render the whole conversation (prior
// analysis dropped) and re-prefill into a freshly cleared slot.
let prompt = build_conversation_gpt_oss(
opts.system_prompt.as_deref(),
&moe_history,
input,
);
let prompt_tokens = tokenizer.encode(&prompt);
if prompt_tokens.is_empty() {
continue;
}
if prompt_tokens.len() >= max_seq_len {
eprintln!(
"context full: conversation needs {} tokens >= max_seq_len {max_seq_len}; use /clear",
prompt_tokens.len()
);
continue;
}
let max_new_tokens = opts.max_tokens.min(max_seq_len - prompt_tokens.len());
reset_slot(&mut cache, &tp_handle);
print!("assistant> ");
io::stdout().flush().unwrap();
let (_finish, answer) = generate_with_paged_cache(
&model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
max_new_tokens, use_color, &tp_handle, is_moe,
);
moe_history.push((input.to_string(), answer));
println!();
continue;
}
// Qwen3: incremental KV cache — only the new turn is prefilled and the
// assistant's tokens stay cached for the next turn.
let include_system = cache.seq_len(SLOT) == 0;
let prompt = if is_moe {
build_turn_prompt_gpt_oss(
opts.system_prompt.as_deref(),
include_system,
input,
)
} else {
build_turn_prompt(
opts.system_prompt.as_deref(),
include_system,
input,
opts.enable_thinking,
)
};
let prompt = build_turn_prompt(
opts.system_prompt.as_deref(),
include_system,
input,
opts.enable_thinking,
);
let prompt_tokens = tokenizer.encode(&prompt);
if prompt_tokens.is_empty() {
continue;
@@ -388,7 +419,7 @@ fn main() {
print!("assistant> ");
io::stdout().flush().unwrap();
let finish = generate_with_paged_cache(
let (finish, _answer) = generate_with_paged_cache(
&model,
&mut cache,
&tokenizer,
@@ -404,14 +435,21 @@ fn main() {
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
}
Finish::Length => {
let end_text = if is_moe { "<|end|>\n" } else { "<|im_end|>\n" };
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, end_text, &tp_handle);
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
}
}
println!();
}
}
/// Free and re-register the single chat KV slot (clears all cached context).
fn reset_slot(cache: &mut PagedKVCache, tp: &Option<TpHandle>) {
if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); }
cache.free_sequence(SLOT);
if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); }
cache.register_sequence(SLOT).expect("register chat slot");
}
fn parse_args() -> CliOptions {
let args: Vec<String> = std::env::args().skip(1).collect();
if args.is_empty() || args.iter().any(|a| a == "--help" || a == "-h") {
@@ -575,37 +613,48 @@ fn build_turn_prompt(
prompt
}
fn build_turn_prompt_gpt_oss(
/// Render the full gpt-oss harmony conversation for re-prefill. gpt-oss was
/// trained on this exact system-message structure (identity / knowledge cutoff
/// / current date / Reasoning level / channels — see the model's
/// chat_template.jinja `build_system_message`). A hand-rolled substitute puts
/// the model out of distribution and destabilizes channel selection.
///
/// Harmony multi-turn drops prior chain-of-thought: past assistant messages are
/// rendered as completed `final` channels ending in `<|end|>` (not the
/// `<|return|>` stop token). Keeping the analysis + `<|return|>` of every turn
/// in context — as an incremental KV cache does — is out of distribution and
/// makes the model collapse at longer context. "Reasoning: low" keeps the
/// analysis channel short for an interactive chat.
fn build_conversation_gpt_oss(
system: Option<&str>,
include_system: bool,
user_input: &str,
history: &[(String, String)],
current_user: &str,
) -> String {
let mut prompt = String::new();
if include_system {
// Canonical harmony system message. gpt-oss was trained on this exact
// structure (identity / knowledge cutoff / current date / Reasoning
// level / channels — see the model's chat_template.jinja). A hand-rolled
// substitute puts the model out of distribution: channel selection
// destabilizes and greedy decoding flips into garbage or analysis loops
// that never reach the `final` channel. "Reasoning: low" keeps the
// analysis channel short for an interactive chat.
prompt.push_str("<|start|>system<|message|>");
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
prompt.push_str("Knowledge cutoff: 2024-06\n");
prompt.push_str(&format!("Current date: {}\n\n", today_ymd()));
prompt.push_str("Reasoning: low\n\n");
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
prompt.push_str("<|end|>");
if let Some(sys) = system {
if !sys.trim().is_empty() {
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
prompt.push_str(sys.trim());
prompt.push_str("<|end|>");
}
prompt.push_str("<|start|>system<|message|>");
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
prompt.push_str("Knowledge cutoff: 2024-06\n");
prompt.push_str(&format!("Current date: {}\n\n", today_ymd()));
prompt.push_str("Reasoning: low\n\n");
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
prompt.push_str("<|end|>");
if let Some(sys) = system {
if !sys.trim().is_empty() {
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
prompt.push_str(sys.trim());
prompt.push_str("<|end|>");
}
}
for (user, assistant) in history {
prompt.push_str("<|start|>user<|message|>");
prompt.push_str(user);
prompt.push_str("<|end|>");
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
prompt.push_str(assistant.trim());
prompt.push_str("<|end|>");
}
prompt.push_str("<|start|>user<|message|>");
prompt.push_str(user_input);
prompt.push_str(current_user);
prompt.push_str("<|end|>");
prompt.push_str("<|start|>assistant");
prompt
@@ -639,7 +688,7 @@ fn generate_with_paged_cache(
use_color: bool,
tp: &Option<TpHandle>,
is_moe: bool,
) -> Finish {
) -> (Finish, String) {
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
@@ -684,6 +733,11 @@ fn generate_with_paged_cache(
let mut next = pick(&logits, sampling, &history);
let mut decode_buffer = Vec::new();
let mut in_thinking = false;
// Visible answer tokens, returned for multi-turn history. For moe this is
// the final-channel content only (analysis is suppressed/gray); for Qwen3
// it is everything printed. The caller decodes these into the assistant
// message it re-renders into the next prompt.
let mut answer_ids: Vec<u32> = Vec::new();
for _ in 0..max_tokens {
let position = cache.seq_len(SLOT);
@@ -697,7 +751,7 @@ fn generate_with_paged_cache(
use_color,
);
io::stdout().flush().unwrap();
return Finish::Stop { token_id: next };
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
}
if harmony_end_id == Some(next) {
// <|end|> closes current segment; if in final channel, we're done
@@ -708,7 +762,7 @@ fn generate_with_paged_cache(
);
if hstate == HarmonyState::InFinal {
io::stdout().flush().unwrap();
return Finish::Stop { token_id: next };
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
}
hstate = HarmonyState::Normal;
next = pick(&logits, sampling, &history);
@@ -767,6 +821,7 @@ fn generate_with_paged_cache(
continue;
}
answer_ids.push(next);
print_generated_token(
tokenizer,
next,
@@ -784,7 +839,7 @@ fn generate_with_paged_cache(
use_color,
);
io::stdout().flush().unwrap();
Finish::Length
(Finish::Length, tokenizer.decode(&answer_ids))
}
fn append_after_stop(