diff --git a/crates/xserv-model/src/sampling.rs b/crates/xserv-model/src/sampling.rs index 5b0a030..2751b34 100644 --- a/crates/xserv-model/src/sampling.rs +++ b/crates/xserv-model/src/sampling.rs @@ -40,7 +40,7 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 { let logits_cpu = logits.to_device(Device::Cpu); // Extract last row as f32 - let last_row: Vec = match logits.dtype() { + let mut last_row: Vec = match logits.dtype() { DType::F32 => { let data = logits_cpu.as_slice::(); data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec() @@ -60,6 +60,20 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 { return argmax(&last_row); } + // NaN-safe: sampling path uses partial_cmp().unwrap() in top-k/top-p + // sorts and softmax; a single NaN logit would panic the engine thread. + // Replace NaN with -inf (equivalent to masking) instead. + let mut nan_seen = false; + for v in last_row.iter_mut() { + if v.is_nan() { + nan_seen = true; + *v = f32::NEG_INFINITY; + } + } + if nan_seen { + eprintln!("[sampling] WARNING: NaN logits encountered in sample()"); + } + // Apply temperature let mut logits_f32: Vec = last_row.iter().map(|v| v / params.temperature).collect();