quantization: W8A8 FP8 compute via cuBLASLt tensor cores
Replace the W8A16 dequant→BF16-GEMM path with native FP8×FP8→BF16 GEMM using cuBLASLt on Blackwell (RTX 5090). Both weights (static FP8 E4M3) and activations (dynamically quantized per-row) are processed directly on FP8 tensor cores. Key implementation details: - cuBLASLt on Blackwell requires transA=T for FP8, so expert weights are transposed during model loading ([E,K,N] → [E,N,K]) - Per-row activation quantization kernel (absmax/448 → FP8 E4M3) - Post-GEMM row-wise rescaling recovers per-token precision - Per-expert loop (not batched) due to cuBLASLt FP8 scale constraints The same FP8 quantized model files work — no re-quantization needed. Activation quantization happens dynamically at inference time. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
123
csrc/quantization/quantize_fp8.cu
Normal file
123
csrc/quantization/quantize_fp8.cu
Normal file
@@ -0,0 +1,123 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// Per-row quantize BF16 → FP8 E4M3 with per-row FP32 scale output.
|
||||
//
|
||||
// Input: src [num_rows, cols] BF16
|
||||
// Output: dst [num_rows, cols] FP8 E4M3
|
||||
// scales [num_rows] FP32
|
||||
//
|
||||
// Algorithm per row:
|
||||
// absmax = max(|src[row, :]|)
|
||||
// scale = absmax / 448.0 (FP8 E4M3 max representable)
|
||||
// dst[row, i] = fp8(src[row, i] / scale)
|
||||
//
|
||||
// Grid: one block per row. Block: 256 threads.
|
||||
// Each thread handles ceil(cols / 256) elements.
|
||||
|
||||
#define QUANT_BLOCK 256
|
||||
#define FP8_E4M3_MAX 448.0f
|
||||
|
||||
__global__ void quantize_bf16_to_fp8e4m3_rowwise_kernel(
|
||||
const __nv_bfloat16* __restrict__ src,
|
||||
__nv_fp8_e4m3* __restrict__ dst,
|
||||
float* __restrict__ scales,
|
||||
int num_rows, int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) return;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const __nv_bfloat16* row_src = src + (long long)row * cols;
|
||||
__nv_fp8_e4m3* row_dst = dst + (long long)row * cols;
|
||||
|
||||
// Step 1: Compute per-row absmax via shared-memory reduction.
|
||||
__shared__ float smem_max[QUANT_BLOCK];
|
||||
float local_max = 0.0f;
|
||||
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
||||
float v = fabsf(__bfloat162float(row_src[i]));
|
||||
local_max = fmaxf(local_max, v);
|
||||
}
|
||||
smem_max[tid] = local_max;
|
||||
__syncthreads();
|
||||
|
||||
// Block reduction
|
||||
for (int s = QUANT_BLOCK / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
smem_max[tid] = fmaxf(smem_max[tid], smem_max[tid + s]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float absmax = smem_max[0];
|
||||
float scale = absmax / FP8_E4M3_MAX;
|
||||
// Clamp scale to avoid div-by-zero for all-zero rows
|
||||
if (scale < 1e-12f) scale = 1e-12f;
|
||||
float inv_scale = 1.0f / scale;
|
||||
|
||||
// Thread 0 writes the scale
|
||||
if (tid == 0) {
|
||||
scales[row] = scale;
|
||||
}
|
||||
|
||||
// Step 2: Quantize each element
|
||||
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
||||
float v = __bfloat162float(row_src[i]) * inv_scale;
|
||||
row_dst[i] = __nv_fp8_e4m3(v);
|
||||
}
|
||||
}
|
||||
|
||||
// Row-wise scale: data[row, :] *= scales[row] (in-place, BF16)
|
||||
__global__ void rowwise_scale_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ data,
|
||||
const float* __restrict__ scales,
|
||||
int num_rows, int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) return;
|
||||
int tid = threadIdx.x;
|
||||
float s = scales[row];
|
||||
__nv_bfloat16* row_data = data + (long long)row * cols;
|
||||
for (int i = tid; i < cols; i += blockDim.x) {
|
||||
float v = __bfloat162float(row_data[i]) * s;
|
||||
row_data[i] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rowwise_scale_bf16(
|
||||
void* data, const void* scales,
|
||||
int num_rows, int cols,
|
||||
void* stream
|
||||
) {
|
||||
int block = 256;
|
||||
int grid = num_rows;
|
||||
rowwise_scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)data, (const float*)scales,
|
||||
num_rows, cols
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||
const void* src,
|
||||
void* dst,
|
||||
void* scales,
|
||||
int num_rows, int cols,
|
||||
void* stream
|
||||
) {
|
||||
int grid = num_rows;
|
||||
int block = QUANT_BLOCK;
|
||||
quantize_bf16_to_fp8e4m3_rowwise_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)src,
|
||||
(__nv_fp8_e4m3*)dst,
|
||||
(float*)scales,
|
||||
num_rows, cols
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user