sampling: GPU argmax fast path for greedy decode
sample() at temperature 0 copied the full [seq, 201088] BF16 logits to the host and scanned them every token (~1 ms/token). Use the Phase 15 argmax kernel (block reduction + 4-byte D2H) when logits are BF16 on GPU; bench-gpt-oss's greedy sampler likewise. Exact-tie logits may break differently than the host scan — greedy trajectories can legitimately diverge at a tie token (GSM8K unchanged). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -19,6 +19,18 @@ impl Default for SamplingParams {
|
||||
/// Uses the last position's logits. Handles both F32 and BF16 dtypes.
|
||||
pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
// Greedy fast path: GPU argmax + 4-byte D2H instead of copying the whole
|
||||
// [seq, vocab] logits to the host and scanning it (~201k bf16/token).
|
||||
// NaN logits lose every `>` comparison in the kernel, matching the
|
||||
// NaN-safe host argmax below.
|
||||
if params.temperature == 0.0
|
||||
&& logits.dtype() == DType::BF16
|
||||
&& matches!(logits.device(), Device::Cuda(_))
|
||||
&& logits.is_contiguous()
|
||||
{
|
||||
let ids = xserv_kernels::argmax_bf16_to_host(logits);
|
||||
return *ids.last().unwrap();
|
||||
}
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
|
||||
Reference in New Issue
Block a user