Files
xserv/docs/benchmarks/fp8-quantization.md
Gahow Wang e631a71b68 quantization: single strided-batched FP8 MoE GEMM — cut per-token launches ~768→48
The plan-cache fix removed the per-expert heuristic churn but still issued one
cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local
experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over
BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul
(BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token.

A single strided call can't carry a per-batch scalar B-scale, so the per-expert
weight scale moves out of the GEMM epilogue into a fused post-scale kernel
(rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one
pass. This is precision-equivalent: BF16's relative error is scale-invariant, so
scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue.

Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K:
  decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms),
  throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-12 01:23:29 +08:00

4.9 KiB
Raw Permalink Blame History

FP8 W8A8 quantization — gpt-oss-20b (dash5, 8× RTX 5090)

Operator-level FP8 E4M3 quantization of the MoE expert weights, with real cuBLASLt FP8 tensor-core GEMM (W8A8: FP8 weights × dynamically-quantized FP8 activations). All other tensors (attention, router, embeddings, norms, biases) stay BF16.

Scheme

  • Weights (tools/quantize_fp8.py): expert gate_up_proj / down_proj quantized BF16 → FP8 E4M3 with a per-expert scalar scale (absmax/448). Stored transposed [E, N, K] because cuBLASLt FP8 on Blackwell (sm120) requires transA=T.
  • Activations: quantized dynamically at runtime, per-token (per-row absmax), recovered by a post-GEMM row scale.
  • Compute: batched_gemm_fp8 (crates/xserv-kernels/src/quantization.rs) runs one strided-batched cuBLASLt FP8 matmul for all experts (alpha=1, in-GEMM scales 1.0); a fused kernel then applies a_scale[token]·b_scale[expert] in a single pass. BF16's relative error is scale-invariant, so applying both scales post-GEMM is precision-equivalent to folding them into the epilogue.
  • Model size: 22 GB (FP8) vs 39 GB (BF16). The FP8 model fits on a single 32 GB 5090; BF16 needs ≥ 2.

The performance bug that was fixed

batched_gemm_fp8 originally rebuilt the entire cuBLASLt plan per expert, per GEMM, per layer, on every forward pass — running the algo heuristic search, creating/destroying the descriptor + 4 layouts + preference, and cudaMalloc-ing a 4-byte scale buffer — roughly 1500 heuristic searches per decoded token. This made FP8 slower than BF16:

FP8 (buggy) FP8 (fixed) BF16
Decode TPOT 27.0 ms 17.9 ms 18.8 ms
Throughput 37 tok/s 55.8 tok/s 53.2 tok/s

Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo) in a thread-local map keyed by (M, N, K, batch) so the heuristic runs once per shape, and allocate the scale buffer once.

Reducing launches: one strided-batched matmul

The per-expert loop still issued one cublasLtMatmul per expert — ~768 tiny launches per decoded token (16 local experts × 2 GEMMs × 24 layers). Collapsing each MoE GEMM into a single strided-batched cuBLASLt FP8 matmul (BATCH_COUNT

  • strided-batch offsets) drops that to ~48, with a fused post-scale kernel applying both scales. This required moving the per-expert weight scale out of the GEMM epilogue (a single strided call can't carry a per-batch scalar) into the post-scale kernel — precision-equivalent, as noted above.
(gpt-oss-20b, TP=2) per-expert FP8 batched FP8 BF16
Decode TPOT 17.9 ms 13.8 ms 18.8 ms
Throughput 55.8 tok/s 72.3 tok/s 53.2 tok/s

Results — GSM8K (greedy, TP=2 on the same 2 GPUs)

200-problem run is the per-expert plan-cache fix; 100-problem run is the strided-batched version. BF16 is the unchanged baseline in both.

Harness: tools/fp8_compare.py — a warm xserv-server per model, GSM8K streamed through /v1/chat/completions; TTFT = time to first token, TPOT = mean inter-token latency, per request.

metric FP8 per-expert (n=200) FP8 batched (n=100) BF16
GSM8K accuracy 93.0 % 91.0 % 90.5 / 90.0 %
TTFT median 67.4 ms 65.0 ms 68.8 / 69.5 ms
TPOT median 17.45 ms 13.08 ms 18.26 / 18.39 ms
TPOT p90 17.65 ms 13.28 ms 18.38 / 18.52 ms
Throughput 57.3 tok/s 76.4 tok/s 54.8 / 54.4 tok/s
Decode speedup vs BF16 1.05× 1.41× 1.00×
  • Accuracy: unchanged. FP8 is nominally +0.5 … +2.5 pts above BF16, but at n=100200 the standard error is ~23 pts, so they are statistically indistinguishable. The takeaway is that neither FP8 quantization nor the strided-batched rounding degrades accuracy.
  • Decode: FP8 1.41× faster once batched (TPOT 13.08 vs 18.39 ms), with a tight p90. The per-expert version was only ~1.05× — the ~768 tiny M=1 launches per token dominated; batching them into ~48 unlocked most of the FP8 expert-weight-bandwidth saving.
  • Prefill (TTFT): comparable. A multi-length sweep (113 / 561 / 1681 tokens) gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e. dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape heuristic), not prefill compute, at these lengths.

Single-GPU (TP=1)

FP8 runs gpt-oss-20b on one 5090 (bench-gpt-oss --tp 1, GPU6): TTFT 538 ms, TPOT 29.0 ms, 34.5 tok/s. BF16 cannot (39 GB > 32 GB). This — fitting a model that otherwise needs two GPUs onto one — is the largest practical win.

Follow-ups (not done)

  • Per-channel (per-output-row) weight scales for better accuracy headroom than per-tensor.
  • Warm common prefill shapes at load to hide the first-request heuristic stall.
  • Sparse (top-k only) MoE compute instead of dense — currently every token runs all experts, so only ~top_k/num_experts of the FP8 GEMM work is used.