CUDA layer for the paged-KV + swap work: - csrc: new paged_attention.cu plus updates across attention/gemm/norm/ activation/embedding/reduce kernels and common.cuh. - xserv-kernels: new dispatch module and kernel-binding updates. - xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap pool backing) and offset-aware D2H/H2D copies used to move KV blocks between the GPU pool and pinned host memory. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
243 lines
8.7 KiB
Plaintext
243 lines
8.7 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include "../common.cuh"
|
|
|
|
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
|
|
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
|
|
|
|
// reshape_heads: [S, H*D] → [1, H, S, D]
|
|
// Input layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
|
// Output layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
|
__global__ void reshape_heads_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int hidden = num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int total = seq_len * hidden;
|
|
if (idx >= total) return;
|
|
|
|
int s = idx / hidden;
|
|
int rem = idx % hidden;
|
|
int h = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int out_idx = h * seq_len * head_dim + s * head_dim + d;
|
|
out[out_idx] = in[idx];
|
|
}
|
|
|
|
// merge_heads: [1, H, S, D] → [S, H*D]
|
|
// Input layout: element at [0, h, s, d] = flat[h * S*D + s*D + d]
|
|
// Output layout: element at [s, h*D + d] = flat[s * H*D + h*D + d]
|
|
__global__ void merge_heads_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int hidden = num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int total = seq_len * hidden;
|
|
if (idx >= total) return;
|
|
|
|
// idx is output index: [s, h*D + d]
|
|
int s = idx / hidden;
|
|
int rem = idx % hidden;
|
|
int h = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// transpose_for_rope: [1, H, S, D] → [S, H, D]
|
|
// Input: [h, s, d] at h*S*D + s*D + d
|
|
// Output: [s, h, d] at s*H*D + h*D + d
|
|
__global__ void transpose_hsd_to_shd_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= total) return;
|
|
|
|
// idx = output flat index: s*H*D + h*D + d
|
|
int s = idx / (num_heads * head_dim);
|
|
int rem = idx % (num_heads * head_dim);
|
|
int h = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int in_idx = h * seq_len * head_dim + s * head_dim + d;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// transpose_from_rope: [S, H, D] → [1, H, S, D]
|
|
// Input: [s, h, d] at s*H*D + h*D + d
|
|
// Output: [h, s, d] at h*S*D + s*D + d
|
|
__global__ void transpose_shd_to_hsd_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int seq_len, int num_heads, int head_dim
|
|
) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= total) return;
|
|
|
|
// idx = output flat index: h*S*D + s*D + d
|
|
int h = idx / (seq_len * head_dim);
|
|
int rem = idx % (seq_len * head_dim);
|
|
int s = rem / head_dim;
|
|
int d = rem % head_dim;
|
|
|
|
int in_idx = s * num_heads * head_dim + h * head_dim + d;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// repeat_kv: [1, KV_H, S, D] → [1, KV_H * n_rep, S, D]
|
|
__global__ void repeat_kv_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int kv_heads, int n_rep, int seq_len, int head_dim
|
|
) {
|
|
int total_heads = kv_heads * n_rep;
|
|
int total = total_heads * seq_len * head_dim;
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= total) return;
|
|
|
|
int out_h = idx / (seq_len * head_dim);
|
|
int rem = idx % (seq_len * head_dim);
|
|
int kv_h = out_h / n_rep;
|
|
|
|
int in_idx = kv_h * seq_len * head_dim + rem;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
// ---- Generic strided copy (up to 4D) ----
|
|
// Each thread copies one element. Maps flat contiguous output index to strided input index.
|
|
// Unused dimensions are padded with shape=1, stride=0.
|
|
|
|
__global__ void strided_copy_bf16(
|
|
const __nv_bfloat16* __restrict__ in,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int numel,
|
|
int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset
|
|
) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= numel) return;
|
|
|
|
// Decompose flat output index into multi-dim indices (rightmost = fastest)
|
|
int remaining = idx;
|
|
int i3 = remaining % shape3; remaining /= shape3;
|
|
int i2 = remaining % shape2; remaining /= shape2;
|
|
int i1 = remaining % shape1; remaining /= shape1;
|
|
int i0 = remaining;
|
|
|
|
int in_idx = in_offset + i0 * in_stride0 + i1 * in_stride1 + i2 * in_stride2 + i3 * in_stride3;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
__global__ void strided_copy_f32(
|
|
const float* __restrict__ in,
|
|
float* __restrict__ out,
|
|
int numel,
|
|
int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset
|
|
) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx >= numel) return;
|
|
|
|
int remaining = idx;
|
|
int i3 = remaining % shape3; remaining /= shape3;
|
|
int i2 = remaining % shape2; remaining /= shape2;
|
|
int i1 = remaining % shape1; remaining /= shape1;
|
|
int i0 = remaining;
|
|
|
|
int in_idx = in_offset + i0 * in_stride0 + i1 * in_stride1 + i2 * in_stride2 + i3 * in_stride3;
|
|
out[idx] = in[in_idx];
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_reshape_heads_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_merge_heads_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
|
|
int seq_len, int num_heads, int head_dim, void* stream) {
|
|
int total = seq_len * num_heads * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_repeat_kv_bf16(const void* in, void* out,
|
|
int kv_heads, int n_rep, int seq_len, int head_dim, void* stream) {
|
|
int total = kv_heads * n_rep * seq_len * head_dim;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset, void* stream) {
|
|
int block = 256;
|
|
int grid = (numel + block - 1) / block;
|
|
strided_copy_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, numel, ndim,
|
|
shape0, shape1, shape2, shape3,
|
|
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
|
|
int shape0, int shape1, int shape2, int shape3,
|
|
int in_stride0, int in_stride1, int in_stride2, int in_stride3,
|
|
int in_offset, void* stream) {
|
|
int block = 256;
|
|
int grid = (numel + block - 1) / block;
|
|
strided_copy_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const float*)in, (float*)out, numel, ndim,
|
|
shape0, shape1, shape2, shape3,
|
|
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|