kernels: fix uninitialized shared-memory read in M=1 decode GEMV

gemv_bf16_fused_kernel returned early on out-of-range columns
(`if (col >= N) return;`) BEFORE the cooperative load of x into shared
memory and the `__syncthreads()`. When N is not a multiple of GEMV_TILE_N
(128), the last column-block's out-of-range threads exited without loading
their slice of x_shared, so the in-range threads then read uninitialized
shared memory in the dot product — and __syncthreads with exited threads is
itself UB. Result: intermittent huge/garbage outputs (~1e33) that, after
the next RMSNorm, collapsed the whole forward pass to a degenerate logit
distribution (argmax → vocab_size-1, or NaN), derailing generation.

This hit every M=1 BF16 GEMV (n>=256) with n % 128 != 0 — i.e. gpt-oss
decode o_proj and the MoE projections (n=2880). q/k/v (4096) and lm_head
(201088) are 128-aligned and were unaffected, as is Qwen3 (hidden 4096),
which is why this manifested as intermittent gpt-oss-only decode failures.

Fix: all threads participate in the shared-memory load and reach the
barrier; the col>=N check moves to AFTER __syncthreads.

Verified on dash5 (TP=2): a prompt that reliably produced garbage ~70% of
runs now yields clean logits 16/16; the multi-turn Chinese chat that
collapsed mid-conversation completes coherently with 0 NaN warnings.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-02 17:18:37 +08:00
parent 5157b2cd30
commit 3b9e32e6cd

View File

@@ -28,18 +28,25 @@ __global__ void gemv_bf16_fused_kernel(
const int t = threadIdx.x;
const int col = block_n * GEMV_TILE_N + t;
if (col >= N) return;
const int k_start = block_k * GEMV_TILE_K;
const int k_end = min(k_start + GEMV_TILE_K, K);
const int k_len = k_end - k_start;
// Cooperative load of x into shared memory uses ALL threads in the block
// (indexed by t, independent of col). Threads whose column is out of range
// must still help load and reach the barrier — returning early here would
// leave part of x_shared uninitialized AND make __syncthreads divergent
// (UB). So the col>=N check happens only AFTER the load + barrier. This bug
// produced intermittent huge/garbage outputs whenever N % GEMV_TILE_N != 0
// (e.g. gpt-oss decode o_proj with N=2880), collapsing the forward pass.
__shared__ float x_shared[GEMV_TILE_K];
for (int i = t; i < k_len; i += GEMV_BLOCK) {
x_shared[i] = __bfloat162float(x[k_start + i]);
}
__syncthreads();
if (col >= N) return;
float sum = 0.0f;
for (int ki = 0; ki < k_len; ki++) {
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);