model/sampling: NaN-safe argmax + optional repetition penalty
Make argmax skip NaN logits (warn once) instead of panicking the engine thread on a single NaN. Add sample_greedy_penalized() applying an HF-style repetition penalty over recent ids on the greedy path, to break greedy repetition loops on reasoning models without touching the forward pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -15,7 +15,7 @@ pub use gpt_oss::GptOss;
|
|||||||
pub use kv_cache::GpuKVCache;
|
pub use kv_cache::GpuKVCache;
|
||||||
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
|
pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE};
|
||||||
pub use qwen3::Qwen3;
|
pub use qwen3::Qwen3;
|
||||||
pub use sampling::{SamplingParams, sample};
|
pub use sampling::{SamplingParams, sample, sample_greedy_penalized};
|
||||||
|
|
||||||
/// Initialize GPU kernel hooks. Called automatically by model constructors,
|
/// Initialize GPU kernel hooks. Called automatically by model constructors,
|
||||||
/// but safe to call multiple times (idempotent via OnceLock).
|
/// but safe to call multiple times (idempotent via OnceLock).
|
||||||
|
|||||||
@@ -112,10 +112,51 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
|||||||
(vocab_size - 1) as u32
|
(vocab_size - 1) as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
fn argmax(data: &[f32]) -> u32 {
|
/// Greedy argmax with a repetition penalty applied to `recent` token ids
|
||||||
data.iter()
|
/// (HF-style: divide positive logits by `penalty`, multiply negative by it).
|
||||||
.enumerate()
|
/// `penalty <= 1.0` is a no-op. Mitigates greedy repetition loops on reasoning
|
||||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
/// models without changing the forward pass. NaN-safe.
|
||||||
.map(|(i, _)| i as u32)
|
pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) -> u32 {
|
||||||
.unwrap()
|
assert_eq!(logits.ndim(), 2);
|
||||||
|
let vocab_size = logits.shape()[1];
|
||||||
|
let seq_len = logits.shape()[0];
|
||||||
|
let logits_cpu = logits.to_device(Device::Cpu);
|
||||||
|
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||||
|
DType::F32 => logits_cpu.as_slice::<f32>()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(),
|
||||||
|
DType::BF16 => logits_cpu.as_slice::<bf16>()[(seq_len - 1) * vocab_size..seq_len * vocab_size]
|
||||||
|
.iter().map(|v| v.to_f32()).collect(),
|
||||||
|
_ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()),
|
||||||
|
};
|
||||||
|
if penalty > 1.0 {
|
||||||
|
for &id in recent {
|
||||||
|
let i = id as usize;
|
||||||
|
if i < last_row.len() {
|
||||||
|
let v = last_row[i];
|
||||||
|
last_row[i] = if v > 0.0 { v / penalty } else { v * penalty };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
argmax(&last_row)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn argmax(data: &[f32]) -> u32 {
|
||||||
|
// NaN-safe: a single NaN logit must not crash the engine thread (a
|
||||||
|
// partial_cmp().unwrap() panics on NaN). Skip NaNs; warn once if seen.
|
||||||
|
let mut best_i = 0usize;
|
||||||
|
let mut best = f32::NEG_INFINITY;
|
||||||
|
let mut nan_seen = false;
|
||||||
|
for (i, &v) in data.iter().enumerate() {
|
||||||
|
if v.is_nan() {
|
||||||
|
nan_seen = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if v > best {
|
||||||
|
best = v;
|
||||||
|
best_i = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nan_seen {
|
||||||
|
eprintln!("[sampling] WARNING: NaN logits encountered in argmax");
|
||||||
|
}
|
||||||
|
best_i as u32
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user