- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
& flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
(GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add flash_plus_dropout_grad_check_fp32 to xtrain-model dropout tests: the two
orthogonal Phase-2 features (T14 flash-attn, T18 dropout) in the same model must
still grad-check. Both models run train-mode p=0.2 (identical masks, seed is
flash-independent) so the only delta is the SDPA reduction order — checked against
the flash-vs-composed tolerance.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Config.dropout (default 0). TinyTransformer gets a Cell<bool> training switch
(train()/eval()/with_training, default eval = safe) + a Cell<u64> step_seed bumped
once per training forward. forward_batched derives a per-layer block_seed (pure fn
of step_seed×layer) and block_forward derives two per-site seeds, inserting
ops::dropout at the attn and ffn sub-block outputs (before each residual). The
seed is a pure function of (step_seed, layer, site) so the checkpoint (T13)
recompute re-derives the same masks → grads stay exact. p=0 or eval → no dropout
node → graph bit-identical to pre-T18.
train_loop: model.train() per step (restored after eval flips to eval); eval_loss
runs model.eval(). bin/train: --dropout flag → cfg.dropout. Export/sampling run in
eval (default), so exported weights are dropout-free (xserv closed loop unaffected).
Model-level tests (dropout.rs): p=0 bit-identical to no-dropout (logits/loss/grads);
eval(p>0) == p=0 identity; train differs from eval + finite; recompute-with-dropout
grads match non-recompute (fp32 + bf16).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
ops::dropout(x,p,seed): fwd runs Tensor::dropout, caches the mask in the backward
closure, bwd pushes dx=d⊙mask. p==0 returns x.clone() (no node) so the default
graph is unchanged. Tests in autograd.rs: fixed-seed finite-diff grad-check (mask
held constant across the ± perturbation — dropout is a fixed elementwise linear
map of x); E[out]≈input + keep-rate≈1-p over a seed sweep; p=0 kernel identity.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
csrc/ops/dropout.cu: counter-based RNG (splitmix64 over seed^index) → fp32
uniform → Bernoulli(keep=1-p); fwd writes out=x⊙mask + an fp32 mask buffer
(per-element 1/(1-p) or 0); bwd applies the same mask (dx=d⊙mask). fp32 + bf16
activation variants (mask fp32 in both; uniform is dtype-independent so masks
match across precisions). Stateless → re-run with same seed = same mask (T13
recompute-safe). Registered in build.rs + FFI decls.
Tensor::dropout(p,seed)->(out,mask) and Tensor::dropout_backward(d,mask) wrap the
launches (contiguous F32/BF16, default stream, per-op sync via the kernels).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Counter-based (stateless) RNG → Bernoulli(keep=1-p) mask, inverted 1/(1-p)
scaling at train, identity at eval. New autodiff `dropout` op (fwd generates +
applies mask, bwd applies the SAME cached mask). Wired at the two residual-path
sites (attn / ffn outputs); attention-probs dropout deliberately skipped (fused
SDPA doesn't materialise probs). Documents the RNG choice, per-site deterministic
seed (so T13 recompute reproduces the same mask), train/eval switch, p=0
bit-identity, and the acceptance gates.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- grad_accum.rs: accum=N×B grads bit-close to a single N·B big batch;
accum_steps=1 bit-identical (max|Δ|==0) to no-accum; real train() loop
with accum tracks a big-batch baseline over 20 AdamW steps.
- ddp_correctness.rs: world=2 + accum=2 matches a single-GPU big batch of
the same effective size (loss + cross-rank + vs-baseline).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Accumulate grads over N micro-batches, then one AdamW step + zero_grad,
for an effective batch of N×micro at one micro-batch's activation cost.
Each micro-loss is scaled by 1/N before backward (the tape SUM-accumulates
the scaled grads) so the boundary grad equals a single step over an N×
batch. accum==1 skips the scale → bit-identical to the pre-T16 path.
DDP: the cross-rank all-reduce fires ONLY at the accumulation boundary
(intermediate micro-steps are local-only, no NCCL); the /world average is
orthogonal to the per-micro 1/N, so the boundary grad is the effective
global-batch mean. New --accum-steps flag in both train binaries; effective
batch is printed.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Fill in the design doc's measured results (grad-check, flash==composed,
PyTorch parity, peak mem -16%/-23%, tok/s tradeoff), add the T14 row to
evolution.md (算法/Infra) and the README build-journey table.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Match the trusted composed grad-check dims (seq=5<FA_TILE); the multi-tile
online-softmax path is gated by flash_bwd_matches_composed_bwd (seq=40),
sharper than finite-diff on the near-zero grads a long softmax produces.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
autograd: flash_attention_batched_bwd (dQ/dK/dV finite-diff, seq>tile)
+ flash_matches_composed_fwd. model/tests/flash.rs: flash==composed
on-vs-off (logits/loss/every param grad), fp32 + bf16. parity_dump:
XTRAIN_PARITY_FLASH dumps the flash path for the same parity.py oracle
(PyTorch SDPA parity at B>1). train + train_ddp get the --flash flag.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
ops::flash_attention autograd node (fwd caches O(N) logsumexp instead of
O(N²) probs; bwd via Tensor::flash_attention_backward). Model gets a
use_flash bool + with_flash(bool) builder; the SDPA core in attention()
picks ops::flash_attention vs ops::attention. flash threads through
block_forward so the recompute (T13) segment also runs flash. Default
off = composed path, graph unchanged.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
csrc/ops/flash_attention.cu: a single fused fwd kernel (one block per
query row, streams KV in tiles of 32, online softmax — running max/sum
+ rescaled V accumulator, causal mask inlined, never materializes the
[bh,S,S] scores) writing out[bh,S,hd] + the per-row logsumexp L (O(N),
saved for backward). flash-style bwd: recompute scores from Q/K/V + L,
collapse the softmax Jacobian with D[i]=ΣdO·O, dQ owned per row, dK/dV
atomicAdd across rows. Tensor::flash_attention / flash_attention_backward
wrap them (bf16 upcasts Q/K/V→f32 for the kernel, same fp32-softmax
policy as composed).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Design doc for the hand-written single fused flash-attention kernel:
online softmax tiled over KV, NEVER materializing the [bh,S,S] score
matrix; flash-style backward (recompute scores from saved logsumexp +
D=ΣdO·O, dQ/dK/dV). Opt-in --flash; composed T10 path stays default.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
v8 = capacity-axis A/B: freeze the v6/v7 2.255B FineWeb-edu subset, scale
dim768→dim1024 (core 127M→226M, +78%) via bf16 + T13 activation recompute.
8-GPU DDP, 2.36B tok (1.05 ep), ~129K tok/s (recompute tax), ~5h.
Result (same FineWeb val, v6/v7/v8 comparable): v6 3.0652 / v7 3.0149 /
v8 2.9801. Capacity helps — v8 (1.05ep) beats v6 at the same ~1ep by 0.085
AND beats v7 (smaller model, 1.45ep more old data) by 0.035 ⇒ v6/v7 were
partly capacity-limited, scaling capacity > repeating old data. But the gain
is only ~3% (same magnitude as the data-axis single-step lever), and v8's
val was still descending at the end (not saturated).
Meta-finding: every single-axis lever (data-volume v5/v7, breadth v6,
capacity v8) is now ~3%/lever ⇒ broad diminishing returns; to progress,
scale capacity AND data together (Chinchilla, reproduced at toy scale).
- docs/runs/08-v8-fineweb-edu-dim1024.md: full capacity experiment + v7-vs-v8 samples
- docs/runs/README.md: +v8 row, v9 proposal
- docs/evolution.md: +T13 infra row, +v8 scaling row, capacity-axis & diminishing-returns notes
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Design doc for per-block gradient checkpointing (KI-3): the no-tape forward +
recompute-on-backward design, the `checkpoint` primitive, per-block wrapping,
the exactness/correctness argument (same kernels + inputs → identical grads),
composition with bf16+DDP+batched, and the verification plan (on-vs-off grad
gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD
to fill after the dash5 run.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Wrap each transformer block's forward in the checkpoint primitive when
recompute is enabled (Phase T13 / KI-3). To make the block forward a pure
segment fn (no `&self` borrow, so it can re-run in the backward closure),
extract the block body + its helpers (linear / norm_gamma / attention /
swiglu_mlp) into free functions parameterised by (cfg, compute_dtype) and add
`Block::block_params()` (the 11 leaves in the params() per-block order). The
non-recompute path calls `block_forward` directly — identical graph to before.
- `TinyTransformer::with_recompute(bool)` builder (opt-in; default off keeps the
unchanged tape / bit-identical numerics).
- `--recompute` flag wired into bin/train and bin/train_ddp (DDP: each rank
checkpoints independently).
Correctness gate: tests/recompute.rs builds two identical models (recompute
on/off), runs the same batched loss+backward, and asserts the forward logits,
the loss, and EVERY parameter grad match within tight fp tol — parameterised
over fp32 and bf16 (T12 composition).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add `xtrain_autodiff::checkpoint::checkpoint(segment_fn, input, params)`, a
higher-order autograd node (à la torch.utils.checkpoint) for activation
recomputation (Phase T13 / KI-3):
- forward: run `segment_fn` on detached leaves so its internal ops are NOT
recorded on the outer tape; keep only the output value (the local sub-tape —
and thus the segment's intermediate activations — drops immediately). The
checkpoint node's parents are [input, ..params].
- backward: re-run `segment_fn` from the saved input + (unchanged) param values
into a fresh local tape, seed the recomputed output with the upstream grad,
backprop, then push the recovered input/param grads to the real parents. Local
tape drops at the end → recomputed activations freed.
Exact by construction (same deterministic kernels, same inputs) → grads match
the non-checkpointed path. Composes with bf16 (T12, same path on recompute) and
DDP (T8, per-rank).
Supporting change: `Var::backward_seeded(seed)` — backward from an explicit
non-scalar upstream grad (the segment output is generally not a scalar);
`backward()` is now the scalar wrapper that seeds ones.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
v7 = same arch as v4/v5/v6 (dim768/18L, bf16, 8-GPU DDP global 256),
trained the SAME 2.255B-token FineWeb-edu subset to 1.45 epoch (vs v6's
1.02), best FineWeb val 3.0149 (v6 3.0652). Exported + archived to
registry v7-fineweb-edu-dim768, serves in xserv (coherent expository
English, ~v6 quality).
Key finding: more epochs of the SAME subset gave only ~0.05 val drop and
the curve flattened (~step 44000) with no sampling quality gain → the
2.255B FineWeb subset is near its ceiling at dim768. Same class as v5's
TinyStories data-volume saturation: repeating old data has thin margins;
true further gains need FRESH shards (more diverse tokens), as v6's
corpus-swap (which raised the ceiling) showed.
Adds docs/runs/07-v7-*.md; updates docs/runs/README.md (+v7 row, intro
saturation note, v8 proposal) and docs/evolution.md (+v7 row, dataset-axis
ceiling note).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
v6 broadens data from TinyStories to FineWeb-edu (HuggingFaceFW/fineweb-edu
sample/10BT) while freezing the v4/v5 arch. scripts/fineweb_to_txt.py streams
the parquet text column row-group by row-group and joins docs with
<|endoftext|> so xtrain's existing Corpus loader (gpt2 BPE, u16 cache) handles
it unchanged. Corpus .txt/.parquet/.u16.bin stay dash5-only (gitignored).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The `keep bf16 logits` change made forward_batched return bf16 logits
in bf16 mode; the bf16 test's host read must cast to f32 first.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
At vocab 50257 the logits tensor [B*S, vocab] is ~1.6GB fp32 at batch
32 — held across the whole backward. Keep it bf16: cross_entropy
upcasts the bf16 logits to fp32 internally (transient) + caches fp32
probs, and its backward casts dx back to bf16 to chain into the
bf16 lm_head matmul backward. The sampler casts bf16 logits→f32 before
the host argmax/softmax. Halves the persistent logits activation.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
docs/11-bf16-mixed-precision.md: the AMP split (bf16 linears +
activations, fp32 master / norms / softmax / RoPE / CE, no loss
scaling), the cast-op bridge, module layout, and the dual
verification gate (fp32 unchanged + bf16 looser-tol + convergence +
mem/throughput). Memory/throughput before->after to be filled from
the dash5 bench.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- TinyTransformer::with_compute_dtype(BF16): embedding stays fp32
master then casts to bf16; each linear casts its fp32 weight to bf16
on the fly; logits cast back to fp32 for cross-entropy. Default F32
reproduces the v0-v4 forward graph bit-for-bit.
- --bf16 flag on bin/train and bin/train_ddp (off by default).
- tests/bf16.rs: same fp32 master weights run fp32 vs bf16; assert
loss/logits/grads within a loose bf16 tol, no NaN, and grads are
fp32 (master untouched).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Tensor ops dispatch on dtype: fp32 branch unchanged (bit-identical),
bf16 branch routes matmul/attention through GemmEx and elementwise
through the bf16 kernels. Norm/softmax/RoPE/cross-entropy upcast to
fp32 around the existing fp32 kernels (standard AMP: reductions/loss
fp32, matmuls bf16). Transposes route bf16 through fp32 (pure layout).
New autodiff `cast` op is the AMP bridge: forward downcasts a fp32
master leaf to bf16 for the matmul; backward upcasts the bf16 grad
back to fp32. So the fp32 leaf accumulates an fp32 grad and AdamW /
clip / DDP all-reduce stay fp32 and completely unchanged.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
v4 surfaced the concrete bf16 trigger: dim768 fp32 OOMs at per-rank batch 32
(global 256) in 32GB, forcing per-rank 16 (global 128). bf16 (halve activation
mem) would restore the batch-256 sweet spot. Record it on KI-2; keep KI-2 as
the backlog item it is (still deferred).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>