Three CUDA bugs from the review after5b350ee/cfbd64dthat were missed by those commits: - flash_attention.cu decode_attention_bf16_kernel used atomicAdd to merge per-warp partials into smem_O — same nondeterminism pattern that5b350eealready fixed in paged_attention.cu and gemv.cu. This kernel is on the legacy forward_gpu_cache path plus the speculative bench baseline, so verify/decode parity depended on it. Replace with smem_O_warp[32][HEAD_DIM_MAX] partials reduced in fixed warp-id order. - causal_mask.cu computed the flat address as `batch_idx * rows * cols + row * cols + col` in int; batch=128 heads=28 seq=32768 already overflows int32. Promote the index to long long. - quantization/dequant_fp8.cu had `int total = num_experts * rows * cols` and `int expert_stride = rows * cols`; 32 experts × 8k × 8k overflows. Same fix pattern as the MoE dense kernels incfbd64d— 64-bit total / idx / expert_stride, and grid computed in long long.
60 lines
1.9 KiB
Plaintext
60 lines
1.9 KiB
Plaintext
#include <cuda_bf16.h>
|
||
#include "../common.cuh"
|
||
|
||
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
|
||
// offset is used for KV cache: when query starts at position `offset`,
|
||
// we allow attending to positions [0, offset + row].
|
||
// scores: [batch, rows, cols] (flattened batch×heads)
|
||
|
||
__global__ void causal_mask_f32(
|
||
float* __restrict__ scores,
|
||
int rows, int cols, int offset
|
||
) {
|
||
int batch_idx = blockIdx.z;
|
||
int row = blockIdx.y;
|
||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||
|
||
if (col < cols && col > row + offset) {
|
||
// 64-bit index: batch * rows * cols overflows int32 at moderate batch
|
||
// and long context (e.g. batch=128 * heads=28 * seq=32768).
|
||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||
scores[idx] = -INFINITY;
|
||
}
|
||
}
|
||
|
||
__global__ void causal_mask_bf16(
|
||
__nv_bfloat16* __restrict__ scores,
|
||
int rows, int cols, int offset
|
||
) {
|
||
int batch_idx = blockIdx.z;
|
||
int row = blockIdx.y;
|
||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||
|
||
if (col < cols && col > row + offset) {
|
||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||
scores[idx] = __float2bfloat16(-INFINITY);
|
||
}
|
||
}
|
||
|
||
extern "C" {
|
||
|
||
void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
|
||
int offset, void* stream) {
|
||
int block = 256;
|
||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(float*)scores, rows, cols, offset);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
||
int offset, void* stream) {
|
||
int block = 256;
|
||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
(__nv_bfloat16*)scores, rows, cols, offset);
|
||
CUDA_CHECK_LAST_ERROR();
|
||
}
|
||
|
||
}
|