phase 5: naive multi-head attention

- Batched GEMM via cublasGemmStridedBatchedEx
- Causal mask CUDA kernel (F32 + BF16)
- Element-wise scale CUDA kernel (F32 + BF16)
- attention() composing: batched_matmul + scale + causal_mask + softmax
- Fixed to_device/contiguous infinite recursion (GPU contiguous via CPU round-trip)
- 5 attention tests passing (max_err < 3e-7 F32)
- Total: 61 tests passing across all crates

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 21:17:23 +08:00
parent c8e8153702
commit 6035ffdc0b
10 changed files with 550 additions and 12 deletions

View File

@@ -35,6 +35,16 @@ __global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
}
__global__ void scale_f32_kernel(const float* x, float* out, float scale, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = x[idx] * scale;
}
__global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, float scale, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
}
extern "C" {
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
@@ -63,4 +73,18 @@ void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
}
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (float*)out, scale, n);
}
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
}
}

View File

@@ -0,0 +1,53 @@
#include <cuda_bf16.h>
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
// offset is used for KV cache: when query starts at position `offset`,
// we allow attending to positions [0, offset + row].
// scores: [batch, rows, cols] (flattened batch×heads)
__global__ void causal_mask_f32(
float* __restrict__ scores,
int rows, int cols, int offset
) {
int batch_idx = blockIdx.z;
int row = blockIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < cols && col > row + offset) {
scores[batch_idx * rows * cols + row * cols + col] = -INFINITY;
}
}
__global__ void causal_mask_bf16(
__nv_bfloat16* __restrict__ scores,
int rows, int cols, int offset
) {
int batch_idx = blockIdx.z;
int row = blockIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < cols && col > row + offset) {
// BF16 doesn't have proper -inf literal, use a very large negative
scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-1e9f);
}
}
extern "C" {
void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
int offset, void* stream) {
int block = 256;
dim3 grid((cols + block - 1) / block, rows, batch);
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(float*)scores, rows, cols, offset);
}
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
int offset, void* stream) {
int block = 256;
dim3 grid((cols + block - 1) / block, rows, batch);
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)scores, rows, cols, offset);
}
}