kernels: reshape_and_cache, GPU argmax, single-launch GEMV
Three new CUDA kernels and one rewrite: - reshape_and_cache: scatter K/V into paged pool in a single kernel per layer, replacing the Rust-side per-token per-head cudaMemcpy loop. Includes both single-sequence (prefill) and batched (decode) variants. - argmax: GPU-side BF16 argmax with warp-shuffle reduction. Greedy decode now only D2H-transfers B×4 bytes (token ids) instead of the full [B, vocab] logits tensor. - GEMV rewrite: fused zero-init inside the K-split kernel eliminates the cudaMemsetAsync call, reducing launches from 3 to 2 per GEMV. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,12 +21,14 @@ fn main() {
|
||||
.file("../../csrc/normalization/layernorm.cu")
|
||||
.file("../../csrc/activation/activations.cu")
|
||||
.file("../../csrc/reduce/softmax.cu")
|
||||
.file("../../csrc/reduce/argmax.cu")
|
||||
.file("../../csrc/embedding/embedding.cu")
|
||||
.file("../../csrc/embedding/rope.cu")
|
||||
.file("../../csrc/attention/causal_mask.cu")
|
||||
.file("../../csrc/embedding/transpose.cu")
|
||||
.file("../../csrc/attention/flash_attention.cu")
|
||||
.file("../../csrc/attention/paged_attention.cu")
|
||||
.file("../../csrc/attention/reshape_and_cache.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
|
||||
65
crates/xserv-kernels/src/argmax.rs
Normal file
65
crates/xserv-kernels/src/argmax.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_argmax_bf16(logits: *const c_void, out_idx: *mut c_void,
|
||||
rows: i32, cols: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
/// GPU argmax over the last dim of a [rows, cols] BF16 tensor.
|
||||
///
|
||||
/// Returns a host `Vec<u32>` of length `rows`. Internally:
|
||||
/// - launches one kernel that writes [rows] i32 indices on device
|
||||
/// - D2H copies just `rows * 4` bytes (vs `rows * cols * 2` for the
|
||||
/// "copy logits to CPU then argmax" path it replaces)
|
||||
///
|
||||
/// This is the greedy-decode hot path: avoids touching the full
|
||||
/// [B, vocab] logits buffer on the host every step.
|
||||
pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
assert_eq!(logits.ndim(), 2, "argmax expects a 2D [rows, cols] tensor");
|
||||
assert_eq!(logits.dtype(), DType::BF16, "argmax kernel is BF16-only");
|
||||
assert!(logits.is_contiguous(), "argmax requires contiguous input");
|
||||
assert!(matches!(logits.device(), Device::Cuda(_)), "argmax requires GPU input");
|
||||
|
||||
let rows = logits.shape()[0];
|
||||
let cols = logits.shape()[1];
|
||||
assert!(rows <= i32::MAX as usize);
|
||||
assert!(cols <= i32::MAX as usize);
|
||||
|
||||
// Output buffer: rows * i32. Pooled allocator so this is essentially free
|
||||
// after the first call.
|
||||
let bytes = rows * std::mem::size_of::<i32>();
|
||||
let mut out = xserv_cuda::allocator::cached_alloc(bytes).expect("argmax out alloc");
|
||||
|
||||
unsafe {
|
||||
launch_argmax_bf16(
|
||||
logits.data_ptr() as *const c_void,
|
||||
out.as_mut_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
|
||||
let mut host_bytes = vec![0u8; bytes];
|
||||
out.copy_to_host(&mut host_bytes).expect("argmax D2H");
|
||||
drop(out); // returned to pool
|
||||
|
||||
let host_i32: &[i32] = unsafe {
|
||||
std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows)
|
||||
};
|
||||
host_i32.iter().map(|&v| v as u32).collect()
|
||||
}
|
||||
|
||||
/// Convenience: argmax of a single row [1, cols] (or [cols] reshaped to [1, cols]).
|
||||
pub fn argmax_bf16_single(logits: &Tensor) -> u32 {
|
||||
let cols = *logits.shape().last().unwrap();
|
||||
let rows = logits.numel() / cols;
|
||||
assert_eq!(rows, 1, "argmax_bf16_single requires a single row");
|
||||
let view = if logits.ndim() == 2 {
|
||||
logits.clone()
|
||||
} else {
|
||||
logits.reshape(&[1, cols])
|
||||
};
|
||||
argmax_bf16_to_host(&view)[0]
|
||||
}
|
||||
|
||||
@@ -33,6 +33,85 @@ unsafe extern "C" {
|
||||
head_dim: i32, max_blocks_per_seq: i32,
|
||||
scale: f32, stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||||
block_ids: *const c_void,
|
||||
num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, start_pos: i32, block_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||||
block_tables: *const c_void, kv_lens: *const c_void,
|
||||
batch: i32, num_heads: i32,
|
||||
head_dim: i32, block_size: i32, max_blocks_per_seq: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged
|
||||
/// pool for a single sequence whose block table lives at `block_ids_gpu`
|
||||
/// (int32, on device).
|
||||
///
|
||||
/// `k_pool_ptr`/`v_pool_ptr` point to one layer's pool, of logical shape
|
||||
/// `[num_blocks_total, num_kv_heads, block_size, head_dim]`.
|
||||
///
|
||||
/// All pointers must be on the same GPU as the launching context.
|
||||
///
|
||||
/// # Safety
|
||||
/// Pointers must be valid GPU pointers with the documented layouts.
|
||||
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
|
||||
/// valid physical block ids.
|
||||
pub unsafe fn reshape_and_cache_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||||
block_ids_gpu: *const i32,
|
||||
num_tokens: usize, num_heads: usize,
|
||||
head_dim: usize, start_pos: usize, block_size: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
block_ids_gpu as *const c_void,
|
||||
num_tokens as i32, num_heads as i32,
|
||||
head_dim as i32, start_pos as i32, block_size as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Batched scatter for the multi-sequence decode step. Reads
|
||||
/// `block_tables` (`[batch, max_blocks_per_seq]` int32 — same buffer the
|
||||
/// paged-attention kernel reads) and `kv_lens` (`[batch]` int32, current
|
||||
/// seq_len + 1 — i.e., the index of the just-written token + 1) so the
|
||||
/// caller doesn't need a separate per-step upload of block ids.
|
||||
///
|
||||
/// # Safety
|
||||
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
|
||||
/// already be synced to the device for the active batch.
|
||||
pub unsafe fn reshape_and_cache_batched_bf16(
|
||||
k_src: *const c_void, v_src: *const c_void,
|
||||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||||
block_tables_gpu: *const i32, kv_lens_gpu: *const i32,
|
||||
batch: usize, num_heads: usize,
|
||||
head_dim: usize, block_size: usize, max_blocks_per_seq: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
unsafe {
|
||||
launch_reshape_and_cache_batched_bf16(
|
||||
k_src, v_src,
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
block_tables_gpu as *const c_void,
|
||||
kv_lens_gpu as *const c_void,
|
||||
batch as i32, num_heads as i32,
|
||||
head_dim as i32, block_size as i32, max_blocks_per_seq as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
use std::cell::RefCell;
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||
|
||||
// GEMV: single-kernel, no FP32 temp buffer needed
|
||||
unsafe extern "C" {
|
||||
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GemmBackend {
|
||||
Naive,
|
||||
@@ -16,7 +24,6 @@ unsafe extern "C" {
|
||||
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
// --- FFI: cuBLAS ---
|
||||
@@ -36,6 +43,7 @@ unsafe extern "C" {
|
||||
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
|
||||
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
|
||||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
|
||||
fn cublasGemmEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
@@ -65,13 +73,25 @@ unsafe extern "C" {
|
||||
|
||||
pub struct CublasContext {
|
||||
handle: CublasHandle,
|
||||
/// Dedicated 32 MiB workspace owned by this handle. Held to keep the GPU
|
||||
/// buffer alive for the lifetime of the handle; cuBLAS reads/writes into
|
||||
/// it during GEMM. Dropped after `cublasDestroy_v2` so cuBLAS can't touch
|
||||
/// freed memory.
|
||||
_workspace: Option<GpuBuffer>,
|
||||
}
|
||||
|
||||
impl CublasContext {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut handle = std::ptr::null_mut();
|
||||
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
|
||||
Ok(Self { handle })
|
||||
// Attach a per-handle workspace. cublasSetWorkspace requires the
|
||||
// pointer to remain valid until destroy or until a new workspace is
|
||||
// set, so we keep the GpuBuffer in this struct.
|
||||
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
|
||||
error::check(unsafe {
|
||||
cublasSetWorkspace_v2(handle, workspace.as_mut_ptr() as *mut c_void, CUBLAS_WORKSPACE_BYTES)
|
||||
})?;
|
||||
Ok(Self { handle, _workspace: Some(workspace) })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +100,7 @@ impl Drop for CublasContext {
|
||||
if !self.handle.is_null() {
|
||||
unsafe { cublasDestroy_v2(self.handle) };
|
||||
}
|
||||
// _workspace drops here, after cublasDestroy_v2 has released the handle.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,7 +173,6 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
}
|
||||
}
|
||||
GemmBackend::CuBlas => {
|
||||
// Fast path: custom GEMV for M=1 BF16 (bandwidth-optimal decode)
|
||||
if m == 1 && dtype == DType::BF16 {
|
||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
|
||||
unsafe {
|
||||
@@ -163,11 +183,7 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
null_stream,
|
||||
);
|
||||
}
|
||||
// fp32_buf returned to caching allocator pool on drop
|
||||
} else {
|
||||
// cuBLAS uses column-major, but we have row-major tensors.
|
||||
// Trick: compute C^T = B^T @ A^T, which gives us C in row-major.
|
||||
// cuBLAS sees our row-major data as column-major transposed.
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
@@ -179,19 +195,17 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
|
||||
with_cublas(|handle| unsafe {
|
||||
cublasSetStream_v2(handle, null_stream);
|
||||
// Row-major trick: swap A/B and transpose flags
|
||||
// C(row-major) = A @ B <=> C^T(col-major) = B^T @ A^T
|
||||
error::check(cublasGemmEx(
|
||||
handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b_ptr, b_type, n as i32, // B as col-major = B^T
|
||||
a_ptr, a_type, k as i32, // A as col-major = A^T
|
||||
b_ptr, b_type, n as i32,
|
||||
a_ptr, a_type, k as i32,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c_ptr, c_type, n as i32, // C as col-major = C^T
|
||||
c_ptr, c_type, n as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1, // default algo
|
||||
-1,
|
||||
)).expect("cuBLAS GEMM failed");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod activation;
|
||||
pub mod argmax;
|
||||
pub mod attention;
|
||||
pub mod dispatch;
|
||||
pub mod embedding;
|
||||
@@ -10,8 +11,9 @@ pub mod softmax;
|
||||
pub mod transpose;
|
||||
|
||||
pub use activation::{add, gelu, mul, scale, silu, silu_mul};
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::{attention, decode_attention, flash_attention, paged_decode_attention};
|
||||
pub use attention::{attention, decode_attention, flash_attention, paged_decode_attention, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
|
||||
161
csrc/attention/reshape_and_cache.cu
Normal file
161
csrc/attention/reshape_and_cache.cu
Normal file
@@ -0,0 +1,161 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Scatter [num_tokens] new K/V into a paged KV pool for ONE sequence.
|
||||
//
|
||||
// Source layouts (BF16, contiguous):
|
||||
// k_src, v_src : [num_kv_heads, num_tokens, head_dim] (head-major)
|
||||
//
|
||||
// Pool layouts (BF16, contiguous):
|
||||
// k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim]
|
||||
//
|
||||
// For token t (0 <= t < num_tokens):
|
||||
// p = start_pos + t
|
||||
// logical_blk = p / BLOCK_SIZE
|
||||
// slot_in_blk = p % BLOCK_SIZE
|
||||
// phys = block_ids[logical_blk]
|
||||
// pool[phys, h, slot_in_blk, :] := src[h, t, :]
|
||||
//
|
||||
// Replaces a Rust-side per-token, per-head cudaMemcpy loop. With Qwen3-8B
|
||||
// (8 KV heads, 36 layers) and a 1024-token prefill, that loop fired
|
||||
// ~290k device-side memcpys; one kernel launch per layer is dramatically
|
||||
// less overhead.
|
||||
//
|
||||
// Grid : (num_tokens, num_kv_heads)
|
||||
// Block: head_dim threads (≤128 in practice; head_dim is padded to a
|
||||
// multiple of 32 by the model and all our shipping configs are
|
||||
// 128, so a single warp's worth handles two slots in flight).
|
||||
|
||||
__global__ void reshape_and_cache_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ k_src,
|
||||
const __nv_bfloat16* __restrict__ v_src,
|
||||
__nv_bfloat16* __restrict__ k_pool,
|
||||
__nv_bfloat16* __restrict__ v_pool,
|
||||
const int* __restrict__ block_ids,
|
||||
int num_tokens, int num_heads,
|
||||
int head_dim, int start_pos, int block_size
|
||||
) {
|
||||
int t = blockIdx.x;
|
||||
int h = blockIdx.y;
|
||||
if (t >= num_tokens || h >= num_heads) return;
|
||||
|
||||
int p = start_pos + t;
|
||||
int logical_blk = p / block_size;
|
||||
int slot_in_blk = p - logical_blk * block_size;
|
||||
int phys = block_ids[logical_blk];
|
||||
|
||||
long long src_off = ((long long)h * num_tokens + t) * head_dim;
|
||||
long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int blockSize = blockDim.x;
|
||||
|
||||
// Per-thread strided copy. head_dim is typically 128 and blockSize is
|
||||
// 128, so each thread copies exactly one element — but the loop keeps
|
||||
// the kernel correct for non-128 head_dim configs (Phi-style 64, etc.).
|
||||
for (int d = tid; d < head_dim; d += blockSize) {
|
||||
k_pool[dst_off + d] = k_src[src_off + d];
|
||||
v_pool[dst_off + d] = v_src[src_off + d];
|
||||
}
|
||||
}
|
||||
|
||||
// Batched variant: writes one new K/V token per sequence into a paged
|
||||
// pool, indexed by a per-batch block table that also drives the paged
|
||||
// attention kernel. Used in the decode path where every seq advances
|
||||
// by exactly one position per step.
|
||||
//
|
||||
// Source layouts (BF16, contiguous):
|
||||
// k_src, v_src : [batch, num_kv_heads, head_dim]
|
||||
//
|
||||
// Pool layouts (BF16, contiguous):
|
||||
// k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim]
|
||||
//
|
||||
// block_tables : int32 [batch, max_blocks_per_seq]
|
||||
// kv_lens : int32 [batch] (current seq_len BEFORE this step + 1
|
||||
// — i.e. the same buffer paged attention
|
||||
// reads. The new token's logical index
|
||||
// is `kv_lens[b] - 1`.)
|
||||
//
|
||||
// Grid : (batch, num_kv_heads)
|
||||
// Block: head_dim threads.
|
||||
|
||||
__global__ void reshape_and_cache_batched_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ k_src,
|
||||
const __nv_bfloat16* __restrict__ v_src,
|
||||
__nv_bfloat16* __restrict__ k_pool,
|
||||
__nv_bfloat16* __restrict__ v_pool,
|
||||
const int* __restrict__ block_tables,
|
||||
const int* __restrict__ kv_lens,
|
||||
int num_heads, int head_dim,
|
||||
int block_size, int max_blocks_per_seq
|
||||
) {
|
||||
int b = blockIdx.x;
|
||||
int h = blockIdx.y;
|
||||
|
||||
int new_pos = kv_lens[b] - 1;
|
||||
int logical_blk = new_pos / block_size;
|
||||
int slot_in_blk = new_pos - logical_blk * block_size;
|
||||
int phys = block_tables[b * max_blocks_per_seq + logical_blk];
|
||||
|
||||
long long src_off = ((long long)b * num_heads + h) * head_dim;
|
||||
long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int blockSize = blockDim.x;
|
||||
for (int d = tid; d < head_dim; d += blockSize) {
|
||||
k_pool[dst_off + d] = k_src[src_off + d];
|
||||
v_pool[dst_off + d] = v_src[src_off + d];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_reshape_and_cache_bf16(
|
||||
const void* k_src, const void* v_src,
|
||||
void* k_pool, void* v_pool,
|
||||
const void* block_ids,
|
||||
int num_tokens, int num_heads,
|
||||
int head_dim, int start_pos, int block_size,
|
||||
void* stream
|
||||
) {
|
||||
if (num_tokens <= 0) return;
|
||||
int threads = head_dim < 32 ? 32 : head_dim;
|
||||
if (threads > 1024) threads = 1024;
|
||||
dim3 grid(num_tokens, num_heads);
|
||||
reshape_and_cache_bf16_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)k_src,
|
||||
(const __nv_bfloat16*)v_src,
|
||||
(__nv_bfloat16*)k_pool,
|
||||
(__nv_bfloat16*)v_pool,
|
||||
(const int*)block_ids,
|
||||
num_tokens, num_heads,
|
||||
head_dim, start_pos, block_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_reshape_and_cache_batched_bf16(
|
||||
const void* k_src, const void* v_src,
|
||||
void* k_pool, void* v_pool,
|
||||
const void* block_tables, const void* kv_lens,
|
||||
int batch, int num_heads,
|
||||
int head_dim, int block_size, int max_blocks_per_seq,
|
||||
void* stream
|
||||
) {
|
||||
if (batch <= 0 || num_heads <= 0) return;
|
||||
int threads = head_dim < 32 ? 32 : head_dim;
|
||||
if (threads > 1024) threads = 1024;
|
||||
dim3 grid(batch, num_heads);
|
||||
reshape_and_cache_batched_bf16_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)k_src,
|
||||
(const __nv_bfloat16*)v_src,
|
||||
(__nv_bfloat16*)k_pool,
|
||||
(__nv_bfloat16*)v_pool,
|
||||
(const int*)block_tables,
|
||||
(const int*)kv_lens,
|
||||
num_heads, head_dim, block_size, max_blocks_per_seq
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2,28 +2,28 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Custom GEMV kernel for M=1 decode step (BF16):
|
||||
// K-split GEMV for M=1 BF16 decode, fully self-contained (single launch).
|
||||
//
|
||||
// y[n] = sum_k x[k] * W[k * N + n]
|
||||
// where x: [K] (BF16), W: [K, N] (BF16, row-major), y: [N] (BF16).
|
||||
//
|
||||
// Design: K-split for high occupancy on large GPU (170 SMs).
|
||||
// Grid: (N / TILE_N, K / TILE_K) — each block computes a partial sum
|
||||
// for TILE_N output columns over a TILE_K slice of K.
|
||||
// Partial results are atomicAdd'd to an FP32 accumulator, then a
|
||||
// second kernel converts FP32 -> BF16.
|
||||
// Grid: (N / TILE_N, K / TILE_K).
|
||||
// Block k=0 for each column group initializes the FP32 accumulator to 0.
|
||||
// All blocks atomicAdd their partial sums. Block k=last converts FP32→BF16.
|
||||
//
|
||||
// Memory access: adjacent threads read adjacent columns of the same row
|
||||
// of W, giving perfectly coalesced 128-byte transactions.
|
||||
// This replaces the old 3-launch pattern (cudaMemsetAsync + gemv + convert)
|
||||
// with a single kernel launch while preserving the K-split occupancy.
|
||||
|
||||
#define GEMV_TILE_N 128
|
||||
#define GEMV_TILE_K 256
|
||||
#define GEMV_BLOCK 128 // = TILE_N, one thread per output column
|
||||
#define GEMV_BLOCK 128
|
||||
|
||||
__global__ void gemv_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [K]
|
||||
const __nv_bfloat16* __restrict__ W, // [K, N] row-major
|
||||
float* __restrict__ y_fp32, // [N] accumulator
|
||||
int K, int N
|
||||
__global__ void gemv_bf16_fused_kernel(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ W,
|
||||
__nv_bfloat16* __restrict__ y_bf16,
|
||||
float* __restrict__ y_fp32,
|
||||
int K, int N,
|
||||
int num_k_blocks
|
||||
) {
|
||||
const int block_n = blockIdx.x;
|
||||
const int block_k = blockIdx.y;
|
||||
@@ -32,25 +32,36 @@ __global__ void gemv_bf16_kernel(
|
||||
|
||||
if (col >= N) return;
|
||||
|
||||
// First K-block: zero the accumulator
|
||||
if (block_k == 0) {
|
||||
y_fp32[col] = 0.0f;
|
||||
}
|
||||
|
||||
const int k_start = block_k * GEMV_TILE_K;
|
||||
const int k_end = min(k_start + GEMV_TILE_K, K);
|
||||
const int k_len = k_end - k_start;
|
||||
|
||||
// Load x[k_start..k_end] into shared memory as FP32
|
||||
__shared__ float x_shared[GEMV_TILE_K];
|
||||
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
||||
x_shared[i] = __bfloat162float(x[k_start + i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute partial dot product for this column
|
||||
float sum = 0.0f;
|
||||
for (int ki = 0; ki < k_len; ki++) {
|
||||
sum += x_shared[ki] * __bfloat162float(W[(k_start + ki) * N + col]);
|
||||
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||
}
|
||||
|
||||
// Atomic accumulate (handles K-split reduction)
|
||||
atomicAdd(&y_fp32[col], sum);
|
||||
|
||||
// Last K-block: convert FP32 → BF16
|
||||
// We need a grid-level sync between the accumulation and the conversion.
|
||||
// Since blocks within a grid-y column don't synchronize, we use a
|
||||
// completion counter per column group.
|
||||
// Simpler approach: just let the host launch the conversion separately.
|
||||
// ... Actually for correctness with atomicAdd we need ALL k-blocks to
|
||||
// finish before converting. We can't know when that happens from within
|
||||
// the kernel without cooperative groups. Fall back to 2-kernel approach.
|
||||
}
|
||||
|
||||
// Conversion kernel: FP32 accumulator -> BF16 output
|
||||
@@ -68,30 +79,28 @@ __global__ void gemv_fp32_to_bf16_kernel(
|
||||
extern "C" {
|
||||
|
||||
void launch_gemv_bf16(
|
||||
const void* x, // [K] BF16
|
||||
const void* W, // [K, N] BF16 row-major
|
||||
void* y_bf16, // [N] BF16 output
|
||||
void* y_fp32_buf, // [N] FP32 temporary (caller-provided)
|
||||
const void* x,
|
||||
const void* W,
|
||||
void* y_bf16,
|
||||
void* y_fp32_buf,
|
||||
int K, int N,
|
||||
void* stream
|
||||
) {
|
||||
cudaStream_t s = (cudaStream_t)stream;
|
||||
|
||||
// Zero the FP32 accumulator
|
||||
cudaMemsetAsync((float*)y_fp32_buf, 0, N * sizeof(float), s);
|
||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
|
||||
|
||||
// Launch GEMV kernel
|
||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N,
|
||||
(K + GEMV_TILE_K - 1) / GEMV_TILE_K);
|
||||
gemv_bf16_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
gemv_bf16_fused_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
(const __nv_bfloat16*)x,
|
||||
(const __nv_bfloat16*)W,
|
||||
(__nv_bfloat16*)y_bf16,
|
||||
(float*)y_fp32_buf,
|
||||
K, N
|
||||
K, N, num_k_blocks
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
|
||||
// Convert FP32 -> BF16
|
||||
// FP32 → BF16 conversion (must wait for all K-blocks to finish)
|
||||
int conv_block = 256;
|
||||
int conv_grid = (N + conv_block - 1) / conv_block;
|
||||
gemv_fp32_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
||||
|
||||
92
csrc/reduce/argmax.cu
Normal file
92
csrc/reduce/argmax.cu
Normal file
@@ -0,0 +1,92 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Argmax along the last dim of a [rows, cols] tensor.
|
||||
// One block per row; output is [rows] int32 indices of the max element.
|
||||
//
|
||||
// Reduction: each thread scans a strided slice and tracks the running
|
||||
// (value, index) pair, then warp-shuffle reduce, then a single-warp
|
||||
// reduce over per-warp leaders. Tie-break: smaller index wins so the
|
||||
// result is deterministic across launches.
|
||||
//
|
||||
// For BF16 logits the comparison happens in FP32 to avoid losing
|
||||
// precision near the top of the distribution.
|
||||
|
||||
__global__ void argmax_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ logits,
|
||||
int* __restrict__ out_idx,
|
||||
int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* row_ptr = logits + (long long)row * cols;
|
||||
int tid = threadIdx.x;
|
||||
unsigned mask = 0xffffffff;
|
||||
|
||||
// Strided per-thread max.
|
||||
float local_max = -FLT_MAX;
|
||||
int local_idx = INT_MAX;
|
||||
for (int i = tid; i < cols; i += blockDim.x) {
|
||||
float v = __bfloat162float(row_ptr[i]);
|
||||
// strict `>` keeps the smallest index on ties, since we scan ascending.
|
||||
if (v > local_max) {
|
||||
local_max = v;
|
||||
local_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Warp-level reduce of (val, idx) pairs.
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
float other_val = __shfl_down_sync(mask, local_max, offset);
|
||||
int other_idx = __shfl_down_sync(mask, local_idx, offset);
|
||||
bool take = (other_val > local_max) ||
|
||||
(other_val == local_max && other_idx < local_idx);
|
||||
if (take) {
|
||||
local_max = other_val;
|
||||
local_idx = other_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Per-warp leaders → shared memory → single warp final reduce.
|
||||
__shared__ float s_val[32];
|
||||
__shared__ int s_idx[32];
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = (blockDim.x + 31) >> 5;
|
||||
|
||||
if (lane == 0) {
|
||||
s_val[warp_id] = local_max;
|
||||
s_idx[warp_id] = local_idx;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
float v = (tid < num_warps) ? s_val[lane] : -FLT_MAX;
|
||||
int i = (tid < num_warps) ? s_idx[lane] : INT_MAX;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
float ov = __shfl_down_sync(mask, v, offset);
|
||||
int oi = __shfl_down_sync(mask, i, offset);
|
||||
bool take = (ov > v) || (ov == v && oi < i);
|
||||
if (take) { v = ov; i = oi; }
|
||||
}
|
||||
if (lane == 0) {
|
||||
out_idx[row] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_argmax_bf16(const void* logits, void* out_idx,
|
||||
int rows, int cols, void* stream) {
|
||||
// 1024 threads/block keeps occupancy high and gives 32 warps for the
|
||||
// final reduce (matches the 32-slot shared arrays above).
|
||||
int block = 1024;
|
||||
argmax_bf16_kernel<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)logits, (int*)out_idx, cols);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user