# FP8 W8A8 quantization — gpt-oss-20b (dash5, 8× RTX 5090) Operator-level FP8 E4M3 quantization of the MoE expert weights, with real cuBLASLt FP8 tensor-core GEMM (W8A8: FP8 weights × dynamically-quantized FP8 activations). All other tensors (attention, router, embeddings, norms, biases) stay BF16. ## Scheme - **Weights** (`tools/quantize_fp8.py`): expert `gate_up_proj` / `down_proj` quantized BF16 → FP8 E4M3 with a **per-expert scalar** scale (`absmax/448`). Stored transposed `[E, N, K]` because cuBLASLt FP8 on Blackwell (sm120) requires `transA=T`. - **Activations**: quantized dynamically at runtime, **per-token** (per-row absmax), recovered by a post-GEMM row scale. - **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`) runs **one strided-batched cuBLASLt FP8 matmul for all experts** (`alpha=1`, in-GEMM scales `1.0`); a fused kernel then applies `a_scale[token]·b_scale[expert]` in a single pass. BF16's relative error is scale-invariant, so applying both scales post-GEMM is precision-equivalent to folding them into the epilogue. - Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a single 32 GB 5090; BF16 needs ≥ 2. ## The performance bug that was fixed `batched_gemm_fp8` originally rebuilt the entire cuBLASLt plan **per expert, per GEMM, per layer, on every forward pass** — running the algo heuristic search, creating/destroying the descriptor + 4 layouts + preference, and `cudaMalloc`-ing a 4-byte scale buffer — roughly 1500 heuristic searches per decoded token. This made FP8 **slower than BF16**: | | FP8 (buggy) | FP8 (fixed) | BF16 | |---|---|---|---| | Decode TPOT | 27.0 ms | **17.9 ms** | 18.8 ms | | Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s | Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo) in a thread-local map keyed by `(M, N, K, batch)` so the heuristic runs once per shape, and allocate the scale buffer once. ## Reducing launches: one strided-batched matmul The per-expert loop still issued one `cublasLtMatmul` per expert — ~768 tiny launches per decoded token (16 local experts × 2 GEMMs × 24 layers). Collapsing each MoE GEMM into a single **strided-batched** cuBLASLt FP8 matmul (BATCH_COUNT + strided-batch offsets) drops that to ~48, with a fused post-scale kernel applying both scales. This required moving the per-expert weight scale out of the GEMM epilogue (a single strided call can't carry a per-batch scalar) into the post-scale kernel — precision-equivalent, as noted above. | (gpt-oss-20b, TP=2) | per-expert FP8 | batched FP8 | BF16 | |---|---|---|---| | Decode TPOT | 17.9 ms | **13.8 ms** | 18.8 ms | | Throughput | 55.8 tok/s | **72.3 tok/s** | 53.2 tok/s | ## Results — GSM8K (greedy, TP=2 on the same 2 GPUs) 200-problem run is the per-expert plan-cache fix; 100-problem run is the strided-batched version. BF16 is the unchanged baseline in both. Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean inter-token latency, per request. | metric | FP8 per-expert (n=200) | FP8 batched (n=100) | BF16 | |---|---|---|---| | GSM8K accuracy | 93.0 % | 91.0 % | 90.5 / 90.0 % | | TTFT median | 67.4 ms | 65.0 ms | 68.8 / 69.5 ms | | TPOT median | 17.45 ms | **13.08 ms** | 18.26 / 18.39 ms | | TPOT p90 | 17.65 ms | **13.28 ms** | 18.38 / 18.52 ms | | Throughput | 57.3 tok/s | **76.4 tok/s** | 54.8 / 54.4 tok/s | | Decode speedup vs BF16 | 1.05× | **1.41×** | 1.00× | - **Accuracy: unchanged.** FP8 is nominally +0.5 … +2.5 pts above BF16, but at n=100–200 the standard error is ~2–3 pts, so they are statistically indistinguishable. The takeaway is that neither FP8 quantization nor the strided-batched rounding degrades accuracy. - **Decode: FP8 1.41× faster** once batched (TPOT 13.08 vs 18.39 ms), with a tight p90. The per-expert version was only ~1.05× — the ~768 tiny M=1 launches per token dominated; batching them into ~48 unlocked most of the FP8 expert-weight-bandwidth saving. - **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens) gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e. dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape heuristic), not prefill compute, at these lengths. ## Single-GPU (TP=1) FP8 runs gpt-oss-20b on **one** 5090 (`bench-gpt-oss --tp 1`, GPU6): TTFT 538 ms, TPOT 29.0 ms, 34.5 tok/s. BF16 cannot (39 GB > 32 GB). This — fitting a model that otherwise needs two GPUs onto one — is the largest practical win. ## Follow-ups (not done) - Per-channel (per-output-row) weight scales for better accuracy headroom than per-tensor. - Warm common prefill shapes at load to hide the first-request heuristic stall. - Sparse (top-k only) MoE compute instead of dense — currently every token runs all experts, so only ~top_k/num_experts of the FP8 GEMM work is used.