kernels: fix NaN in flash-attention sinks on fully-masked window tiles
flash_attention_sinks_bf16_kernel skipped only fully-future KV tiles (the causal `continue`); an early tile entirely outside the sliding window was still processed with every key masked to -inf, so row_max == -INFINITY. Folding that into the online softmax computed expf(-inf - (-inf)) = NaN, and the next valid tile's 0*NaN correction then poisoned the whole row. Result: the gpt-oss prefill produced all-NaN logits for any query whose sliding window (128) starts past the first KV tile — i.e. at longer context — collapsing generation into a single repeated token (argmax of all-NaN logits: vocab_size-1 in bench, token 0 "!" in the chat). This was the residual multi-turn/long-context collapse. Fix: skip a fully-masked tile (row_max == -INFINITY) — it contributes nothing to the softmax. The decode kernel already guards local_max == -INFINITY, so it was unaffected. Verified on dash5 (TP=2): the prefill that previously went all-NaN now produces clean logits; multi-turn gpt-oss chat (e.g. a haiku after a long prior answer) completes correctly instead of emitting "!!!!". Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -306,16 +306,27 @@ __global__ void flash_attention_sinks_bf16_kernel(
|
||||
row_max = fmaxf(row_max, s);
|
||||
}
|
||||
|
||||
float m_new = fmaxf(m_val, row_max);
|
||||
float psum = 0.0f;
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
P[c] = expf(P[c] - m_new);
|
||||
psum += P[c];
|
||||
// A fully-masked KV tile (every key causal- or window-masked) has
|
||||
// row_max == -INFINITY. Folding it in computes expf(-inf - (-inf))
|
||||
// = NaN, and a later valid tile's 0*NaN correction then poisons the
|
||||
// whole row. This happens for sliding-window layers whenever a
|
||||
// query's window starts past an early tile (the causal `continue`
|
||||
// above only skips fully-future tiles, not out-of-window ones).
|
||||
// A masked tile contributes nothing to the softmax — skip it.
|
||||
if (row_max != -INFINITY) {
|
||||
float m_new = fmaxf(m_val, row_max);
|
||||
float psum = 0.0f;
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
P[c] = expf(P[c] - m_new);
|
||||
psum += P[c];
|
||||
}
|
||||
float correction = expf(m_val - m_new);
|
||||
l_val = correction * l_val + psum;
|
||||
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
|
||||
m_val = m_new;
|
||||
} else {
|
||||
for (int c = 0; c < kv_tile_cols; c++) P[c] = 0.0f;
|
||||
}
|
||||
float correction = expf(m_val - m_new);
|
||||
l_val = correction * l_val + psum;
|
||||
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
|
||||
m_val = m_new;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Reference in New Issue
Block a user