sampling: NaN-safe sample() top-k/top-p path
partial_cmp().unwrap() in the top-k / top-p sort and softmax paths would panic the engine thread on a single NaN logit. The greedy argmax path is already NaN-safe. Add a one-pass NaN → -inf sweep on the extracted last_row before temperature scaling, which is equivalent to masking the token and keeps the sampler deterministic. Warn once when triggered so the underlying numeric bug isn't silently hidden.
This commit is contained in:
@@ -40,7 +40,7 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
|
||||
// Extract last row as f32
|
||||
let last_row: Vec<f32> = match logits.dtype() {
|
||||
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||
DType::F32 => {
|
||||
let data = logits_cpu.as_slice::<f32>();
|
||||
data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
||||
@@ -60,6 +60,20 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
return argmax(&last_row);
|
||||
}
|
||||
|
||||
// NaN-safe: sampling path uses partial_cmp().unwrap() in top-k/top-p
|
||||
// sorts and softmax; a single NaN logit would panic the engine thread.
|
||||
// Replace NaN with -inf (equivalent to masking) instead.
|
||||
let mut nan_seen = false;
|
||||
for v in last_row.iter_mut() {
|
||||
if v.is_nan() {
|
||||
nan_seen = true;
|
||||
*v = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
if nan_seen {
|
||||
eprintln!("[sampling] WARNING: NaN logits encountered in sample()");
|
||||
}
|
||||
|
||||
// Apply temperature
|
||||
let mut logits_f32: Vec<f32> = last_row.iter().map(|v| v / params.temperature).collect();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user