From f5ec10c2c307042925f2f917e71afb2f2f4c84e4 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 1 Jul 2026 14:15:50 +0800 Subject: [PATCH] xserv-cli: expose sampling params and greedy repetition penalty Interactive REPL used to always call sample_greedy_last on both the paged and legacy KV paths, so temperature/top-k/top-p and the repetition penalty added in the sampling module were unreachable from the CLI. - flag() helper parses --max-tokens / --temperature / --top-k / --top-p / --rep-penalty / --rep-window (defaults preserve prior behavior: temperature 0, top-p 1, penalty 1, window 512). - pick_next() dispatches to sample_greedy_penalized only when temperature==0 and rep_penalty>1, otherwise to sample(). - Both Qwen3/GPT-2 paths and the GptOss paged path share the same sampler and both feed the rolling history window used for the penalty. - Prompt input now unescapes literal "\n" so multi-turn prompts can be typed on one line. --- crates/xserv-model/src/bin/xserv-cli.rs | 97 +++++++++++++++---------- 1 file changed, 57 insertions(+), 40 deletions(-) diff --git a/crates/xserv-model/src/bin/xserv-cli.rs b/crates/xserv-model/src/bin/xserv-cli.rs index fd49c82..f5fe4f9 100644 --- a/crates/xserv-model/src/bin/xserv-cli.rs +++ b/crates/xserv-model/src/bin/xserv-cli.rs @@ -1,23 +1,51 @@ use std::io::{self, Write}; use std::path::PathBuf; -use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader}; +use xserv_model::{ + BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, SamplingParams, loader, sample, + sample_greedy_penalized, +}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; +fn flag(args: &[String], name: &str, default: T) -> T { + args.iter() + .position(|a| a == name) + .and_then(|i| args.get(i + 1)) + .and_then(|s| s.parse().ok()) + .unwrap_or(default) +} + +fn pick_next( + logits: &xserv_tensor::Tensor, + sampling: &SamplingParams, + history: &[u32], + rep_penalty: f32, +) -> u32 { + if rep_penalty > 1.0 && sampling.temperature == 0.0 { + sample_greedy_penalized(logits, history, rep_penalty) + } else { + sample(logits, sampling) + } +} + fn main() { let args: Vec = std::env::args().collect(); if args.len() < 2 { - eprintln!("Usage: xserv-cli [--max-tokens N]"); + eprintln!( + "Usage: xserv-cli [--max-tokens N] [--temperature F] [--top-k N] [--top-p F] [--rep-penalty F] [--rep-window N]" + ); std::process::exit(1); } let model_dir = PathBuf::from(&args[1]); - let max_tokens: usize = args - .iter() - .position(|a| a == "--max-tokens") - .and_then(|i| args.get(i + 1)) - .and_then(|s| s.parse().ok()) - .unwrap_or(100); + let max_tokens = flag(&args, "--max-tokens", 100usize); + let sampling = SamplingParams { + temperature: flag(&args, "--temperature", 0.0f32), + top_k: flag(&args, "--top-k", 0usize), + top_p: flag(&args, "--top-p", 1.0f32), + }; + let rep_penalty = flag(&args, "--rep-penalty", 1.0f32); + let rep_window = flag(&args, "--rep-window", 512usize); xserv_cuda::device::set_device(0).unwrap(); let info = xserv_cuda::device::device_info(0).unwrap(); @@ -65,7 +93,10 @@ fn main() { }; let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json")); - eprintln!("Ready (KV cache, dtype={dtype}).\n"); + eprintln!( + "Ready (KV cache, dtype={dtype}, temperature={}, top_k={}, top_p={}, rep_penalty={}, rep_window={}).\n", + sampling.temperature, sampling.top_k, sampling.top_p, rep_penalty, rep_window + ); loop { print!("xserv> "); @@ -74,15 +105,16 @@ fn main() { if io::stdin().read_line(&mut input).unwrap() == 0 { break; } - let input = input.trim(); - if input.is_empty() { + let raw_input = input.trim(); + if raw_input.is_empty() { continue; } - if input == "quit" || input == "exit" { + if raw_input == "quit" || raw_input == "exit" { break; } + let input = raw_input.replace("\\n", "\n"); - let token_ids = tokenizer.encode(input); + let token_ids = tokenizer.encode(&input); if is_gpt_oss { // GptOss uses paged KV cache @@ -106,7 +138,9 @@ fn main() { _ => unreachable!(), }; let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache); - let mut next = sample_greedy_last(&logits); + let mut history = token_ids.clone(); + let start = history.len().saturating_sub(rep_window); + let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty); print!("{input}"); io::stdout().flush().unwrap(); @@ -115,6 +149,7 @@ fn main() { let text = tokenizer.decode(&[next]); print!("{text}"); io::stdout().flush().unwrap(); + history.push(next); if tokenizer.eos_token_id() == Some(next) { break; @@ -122,7 +157,8 @@ fn main() { let pos = paged_cache.seq_len(slot); let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache); - next = sample_greedy_last(&logits); + let start = history.len().saturating_sub(rep_window); + next = pick_next(&logits, &sampling, &history[start..], rep_penalty); } println!(); paged_cache.free_sequence(slot); @@ -145,11 +181,9 @@ fn main() { Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache), Model::GptOss(_) => unreachable!(), }; - let mut next = match &model { - Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits), - Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits), - Model::GptOss(_) => unreachable!(), - }; + let mut history = token_ids.clone(); + let start = history.len().saturating_sub(rep_window); + let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty); print!("{input}"); io::stdout().flush().unwrap(); @@ -158,6 +192,7 @@ fn main() { let text = tokenizer.decode(&[next]); print!("{text}"); io::stdout().flush().unwrap(); + history.push(next); if tokenizer.eos_token_id() == Some(next) { break; @@ -168,28 +203,10 @@ fn main() { Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache), Model::GptOss(_) => unreachable!(), }; - next = match &model { - Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits), - Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits), - Model::GptOss(_) => unreachable!(), - }; + let start = history.len().saturating_sub(rep_window); + next = pick_next(&logits, &sampling, &history[start..], rep_penalty); } println!(); } } } - -fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 { - use half::bf16; - assert_eq!(logits.ndim(), 2); - let logits_cpu = logits.to_device(Device::Cpu); - let vocab_size = logits.shape()[1]; - let seq_len = logits.shape()[0]; - let data = logits_cpu.as_slice::(); - let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size]; - last.iter() - .enumerate() - .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) - .map(|(i, _)| i as u32) - .unwrap() -}