docs: M2c — device KV cache + the bottleneck-shift finding
Implementation log (docs/18) + Phase-3 row (evolution.md): cat_seq device cache, gates hold (token-identical), and the profile-first finding — ~10% single-seq decode but no GRPO-step change because the long pole shifted to the per-sample logp/PG forwards after M2b batching. Names ragged batched prefill as the next decode lever. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -105,6 +105,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
||||
|
||||
**M2b(批量 KV-cache 解码,已落地,补全 M2 引擎 + 修 rollout 长杆)**:M4 后补的 rollout 长杆修复——一个 prompt 的 **G 个样本同步解码**(每步一次 forward 跑整组 → G× 更少 kernel 启动)。一个新原语 `rope_pos`(逐 row 绝对位置 kernel,G 行共享一个解码位置;闸门 = `[0..n]` 逐位等于全 rope、统一 P 逐行等于 `rope_at(P)`,bit-identical)。引擎 `generate_cached_batch`:`BatchKVCache` 带 G 维,批量 `decode_step` 把 G 贯穿 embed/proj/QK-norm/`rope_pos`/cache;**M2a 两件零改动复用**——`decode_attention` 本就 batch-agnostic(bh=G·nh)、`repeat_kv(nh,batch=G)` 按组广播。闸门 = G 个贪心行逐字节等于单序列(`tests/decode_batch.rs`,8q/2kv 头练 repeat_kv 批量)。**吞吐**(v12, G6·B6, 接进 train_grpo):**~8.5s/step vs 单序列 ~14-16s/step ≈ 1.7×**(rollout-inclusive;未到满 G× 因 per_token_logp + PG 更新也占时间、M2a 主机往返还在);且**显存更稳**(一次批量 forward vs G 次分配撑碎 allocator 的 M4 OOM)。⇒ M2 引擎闭环(M2a 单序列 + M2b 批量),rollout 长杆从"OOM/无界"变成有界 ~1.7× 收益,device 端 cache 是点名的下一杠杆。
|
||||
|
||||
**M2c(device 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**:K/V 留 device 为 `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它(比 host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concat、decode_kv 单序列 + decode_batch 批量仍 **token-identical**、GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133→147 tok/s@128),但 **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward(同 T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实、闸门齐全的改进(更干净、少 PCIe、解码 +10%),但下一杠杆是 **per-sample forward 的 ragged 批量**而非 cache。M2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache),全 token-identical-gated;后训练栈完整、瓶颈已测绘。
|
||||
|
||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||
|
||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||
|
||||
Reference in New Issue
Block a user