model/sampling: NaN-safe argmax + optional repetition penalty

Make argmax skip NaN logits (warn once) instead of panicking the engine
thread on a single NaN. Add sample_greedy_penalized() applying an
HF-style repetition penalty over recent ids on the greedy path, to break
greedy repetition loops on reasoning models without touching the forward
pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-05-31 00:56:27 +08:00
parent e11f15e009
commit 99b212e6c1
2 changed files with 48 additions and 7 deletions

View File

@@ -15,7 +15,7 @@ pub use gpt_oss::GptOss;
pub use kv_cache::GpuKVCache;
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
pub use qwen3::Qwen3;
pub use sampling::{SamplingParams, sample};
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};
/// Initialize GPU kernel hooks. Called automatically by model constructors,
/// but safe to call multiple times (idempotent via OnceLock).

View File

@@ -112,10 +112,51 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
(vocab_size - 1) as u32
}
fn argmax(data: &[f32]) -> u32 {
data.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i as u32)
.unwrap()
/// Greedy argmax with a repetition penalty applied to `recent` token ids
/// (HF-style: divide positive logits by `penalty`, multiply negative by it).
/// `penalty <= 1.0` is a no-op. Mitigates greedy repetition loops on reasoning
/// models without changing the forward pass. NaN-safe.
pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) -> u32 {
assert_eq!(logits.ndim(), 2);
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let logits_cpu = logits.to_device(Device::Cpu);
let mut last_row: Vec<f32> = match logits.dtype() {
DType::F32 => logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(),
DType::BF16 => logits_cpu.as_slice::<bf16>()[(seq_len - 1) * vocab_size..seq_len * vocab_size]
.iter().map(|v| v.to_f32()).collect(),
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
};
if penalty > 1.0 {
for &id in recent {
let i = id as usize;
if i < last_row.len() {
let v = last_row[i];
last_row[i] = if v > 0.0 { v / penalty } else { v * penalty };
}
}
}
argmax(&last_row)
}
fn argmax(data: &[f32]) -> u32 {
// NaN-safe: a single NaN logit must not crash the engine thread (a
// partial_cmp().unwrap() panics on NaN). Skip NaNs; warn once if seen.
let mut best_i = 0usize;
let mut best = f32::NEG_INFINITY;
let mut nan_seen = false;
for (i, &v) in data.iter().enumerate() {
if v.is_nan() {
nan_seen = true;
continue;
}
if v > best {
best = v;
best_i = i;
}
}
if nan_seen {
eprintln!("[sampling] WARNING: NaN logits encountered in argmax");
}
best_i as u32
}