quantization: W8A8 FP8 compute via cuBLASLt tensor cores

Replace the W8A16 dequant→BF16-GEMM path with native FP8×FP8→BF16 GEMM
using cuBLASLt on Blackwell (RTX 5090). Both weights (static FP8 E4M3)
and activations (dynamically quantized per-row) are processed directly
on FP8 tensor cores.

Key implementation details:
- cuBLASLt on Blackwell requires transA=T for FP8, so expert weights
  are transposed during model loading ([E,K,N] → [E,N,K])
- Per-row activation quantization kernel (absmax/448 → FP8 E4M3)
- Post-GEMM row-wise rescaling recovers per-token precision
- Per-expert loop (not batched) due to cuBLASLt FP8 scale constraints

The same FP8 quantized model files work — no re-quantization needed.
Activation quantization happens dynamically at inference time.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-07 20:38:26 +08:00
parent 9f1fbbb98b
commit 76487b7963
4 changed files with 508 additions and 15 deletions

View File

@@ -0,0 +1,123 @@
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <float.h>
#include "../common.cuh"
// Per-row quantize BF16 → FP8 E4M3 with per-row FP32 scale output.
//
// Input: src [num_rows, cols] BF16
// Output: dst [num_rows, cols] FP8 E4M3
// scales [num_rows] FP32
//
// Algorithm per row:
// absmax = max(|src[row, :]|)
// scale = absmax / 448.0 (FP8 E4M3 max representable)
// dst[row, i] = fp8(src[row, i] / scale)
//
// Grid: one block per row. Block: 256 threads.
// Each thread handles ceil(cols / 256) elements.
#define QUANT_BLOCK 256
#define FP8_E4M3_MAX 448.0f
__global__ void quantize_bf16_to_fp8e4m3_rowwise_kernel(
const __nv_bfloat16* __restrict__ src,
__nv_fp8_e4m3* __restrict__ dst,
float* __restrict__ scales,
int num_rows, int cols
) {
int row = blockIdx.x;
if (row >= num_rows) return;
int tid = threadIdx.x;
const __nv_bfloat16* row_src = src + (long long)row * cols;
__nv_fp8_e4m3* row_dst = dst + (long long)row * cols;
// Step 1: Compute per-row absmax via shared-memory reduction.
__shared__ float smem_max[QUANT_BLOCK];
float local_max = 0.0f;
for (int i = tid; i < cols; i += QUANT_BLOCK) {
float v = fabsf(__bfloat162float(row_src[i]));
local_max = fmaxf(local_max, v);
}
smem_max[tid] = local_max;
__syncthreads();
// Block reduction
for (int s = QUANT_BLOCK / 2; s > 0; s >>= 1) {
if (tid < s) {
smem_max[tid] = fmaxf(smem_max[tid], smem_max[tid + s]);
}
__syncthreads();
}
float absmax = smem_max[0];
float scale = absmax / FP8_E4M3_MAX;
// Clamp scale to avoid div-by-zero for all-zero rows
if (scale < 1e-12f) scale = 1e-12f;
float inv_scale = 1.0f / scale;
// Thread 0 writes the scale
if (tid == 0) {
scales[row] = scale;
}
// Step 2: Quantize each element
for (int i = tid; i < cols; i += QUANT_BLOCK) {
float v = __bfloat162float(row_src[i]) * inv_scale;
row_dst[i] = __nv_fp8_e4m3(v);
}
}
// Row-wise scale: data[row, :] *= scales[row] (in-place, BF16)
__global__ void rowwise_scale_bf16_kernel(
__nv_bfloat16* __restrict__ data,
const float* __restrict__ scales,
int num_rows, int cols
) {
int row = blockIdx.x;
if (row >= num_rows) return;
int tid = threadIdx.x;
float s = scales[row];
__nv_bfloat16* row_data = data + (long long)row * cols;
for (int i = tid; i < cols; i += blockDim.x) {
float v = __bfloat162float(row_data[i]) * s;
row_data[i] = __float2bfloat16(v);
}
}
extern "C" {
void launch_rowwise_scale_bf16(
void* data, const void* scales,
int num_rows, int cols,
void* stream
) {
int block = 256;
int grid = num_rows;
rowwise_scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)data, (const float*)scales,
num_rows, cols
);
CUDA_CHECK_LAST_ERROR();
}
void launch_quantize_bf16_to_fp8e4m3_rowwise(
const void* src,
void* dst,
void* scales,
int num_rows, int cols,
void* stream
) {
int grid = num_rows;
int block = QUANT_BLOCK;
quantize_bf16_to_fp8e4m3_rowwise_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)src,
(__nv_fp8_e4m3*)dst,
(float*)scales,
num_rows, cols
);
CUDA_CHECK_LAST_ERROR();
}
}