docs: Phase 21 — decode CUDA graph + GPU argmax results

dash5, gpt-oss-20b FP8, warm-server vs llama.cpp MXFP4 (6 reps):
TP=2 TPOT 5.76-5.89 vs 7.42-8.45 ms (xserv 1.26-1.47x), TTFT 2.4x
ahead short/medium; TP=1 5.78-5.95 vs 2.80-3.22 ms (gap 2.5x -> 2.0x,
TTFT now ahead short/medium). GSM8K-50 through the graph path: 94%.
Lesson recorded: graphs bought ~0.6 ms (launches were already hidden
by async execution), the GPU argmax ~1 ms — measure, don't guess.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 20:12:37 +08:00
parent 8414f8d1e6
commit 013465fc06
4 changed files with 149 additions and 20 deletions

View File

@@ -0,0 +1,111 @@
# Phase 21: gpt-oss decode CUDA Graph + GPU argmax
> 目标:消除 decode 的每 token 固定开销。Phase 20 之后 TPOT ~7ms,其中
> GPU 实际计算只占一部分,剩下是 ~200 个 kernel launch 和 per-token 的
> host 工作。本阶段把**整个 decode step 捕获成一个 CUDA graph**,每 token
> 一次 `cudaGraphLaunch` 回放;顺带把 greedy 采样换成 GPU argmax。
>
> 实现:`crates/xserv-model/src/gpt_oss_graph.rs`(~150 行)+ 三块基础设施。
## 1. CUDA Graph 是什么,为什么有约束
`cudaStreamBeginCapture` 之后,发到该 stream 的 kernel 不执行而是被**录制**;
`EndCapture + Instantiate` 得到可执行图;以后每步 `cudaGraphLaunch` 一次性
重放全部 ~200 个 kernel,host 端开销从 ~200 次 launch 降到 1 次。
代价是三条硬约束,每条都对应一个工程问题:
1. **地址稳定**:录制时烤进图里的全部指针,回放时必须仍然有效且指向正确数据;
2. **capture 期间禁止"不安全"调用**:`cudaMalloc`/同步 memcpy/`cudaDeviceSynchronize`
都会让 capture 报错(error 900);
3. **形状固定**:grid 尺寸被烤死,变 shape 就要重录。
## 2. 为什么 xserv 的 decode 本来就"差一点"就能整图捕获
逐项检查 decode step 的输入,发现绝大部分已经满足地址稳定:
| 每步会变的输入 | 地址 | 内容如何更新 |
|---|---|---|
| block table / context lens | PagedKVCache 的常驻 GPU 缓冲 ✓ | `decode_prepare` 在图外 H2D |
| KV 写入位置 | scatter kernel **从 GPU 上的 context_lens 读** ✓ | 同上 |
| attention 读取范围 | paged kernel 从同一缓冲读 ✓ | 同上 |
| MoE 专家选择 | sparse GEMV 从图内刚写的 `topk_ids` 读 ✓ | 数据依赖,天然支持 |
| token id / position | ✗ 原来是每步从 host slice 上传 | **本阶段改造点** |
也就是说,Phase 11(paged KV)和 Phase 20(sparse MoE)的"数据驱动"设计
无意中已经为 graph 化铺平了路 —— 唯二需要动的是 embedding 的 token id 和
RoPE 的 position:各加一个 device-buffer 变体(`embedding_device_ids` /
`rope_inplace_device_pos`),id/pos 存进两个常驻 4 字节缓冲,每步图外更新。
重构后的结构:
```text
forward_decode_paged = decode_prepare(host 簿记,图外)
+ upload ids/pos(图外)
+ decode_core(纯 GPU,可整段捕获)
+ advance_seq_len(host 簿记,图外)
```
## 3. 三个工程问题
### 3.1 null stream 不可捕获 → thread-local launch stream
全代码库的 kernel 都发射在 legacy null stream 上,而 capture 必须在显式
stream 上。解法:`xserv_cuda::stream` 加一个 **thread-local 当前 stream**
(默认 null,行为与从前逐字节一致),所有 kernel wrapper、cuBLAS 的
`cublasSetStream`、NCCL 的 collective 全部改读它。capture 代码用 RAII guard
(`push_stream`)把 capture stream 装进去,录完自动还原。
顺序正确性:显式 stream 以默认(blocking)方式创建,legacy stream 与其
双向隐式同步,所以图外的 H2D/采样 memcpy 与回放天然有序。
### 3.2 capture 期间禁止 cudaMalloc → "retained warmup" 二段式
中间张量来自 caching allocator;capture 中任何一次 pool miss 都会触发
`cudaMalloc` → error 900。第一版实现就栽在这里:**隔离机制自己制造了
pool miss**(capture 中释放的块被隔离,下一层同尺寸分配就找不到块了)。
解法是把同一个 step 先 eager 跑一遍、但**隔离打开**(`begin_retain`):
释放的块全部扣下不回池 → 跑完后池外恰好积累了"这一步需要的每一块";
把它们整批放回池,再开始 capture —— capture 重复完全相同的分配序列,
每次分配都命中池,一次 cudaMalloc 都不会发生。
(重复执行同一 step 是无害的:KV scatter 往同一个位置重写同样的值。)
### 3.3 回放引用的内存不能被别人拿走 → 隔离仓(quarantine)
capture 录下的中间缓冲在 host 侧早就 Drop 了,但图每次回放都会读写这些
地址。若它们回到分配池、被后续 prefill 拿走长期持有,就是双写损坏。
所以 capture 期间释放的块进入 `RetainedBlocks` 隔离仓,由 graph 对象持有,
graph 销毁时才归还 —— 这些内存在 graph 存活期内被锁定为它专用。
### 3.4 两个顺手的点
- **THREAD_LOCAL capture mode**:GLOBAL 模式下,任何线程的 cudaMalloc 都会
毒化 capture;TP 多 rank 线程并发 capture 必须用 THREAD_LOCAL。
- **NCCL 可以被捕获**:rank 内 `ncclAllReduce` 发在 capture stream 上即可,
TP=2 一次成功(各 rank 录各自的图,回放时 collective 自然配对)。
## 4. 意外的教训:launch 开销没有想象的大,argmax 才是大头
A/B 实测(in-process,FP8,96 tok):
| | TP=1 | TP=2 |
|---|---|---|
| eager + host argmax(Phase 20 末) | 7.5 ms | 7.6 ms |
| graph + host argmax | 6.9 ms | 6.9 ms |
| eager + **GPU argmax** | 6.5 ms | — |
| **graph + GPU argmax** | **5.9 ms** | **5.8 ms** |
- **graph 只省了 ~0.6ms**:decode 循环本来就是全异步的,launch 大部分被
GPU 执行掩盖,"~200 launch ≈ 4ms"的预估错了 —— 优化要测不要猜。
- **GPU argmax 省了 ~1ms**:greedy 采样原来每 token 把 [1, 201088] 的
logits(402KB)同步拷回 host、再扫描 201K 个 bf16。仓库里 Phase 15 就写好
的 argmax kernel(kernel 内归约 + 4 字节 D2H)一直没接到 `sample()` 上。
- 细节:GPU argmax 与 host `max_by` 对**完全相等**的 logits 平局取的索引
不同,greedy 轨迹会在某个平局 token 处分叉 —— 输出同样合法(GSM8K 验证)。
## 5. 结果与剩余瓶颈
`docs/benchmarks/sparse-moe.md` 的 Phase 21 小节(warm-server 对打 llama
的数字以那里为准)。剩余 TPOT 的构成:~3ms 是 HBM 字节(其中非专家权重
仍是 BF16,含 1.16GB 的 lm_head —— **Phase 22 量化它们**),其余是 GEMV
带宽效率与 attention。llama 单卡 2.9ms 的差距主要就在"全模型 4-bit"。