kernels/cuda: paged-attention kernel, dispatch, pinned host memory

CUDA layer for the paged-KV + swap work:
- csrc: new paged_attention.cu plus updates across attention/gemm/norm/
  activation/embedding/reduce kernels and common.cuh.
- xserv-kernels: new dispatch module and kernel-binding updates.
- xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap
  pool backing) and offset-aware D2H/H2D copies used to move KV blocks
  between the GPU pool and pinned host memory.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-28 19:58:36 +08:00
parent 3f1c3d429a
commit 4c3f914459
27 changed files with 581 additions and 32 deletions

View File

@@ -1,5 +1,6 @@
#include <cuda_bf16.h>
#include <math.h>
#include "../common.cuh"
// GELU (tanh approximation):
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
@@ -83,6 +84,7 @@ void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
gelu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
@@ -90,12 +92,14 @@ void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
int grid = (n + block - 1) / block;
gelu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_silu_f32(const void* x, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
silu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
@@ -103,6 +107,7 @@ void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
int grid = (n + block - 1) / block;
silu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
@@ -110,6 +115,7 @@ void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream
int grid = (n + block - 1) / block;
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (float*)out, scale, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
@@ -117,6 +123,7 @@ void launch_scale_bf16(const void* x, void* out, float scale, int n, void* strea
int grid = (n + block - 1) / block;
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream) {
@@ -124,24 +131,28 @@ void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream
int grid = (n + block - 1) / block;
add_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)a, (const float*)b, (float*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
add_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
mul_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)a, (const float*)b, (float*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, void* stream) {
@@ -149,6 +160,7 @@ void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, vo
int grid = (n + block - 1) / block;
silu_mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)gate, (const __nv_bfloat16*)up, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -1,4 +1,5 @@
#include <cuda_bf16.h>
#include "../common.cuh"
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
// offset is used for KV cache: when query starts at position `offset`,
@@ -39,6 +40,7 @@ void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
dim3 grid((cols + block - 1) / block, rows, batch);
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(float*)scores, rows, cols, offset);
CUDA_CHECK_LAST_ERROR();
}
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
@@ -47,6 +49,7 @@ void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
dim3 grid((cols + block - 1) / block, rows, batch);
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)scores, rows, cols, offset);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -1,5 +1,6 @@
#include <cuda_bf16.h>
#include <float.h>
#include "../common.cuh"
// Flash Attention 2 forward kernel for BF16 with FP32 accumulation.
//
@@ -391,6 +392,7 @@ void launch_flash_attention_bf16(
q_len, kv_len, head_dim,
scale, causal
);
CUDA_CHECK_LAST_ERROR();
}
void launch_decode_attention_bf16(
@@ -411,6 +413,7 @@ void launch_decode_attention_bf16(
kv_len, head_dim,
scale
);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -0,0 +1,215 @@
#include <cuda_bf16.h>
#include <float.h>
#include "../common.cuh"
// Paged decode attention kernel for BF16 with FP32 accumulation.
//
// Reads K/V from a paged pool indexed by a per-sequence block table.
// One CUDA block per (sequence, q_head). Each block streams over the
// sequence's KV positions and accumulates attention output via online
// softmax.
//
// Layouts:
// Q [batch, num_q_heads, 1, head_dim] BF16
// K_cache [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16
// V_cache same
// block_tables [max_seqs, max_blocks_per_seq] int32
// — the i-th sequence in this launch reads row
// block_tables[seq_slot[i] * stride + ...].
// For simplicity the launch passes a packed row table
// [batch, max_blocks_per_seq] (already gathered for the
// active batch) so we just index by blockIdx.x_seq.
// context_lens [batch] int32 — number of valid tokens per sequence.
//
// One CUDA block: 256 threads, head_dim <= 128.
#define PAGED_BLOCK_SIZE 16
#define PAGED_THREADS 256
#define PAGED_HEAD_DIM_MAX 128
__global__ void paged_decode_attention_bf16_kernel(
const __nv_bfloat16* __restrict__ Q,
const __nv_bfloat16* __restrict__ K_cache,
const __nv_bfloat16* __restrict__ V_cache,
__nv_bfloat16* __restrict__ O,
const int* __restrict__ block_tables, // [batch, max_blocks_per_seq]
const int* __restrict__ context_lens, // [batch]
int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
float scale
) {
int seq_idx = blockIdx.y; // batch dim
int q_head = blockIdx.x; // 0 .. num_q_heads-1
int tid = threadIdx.x;
int kv_len = context_lens[seq_idx];
if (kv_len <= 0) {
// Nothing to attend over; zero output for safety.
if (tid < head_dim) {
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
__float2bfloat16(0.0f);
}
return;
}
// GQA mapping
int heads_per_group = num_q_heads / num_kv_heads;
int kv_head = q_head / heads_per_group;
// Pointers
const __nv_bfloat16* Q_ptr = Q +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
__nv_bfloat16* O_ptr = O +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
// Load Q vector into registers.
float q_reg[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) {
q_reg[d] = __bfloat162float(Q_ptr[d]);
}
// Per-thread online softmax state.
float local_max = -INFINITY;
float local_sum = 0.0f;
float local_O[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
// Each thread handles positions tid, tid+PAGED_THREADS, ...
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
int logical_blk = pos / PAGED_BLOCK_SIZE;
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
int phys_blk = bt[logical_blk];
const __nv_bfloat16* K_pos = K_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
const __nv_bfloat16* V_pos = V_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
// dot(Q, K[pos]) * scale
float dot = 0.0f;
for (int d = 0; d < head_dim; d++) {
dot += q_reg[d] * __bfloat162float(K_pos[d]);
}
float s = dot * scale;
float new_max = fmaxf(local_max, s);
float correction = expf(local_max - new_max);
float p = expf(s - new_max);
local_sum = local_sum * correction + p;
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
// Accumulate weighted V.
for (int d = 0; d < head_dim; d++) {
local_O[d] += p * __bfloat162float(V_pos[d]);
}
local_max = new_max;
}
// ---- Block-level online softmax reduction ----
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
int num_warps = PAGED_THREADS >> 5;
// Step 1: block-wide max
float warp_max = local_max;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
if (lane == 0) smem_max[warp_id] = warp_max;
__syncthreads();
float global_max;
if (tid == 0) {
global_max = smem_max[0];
for (int i = 1; i < num_warps; i++)
global_max = fmaxf(global_max, smem_max[i]);
smem_max[0] = global_max;
}
__syncthreads();
global_max = smem_max[0];
// Step 2: rescale local state to global_max
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
local_sum *= rescale;
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
// Step 3: reduce sum
float warp_sum = local_sum;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
if (lane == 0) smem_sum[warp_id] = warp_sum;
__syncthreads();
float global_sum;
if (tid == 0) {
global_sum = 0.0f;
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
smem_sum[0] = global_sum;
}
__syncthreads();
global_sum = smem_sum[0];
// Step 4: reduce O across block, dim by dim
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
__syncthreads();
for (int d = 0; d < head_dim; d++) {
float val = local_O[d];
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) atomicAdd(&smem_O[d], val);
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
}
}
extern "C" {
void launch_paged_decode_attention_bf16(
const void* Q,
const void* K_cache,
const void* V_cache,
void* O,
const int* block_tables,
const int* context_lens,
int batch, int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
float scale, void* stream
) {
dim3 grid(num_q_heads, batch);
int block = PAGED_THREADS;
paged_decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)Q,
(const __nv_bfloat16*)K_cache,
(const __nv_bfloat16*)V_cache,
(__nv_bfloat16*)O,
block_tables, context_lens,
num_q_heads, num_kv_heads,
head_dim, max_blocks_per_seq,
scale
);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -48,3 +48,17 @@ __device__ __forceinline__ float block_reduce_max(float val) {
if (warp_id == 0) val = warp_reduce_max(val);
return val;
}
// --- Launch error checking (debug builds only) ---
#ifdef NDEBUG
#define CUDA_CHECK_LAST_ERROR() ((void)0)
#else
#include <cstdio>
#define CUDA_CHECK_LAST_ERROR() do { \
cudaError_t err = cudaGetLastError(); \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA kernel launch error at %s:%d: %s\n", \
__FILE__, __LINE__, cudaGetErrorString(err)); \
} \
} while(0)
#endif

View File

@@ -1,4 +1,5 @@
#include <cuda_bf16.h>
#include "../common.cuh"
// Embedding lookup: out[seq_idx] = table[token_ids[seq_idx]]
// Grid: num_tokens, Block: handles hidden_size elements per token.
@@ -7,10 +8,12 @@ __global__ void embedding_f32(
const float* __restrict__ table, // [vocab_size, hidden_size]
const int* __restrict__ token_ids, // [num_tokens]
float* __restrict__ out, // [num_tokens, hidden_size]
int hidden_size
int hidden_size,
int vocab_size
) {
int token_idx = blockIdx.x;
int tid = token_ids[token_idx];
if (tid < 0 || tid >= vocab_size) return;
const float* row = table + tid * hidden_size;
float* dst = out + token_idx * hidden_size;
@@ -23,10 +26,12 @@ __global__ void embedding_bf16(
const __nv_bfloat16* __restrict__ table,
const int* __restrict__ token_ids,
__nv_bfloat16* __restrict__ out,
int hidden_size
int hidden_size,
int vocab_size
) {
int token_idx = blockIdx.x;
int tid = token_ids[token_idx];
if (tid < 0 || tid >= vocab_size) return;
const __nv_bfloat16* row = table + tid * hidden_size;
__nv_bfloat16* dst = out + token_idx * hidden_size;
@@ -38,18 +43,20 @@ __global__ void embedding_bf16(
extern "C" {
void launch_embedding_f32(const void* table, const void* token_ids, void* out,
int num_tokens, int hidden_size, void* stream) {
int num_tokens, int hidden_size, int vocab_size, void* stream) {
int block = (hidden_size < 256) ? hidden_size : 256;
embedding_f32<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
(const float*)table, (const int*)token_ids, (float*)out, hidden_size);
(const float*)table, (const int*)token_ids, (float*)out, hidden_size, vocab_size);
CUDA_CHECK_LAST_ERROR();
}
void launch_embedding_bf16(const void* table, const void* token_ids, void* out,
int num_tokens, int hidden_size, void* stream) {
int num_tokens, int hidden_size, int vocab_size, void* stream) {
int block = (hidden_size < 256) ? hidden_size : 256;
embedding_bf16<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)table, (const int*)token_ids,
(__nv_bfloat16*)out, hidden_size);
(__nv_bfloat16*)out, hidden_size, vocab_size);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -1,10 +1,11 @@
#include <cuda_bf16.h>
#include <math.h>
#include "../common.cuh"
// RoPE: Rotary Position Embedding
// For each pair (x[2i], x[2i+1]) at position `pos`:
// y[2i] = x[2i] * cos - x[2i+1] * sin
// y[2i+1] = x[2i] * sin + x[2i+1] * cos
// RoPE: Rotary Position Embedding, using the Qwen/Llama rotate_half layout.
// For each dimension i in the first half at position `pos`:
// y[i] = x[i] * cos - x[i + half_dim] * sin
// y[i + half_dim] = x[i + half_dim] * cos + x[i] * sin
// where cos/sin come from precomputed cos_cache/sin_cache.
//
// cos_cache[pos][i] = cos(pos * freq[i])
@@ -35,11 +36,11 @@ __global__ void rope_f32(
float sin_val = sin_cache[pos * half_dim + pair_idx];
int base = (token_idx * num_heads + head_idx) * head_dim;
float x0 = x[base + 2 * pair_idx];
float x1 = x[base + 2 * pair_idx + 1];
float x0 = x[base + pair_idx];
float x1 = x[base + pair_idx + half_dim];
x[base + 2 * pair_idx] = x0 * cos_val - x1 * sin_val;
x[base + 2 * pair_idx + 1] = x0 * sin_val + x1 * cos_val;
x[base + pair_idx] = x0 * cos_val - x1 * sin_val;
x[base + pair_idx + half_dim] = x1 * cos_val + x0 * sin_val;
}
__global__ void rope_bf16(
@@ -61,11 +62,11 @@ __global__ void rope_bf16(
float sin_val = sin_cache[pos * half_dim + pair_idx];
int base = (token_idx * num_heads + head_idx) * head_dim;
float x0 = __bfloat162float(x[base + 2 * pair_idx]);
float x1 = __bfloat162float(x[base + 2 * pair_idx + 1]);
float x0 = __bfloat162float(x[base + pair_idx]);
float x1 = __bfloat162float(x[base + pair_idx + half_dim]);
x[base + 2 * pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
x[base + 2 * pair_idx + 1] = __float2bfloat16(x0 * sin_val + x1 * cos_val);
x[base + pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
x[base + pair_idx + half_dim] = __float2bfloat16(x1 * cos_val + x0 * sin_val);
}
// Precompute cos/sin cache on GPU
@@ -94,6 +95,7 @@ void launch_rope_f32(void* x, const void* cos_cache, const void* sin_cache,
rope_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(float*)x, (const float*)cos_cache, (const float*)sin_cache,
(const int*)positions, num_heads, head_dim);
CUDA_CHECK_LAST_ERROR();
}
void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
@@ -104,6 +106,7 @@ void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
rope_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)x, (const float*)cos_cache, (const float*)sin_cache,
(const int*)positions, num_heads, head_dim);
CUDA_CHECK_LAST_ERROR();
}
void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
@@ -111,6 +114,7 @@ void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
void* stream) {
compute_rope_cache<<<max_seq_len, half_dim, 0, (cudaStream_t)stream>>>(
(float*)cos_cache, (float*)sin_cache, max_seq_len, half_dim, theta);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -1,4 +1,5 @@
#include <cuda_bf16.h>
#include "../common.cuh"
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
@@ -169,6 +170,7 @@ void launch_reshape_heads_bf16(const void* in, void* out,
int grid = (total + block - 1) / block;
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
CUDA_CHECK_LAST_ERROR();
}
void launch_merge_heads_bf16(const void* in, void* out,
@@ -178,6 +180,7 @@ void launch_merge_heads_bf16(const void* in, void* out,
int grid = (total + block - 1) / block;
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
CUDA_CHECK_LAST_ERROR();
}
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
@@ -187,6 +190,7 @@ void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
int grid = (total + block - 1) / block;
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
CUDA_CHECK_LAST_ERROR();
}
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
@@ -196,6 +200,7 @@ void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
int grid = (total + block - 1) / block;
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
CUDA_CHECK_LAST_ERROR();
}
void launch_repeat_kv_bf16(const void* in, void* out,
@@ -205,6 +210,7 @@ void launch_repeat_kv_bf16(const void* in, void* out,
int grid = (total + block - 1) / block;
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
CUDA_CHECK_LAST_ERROR();
}
void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
@@ -217,6 +223,7 @@ void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, numel, ndim,
shape0, shape1, shape2, shape3,
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
CUDA_CHECK_LAST_ERROR();
}
void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
@@ -229,6 +236,7 @@ void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
(const float*)in, (float*)out, numel, ndim,
shape0, shape1, shape2, shape3,
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -1,5 +1,6 @@
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "../common.cuh"
// Custom GEMV kernel for M=1 decode step (BF16):
// y[n] = sum_k x[k] * W[k * N + n]
@@ -88,6 +89,7 @@ void launch_gemv_bf16(
(float*)y_fp32_buf,
K, N
);
CUDA_CHECK_LAST_ERROR();
// Convert FP32 -> BF16
int conv_block = 256;
@@ -97,6 +99,7 @@ void launch_gemv_bf16(
(__nv_bfloat16*)y_bf16,
N
);
CUDA_CHECK_LAST_ERROR();
}
} // extern "C"

View File

@@ -1,4 +1,5 @@
#include <cuda_bf16.h>
#include "../common.cuh"
// Naive GEMM: each thread computes one element of C.
// C[i][j] = sum_k A[i][k] * B[k][j]
@@ -46,6 +47,7 @@ void launch_gemm_naive_bf16(
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
);
CUDA_CHECK_LAST_ERROR();
}
void launch_gemm_naive_f32(
@@ -57,6 +59,7 @@ void launch_gemm_naive_f32(
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)A, (const float*)B, (float*)C, M, N, K
);
CUDA_CHECK_LAST_ERROR();
}
} // extern "C"

View File

@@ -1,4 +1,5 @@
#include <cuda_bf16.h>
#include "../common.cuh"
// Tiled GEMM using shared memory.
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
@@ -100,6 +101,7 @@ void launch_gemm_tiled_f32(
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)A, (const float*)B, (float*)C, M, N, K
);
CUDA_CHECK_LAST_ERROR();
}
void launch_gemm_tiled_bf16(
@@ -111,6 +113,7 @@ void launch_gemm_tiled_bf16(
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
);
CUDA_CHECK_LAST_ERROR();
}
} // extern "C"

View File

@@ -105,6 +105,7 @@ void launch_layernorm_f32(const void* x, const void* gamma, const void* beta,
layernorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (const float*)gamma, (const float*)beta,
(float*)out, hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
@@ -114,6 +115,7 @@ void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
layernorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma, (const __nv_bfloat16*)beta,
(__nv_bfloat16*)out, hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -111,6 +111,7 @@ void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
if (block < 32) block = 32;
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
@@ -120,6 +121,7 @@ void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
(__nv_bfloat16*)out, hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
@@ -132,6 +134,7 @@ void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* ga
(const __nv_bfloat16*)gamma,
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -94,6 +94,7 @@ void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stre
if (block < 32) block = 32;
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (float*)out, cols);
CUDA_CHECK_LAST_ERROR();
}
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
@@ -101,6 +102,7 @@ void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* str
if (block < 32) block = 32;
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);
CUDA_CHECK_LAST_ERROR();
}
}