xserv-chat: render gpt-oss multi-turn as canonical harmony (drop CoT)
Re-render the whole conversation each turn and re-prefill into a freshly cleared slot, with past assistant messages rendered as completed `final` channels (analysis dropped, terminated with <|end|> not the <|return|> stop token) — matching the model's training format and the server's builder. The previous incremental cache kept every turn's chain-of-thought plus <|return|> in context, which is out of distribution for harmony multi-turn. The generator now returns the final-channel text to feed back as history. Qwen3 keeps the incremental cache (its ChatML format is unaffected); reset_slot factors out the free+re-register. NOTE: this corrects the multi-turn *format* but does NOT cure the long-context collapse on some inputs. That is a forward-pass numerical bug (NaN / degenerate logits) reproducible in clean bench-gpt-oss independent of the chat layer — the collapse token is vocab_size-1 (201087), the all-NaN argmax tie-break. Tracked separately. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -328,6 +328,13 @@ fn main() {
|
||||
eprintln!("Ready (paged KV cache, tp={world}).");
|
||||
eprintln!("Commands: /exit, /quit, /clear\n");
|
||||
|
||||
// gpt-oss multi-turn history of (user, assistant-final) text. Harmony
|
||||
// requires re-rendering the conversation each turn with prior analysis
|
||||
// dropped, so the moe path re-prefills from this rather than reusing an
|
||||
// incremental KV cache (which would accumulate CoT + <|return|> and collapse
|
||||
// at longer context). Qwen3 ignores this and keeps the incremental cache.
|
||||
let mut moe_history: Vec<(String, String)> = Vec::new();
|
||||
|
||||
loop {
|
||||
let line = match read_line_edited("user> ") {
|
||||
Line::Eof => break,
|
||||
@@ -341,10 +348,8 @@ fn main() {
|
||||
match input {
|
||||
"/exit" | "/quit" | "exit" | "quit" => break,
|
||||
"/clear" => {
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Free(SLOT)); h.wait(); }
|
||||
cache.free_sequence(SLOT);
|
||||
if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
reset_slot(&mut cache, &tp_handle);
|
||||
moe_history.clear();
|
||||
eprintln!("history and KV cache cleared");
|
||||
continue;
|
||||
}
|
||||
@@ -355,21 +360,47 @@ fn main() {
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if is_moe {
|
||||
// Harmony multi-turn: re-render the whole conversation (prior
|
||||
// analysis dropped) and re-prefill into a freshly cleared slot.
|
||||
let prompt = build_conversation_gpt_oss(
|
||||
opts.system_prompt.as_deref(),
|
||||
&moe_history,
|
||||
input,
|
||||
);
|
||||
let prompt_tokens = tokenizer.encode(&prompt);
|
||||
if prompt_tokens.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if prompt_tokens.len() >= max_seq_len {
|
||||
eprintln!(
|
||||
"context full: conversation needs {} tokens >= max_seq_len {max_seq_len}; use /clear",
|
||||
prompt_tokens.len()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let max_new_tokens = opts.max_tokens.min(max_seq_len - prompt_tokens.len());
|
||||
reset_slot(&mut cache, &tp_handle);
|
||||
print!("assistant> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let (_finish, answer) = generate_with_paged_cache(
|
||||
&model, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling,
|
||||
max_new_tokens, use_color, &tp_handle, is_moe,
|
||||
);
|
||||
moe_history.push((input.to_string(), answer));
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Qwen3: incremental KV cache — only the new turn is prefilled and the
|
||||
// assistant's tokens stay cached for the next turn.
|
||||
let include_system = cache.seq_len(SLOT) == 0;
|
||||
let prompt = if is_moe {
|
||||
build_turn_prompt_gpt_oss(
|
||||
opts.system_prompt.as_deref(),
|
||||
include_system,
|
||||
input,
|
||||
)
|
||||
} else {
|
||||
build_turn_prompt(
|
||||
opts.system_prompt.as_deref(),
|
||||
include_system,
|
||||
input,
|
||||
opts.enable_thinking,
|
||||
)
|
||||
};
|
||||
let prompt = build_turn_prompt(
|
||||
opts.system_prompt.as_deref(),
|
||||
include_system,
|
||||
input,
|
||||
opts.enable_thinking,
|
||||
);
|
||||
let prompt_tokens = tokenizer.encode(&prompt);
|
||||
if prompt_tokens.is_empty() {
|
||||
continue;
|
||||
@@ -388,7 +419,7 @@ fn main() {
|
||||
|
||||
print!("assistant> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let finish = generate_with_paged_cache(
|
||||
let (finish, _answer) = generate_with_paged_cache(
|
||||
&model,
|
||||
&mut cache,
|
||||
&tokenizer,
|
||||
@@ -404,14 +435,21 @@ fn main() {
|
||||
append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle);
|
||||
}
|
||||
Finish::Length => {
|
||||
let end_text = if is_moe { "<|end|>\n" } else { "<|im_end|>\n" };
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, end_text, &tp_handle);
|
||||
append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle);
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
/// Free and re-register the single chat KV slot (clears all cached context).
|
||||
fn reset_slot(cache: &mut PagedKVCache, tp: &Option<TpHandle>) {
|
||||
if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); }
|
||||
cache.free_sequence(SLOT);
|
||||
if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); }
|
||||
cache.register_sequence(SLOT).expect("register chat slot");
|
||||
}
|
||||
|
||||
fn parse_args() -> CliOptions {
|
||||
let args: Vec<String> = std::env::args().skip(1).collect();
|
||||
if args.is_empty() || args.iter().any(|a| a == "--help" || a == "-h") {
|
||||
@@ -575,37 +613,48 @@ fn build_turn_prompt(
|
||||
prompt
|
||||
}
|
||||
|
||||
fn build_turn_prompt_gpt_oss(
|
||||
/// Render the full gpt-oss harmony conversation for re-prefill. gpt-oss was
|
||||
/// trained on this exact system-message structure (identity / knowledge cutoff
|
||||
/// / current date / Reasoning level / channels — see the model's
|
||||
/// chat_template.jinja `build_system_message`). A hand-rolled substitute puts
|
||||
/// the model out of distribution and destabilizes channel selection.
|
||||
///
|
||||
/// Harmony multi-turn drops prior chain-of-thought: past assistant messages are
|
||||
/// rendered as completed `final` channels ending in `<|end|>` (not the
|
||||
/// `<|return|>` stop token). Keeping the analysis + `<|return|>` of every turn
|
||||
/// in context — as an incremental KV cache does — is out of distribution and
|
||||
/// makes the model collapse at longer context. "Reasoning: low" keeps the
|
||||
/// analysis channel short for an interactive chat.
|
||||
fn build_conversation_gpt_oss(
|
||||
system: Option<&str>,
|
||||
include_system: bool,
|
||||
user_input: &str,
|
||||
history: &[(String, String)],
|
||||
current_user: &str,
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
if include_system {
|
||||
// Canonical harmony system message. gpt-oss was trained on this exact
|
||||
// structure (identity / knowledge cutoff / current date / Reasoning
|
||||
// level / channels — see the model's chat_template.jinja). A hand-rolled
|
||||
// substitute puts the model out of distribution: channel selection
|
||||
// destabilizes and greedy decoding flips into garbage or analysis loops
|
||||
// that never reach the `final` channel. "Reasoning: low" keeps the
|
||||
// analysis channel short for an interactive chat.
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
prompt.push_str("Knowledge cutoff: 2024-06\n");
|
||||
prompt.push_str(&format!("Current date: {}\n\n", today_ymd()));
|
||||
prompt.push_str("Reasoning: low\n\n");
|
||||
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
if let Some(sys) = system {
|
||||
if !sys.trim().is_empty() {
|
||||
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
|
||||
prompt.push_str(sys.trim());
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
prompt.push_str("<|start|>system<|message|>");
|
||||
prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
prompt.push_str("Knowledge cutoff: 2024-06\n");
|
||||
prompt.push_str(&format!("Current date: {}\n\n", today_ymd()));
|
||||
prompt.push_str("Reasoning: low\n\n");
|
||||
prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message.");
|
||||
prompt.push_str("<|end|>");
|
||||
if let Some(sys) = system {
|
||||
if !sys.trim().is_empty() {
|
||||
prompt.push_str("<|start|>developer<|message|># Instructions\n\n");
|
||||
prompt.push_str(sys.trim());
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
}
|
||||
for (user, assistant) in history {
|
||||
prompt.push_str("<|start|>user<|message|>");
|
||||
prompt.push_str(user);
|
||||
prompt.push_str("<|end|>");
|
||||
prompt.push_str("<|start|>assistant<|channel|>final<|message|>");
|
||||
prompt.push_str(assistant.trim());
|
||||
prompt.push_str("<|end|>");
|
||||
}
|
||||
prompt.push_str("<|start|>user<|message|>");
|
||||
prompt.push_str(user_input);
|
||||
prompt.push_str(current_user);
|
||||
prompt.push_str("<|end|>");
|
||||
prompt.push_str("<|start|>assistant");
|
||||
prompt
|
||||
@@ -639,7 +688,7 @@ fn generate_with_paged_cache(
|
||||
use_color: bool,
|
||||
tp: &Option<TpHandle>,
|
||||
is_moe: bool,
|
||||
) -> Finish {
|
||||
) -> (Finish, String) {
|
||||
let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None };
|
||||
let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None };
|
||||
let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None };
|
||||
@@ -684,6 +733,11 @@ fn generate_with_paged_cache(
|
||||
let mut next = pick(&logits, sampling, &history);
|
||||
let mut decode_buffer = Vec::new();
|
||||
let mut in_thinking = false;
|
||||
// Visible answer tokens, returned for multi-turn history. For moe this is
|
||||
// the final-channel content only (analysis is suppressed/gray); for Qwen3
|
||||
// it is everything printed. The caller decodes these into the assistant
|
||||
// message it re-renders into the next prompt.
|
||||
let mut answer_ids: Vec<u32> = Vec::new();
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let position = cache.seq_len(SLOT);
|
||||
@@ -697,7 +751,7 @@ fn generate_with_paged_cache(
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
return Finish::Stop { token_id: next };
|
||||
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
|
||||
}
|
||||
if harmony_end_id == Some(next) {
|
||||
// <|end|> closes current segment; if in final channel, we're done
|
||||
@@ -708,7 +762,7 @@ fn generate_with_paged_cache(
|
||||
);
|
||||
if hstate == HarmonyState::InFinal {
|
||||
io::stdout().flush().unwrap();
|
||||
return Finish::Stop { token_id: next };
|
||||
return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids));
|
||||
}
|
||||
hstate = HarmonyState::Normal;
|
||||
next = pick(&logits, sampling, &history);
|
||||
@@ -767,6 +821,7 @@ fn generate_with_paged_cache(
|
||||
continue;
|
||||
}
|
||||
|
||||
answer_ids.push(next);
|
||||
print_generated_token(
|
||||
tokenizer,
|
||||
next,
|
||||
@@ -784,7 +839,7 @@ fn generate_with_paged_cache(
|
||||
use_color,
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
Finish::Length
|
||||
(Finish::Length, tokenizer.decode(&answer_ids))
|
||||
}
|
||||
|
||||
fn append_after_stop(
|
||||
|
||||
Reference in New Issue
Block a user