phase19: MoE support — gpt-oss-20b end-to-end inference with TP=2

Add Mixture-of-Experts support for the gpt-oss-20b model (20.9B params,
32 experts × top-4 routing). Key additions:

- ModelConfig: MoE fields (num_local_experts, layer_types, sliding_window,
  attention_bias, explicit head_dim, rope_scaling, swiglu_limit)
- YaRN RoPE: RopeCache::new_yarn() with correct frequency interpolation
  and attention_scaling = 0.1*ln(factor)+1
- Custom GLU kernel: gpt_oss_glu_bf16 (clamped sigmoid gate activation)
- Paged attention with sinks + sliding window kernel variant
- GptOss model struct with expert-parallel TP (split 32 experts across ranks)
- bench-gpt-oss binary for TP inference benchmarking

Verified on dash5 with 2x RTX 5090: 63.6 tok/s decode, ~160ms TTFT.
Model generates topically-coherent output (needs chat template for quality).

Known issues:
- Custom GEMV kernel produces NaN with small N (workaround: pad to M=2)
- Prefill doesn't use attention sinks (uses standard flash attention)
- Output quality requires chat template formatting

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Gahow Wang
2026-05-30 15:18:01 +08:00
parent 46bfb59f30
commit 9ad91a4a92
12 changed files with 1390 additions and 44 deletions

View File

@@ -58,6 +58,25 @@ __global__ void silu_mul_bf16_kernel(const __nv_bfloat16* gate, const __nv_bfloa
}
}
// gpt-oss GLU: gate_up is [N, 2*D] with interleaved columns (gate=even, up=odd).
// gate = gate_up[::2].clamp(max=limit)
// up = gate_up[1::2].clamp(-limit, limit)
// glu = gate * sigmoid(gate * alpha)
// out = (up + 1) * glu
// Output: [N, D]
__global__ void gpt_oss_glu_bf16_kernel(const __nv_bfloat16* gate_up, __nv_bfloat16* out,
int n_elements, float alpha, float limit) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n_elements) {
float g = __bfloat162float(gate_up[idx * 2]);
float u = __bfloat162float(gate_up[idx * 2 + 1]);
g = fminf(g, limit);
u = fmaxf(fminf(u, limit), -limit);
float glu = g / (1.0f + expf(-g * alpha));
out[idx] = __float2bfloat16((u + 1.0f) * glu);
}
}
// Element-wise add: out = a + b
__global__ void add_f32_kernel(const float* a, const float* b, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -163,4 +182,13 @@ void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, vo
CUDA_CHECK_LAST_ERROR();
}
void launch_gpt_oss_glu_bf16(const void* gate_up, void* out, int n_elements,
float alpha, float limit, void* stream) {
int block = 256;
int grid = (n_elements + block - 1) / block;
gpt_oss_glu_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)gate_up, (__nv_bfloat16*)out, n_elements, alpha, limit);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -183,6 +183,173 @@ __global__ void paged_decode_attention_bf16_kernel(
}
}
// Extended paged decode attention with attention sinks and sliding window.
// sinks: [num_q_heads] BF16 — per-head extra logit appended before softmax.
// window_size: >0 = sliding window (only attend to last `window_size` positions), 0 = full.
__global__ void paged_decode_attention_sinks_bf16_kernel(
const __nv_bfloat16* __restrict__ Q,
const __nv_bfloat16* __restrict__ K_cache,
const __nv_bfloat16* __restrict__ V_cache,
__nv_bfloat16* __restrict__ O,
const int* __restrict__ block_tables,
const int* __restrict__ context_lens,
const __nv_bfloat16* __restrict__ sinks, // [num_q_heads] or NULL
int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
float scale, int window_size
) {
int seq_idx = blockIdx.y;
int q_head = blockIdx.x;
int tid = threadIdx.x;
int kv_len = context_lens[seq_idx];
if (kv_len <= 0) {
if (tid < head_dim) {
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
__float2bfloat16(0.0f);
}
return;
}
int heads_per_group = num_q_heads / num_kv_heads;
int kv_head = q_head / heads_per_group;
const __nv_bfloat16* Q_ptr = Q +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
__nv_bfloat16* O_ptr = O +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
// Sliding window: only attend to positions [kv_len - window_size, kv_len)
int start_pos = 0;
if (window_size > 0 && kv_len > window_size) {
start_pos = kv_len - window_size;
}
float q_reg[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) {
q_reg[d] = __bfloat162float(Q_ptr[d]);
}
float local_max = -INFINITY;
float local_sum = 0.0f;
float local_O[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
int attend_len = kv_len - start_pos;
for (int rel = tid; rel < attend_len; rel += PAGED_THREADS) {
int pos = start_pos + rel;
int logical_blk = pos / PAGED_BLOCK_SIZE;
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
int phys_blk = bt[logical_blk];
const __nv_bfloat16* K_pos = K_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
const __nv_bfloat16* V_pos = V_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
float dot = 0.0f;
for (int d = 0; d < head_dim; d++) {
dot += q_reg[d] * __bfloat162float(K_pos[d]);
}
float s = dot * scale;
float new_max = fmaxf(local_max, s);
float correction = expf(local_max - new_max);
float p = expf(s - new_max);
local_sum = local_sum * correction + p;
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
for (int d = 0; d < head_dim; d++) {
local_O[d] += p * __bfloat162float(V_pos[d]);
}
local_max = new_max;
}
// Include the sink logit (only thread 0 handles it to avoid double-counting)
float sink_logit = -INFINITY;
if (sinks != nullptr && tid == 0) {
sink_logit = __bfloat162float(sinks[q_head]);
float new_max = fmaxf(local_max, sink_logit);
float correction = expf(local_max - new_max);
float p = expf(sink_logit - new_max);
local_sum = local_sum * correction + p;
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
// Sink absorbs probability but produces no value output (p * 0)
local_max = new_max;
}
// ---- Block-level online softmax reduction (same as base kernel) ----
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
int num_warps = PAGED_THREADS >> 5;
float warp_max = local_max;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
if (lane == 0) smem_max[warp_id] = warp_max;
__syncthreads();
float global_max;
if (tid == 0) {
global_max = smem_max[0];
for (int i = 1; i < num_warps; i++)
global_max = fmaxf(global_max, smem_max[i]);
smem_max[0] = global_max;
}
__syncthreads();
global_max = smem_max[0];
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
local_sum *= rescale;
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
float warp_sum = local_sum;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
if (lane == 0) smem_sum[warp_id] = warp_sum;
__syncthreads();
float global_sum;
if (tid == 0) {
global_sum = 0.0f;
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
smem_sum[0] = global_sum;
}
__syncthreads();
global_sum = smem_sum[0];
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
__syncthreads();
for (int d = 0; d < head_dim; d++) {
float val = local_O[d];
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) atomicAdd(&smem_O[d], val);
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
}
}
extern "C" {
void launch_paged_decode_attention_bf16(
@@ -212,4 +379,33 @@ void launch_paged_decode_attention_bf16(
CUDA_CHECK_LAST_ERROR();
}
void launch_paged_decode_attention_sinks_bf16(
const void* Q,
const void* K_cache,
const void* V_cache,
void* O,
const int* block_tables,
const int* context_lens,
const void* sinks,
int batch, int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
float scale, int window_size, void* stream
) {
dim3 grid(num_q_heads, batch);
int block = PAGED_THREADS;
paged_decode_attention_sinks_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)Q,
(const __nv_bfloat16*)K_cache,
(const __nv_bfloat16*)V_cache,
(__nv_bfloat16*)O,
block_tables, context_lens,
(const __nv_bfloat16*)sinks,
num_q_heads, num_kv_heads,
head_dim, max_blocks_per_seq,
scale, window_size
);
CUDA_CHECK_LAST_ERROR();
}
}