xserv-cli: expose sampling params and greedy repetition penalty
Interactive REPL used to always call sample_greedy_last on both the paged and legacy KV paths, so temperature/top-k/top-p and the repetition penalty added in the sampling module were unreachable from the CLI. - flag() helper parses --max-tokens / --temperature / --top-k / --top-p / --rep-penalty / --rep-window (defaults preserve prior behavior: temperature 0, top-p 1, penalty 1, window 512). - pick_next() dispatches to sample_greedy_penalized only when temperature==0 and rep_penalty>1, otherwise to sample(). - Both Qwen3/GPT-2 paths and the GptOss paged path share the same sampler and both feed the rolling history window used for the penalty. - Prompt input now unescapes literal "\n" so multi-turn prompts can be typed on one line.
This commit is contained in:
@@ -1,23 +1,51 @@
|
|||||||
use std::io::{self, Write};
|
use std::io::{self, Write};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader};
|
use xserv_model::{
|
||||||
|
BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, SamplingParams, loader, sample,
|
||||||
|
sample_greedy_penalized,
|
||||||
|
};
|
||||||
use xserv_tensor::{DType, Device};
|
use xserv_tensor::{DType, Device};
|
||||||
use xserv_tokenizer::Tokenizer;
|
use xserv_tokenizer::Tokenizer;
|
||||||
|
|
||||||
|
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == name)
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_next(
|
||||||
|
logits: &xserv_tensor::Tensor,
|
||||||
|
sampling: &SamplingParams,
|
||||||
|
history: &[u32],
|
||||||
|
rep_penalty: f32,
|
||||||
|
) -> u32 {
|
||||||
|
if rep_penalty > 1.0 && sampling.temperature == 0.0 {
|
||||||
|
sample_greedy_penalized(logits, history, rep_penalty)
|
||||||
|
} else {
|
||||||
|
sample(logits, sampling)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
if args.len() < 2 {
|
if args.len() < 2 {
|
||||||
eprintln!("Usage: xserv-cli <model-dir> [--max-tokens N]");
|
eprintln!(
|
||||||
|
"Usage: xserv-cli <model-dir> [--max-tokens N] [--temperature F] [--top-k N] [--top-p F] [--rep-penalty F] [--rep-window N]"
|
||||||
|
);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
let model_dir = PathBuf::from(&args[1]);
|
let model_dir = PathBuf::from(&args[1]);
|
||||||
let max_tokens: usize = args
|
let max_tokens = flag(&args, "--max-tokens", 100usize);
|
||||||
.iter()
|
let sampling = SamplingParams {
|
||||||
.position(|a| a == "--max-tokens")
|
temperature: flag(&args, "--temperature", 0.0f32),
|
||||||
.and_then(|i| args.get(i + 1))
|
top_k: flag(&args, "--top-k", 0usize),
|
||||||
.and_then(|s| s.parse().ok())
|
top_p: flag(&args, "--top-p", 1.0f32),
|
||||||
.unwrap_or(100);
|
};
|
||||||
|
let rep_penalty = flag(&args, "--rep-penalty", 1.0f32);
|
||||||
|
let rep_window = flag(&args, "--rep-window", 512usize);
|
||||||
|
|
||||||
xserv_cuda::device::set_device(0).unwrap();
|
xserv_cuda::device::set_device(0).unwrap();
|
||||||
let info = xserv_cuda::device::device_info(0).unwrap();
|
let info = xserv_cuda::device::device_info(0).unwrap();
|
||||||
@@ -65,7 +93,10 @@ fn main() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||||
eprintln!("Ready (KV cache, dtype={dtype}).\n");
|
eprintln!(
|
||||||
|
"Ready (KV cache, dtype={dtype}, temperature={}, top_k={}, top_p={}, rep_penalty={}, rep_window={}).\n",
|
||||||
|
sampling.temperature, sampling.top_k, sampling.top_p, rep_penalty, rep_window
|
||||||
|
);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
print!("xserv> ");
|
print!("xserv> ");
|
||||||
@@ -74,15 +105,16 @@ fn main() {
|
|||||||
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let input = input.trim();
|
let raw_input = input.trim();
|
||||||
if input.is_empty() {
|
if raw_input.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if input == "quit" || input == "exit" {
|
if raw_input == "quit" || raw_input == "exit" {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
let input = raw_input.replace("\\n", "\n");
|
||||||
|
|
||||||
let token_ids = tokenizer.encode(input);
|
let token_ids = tokenizer.encode(&input);
|
||||||
|
|
||||||
if is_gpt_oss {
|
if is_gpt_oss {
|
||||||
// GptOss uses paged KV cache
|
// GptOss uses paged KV cache
|
||||||
@@ -106,7 +138,9 @@ fn main() {
|
|||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
|
let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache);
|
||||||
let mut next = sample_greedy_last(&logits);
|
let mut history = token_ids.clone();
|
||||||
|
let start = history.len().saturating_sub(rep_window);
|
||||||
|
let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||||
|
|
||||||
print!("{input}");
|
print!("{input}");
|
||||||
io::stdout().flush().unwrap();
|
io::stdout().flush().unwrap();
|
||||||
@@ -115,6 +149,7 @@ fn main() {
|
|||||||
let text = tokenizer.decode(&[next]);
|
let text = tokenizer.decode(&[next]);
|
||||||
print!("{text}");
|
print!("{text}");
|
||||||
io::stdout().flush().unwrap();
|
io::stdout().flush().unwrap();
|
||||||
|
history.push(next);
|
||||||
|
|
||||||
if tokenizer.eos_token_id() == Some(next) {
|
if tokenizer.eos_token_id() == Some(next) {
|
||||||
break;
|
break;
|
||||||
@@ -122,7 +157,8 @@ fn main() {
|
|||||||
|
|
||||||
let pos = paged_cache.seq_len(slot);
|
let pos = paged_cache.seq_len(slot);
|
||||||
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
|
let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache);
|
||||||
next = sample_greedy_last(&logits);
|
let start = history.len().saturating_sub(rep_window);
|
||||||
|
next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||||
}
|
}
|
||||||
println!();
|
println!();
|
||||||
paged_cache.free_sequence(slot);
|
paged_cache.free_sequence(slot);
|
||||||
@@ -145,11 +181,9 @@ fn main() {
|
|||||||
Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache),
|
Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache),
|
||||||
Model::GptOss(_) => unreachable!(),
|
Model::GptOss(_) => unreachable!(),
|
||||||
};
|
};
|
||||||
let mut next = match &model {
|
let mut history = token_ids.clone();
|
||||||
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
|
let start = history.len().saturating_sub(rep_window);
|
||||||
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
|
let mut next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||||
Model::GptOss(_) => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
print!("{input}");
|
print!("{input}");
|
||||||
io::stdout().flush().unwrap();
|
io::stdout().flush().unwrap();
|
||||||
@@ -158,6 +192,7 @@ fn main() {
|
|||||||
let text = tokenizer.decode(&[next]);
|
let text = tokenizer.decode(&[next]);
|
||||||
print!("{text}");
|
print!("{text}");
|
||||||
io::stdout().flush().unwrap();
|
io::stdout().flush().unwrap();
|
||||||
|
history.push(next);
|
||||||
|
|
||||||
if tokenizer.eos_token_id() == Some(next) {
|
if tokenizer.eos_token_id() == Some(next) {
|
||||||
break;
|
break;
|
||||||
@@ -168,28 +203,10 @@ fn main() {
|
|||||||
Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache),
|
Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache),
|
||||||
Model::GptOss(_) => unreachable!(),
|
Model::GptOss(_) => unreachable!(),
|
||||||
};
|
};
|
||||||
next = match &model {
|
let start = history.len().saturating_sub(rep_window);
|
||||||
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
|
next = pick_next(&logits, &sampling, &history[start..], rep_penalty);
|
||||||
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
|
|
||||||
Model::GptOss(_) => unreachable!(),
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
println!();
|
println!();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 {
|
|
||||||
use half::bf16;
|
|
||||||
assert_eq!(logits.ndim(), 2);
|
|
||||||
let logits_cpu = logits.to_device(Device::Cpu);
|
|
||||||
let vocab_size = logits.shape()[1];
|
|
||||||
let seq_len = logits.shape()[0];
|
|
||||||
let data = logits_cpu.as_slice::<bf16>();
|
|
||||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
|
||||||
last.iter()
|
|
||||||
.enumerate()
|
|
||||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
|
||||||
.map(|(i, _)| i as u32)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user