diff --git a/crates/xserv-model/src/bin/xserv-chat.rs b/crates/xserv-model/src/bin/xserv-chat.rs index 2dcabc2..e206529 100644 --- a/crates/xserv-model/src/bin/xserv-chat.rs +++ b/crates/xserv-model/src/bin/xserv-chat.rs @@ -397,6 +397,7 @@ fn main() { max_new_tokens, use_color, &tp_handle, + is_moe, ); match finish { Finish::Stop { token_id } => { @@ -607,7 +608,16 @@ fn generate_with_paged_cache( max_tokens: usize, use_color: bool, tp: &Option, + is_moe: bool, ) -> Finish { + let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None }; + let harmony_special: Vec = if is_moe { + ["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"] + .iter().filter_map(|s| tokenizer.special_token_id(s)).collect() + } else { + Vec::new() + }; + if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); } let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache); if let Some(h) = tp { h.wait(); } @@ -620,7 +630,7 @@ fn generate_with_paged_cache( if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); } let logits = model.forward_decode_paged(&[next], &[position], &[SLOT], cache); if let Some(h) = tp { h.wait(); } - if tokenizer.is_eos(next) { + if tokenizer.is_eos(next) || harmony_end_id == Some(next) { print_stream_text( &tokenizer.flush_decode_stream(&mut decode_buffer), in_thinking, @@ -630,6 +640,11 @@ fn generate_with_paged_cache( return Finish::Stop { token_id: next }; } + if harmony_special.contains(&next) { + next = sample(&logits, sampling); + continue; + } + print_generated_token( tokenizer, next,