Replace the W8A16 dequant→BF16-GEMM path with native FP8×FP8→BF16 GEMM using cuBLASLt on Blackwell (RTX 5090). Both weights (static FP8 E4M3) and activations (dynamically quantized per-row) are processed directly on FP8 tensor cores. Key implementation details: - cuBLASLt on Blackwell requires transA=T for FP8, so expert weights are transposed during model loading ([E,K,N] → [E,N,K]) - Per-row activation quantization kernel (absmax/448 → FP8 E4M3) - Post-GEMM row-wise rescaling recovers per-token precision - Per-expert loop (not batched) due to cuBLASLt FP8 scale constraints The same FP8 quantized model files work — no re-quantization needed. Activation quantization happens dynamically at inference time. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
124 lines
3.3 KiB
Plaintext
124 lines
3.3 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <cuda_fp8.h>
|
|
#include <float.h>
|
|
#include "../common.cuh"
|
|
|
|
// Per-row quantize BF16 → FP8 E4M3 with per-row FP32 scale output.
|
|
//
|
|
// Input: src [num_rows, cols] BF16
|
|
// Output: dst [num_rows, cols] FP8 E4M3
|
|
// scales [num_rows] FP32
|
|
//
|
|
// Algorithm per row:
|
|
// absmax = max(|src[row, :]|)
|
|
// scale = absmax / 448.0 (FP8 E4M3 max representable)
|
|
// dst[row, i] = fp8(src[row, i] / scale)
|
|
//
|
|
// Grid: one block per row. Block: 256 threads.
|
|
// Each thread handles ceil(cols / 256) elements.
|
|
|
|
#define QUANT_BLOCK 256
|
|
#define FP8_E4M3_MAX 448.0f
|
|
|
|
__global__ void quantize_bf16_to_fp8e4m3_rowwise_kernel(
|
|
const __nv_bfloat16* __restrict__ src,
|
|
__nv_fp8_e4m3* __restrict__ dst,
|
|
float* __restrict__ scales,
|
|
int num_rows, int cols
|
|
) {
|
|
int row = blockIdx.x;
|
|
if (row >= num_rows) return;
|
|
int tid = threadIdx.x;
|
|
|
|
const __nv_bfloat16* row_src = src + (long long)row * cols;
|
|
__nv_fp8_e4m3* row_dst = dst + (long long)row * cols;
|
|
|
|
// Step 1: Compute per-row absmax via shared-memory reduction.
|
|
__shared__ float smem_max[QUANT_BLOCK];
|
|
float local_max = 0.0f;
|
|
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
|
float v = fabsf(__bfloat162float(row_src[i]));
|
|
local_max = fmaxf(local_max, v);
|
|
}
|
|
smem_max[tid] = local_max;
|
|
__syncthreads();
|
|
|
|
// Block reduction
|
|
for (int s = QUANT_BLOCK / 2; s > 0; s >>= 1) {
|
|
if (tid < s) {
|
|
smem_max[tid] = fmaxf(smem_max[tid], smem_max[tid + s]);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
float absmax = smem_max[0];
|
|
float scale = absmax / FP8_E4M3_MAX;
|
|
// Clamp scale to avoid div-by-zero for all-zero rows
|
|
if (scale < 1e-12f) scale = 1e-12f;
|
|
float inv_scale = 1.0f / scale;
|
|
|
|
// Thread 0 writes the scale
|
|
if (tid == 0) {
|
|
scales[row] = scale;
|
|
}
|
|
|
|
// Step 2: Quantize each element
|
|
for (int i = tid; i < cols; i += QUANT_BLOCK) {
|
|
float v = __bfloat162float(row_src[i]) * inv_scale;
|
|
row_dst[i] = __nv_fp8_e4m3(v);
|
|
}
|
|
}
|
|
|
|
// Row-wise scale: data[row, :] *= scales[row] (in-place, BF16)
|
|
__global__ void rowwise_scale_bf16_kernel(
|
|
__nv_bfloat16* __restrict__ data,
|
|
const float* __restrict__ scales,
|
|
int num_rows, int cols
|
|
) {
|
|
int row = blockIdx.x;
|
|
if (row >= num_rows) return;
|
|
int tid = threadIdx.x;
|
|
float s = scales[row];
|
|
__nv_bfloat16* row_data = data + (long long)row * cols;
|
|
for (int i = tid; i < cols; i += blockDim.x) {
|
|
float v = __bfloat162float(row_data[i]) * s;
|
|
row_data[i] = __float2bfloat16(v);
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_rowwise_scale_bf16(
|
|
void* data, const void* scales,
|
|
int num_rows, int cols,
|
|
void* stream
|
|
) {
|
|
int block = 256;
|
|
int grid = num_rows;
|
|
rowwise_scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(__nv_bfloat16*)data, (const float*)scales,
|
|
num_rows, cols
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_quantize_bf16_to_fp8e4m3_rowwise(
|
|
const void* src,
|
|
void* dst,
|
|
void* scales,
|
|
int num_rows, int cols,
|
|
void* stream
|
|
) {
|
|
int grid = num_rows;
|
|
int block = QUANT_BLOCK;
|
|
quantize_bf16_to_fp8e4m3_rowwise_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)src,
|
|
(__nv_fp8_e4m3*)dst,
|
|
(float*)scales,
|
|
num_rows, cols
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|