Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).
Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.
Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).
- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
only kernel) — G rows share one decode position. Gate: == full rope for
[0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The GRPO (M4) token-level loss op + the one primitive it needs:
- scale_rows(x[r,c], s[r]): per-row scale (new ~5-line CUDA kernel). The
clipped-PG backward scales each completion token's row of (probs − onehot) by
its own per-token coefficient, which cross_entropy_backward's single scalar
scale can't express.
- clipped_pg_loss(logits, target, logp_old, logp_ref, A, eps, beta): per-token
ρ_t = exp(logπθ_t − logp_old_t), L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL
(k3 estimator), masked to completion tokens. Backward reuses the CE machinery
(probs − onehot) + scale_rows. Gates: grad-check the active PG path + the A=0
(KL-only) path; degenerate value checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Two forward-only Tensor primitives the KV-cache decode engine is built on,
each gated by an isolated correctness test:
- rope_at(theta, pos0): RoPE at an absolute position (pos = pos0 + row, no
modulo) for a single decode token, vs the training rope_k (pos = row %
period) left untouched. New forward-only CUDA kernel, no training-path risk.
Gate: bit-identical to the full-sequence rope's corresponding row.
- decode_attention(k, v, scale): single-query × cached-K/V SDPA, composed from
the existing strided batched GEMM + plain (non-causal) softmax — no new
kernel. Gate: equals the full causal attention's last query row (max |Δ| 6e-8).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path
used by the v12 SFT runs:
- cross_entropy ignores negative targets (-100 ignore-index), normalizing by
valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu).
- Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is
formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens
masked to -100 while answer+EOS are supervised; i32 label cache alongside the
u16 token cache; sample() retries windows that are fully masked; eval uses
target_window so masking applies to val loss too (data.rs, train_loop.rs).
- train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues
training from a base checkpoint.
- greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt
generation eval.
Test fixtures updated for the new Corpus.labels field; dropout.rs carries
incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout);
correctness rests on the documented v12 base+SFT runs on the GPU box.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE
already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only
attention + RoPE need sequence awareness:
- RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e.
per-sequence position on a flattened batch (period == tokens = single seq).
- Fused batched causal attention: new `Tensor::attention`/`attention_backward`
+ ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh
(sequence,head) blocks (new sgemm_strided_batched binding) and a causal
softmax kernel (scale + per-row causal mask inline) — the whole attention is
3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip.
- transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads.
grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all
pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside
the existing 12.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
add/mul/add_bias(+sum_rows)/rms_norm/silu/rope/softmax/cross_entropy,
each with its analytic backward, in csrc/ops/nn.cu (inlined warp/block
reductions). FFI declarations + nn.cu in build.rs (no_cuda gated). Tensor
gains the matching thin wrappers; DType grows I32 for cross-entropy targets.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>