server: GPU argmax fast path for greedy decode

When all active sequences use temperature=0, run argmax on the GPU and
only D2H the token ids (~B×4 bytes) instead of the full [B, vocab_size]
BF16 logits (~1.2 MB at B=4, Qwen3 vocab=152K). Mixed-sampling batches
fall back to the existing CPU path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Gahow Wang
2026-05-30 12:50:47 +08:00
parent c679f618fd
commit 9a01c60100

View File

@@ -260,7 +260,24 @@ impl Engine {
&tokens, &positions, &slots, &mut self.paged_cache,
);
// Sample per-sequence from batched logits [B, vocab_size]
// Fast path: every active sequence is greedy → run argmax on
// the GPU and only D2H the chosen token ids (a few bytes per
// sequence) instead of the full [B, vocab_size] BF16 logits
// (~1.2 MB for B=4, Qwen3 vocab=152K).
let all_greedy = decode_indices.iter()
.all(|&i| running[i].sampling.temperature == 0.0);
if all_greedy {
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
for (j, &i) in decode_indices.iter().enumerate() {
let next = next_ids[j];
running[i].generated_tokens.push(next);
emit_token(&self.tokenizer, &mut running[i], next);
}
} else {
// Mixed sampling: keep the CPU path for now (top-k/top-p
// sampling still runs there). Only the rows that need it
// get exercised; greedy rows could in principle reuse the
// GPU argmax but the CPU pass is short for B<=4.
let vocab_size = logits.shape()[1];
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
let data = logits_cpu.as_slice::<half::bf16>();
@@ -279,6 +296,7 @@ impl Engine {
emit_token(&self.tokenizer, &mut running[i], next);
}
}
}
// Step 6: Check for newly arrived requests (non-blocking)
loop {