docs: M2a — KV-cache decode engine results (token-identical + length-dependent speedup)
Implementation log (docs/18) + Phase-3 row (evolution.md): the two decode primitives and their gates, the engine design (host-cache baseline), the token-identical centerpiece gate, and the measured throughput baseline showing the cache win is sequence-length-dependent (~1.0x@32, ~1.9x@128, naive OOM@256). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -360,3 +360,53 @@ gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to cl
|
||||
held-out correct > 0 confirms the checker + eval harness score real matches (not just format).
|
||||
M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic
|
||||
skill, which is downstream by design.
|
||||
|
||||
### M2a — KV-cache incremental-decode engine (single sequence, landed)
|
||||
|
||||
The decode engine (D3, built up front) that replaces the naive sampler — which re-runs the
|
||||
full forward over the growing prefix every step (O(t²), a fresh autograd graph per token). Two
|
||||
forward-only primitives + a raw-Tensor per-token block forward, each gated in isolation.
|
||||
|
||||
**Primitives (`xtrain-tensor`, both forward-only):**
|
||||
- `Tensor::rope_at(theta, pos0)` — RoPE at a token's *absolute* position (`pos = pos0 + row`,
|
||||
no modulo), vs the training `rope` (`pos = row % period`) which is left untouched (new CUDA
|
||||
kernel `rope_at_k` → no training-path risk). Cached K is stored post-RoPE, so it must match
|
||||
what the full forward produced at that position. **Gate:** bit-identical to the full-sequence
|
||||
rope's row `t` (`integration::rope_at_matches_full_rope_row`).
|
||||
- `Tensor::decode_attention(k, v, scale)` — single-query × cached-K/V SDPA (`[bh,1,hd]` vs
|
||||
`[bh,t,hd]`, no causal mask: the one query sees all cached keys). Composed from the existing
|
||||
strided batched GEMM + plain softmax — **no new kernel**. **Gate:** equals the full causal
|
||||
attention's last query row, max |Δ| 6e-8 (`integration::decode_attention_matches_…`).
|
||||
|
||||
**Engine (`xtrain-model/src/decode.rs`, `generate_greedy_cached`):** per-layer K/V cache +
|
||||
single-token incremental forward. Prefill = the first `prompt.len()` decode steps (one code
|
||||
path). Mirrors `model::block_forward` at the raw-Tensor level (no autograd tape — inference
|
||||
needs no grads), pulling weights via the public `params()` stable order (no model-internal
|
||||
visibility changes). The cache is host-accumulated token-major f32, rebuilt per step — the
|
||||
honest M2a baseline; M2b moves it device-side + adds batched ragged decode.
|
||||
|
||||
**Gate (the M2 centerpiece — token-identical):** KV-cache greedy decode is byte-for-byte the
|
||||
same token sequence as the naive full-recompute greedy. Verified two ways:
|
||||
- `xtrain-train/tests/decode_kv.rs` — small GQA model (8 query / 2 kv heads), F32, 24 generated
|
||||
tokens, exact token-equality. (Unit gate runs F32: a random model's near-uniform logits make
|
||||
argmax fragile to ~1e-6, so the tightest path is used; the trained model below has peaked
|
||||
logits → robust.)
|
||||
- v12 1.05B SFT checkpoint: `eval_arith --cached` produces the **identical** eval outcome to the
|
||||
naive run (format 100/100, correct 8/100) and byte-identical completions.
|
||||
|
||||
**Throughput baseline (v12 1.05B, batch 1, F32, profile-first — measured, not assumed):** the
|
||||
cache win is **sequence-length-dependent**, which is the honest systems finding here:
|
||||
|
||||
| max_new | naive | kv-cache | note |
|
||||
|---------|-------|----------|------|
|
||||
| 32 | 108 tok/s | 111 tok/s | ~1.0× — both **launch/overhead-bound** at short seq |
|
||||
| 128 | 69 tok/s | **133 tok/s** | **~1.9×** — naive's O(t²) recompute starts to bite |
|
||||
| 256 | **OOM** | 129 tok/s | naive rebuilds the O(seq²) graph every step → OOM |
|
||||
|
||||
Cached throughput stays ~constant (O(1)/token compute + constant memory); naive **decays**
|
||||
(108→69 tok/s, O(t)/token) and eventually **OOMs** (the full autograd graph per step). So at the
|
||||
short arithmetic-eval lengths the cache is overhead-bound and gives ~nothing — it matters for
|
||||
**long rollouts** (DPO pair-generation, GRPO completions), exactly where M3/M4 use it. (M2a's
|
||||
per-layer host round-trip is part of why short-seq is overhead-bound; M2b's device-side cache
|
||||
targets it.) This is the same measure-first lesson as T17 (process-per-GPU throughput-neutral):
|
||||
the win is real but only in the regime that actually stresses the bottleneck.
|
||||
|
||||
@@ -97,6 +97,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
||||
|
||||
**M1(SFT task baseline,已落地)**:可验证算术任务 + 数据生成器 + 评分器一套,host-side 9/9 单测过(masking、SFT-target 自洽 2000 样、parser 边界、种子确定性)。dash5 单卡从 v12 基座 SFT(loss 4.68→~0.34,best val 0.386)。**100 留出题 eval:格式 `\boxed{}` 习得率 base 0% → SFT 100%;算术正确率 8%。**——SFT 只买**格式**(0%→100% 干净落地),算术正确性是 base 模型本身弱项(如 `46*80` 框成 3380),正是 M3/M4 的可验证 reward 要去补的残差。一条诚实账:M1 用的是**朴素无 KV-cache 采样器**(每 token 全量 forward),100 题已经很慢——这正是 M2 解码引擎前置的动机。
|
||||
|
||||
**M2a(KV-cache 增量解码引擎,单序列,已落地)**:两个 forward-only 原语 + 裸 Tensor 逐 token block forward,各自隔离闸门。`rope_at`(绝对位置 RoPE,新 kernel,不动训练 `rope` → 训练路径零风险)逐位等于全序列 rope 的对应行;`decode_attention`(单 query × cached-K/V,由现成 strided-gemm + 普通 softmax 组合,**零新 kernel**)等于全 causal attention 末行(max|Δ| 6e-8)。引擎 `generate_greedy_cached` 镜像 `block_forward` 在 Tensor 层(无 autograd tape,推理不需梯度),靠**公开 `params()` 稳定顺序**拿权重(零 model 可见性改动)。**核心闸门 = token-identical**:与朴素全重算贪心逐 token 一致(小 GQA 单测 + v12 1.05B 上 cached eval 与 naive **逐字节相同**:format 100/100, correct 8/100)。**吞吐 baseline(v12, batch1, F32,profile-first 实测)= cache 收益随序列长度而定**:max_new 32 ≈ 持平(108 vs 111,短序列 launch 开销 bound)、128 **~1.9×**(69 vs 133)、256 naive **OOM** vs cached 129 tok/s。cached 吞吐**近恒定**(O(1)/token + 恒定显存),naive **衰减**(O(t)/token,O(seq²) 图 → OOM)。⇒ 短 eval prompt overhead-bound、cache 几乎无收益,真正受益的是**长 rollout**(DPO 造对 / GRPO completion)——与 T17(process-per-GPU 吞吐中性)同一条 measure-first 教训:收益真实,但只在真正压到瓶颈的 regime 里。M2a 的 per-layer 主机往返是短序列 overhead-bound 的一部分原因,M2b(device 端 cache + 批量 ragged)针对它。
|
||||
|
||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||
|
||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||
|
||||
Reference in New Issue
Block a user