dash5, gpt-oss-20b FP8, warm-server vs llama.cpp MXFP4 (6 reps): TP=2 TPOT 5.76-5.89 vs 7.42-8.45 ms (xserv 1.26-1.47x), TTFT 2.4x ahead short/medium; TP=1 5.78-5.95 vs 2.80-3.22 ms (gap 2.5x -> 2.0x, TTFT now ahead short/medium). GSM8K-50 through the graph path: 94%. Lesson recorded: graphs bought ~0.6 ms (launches were already hidden by async execution), the GPU argmax ~1 ms — measure, don't guess. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
5.9 KiB
Phase 21: gpt-oss decode CUDA Graph + GPU argmax
目标:消除 decode 的每 token 固定开销。Phase 20 之后 TPOT ~7ms,其中 GPU 实际计算只占一部分,剩下是 ~200 个 kernel launch 和 per-token 的 host 工作。本阶段把整个 decode step 捕获成一个 CUDA graph,每 token 一次
cudaGraphLaunch回放;顺带把 greedy 采样换成 GPU argmax。实现:
crates/xserv-model/src/gpt_oss_graph.rs(~150 行)+ 三块基础设施。
1. CUDA Graph 是什么,为什么有约束
cudaStreamBeginCapture 之后,发到该 stream 的 kernel 不执行而是被录制;
EndCapture + Instantiate 得到可执行图;以后每步 cudaGraphLaunch 一次性
重放全部 ~200 个 kernel,host 端开销从 ~200 次 launch 降到 1 次。
代价是三条硬约束,每条都对应一个工程问题:
- 地址稳定:录制时烤进图里的全部指针,回放时必须仍然有效且指向正确数据;
- capture 期间禁止"不安全"调用:
cudaMalloc/同步 memcpy/cudaDeviceSynchronize都会让 capture 报错(error 900); - 形状固定:grid 尺寸被烤死,变 shape 就要重录。
2. 为什么 xserv 的 decode 本来就"差一点"就能整图捕获
逐项检查 decode step 的输入,发现绝大部分已经满足地址稳定:
| 每步会变的输入 | 地址 | 内容如何更新 |
|---|---|---|
| block table / context lens | PagedKVCache 的常驻 GPU 缓冲 ✓ | decode_prepare 在图外 H2D |
| KV 写入位置 | scatter kernel 从 GPU 上的 context_lens 读 ✓ | 同上 |
| attention 读取范围 | paged kernel 从同一缓冲读 ✓ | 同上 |
| MoE 专家选择 | sparse GEMV 从图内刚写的 topk_ids 读 ✓ |
数据依赖,天然支持 |
| token id / position | ✗ 原来是每步从 host slice 上传 | 本阶段改造点 |
也就是说,Phase 11(paged KV)和 Phase 20(sparse MoE)的"数据驱动"设计
无意中已经为 graph 化铺平了路 —— 唯二需要动的是 embedding 的 token id 和
RoPE 的 position:各加一个 device-buffer 变体(embedding_device_ids /
rope_inplace_device_pos),id/pos 存进两个常驻 4 字节缓冲,每步图外更新。
重构后的结构:
forward_decode_paged = decode_prepare(host 簿记,图外)
+ upload ids/pos(图外)
+ decode_core(纯 GPU,可整段捕获)
+ advance_seq_len(host 簿记,图外)
3. 三个工程问题
3.1 null stream 不可捕获 → thread-local launch stream
全代码库的 kernel 都发射在 legacy null stream 上,而 capture 必须在显式
stream 上。解法:xserv_cuda::stream 加一个 thread-local 当前 stream
(默认 null,行为与从前逐字节一致),所有 kernel wrapper、cuBLAS 的
cublasSetStream、NCCL 的 collective 全部改读它。capture 代码用 RAII guard
(push_stream)把 capture stream 装进去,录完自动还原。
顺序正确性:显式 stream 以默认(blocking)方式创建,legacy stream 与其
双向隐式同步,所以图外的 H2D/采样 memcpy 与回放天然有序。
3.2 capture 期间禁止 cudaMalloc → "retained warmup" 二段式
中间张量来自 caching allocator;capture 中任何一次 pool miss 都会触发
cudaMalloc → error 900。第一版实现就栽在这里:隔离机制自己制造了
pool miss(capture 中释放的块被隔离,下一层同尺寸分配就找不到块了)。
解法是把同一个 step 先 eager 跑一遍、但隔离打开(begin_retain):
释放的块全部扣下不回池 → 跑完后池外恰好积累了"这一步需要的每一块";
把它们整批放回池,再开始 capture —— capture 重复完全相同的分配序列,
每次分配都命中池,一次 cudaMalloc 都不会发生。
(重复执行同一 step 是无害的:KV scatter 往同一个位置重写同样的值。)
3.3 回放引用的内存不能被别人拿走 → 隔离仓(quarantine)
capture 录下的中间缓冲在 host 侧早就 Drop 了,但图每次回放都会读写这些
地址。若它们回到分配池、被后续 prefill 拿走长期持有,就是双写损坏。
所以 capture 期间释放的块进入 RetainedBlocks 隔离仓,由 graph 对象持有,
graph 销毁时才归还 —— 这些内存在 graph 存活期内被锁定为它专用。
3.4 两个顺手的点
- THREAD_LOCAL capture mode:GLOBAL 模式下,任何线程的 cudaMalloc 都会 毒化 capture;TP 多 rank 线程并发 capture 必须用 THREAD_LOCAL。
- NCCL 可以被捕获:rank 内
ncclAllReduce发在 capture stream 上即可, TP=2 一次成功(各 rank 录各自的图,回放时 collective 自然配对)。
4. 意外的教训:launch 开销没有想象的大,argmax 才是大头
A/B 实测(in-process,FP8,96 tok):
| TP=1 | TP=2 | |
|---|---|---|
| eager + host argmax(Phase 20 末) | 7.5 ms | 7.6 ms |
| graph + host argmax | 6.9 ms | 6.9 ms |
| eager + GPU argmax | 6.5 ms | — |
| graph + GPU argmax | 5.9 ms | 5.8 ms |
- graph 只省了 ~0.6ms:decode 循环本来就是全异步的,launch 大部分被 GPU 执行掩盖,"~200 launch ≈ 4ms"的预估错了 —— 优化要测不要猜。
- GPU argmax 省了 ~1ms:greedy 采样原来每 token 把 [1, 201088] 的
logits(402KB)同步拷回 host、再扫描 201K 个 bf16。仓库里 Phase 15 就写好
的 argmax kernel(kernel 内归约 + 4 字节 D2H)一直没接到
sample()上。 - 细节:GPU argmax 与 host
max_by对完全相等的 logits 平局取的索引 不同,greedy 轨迹会在某个平局 token 处分叉 —— 输出同样合法(GSM8K 验证)。
5. 结果与剩余瓶颈
见 docs/benchmarks/sparse-moe.md 的 Phase 21 小节(warm-server 对打 llama
的数字以那里为准)。剩余 TPOT 的构成:~3ms 是 HBM 字节(其中非专家权重
仍是 BF16,含 1.16GB 的 lm_head —— Phase 22 量化它们),其余是 GEMV
带宽效率与 attention。llama 单卡 2.9ms 的差距主要就在"全模型 4-bit"。