Files
xserv/csrc/gemm/tiled.cu
Gahow Wang 4c3f914459 kernels/cuda: paged-attention kernel, dispatch, pinned host memory
CUDA layer for the paged-KV + swap work:
- csrc: new paged_attention.cu plus updates across attention/gemm/norm/
  activation/embedding/reduce kernels and common.cuh.
- xserv-kernels: new dispatch module and kernel-binding updates.
- xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap
  pool backing) and offset-aware D2H/H2D copies used to move KV blocks
  between the GPU pool and pinned host memory.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 19:58:36 +08:00

120 lines
3.3 KiB
Plaintext

#include <cuda_bf16.h>
#include "../common.cuh"
// Tiled GEMM using shared memory.
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
// into shared memory, then computes a partial dot product.
#define TILE_SIZE 32
__global__ void gemm_tiled_f32(
const float* A, const float* B, float* C,
int M, int N, int K
) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
// Load tile of A
int a_col = t * TILE_SIZE + threadIdx.x;
if (row < M && a_col < K) {
As[threadIdx.y][threadIdx.x] = A[row * K + a_col];
} else {
As[threadIdx.y][threadIdx.x] = 0.0f;
}
// Load tile of B
int b_row = t * TILE_SIZE + threadIdx.y;
if (b_row < K && col < N) {
Bs[threadIdx.y][threadIdx.x] = B[b_row * N + col];
} else {
Bs[threadIdx.y][threadIdx.x] = 0.0f;
}
__syncthreads();
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = sum;
}
}
__global__ void gemm_tiled_bf16(
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
int M, int N, int K
) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
float sum = 0.0f;
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
int a_col = t * TILE_SIZE + threadIdx.x;
if (row < M && a_col < K) {
As[threadIdx.y][threadIdx.x] = __bfloat162float(A[row * K + a_col]);
} else {
As[threadIdx.y][threadIdx.x] = 0.0f;
}
int b_row = t * TILE_SIZE + threadIdx.y;
if (b_row < K && col < N) {
Bs[threadIdx.y][threadIdx.x] = __bfloat162float(B[b_row * N + col]);
} else {
Bs[threadIdx.y][threadIdx.x] = 0.0f;
}
__syncthreads();
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = __float2bfloat16(sum);
}
}
extern "C" {
void launch_gemm_tiled_f32(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(TILE_SIZE, TILE_SIZE);
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)A, (const float*)B, (float*)C, M, N, K
);
CUDA_CHECK_LAST_ERROR();
}
void launch_gemm_tiled_bf16(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(TILE_SIZE, TILE_SIZE);
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
);
CUDA_CHECK_LAST_ERROR();
}
} // extern "C"