model/sampling: NaN-safe argmax + optional repetition penalty
Make argmax skip NaN logits (warn once) instead of panicking the engine thread on a single NaN. Add sample_greedy_penalized() applying an HF-style repetition penalty over recent ids on the greedy path, to break greedy repetition loops on reasoning models without touching the forward pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -15,7 +15,7 @@ pub use gpt_oss::GptOss;
|
||||
pub use kv_cache::GpuKVCache;
|
||||
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
|
||||
pub use qwen3::Qwen3;
|
||||
pub use sampling::{SamplingParams, sample};
|
||||
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};
|
||||
|
||||
/// Initialize GPU kernel hooks. Called automatically by model constructors,
|
||||
/// but safe to call multiple times (idempotent via OnceLock).
|
||||
|
||||
@@ -112,10 +112,51 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
(vocab_size - 1) as u32
|
||||
}
|
||||
|
||||
fn argmax(data: &[f32]) -> u32 {
|
||||
data.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
/// Greedy argmax with a repetition penalty applied to `recent` token ids
|
||||
/// (HF-style: divide positive logits by `penalty`, multiply negative by it).
|
||||
/// `penalty <= 1.0` is a no-op. Mitigates greedy repetition loops on reasoning
|
||||
/// models without changing the forward pass. NaN-safe.
|
||||
pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||
DType::F32 => logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(),
|
||||
DType::BF16 => logits_cpu.as_slice::<bf16>()[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||
.iter().map(|v| v.to_f32()).collect(),
|
||||
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
|
||||
};
|
||||
if penalty > 1.0 {
|
||||
for &id in recent {
|
||||
let i = id as usize;
|
||||
if i < last_row.len() {
|
||||
let v = last_row[i];
|
||||
last_row[i] = if v > 0.0 { v / penalty } else { v * penalty };
|
||||
}
|
||||
}
|
||||
}
|
||||
argmax(&last_row)
|
||||
}
|
||||
|
||||
fn argmax(data: &[f32]) -> u32 {
|
||||
// NaN-safe: a single NaN logit must not crash the engine thread (a
|
||||
// partial_cmp().unwrap() panics on NaN). Skip NaNs; warn once if seen.
|
||||
let mut best_i = 0usize;
|
||||
let mut best = f32::NEG_INFINITY;
|
||||
let mut nan_seen = false;
|
||||
for (i, &v) in data.iter().enumerate() {
|
||||
if v.is_nan() {
|
||||
nan_seen = true;
|
||||
continue;
|
||||
}
|
||||
if v > best {
|
||||
best = v;
|
||||
best_i = i;
|
||||
}
|
||||
}
|
||||
if nan_seen {
|
||||
eprintln!("[sampling] WARNING: NaN logits encountered in argmax");
|
||||
}
|
||||
best_i as u32
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user