Files
xserv/csrc/attention/causal_mask.cu
Gahow Wang 5f060902f6 cuda: fix remaining int32-address and nondeterministic-reduction bugs
Three CUDA bugs from the review after 5b350ee / cfbd64d that were missed
by those commits:

- flash_attention.cu decode_attention_bf16_kernel used atomicAdd to
  merge per-warp partials into smem_O — same nondeterminism pattern
  that 5b350ee already fixed in paged_attention.cu and gemv.cu. This
  kernel is on the legacy forward_gpu_cache path plus the speculative
  bench baseline, so verify/decode parity depended on it. Replace with
  smem_O_warp[32][HEAD_DIM_MAX] partials reduced in fixed warp-id order.
- causal_mask.cu computed the flat address as
  `batch_idx * rows * cols + row * cols + col` in int; batch=128 heads=28
  seq=32768 already overflows int32. Promote the index to long long.
- quantization/dequant_fp8.cu had `int total = num_experts * rows * cols`
  and `int expert_stride = rows * cols`; 32 experts × 8k × 8k overflows.
  Same fix pattern as the MoE dense kernels in cfbd64d — 64-bit total /
  idx / expert_stride, and grid computed in long long.
2026-07-01 15:13:07 +08:00

60 lines
1.9 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <cuda_bf16.h>
#include "../common.cuh"
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
// offset is used for KV cache: when query starts at position `offset`,
// we allow attending to positions [0, offset + row].
// scores: [batch, rows, cols] (flattened batch×heads)
__global__ void causal_mask_f32(
float* __restrict__ scores,
int rows, int cols, int offset
) {
int batch_idx = blockIdx.z;
int row = blockIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < cols && col > row + offset) {
// 64-bit index: batch * rows * cols overflows int32 at moderate batch
// and long context (e.g. batch=128 * heads=28 * seq=32768).
long long idx = ((long long)batch_idx * rows + row) * cols + col;
scores[idx] = -INFINITY;
}
}
__global__ void causal_mask_bf16(
__nv_bfloat16* __restrict__ scores,
int rows, int cols, int offset
) {
int batch_idx = blockIdx.z;
int row = blockIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < cols && col > row + offset) {
long long idx = ((long long)batch_idx * rows + row) * cols + col;
scores[idx] = __float2bfloat16(-INFINITY);
}
}
extern "C" {
void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
int offset, void* stream) {
int block = 256;
dim3 grid((cols + block - 1) / block, rows, batch);
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(float*)scores, rows, cols, offset);
CUDA_CHECK_LAST_ERROR();
}
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
int offset, void* stream) {
int block = 256;
dim3 grid((cols + block - 1) / block, rows, batch);
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)scores, rows, cols, offset);
CUDA_CHECK_LAST_ERROR();
}
}