Files
xtrain/README.md
Gahow Wang 5c27493a90 docs: backfill v9/v10 scaling runs + reframe README to v0–v10 / three phases
Add per-run design+result docs for the two Chinchilla-axis runs that were
done but never committed:
- v9 (dim1280 true-GQA, core 357M, 6.01B FineWeb tokens): double-axis scale,
  best moving-tail val 2.8854 (~3.2% below v8) — direction validated, gain
  still incremental, greedy repetition remains.
- v10 (same arch, data-only top-up to 6.765B): moving-tail 2.8816; fixed
  eval v1 v6→v10 = 3.2328/3.1850/3.1515/2.9278/2.8814.

Extend the comparison tables in docs/runs/README.md and docs/evolution.md to
v10, and reframe README to v0–v10 with Phase 3 = the v9 double-axis run. No
code changes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-29 16:18:48 +08:00

211 lines
15 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# xtrain
A from-scratch **Rust + CUDA** LLM **training** engine — the sibling of **xserv** (the
inference side). A learning project: hand-write the entire training-systems stack
(autograd → backward → optimizer → training loop → distributed → mixed precision →
gradient checkpointing), then use it to run a multi-version **scaling study** that maps
the data-vs-capacity frontier for a tiny model.
> **Status: complete — three phases.**
> **Phase 1** = the from-scratch full stack (T1T13) + an 8-version scaling study (v0v8):
> hand-write the whole training-systems stack, then map the data-vs-capacity frontier.
> **Phase 2** = systems-stack depth (T14T18): hand-write the five deferred training-stack
> features — fused flash-attention, real GQA, gradient accumulation, process-per-GPU DDP,
> dropout. **Phase 3** = one Chinchilla-style double-axis run (v9): dim1280 true-GQA +
> 6.01B FineWeb tokens, validating the v8 conclusion that data and capacity must scale
> together. Trains Qwen3-compatible LMs whose weights load into **xserv**; deterministic
> gates stay byte-identical, while large BF16 checkpoints are served and checked for
> prompt-level drift. This README is the capstone; per-topic detail lives in [`docs/`](docs/).
---
## What got built (from scratch, by hand)
7 crates, no ML framework — only cuBLAS / NCCL / safetensors as deliberate "heavy-lifting"
borrows, the rest hand-written CUDA + Rust:
| crate | what's hand-written |
|---|---|
| `xtrain-cuda` | CUDA Runtime FFI, RAII `GpuBuffer`, **caching/pool allocator**, cuBLAS (sgemm + bf16 GemmEx) bindings |
| `xtrain-tensor` | tensor (dtype/shape/strides/storage), elementwise + transpose + embedding kernels |
| `xtrain-autodiff` | **tape autograd engine** (grad accumulation), per-op backward, finite-diff grad-check, **checkpoint** (recompute) primitive, **fused flash-attention** (online-softmax) fwd/bwd, **`repeat_kv`** broadcast (GQA), **`dropout`** (counter-based device RNG + mask) |
| `xtrain-model` | tiny **Qwen3-style** transformer (RoPE + RMSNorm + QK-norm + SwiGLU), batched forward, **GQA** (`num_kv_heads<num_heads`), residual/MLP **dropout** |
| `xtrain-optim` | hand-written **AdamW** (host + GPU kernels) |
| `xtrain-train` | training loop, LR schedule, grad clip, **gradient accumulation**, checkpoint, BPE corpus + cache, samplers, safetensors export |
| `xtrain-distributed` | **NCCL DDP** (thread-per-GPU + torchrun-style process-per-GPU launcher / cross-process `ncclUniqueId`, all-reduce) |
Every op's backward is verified against **finite differences** and against **PyTorch**
(forward + per-parameter grads, batch > 1). Trained weights export to HF-safetensors and
load into xserv (Qwen3, BF16); deterministic fixtures produce token-identical greedy output,
and large checkpoints are validated end-to-end in the serving path.
## The build journey — Phase 1 (T1T13) + Phase 2 (T14T18)
Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`](docs/) and
[`docs/evolution.md`](docs/evolution.md) for the per-axis changelog). **Phase 1 (T1T13)**
hand-built the stack and fixed the four real bottlenecks; **Phase 2 (T14T18)** went back to
hand-write five deferred training-stack features — see the Phase-2 summary below the table.
| phase | what | result |
|---|---|---|
| T1T2 | Rust↔CUDA build chain · tensor abstraction | vector-add verified · roundtrip |
| T3T4 | hand GEMM fwd/bwd + finite-diff · **tape autograd** + 11 op backwards | grads vs cuBLAS 1e-7 / finite-diff |
| T5 | tiny transformer (RoPE+RMSNorm+SwiGLU) | overfit + **PyTorch parity** |
| T6 | AdamW + training loop + checkpoint · GPT-2 BPE + TinyStories | first **coherent English** |
| T7 | cuBLAS + GPU optimizer + drop syncs | ~3× (2.7K→8.5K tok/s) |
| T8 | NCCL DDP | multi-GPU (weak scaling, then) |
| T9 | + per-head **QK-norm** (Qwen3-compat) + safetensors export | **xserv closed loop, token-identical** |
| **T10** | **batched multi-sequence forward** (fixes KI-1) | **single-GPU 1524×**; MFU 0.4%→14% |
| **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** |
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; 29% mem |
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem 16%@1k / 23%@2k seq; flash==composed (grads/PyTorch) |
| **T15** | **grouped-query attention** (`num_kv_heads<num_heads`; `repeat_kv` broadcast feeds both SDPA paths; backward sums each kv head's group; `--kv-heads`) | repeat_kv grad-check + **group=1 bit-identical to MHA**; GQA flash==composed; PyTorch GQA B>1; **xserv closed loop with real `num_key_value_heads`** token-identical |
| **T16** | **gradient accumulation** (`--accum-steps`; DDP all-reduces only at the boundary) | equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (74%) |
| **T17** | **process-per-GPU** DDP (torchrun-style: 1 worker process / CUDA context per GPU; launcher mints `ncclUniqueId` → hex env injection; `train_rank` reused unchanged; thread-per-GPU path kept) | proc==thread loss 1.5e-7, cross-rank 1.2e-7, xserv md5 identical · **measured no-op on throughput**: thread 5.27× vs proc 5.31×@8 (8 GPUs 9599% util) → residual non-linearity is NCCL/PCIe, *not* CUDA-context serialization (falsifies the old KI-5 hypothesis) |
| **T18** | **dropout** (hand counter-based device RNG + mask, inverted scaling, train/eval switch) | fixed-seed grad-check; **p=0 bit-identical**; recompute-safe |
The four performance fixes (T10T13) each removed a real bottleneck — see
[`docs/known-issues.md`](docs/known-issues.md) — which is where **Phase 1** closed.
## Phase 2 — systems-stack depth (T14T18)
Phase 1 fixed bottlenecks; Phase 2 went back to hand-write the five training-stack features
that had been **explicitly deferred** earlier (project's actual goal = learn the whole stack).
Each is opt-in, kept the default path **bit-identical**, and held a **hard correctness gate**:
- **T14 · fused flash-attention** ([`docs/13-flash-attention.md`](docs/13-flash-attention.md)) —
a single hand-written kernel: **online (streaming) softmax, tiled over KV, never materializes
the `N×N` scores**; flash-style backward recomputes scores + the `D=ΣdO·O` Jacobian
simplification for dQ/dK/dV. Opt-in `--flash`, default off. **The win is memory, not
wall-clock**: peak activation **16%@seq1024 / 23%@seq2048** (grows with seq, since the
`N×N` never lands), but **~2.3× slower** at head-dim 64 (a hand kernel can't beat cuBLAS
tensor-cores on a small head). Gate: flash == composed (loss rel `0.0`, grad `4.4e-5`),
PyTorch B>1 `7.9e-6`.
- **T15 · real GQA** ([`docs/14-gqa.md`](docs/14-gqa.md)) — `num_kv_heads < num_heads` via a new
`repeat_kv` **broadcast op** that copies K/V `group = nh/num_kv` times to feed **both**
(composed + flash) SDPA paths **unchanged**; its **backward is a deterministic group-sum**
(no atomics) collapsing each kv head's query-head group. Gate: `repeat_kv` grad-check +
**group=1 bit-identical to MHA** (regression guard); **xserv closed loop with real
`num_key_value_heads`** token-identical.
- **T16 · gradient accumulation** ([`docs/15-grad-accum.md`](docs/15-grad-accum.md)) — N
micro-steps scaled by `1/N` accumulate on the tape, then one AdamW step; DDP **all-reduces
only at the accumulation boundary**. Decouples effective batch from activation memory: same
effective batch 64, big-batch **27.7GB (OOM)** → accum 4×16 **7.2GB (74%)**. Gate: `accum=N`
≡ one N× batch (grad `3.8e-5`); `accum=1` bit-identical.
- **T18 · dropout** ([`docs/17-dropout.md`](docs/17-dropout.md)) — a **stateless counter-based
device RNG** (Philox-style bit-mix) → Bernoulli mask, inverted `1/(1p)` scaling in train,
identity in eval; wired at the two residual sites (attn-out, mlp-out). Stateless RNG is what
makes it **compose bit-exactly with T13 activation recompute** — the backward re-run
regenerates the *same* mask from `(seed, index)`. Gate: fixed-seed grad-check; **p=0
bit-identical**.
- **T17 · process-per-GPU** ([`docs/16-process-per-gpu.md`](docs/16-process-per-gpu.md)) — a
torchrun-style launcher: one worker process + CUDA context per GPU, the launcher mints one
`ncclUniqueId` and **hex-injects it into each child's env** (no shared FS/TCP, no race); the
worker reuses the T8 `train_rank` **unchanged**. Built and **correct** (proc vs thread loss
`1.5e-7`, cross-rank `1.2e-7`, xserv md5 identical) — but **measured throughput-neutral**:
8-GPU thread **491K (5.27×)** vs proc **493K (5.31×)**, `<1%`. This **falsifies** the
long-standing KI-5/T11 hypothesis that thread-per-GPU's shared CUDA context caused the
residual ~5×@8; with all 8 GPUs at 9599% util, the residual is the **NCCL all-reduce + PCIe
topology wall**, not context serialization. The third profile-first falsification (see below).
## The scaling study — v0 → v10
Same Qwen3-style architecture throughout; we scaled **dim** and **data** and read out val
loss (full per-run detail in [`docs/runs/`](docs/runs/)).
| ver | data (trained tok / epoch) | dim / core params | val loss | axis explored |
|---|---|---|---|---|
| v0v3 | TinyStories (↑) | 32→512 / 41K→67M | 3.80 → 1.30 | bring-up |
| v4 | TinyStories 1.54ep | 768 / 127M | 1.17 | — |
| v5 | TinyStories 5.33ep | 768 / 127M | **1.11** | **data volume → saturates** |
| v6 | FineWeb-edu 1.02ep | 768 / 127M | 3.07\* | **corpus swap → graduates to real text** |
| v7 | FineWeb-edu 1.45ep | 768 / 127M | 3.01\* | same subset, more epochs → near-ceiling |
| **v8** | FineWeb-edu 1.05ep | **1024 / 226M** | **2.98\*** | **capacity → helps** |
| **v9** | FineWeb-edu 6.01B / ~1ep | **1280 / 357M + GQA** | **2.89\*** | **data + capacity → helps** |
| **v10** | FineWeb-edu 6.76B / ~1ep | **1280 / 357M + GQA** | **2.88\*** | **data-only top-up → small gain** |
\* FineWeb-edu val is a different (harder) distribution — **not comparable** to the
TinyStories val of v0v5. Judge v6+ by sample quality + transfer, not the number.
### Four findings
1. **Data volume saturates.** TinyStories at dim768: 3.5× more tokens (v4→v5) bought only
5% val, curve flat. The narrow synthetic corpus is exhausted at this model size.
2. **Corpus > more-of-the-same.** Swapping TinyStories → FineWeb-edu (v5→v6) was a
*qualitative* jump: the model went from only-writes-kid-stories to writing genuine
historical/scientific expository prose. (Cost: TinyStories transfer val 1.11 → 2.75.)
3. **Capacity helps.** v8 (dim1024, ~1 epoch) beats both v6 (dim768, same epoch, by 0.085)
and v7 (dim768, *more* data, by 0.035) → the dim768 runs were partly capacity-limited.
4. **Double-axis scale helps.** v9 scales both axes (dim1280/core357M + 6.01B FineWeb tokens)
and beats v8 by another 0.095 val loss (~3.2%). The direction is validated, but the gain is
still incremental and greedy decoding still repeats.
5. **Moving validation tails must stop.** v10 added one more FineWeb shard and got moving-tail
val 2.8816, but appending data moves the held-out tail. A fixed eval v1 was created from the
shard010 tail: v6/v7/v8/v9/v10 = 3.2328 / 3.1850 / 3.1515 / 2.9278 / 2.8814. Future runs
should report this fixed eval first.
**Meta-finding:** every lever is now in the **~3% or smaller** regime. Single-axis moves were
exhausted by v8; v9 confirms Chinchilla-style double-axis scale works; v10 shows a data-only
top-up mostly adapts to the new shard. The next useful run should change model/context, not just
append another shard.
## Efficiency — throughput & MFU
The throughput story is the perf-infra report card (RTX 5090, bf16/fp32):
| | v1 | v2 | v3 | v4 | v5 |
|---|---|---|---|---|---|
| tok/s | 3.3K (1 GPU) | 3.6K (4 GPU) | 26K (1 GPU) | 145K (8 GPU) | 217K (8 GPU) |
| **MFU** | 0.4% | 0.2% | 14% | 17% | 13% |
| enabled by | — | DDP (weak) | **batched (T10)** | **alloc (T11)** | **bf16 (T12)** |
v1/v2 ran at **<0.5% MFU** the single-sequence design left the GPU idle (launch-bound).
**Batched forward (T10) was the single biggest unlock** (~35× MFU jump). 6ND is an accurate
FLOPs count, but predicting *time* needs the *realized* MFU, which varied ~40× across
versions a fixed-MFU estimate is off by up to ~100× for the early launch-bound runs.
## Engineering lessons
- **Profile before optimizing.** *Three* "known" fixes were *falsified by measurement*: (1)
"bigger batch fixes DDP scaling" (real cause: single-seq launch-bound T10); (2) "bucket the
all-reduce" (real cause: per-op `cudaMalloc` serialization T11 caching allocator); and (3)
"process-per-GPU would fix the residual ~5×@8" (T17 built the torchrun-style launcher and
measured it **throughput-neutral**: the residual is the NCCL/PCIe communication wall, not
shared-context serialization). All three would have been no-ops; each got measured and either
reverted or recorded as a deliberate negative result instead of shipped on faith.
- **Honest correctness.** QK-norm was *added* to match xserv's Qwen3 (not faked); every change
kept a hard correctness gate, and **no tolerance was ever loosened to go green**. Phase 2 held
the line: flash == composed SDPA (grads/PyTorch), GQA group=1 bit-identical to MHA, gradient
accumulation `accum=1` bit-identical, dropout p=0 bit-identical *and* dropout × recompute
bit-exact, the default path unchanged on every feature, and the **xserv closed-loop md5
byte-identical (`b04fc9f9`) throughout the deterministic gates**.
- **The closed loop matters.** Exporting to xserv and checking generated continuations caught
real bugs and proved the whole stack end-to-end.
## Running it
Everything trains on a remote 8× RTX 5090 box; model artifacts live in a registry
(`tiny-models/v0…v10`). Serve any trained version in xserv:
```bash
# on the GPU box
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v10-fineweb-edu-dim1280-gqa-data6765 --max-tokens 100
# then type a prompt, e.g. In science,
```
Build/test the engine itself (CUDA compiles + runs on the GPU box; host-side `cargo check`
works anywhere via the `no_cuda` cfg):
```sh
export PATH=/usr/local/cuda/bin:$HOME/.cargo/bin:$PATH
cargo test --workspace # autograd grad-checks, PyTorch parity, DDP, etc.
```
## Doc index
- [`docs/evolution.md`](docs/evolution.md) per-milestone changes across algorithm / architecture / infra / dataset.
- [`docs/runs/README.md`](docs/runs/README.md) the v0v10 comparison; [`docs/runs/0N-*.md`](docs/runs/) per-run detail.
- [`docs/00-*` … `17-*`](docs/) per-phase design docs (build chain tensor autograd transformer training perf distributed export batched allocator bf16 recompute flash-attention GQA grad-accum process-per-GPU dropout).
- [`docs/known-issues.md`](docs/known-issues.md) perf backlog (KI-1/2/3/5 fixed; process-per-GPU CLOSED = measured no-op; KI-4 = accepted modeling tradeoff).