server: GPU argmax fast path for greedy decode
When all active sequences use temperature=0, run argmax on the GPU and only D2H the token ids (~B×4 bytes) instead of the full [B, vocab_size] BF16 logits (~1.2 MB at B=4, Qwen3 vocab=152K). Mixed-sampling batches fall back to the existing CPU path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -260,7 +260,24 @@ impl Engine {
|
||||
&tokens, &positions, &slots, &mut self.paged_cache,
|
||||
);
|
||||
|
||||
// Sample per-sequence from batched logits [B, vocab_size]
|
||||
// Fast path: every active sequence is greedy → run argmax on
|
||||
// the GPU and only D2H the chosen token ids (a few bytes per
|
||||
// sequence) instead of the full [B, vocab_size] BF16 logits
|
||||
// (~1.2 MB for B=4, Qwen3 vocab=152K).
|
||||
let all_greedy = decode_indices.iter()
|
||||
.all(|&i| running[i].sampling.temperature == 0.0);
|
||||
if all_greedy {
|
||||
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
|
||||
for (j, &i) in decode_indices.iter().enumerate() {
|
||||
let next = next_ids[j];
|
||||
running[i].generated_tokens.push(next);
|
||||
emit_token(&self.tokenizer, &mut running[i], next);
|
||||
}
|
||||
} else {
|
||||
// Mixed sampling: keep the CPU path for now (top-k/top-p
|
||||
// sampling still runs there). Only the rows that need it
|
||||
// get exercised; greedy rows could in principle reuse the
|
||||
// GPU argmax but the CPU pass is short for B<=4.
|
||||
let vocab_size = logits.shape()[1];
|
||||
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<half::bf16>();
|
||||
@@ -279,6 +296,7 @@ impl Engine {
|
||||
emit_token(&self.tokenizer, &mut running[i], next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Check for newly arrived requests (non-blocking)
|
||||
loop {
|
||||
|
||||
Reference in New Issue
Block a user