diff --git a/crates/xserv-model/src/lib.rs b/crates/xserv-model/src/lib.rs index fc0d910..67c3f46 100644 --- a/crates/xserv-model/src/lib.rs +++ b/crates/xserv-model/src/lib.rs @@ -15,7 +15,7 @@ pub use gpt_oss::GptOss; pub use kv_cache::GpuKVCache; pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE}; pub use qwen3::Qwen3; -pub use sampling::{SamplingParams, sample}; +pub use sampling::{SamplingParams, sample, sample_greedy_penalized}; /// Initialize GPU kernel hooks. Called automatically by model constructors, /// but safe to call multiple times (idempotent via OnceLock). diff --git a/crates/xserv-model/src/sampling.rs b/crates/xserv-model/src/sampling.rs index 97bd01a..762b62e 100644 --- a/crates/xserv-model/src/sampling.rs +++ b/crates/xserv-model/src/sampling.rs @@ -112,10 +112,51 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 { (vocab_size - 1) as u32 } -fn argmax(data: &[f32]) -> u32 { - data.iter() - .enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) - .map(|(i, _)| i as u32) - .unwrap() +/// Greedy argmax with a repetition penalty applied to `recent` token ids +/// (HF-style: divide positive logits by `penalty`, multiply negative by it). +/// `penalty <= 1.0` is a no-op. Mitigates greedy repetition loops on reasoning +/// models without changing the forward pass. NaN-safe. +pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) -> u32 { + assert_eq!(logits.ndim(), 2); + let vocab_size = logits.shape()[1]; + let seq_len = logits.shape()[0]; + let logits_cpu = logits.to_device(Device::Cpu); + let mut last_row: Vec = match logits.dtype() { + DType::F32 => logits_cpu.as_slice::()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(), + DType::BF16 => logits_cpu.as_slice::()[(seq_len - 1) * vocab_size..seq_len * vocab_size] + .iter().map(|v| v.to_f32()).collect(), + _ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()), + }; + if penalty > 1.0 { + for &id in recent { + let i = id as usize; + if i < last_row.len() { + let v = last_row[i]; + last_row[i] = if v > 0.0 { v / penalty } else { v * penalty }; + } + } + } + argmax(&last_row) +} + +fn argmax(data: &[f32]) -> u32 { + // NaN-safe: a single NaN logit must not crash the engine thread (a + // partial_cmp().unwrap() panics on NaN). Skip NaNs; warn once if seen. + let mut best_i = 0usize; + let mut best = f32::NEG_INFINITY; + let mut nan_seen = false; + for (i, &v) in data.iter().enumerate() { + if v.is_nan() { + nan_seen = true; + continue; + } + if v > best { + best = v; + best_i = i; + } + } + if nan_seen { + eprintln!("[sampling] WARNING: NaN logits encountered in argmax"); + } + best_i as u32 }