sampling: GPU argmax fast path for greedy decode

sample() at temperature 0 copied the full [seq, 201088] BF16 logits
to the host and scanned them every token (~1 ms/token). Use the
Phase 15 argmax kernel (block reduction + 4-byte D2H) when logits are
BF16 on GPU; bench-gpt-oss's greedy sampler likewise. Exact-tie
logits may break differently than the host scan — greedy trajectories
can legitimately diverge at a tie token (GSM8K unchanged).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 20:12:37 +08:00
parent 34224c7c93
commit 8414f8d1e6

View File

@@ -19,6 +19,18 @@ impl Default for SamplingParams {
/// Uses the last position's logits. Handles both F32 and BF16 dtypes.
pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
assert_eq!(logits.ndim(), 2);
// Greedy fast path: GPU argmax + 4-byte D2H instead of copying the whole
// [seq, vocab] logits to the host and scanning it (~201k bf16/token).
// NaN logits lose every `>` comparison in the kernel, matching the
// NaN-safe host argmax below.
if params.temperature == 0.0
&& logits.dtype() == DType::BF16
&& matches!(logits.device(), Device::Cuda(_))
&& logits.is_contiguous()
{
let ids = xserv_kernels::argmax_bf16_to_host(logits);
return *ids.last().unwrap();
}
let vocab_size = logits.shape()[1];
let seq_len = logits.shape()[0];
let logits_cpu = logits.to_device(Device::Cpu);