Three new CUDA kernels and one rewrite: - reshape_and_cache: scatter K/V into paged pool in a single kernel per layer, replacing the Rust-side per-token per-head cudaMemcpy loop. Includes both single-sequence (prefill) and batched (decode) variants. - argmax: GPU-side BF16 argmax with warp-shuffle reduction. Greedy decode now only D2H-transfers B×4 bytes (token ids) instead of the full [B, vocab] logits tensor. - GEMV rewrite: fused zero-init inside the K-split kernel eliminates the cudaMemsetAsync call, reducing launches from 3 to 2 per GEMV. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
93 lines
3.0 KiB
Plaintext
93 lines
3.0 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <float.h>
|
|
#include "../common.cuh"
|
|
|
|
// Argmax along the last dim of a [rows, cols] tensor.
|
|
// One block per row; output is [rows] int32 indices of the max element.
|
|
//
|
|
// Reduction: each thread scans a strided slice and tracks the running
|
|
// (value, index) pair, then warp-shuffle reduce, then a single-warp
|
|
// reduce over per-warp leaders. Tie-break: smaller index wins so the
|
|
// result is deterministic across launches.
|
|
//
|
|
// For BF16 logits the comparison happens in FP32 to avoid losing
|
|
// precision near the top of the distribution.
|
|
|
|
__global__ void argmax_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ logits,
|
|
int* __restrict__ out_idx,
|
|
int cols
|
|
) {
|
|
int row = blockIdx.x;
|
|
const __nv_bfloat16* row_ptr = logits + (long long)row * cols;
|
|
int tid = threadIdx.x;
|
|
unsigned mask = 0xffffffff;
|
|
|
|
// Strided per-thread max.
|
|
float local_max = -FLT_MAX;
|
|
int local_idx = INT_MAX;
|
|
for (int i = tid; i < cols; i += blockDim.x) {
|
|
float v = __bfloat162float(row_ptr[i]);
|
|
// strict `>` keeps the smallest index on ties, since we scan ascending.
|
|
if (v > local_max) {
|
|
local_max = v;
|
|
local_idx = i;
|
|
}
|
|
}
|
|
|
|
// Warp-level reduce of (val, idx) pairs.
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
|
float other_val = __shfl_down_sync(mask, local_max, offset);
|
|
int other_idx = __shfl_down_sync(mask, local_idx, offset);
|
|
bool take = (other_val > local_max) ||
|
|
(other_val == local_max && other_idx < local_idx);
|
|
if (take) {
|
|
local_max = other_val;
|
|
local_idx = other_idx;
|
|
}
|
|
}
|
|
|
|
// Per-warp leaders → shared memory → single warp final reduce.
|
|
__shared__ float s_val[32];
|
|
__shared__ int s_idx[32];
|
|
int lane = tid & 31;
|
|
int warp_id = tid >> 5;
|
|
int num_warps = (blockDim.x + 31) >> 5;
|
|
|
|
if (lane == 0) {
|
|
s_val[warp_id] = local_max;
|
|
s_idx[warp_id] = local_idx;
|
|
}
|
|
__syncthreads();
|
|
|
|
if (warp_id == 0) {
|
|
float v = (tid < num_warps) ? s_val[lane] : -FLT_MAX;
|
|
int i = (tid < num_warps) ? s_idx[lane] : INT_MAX;
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
|
float ov = __shfl_down_sync(mask, v, offset);
|
|
int oi = __shfl_down_sync(mask, i, offset);
|
|
bool take = (ov > v) || (ov == v && oi < i);
|
|
if (take) { v = ov; i = oi; }
|
|
}
|
|
if (lane == 0) {
|
|
out_idx[row] = i;
|
|
}
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_argmax_bf16(const void* logits, void* out_idx,
|
|
int rows, int cols, void* stream) {
|
|
// 1024 threads/block keeps occupancy high and gives 32 warps for the
|
|
// final reduce (matches the 32-slot shared arrays above).
|
|
int block = 1024;
|
|
argmax_bf16_kernel<<<rows, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)logits, (int*)out_idx, cols);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|