quantization: single strided-batched FP8 MoE GEMM — cut per-token launches ~768→48

The plan-cache fix removed the per-expert heuristic churn but still issued one
cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local
experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over
BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul
(BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token.

A single strided call can't carry a per-batch scalar B-scale, so the per-expert
weight scale moves out of the GEMM epilogue into a fused post-scale kernel
(rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one
pass. This is precision-equivalent: BF16's relative error is scale-invariant, so
scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue.

Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K:
  decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms),
  throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 01:23:29 +08:00
parent 24c49c31c2
commit e631a71b68
3 changed files with 150 additions and 94 deletions

View File

@@ -86,6 +86,29 @@ __global__ void rowwise_scale_bf16_kernel(
}
}
// Combined dequant scale for batched MoE FP8 GEMM output.
// data[row, :] *= a_scales[row] * b_scales[row / tokens]
// where row = expert * tokens + token. a_scales is the per-token activation
// scale; b_scales is the per-expert scalar weight scale. Lets a single
// strided-batched FP8 matmul (alpha=1, scales=1) recover the real result in
// one pass instead of folding the weight scale into a per-expert GEMM call.
__global__ void rowwise_scale_moe_bf16_kernel(
__nv_bfloat16* __restrict__ data,
const float* __restrict__ a_scales,
const float* __restrict__ b_scales,
int num_rows, int cols, int tokens
) {
int row = blockIdx.x;
if (row >= num_rows) return;
int tid = threadIdx.x;
float s = a_scales[row] * b_scales[row / tokens];
__nv_bfloat16* row_data = data + (long long)row * cols;
for (int i = tid; i < cols; i += blockDim.x) {
float v = __bfloat162float(row_data[i]) * s;
row_data[i] = __float2bfloat16(v);
}
}
extern "C" {
void launch_rowwise_scale_bf16(
@@ -102,6 +125,20 @@ void launch_rowwise_scale_bf16(
CUDA_CHECK_LAST_ERROR();
}
void launch_rowwise_scale_moe_bf16(
void* data, const void* a_scales, const void* b_scales,
int num_rows, int cols, int tokens,
void* stream
) {
int block = 256;
int grid = num_rows;
rowwise_scale_moe_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)data, (const float*)a_scales, (const float*)b_scales,
num_rows, cols, tokens
);
CUDA_CHECK_LAST_ERROR();
}
void launch_quantize_bf16_to_fp8e4m3_rowwise(
const void* src,
void* dst,