Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 + per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill (M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E]. Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB 5090 with room for KV cache. NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses (350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt block-scaled MXFP4) or a Marlin-class kernel; see docs/benchmarks/mxfp4-and-llama-decode.md. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
72 lines
3.3 KiB
Markdown
72 lines
3.3 KiB
Markdown
# MXFP4 W4A16 + decode-speed vs llama.cpp (gpt-oss-20b, 2×RTX 5090)
|
||
|
||
## xserv vs llama.cpp — single-stream decode (TP=2, same GPUs)
|
||
|
||
`tools/xserv_vs_llama.py` streams identical prompts through each server's
|
||
OpenAI endpoint (counting llama's `reasoning_content` as real decode tokens).
|
||
|
||
| metric | xserv FP8 | llama MXFP4 |
|
||
|---|---|---|
|
||
| Decode TPOT (medium) | 13.1 ms | **6.6 ms** (2.0× faster) |
|
||
| Throughput | 76 tok/s | **151 tok/s** |
|
||
| TTFT (short/medium) | 35–50 ms | 60–63 ms |
|
||
| TTFT (long, 1.6k tok) | 94 ms | **35 ms** |
|
||
|
||
llama.cpp decodes ~2× faster; prefill is comparable-to-better.
|
||
|
||
## Why — decode is memory/comm-bound, not launch-bound
|
||
|
||
Traced + measured (not assumed):
|
||
|
||
- The 24-layer decode loop is already fully async (no per-layer syncs), so kernel
|
||
launches hide behind GPU work — a CUDA graph would buy ~0.5–1.5 ms, not 2×.
|
||
- **TP=2→TP=4 probe**: TPOT 13.5→10.2 ms (FP8) with the *same* launch count and
|
||
*more* NCCL — confirms the bottleneck is **expert HBM traffic + all-reduce**,
|
||
not launch overhead.
|
||
- Even FP8 TP=4 (10.2 ms) can't catch llama TP=2 (6.6 ms): the gap is
|
||
*algorithmic*. llama is **sparse (top-4 of 32 experts) + 4-bit (MXFP4)**;
|
||
xserv is **dense (all 16 local experts) + 8-bit (FP8)** → ~8× the expert bytes
|
||
per token. Dense also makes xserv's long-prefill TTFT worse.
|
||
|
||
The two levers that close it: **sparse top-k MoE** (≈4×, the bigger structural
|
||
change) and **4-bit weights** (≈2×).
|
||
|
||
## MXFP4 W4A16 (this change) — correct, smallest, not yet faster than FP8
|
||
|
||
Weight-only 4-bit: expert weights are MXFP4 (E2M1 + per-32 UE8M0 scale,
|
||
`tools/quantize_mxfp4.py`); a fused kernel reads the 4-bit weights and
|
||
dequantizes on-chip to BF16. Decode uses `batched_gemv_mxfp4`; prefill (M>1)
|
||
dequantizes to BF16 then reuses the BF16 batched GEMM.
|
||
|
||
| | MXFP4 W4A16 | FP8 W8A8 | BF16 |
|
||
|---|---|---|---|
|
||
| Model size | **13 GB** | 22 GB | 39 GB |
|
||
| Greedy tokens | identical | identical | baseline |
|
||
| Decode TPOT (TP=2) | 17.0 ms | **13.5 ms** | 18.8 ms |
|
||
| Decode TPOT (TP=4) | 11.8 ms | **10.2 ms** | — |
|
||
| Prefill TTFT | 350 ms | **134 ms** | 135 ms |
|
||
|
||
- **Correct** (byte-identical greedy tokens to FP8/BF16) and **smallest
|
||
footprint** — fits one 32 GB 5090 with ample room for KV cache.
|
||
- **Not faster than FP8**: the hand-written W4A16 dequant-GEMV (no tensor cores)
|
||
is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even reading half
|
||
the bytes it stays ~2–3.5 ms behind FP8 at every TP. The TP=4 scaling
|
||
(17→11.8) shows it *is* partly memory-bound; a fixed per-GEMM inefficiency
|
||
dominates. Vectorized loads, hoisted scale, warp reduction, and shared-memory
|
||
activation tiling did not change it.
|
||
- **Prefill regresses** (350 vs 134 ms) — the dequant-to-BF16 fallback.
|
||
|
||
Committed as a **memory-optimization foundation**, not a decode speedup.
|
||
|
||
## To make 4-bit actually win
|
||
|
||
- **FP4 tensor cores (W4A4)** — cuBLASLt block-scaled MXFP4 GEMM
|
||
(`CUDA_R_4F_E2M1` + `CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0`, available on
|
||
sm_120). Tensor-core throughput *at* 4-bit would beat FP8. Risk: the scale
|
||
swizzle layout.
|
||
- A **Marlin-class W4A16 kernel** (register-blocked, async-copy pipelined).
|
||
- **Sparse top-k MoE** for the larger, llama-matching win.
|
||
|
||
FP8 (the plan-cache fix + strided-batched optimization, 1.41× over BF16) remains
|
||
xserv's best-performing quantization today.
|