Files
xserv/csrc/reduce/argmax.cu
Gahow Wang 13ae3de69e kernels: reshape_and_cache, GPU argmax, single-launch GEMV
Three new CUDA kernels and one rewrite:

- reshape_and_cache: scatter K/V into paged pool in a single kernel per
  layer, replacing the Rust-side per-token per-head cudaMemcpy loop.
  Includes both single-sequence (prefill) and batched (decode) variants.

- argmax: GPU-side BF16 argmax with warp-shuffle reduction. Greedy
  decode now only D2H-transfers B×4 bytes (token ids) instead of the
  full [B, vocab] logits tensor.

- GEMV rewrite: fused zero-init inside the K-split kernel eliminates
  the cudaMemsetAsync call, reducing launches from 3 to 2 per GEMV.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-30 12:50:17 +08:00

93 lines
3.0 KiB
Plaintext

#include <cuda_bf16.h>
#include <float.h>
#include "../common.cuh"
// Argmax along the last dim of a [rows, cols] tensor.
// One block per row; output is [rows] int32 indices of the max element.
//
// Reduction: each thread scans a strided slice and tracks the running
// (value, index) pair, then warp-shuffle reduce, then a single-warp
// reduce over per-warp leaders. Tie-break: smaller index wins so the
// result is deterministic across launches.
//
// For BF16 logits the comparison happens in FP32 to avoid losing
// precision near the top of the distribution.
__global__ void argmax_bf16_kernel(
const __nv_bfloat16* __restrict__ logits,
int* __restrict__ out_idx,
int cols
) {
int row = blockIdx.x;
const __nv_bfloat16* row_ptr = logits + (long long)row * cols;
int tid = threadIdx.x;
unsigned mask = 0xffffffff;
// Strided per-thread max.
float local_max = -FLT_MAX;
int local_idx = INT_MAX;
for (int i = tid; i < cols; i += blockDim.x) {
float v = __bfloat162float(row_ptr[i]);
// strict `>` keeps the smallest index on ties, since we scan ascending.
if (v > local_max) {
local_max = v;
local_idx = i;
}
}
// Warp-level reduce of (val, idx) pairs.
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
float other_val = __shfl_down_sync(mask, local_max, offset);
int other_idx = __shfl_down_sync(mask, local_idx, offset);
bool take = (other_val > local_max) ||
(other_val == local_max && other_idx < local_idx);
if (take) {
local_max = other_val;
local_idx = other_idx;
}
}
// Per-warp leaders → shared memory → single warp final reduce.
__shared__ float s_val[32];
__shared__ int s_idx[32];
int lane = tid & 31;
int warp_id = tid >> 5;
int num_warps = (blockDim.x + 31) >> 5;
if (lane == 0) {
s_val[warp_id] = local_max;
s_idx[warp_id] = local_idx;
}
__syncthreads();
if (warp_id == 0) {
float v = (tid < num_warps) ? s_val[lane] : -FLT_MAX;
int i = (tid < num_warps) ? s_idx[lane] : INT_MAX;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
float ov = __shfl_down_sync(mask, v, offset);
int oi = __shfl_down_sync(mask, i, offset);
bool take = (ov > v) || (ov == v && oi < i);
if (take) { v = ov; i = oi; }
}
if (lane == 0) {
out_idx[row] = i;
}
}
}
extern "C" {
void launch_argmax_bf16(const void* logits, void* out_idx,
int rows, int cols, void* stream) {
// 1024 threads/block keeps occupancy high and gives 32 warps for the
// final reduce (matches the 32-slot shared arrays above).
int block = 1024;
argmax_bf16_kernel<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)logits, (int*)out_idx, cols);
CUDA_CHECK_LAST_ERROR();
}
}