Files
xserv/docs/benchmarks/mxfp4-and-llama-decode.md
Gahow Wang d33220498a quantization: MXFP4 W4A16 expert weights (memory-optimization foundation)
Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 +
per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads
the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused
dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill
(M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected
by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E].

Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to
FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB
5090 with room for KV cache.

NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less
efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes
decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses
(350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization
foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt
block-scaled MXFP4) or a Marlin-class kernel; see
docs/benchmarks/mxfp4-and-llama-decode.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-12 15:01:42 +08:00

72 lines
3.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# MXFP4 W4A16 + decode-speed vs llama.cpp (gpt-oss-20b, 2×RTX 5090)
## xserv vs llama.cpp — single-stream decode (TP=2, same GPUs)
`tools/xserv_vs_llama.py` streams identical prompts through each server's
OpenAI endpoint (counting llama's `reasoning_content` as real decode tokens).
| metric | xserv FP8 | llama MXFP4 |
|---|---|---|
| Decode TPOT (medium) | 13.1 ms | **6.6 ms** (2.0× faster) |
| Throughput | 76 tok/s | **151 tok/s** |
| TTFT (short/medium) | 3550 ms | 6063 ms |
| TTFT (long, 1.6k tok) | 94 ms | **35 ms** |
llama.cpp decodes ~2× faster; prefill is comparable-to-better.
## Why — decode is memory/comm-bound, not launch-bound
Traced + measured (not assumed):
- The 24-layer decode loop is already fully async (no per-layer syncs), so kernel
launches hide behind GPU work — a CUDA graph would buy ~0.51.5 ms, not 2×.
- **TP=2→TP=4 probe**: TPOT 13.5→10.2 ms (FP8) with the *same* launch count and
*more* NCCL — confirms the bottleneck is **expert HBM traffic + all-reduce**,
not launch overhead.
- Even FP8 TP=4 (10.2 ms) can't catch llama TP=2 (6.6 ms): the gap is
*algorithmic*. llama is **sparse (top-4 of 32 experts) + 4-bit (MXFP4)**;
xserv is **dense (all 16 local experts) + 8-bit (FP8)** → ~8× the expert bytes
per token. Dense also makes xserv's long-prefill TTFT worse.
The two levers that close it: **sparse top-k MoE** (≈4×, the bigger structural
change) and **4-bit weights** (≈2×).
## MXFP4 W4A16 (this change) — correct, smallest, not yet faster than FP8
Weight-only 4-bit: expert weights are MXFP4 (E2M1 + per-32 UE8M0 scale,
`tools/quantize_mxfp4.py`); a fused kernel reads the 4-bit weights and
dequantizes on-chip to BF16. Decode uses `batched_gemv_mxfp4`; prefill (M>1)
dequantizes to BF16 then reuses the BF16 batched GEMM.
| | MXFP4 W4A16 | FP8 W8A8 | BF16 |
|---|---|---|---|
| Model size | **13 GB** | 22 GB | 39 GB |
| Greedy tokens | identical | identical | baseline |
| Decode TPOT (TP=2) | 17.0 ms | **13.5 ms** | 18.8 ms |
| Decode TPOT (TP=4) | 11.8 ms | **10.2 ms** | — |
| Prefill TTFT | 350 ms | **134 ms** | 135 ms |
- **Correct** (byte-identical greedy tokens to FP8/BF16) and **smallest
footprint** — fits one 32 GB 5090 with ample room for KV cache.
- **Not faster than FP8**: the hand-written W4A16 dequant-GEMV (no tensor cores)
is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even reading half
the bytes it stays ~23.5 ms behind FP8 at every TP. The TP=4 scaling
(17→11.8) shows it *is* partly memory-bound; a fixed per-GEMM inefficiency
dominates. Vectorized loads, hoisted scale, warp reduction, and shared-memory
activation tiling did not change it.
- **Prefill regresses** (350 vs 134 ms) — the dequant-to-BF16 fallback.
Committed as a **memory-optimization foundation**, not a decode speedup.
## To make 4-bit actually win
- **FP4 tensor cores (W4A4)** — cuBLASLt block-scaled MXFP4 GEMM
(`CUDA_R_4F_E2M1` + `CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0`, available on
sm_120). Tensor-core throughput *at* 4-bit would beat FP8. Risk: the scale
swizzle layout.
- A **Marlin-class W4A16 kernel** (register-blocked, async-copy pipelined).
- **Sparse top-k MoE** for the larger, llama-matching win.
FP8 (the plan-cache fix + strided-batched optimization, 1.41× over BF16) remains
xserv's best-performing quantization today.