Three CUDA bugs from the review after5b350ee/cfbd64dthat were missed by those commits: - flash_attention.cu decode_attention_bf16_kernel used atomicAdd to merge per-warp partials into smem_O — same nondeterminism pattern that5b350eealready fixed in paged_attention.cu and gemv.cu. This kernel is on the legacy forward_gpu_cache path plus the speculative bench baseline, so verify/decode parity depended on it. Replace with smem_O_warp[32][HEAD_DIM_MAX] partials reduced in fixed warp-id order. - causal_mask.cu computed the flat address as `batch_idx * rows * cols + row * cols + col` in int; batch=128 heads=28 seq=32768 already overflows int32. Promote the index to long long. - quantization/dequant_fp8.cu had `int total = num_experts * rows * cols` and `int expert_stride = rows * cols`; 32 experts × 8k × 8k overflows. Same fix pattern as the MoE dense kernels incfbd64d— 64-bit total / idx / expert_stride, and grid computed in long long.
54 lines
1.6 KiB
Plaintext
54 lines
1.6 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <cuda_fp8.h>
|
|
#include "../common.cuh"
|
|
|
|
// Dequantize FP8 E4M3 → BF16 with per-expert (per-batch-slice) FP32 scale.
|
|
//
|
|
// Input: src [num_experts, rows, cols] FP8 E4M3 (1 byte each)
|
|
// scales [num_experts] FP32
|
|
// Output: dst [num_experts, rows, cols] BF16
|
|
//
|
|
// Each element: dst[e, r, c] = bf16( float(src[e, r, c]) * scales[e] )
|
|
|
|
__global__ void dequant_fp8e4m3_to_bf16_kernel(
|
|
const __nv_fp8_e4m3* __restrict__ src,
|
|
const float* __restrict__ scales,
|
|
__nv_bfloat16* __restrict__ dst,
|
|
int num_experts, int rows, int cols
|
|
) {
|
|
// 64-bit index: num_experts * rows * cols overflows int32 for 32 experts
|
|
// at ~8k*8k weight matrices, same class as the MoE fix in cfbd64d.
|
|
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
|
long long total = (long long)num_experts * rows * cols;
|
|
if (idx >= total) return;
|
|
|
|
long long expert_stride = (long long)rows * cols;
|
|
int expert = (int)(idx / expert_stride);
|
|
float scale = scales[expert];
|
|
float val = float(src[idx]) * scale;
|
|
dst[idx] = __float2bfloat16(val);
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_dequant_fp8e4m3_to_bf16(
|
|
const void* src,
|
|
const void* scales,
|
|
void* dst,
|
|
int num_experts, int rows, int cols,
|
|
void* stream
|
|
) {
|
|
long long total = (long long)num_experts * rows * cols;
|
|
int block = 256;
|
|
int grid = (int)((total + block - 1) / block);
|
|
dequant_fp8e4m3_to_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_fp8_e4m3*)src,
|
|
(const float*)scales,
|
|
(__nv_bfloat16*)dst,
|
|
num_experts, rows, cols
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|