sampling: NaN-safe sample() top-k/top-p path
partial_cmp().unwrap() in the top-k / top-p sort and softmax paths would panic the engine thread on a single NaN logit. The greedy argmax path is already NaN-safe. Add a one-pass NaN → -inf sweep on the extracted last_row before temperature scaling, which is equivalent to masking the token and keeps the sampler deterministic. Warn once when triggered so the underlying numeric bug isn't silently hidden.
This commit is contained in:
@@ -40,7 +40,7 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
|||||||
let logits_cpu = logits.to_device(Device::Cpu);
|
let logits_cpu = logits.to_device(Device::Cpu);
|
||||||
|
|
||||||
// Extract last row as f32
|
// Extract last row as f32
|
||||||
let last_row: Vec<f32> = match logits.dtype() {
|
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = logits_cpu.as_slice::<f32>();
|
let data = logits_cpu.as_slice::<f32>();
|
||||||
data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
||||||
@@ -60,6 +60,20 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
|||||||
return argmax(&last_row);
|
return argmax(&last_row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NaN-safe: sampling path uses partial_cmp().unwrap() in top-k/top-p
|
||||||
|
// sorts and softmax; a single NaN logit would panic the engine thread.
|
||||||
|
// Replace NaN with -inf (equivalent to masking) instead.
|
||||||
|
let mut nan_seen = false;
|
||||||
|
for v in last_row.iter_mut() {
|
||||||
|
if v.is_nan() {
|
||||||
|
nan_seen = true;
|
||||||
|
*v = f32::NEG_INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nan_seen {
|
||||||
|
eprintln!("[sampling] WARNING: NaN logits encountered in sample()");
|
||||||
|
}
|
||||||
|
|
||||||
// Apply temperature
|
// Apply temperature
|
||||||
let mut logits_f32: Vec<f32> = last_row.iter().map(|v| v / params.temperature).collect();
|
let mut logits_f32: Vec<f32> = last_row.iter().map(|v| v / params.temperature).collect();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user