server: GPU argmax fast path for greedy decode

When all active sequences use temperature=0, run argmax on the GPU and
only D2H the token ids (~B×4 bytes) instead of the full [B, vocab_size]
BF16 logits (~1.2 MB at B=4, Qwen3 vocab=152K). Mixed-sampling batches
fall back to the existing CPU path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Gahow Wang
2026-05-30 12:50:47 +08:00
parent c679f618fd
commit 9a01c60100

View File

@@ -260,23 +260,41 @@ impl Engine {
&tokens, &positions, &slots, &mut self.paged_cache,
);
// Sample per-sequence from batched logits [B, vocab_size]
let vocab_size = logits.shape()[1];
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
let data = logits_cpu.as_slice::<half::bf16>();
for (j, &i) in decode_indices.iter().enumerate() {
let row_start = j * vocab_size;
let row_logits = &data[row_start..row_start + vocab_size];
let next = if running[i].sampling.temperature == 0.0 {
row_logits.iter().enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
} else {
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
sample(&row_tensor, &running[i].sampling)
};
running[i].generated_tokens.push(next);
emit_token(&self.tokenizer, &mut running[i], next);
// Fast path: every active sequence is greedy → run argmax on
// the GPU and only D2H the chosen token ids (a few bytes per
// sequence) instead of the full [B, vocab_size] BF16 logits
// (~1.2 MB for B=4, Qwen3 vocab=152K).
let all_greedy = decode_indices.iter()
.all(|&i| running[i].sampling.temperature == 0.0);
if all_greedy {
let next_ids = xserv_kernels::argmax_bf16_to_host(&logits);
for (j, &i) in decode_indices.iter().enumerate() {
let next = next_ids[j];
running[i].generated_tokens.push(next);
emit_token(&self.tokenizer, &mut running[i], next);
}
} else {
// Mixed sampling: keep the CPU path for now (top-k/top-p
// sampling still runs there). Only the rows that need it
// get exercised; greedy rows could in principle reuse the
// GPU argmax but the CPU pass is short for B<=4.
let vocab_size = logits.shape()[1];
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
let data = logits_cpu.as_slice::<half::bf16>();
for (j, &i) in decode_indices.iter().enumerate() {
let row_start = j * vocab_size;
let row_logits = &data[row_start..row_start + vocab_size];
let next = if running[i].sampling.temperature == 0.0 {
row_logits.iter().enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
} else {
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
sample(&row_tensor, &running[i].sampling)
};
running[i].generated_tokens.push(next);
emit_token(&self.tokenizer, &mut running[i], next);
}
}
}