sampling: GPU argmax fast path for greedy decode
sample() at temperature 0 copied the full [seq, 201088] BF16 logits to the host and scanned them every token (~1 ms/token). Use the Phase 15 argmax kernel (block reduction + 4-byte D2H) when logits are BF16 on GPU; bench-gpt-oss's greedy sampler likewise. Exact-tie logits may break differently than the host scan — greedy trajectories can legitimately diverge at a tie token (GSM8K unchanged). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -19,6 +19,18 @@ impl Default for SamplingParams {
|
|||||||
/// Uses the last position's logits. Handles both F32 and BF16 dtypes.
|
/// Uses the last position's logits. Handles both F32 and BF16 dtypes.
|
||||||
pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||||
assert_eq!(logits.ndim(), 2);
|
assert_eq!(logits.ndim(), 2);
|
||||||
|
// Greedy fast path: GPU argmax + 4-byte D2H instead of copying the whole
|
||||||
|
// [seq, vocab] logits to the host and scanning it (~201k bf16/token).
|
||||||
|
// NaN logits lose every `>` comparison in the kernel, matching the
|
||||||
|
// NaN-safe host argmax below.
|
||||||
|
if params.temperature == 0.0
|
||||||
|
&& logits.dtype() == DType::BF16
|
||||||
|
&& matches!(logits.device(), Device::Cuda(_))
|
||||||
|
&& logits.is_contiguous()
|
||||||
|
{
|
||||||
|
let ids = xserv_kernels::argmax_bf16_to_host(logits);
|
||||||
|
return *ids.last().unwrap();
|
||||||
|
}
|
||||||
let vocab_size = logits.shape()[1];
|
let vocab_size = logits.shape()[1];
|
||||||
let seq_len = logits.shape()[0];
|
let seq_len = logits.shape()[0];
|
||||||
let logits_cpu = logits.to_device(Device::Cpu);
|
let logits_cpu = logits.to_device(Device::Cpu);
|
||||||
|
|||||||
Reference in New Issue
Block a user