libtest with --test-threads=1 (the documented invariant for this file's DDP
tests) runs tests alphabetically. The new
proc_per_gpu_dropout_is_live_and_p0_matches_no_dropout ('d') runs BEFORE
proc_per_gpu_matches_single_gpu_and_thread_path ('m'). It sets ENV_DROPOUT=0.2
via std::env::set_var; if left in place, the correctness test's spawned workers
would inherit it (Command inherits parent env by default) and build with
cfg.dropout=0.2 while its single-GPU baseline (run_single_gpu → test_config →
dropout=0) stays at 0 — GATE (a) `max_rel_single < 1e-3` would blow up by
orders of magnitude.
Two defenses:
- correctness test remove_var(ENV_DROPOUT) before spawn (belt): even if the
dropout test forgot to clean up, this test starts from a clean env.
- dropout test remove_var(ENV_DROPOUT, ENV_DUMP_DIR) at exit (suspenders):
keep the invariant "each test leaves the env as it found it" so any future
test added after these two starts clean too.
Same --test-threads=1 SAFETY comment applies (no concurrent env access).
Analogue of the ddp_dropout_is_live_and_p0_bit_identical test (T21, thread-per-
GPU) for the process-per-GPU launcher. Runs launch_processes twice on the same
corpus / init / config with the ONLY difference being cfg.dropout (passed
launcher→worker via a new XTRAIN_TEST_DROPOUT env — worker re-execs cannot
inherit argv changes), reads rank 0's loss trajectory from both runs, and
asserts GATE B: max |loss diff| > 1e-3.
The threshold sits ~4 orders of magnitude above this box's KI-5 cross-rank NCCL
noise floor (~1e-7), so it is an unambiguous "dropout mask is applied" signal,
not a noise measurement. Pre-fix (missing cfg.dropout = ... in the worker /
launcher, exactly the gap the paired launcher commit closes) both traces are
bit-identical and this test FAILs.
Also wires ENV_DROPOUT into the shared worker entry so the existing correctness
test's contract is unchanged (absent env → 0.0 → same synth run as before).
p0/ and p02/ subdirs isolate the two invocations' dumps.
T21 fixed --dropout under thread-per-GPU (train_ddp): added the flag, set
cfg.dropout, and made train_rank re-assert model.train() each step so the
training forward stays live across periodic eval flips. The process-per-GPU
launcher (train_ddp_mp) was left out: it never parsed --dropout, so cfg.dropout
stayed at Config::from_arch's 0.0 default, and the worker's model built with
dropout permanently disabled — silently, regardless of what the user passed.
The gap is the exact same launcher-wiring class the V9-PILOT caught: op-level
+ single-GPU tests pass, the DDP-thread T21 regression test passes, but the
proc-per-GPU launcher path was never exercised end-to-end with dropout>0.
Mirror bin/train_ddp exactly: parse --dropout (default 0, bit-identical
default), set cfg.dropout before build_model, print an ON banner on rank 0.
train_rank's per-step model.train() from T21 is reused unchanged (proc-per-GPU
uses the same train_rank).
Follow-up test that exercises this wiring end-to-end (GATE B loss-trace
divergence between p=0 and p=0.2 under process-per-GPU) lands in the next
commit.
Two exact correctness gates (composed = the end-to-end batched GRPO step == looped):
- xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded
ragged sequences == per-sequence single-seq forward on the real rows. fp32
max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is
free under causal".
- xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped
Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32).
bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards
(loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets
so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on
dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms
(9.2×), training forwards 2526→280 ms (9.0×).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
After M2b/M2c made the rollout cheap, the GRPO step is dominated by the per-sample
single-sequence training-side forwards: the per_token_logp captures (policy +
reference) and the inner clipped-PG forward/backwards. M2d packs all N=B·G ragged
samples of a step into ONE forward_batched.
Enabling property — right-padding is free under causal attention: a real completion
row sits at an earlier position than the trailing pad, and causal masking forbids
attending forward, so its logits equal the unpadded single-sequence forward; pad
rows are masked out (target=-100).
- ops::clipped_pg_loss_batched: like clipped_pg_loss but takes per-row advantage[t]
(the owning sample's A) and per-row weight[t] (the full normaliser). It does NOT
compute its own 1/n_tokens, so the caller passing weight=1/(N·n_s) reproduces the
looped Σ_s (1/N)(1/n_s)·clipped_pg_loss_s bit-for-bit (per-row CE backward is
row-local).
- grpo_batch.rs (shared module): per_token_logp_batched (right-pad → one
forward_batched(N) → slice back to real length) + looped baselines +
inner_pg_step_{looped,batched}. A --micro knob chunks the pack to bound the
[chunk·Lmax, vocab] logits memory; weight uses the GLOBAL N so chunked
grad-accumulation stays exact.
- train_grpo restructured to collect-all-samples-then-batch; per-window phase timers
(rollout / capture / inner) to keep the step decomposition honest. Default micro =
B·G; bench-measured 9× on the training forwards.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).
Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.
Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
train_grpo rolls out a prompt's G samples with one generate_cached_batch call
instead of G sequential generate_cached calls. Measured on v12 1.05B (G=6, B=6,
easy task): ~8.5 s/step vs ~14-16 s/step single-seq cached — ~1.7× (rollout-
inclusive; short of G× because per_token_logp + the PG update also cost, and the
M2a host round-trip remains). Also more stable memory: one batched forward per
step vs G allocations that fragment the caching allocator.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).
- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
only kernel) — G rows share one decode position. Gate: == full rope for
[0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
train_grpo: the online, critic-free RL loop — per step sample B prompts, roll
out G completions each, score with the rule-based checker (reward 0/1), compute
group-relative advantage A=(r−mean)/(std+ε), then K inner clipped_pg_loss
epochs with a KL leash to the frozen reference. Reward = pure 0/1 correctness
(KL is the format protector, the M3 collapse lesson). Tracks mean rollout reward
(the falsifiable "it learns" signal). Periodic checkpoint save.
decode: generate_cached adds temperature sampling to the KV-cache engine (M2) —
single-row [1,vocab] logits per step vs the naive sampler's [seq,vocab], far
lighter on the caching allocator (the naive sampler fragments it over a long
rollout). generate_greedy_cached now routes through it (temp 0); decode_kv
token-identical gate still passes.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The GRPO (M4) token-level loss op + the one primitive it needs:
- scale_rows(x[r,c], s[r]): per-row scale (new ~5-line CUDA kernel). The
clipped-PG backward scales each completion token's row of (probs − onehot) by
its own per-token coefficient, which cross_entropy_backward's single scalar
scale can't express.
- clipped_pg_loss(logits, target, logp_old, logp_ref, A, eps, beta): per-token
ρ_t = exp(logπθ_t − logp_old_t), L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL
(k3 estimator), masked to completion tokens. Backward reuses the CE machinery
(probs − onehot) + scale_rows. Gates: grad-check the active PG path + the A=0
(KL-only) path; degenerate value checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
gen_dpo_pairs: chosen = gold answer, rejected = the SFT model's own greedy
(KV-cache engine, M2a) completion when it's a format-valid WRONG boxed answer —
a hard negative from the model's distribution. ~8% of prompts skipped (greedy
correct). Writes question<TAB>chosen<TAB>rejected (bare, SFT-framed at train).
train_dpo: loads the SFT ckpt as policy AND frozen reference; precomputes the
reference logprobs ONCE (policy==ref) and caches them (one resident model). Each
step forwards the policy on chosen+rejected, seq_logprob each, minimises
dpo_loss; the two forwards share params so backward accumulates both branches.
Tracks reward margin + preference accuracy (the doc-13 "don't trust loss alone"
health signal). Loss starts at exactly log2 (Δ=0 at init) — a built-in check.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Two new ops for DPO (M3), both reusing existing kernels (no new CUDA):
- seq_logprob(logits, target): Σ log πθ(target) over non-ignored (target≥0)
positions — the per-sequence logprob DPO compares between policy and
reference. = −Σ per_row of cross_entropy (ignored rows already 0, like SFT
masking); backward = cross_entropy_backward(probs, target, −upstream) (sum,
no mean division). Gate: finite-diff grad-check with a -100 completion mask.
- dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β): scalar
L = −log σ(Δ) = softplus(−Δ) with the two policy logprobs as parents (ref
logprobs constant). Gate: grad-check both parents + degenerate points
(policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0). Same formula as TRL.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Single-sequence KV-cache decode (xtrain-model/src/decode.rs): per-layer K/V
cache + single-token incremental forward (prefill = first prompt.len() decode
steps, one code path). Mirrors model::block_forward at the raw-Tensor level (no
autograd tape — inference needs no grads), using rope_at + decode_attention.
Cache is host-accumulated token-major f32, rebuilt per step (the honest M2a
baseline; M2b moves it device-side + batched ragged).
Gate (the M2 centerpiece): KV-cache greedy decode is TOKEN-IDENTICAL to the
naive full-recompute greedy — tests/decode_kv.rs (small GQA model, F32, 24
tokens) and corroborated on the v12 1.05B SFT checkpoint (cached eval =
naive eval byte-for-byte: format 100/100, correct 8/100).
eval_arith --cached A/Bs the two paths + reports decode tok/s. Measured on v12
(1.05B, batch 1, F32): the cache win is sequence-length-dependent —
max_new=32 naive 108 vs cached 111 tok/s (~1.0x; overhead-bound)
max_new=128 naive 69 vs cached 133 tok/s (~1.9x)
max_new=256 naive OOM vs cached 129 tok/s
Cached throughput stays ~constant (O(1)/token) while naive decays (O(t)/token,
O(seq^2) graph → OOM at length). Short eval prompts are overhead-bound, so the
cache matters for long rollouts (DPO/GRPO), not the arithmetic eval itself.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Two forward-only Tensor primitives the KV-cache decode engine is built on,
each gated by an isolated correctness test:
- rope_at(theta, pos0): RoPE at an absolute position (pos = pos0 + row, no
modulo) for a single decode token, vs the training rope_k (pos = row %
period) left untouched. New forward-only CUDA kernel, no training-path risk.
Gate: bit-identical to the full-sequence rope's corresponding row.
- decode_attention(k, v, scale): single-query × cached-K/V SDPA, composed from
the existing strided batched GEMM + plain (non-causal) softmax — no new
kernel. Gate: equals the full causal attention's last query row (max |Δ| 6e-8).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
eval_arith: load ckpt, greedy-generate per held-out prompt, parse \boxed{}
via the shared task checker, report format(boxed) + correctness pass-rates.
Reused as the verifiable-eval harness for M3 (DPO) / M4 (GRPO).
M1 result (100 held-out prompts, v12 1.05B base): SFT moves answer-format
adherence 0% -> 100%, arithmetic correctness 8% -- the intended split (SFT
buys the format; correctness is the verifiable-reward job of M3/M4). Logged
in docs/18 implementation log + a Phase-3 row in docs/evolution.md.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The default operand ranges (max_add=99, max_mul=12) gave only ~20k unique
problems, so 'gen_arith_task --n 20000 --eval 500' (a) made train dedup
pathologically slow near saturation and (b) made the disjoint-eval loop never
terminate. A background run stalled after ~10k train rows with no eval files.
Fix (root cause, not a workaround):
- enlarge default ranges to max_add=999, max_mul=99 (~2.01M key space) so 20k+
requests are a tiny fraction and dedup stays trivial;
- add unique_space() + a generator guard that errors clearly when n+eval exceeds
80% of the key space, instead of looping forever.
Verified: cargo test 10/10; full 20000/500 gen now 0.2s, all 3 files, 0
train/eval leakage; guard panics on an oversized (--max-add 99) request.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
First post-training milestone (docs/18). Lands the verifiable task + its data
pipeline, all verified host-side (no CUDA); the SFT run itself reuses the
existing --sft-tsv path on the GPU box.
- task.rs: the shared task spec — two-operand integer arithmetic, answer in
\boxed{N}, with parse_boxed_answer + check_answer (exact-match rule-based
reward). One module reused by M1 (SFT data), M3 (DPO pairs), M4 (GRPO reward).
- gen_arith_task bin: writes arith_sft.tsv (--sft-tsv format) + held-out
arith_eval_prompts.txt (greedy_sample format) + arith_eval_gold.txt; train
deduped, eval disjoint from train.
- data.rs: extract assistant-only masking into a pure, testable sft_row()
(behavior-preserving; single-turn bit-identical to fbf4ac2).
Gate (verified locally, no_cuda): cargo test -p xtrain-train --lib = 9/9 pass
(masking, SFT-target self-consistency over 2000 samples, parser edges, seed
determinism); a 200/50 gen run = clean 2-col TSV, correct gold incl. negatives,
0 train/eval leakage. SFT training run + format-eval pending on dash5.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path
used by the v12 SFT runs:
- cross_entropy ignores negative targets (-100 ignore-index), normalizing by
valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu).
- Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is
formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens
masked to -100 while answer+EOS are supervised; i32 label cache alongside the
u16 token cache; sample() retries windows that are fully masked; eval uses
target_window so masking applies to val loss too (data.rs, train_loop.rs).
- train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues
training from a base checkpoint.
- greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt
generation eval.
Test fixtures updated for the new Corpus.labels field; dropout.rs carries
incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout);
correctness rests on the documented v12 base+SFT runs on the GPU box.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Adds ddp_dropout_is_live_and_p0_bit_identical, run via the real launcher
path (DdpContext::init + train_rank). It would have caught the original bug:
- GATE A (world=1, ONE step — the deterministic scope): the p=0 FORWARD is
byte-identical to no-dropout (ops::dropout(p=0) is a graph no-op) so the
step loss is BIT-IDENTICAL (== 0.0). At world=1 the NCCL all-reduce
short-circuits and one step has no optimizer-state compounding; the only
residual non-determinism is the engine's atomicAdd backward-reduction
order (the documented fresh-train md5 caveat — dropout-independent), so the
post-step params are checked against that tight ULP floor (< 1e-7).
- GATE A2 (world=2): p=0 matches a separate no-dropout baseline within NCCL's
run-to-run ULP noise (< 1e-6, KI-5 — the all-reduce is not bit-reproducible
on this PCIe box). Enabling dropout=0 doesn't perturb the DDP path beyond it.
- GATE B (world=2): a p=0.2 run's loss trace DIFFERS by > 1e-3 from p=0 —
orders of magnitude above every noise floor here (~3e-2 observed). On the
pre-T21 code the model stays in eval mode, so p=0.2 would be an identity and
the trace would match p=0 at the noise floor — this gate fails. (Verified by
simulating the bug: with model.train() removed, GATE B drops to 2.4e-7.)
- GATE C: a dedicated no-eval run ends with model.is_training() == true,
direct proof that train_rank called model.train().
- p>0 run is finite (no NaN/Inf).
eval_every < steps so a periodic eval fires mid-run (flipping to eval mode),
exercising the per-step model.train() restore discipline the pilot called out.
Run with --test-threads=1 like the other DDP tests (shared-GPU deadlock).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
V9-PILOT caught a launcher-level integration gap: T18 wired dropout into
the single-GPU bin/train, but the DDP path never did. train_ddp had no
--dropout flag and never set cfg.dropout, and ddp.rs::train_rank never
called model.train() — so under DDP every forward ran in the default eval
mode and dropout was a silent identity, regardless of config.
Fix, mirroring the single-GPU train/eval discipline:
- train_ddp.rs: add a --dropout <p> flag (default 0 = off, matching the
prior behavior) and set cfg.dropout from it; log it when on.
- ddp.rs::train_rank: call model.train() at the start of each step (before
the micro-batch loop). eval_loss() flips the model to eval mode and does
not restore it, so re-asserting train() each step keeps dropout live
across eval boundaries.
--dropout 0 (default) is bit-identical to the prior DDP path: cfg.dropout
stays 0 and ops::dropout(p=0) is a clone no-op regardless of training mode.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Dual-mode binary self-detecting via XTRAIN_RANK: launcher spawns one worker
per visible GPU forwarding full argv; worker rebuilds config from argv and runs
run_worker. CLI flags identical to train_ddp (thread-per-GPU, kept), so it
doubles as the before->after throughput driver. thread-per-GPU path untouched.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
torchrun-style process-per-GPU: launch_processes spawns one worker process per
GPU (re-exec current_exe with XTRAIN_{RANK,WORLD,LOCAL_RANK,NCCL_ID} env),
mints the ncclUniqueId once in the launcher and hex-injects it via env (no
shared FS/TCP, race-free). worker_env/run_worker read the env, bind the device
(own CUDA context), DdpContext::init + build_model + train_rank reused from T8
UNCHANGED. hex_encode/decode_unique_id are host-testable pure fns.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
& flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
(GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add flash_plus_dropout_grad_check_fp32 to xtrain-model dropout tests: the two
orthogonal Phase-2 features (T14 flash-attn, T18 dropout) in the same model must
still grad-check. Both models run train-mode p=0.2 (identical masks, seed is
flash-independent) so the only delta is the SDPA reduction order — checked against
the flash-vs-composed tolerance.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Config.dropout (default 0). TinyTransformer gets a Cell<bool> training switch
(train()/eval()/with_training, default eval = safe) + a Cell<u64> step_seed bumped
once per training forward. forward_batched derives a per-layer block_seed (pure fn
of step_seed×layer) and block_forward derives two per-site seeds, inserting
ops::dropout at the attn and ffn sub-block outputs (before each residual). The
seed is a pure function of (step_seed, layer, site) so the checkpoint (T13)
recompute re-derives the same masks → grads stay exact. p=0 or eval → no dropout
node → graph bit-identical to pre-T18.
train_loop: model.train() per step (restored after eval flips to eval); eval_loss
runs model.eval(). bin/train: --dropout flag → cfg.dropout. Export/sampling run in
eval (default), so exported weights are dropout-free (xserv closed loop unaffected).
Model-level tests (dropout.rs): p=0 bit-identical to no-dropout (logits/loss/grads);
eval(p>0) == p=0 identity; train differs from eval + finite; recompute-with-dropout
grads match non-recompute (fp32 + bf16).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
ops::dropout(x,p,seed): fwd runs Tensor::dropout, caches the mask in the backward
closure, bwd pushes dx=d⊙mask. p==0 returns x.clone() (no node) so the default
graph is unchanged. Tests in autograd.rs: fixed-seed finite-diff grad-check (mask
held constant across the ± perturbation — dropout is a fixed elementwise linear
map of x); E[out]≈input + keep-rate≈1-p over a seed sweep; p=0 kernel identity.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
csrc/ops/dropout.cu: counter-based RNG (splitmix64 over seed^index) → fp32
uniform → Bernoulli(keep=1-p); fwd writes out=x⊙mask + an fp32 mask buffer
(per-element 1/(1-p) or 0); bwd applies the same mask (dx=d⊙mask). fp32 + bf16
activation variants (mask fp32 in both; uniform is dtype-independent so masks
match across precisions). Stateless → re-run with same seed = same mask (T13
recompute-safe). Registered in build.rs + FFI decls.
Tensor::dropout(p,seed)->(out,mask) and Tensor::dropout_backward(d,mask) wrap the
launches (contiguous F32/BF16, default stream, per-op sync via the kernels).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- grad_accum.rs: accum=N×B grads bit-close to a single N·B big batch;
accum_steps=1 bit-identical (max|Δ|==0) to no-accum; real train() loop
with accum tracks a big-batch baseline over 20 AdamW steps.
- ddp_correctness.rs: world=2 + accum=2 matches a single-GPU big batch of
the same effective size (loss + cross-rank + vs-baseline).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Accumulate grads over N micro-batches, then one AdamW step + zero_grad,
for an effective batch of N×micro at one micro-batch's activation cost.
Each micro-loss is scaled by 1/N before backward (the tape SUM-accumulates
the scaled grads) so the boundary grad equals a single step over an N×
batch. accum==1 skips the scale → bit-identical to the pre-T16 path.
DDP: the cross-rank all-reduce fires ONLY at the accumulation boundary
(intermediate micro-steps are local-only, no NCCL); the /world average is
orthogonal to the per-micro 1/N, so the boundary grad is the effective
global-batch mean. New --accum-steps flag in both train binaries; effective
batch is printed.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Match the trusted composed grad-check dims (seq=5<FA_TILE); the multi-tile
online-softmax path is gated by flash_bwd_matches_composed_bwd (seq=40),
sharper than finite-diff on the near-zero grads a long softmax produces.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
autograd: flash_attention_batched_bwd (dQ/dK/dV finite-diff, seq>tile)
+ flash_matches_composed_fwd. model/tests/flash.rs: flash==composed
on-vs-off (logits/loss/every param grad), fp32 + bf16. parity_dump:
XTRAIN_PARITY_FLASH dumps the flash path for the same parity.py oracle
(PyTorch SDPA parity at B>1). train + train_ddp get the --flash flag.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
ops::flash_attention autograd node (fwd caches O(N) logsumexp instead of
O(N²) probs; bwd via Tensor::flash_attention_backward). Model gets a
use_flash bool + with_flash(bool) builder; the SDPA core in attention()
picks ops::flash_attention vs ops::attention. flash threads through
block_forward so the recompute (T13) segment also runs flash. Default
off = composed path, graph unchanged.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
csrc/ops/flash_attention.cu: a single fused fwd kernel (one block per
query row, streams KV in tiles of 32, online softmax — running max/sum
+ rescaled V accumulator, causal mask inlined, never materializes the
[bh,S,S] scores) writing out[bh,S,hd] + the per-row logsumexp L (O(N),
saved for backward). flash-style bwd: recompute scores from Q/K/V + L,
collapse the softmax Jacobian with D[i]=ΣdO·O, dQ owned per row, dK/dV
atomicAdd across rows. Tensor::flash_attention / flash_attention_backward
wrap them (bf16 upcasts Q/K/V→f32 for the kernel, same fp32-softmax
policy as composed).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Design doc for per-block gradient checkpointing (KI-3): the no-tape forward +
recompute-on-backward design, the `checkpoint` primitive, per-block wrapping,
the exactness/correctness argument (same kernels + inputs → identical grads),
composition with bf16+DDP+batched, and the verification plan (on-vs-off grad
gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD
to fill after the dash5 run.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Wrap each transformer block's forward in the checkpoint primitive when
recompute is enabled (Phase T13 / KI-3). To make the block forward a pure
segment fn (no `&self` borrow, so it can re-run in the backward closure),
extract the block body + its helpers (linear / norm_gamma / attention /
swiglu_mlp) into free functions parameterised by (cfg, compute_dtype) and add
`Block::block_params()` (the 11 leaves in the params() per-block order). The
non-recompute path calls `block_forward` directly — identical graph to before.
- `TinyTransformer::with_recompute(bool)` builder (opt-in; default off keeps the
unchanged tape / bit-identical numerics).
- `--recompute` flag wired into bin/train and bin/train_ddp (DDP: each rank
checkpoints independently).
Correctness gate: tests/recompute.rs builds two identical models (recompute
on/off), runs the same batched loss+backward, and asserts the forward logits,
the loss, and EVERY parameter grad match within tight fp tol — parameterised
over fp32 and bf16 (T12 composition).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add `xtrain_autodiff::checkpoint::checkpoint(segment_fn, input, params)`, a
higher-order autograd node (à la torch.utils.checkpoint) for activation
recomputation (Phase T13 / KI-3):
- forward: run `segment_fn` on detached leaves so its internal ops are NOT
recorded on the outer tape; keep only the output value (the local sub-tape —
and thus the segment's intermediate activations — drops immediately). The
checkpoint node's parents are [input, ..params].
- backward: re-run `segment_fn` from the saved input + (unchanged) param values
into a fresh local tape, seed the recomputed output with the upstream grad,
backprop, then push the recovered input/param grads to the real parents. Local
tape drops at the end → recomputed activations freed.
Exact by construction (same deterministic kernels, same inputs) → grads match
the non-checkpointed path. Composes with bf16 (T12, same path on recompute) and
DDP (T8, per-rank).
Supporting change: `Var::backward_seeded(seed)` — backward from an explicit
non-scalar upstream grad (the segment output is generally not a scalar);
`backward()` is now the scalar wrapper that seeds ones.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The `keep bf16 logits` change made forward_batched return bf16 logits
in bf16 mode; the bf16 test's host read must cast to f32 first.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
At vocab 50257 the logits tensor [B*S, vocab] is ~1.6GB fp32 at batch
32 — held across the whole backward. Keep it bf16: cross_entropy
upcasts the bf16 logits to fp32 internally (transient) + caches fp32
probs, and its backward casts dx back to bf16 to chain into the
bf16 lm_head matmul backward. The sampler casts bf16 logits→f32 before
the host argmax/softmax. Halves the persistent logits activation.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- TinyTransformer::with_compute_dtype(BF16): embedding stays fp32
master then casts to bf16; each linear casts its fp32 weight to bf16
on the fly; logits cast back to fp32 for cross-entropy. Default F32
reproduces the v0-v4 forward graph bit-for-bit.
- --bf16 flag on bin/train and bin/train_ddp (off by default).
- tests/bf16.rs: same fp32 master weights run fp32 vs bf16; assert
loss/logits/grads within a loose bf16 tol, no NaN, and grads are
fp32 (master untouched).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Tensor ops dispatch on dtype: fp32 branch unchanged (bit-identical),
bf16 branch routes matmul/attention through GemmEx and elementwise
through the bf16 kernels. Norm/softmax/RoPE/cross-entropy upcast to
fp32 around the existing fp32 kernels (standard AMP: reductions/loss
fp32, matmuls bf16). Transposes route bf16 through fp32 (pure layout).
New autodiff `cast` op is the AMP bridge: forward downcasts a fp32
master leaf to bf16 for the matmul; backward upcasts the bf16 grad
back to fp32. So the fp32 leaf accumulates an fp32 grad and AdamW /
clip / DDP all-reduce stay fp32 and completely unchanged.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>