Add Mixture-of-Experts support for the gpt-oss-20b model (20.9B params, 32 experts × top-4 routing). Key additions: - ModelConfig: MoE fields (num_local_experts, layer_types, sliding_window, attention_bias, explicit head_dim, rope_scaling, swiglu_limit) - YaRN RoPE: RopeCache::new_yarn() with correct frequency interpolation and attention_scaling = 0.1*ln(factor)+1 - Custom GLU kernel: gpt_oss_glu_bf16 (clamped sigmoid gate activation) - Paged attention with sinks + sliding window kernel variant - GptOss model struct with expert-parallel TP (split 32 experts across ranks) - bench-gpt-oss binary for TP inference benchmarking Verified on dash5 with 2x RTX 5090: 63.6 tok/s decode, ~160ms TTFT. Model generates topically-coherent output (needs chat template for quality). Known issues: - Custom GEMV kernel produces NaN with small N (workaround: pad to M=2) - Prefill doesn't use attention sinks (uses standard flash attention) - Output quality requires chat template formatting Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
412 lines
14 KiB
Plaintext
412 lines
14 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <float.h>
|
|
#include "../common.cuh"
|
|
|
|
// Paged decode attention kernel for BF16 with FP32 accumulation.
|
|
//
|
|
// Reads K/V from a paged pool indexed by a per-sequence block table.
|
|
// One CUDA block per (sequence, q_head). Each block streams over the
|
|
// sequence's KV positions and accumulates attention output via online
|
|
// softmax.
|
|
//
|
|
// Layouts:
|
|
// Q [batch, num_q_heads, 1, head_dim] BF16
|
|
// K_cache [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16
|
|
// V_cache same
|
|
// block_tables [max_seqs, max_blocks_per_seq] int32
|
|
// — the i-th sequence in this launch reads row
|
|
// block_tables[seq_slot[i] * stride + ...].
|
|
// For simplicity the launch passes a packed row table
|
|
// [batch, max_blocks_per_seq] (already gathered for the
|
|
// active batch) so we just index by blockIdx.x_seq.
|
|
// context_lens [batch] int32 — number of valid tokens per sequence.
|
|
//
|
|
// One CUDA block: 256 threads, head_dim <= 128.
|
|
|
|
#define PAGED_BLOCK_SIZE 16
|
|
#define PAGED_THREADS 256
|
|
#define PAGED_HEAD_DIM_MAX 128
|
|
|
|
__global__ void paged_decode_attention_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ Q,
|
|
const __nv_bfloat16* __restrict__ K_cache,
|
|
const __nv_bfloat16* __restrict__ V_cache,
|
|
__nv_bfloat16* __restrict__ O,
|
|
const int* __restrict__ block_tables, // [batch, max_blocks_per_seq]
|
|
const int* __restrict__ context_lens, // [batch]
|
|
int num_q_heads, int num_kv_heads,
|
|
int head_dim, int max_blocks_per_seq,
|
|
float scale
|
|
) {
|
|
int seq_idx = blockIdx.y; // batch dim
|
|
int q_head = blockIdx.x; // 0 .. num_q_heads-1
|
|
int tid = threadIdx.x;
|
|
|
|
int kv_len = context_lens[seq_idx];
|
|
if (kv_len <= 0) {
|
|
// Nothing to attend over; zero output for safety.
|
|
if (tid < head_dim) {
|
|
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
|
|
__float2bfloat16(0.0f);
|
|
}
|
|
return;
|
|
}
|
|
|
|
// GQA mapping
|
|
int heads_per_group = num_q_heads / num_kv_heads;
|
|
int kv_head = q_head / heads_per_group;
|
|
|
|
// Pointers
|
|
const __nv_bfloat16* Q_ptr = Q +
|
|
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
|
__nv_bfloat16* O_ptr = O +
|
|
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
|
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
|
|
|
|
// Load Q vector into registers.
|
|
float q_reg[PAGED_HEAD_DIM_MAX];
|
|
for (int d = 0; d < head_dim; d++) {
|
|
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
|
}
|
|
|
|
// Per-thread online softmax state.
|
|
float local_max = -INFINITY;
|
|
float local_sum = 0.0f;
|
|
float local_O[PAGED_HEAD_DIM_MAX];
|
|
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
|
|
|
|
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
|
|
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
|
|
|
|
// Each thread handles positions tid, tid+PAGED_THREADS, ...
|
|
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
|
|
int logical_blk = pos / PAGED_BLOCK_SIZE;
|
|
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
|
|
int phys_blk = bt[logical_blk];
|
|
|
|
const __nv_bfloat16* K_pos = K_cache
|
|
+ (long long)phys_blk * kv_stride_block
|
|
+ kv_head * kv_stride_head
|
|
+ slot_in_blk * head_dim;
|
|
const __nv_bfloat16* V_pos = V_cache
|
|
+ (long long)phys_blk * kv_stride_block
|
|
+ kv_head * kv_stride_head
|
|
+ slot_in_blk * head_dim;
|
|
|
|
// dot(Q, K[pos]) * scale
|
|
float dot = 0.0f;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
|
}
|
|
float s = dot * scale;
|
|
|
|
float new_max = fmaxf(local_max, s);
|
|
float correction = expf(local_max - new_max);
|
|
float p = expf(s - new_max);
|
|
|
|
local_sum = local_sum * correction + p;
|
|
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
|
|
|
// Accumulate weighted V.
|
|
for (int d = 0; d < head_dim; d++) {
|
|
local_O[d] += p * __bfloat162float(V_pos[d]);
|
|
}
|
|
|
|
local_max = new_max;
|
|
}
|
|
|
|
// ---- Block-level online softmax reduction ----
|
|
__shared__ float smem_max[32];
|
|
__shared__ float smem_sum[32];
|
|
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
|
|
|
|
int lane = tid & 31;
|
|
int warp_id = tid >> 5;
|
|
int num_warps = PAGED_THREADS >> 5;
|
|
|
|
// Step 1: block-wide max
|
|
float warp_max = local_max;
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
|
if (lane == 0) smem_max[warp_id] = warp_max;
|
|
__syncthreads();
|
|
|
|
float global_max;
|
|
if (tid == 0) {
|
|
global_max = smem_max[0];
|
|
for (int i = 1; i < num_warps; i++)
|
|
global_max = fmaxf(global_max, smem_max[i]);
|
|
smem_max[0] = global_max;
|
|
}
|
|
__syncthreads();
|
|
global_max = smem_max[0];
|
|
|
|
// Step 2: rescale local state to global_max
|
|
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
|
local_sum *= rescale;
|
|
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
|
|
|
|
// Step 3: reduce sum
|
|
float warp_sum = local_sum;
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
|
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
|
__syncthreads();
|
|
|
|
float global_sum;
|
|
if (tid == 0) {
|
|
global_sum = 0.0f;
|
|
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
|
|
smem_sum[0] = global_sum;
|
|
}
|
|
__syncthreads();
|
|
global_sum = smem_sum[0];
|
|
|
|
// Step 4: reduce O across block, dim by dim
|
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
|
|
__syncthreads();
|
|
|
|
for (int d = 0; d < head_dim; d++) {
|
|
float val = local_O[d];
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
|
if (lane == 0) atomicAdd(&smem_O[d], val);
|
|
}
|
|
__syncthreads();
|
|
|
|
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
|
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
|
}
|
|
}
|
|
|
|
// Extended paged decode attention with attention sinks and sliding window.
|
|
// sinks: [num_q_heads] BF16 — per-head extra logit appended before softmax.
|
|
// window_size: >0 = sliding window (only attend to last `window_size` positions), 0 = full.
|
|
__global__ void paged_decode_attention_sinks_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ Q,
|
|
const __nv_bfloat16* __restrict__ K_cache,
|
|
const __nv_bfloat16* __restrict__ V_cache,
|
|
__nv_bfloat16* __restrict__ O,
|
|
const int* __restrict__ block_tables,
|
|
const int* __restrict__ context_lens,
|
|
const __nv_bfloat16* __restrict__ sinks, // [num_q_heads] or NULL
|
|
int num_q_heads, int num_kv_heads,
|
|
int head_dim, int max_blocks_per_seq,
|
|
float scale, int window_size
|
|
) {
|
|
int seq_idx = blockIdx.y;
|
|
int q_head = blockIdx.x;
|
|
int tid = threadIdx.x;
|
|
|
|
int kv_len = context_lens[seq_idx];
|
|
if (kv_len <= 0) {
|
|
if (tid < head_dim) {
|
|
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
|
|
__float2bfloat16(0.0f);
|
|
}
|
|
return;
|
|
}
|
|
|
|
int heads_per_group = num_q_heads / num_kv_heads;
|
|
int kv_head = q_head / heads_per_group;
|
|
|
|
const __nv_bfloat16* Q_ptr = Q +
|
|
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
|
__nv_bfloat16* O_ptr = O +
|
|
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
|
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
|
|
|
|
// Sliding window: only attend to positions [kv_len - window_size, kv_len)
|
|
int start_pos = 0;
|
|
if (window_size > 0 && kv_len > window_size) {
|
|
start_pos = kv_len - window_size;
|
|
}
|
|
|
|
float q_reg[PAGED_HEAD_DIM_MAX];
|
|
for (int d = 0; d < head_dim; d++) {
|
|
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
|
}
|
|
|
|
float local_max = -INFINITY;
|
|
float local_sum = 0.0f;
|
|
float local_O[PAGED_HEAD_DIM_MAX];
|
|
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
|
|
|
|
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
|
|
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
|
|
|
|
int attend_len = kv_len - start_pos;
|
|
for (int rel = tid; rel < attend_len; rel += PAGED_THREADS) {
|
|
int pos = start_pos + rel;
|
|
int logical_blk = pos / PAGED_BLOCK_SIZE;
|
|
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
|
|
int phys_blk = bt[logical_blk];
|
|
|
|
const __nv_bfloat16* K_pos = K_cache
|
|
+ (long long)phys_blk * kv_stride_block
|
|
+ kv_head * kv_stride_head
|
|
+ slot_in_blk * head_dim;
|
|
const __nv_bfloat16* V_pos = V_cache
|
|
+ (long long)phys_blk * kv_stride_block
|
|
+ kv_head * kv_stride_head
|
|
+ slot_in_blk * head_dim;
|
|
|
|
float dot = 0.0f;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
|
}
|
|
float s = dot * scale;
|
|
|
|
float new_max = fmaxf(local_max, s);
|
|
float correction = expf(local_max - new_max);
|
|
float p = expf(s - new_max);
|
|
|
|
local_sum = local_sum * correction + p;
|
|
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
local_O[d] += p * __bfloat162float(V_pos[d]);
|
|
}
|
|
local_max = new_max;
|
|
}
|
|
|
|
// Include the sink logit (only thread 0 handles it to avoid double-counting)
|
|
float sink_logit = -INFINITY;
|
|
if (sinks != nullptr && tid == 0) {
|
|
sink_logit = __bfloat162float(sinks[q_head]);
|
|
float new_max = fmaxf(local_max, sink_logit);
|
|
float correction = expf(local_max - new_max);
|
|
float p = expf(sink_logit - new_max);
|
|
local_sum = local_sum * correction + p;
|
|
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
|
// Sink absorbs probability but produces no value output (p * 0)
|
|
local_max = new_max;
|
|
}
|
|
|
|
// ---- Block-level online softmax reduction (same as base kernel) ----
|
|
__shared__ float smem_max[32];
|
|
__shared__ float smem_sum[32];
|
|
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
|
|
|
|
int lane = tid & 31;
|
|
int warp_id = tid >> 5;
|
|
int num_warps = PAGED_THREADS >> 5;
|
|
|
|
float warp_max = local_max;
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
|
if (lane == 0) smem_max[warp_id] = warp_max;
|
|
__syncthreads();
|
|
|
|
float global_max;
|
|
if (tid == 0) {
|
|
global_max = smem_max[0];
|
|
for (int i = 1; i < num_warps; i++)
|
|
global_max = fmaxf(global_max, smem_max[i]);
|
|
smem_max[0] = global_max;
|
|
}
|
|
__syncthreads();
|
|
global_max = smem_max[0];
|
|
|
|
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
|
local_sum *= rescale;
|
|
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
|
|
|
|
float warp_sum = local_sum;
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
|
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
|
__syncthreads();
|
|
|
|
float global_sum;
|
|
if (tid == 0) {
|
|
global_sum = 0.0f;
|
|
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
|
|
smem_sum[0] = global_sum;
|
|
}
|
|
__syncthreads();
|
|
global_sum = smem_sum[0];
|
|
|
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
|
|
__syncthreads();
|
|
|
|
for (int d = 0; d < head_dim; d++) {
|
|
float val = local_O[d];
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
|
if (lane == 0) atomicAdd(&smem_O[d], val);
|
|
}
|
|
__syncthreads();
|
|
|
|
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
|
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_paged_decode_attention_bf16(
|
|
const void* Q,
|
|
const void* K_cache,
|
|
const void* V_cache,
|
|
void* O,
|
|
const int* block_tables,
|
|
const int* context_lens,
|
|
int batch, int num_q_heads, int num_kv_heads,
|
|
int head_dim, int max_blocks_per_seq,
|
|
float scale, void* stream
|
|
) {
|
|
dim3 grid(num_q_heads, batch);
|
|
int block = PAGED_THREADS;
|
|
|
|
paged_decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)Q,
|
|
(const __nv_bfloat16*)K_cache,
|
|
(const __nv_bfloat16*)V_cache,
|
|
(__nv_bfloat16*)O,
|
|
block_tables, context_lens,
|
|
num_q_heads, num_kv_heads,
|
|
head_dim, max_blocks_per_seq,
|
|
scale
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_paged_decode_attention_sinks_bf16(
|
|
const void* Q,
|
|
const void* K_cache,
|
|
const void* V_cache,
|
|
void* O,
|
|
const int* block_tables,
|
|
const int* context_lens,
|
|
const void* sinks,
|
|
int batch, int num_q_heads, int num_kv_heads,
|
|
int head_dim, int max_blocks_per_seq,
|
|
float scale, int window_size, void* stream
|
|
) {
|
|
dim3 grid(num_q_heads, batch);
|
|
int block = PAGED_THREADS;
|
|
|
|
paged_decode_attention_sinks_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)Q,
|
|
(const __nv_bfloat16*)K_cache,
|
|
(const __nv_bfloat16*)V_cache,
|
|
(__nv_bfloat16*)O,
|
|
block_tables, context_lens,
|
|
(const __nv_bfloat16*)sinks,
|
|
num_q_heads, num_kv_heads,
|
|
head_dim, max_blocks_per_seq,
|
|
scale, window_size
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|