dash5, gpt-oss-20b FP8, warm-server vs llama.cpp MXFP4 (6 reps): TP=2 TPOT 5.76-5.89 vs 7.42-8.45 ms (xserv 1.26-1.47x), TTFT 2.4x ahead short/medium; TP=1 5.78-5.95 vs 2.80-3.22 ms (gap 2.5x -> 2.0x, TTFT now ahead short/medium). GSM8K-50 through the graph path: 94%. Lesson recorded: graphs bought ~0.6 ms (launches were already hidden by async execution), the GPU argmax ~1 ms — measure, don't guess. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
112 lines
5.7 KiB
Markdown
112 lines
5.7 KiB
Markdown
# Sparse MoE decode — 1.8× over dense; beats llama.cpp at TP=2 (gpt-oss-20b, RTX 5090)
|
||
|
||
Phase 20 (`docs/20-sparse-moe.md`): decode computes only the routed top-4
|
||
experts via fused expert-indexed GEMVs (`csrc/moe/moe_sparse.cu`) instead of
|
||
the dense all-local-expert batched GEMM. FP8 weights run W8A16 (weights FP8,
|
||
activations BF16 — decode is memory-bound, tensor cores irrelevant at M=1);
|
||
MXFP4 runs W4A16. Dense path retained for prefill / `num_tokens > 8` and via
|
||
`XSERV_DENSE_MOE=1` for A/B.
|
||
|
||
## In-process decode (bench-gpt-oss, greedy, 96 tokens)
|
||
|
||
| config | TPOT | tok/s |
|
||
|---|---|---|
|
||
| dense FP8 TP=2 (baseline) | 13.9 ms | 72 |
|
||
| **sparse FP8 TP=2** | **7.6 ms** | **132** |
|
||
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
|
||
| sparse FP8 TP=1 (one 5090) | 7.8 ms | 128 |
|
||
| sparse MXFP4 TP=1 | 8.9 ms | 113 |
|
||
|
||
- Sparse FP8 = **1.8× over dense**. Greedy output stays coherent.
|
||
- TP=1 ≈ TP=2: expert reads are now so small that PCIe all-reduce eats the
|
||
TP gain — single-GPU serving becomes the attractive deployment.
|
||
- MXFP4 reads half the bytes of FP8 but stays slower: the 4-bit dequant GEMV
|
||
has lower effective bandwidth (same fixed inefficiency seen in the dense
|
||
MXFP4 experiments); at sparse sizes both are partly launch/latency-bound.
|
||
|
||
## Head-to-head vs llama.cpp (tools/xserv_vs_llama.py, warm servers, TP=2, GPUs 0-1, 6 reps, 256 tok)
|
||
|
||
| prompt | metric | xserv sparse FP8 | llama MXFP4 | xserv vs llama |
|
||
|---|---|---|---|---|
|
||
| short | TTFT | **35.3 ms** | 62.7 ms | 1.78× faster |
|
||
| short | TPOT | **7.32 ms** | 8.42 ms | 1.15× faster |
|
||
| medium | TTFT | **49.4 ms** | 65.0 ms | 1.32× faster |
|
||
| medium | TPOT | **7.19 ms** | 7.54 ms | 1.05× faster |
|
||
| medium | tok/s | **139.1** | 132.7 | |
|
||
| long (1.6k) | TTFT | 94.1 ms | **44.7 ms** | 0.48× (llama wins) |
|
||
| long | TPOT | **7.25 ms** | 7.64 ms | 1.05× faster |
|
||
|
||
**Decode TPOT now beats llama.cpp at every prompt length** (was 2× slower:
|
||
13.1 vs 6.6 ms before sparse). Remaining loss: long-prompt TTFT — prefill is
|
||
still the dense all-expert GEMM; sparse/grouped prefill is the next phase.
|
||
|
||
**Post-review fixes** (same harness, rerun): removing three leftover
|
||
`cudaDeviceSynchronize` from the decode hot path and replacing the CPU-tiled
|
||
prefill bias-add (96 D2H/H2D round-trips per prefill) with a GPU broadcast
|
||
kernel improved both axes — TPOT 7.19-7.32 → **6.99-7.21 ms**, TTFT
|
||
short/medium/long 35/49/94 → **29/42/79 ms**. GSM8K-50: 94% (unchanged).
|
||
|
||
## TP=1 head-to-head (single 5090; server now routes gpt-oss tp=1 to the TP engine)
|
||
|
||
| prompt | metric | xserv sparse FP8 | llama MXFP4 |
|
||
|---|---|---|---|
|
||
| short | TTFT / TPOT | 42.8 ms / 7.00 ms | **34.5 ms / 3.22 ms** |
|
||
| medium | TTFT / TPOT | 57.1 ms / 7.19 ms | **37.3 ms / 2.89 ms** |
|
||
| long | TTFT / TPOT | 119.6 ms / 7.20 ms | **27.8 ms / 2.88 ms** |
|
||
| | tok/s | 139–143 | **311–347** |
|
||
|
||
**Single-GPU is llama.cpp's sweet spot and it wins 2.2–2.5×.** Two structural
|
||
reasons, both instructive:
|
||
|
||
1. llama TP=2 (7.5–8.4 ms) is much WORSE than its TP=1 (2.9 ms): its PCIe
|
||
cross-GPU split costs ~5 ms/token. xserv's NCCL all-reduce is cheap enough
|
||
that TP=2 ≈ TP=1 (7.2 vs 7.0 ms) — but xserv's single-GPU floor is high.
|
||
2. xserv TP=1 reads ~4.7 GB/token (experts FP8 2.4 GB + **non-expert weights
|
||
still BF16** ~2.3 GB, half of that the 201k-vocab lm_head) ≈ 3.1 ms of pure
|
||
HBM time; the other ~4 ms is launch overhead (~200 kernels/token, no CUDA
|
||
graphs) + BF16 GEMV efficiency. llama reads ~1.3 GB (everything MXFP4) and
|
||
replays the whole token as one CUDA graph.
|
||
|
||
## Correctness
|
||
|
||
- Greedy generations coherent across prompts (FP8/MXFP4, TP=1/2).
|
||
- Sparse FP8 is W8A16 vs dense W8A8 — activations are no longer quantized, so
|
||
tokens are not expected to be byte-identical to dense; quality is checked by
|
||
GSM8K instead.
|
||
- **GSM8K-100 (greedy, TP=2, `tools/eval_gsm8k_fast.py`): 96/100 = 96.0%** vs
|
||
dense FP8 91.0% / BF16 90.0% — no regression (within greedy-nondeterminism
|
||
noise; W8A16 removes activation-quantization error so ≥ dense is expected).
|
||
Avg 1.3 s/problem also reflects the decode speedup.
|
||
|
||
## Phase 21 update: decode CUDA graph + GPU argmax (docs/21-cuda-graph-decode.md)
|
||
|
||
The whole batch=1 decode step now replays as one CUDA graph, and greedy
|
||
sampling uses the GPU argmax kernel (4-byte D2H instead of a 402 KB logits
|
||
copy + 201k-element host scan). In-process A/B: graph −0.6 ms, GPU argmax
|
||
−1.0 ms. Warm-server head-to-head (same harness/GPUs, 6 reps):
|
||
|
||
| | xserv FP8 (graph) | llama MXFP4 | |
|
||
|---|---|---|---|
|
||
| TP=2 TPOT | **5.76–5.89 ms** (170–174 tok/s) | 7.42–8.45 ms | **xserv 1.26–1.47×** |
|
||
| TP=2 TTFT s/m/l | **25 / 28 / 51 ms** | 63 / 66 / 45 ms | xserv 2.4× s/m; long ~par |
|
||
| TP=1 TPOT | 5.78–5.95 ms | **2.80–3.22 ms** | llama 2.0× (was 2.5×) |
|
||
| TP=1 TTFT s/m | **32 / 35 ms** | 34 / 36 ms | xserv slightly ahead |
|
||
|
||
GSM8K-50 through the graph path: 47/50 = 94% (unchanged). Note: GPU argmax
|
||
breaks exact-tie logits differently than the host scan, so greedy trajectories
|
||
can legitimately diverge at a tie token.
|
||
|
||
## Remaining gaps / next levers (to catch llama TP=1 at 2.8 ms)
|
||
|
||
Per-token fixed overhead is now mostly gone; the residual ~5.8 ms is
|
||
dominated by HBM bytes and kernel efficiency. In impact order:
|
||
|
||
1. **Quantize non-expert weights** (~1.5 ms): attn qkv/o + the 1.16 GB BF16
|
||
lm_head read every token; FP8/MXFP4 them like llama quantizes everything.
|
||
2. **GEMV/attention bandwidth tuning**: effective BW of the hand GEMVs is
|
||
well under peak; llama's 2.8 ms implies ~85%+ efficiency on ~1.3 GB.
|
||
3. **Sparse prefill** (permute tokens by expert + grouped GEMM): long-prompt
|
||
TTFT 51–75 ms → llama's ~30 ms territory.
|
||
4. **W4A4 FP4 tensor cores / bandwidth-tuned MXFP4 GEMV**: make 4-bit experts
|
||
actually beat FP8.
|