sampling: NaN-safe sample() top-k/top-p path

partial_cmp().unwrap() in the top-k / top-p sort and softmax paths would
panic the engine thread on a single NaN logit. The greedy argmax path
is already NaN-safe. Add a one-pass NaN → -inf sweep on the extracted
last_row before temperature scaling, which is equivalent to masking the
token and keeps the sampler deterministic. Warn once when triggered so
the underlying numeric bug isn't silently hidden.
This commit is contained in:
2026-07-01 15:13:19 +08:00
parent 5f060902f6
commit ce10e4a998

View File

@@ -40,7 +40,7 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
let logits_cpu = logits.to_device(Device::Cpu); let logits_cpu = logits.to_device(Device::Cpu);
// Extract last row as f32 // Extract last row as f32
let last_row: Vec<f32> = match logits.dtype() { let mut last_row: Vec<f32> = match logits.dtype() {
DType::F32 => { DType::F32 => {
let data = logits_cpu.as_slice::<f32>(); let data = logits_cpu.as_slice::<f32>();
data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec() data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
@@ -60,6 +60,20 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
return argmax(&last_row); return argmax(&last_row);
} }
// NaN-safe: sampling path uses partial_cmp().unwrap() in top-k/top-p
// sorts and softmax; a single NaN logit would panic the engine thread.
// Replace NaN with -inf (equivalent to masking) instead.
let mut nan_seen = false;
for v in last_row.iter_mut() {
if v.is_nan() {
nan_seen = true;
*v = f32::NEG_INFINITY;
}
}
if nan_seen {
eprintln!("[sampling] WARNING: NaN logits encountered in sample()");
}
// Apply temperature // Apply temperature
let mut logits_f32: Vec<f32> = last_row.iter().map(|v| v / params.temperature).collect(); let mut logits_f32: Vec<f32> = last_row.iter().map(|v| v / params.temperature).collect();