Commit Graph

24 Commits

Author SHA1 Message Date
3a3425960c post-train: M2c — device-side KV cache (cat_seq), profile-first bottleneck shift
Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).

Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.

Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:38:16 +08:00
2c9b58cb3b post-train: M2b — batched KV-cache decode (G-way, token-identical)
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).

- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
  only kernel) — G rows share one decode position. Gate: == full rope for
  [0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
  decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
  broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
  sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:18:54 +08:00
7fb3b32fd9 post-train: M4 — GRPO actor-learner loop + cached temperature rollout
train_grpo: the online, critic-free RL loop — per step sample B prompts, roll
out G completions each, score with the rule-based checker (reward 0/1), compute
group-relative advantage A=(r−mean)/(std+ε), then K inner clipped_pg_loss
epochs with a KL leash to the frozen reference. Reward = pure 0/1 correctness
(KL is the format protector, the M3 collapse lesson). Tracks mean rollout reward
(the falsifiable "it learns" signal). Periodic checkpoint save.

decode: generate_cached adds temperature sampling to the KV-cache engine (M2) —
single-row [1,vocab] logits per step vs the naive sampler's [seq,vocab], far
lighter on the caching allocator (the naive sampler fragments it over a long
rollout). generate_greedy_cached now routes through it (temp 0); decode_kv
token-identical gate still passes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 16:59:05 +08:00
eff26a0898 post-train: M2a — KV-cache incremental decode engine (token-identical)
Single-sequence KV-cache decode (xtrain-model/src/decode.rs): per-layer K/V
cache + single-token incremental forward (prefill = first prompt.len() decode
steps, one code path). Mirrors model::block_forward at the raw-Tensor level (no
autograd tape — inference needs no grads), using rope_at + decode_attention.
Cache is host-accumulated token-major f32, rebuilt per step (the honest M2a
baseline; M2b moves it device-side + batched ragged).

Gate (the M2 centerpiece): KV-cache greedy decode is TOKEN-IDENTICAL to the
naive full-recompute greedy — tests/decode_kv.rs (small GQA model, F32, 24
tokens) and corroborated on the v12 1.05B SFT checkpoint (cached eval =
naive eval byte-for-byte: format 100/100, correct 8/100).

eval_arith --cached A/Bs the two paths + reports decode tok/s. Measured on v12
(1.05B, batch 1, F32): the cache win is sequence-length-dependent —
  max_new=32   naive 108 vs cached 111 tok/s  (~1.0x; overhead-bound)
  max_new=128  naive  69 vs cached 133 tok/s  (~1.9x)
  max_new=256  naive OOM     vs cached 129 tok/s
Cached throughput stays ~constant (O(1)/token) while naive decays (O(t)/token,
O(seq^2) graph → OOM at length). Short eval prompts are overhead-bound, so the
cache matters for long rollouts (DPO/GRPO), not the arithmetic eval itself.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 12:00:03 +08:00
fbf4ac2917 sft: assistant-only SFT (ignore-index CE) + chat-prompt greedy eval
Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path
used by the v12 SFT runs:

- cross_entropy ignores negative targets (-100 ignore-index), normalizing by
  valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu).
- Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is
  formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens
  masked to -100 while answer+EOS are supervised; i32 label cache alongside the
  u16 token cache; sample() retries windows that are fully masked; eval uses
  target_window so masking applies to val loss too (data.rs, train_loop.rs).
- train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues
  training from a base checkpoint.
- greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt
  generation eval.

Test fixtures updated for the new Corpus.labels field; dropout.rs carries
incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout);
correctness rests on the documented v12 base+SFT runs on the GPU box.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-29 16:19:02 +08:00
39df0b40c1 gqa: fix kv-proj shape test param indices (embed,attn_norm precede wq)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:38:42 +08:00
830d06ad01 gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
  kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
  attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
  & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
  writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
  (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
  parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:37:37 +08:00
4b6d3e0a79 test: flash+dropout cross-feature grad-check (Phase-2 integration)
Add flash_plus_dropout_grad_check_fp32 to xtrain-model dropout tests: the two
orthogonal Phase-2 features (T14 flash-attn, T18 dropout) in the same model must
still grad-check. Both models run train-mode p=0.2 (identical masks, seed is
flash-independent) so the only delta is the SDPA reduction order — checked against
the flash-vs-composed tolerance.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 00:43:54 +08:00
c36cdf74d1 Merge t18-dropout into main
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

# Conflicts:
#	README.md
#	crates/xtrain-autodiff/tests/autograd.rs
#	crates/xtrain-model/src/model.rs
#	crates/xtrain-train/src/bin/train.rs
#	crates/xtrain-train/src/train_loop.rs
#	docs/evolution.md
2026-06-18 00:41:41 +08:00
e625aa05dd dropout: wire into model (residual sites) + train/eval switch + flag (T18)
Config.dropout (default 0). TinyTransformer gets a Cell<bool> training switch
(train()/eval()/with_training, default eval = safe) + a Cell<u64> step_seed bumped
once per training forward. forward_batched derives a per-layer block_seed (pure fn
of step_seed×layer) and block_forward derives two per-site seeds, inserting
ops::dropout at the attn and ffn sub-block outputs (before each residual). The
seed is a pure function of (step_seed, layer, site) so the checkpoint (T13)
recompute re-derives the same masks → grads stay exact. p=0 or eval → no dropout
node → graph bit-identical to pre-T18.

train_loop: model.train() per step (restored after eval flips to eval); eval_loss
runs model.eval(). bin/train: --dropout flag → cfg.dropout. Export/sampling run in
eval (default), so exported weights are dropout-free (xserv closed loop unaffected).

Model-level tests (dropout.rs): p=0 bit-identical to no-dropout (logits/loss/grads);
eval(p>0) == p=0 identity; train differs from eval + finite; recompute-with-dropout
grads match non-recompute (fp32 + bf16).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 00:05:32 +08:00
9b05f4f93f test: flash==composed bf16 uses robust mean/p99 metric (repo convention)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:19:08 +08:00
5f3b81ac96 test+bins: flash grad-check, flash==composed, PyTorch parity, --flash flag
autograd: flash_attention_batched_bwd (dQ/dK/dV finite-diff, seq>tile)
+ flash_matches_composed_fwd. model/tests/flash.rs: flash==composed
on-vs-off (logits/loss/every param grad), fp32 + bf16. parity_dump:
XTRAIN_PARITY_FLASH dumps the flash path for the same parity.py oracle
(PyTorch SDPA parity at B>1). train + train_ddp get the --flash flag.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:10:39 +08:00
0e20821633 autodiff+model: flash-attention op + --flash opt-in wiring
ops::flash_attention autograd node (fwd caches O(N) logsumexp instead of
O(N²) probs; bwd via Tensor::flash_attention_backward). Model gets a
use_flash bool + with_flash(bool) builder; the SDPA core in attention()
picks ops::flash_attention vs ops::attention. flash threads through
block_forward so the recompute (T13) segment also runs flash. Default
off = composed path, graph unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:10:32 +08:00
69c5f07359 docs: Phase T13 — activation recompute
Design doc for per-block gradient checkpointing (KI-3): the no-tape forward +
recompute-on-backward design, the `checkpoint` primitive, per-block wrapping,
the exactness/correctness argument (same kernels + inputs → identical grads),
composition with bf16+DDP+batched, and the verification plan (on-vs-off grad
gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD
to fill after the dash5 run.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 09:45:16 +08:00
f202351be5 model: per-block activation recompute (--recompute)
Wrap each transformer block's forward in the checkpoint primitive when
recompute is enabled (Phase T13 / KI-3). To make the block forward a pure
segment fn (no `&self` borrow, so it can re-run in the backward closure),
extract the block body + its helpers (linear / norm_gamma / attention /
swiglu_mlp) into free functions parameterised by (cfg, compute_dtype) and add
`Block::block_params()` (the 11 leaves in the params() per-block order). The
non-recompute path calls `block_forward` directly — identical graph to before.

- `TinyTransformer::with_recompute(bool)` builder (opt-in; default off keeps the
  unchanged tape / bit-identical numerics).
- `--recompute` flag wired into bin/train and bin/train_ddp (DDP: each rank
  checkpoints independently).

Correctness gate: tests/recompute.rs builds two identical models (recompute
on/off), runs the same batched loss+backward, and asserts the forward logits,
the loss, and EVERY parameter grad match within tight fp tol — parameterised
over fp32 and bf16 (T12 composition).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 09:42:42 +08:00
5b7dde1736 test: bf16 test reads f32-cast logits (forward now returns bf16)
The `keep bf16 logits` change made forward_batched return bf16 logits
in bf16 mode; the bf16 test's host read must cast to f32 first.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 14:29:24 +08:00
48922cb628 perf: keep bf16 logits (no persistent fp32 logits buffer)
At vocab 50257 the logits tensor [B*S, vocab] is ~1.6GB fp32 at batch
32 — held across the whole backward. Keep it bf16: cross_entropy
upcasts the bf16 logits to fp32 internally (transient) + caches fp32
probs, and its backward casts dx back to bf16 to chain into the
bf16 lm_head matmul backward. The sampler casts bf16 logits→f32 before
the host argmax/softmax. Halves the persistent logits activation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 14:20:48 +08:00
0a2a4dcaa8 train: --bf16 flag (fp32-master AMP) + bf16 correctness test
- TinyTransformer::with_compute_dtype(BF16): embedding stays fp32
  master then casts to bf16; each linear casts its fp32 weight to bf16
  on the fly; logits cast back to fp32 for cross-entropy. Default F32
  reproduces the v0-v4 forward graph bit-for-bit.
- --bf16 flag on bin/train and bin/train_ddp (off by default).
- tests/bf16.rs: same fp32 master weights run fp32 vs bf16; assert
  loss/logits/grads within a loose bf16 tol, no NaN, and grads are
  fp32 (master untouched).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 14:14:55 +08:00
5353b38402 model: batched forward [B,S]
forward_batched(ids[B*S], batch)/loss_batched: run B equal-length sequences as
ONE forward over flattened [B*S] ids, so every linear is one big [B*S,dim] GEMM.
Attention reshapes to [B*nh,S,hd], runs the fused batched causal SDPA (per-seq
mask + RoPE period=S, no cross-sequence attention), writes back [B*S,dim]. The
old per-(batch,head) loop + host-round-tripping split/merge_heads + the additive
causal_mask leaf are gone. forward(ids[seq]) is now forward_batched(ids,1), so
the sampler / inference path (batch=1) is unchanged.

+batched_ids_tensor helper. New batched.rs test: batched forward == looped
single-sequence (logits identical 0.0, grads 6.4e-4, loss identical). PyTorch
parity now exercises B>1 (B=2,S=4): loss 5e-8, logits 6.9e-6, all 25 param
grads within rtol — verifying per-seq RoPE position + per-seq causal masking.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:25 +08:00
15f1e526c7 train: parameterize model size (scaling ladder)
Add Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn) so the model
size is a tunable rung instead of a hardcoded tiny config, and Config::core_params()
(num_params minus the two vocab×dim tables) — the figure the ladder is sized
against (the 50257-vocab embed+lm_head adds a fixed ~25M that is not capacity).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 18:34:39 +08:00
7a4f69e430 model: add per-head QK-norm (Qwen3-compat) for xserv export
xserv's Qwen3 forward unconditionally applies per-head RMSNorm to Q and K
(q_norm/k_norm, shape [head_dim]) before RoPE — even gamma=1 is a real RMS
divide, not identity. xtrain never had this, so an exact xserv<->xtrain loop
was structurally impossible. Add it (reusing the 2D rms_norm op on the
[seq*nh, hd] head rows, inserted between reshape and rope to mirror
qwen3.rs's order) so the trained model is genuinely Qwen3-compatible.

params() inserts q_norm,k_norm after wv; num_params() counts them; the
PyTorch parity refs (parity.py / adamw_parity.py) + their name lists add the
same step so the dumps stay self-consistent.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 17:33:19 +08:00
603c85e1e0 model: silence torch parity warning (read loss before backward)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:09:30 +08:00
3366f30c4d model: PyTorch parity harness (weight dump + equivalent torch model)
parity_dump.rs (#[ignore] fixture generator) dumps the model's exact
weights, ids, forward logits, loss, and per-param grads after one
backward. parity.py rebuilds the IDENTICAL model in PyTorch (same x@W
convention, RoPE rotate_half pos=row, RMSNorm, SwiGLU, causal SDPA),
runs fwd+bwd, and compares logits + every grad within rtol.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:07:30 +08:00
e3912c2380 model: tiny RoPE+RMSNorm+SwiGLU transformer + overfit test
New crate xtrain-model: a from-scratch decoder built entirely from the
autodiff op set.
- Config (tiny: dim=32, 2 layers, 2 heads, head_dim=16, ffn=64).
- TinyTransformer: embedding -> N x {pre-RMSNorm -> multi-head causal
  attention (RoPE, additive causal mask, per-head SDPA) -> residual;
  pre-RMSNorm -> SwiGLU MLP -> residual} -> final RMSNorm -> LM head.
  x@W weight convention (engine GEMM is plain A@B); dim=n_heads*head_dim.
- params()/zero_grad-able leaves for the optimizer; param_to_host export.
- overfit test: char-level bring-up (embedded text -> vocab -> shifted
  targets), minimal hand-written GD (p -= lr*grad) memorises one fixed
  batch -> loss ~0 + greedy argmax matches targets. End-to-end fwd+bwd
  correctness signal. Gated #![cfg(not(no_cuda))].

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:05:20 +08:00