Files
xserv/docs/benchmarks/fp8-quantization.md
Gahow Wang e631a71b68 quantization: single strided-batched FP8 MoE GEMM — cut per-token launches ~768→48
The plan-cache fix removed the per-expert heuristic churn but still issued one
cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local
experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over
BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul
(BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token.

A single strided call can't carry a per-batch scalar B-scale, so the per-expert
weight scale moves out of the GEMM epilogue into a fused post-scale kernel
(rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one
pass. This is precision-equivalent: BF16's relative error is scale-invariant, so
scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue.

Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K:
  decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms),
  throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-12 01:23:29 +08:00

100 lines
4.9 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# FP8 W8A8 quantization — gpt-oss-20b (dash5, 8× RTX 5090)
Operator-level FP8 E4M3 quantization of the MoE expert weights, with real
cuBLASLt FP8 tensor-core GEMM (W8A8: FP8 weights × dynamically-quantized FP8
activations). All other tensors (attention, router, embeddings, norms, biases)
stay BF16.
## Scheme
- **Weights** (`tools/quantize_fp8.py`): expert `gate_up_proj` / `down_proj`
quantized BF16 → FP8 E4M3 with a **per-expert scalar** scale (`absmax/448`).
Stored transposed `[E, N, K]` because cuBLASLt FP8 on Blackwell (sm120)
requires `transA=T`.
- **Activations**: quantized dynamically at runtime, **per-token** (per-row
absmax), recovered by a post-GEMM row scale.
- **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`)
runs **one strided-batched cuBLASLt FP8 matmul for all experts** (`alpha=1`,
in-GEMM scales `1.0`); a fused kernel then applies `a_scale[token]·b_scale[expert]`
in a single pass. BF16's relative error is scale-invariant, so applying both
scales post-GEMM is precision-equivalent to folding them into the epilogue.
- Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a
single 32 GB 5090; BF16 needs ≥ 2.
## The performance bug that was fixed
`batched_gemm_fp8` originally rebuilt the entire cuBLASLt plan **per expert,
per GEMM, per layer, on every forward pass** — running the algo heuristic
search, creating/destroying the descriptor + 4 layouts + preference, and
`cudaMalloc`-ing a 4-byte scale buffer — roughly 1500 heuristic searches per
decoded token. This made FP8 **slower than BF16**:
| | FP8 (buggy) | FP8 (fixed) | BF16 |
|---|---|---|---|
| Decode TPOT | 27.0 ms | **17.9 ms** | 18.8 ms |
| Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s |
Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo)
in a thread-local map keyed by `(M, N, K, batch)` so the heuristic runs once per
shape, and allocate the scale buffer once.
## Reducing launches: one strided-batched matmul
The per-expert loop still issued one `cublasLtMatmul` per expert — ~768 tiny
launches per decoded token (16 local experts × 2 GEMMs × 24 layers). Collapsing
each MoE GEMM into a single **strided-batched** cuBLASLt FP8 matmul (BATCH_COUNT
+ strided-batch offsets) drops that to ~48, with a fused post-scale kernel
applying both scales. This required moving the per-expert weight scale out of the
GEMM epilogue (a single strided call can't carry a per-batch scalar) into the
post-scale kernel — precision-equivalent, as noted above.
| (gpt-oss-20b, TP=2) | per-expert FP8 | batched FP8 | BF16 |
|---|---|---|---|
| Decode TPOT | 17.9 ms | **13.8 ms** | 18.8 ms |
| Throughput | 55.8 tok/s | **72.3 tok/s** | 53.2 tok/s |
## Results — GSM8K (greedy, TP=2 on the same 2 GPUs)
200-problem run is the per-expert plan-cache fix; 100-problem run is the
strided-batched version. BF16 is the unchanged baseline in both.
Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed
through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean
inter-token latency, per request.
| metric | FP8 per-expert (n=200) | FP8 batched (n=100) | BF16 |
|---|---|---|---|
| GSM8K accuracy | 93.0 % | 91.0 % | 90.5 / 90.0 % |
| TTFT median | 67.4 ms | 65.0 ms | 68.8 / 69.5 ms |
| TPOT median | 17.45 ms | **13.08 ms** | 18.26 / 18.39 ms |
| TPOT p90 | 17.65 ms | **13.28 ms** | 18.38 / 18.52 ms |
| Throughput | 57.3 tok/s | **76.4 tok/s** | 54.8 / 54.4 tok/s |
| Decode speedup vs BF16 | 1.05× | **1.41×** | 1.00× |
- **Accuracy: unchanged.** FP8 is nominally +0.5 … +2.5 pts above BF16, but at
n=100200 the standard error is ~23 pts, so they are statistically
indistinguishable. The takeaway is that neither FP8 quantization nor the
strided-batched rounding degrades accuracy.
- **Decode: FP8 1.41× faster** once batched (TPOT 13.08 vs 18.39 ms), with a
tight p90. The per-expert version was only ~1.05× — the ~768 tiny M=1 launches
per token dominated; batching them into ~48 unlocked most of the FP8
expert-weight-bandwidth saving.
- **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens)
gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e.
dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape
heuristic), not prefill compute, at these lengths.
## Single-GPU (TP=1)
FP8 runs gpt-oss-20b on **one** 5090 (`bench-gpt-oss --tp 1`, GPU6): TTFT 538 ms,
TPOT 29.0 ms, 34.5 tok/s. BF16 cannot (39 GB > 32 GB). This — fitting a model
that otherwise needs two GPUs onto one — is the largest practical win.
## Follow-ups (not done)
- Per-channel (per-output-row) weight scales for better accuracy headroom than
per-tensor.
- Warm common prefill shapes at load to hide the first-request heuristic stall.
- Sparse (top-k only) MoE compute instead of dense — currently every token runs
all experts, so only ~top_k/num_experts of the FP8 GEMM work is used.