Files
xserv/docs/20-sparse-moe.md
Gahow Wang fb20178992 moe: sparse top-k decode — compute only routed experts (1.8x, beats llama TP=2)
Dense MoE replicated x across all 16 local experts and ran the full
batched GEMM, reading every expert's weights per token; the weighted
sum then discarded 12 of 16 results. Decode is memory-bound, so this
was ~8x wasted expert bytes — the entire decode gap vs llama.cpp.

New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read
topk_ids on-device (no host sync) and early-return block-uniformly
for experts other ranks own. FP8 runs W8A16 (activations stay BF16 —
tensor cores are irrelevant at M=1, and activation quantization error
disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the
GEMV epilogue; slot-indexed weighted sum skips (never multiplies)
unwritten non-local slots. Dense path retained for num_tokens > 8
(prefill) and via XSERV_DENSE_MOE=1 for A/B.

dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms.
Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms —
first config where xserv wins decode outright. GSM8K-100: 96% (dense
FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode
CUDA graphs, non-expert quantization, sparse prefill (docs/20).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 16:29:10 +08:00

161 lines
7.9 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase 20: Sparse MoE Decode — 只算被路由到的专家
> 目标:消除 dense MoE 的无效权重读取,decode TPOT 追上并超过 llama.cpp。
> 前置:Phase 19(gpt-oss MoE 正确性)、FP8 W8A8 / MXFP4 W4A16 量化
> (见 `docs/benchmarks/fp8-quantization.md`、`docs/benchmarks/mxfp4-and-llama-decode.md`)。
## 1. 现状:dense MoE 在浪费什么
gpt-oss-20b 是 32 专家 top-4 的 MoE:router 给每个 token 选 4 个专家,
理论上每 token 只需要读 4/32 = 12.5% 的专家权重。但 `moe_forward`
(`crates/xserv-model/src/gpt_oss.rs`)目前是 **dense** 实现:
```text
1. router GEMV [T, 2880] → [T, 32]
2. topk_softmax (GPU) → topk_ids [T,4], topk_weights [T,4]
3. moe_replicate x 复制 16 份 → [16, T, 2880] ← 浪费开始
4. batched GEMM gate_up 全部 16 个本地专家都算 ← 读 16 份权重
5. bias + GLU
6. batched GEMM down 全部 16 个本地专家都算 ← 读 16 份权重
7. bias
8. moe_weighted_sum 只挑出 top-4 加权求和,其余 12 个全部丢弃
9. all-reduce
```
为什么当初这么写:batched GEMM(cuBLAS strided-batched)要求规则的
`[E, T, K]` 形状;top-4 的专家编号在 **GPU** 上(`topk_ids`),host 不知道
该挑哪几个,挑了形状也不规则。dense 是"先把正确性做出来"的合理起点,
但每 token 把 16 个专家的权重从 HBM 全部读一遍。
### 字节账本(decode,每 token,TP=2 每卡 16 个本地专家)
每层每专家:gate_up `[2880, 5760]` + down `[2880, 2880]` ≈ 24.9 M 参数。
| 方案 | 每卡每 token 专家字节 | 相对量 |
|---|---|---|
| xserv dense FP8(现状) | 16 × 24.9 MB × 24 层 ≈ **9.6 GB** | 1× |
| xserv sparse FP8(本阶段) | ~2 × 24.9 MB × 24 层 ≈ **1.2 GB** | 1/8 |
| llama.cpp sparse MXFP4 | ~2 × 12.5 MB × 24 层 ≈ **0.6 GB** | 1/16 |
(top-4 均匀散落在 2 张卡上,期望每卡 2 个命中;严格说每层取的是
两卡命中数的 max,期望 ≈ 2.6,仍是 ~6-8× 的节省。)
实测旁证:FP8 dense TP=2 TPOT 13.1 ms,其中专家 GEMM ≈ 9.6 GB ÷ ~1 TB/s
≈ 9.5 ms,其余(attention、qkv/o、lm_head、48 次 PCIe all-reduce)≈ 3.5 ms。
**专家权重读取占 TPOT 的 ~3/4,这就是与 llama.cpp(6.6 ms)的全部差距。**
## 2. Roofline:M=1 时为什么"省字节 = 省时间"
decode 的 GEMV(M=1)每读 1 字节 FP8 权重只做 2 FLOP(乘加)。
RTX 5090:HBM ~1.8 TB/s,BF16 算力 ~210 TFLOPS —— 算强比(arithmetic
intensity)需要 ~100 FLOP/byte 才能喂饱算力,GEMV 只有 2。结论:
1. **decode 完全 memory-bound**,tensor core 帮不上忙 → 手写 W8A16 GEMV
(权重 FP8、激活保持 BF16)不会输给 cuBLASLt 的 W8A8 tensor-core GEMM,
还省掉激活量化 kernel,精度更好(激活不再有量化误差)。
2. 优化只有一个方向:**少读字节**。sparse(×8)与 4-bit(×2)正交,
可叠加。本阶段先做 sparse,FP8 与 MXFP4 两种权重格式都支持。
## 3. Sparse 设计:让 kernel 自己按 topk_ids 索引权重
关键观察:`topk_ids` 本来就在 GPU 上。不需要 host 知道选了谁 ——
**让 GEMV kernel 的每个 block 自己读 `topk_ids[token, slot]`,
直接寻址到对应专家的权重**,不命中本卡就整块退出。零 host 同步,
管线保持完全异步(这是之前排查过的:decode 循环无 per-layer sync)。
新数据流(`num_tokens ≤ 8` 时启用):
```text
x [T, 2880]
├─ router → topk_ids/weights [T, 4] (不变)
├─ sparse GEMV gate_up → [T, 4, 5760] bias 已融合,非本地 slot 不写
├─ GLU → [T*4, 2880]
├─ sparse GEMV down → [T, 4, 2880] bias 已融合,非本地 slot 不写
└─ weighted_sum_sparse → [T, 2880] 只累加本地 slot
all-reduce (不变)
```
`moe_replicate` 和独立的 bias kernel 在 sparse 路径下消失;FP8 路径还省掉
`quantize_bf16_to_fp8_rowwise`
### Kernel 设计(`csrc/moe/moe_sparse.cu`)
`moe_sparse_gemv_{fp8,mxfp4}_bf16_kernel`:
- **grid = (N/8, top_k, tokens)**,block = 8 warp × 32 lane。
每个 block 负责一个 (token, slot) 的 8 个输出列,**一个 warp 算一个输出**。
- block 先读 `eid = topk_ids[token*top_k + slot]`,折算 `lid = eid - expert_start`;
不在 `[0, local_experts)` 就整块 return。
- 命中的 block 把激活行(K=2880 个 BF16 → float)协作搬进 shared memory
(11.25 KB),`__syncthreads()`,然后每 warp 沿 K 维做点积:
每 lane 一次 `uint4` 读 16 字节权重(FP8 = 16 个权重,MXFP4 = 32 个 nibble),
warp 内 32 lane 连续 → 512B coalesced 事务。
- epilogue(lane 0):`y = acc * w_scale[lid] + bias[lid, n]` —— per-expert
scale 和 bias 都融合在这里,与 dense 路径的"GEMM → bias add → 路由加权"
语义逐位等价(HF 参考实现也是先加 bias 再乘路由权重)。
- gate_up 与 down 共用同一个 kernel,用 `x_per_slot` 区分激活寻址:
gate_up 时 4 个 slot 共享 `x[token]`;down 时各读自己的 `act[token*4+slot]`
### 两个容易写错的安全点
1. **early-return 必须 block-uniform。** Phase 19 的 GEMV 垃圾输出 bug
(commit `3b9e32e`)正是"部分线程在 `__syncthreads()` 之前 return"导致
读未初始化 shared memory。这里的 return 发生在 smem 装载**之前**,且整个
block 基于同一个 `topk_ids` 值统一退出 —— 没有 divergence,合法且安全。
2. **weighted-sum 对非本地 slot 必须"跳过",不能"乘 0"。** 非本地 slot 的
GEMV 输出从未被写入(未初始化显存,可能是 NaN 位型),GLU 也会在上面算出
垃圾。`NaN × 0 = NaN`,所以求和 kernel 用 `if (local) sum += w*v` 跳过,
垃圾永远不进入数据流(dense 路径的 `moe_weighted_sum` 同理)。
## 4. 为什么 prefill 保持 dense
dense batched GEMM 把 16 份权重读**一次**,服务全部 M 个 token;
sparse GEMV 是**每 token** 重读自己的 ~2 份。字节交叉点:
```text
sparse 读 M × 2 份 vs dense 读 16 份 → M ≈ 8 (TP=2)
```
M > 8 后 dense 更省(且 GEMM 是 compute-bound,tensor core 开始有用)。
所以 sparse 只在 `num_tokens ≤ 8` 启用 —— 覆盖 decode(连续批合并的
多请求 decode 也是小 M)和极短的 re-prefill。真正的 sparse prefill
(按专家对 token 做 permute/gather 的 grouped GEMM,vLLM 的做法)是
后续阶段,主要收益在长 prompt TTFT。
## 5. 实测结果(2026-06-12,完整数据见 `docs/benchmarks/sparse-moe.md`)
In-process decode(bench-gpt-oss,greedy 96 tok):
| | TPOT | tok/s |
|---|---|---|
| dense FP8 TP=2(基线) | 13.9 ms | 72 |
| **sparse FP8 TP=2** | **7.6 ms(1.8×)** | **132** |
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
| sparse FP8 TP=1(单卡) | 7.8 ms | 128 |
Warm-server 对打 llama.cpp(`tools/xserv_vs_llama.py`):
- **TP=2 vs TP=2:xserv 首次全面反超** —— TPOT 7.19-7.32 ms vs llama
7.54-8.42 ms;短/中 prompt TTFT 也领先(35/49 vs 63/65 ms)。
- **TP=1 vs TP=1:llama 大胜**(2.88-3.22 ms vs 7.0-7.2 ms,347 vs 140
tok/s)。单卡才是 llama 的最优配置:它的跨卡 split 在 PCIe 上每 token
损失 ~5 ms,而单卡时它"全模型 4-bit + CUDA graph 整 token 回放"的
优势全部兑现。xserv 的残余 ~7 ms ≈ ~3 ms HBM(其中非专家权重还是
BF16,含 1.16 GB 的 lm_head)+ ~4 ms 启动开销(~200 个 kernel
launch/token,无 CUDA graph)。
- **正确性:GSM8K-100 = 96%**(dense FP8 91% / BF16 90%,greedy 噪声内,
无回归)。
教训:之前"CUDA graph ≈ 无用(~0.5-1.5ms)"的结论是相对 13 ms 的
dense TPOT 而言;专家成本砍掉后,launch 开销变成了最大的单项。
## 6. 下一阶段(按收益排序)
1. **decode CUDA graph**(~2-4 ms):当前最大单项。
2. **非专家权重量化**(~1-1.5 ms):qkv/o + lm_head 仍是 BF16,每 token
白读 ~2.3 GB;llama 是全模型 4-bit。
3. **sparse prefill**(grouped GEMM):长 prompt TTFT 94-120 ms → llama
的 ~30 ms 量级。
4. **W4A4 FP4 tensor core / 带宽调优的 MXFP4 GEMV**:让 4-bit 专家真正
快过 FP8(目前 8.4 vs 7.6 ms,GEMV 效率抵消了字节优势)。