docs: M2c — device KV cache + the bottleneck-shift finding
Implementation log (docs/18) + Phase-3 row (evolution.md): cat_seq device cache, gates hold (token-identical), and the profile-first finding — ~10% single-seq decode but no GRPO-step change because the long pole shifted to the per-sample logp/PG forwards after M2b batching. Names ragged batched prefill as the next decode lever. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -554,3 +554,25 @@ device-side cache (no host round-trip) is the remaining decode-engine optimizati
|
||||
caching allocator (the M4 OOM). So M2b closes the decode-engine milestone (M2a single-seq + M2b
|
||||
batched) and turns the rollout long-pole from "OOM/unbounded" into a bounded ~1.7× win — measured,
|
||||
with the device-cache as the named next lever.
|
||||
|
||||
### M2c — device-side KV cache (landed; the bottleneck moved, a profile-first finding)
|
||||
|
||||
The named M2b follow-up: keep K/V on the GPU (`[bh,T,hd]`, an `Option<Tensor>` per layer) and
|
||||
grow it by one token per step via a new `cat_seq` kernel (concat along the seq dim) — removing the
|
||||
M2a/M2b per-layer **host round-trip** (`to_cpu`/`from_slice`/re-upload) *and* the `transpose_3d01`.
|
||||
Both single-seq and batched decode refactored to it (cleaner than the host `Vec` + rebuild).
|
||||
|
||||
**Gates hold:** `cat_seq == host concat`; `decode_kv` single-seq + `decode_batch` G-way both still
|
||||
**token-identical**; GQA training path unaffected.
|
||||
|
||||
**The finding (why this is a measure-first lesson, not a speedup story):** removing the host
|
||||
round-trip buys **~10%** on *pure* single-seq decode (133 → 147 tok/s @128) but **does not move the
|
||||
GRPO step** (~8.5 s/step, unchanged). Because after M2b batching, the rollout is no longer the
|
||||
step's bottleneck — the per-sample **`per_token_logp` captures** (2 forwards/sample) and the
|
||||
**PG-update** forwards+backwards (`model.forward`, full-sequence, per sample) now dominate. So the
|
||||
long pole **shifted** from the rollout to the training-side forwards (cf. T11/T17/M2a: profile
|
||||
before optimizing — the bottleneck you fixed is not the one that remains). The device cache is
|
||||
still a real, correctness-gated improvement (cleaner code, less PCIe, ~10% decode); the honest
|
||||
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
|
||||
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
|
||||
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
|
||||
|
||||
@@ -105,6 +105,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
||||
|
||||
**M2b(批量 KV-cache 解码,已落地,补全 M2 引擎 + 修 rollout 长杆)**:M4 后补的 rollout 长杆修复——一个 prompt 的 **G 个样本同步解码**(每步一次 forward 跑整组 → G× 更少 kernel 启动)。一个新原语 `rope_pos`(逐 row 绝对位置 kernel,G 行共享一个解码位置;闸门 = `[0..n]` 逐位等于全 rope、统一 P 逐行等于 `rope_at(P)`,bit-identical)。引擎 `generate_cached_batch`:`BatchKVCache` 带 G 维,批量 `decode_step` 把 G 贯穿 embed/proj/QK-norm/`rope_pos`/cache;**M2a 两件零改动复用**——`decode_attention` 本就 batch-agnostic(bh=G·nh)、`repeat_kv(nh,batch=G)` 按组广播。闸门 = G 个贪心行逐字节等于单序列(`tests/decode_batch.rs`,8q/2kv 头练 repeat_kv 批量)。**吞吐**(v12, G6·B6, 接进 train_grpo):**~8.5s/step vs 单序列 ~14-16s/step ≈ 1.7×**(rollout-inclusive;未到满 G× 因 per_token_logp + PG 更新也占时间、M2a 主机往返还在);且**显存更稳**(一次批量 forward vs G 次分配撑碎 allocator 的 M4 OOM)。⇒ M2 引擎闭环(M2a 单序列 + M2b 批量),rollout 长杆从"OOM/无界"变成有界 ~1.7× 收益,device 端 cache 是点名的下一杠杆。
|
||||
|
||||
**M2c(device 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**:K/V 留 device 为 `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它(比 host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concat、decode_kv 单序列 + decode_batch 批量仍 **token-identical**、GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133→147 tok/s@128),但 **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward(同 T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实、闸门齐全的改进(更干净、少 PCIe、解码 +10%),但下一杠杆是 **per-sample forward 的 ragged 批量**而非 cache。M2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache),全 token-identical-gated;后训练栈完整、瓶颈已测绘。
|
||||
|
||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||
|
||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||
|
||||
Reference in New Issue
Block a user