kernels: fix NaN in flash-attention sinks on fully-masked window tiles

flash_attention_sinks_bf16_kernel skipped only fully-future KV tiles (the
causal `continue`); an early tile entirely outside the sliding window was
still processed with every key masked to -inf, so row_max == -INFINITY.
Folding that into the online softmax computed expf(-inf - (-inf)) = NaN,
and the next valid tile's 0*NaN correction then poisoned the whole row.

Result: the gpt-oss prefill produced all-NaN logits for any query whose
sliding window (128) starts past the first KV tile — i.e. at longer
context — collapsing generation into a single repeated token (argmax of
all-NaN logits: vocab_size-1 in bench, token 0 "!" in the chat). This was
the residual multi-turn/long-context collapse.

Fix: skip a fully-masked tile (row_max == -INFINITY) — it contributes
nothing to the softmax. The decode kernel already guards
local_max == -INFINITY, so it was unaffected.

Verified on dash5 (TP=2): the prefill that previously went all-NaN now
produces clean logits; multi-turn gpt-oss chat (e.g. a haiku after a long
prior answer) completes correctly instead of emitting "!!!!".

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-02 16:09:43 +08:00
parent ea5d8ba7ea
commit 5157b2cd30

View File

@@ -306,16 +306,27 @@ __global__ void flash_attention_sinks_bf16_kernel(
row_max = fmaxf(row_max, s); row_max = fmaxf(row_max, s);
} }
float m_new = fmaxf(m_val, row_max); // A fully-masked KV tile (every key causal- or window-masked) has
float psum = 0.0f; // row_max == -INFINITY. Folding it in computes expf(-inf - (-inf))
for (int c = 0; c < kv_tile_cols; c++) { // = NaN, and a later valid tile's 0*NaN correction then poisons the
P[c] = expf(P[c] - m_new); // whole row. This happens for sliding-window layers whenever a
psum += P[c]; // query's window starts past an early tile (the causal `continue`
// above only skips fully-future tiles, not out-of-window ones).
// A masked tile contributes nothing to the softmax — skip it.
if (row_max != -INFINITY) {
float m_new = fmaxf(m_val, row_max);
float psum = 0.0f;
for (int c = 0; c < kv_tile_cols; c++) {
P[c] = expf(P[c] - m_new);
psum += P[c];
}
float correction = expf(m_val - m_new);
l_val = correction * l_val + psum;
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
m_val = m_new;
} else {
for (int c = 0; c < kv_tile_cols; c++) P[c] = 0.0f;
} }
float correction = expf(m_val - m_new);
l_val = correction * l_val + psum;
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
m_val = m_new;
} }
__syncthreads(); __syncthreads();