gpt-oss: drop debug syncs from forward; GPU broadcast bias-add

Decode carried three leftover cudaDeviceSynchronize (prefill one) from
NaN debugging — the Qwen3 path has none and the logits D2H in sample()
already orders against the null stream.

add_bias for rows>1 round-tripped the bias through the CPU (D2H + host
tile loop + H2D) on every call — 96 times per prefill across q/k/v/o.
Replace with a bias_add_2d broadcast kernel.

dash5, FP8 TP=2, warm-server: TTFT 35/49/94 -> 29/42/79 ms
(short/medium/long), TPOT 7.19-7.32 -> 6.99-7.21 ms. Greedy tokens
unchanged; GSM8K-50 94%.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 17:02:59 +08:00
parent 63f5599717
commit 1897b2e17a
4 changed files with 50 additions and 32 deletions

View File

@@ -87,6 +87,17 @@ __global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b,
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx]));
}
// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c]
__global__ void bias_add_2d_bf16_kernel(
const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ bias,
__nv_bfloat16* __restrict__ out, int rows, int cols
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= rows * cols) return;
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[idx % cols]);
out[idx] = __float2bfloat16(v);
}
// Element-wise mul: out = a * b
__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -159,6 +170,14 @@ void launch_add_bf16(const void* a, const void* b, void* out, int n, void* strea
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
CUDA_CHECK_LAST_ERROR();
}
void launch_bias_add_2d_bf16(const void* x, const void* bias, void* out, int rows, int cols, void* stream) {
int n = rows * cols;
int block = 256;
int grid = (n + block - 1) / block;
bias_add_2d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out, rows, cols);
CUDA_CHECK_LAST_ERROR();
}
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;