Add Mixture-of-Experts support for the gpt-oss-20b model (20.9B params, 32 experts × top-4 routing). Key additions: - ModelConfig: MoE fields (num_local_experts, layer_types, sliding_window, attention_bias, explicit head_dim, rope_scaling, swiglu_limit) - YaRN RoPE: RopeCache::new_yarn() with correct frequency interpolation and attention_scaling = 0.1*ln(factor)+1 - Custom GLU kernel: gpt_oss_glu_bf16 (clamped sigmoid gate activation) - Paged attention with sinks + sliding window kernel variant - GptOss model struct with expert-parallel TP (split 32 experts across ranks) - bench-gpt-oss binary for TP inference benchmarking Verified on dash5 with 2x RTX 5090: 63.6 tok/s decode, ~160ms TTFT. Model generates topically-coherent output (needs chat template for quality). Known issues: - Custom GEMV kernel produces NaN with small N (workaround: pad to M=2) - Prefill doesn't use attention sinks (uses standard flash attention) - Output quality requires chat template formatting Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
195 lines
7.4 KiB
Plaintext
195 lines
7.4 KiB
Plaintext
#include <cuda_bf16.h>
|
||
#include <math.h>
|
||
#include "../common.cuh"
|
||
|
||
// GELU (tanh approximation):
|
||
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||
__device__ __forceinline__ float gelu_f(float x) {
|
||
const float SQRT_2_OVER_PI = 0.7978845608f;
|
||
float cube = x * x * x;
|
||
float inner = SQRT_2_OVER_PI * (x + 0.044715f * cube);
|
||
return 0.5f * x * (1.0f + tanhf(inner));
|
||
}
|
||
|
||
// SiLU (Swish): silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
||
__device__ __forceinline__ float silu_f(float x) {
|
||
return x / (1.0f + expf(-x));
|
||
}
|
||
|
||
__global__ void gelu_f32(const float* x, float* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = gelu_f(x[idx]);
|
||
}
|
||
|
||
__global__ void gelu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = __float2bfloat16(gelu_f(__bfloat162float(x[idx])));
|
||
}
|
||
|
||
__global__ void silu_f32(const float* x, float* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = silu_f(x[idx]);
|
||
}
|
||
|
||
__global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
|
||
}
|
||
|
||
__global__ void scale_f32_kernel(const float* x, float* out, float scale, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = x[idx] * scale;
|
||
}
|
||
|
||
__global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, float scale, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
|
||
}
|
||
|
||
// Fused SiLU×Mul: out = silu(gate) * up
|
||
__global__ void silu_mul_bf16_kernel(const __nv_bfloat16* gate, const __nv_bfloat16* up,
|
||
__nv_bfloat16* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) {
|
||
float g = __bfloat162float(gate[idx]);
|
||
float u = __bfloat162float(up[idx]);
|
||
float silu_g = g / (1.0f + expf(-g));
|
||
out[idx] = __float2bfloat16(silu_g * u);
|
||
}
|
||
}
|
||
|
||
// gpt-oss GLU: gate_up is [N, 2*D] with interleaved columns (gate=even, up=odd).
|
||
// gate = gate_up[::2].clamp(max=limit)
|
||
// up = gate_up[1::2].clamp(-limit, limit)
|
||
// glu = gate * sigmoid(gate * alpha)
|
||
// out = (up + 1) * glu
|
||
// Output: [N, D]
|
||
__global__ void gpt_oss_glu_bf16_kernel(const __nv_bfloat16* gate_up, __nv_bfloat16* out,
|
||
int n_elements, float alpha, float limit) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n_elements) {
|
||
float g = __bfloat162float(gate_up[idx * 2]);
|
||
float u = __bfloat162float(gate_up[idx * 2 + 1]);
|
||
g = fminf(g, limit);
|
||
u = fmaxf(fminf(u, limit), -limit);
|
||
float glu = g / (1.0f + expf(-g * alpha));
|
||
out[idx] = __float2bfloat16((u + 1.0f) * glu);
|
||
}
|
||
}
|
||
|
||
// Element-wise add: out = a + b
|
||
__global__ void add_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = a[idx] + b[idx];
|
||
}
|
||
__global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx]));
|
||
}
|
||
|
||
// Element-wise mul: out = a * b
|
||
__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = a[idx] * b[idx];
|
||
}
|
||
__global__ void mul_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) * __bfloat162float(b[idx]));
|
||
}
|
||
|
||
extern "C" {
|
||
|
||
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
gelu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
gelu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_silu_f32(const void* x, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
silu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
silu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const float*)x, (float*)out, scale, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
add_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const float*)a, (const float*)b, (float*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
add_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
mul_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const float*)a, (const float*)b, (float*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
silu_mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const __nv_bfloat16*)gate, (const __nv_bfloat16*)up, (__nv_bfloat16*)out, n);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_gpt_oss_glu_bf16(const void* gate_up, void* out, int n_elements,
|
||
float alpha, float limit, void* stream) {
|
||
int block = 256;
|
||
int grid = (n_elements + block - 1) / block;
|
||
gpt_oss_glu_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(const __nv_bfloat16*)gate_up, (__nv_bfloat16*)out, n_elements, alpha, limit);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
}
|