7 Commits

Author SHA1 Message Date
3a3425960c post-train: M2c — device-side KV cache (cat_seq), profile-first bottleneck shift
Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).

Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.

Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:38:16 +08:00
2c9b58cb3b post-train: M2b — batched KV-cache decode (G-way, token-identical)
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).

- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
  only kernel) — G rows share one decode position. Gate: == full rope for
  [0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
  decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
  broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
  sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:18:54 +08:00
aaa77082ef post-train: M4 — clipped_pg_loss + scale_rows (GRPO policy-gradient op)
The GRPO (M4) token-level loss op + the one primitive it needs:

- scale_rows(x[r,c], s[r]): per-row scale (new ~5-line CUDA kernel). The
  clipped-PG backward scales each completion token's row of (probs − onehot) by
  its own per-token coefficient, which cross_entropy_backward's single scalar
  scale can't express.
- clipped_pg_loss(logits, target, logp_old, logp_ref, A, eps, beta): per-token
  ρ_t = exp(logπθ_t − logp_old_t), L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL
  (k3 estimator), masked to completion tokens. Backward reuses the CE machinery
  (probs − onehot) + scale_rows. Gates: grad-check the active PG path + the A=0
  (KL-only) path; degenerate value checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 14:07:02 +08:00
c88e2ab88c post-train: M2 — decode primitives (rope_at + decode_attention)
Two forward-only Tensor primitives the KV-cache decode engine is built on,
each gated by an isolated correctness test:

- rope_at(theta, pos0): RoPE at an absolute position (pos = pos0 + row, no
  modulo) for a single decode token, vs the training rope_k (pos = row %
  period) left untouched. New forward-only CUDA kernel, no training-path risk.
  Gate: bit-identical to the full-sequence rope's corresponding row.
- decode_attention(k, v, scale): single-query × cached-K/V SDPA, composed from
  the existing strided batched GEMM + plain (non-causal) softmax — no new
  kernel. Gate: equals the full causal attention's last query row (max |Δ| 6e-8).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 12:00:03 +08:00
fbf4ac2917 sft: assistant-only SFT (ignore-index CE) + chat-prompt greedy eval
Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path
used by the v12 SFT runs:

- cross_entropy ignores negative targets (-100 ignore-index), normalizing by
  valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu).
- Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is
  formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens
  masked to -100 while answer+EOS are supervised; i32 label cache alongside the
  u16 token cache; sample() retries windows that are fully masked; eval uses
  target_window so masking applies to val loss too (data.rs, train_loop.rs).
- train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues
  training from a base checkpoint.
- greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt
  generation eval.

Test fixtures updated for the new Corpus.labels field; dropout.rs carries
incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout);
correctness rests on the documented v12 base+SFT runs on the GPU box.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-29 16:19:02 +08:00
7821bd9c34 autograd: batch dim for ops (flatten linears, batched attention)
Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE
already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only
attention + RoPE need sequence awareness:

- RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e.
  per-sequence position on a flattened batch (period == tokens = single seq).
- Fused batched causal attention: new `Tensor::attention`/`attention_backward`
  + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh
  (sequence,head) blocks (new sgemm_strided_batched binding) and a causal
  softmax kernel (scale + per-row causal mask inline) — the whole attention is
  3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip.
- transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads.

grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all
pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside
the existing 12.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:15 +08:00
5aef3742d6 ops: transformer op fwd/bwd CUDA kernels + Tensor wrappers
add/mul/add_bias(+sum_rows)/rms_norm/silu/rope/softmax/cross_entropy,
each with its analytic backward, in csrc/ops/nn.cu (inlined warp/block
reductions). FFI declarations + nn.cu in build.rs (no_cuda gated). Tensor
gains the matching thin wrappers; DType grows I32 for cross-entropy targets.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:44:09 +08:00