cuda: fix remaining int32-address and nondeterministic-reduction bugs
Three CUDA bugs from the review after5b350ee/cfbd64dthat were missed by those commits: - flash_attention.cu decode_attention_bf16_kernel used atomicAdd to merge per-warp partials into smem_O — same nondeterminism pattern that5b350eealready fixed in paged_attention.cu and gemv.cu. This kernel is on the legacy forward_gpu_cache path plus the speculative bench baseline, so verify/decode parity depended on it. Replace with smem_O_warp[32][HEAD_DIM_MAX] partials reduced in fixed warp-id order. - causal_mask.cu computed the flat address as `batch_idx * rows * cols + row * cols + col` in int; batch=128 heads=28 seq=32768 already overflows int32. Promote the index to long long. - quantization/dequant_fp8.cu had `int total = num_experts * rows * cols` and `int expert_stride = rows * cols`; 32 experts × 8k × 8k overflows. Same fix pattern as the MoE dense kernels incfbd64d— 64-bit total / idx / expert_stride, and grid computed in long long.
This commit is contained in:
@@ -15,7 +15,10 @@ __global__ void causal_mask_f32(
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
scores[batch_idx * rows * cols + row * cols + col] = -INFINITY;
|
||||
// 64-bit index: batch * rows * cols overflows int32 at moderate batch
|
||||
// and long context (e.g. batch=128 * heads=28 * seq=32768).
|
||||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||||
scores[idx] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +31,8 @@ __global__ void causal_mask_bf16(
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-INFINITY);
|
||||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||||
scores[idx] = __float2bfloat16(-INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -464,7 +464,7 @@ __global__ void decode_attention_bf16_kernel(
|
||||
// Shared memory for reduction
|
||||
__shared__ float smem_max[32]; // one per warp
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O[HEAD_DIM_MAX]; // final output accumulator
|
||||
__shared__ float smem_O_warp[32][HEAD_DIM_MAX];
|
||||
|
||||
// Step 1: Block-wide max reduction
|
||||
int lane = tid & 31;
|
||||
@@ -513,35 +513,30 @@ __global__ void decode_attention_bf16_kernel(
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
// Step 4: Reduce O across block (dimension by dimension using shared mem)
|
||||
// Step 4: Reduce O across block, dim by dim. Store one partial per warp
|
||||
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
|
||||
// when logits were close (same fix pattern as paged_attention.cu / gemv.cu).
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
|
||||
// Process head_dim in chunks: each iteration reduces one dimension
|
||||
// Use shared memory accumulator: each warp contributes via warp reduction + atomic
|
||||
// Actually simpler: iterate over dimensions, warp reduce each, then lane0 atomicAdd to smem_O
|
||||
|
||||
// Initialize smem_O
|
||||
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
||||
smem_O[d] = 0.0f;
|
||||
for (int i = tid; i < 32 * HEAD_DIM_MAX; i += DECODE_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Each thread adds its local_O contributions via warp reduction + atomicAdd
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
// Warp-level reduction
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) {
|
||||
atomicAdd(&smem_O[d], val);
|
||||
}
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Thread 0..head_dim-1 write final output
|
||||
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
||||
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,12 +16,14 @@ __global__ void dequant_fp8e4m3_to_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ dst,
|
||||
int num_experts, int rows, int cols
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = num_experts * rows * cols;
|
||||
// 64-bit index: num_experts * rows * cols overflows int32 for 32 experts
|
||||
// at ~8k*8k weight matrices, same class as the MoE fix in cfbd64d.
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)num_experts * rows * cols;
|
||||
if (idx >= total) return;
|
||||
|
||||
int expert_stride = rows * cols;
|
||||
int expert = idx / expert_stride;
|
||||
long long expert_stride = (long long)rows * cols;
|
||||
int expert = (int)(idx / expert_stride);
|
||||
float scale = scales[expert];
|
||||
float val = float(src[idx]) * scale;
|
||||
dst[idx] = __float2bfloat16(val);
|
||||
@@ -36,9 +38,9 @@ void launch_dequant_fp8e4m3_to_bf16(
|
||||
int num_experts, int rows, int cols,
|
||||
void* stream
|
||||
) {
|
||||
int total = num_experts * rows * cols;
|
||||
long long total = (long long)num_experts * rows * cols;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
int grid = (int)((total + block - 1) / block);
|
||||
dequant_fp8e4m3_to_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_fp8_e4m3*)src,
|
||||
(const float*)scales,
|
||||
|
||||
Reference in New Issue
Block a user