From ea5d8ba7ea1bc35af69c570860d7acfb9d2c9aff Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 2 Jun 2026 15:39:24 +0800 Subject: [PATCH] xserv-chat: render gpt-oss multi-turn as canonical harmony (drop CoT) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Re-render the whole conversation each turn and re-prefill into a freshly cleared slot, with past assistant messages rendered as completed `final` channels (analysis dropped, terminated with <|end|> not the <|return|> stop token) — matching the model's training format and the server's builder. The previous incremental cache kept every turn's chain-of-thought plus <|return|> in context, which is out of distribution for harmony multi-turn. The generator now returns the final-channel text to feed back as history. Qwen3 keeps the incremental cache (its ChatML format is unaffected); reset_slot factors out the free+re-register. NOTE: this corrects the multi-turn *format* but does NOT cure the long-context collapse on some inputs. That is a forward-pass numerical bug (NaN / degenerate logits) reproducible in clean bench-gpt-oss independent of the chat layer — the collapse token is vocab_size-1 (201087), the all-NaN argmax tie-break. Tracked separately. Co-Authored-By: Claude Opus 4.8 --- crates/xserv-model/src/bin/xserv-chat.rs | 155 +++++++++++++++-------- 1 file changed, 105 insertions(+), 50 deletions(-) diff --git a/crates/xserv-model/src/bin/xserv-chat.rs b/crates/xserv-model/src/bin/xserv-chat.rs index 0e2c700..4dbe105 100644 --- a/crates/xserv-model/src/bin/xserv-chat.rs +++ b/crates/xserv-model/src/bin/xserv-chat.rs @@ -328,6 +328,13 @@ fn main() { eprintln!("Ready (paged KV cache, tp={world})."); eprintln!("Commands: /exit, /quit, /clear\n"); + // gpt-oss multi-turn history of (user, assistant-final) text. Harmony + // requires re-rendering the conversation each turn with prior analysis + // dropped, so the moe path re-prefills from this rather than reusing an + // incremental KV cache (which would accumulate CoT + <|return|> and collapse + // at longer context). Qwen3 ignores this and keeps the incremental cache. + let mut moe_history: Vec<(String, String)> = Vec::new(); + loop { let line = match read_line_edited("user> ") { Line::Eof => break, @@ -341,10 +348,8 @@ fn main() { match input { "/exit" | "/quit" | "exit" | "quit" => break, "/clear" => { - if let Some(h) = &tp_handle { h.send(TpCommand::Free(SLOT)); h.wait(); } - cache.free_sequence(SLOT); - if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); } - cache.register_sequence(SLOT).expect("register chat slot"); + reset_slot(&mut cache, &tp_handle); + moe_history.clear(); eprintln!("history and KV cache cleared"); continue; } @@ -355,21 +360,47 @@ fn main() { _ => {} } + if is_moe { + // Harmony multi-turn: re-render the whole conversation (prior + // analysis dropped) and re-prefill into a freshly cleared slot. + let prompt = build_conversation_gpt_oss( + opts.system_prompt.as_deref(), + &moe_history, + input, + ); + let prompt_tokens = tokenizer.encode(&prompt); + if prompt_tokens.is_empty() { + continue; + } + if prompt_tokens.len() >= max_seq_len { + eprintln!( + "context full: conversation needs {} tokens >= max_seq_len {max_seq_len}; use /clear", + prompt_tokens.len() + ); + continue; + } + let max_new_tokens = opts.max_tokens.min(max_seq_len - prompt_tokens.len()); + reset_slot(&mut cache, &tp_handle); + print!("assistant> "); + io::stdout().flush().unwrap(); + let (_finish, answer) = generate_with_paged_cache( + &model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling, + max_new_tokens, use_color, &tp_handle, is_moe, + ); + moe_history.push((input.to_string(), answer)); + println!(); + continue; + } + + // Qwen3: incremental KV cache — only the new turn is prefilled and the + // assistant's tokens stay cached for the next turn. let include_system = cache.seq_len(SLOT) == 0; - let prompt = if is_moe { - build_turn_prompt_gpt_oss( - opts.system_prompt.as_deref(), - include_system, - input, - ) - } else { - build_turn_prompt( - opts.system_prompt.as_deref(), - include_system, - input, - opts.enable_thinking, - ) - }; + let prompt = build_turn_prompt( + opts.system_prompt.as_deref(), + include_system, + input, + opts.enable_thinking, + ); let prompt_tokens = tokenizer.encode(&prompt); if prompt_tokens.is_empty() { continue; @@ -388,7 +419,7 @@ fn main() { print!("assistant> "); io::stdout().flush().unwrap(); - let finish = generate_with_paged_cache( + let (finish, _answer) = generate_with_paged_cache( &model, &mut cache, &tokenizer, @@ -404,14 +435,21 @@ fn main() { append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle); } Finish::Length => { - let end_text = if is_moe { "<|end|>\n" } else { "<|im_end|>\n" }; - append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, end_text, &tp_handle); + append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle); } } println!(); } } +/// Free and re-register the single chat KV slot (clears all cached context). +fn reset_slot(cache: &mut PagedKVCache, tp: &Option) { + if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); } + cache.free_sequence(SLOT); + if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); } + cache.register_sequence(SLOT).expect("register chat slot"); +} + fn parse_args() -> CliOptions { let args: Vec = std::env::args().skip(1).collect(); if args.is_empty() || args.iter().any(|a| a == "--help" || a == "-h") { @@ -575,37 +613,48 @@ fn build_turn_prompt( prompt } -fn build_turn_prompt_gpt_oss( +/// Render the full gpt-oss harmony conversation for re-prefill. gpt-oss was +/// trained on this exact system-message structure (identity / knowledge cutoff +/// / current date / Reasoning level / channels — see the model's +/// chat_template.jinja `build_system_message`). A hand-rolled substitute puts +/// the model out of distribution and destabilizes channel selection. +/// +/// Harmony multi-turn drops prior chain-of-thought: past assistant messages are +/// rendered as completed `final` channels ending in `<|end|>` (not the +/// `<|return|>` stop token). Keeping the analysis + `<|return|>` of every turn +/// in context — as an incremental KV cache does — is out of distribution and +/// makes the model collapse at longer context. "Reasoning: low" keeps the +/// analysis channel short for an interactive chat. +fn build_conversation_gpt_oss( system: Option<&str>, - include_system: bool, - user_input: &str, + history: &[(String, String)], + current_user: &str, ) -> String { let mut prompt = String::new(); - if include_system { - // Canonical harmony system message. gpt-oss was trained on this exact - // structure (identity / knowledge cutoff / current date / Reasoning - // level / channels — see the model's chat_template.jinja). A hand-rolled - // substitute puts the model out of distribution: channel selection - // destabilizes and greedy decoding flips into garbage or analysis loops - // that never reach the `final` channel. "Reasoning: low" keeps the - // analysis channel short for an interactive chat. - prompt.push_str("<|start|>system<|message|>"); - prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n"); - prompt.push_str("Knowledge cutoff: 2024-06\n"); - prompt.push_str(&format!("Current date: {}\n\n", today_ymd())); - prompt.push_str("Reasoning: low\n\n"); - prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message."); - prompt.push_str("<|end|>"); - if let Some(sys) = system { - if !sys.trim().is_empty() { - prompt.push_str("<|start|>developer<|message|># Instructions\n\n"); - prompt.push_str(sys.trim()); - prompt.push_str("<|end|>"); - } + prompt.push_str("<|start|>system<|message|>"); + prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n"); + prompt.push_str("Knowledge cutoff: 2024-06\n"); + prompt.push_str(&format!("Current date: {}\n\n", today_ymd())); + prompt.push_str("Reasoning: low\n\n"); + prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message."); + prompt.push_str("<|end|>"); + if let Some(sys) = system { + if !sys.trim().is_empty() { + prompt.push_str("<|start|>developer<|message|># Instructions\n\n"); + prompt.push_str(sys.trim()); + prompt.push_str("<|end|>"); } } + for (user, assistant) in history { + prompt.push_str("<|start|>user<|message|>"); + prompt.push_str(user); + prompt.push_str("<|end|>"); + prompt.push_str("<|start|>assistant<|channel|>final<|message|>"); + prompt.push_str(assistant.trim()); + prompt.push_str("<|end|>"); + } prompt.push_str("<|start|>user<|message|>"); - prompt.push_str(user_input); + prompt.push_str(current_user); prompt.push_str("<|end|>"); prompt.push_str("<|start|>assistant"); prompt @@ -639,7 +688,7 @@ fn generate_with_paged_cache( use_color: bool, tp: &Option, is_moe: bool, -) -> Finish { +) -> (Finish, String) { let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None }; let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None }; let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None }; @@ -684,6 +733,11 @@ fn generate_with_paged_cache( let mut next = pick(&logits, sampling, &history); let mut decode_buffer = Vec::new(); let mut in_thinking = false; + // Visible answer tokens, returned for multi-turn history. For moe this is + // the final-channel content only (analysis is suppressed/gray); for Qwen3 + // it is everything printed. The caller decodes these into the assistant + // message it re-renders into the next prompt. + let mut answer_ids: Vec = Vec::new(); for _ in 0..max_tokens { let position = cache.seq_len(SLOT); @@ -697,7 +751,7 @@ fn generate_with_paged_cache( use_color, ); io::stdout().flush().unwrap(); - return Finish::Stop { token_id: next }; + return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids)); } if harmony_end_id == Some(next) { // <|end|> closes current segment; if in final channel, we're done @@ -708,7 +762,7 @@ fn generate_with_paged_cache( ); if hstate == HarmonyState::InFinal { io::stdout().flush().unwrap(); - return Finish::Stop { token_id: next }; + return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids)); } hstate = HarmonyState::Normal; next = pick(&logits, sampling, &history); @@ -767,6 +821,7 @@ fn generate_with_paged_cache( continue; } + answer_ids.push(next); print_generated_token( tokenizer, next, @@ -784,7 +839,7 @@ fn generate_with_paged_cache( use_color, ); io::stdout().flush().unwrap(); - Finish::Length + (Finish::Length, tokenizer.decode(&answer_ids)) } fn append_after_stop(