quantization: single strided-batched FP8 MoE GEMM — cut per-token launches ~768→48
The plan-cache fix removed the per-expert heuristic churn but still issued one cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul (BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token. A single strided call can't carry a per-batch scalar B-scale, so the per-expert weight scale moves out of the GEMM epilogue into a fused post-scale kernel (rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one pass. This is precision-equivalent: BF16's relative error is scale-invariant, so scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue. Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K: decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms), throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -14,9 +14,10 @@ stay BF16.
|
||||
- **Activations**: quantized dynamically at runtime, **per-token** (per-row
|
||||
absmax), recovered by a post-GEMM row scale.
|
||||
- **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`)
|
||||
runs one cuBLASLt FP8 matmul per expert; the per-expert weight scale is
|
||||
supplied via the cuBLASLt B-scale device pointer (FP32 epilogue, so precision
|
||||
matches folding it into `alpha`).
|
||||
runs **one strided-batched cuBLASLt FP8 matmul for all experts** (`alpha=1`,
|
||||
in-GEMM scales `1.0`); a fused kernel then applies `a_scale[token]·b_scale[expert]`
|
||||
in a single pass. BF16's relative error is scale-invariant, so applying both
|
||||
scales post-GEMM is precision-equivalent to folding them into the epilogue.
|
||||
- Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a
|
||||
single 32 GB 5090; BF16 needs ≥ 2.
|
||||
|
||||
@@ -34,34 +35,50 @@ decoded token. This made FP8 **slower than BF16**:
|
||||
| Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s |
|
||||
|
||||
Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo)
|
||||
in a thread-local map keyed by `(M, N, K)` so the heuristic runs once per shape;
|
||||
allocate the scale buffer once; pass per-expert weight scales by device pointer.
|
||||
The per-expert loop now issues only `cublasLtMatmul`.
|
||||
in a thread-local map keyed by `(M, N, K, batch)` so the heuristic runs once per
|
||||
shape, and allocate the scale buffer once.
|
||||
|
||||
## Results — GSM8K (200 problems, greedy, TP=2 on the same 2 GPUs)
|
||||
## Reducing launches: one strided-batched matmul
|
||||
|
||||
The per-expert loop still issued one `cublasLtMatmul` per expert — ~768 tiny
|
||||
launches per decoded token (16 local experts × 2 GEMMs × 24 layers). Collapsing
|
||||
each MoE GEMM into a single **strided-batched** cuBLASLt FP8 matmul (BATCH_COUNT
|
||||
+ strided-batch offsets) drops that to ~48, with a fused post-scale kernel
|
||||
applying both scales. This required moving the per-expert weight scale out of the
|
||||
GEMM epilogue (a single strided call can't carry a per-batch scalar) into the
|
||||
post-scale kernel — precision-equivalent, as noted above.
|
||||
|
||||
| (gpt-oss-20b, TP=2) | per-expert FP8 | batched FP8 | BF16 |
|
||||
|---|---|---|---|
|
||||
| Decode TPOT | 17.9 ms | **13.8 ms** | 18.8 ms |
|
||||
| Throughput | 55.8 tok/s | **72.3 tok/s** | 53.2 tok/s |
|
||||
|
||||
## Results — GSM8K (greedy, TP=2 on the same 2 GPUs)
|
||||
|
||||
200-problem run is the per-expert plan-cache fix; 100-problem run is the
|
||||
strided-batched version. BF16 is the unchanged baseline in both.
|
||||
|
||||
Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed
|
||||
through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean
|
||||
inter-token latency, per request.
|
||||
|
||||
| metric | FP8 W8A8 | BF16 |
|
||||
|---|---|---|
|
||||
| GSM8K accuracy | **93.0 %** | 90.5 % |
|
||||
| TTFT median | 67.4 ms | 68.8 ms |
|
||||
| TTFT p90 | 90.4 ms | 96.7 ms |
|
||||
| TPOT median | **17.45 ms** | 18.26 ms |
|
||||
| TPOT p90 | 17.65 ms | 18.38 ms |
|
||||
| Throughput | **57.3 tok/s** | 54.8 tok/s |
|
||||
| Mean output tokens | 288 | 293 |
|
||||
| metric | FP8 per-expert (n=200) | FP8 batched (n=100) | BF16 |
|
||||
|---|---|---|---|
|
||||
| GSM8K accuracy | 93.0 % | 91.0 % | 90.5 / 90.0 % |
|
||||
| TTFT median | 67.4 ms | 65.0 ms | 68.8 / 69.5 ms |
|
||||
| TPOT median | 17.45 ms | **13.08 ms** | 18.26 / 18.39 ms |
|
||||
| TPOT p90 | 17.65 ms | **13.28 ms** | 18.38 / 18.52 ms |
|
||||
| Throughput | 57.3 tok/s | **76.4 tok/s** | 54.8 / 54.4 tok/s |
|
||||
| Decode speedup vs BF16 | 1.05× | **1.41×** | 1.00× |
|
||||
|
||||
- **Accuracy: unchanged.** FP8 is nominally +2.5 pts, but with n=200 the
|
||||
standard error is ~2.1 pts, so the two are statistically indistinguishable.
|
||||
The takeaway is that FP8 did **not** degrade accuracy.
|
||||
- **Decode: FP8 ~5 % faster** (TPOT 17.45 vs 18.26 ms), reproducible across
|
||||
runs, with a tighter p90. Modest because the dense-MoE path loads *all*
|
||||
experts every token and FP8 only halves the *expert* bytes; the per-expert
|
||||
M=1 launches and M=1 tensor-core inefficiency absorb much of the bandwidth
|
||||
saving.
|
||||
- **Accuracy: unchanged.** FP8 is nominally +0.5 … +2.5 pts above BF16, but at
|
||||
n=100–200 the standard error is ~2–3 pts, so they are statistically
|
||||
indistinguishable. The takeaway is that neither FP8 quantization nor the
|
||||
strided-batched rounding degrades accuracy.
|
||||
- **Decode: FP8 1.41× faster** once batched (TPOT 13.08 vs 18.39 ms), with a
|
||||
tight p90. The per-expert version was only ~1.05× — the ~768 tiny M=1 launches
|
||||
per token dominated; batching them into ~48 unlocked most of the FP8
|
||||
expert-weight-bandwidth saving.
|
||||
- **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens)
|
||||
gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e.
|
||||
dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape
|
||||
@@ -75,9 +92,8 @@ that otherwise needs two GPUs onto one — is the largest practical win.
|
||||
|
||||
## Follow-ups (not done)
|
||||
|
||||
- Strided-batched FP8 (one call instead of ~768 per-expert launches per token) —
|
||||
requires folding the per-expert weight scale into the post-scale kernel, at a
|
||||
BF16-intermediate precision cost.
|
||||
- Per-channel (per-output-row) weight scales for better accuracy headroom than
|
||||
per-tensor.
|
||||
- Warm common prefill shapes at load to hide the first-request heuristic stall.
|
||||
- Sparse (top-k only) MoE compute instead of dense — currently every token runs
|
||||
all experts, so only ~top_k/num_experts of the FP8 GEMM work is used.
|
||||
|
||||
Reference in New Issue
Block a user