From 8414f8d1e6ec7d144878e20d8c344e427caf8c27 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 12 Jun 2026 20:12:37 +0800 Subject: [PATCH] sampling: GPU argmax fast path for greedy decode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sample() at temperature 0 copied the full [seq, 201088] BF16 logits to the host and scanned them every token (~1 ms/token). Use the Phase 15 argmax kernel (block reduction + 4-byte D2H) when logits are BF16 on GPU; bench-gpt-oss's greedy sampler likewise. Exact-tie logits may break differently than the host scan — greedy trajectories can legitimately diverge at a tie token (GSM8K unchanged). Co-Authored-By: Claude Fable 5 --- crates/xserv-model/src/sampling.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/xserv-model/src/sampling.rs b/crates/xserv-model/src/sampling.rs index 762b62e..efa2c4c 100644 --- a/crates/xserv-model/src/sampling.rs +++ b/crates/xserv-model/src/sampling.rs @@ -19,6 +19,18 @@ impl Default for SamplingParams { /// Uses the last position's logits. Handles both F32 and BF16 dtypes. pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 { assert_eq!(logits.ndim(), 2); + // Greedy fast path: GPU argmax + 4-byte D2H instead of copying the whole + // [seq, vocab] logits to the host and scanning it (~201k bf16/token). + // NaN logits lose every `>` comparison in the kernel, matching the + // NaN-safe host argmax below. + if params.temperature == 0.0 + && logits.dtype() == DType::BF16 + && matches!(logits.device(), Device::Cuda(_)) + && logits.is_contiguous() + { + let ids = xserv_kernels::argmax_bf16_to_host(logits); + return *ids.last().unwrap(); + } let vocab_size = logits.shape()[1]; let seq_len = logits.shape()[0]; let logits_cpu = logits.to_device(Device::Cpu);