quantization: single strided-batched FP8 MoE GEMM — cut per-token launches ~768→48
The plan-cache fix removed the per-expert heuristic churn but still issued one cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul (BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token. A single strided call can't carry a per-batch scalar B-scale, so the per-expert weight scale moves out of the GEMM epilogue into a fused post-scale kernel (rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one pass. This is precision-equivalent: BF16's relative error is scale-invariant, so scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue. Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K: decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms), throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -86,6 +86,29 @@ __global__ void rowwise_scale_bf16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Combined dequant scale for batched MoE FP8 GEMM output.
|
||||
// data[row, :] *= a_scales[row] * b_scales[row / tokens]
|
||||
// where row = expert * tokens + token. a_scales is the per-token activation
|
||||
// scale; b_scales is the per-expert scalar weight scale. Lets a single
|
||||
// strided-batched FP8 matmul (alpha=1, scales=1) recover the real result in
|
||||
// one pass instead of folding the weight scale into a per-expert GEMM call.
|
||||
__global__ void rowwise_scale_moe_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ data,
|
||||
const float* __restrict__ a_scales,
|
||||
const float* __restrict__ b_scales,
|
||||
int num_rows, int cols, int tokens
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) return;
|
||||
int tid = threadIdx.x;
|
||||
float s = a_scales[row] * b_scales[row / tokens];
|
||||
__nv_bfloat16* row_data = data + (long long)row * cols;
|
||||
for (int i = tid; i < cols; i += blockDim.x) {
|
||||
float v = __bfloat162float(row_data[i]) * s;
|
||||
row_data[i] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rowwise_scale_bf16(
|
||||
@@ -102,6 +125,20 @@ void launch_rowwise_scale_bf16(
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_rowwise_scale_moe_bf16(
|
||||
void* data, const void* a_scales, const void* b_scales,
|
||||
int num_rows, int cols, int tokens,
|
||||
void* stream
|
||||
) {
|
||||
int block = 256;
|
||||
int grid = num_rows;
|
||||
rowwise_scale_moe_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)data, (const float*)a_scales, (const float*)b_scales,
|
||||
num_rows, cols, tokens
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_quantize_bf16_to_fp8e4m3_rowwise(
|
||||
const void* src,
|
||||
void* dst,
|
||||
|
||||
Reference in New Issue
Block a user