kernels: fix uninitialized shared-memory read in M=1 decode GEMV
gemv_bf16_fused_kernel returned early on out-of-range columns (`if (col >= N) return;`) BEFORE the cooperative load of x into shared memory and the `__syncthreads()`. When N is not a multiple of GEMV_TILE_N (128), the last column-block's out-of-range threads exited without loading their slice of x_shared, so the in-range threads then read uninitialized shared memory in the dot product — and __syncthreads with exited threads is itself UB. Result: intermittent huge/garbage outputs (~1e33) that, after the next RMSNorm, collapsed the whole forward pass to a degenerate logit distribution (argmax → vocab_size-1, or NaN), derailing generation. This hit every M=1 BF16 GEMV (n>=256) with n % 128 != 0 — i.e. gpt-oss decode o_proj and the MoE projections (n=2880). q/k/v (4096) and lm_head (201088) are 128-aligned and were unaffected, as is Qwen3 (hidden 4096), which is why this manifested as intermittent gpt-oss-only decode failures. Fix: all threads participate in the shared-memory load and reach the barrier; the col>=N check moves to AFTER __syncthreads. Verified on dash5 (TP=2): a prompt that reliably produced garbage ~70% of runs now yields clean logits 16/16; the multi-turn Chinese chat that collapsed mid-conversation completes coherently with 0 NaN warnings. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -28,18 +28,25 @@ __global__ void gemv_bf16_fused_kernel(
|
||||
const int t = threadIdx.x;
|
||||
const int col = block_n * GEMV_TILE_N + t;
|
||||
|
||||
if (col >= N) return;
|
||||
|
||||
const int k_start = block_k * GEMV_TILE_K;
|
||||
const int k_end = min(k_start + GEMV_TILE_K, K);
|
||||
const int k_len = k_end - k_start;
|
||||
|
||||
// Cooperative load of x into shared memory uses ALL threads in the block
|
||||
// (indexed by t, independent of col). Threads whose column is out of range
|
||||
// must still help load and reach the barrier — returning early here would
|
||||
// leave part of x_shared uninitialized AND make __syncthreads divergent
|
||||
// (UB). So the col>=N check happens only AFTER the load + barrier. This bug
|
||||
// produced intermittent huge/garbage outputs whenever N % GEMV_TILE_N != 0
|
||||
// (e.g. gpt-oss decode o_proj with N=2880), collapsing the forward pass.
|
||||
__shared__ float x_shared[GEMV_TILE_K];
|
||||
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
||||
x_shared[i] = __bfloat162float(x[k_start + i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (col >= N) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int ki = 0; ki < k_len; ki++) {
|
||||
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||
|
||||
Reference in New Issue
Block a user