From ce10e4a998d96784770623e52ab676b84462f2f8 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 1 Jul 2026 15:13:19 +0800 Subject: [PATCH] sampling: NaN-safe sample() top-k/top-p path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit partial_cmp().unwrap() in the top-k / top-p sort and softmax paths would panic the engine thread on a single NaN logit. The greedy argmax path is already NaN-safe. Add a one-pass NaN → -inf sweep on the extracted last_row before temperature scaling, which is equivalent to masking the token and keeps the sampler deterministic. Warn once when triggered so the underlying numeric bug isn't silently hidden. --- crates/xserv-model/src/sampling.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/crates/xserv-model/src/sampling.rs b/crates/xserv-model/src/sampling.rs index 5b0a030..2751b34 100644 --- a/crates/xserv-model/src/sampling.rs +++ b/crates/xserv-model/src/sampling.rs @@ -40,7 +40,7 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 { let logits_cpu = logits.to_device(Device::Cpu); // Extract last row as f32 - let last_row: Vec = match logits.dtype() { + let mut last_row: Vec = match logits.dtype() { DType::F32 => { let data = logits_cpu.as_slice::(); data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec() @@ -60,6 +60,20 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 { return argmax(&last_row); } + // NaN-safe: sampling path uses partial_cmp().unwrap() in top-k/top-p + // sorts and softmax; a single NaN logit would panic the engine thread. + // Replace NaN with -inf (equivalent to masking) instead. + let mut nan_seen = false; + for v in last_row.iter_mut() { + if v.is_nan() { + nan_seen = true; + *v = f32::NEG_INFINITY; + } + } + if nan_seen { + eprintln!("[sampling] WARNING: NaN logits encountered in sample()"); + } + // Apply temperature let mut logits_f32: Vec = last_row.iter().map(|v| v / params.temperature).collect();