diff --git a/csrc/attention/causal_mask.cu b/csrc/attention/causal_mask.cu index ccce701..e16bfdf 100644 --- a/csrc/attention/causal_mask.cu +++ b/csrc/attention/causal_mask.cu @@ -15,7 +15,10 @@ __global__ void causal_mask_f32( int col = blockIdx.x * blockDim.x + threadIdx.x; if (col < cols && col > row + offset) { - scores[batch_idx * rows * cols + row * cols + col] = -INFINITY; + // 64-bit index: batch * rows * cols overflows int32 at moderate batch + // and long context (e.g. batch=128 * heads=28 * seq=32768). + long long idx = ((long long)batch_idx * rows + row) * cols + col; + scores[idx] = -INFINITY; } } @@ -28,7 +31,8 @@ __global__ void causal_mask_bf16( int col = blockIdx.x * blockDim.x + threadIdx.x; if (col < cols && col > row + offset) { - scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-INFINITY); + long long idx = ((long long)batch_idx * rows + row) * cols + col; + scores[idx] = __float2bfloat16(-INFINITY); } } diff --git a/csrc/attention/flash_attention.cu b/csrc/attention/flash_attention.cu index c9c69cd..1042007 100644 --- a/csrc/attention/flash_attention.cu +++ b/csrc/attention/flash_attention.cu @@ -464,7 +464,7 @@ __global__ void decode_attention_bf16_kernel( // Shared memory for reduction __shared__ float smem_max[32]; // one per warp __shared__ float smem_sum[32]; - __shared__ float smem_O[HEAD_DIM_MAX]; // final output accumulator + __shared__ float smem_O_warp[32][HEAD_DIM_MAX]; // Step 1: Block-wide max reduction int lane = tid & 31; @@ -513,35 +513,30 @@ __global__ void decode_attention_bf16_kernel( __syncthreads(); global_sum = smem_sum[0]; - // Step 4: Reduce O across block (dimension by dimension using shared mem) + // Step 4: Reduce O across block, dim by dim. Store one partial per warp + // and sum in warp-id order; atomicAdd made greedy decode nondeterministic + // when logits were close (same fix pattern as paged_attention.cu / gemv.cu). float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f; - // Process head_dim in chunks: each iteration reduces one dimension - // Use shared memory accumulator: each warp contributes via warp reduction + atomic - // Actually simpler: iterate over dimensions, warp reduce each, then lane0 atomicAdd to smem_O - - // Initialize smem_O - for (int d = tid; d < head_dim; d += DECODE_THREADS) { - smem_O[d] = 0.0f; + for (int i = tid; i < 32 * HEAD_DIM_MAX; i += DECODE_THREADS) { + reinterpret_cast(smem_O_warp)[i] = 0.0f; } __syncthreads(); - // Each thread adds its local_O contributions via warp reduction + atomicAdd for (int d = 0; d < head_dim; d++) { float val = local_O[d]; - // Warp-level reduction #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) val += __shfl_down_sync(0xffffffff, val, offset); - if (lane == 0) { - atomicAdd(&smem_O[d], val); - } + if (lane == 0) smem_O_warp[warp_id][d] = val; } __syncthreads(); // Thread 0..head_dim-1 write final output for (int d = tid; d < head_dim; d += DECODE_THREADS) { - O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum); + float out = 0.0f; + for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d]; + O_ptr[d] = __float2bfloat16(out * inv_sum); } } diff --git a/csrc/quantization/dequant_fp8.cu b/csrc/quantization/dequant_fp8.cu index cd98ca5..6d74f88 100644 --- a/csrc/quantization/dequant_fp8.cu +++ b/csrc/quantization/dequant_fp8.cu @@ -16,12 +16,14 @@ __global__ void dequant_fp8e4m3_to_bf16_kernel( __nv_bfloat16* __restrict__ dst, int num_experts, int rows, int cols ) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total = num_experts * rows * cols; + // 64-bit index: num_experts * rows * cols overflows int32 for 32 experts + // at ~8k*8k weight matrices, same class as the MoE fix in cfbd64d. + long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x; + long long total = (long long)num_experts * rows * cols; if (idx >= total) return; - int expert_stride = rows * cols; - int expert = idx / expert_stride; + long long expert_stride = (long long)rows * cols; + int expert = (int)(idx / expert_stride); float scale = scales[expert]; float val = float(src[idx]) * scale; dst[idx] = __float2bfloat16(val); @@ -36,9 +38,9 @@ void launch_dequant_fp8e4m3_to_bf16( int num_experts, int rows, int cols, void* stream ) { - int total = num_experts * rows * cols; + long long total = (long long)num_experts * rows * cols; int block = 256; - int grid = (total + block - 1) / block; + int grid = (int)((total + block - 1) / block); dequant_fp8e4m3_to_bf16_kernel<<>>( (const __nv_fp8_e4m3*)src, (const float*)scales,