xserv-cli: expose sampling params and greedy repetition penalty

Interactive REPL used to always call sample_greedy_last on both the
paged and legacy KV paths, so temperature/top-k/top-p and the repetition
penalty added in the sampling module were unreachable from the CLI.

- flag() helper parses --max-tokens / --temperature / --top-k / --top-p
  / --rep-penalty / --rep-window (defaults preserve prior behavior:
  temperature 0, top-p 1, penalty 1, window 512).
- pick_next() dispatches to sample_greedy_penalized only when
  temperature==0 and rep_penalty>1, otherwise to sample().
- Both Qwen3/GPT-2 paths and the GptOss paged path share the same
  sampler and both feed the rolling history window used for the penalty.
- Prompt input now unescapes literal "\n" so multi-turn prompts can be
  typed on one line.
This commit is contained in:
2026-07-01 14:15:50 +08:00
parent ce7229f4fe
commit f5ec10c2c3

View File

@@ -1,23 +1,51 @@
use std::io::{self, Write}; use std::io::{self, Write};
use std::path::PathBuf; use std::path::PathBuf;
use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader}; use xserv_model::{
BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, SamplingParams, loader, sample,
sample_greedy_penalized,
};
use xserv_tensor::{DType, Device}; use xserv_tensor::{DType, Device};
use xserv_tokenizer::Tokenizer; use xserv_tokenizer::Tokenizer;
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
fn pick_next(
logits: &xserv_tensor::Tensor,
sampling: &SamplingParams,
history: &[u32],
rep_penalty: f32,
) -> u32 {
if rep_penalty > 1.0 && sampling.temperature == 0.0 {
sample_greedy_penalized(logits, history, rep_penalty)
} else {
sample(logits, sampling)
}
}
fn main() { fn main() {
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
if args.len() < 2 { if args.len() < 2 {
eprintln!("Usage: xserv-cli <model-dir> [--max-tokens N]"); eprintln!(
"Usage: xserv-cli <model-dir> [--max-tokens N] [--temperature F] [--top-k N] [--top-p F] [--rep-penalty F] [--rep-window N]"
);
std::process::exit(1); std::process::exit(1);
} }
let model_dir = PathBuf::from(&args[1]); let model_dir = PathBuf::from(&args[1]);
let max_tokens: usize = args let max_tokens = flag(&args, "--max-tokens", 100usize);
.iter() let sampling = SamplingParams {
.position(|a| a == "--max-tokens") temperature: flag(&args, "--temperature", 0.0f32),
.and_then(|i| args.get(i + 1)) top_k: flag(&args, "--top-k", 0usize),
.and_then(|s| s.parse().ok()) top_p: flag(&args, "--top-p", 1.0f32),
.unwrap_or(100); };
let rep_penalty = flag(&args, "--rep-penalty", 1.0f32);
let rep_window = flag(&args, "--rep-window", 512usize);
xserv_cuda::device::set_device(0).unwrap(); xserv_cuda::device::set_device(0).unwrap();
let info = xserv_cuda::device::device_info(0).unwrap(); let info = xserv_cuda::device::device_info(0).unwrap();
@@ -65,7 +93,10 @@ fn main() {
}; };
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json")); let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
eprintln!("Ready (KV cache, dtype={dtype}).\n"); eprintln!(
"Ready (KV cache, dtype={dtype}, temperature={}, top_k={}, top_p={}, rep_penalty={}, rep_window={}).\n",
sampling.temperature, sampling.top_k, sampling.top_p, rep_penalty, rep_window
);
loop { loop {
print!("xserv> "); print!("xserv> ");
@@ -74,15 +105,16 @@ fn main() {
if io::stdin().read_line(&mut input).unwrap() == 0 { if io::stdin().read_line(&mut input).unwrap() == 0 {
break; break;
} }
let input = input.trim(); let raw_input = input.trim();
if input.is_empty() { if raw_input.is_empty() {
continue; continue;
} }
if input == "quit" || input == "exit" { if raw_input == "quit" || raw_input == "exit" {
break; break;
} }
let input = raw_input.replace("\\n", "\n");
let token_ids = tokenizer.encode(input); let token_ids = tokenizer.encode(&input);
if is_gpt_oss { if is_gpt_oss {
// GptOss uses paged KV cache // GptOss uses paged KV cache
@@ -106,7 +138,9 @@ fn main() {
_ => unreachable!(), _ => unreachable!(),
}; };
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache); let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
let mut next = sample_greedy_last(&logits); let mut history = token_ids.clone();
let start = history.len().saturating_sub(rep_window);
let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
print!("{input}"); print!("{input}");
io::stdout().flush().unwrap(); io::stdout().flush().unwrap();
@@ -115,6 +149,7 @@ fn main() {
let text = tokenizer.decode(&[next]); let text = tokenizer.decode(&[next]);
print!("{text}"); print!("{text}");
io::stdout().flush().unwrap(); io::stdout().flush().unwrap();
history.push(next);
if tokenizer.eos_token_id() == Some(next) { if tokenizer.eos_token_id() == Some(next) {
break; break;
@@ -122,7 +157,8 @@ fn main() {
let pos = paged_cache.seq_len(slot); let pos = paged_cache.seq_len(slot);
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache); let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
next = sample_greedy_last(&logits); let start = history.len().saturating_sub(rep_window);
next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
} }
println!(); println!();
paged_cache.free_sequence(slot); paged_cache.free_sequence(slot);
@@ -145,11 +181,9 @@ fn main() {
Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache), Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache),
Model::GptOss(_) => unreachable!(), Model::GptOss(_) => unreachable!(),
}; };
let mut next = match &model { let mut history = token_ids.clone();
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits), let start = history.len().saturating_sub(rep_window);
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits), let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
Model::GptOss(_) => unreachable!(),
};
print!("{input}"); print!("{input}");
io::stdout().flush().unwrap(); io::stdout().flush().unwrap();
@@ -158,6 +192,7 @@ fn main() {
let text = tokenizer.decode(&[next]); let text = tokenizer.decode(&[next]);
print!("{text}"); print!("{text}");
io::stdout().flush().unwrap(); io::stdout().flush().unwrap();
history.push(next);
if tokenizer.eos_token_id() == Some(next) { if tokenizer.eos_token_id() == Some(next) {
break; break;
@@ -168,28 +203,10 @@ fn main() {
Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache), Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache),
Model::GptOss(_) => unreachable!(), Model::GptOss(_) => unreachable!(),
}; };
next = match &model { let start = history.len().saturating_sub(rep_window);
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits), next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
Model::GptOss(_) => unreachable!(),
};
} }
println!(); println!();
} }
} }
} }
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
use half::bf16;
assert_eq!(logits.ndim(), 2);
let logits_cpu = logits.to_device(Device::Cpu);
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let data = logits_cpu.as_slice::<bf16>();
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
last.iter()
.enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(i, _)| i as u32)
.unwrap()
}