Re-conclude xtrain as TWO phases now that Phase-2 (T14–T18) is merged on main: README.md - Status header: "complete (T1–T13) + scaling v0–v8" → "complete — two phases" (Phase 1 = from-scratch stack T1–T13 + v0–v8 scaling study; Phase 2 = the five deferred systems-stack features T14–T18). - Crate table: note the Phase-2 additions (fused flash-attn + repeat_kv + dropout in autodiff; GQA + dropout in model; grad-accum in train; process-per-GPU launcher in distributed). - Build-journey section retitled Phase 1 + Phase 2; replaced the run-on T14–T18 prose with a structured "## Phase 2" summary (5 features + honest results: flash = mem-not-walltime win, GQA group-sum backward, grad-accum −74% mem, dropout × recompute bit-exact, T17 throughput-neutral falsification). - Engineering lessons: T17 added as the THIRD profile-first falsification; reinforced honest-correctness with the Phase-2 hard gates + md5 b04fc9f9. - Doc index: doc range …14-* → …17-*; KI status line (process-per-GPU CLOSED, KI-4 accepted tradeoff). docs/evolution.md - New "三·五、Phase 2 systems-depth synthesis": ties the 5 features into the per-axis (算法/架构/Infra/数据) narrative + the two integration notes. docs/known-issues.md - KI-4 reframed as a deliberately-accepted modeling tradeoff (保 xserv closed loop; T19 DROPPED), not "open". - New integration notes: (a) DDP tests need --test-threads=1 (parallel deadlock); (b) fresh-train md5 is non-deterministic (atomicAdd reduction order) → the valid determinism gate is export re-determinism, not fresh-train reproduction. - (process-per-GPU item was already CLOSED=measured no-op in T17.) Docs-only; no code touched. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
198 lines
14 KiB
Markdown
198 lines
14 KiB
Markdown
# xtrain
|
||
|
||
A from-scratch **Rust + CUDA** LLM **training** engine — the sibling of **xserv** (the
|
||
inference side). A learning project: hand-write the entire training-systems stack
|
||
(autograd → backward → optimizer → training loop → distributed → mixed precision →
|
||
gradient checkpointing), then use it to run a multi-version **scaling study** that maps
|
||
the data-vs-capacity frontier for a tiny model.
|
||
|
||
> **Status: complete — two phases.**
|
||
> **Phase 1** = the from-scratch full stack (T1–T13) + an 8-version scaling study (v0–v8):
|
||
> hand-write the whole training-systems stack, then map the data-vs-capacity frontier.
|
||
> **Phase 2** = systems-stack depth (T14–T18): hand-write the five deferred training-stack
|
||
> features — fused flash-attention, real GQA, gradient accumulation, process-per-GPU DDP,
|
||
> dropout. Trains a Qwen3-compatible LM whose weights load into **xserv** and generate
|
||
> **token-identical** output — the closed loop held byte-for-byte across both phases. This
|
||
> README is the capstone; per-topic detail lives in [`docs/`](docs/).
|
||
|
||
---
|
||
|
||
## What got built (from scratch, by hand)
|
||
|
||
7 crates, no ML framework — only cuBLAS / NCCL / safetensors as deliberate "heavy-lifting"
|
||
borrows, the rest hand-written CUDA + Rust:
|
||
|
||
| crate | what's hand-written |
|
||
|---|---|
|
||
| `xtrain-cuda` | CUDA Runtime FFI, RAII `GpuBuffer`, **caching/pool allocator**, cuBLAS (sgemm + bf16 GemmEx) bindings |
|
||
| `xtrain-tensor` | tensor (dtype/shape/strides/storage), elementwise + transpose + embedding kernels |
|
||
| `xtrain-autodiff` | **tape autograd engine** (grad accumulation), per-op backward, finite-diff grad-check, **checkpoint** (recompute) primitive, **fused flash-attention** (online-softmax) fwd/bwd, **`repeat_kv`** broadcast (GQA), **`dropout`** (counter-based device RNG + mask) |
|
||
| `xtrain-model` | tiny **Qwen3-style** transformer (RoPE + RMSNorm + QK-norm + SwiGLU), batched forward, **GQA** (`num_kv_heads<num_heads`), residual/MLP **dropout** |
|
||
| `xtrain-optim` | hand-written **AdamW** (host + GPU kernels) |
|
||
| `xtrain-train` | training loop, LR schedule, grad clip, **gradient accumulation**, checkpoint, BPE corpus + cache, samplers, safetensors export |
|
||
| `xtrain-distributed` | **NCCL DDP** (thread-per-GPU + torchrun-style process-per-GPU launcher / cross-process `ncclUniqueId`, all-reduce) |
|
||
|
||
Every op's backward is verified against **finite differences** and against **PyTorch**
|
||
(forward + per-parameter grads, batch > 1). Trained weights export to HF-safetensors and
|
||
load into xserv (Qwen3, BF16) producing token-identical greedy output — the closed loop.
|
||
|
||
## The build journey — Phase 1 (T1–T13) + Phase 2 (T14–T18)
|
||
|
||
Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`](docs/) and
|
||
[`docs/evolution.md`](docs/evolution.md) for the per-axis changelog). **Phase 1 (T1–T13)**
|
||
hand-built the stack and fixed the four real bottlenecks; **Phase 2 (T14–T18)** went back to
|
||
hand-write five deferred training-stack features — see the Phase-2 summary below the table.
|
||
|
||
| phase | what | result |
|
||
|---|---|---|
|
||
| T1–T2 | Rust↔CUDA build chain · tensor abstraction | vector-add verified · roundtrip |
|
||
| T3–T4 | hand GEMM fwd/bwd + finite-diff · **tape autograd** + 11 op backwards | grads vs cuBLAS 1e-7 / finite-diff |
|
||
| T5 | tiny transformer (RoPE+RMSNorm+SwiGLU) | overfit + **PyTorch parity** |
|
||
| T6 | AdamW + training loop + checkpoint · GPT-2 BPE + TinyStories | first **coherent English** |
|
||
| T7 | cuBLAS + GPU optimizer + drop syncs | ~3× (2.7K→8.5K tok/s) |
|
||
| T8 | NCCL DDP | multi-GPU (weak scaling, then) |
|
||
| T9 | + per-head **QK-norm** (Qwen3-compat) + safetensors export | **xserv closed loop, token-identical** |
|
||
| **T10** | **batched multi-sequence forward** (fixes KI-1) | **single-GPU 15–24×**; MFU 0.4%→14% |
|
||
| **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** |
|
||
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; −29% mem |
|
||
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
|
||
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem −16%@1k / −23%@2k seq; flash==composed (grads/PyTorch) |
|
||
| **T15** | **grouped-query attention** (`num_kv_heads<num_heads`; `repeat_kv` broadcast feeds both SDPA paths; backward sums each kv head's group; `--kv-heads`) | repeat_kv grad-check + **group=1 bit-identical to MHA**; GQA flash==composed; PyTorch GQA B>1; **xserv closed loop with real `num_key_value_heads`** token-identical |
|
||
| **T16** | **gradient accumulation** (`--accum-steps`; DDP all-reduces only at the boundary) | equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (−74%) |
|
||
| **T17** | **process-per-GPU** DDP (torchrun-style: 1 worker process / CUDA context per GPU; launcher mints `ncclUniqueId` → hex env injection; `train_rank` reused unchanged; thread-per-GPU path kept) | proc==thread loss 1.5e-7, cross-rank 1.2e-7, xserv md5 identical · **measured no-op on throughput**: thread 5.27× vs proc 5.31×@8 (8 GPUs 95–99% util) → residual non-linearity is NCCL/PCIe, *not* CUDA-context serialization (falsifies the old KI-5 hypothesis) |
|
||
| **T18** | **dropout** (hand counter-based device RNG + mask, inverted scaling, train/eval switch) | fixed-seed grad-check; **p=0 bit-identical**; recompute-safe |
|
||
|
||
The four performance fixes (T10–T13) each removed a real bottleneck — see
|
||
[`docs/known-issues.md`](docs/known-issues.md) — which is where **Phase 1** closed.
|
||
|
||
## Phase 2 — systems-stack depth (T14–T18)
|
||
|
||
Phase 1 fixed bottlenecks; Phase 2 went back to hand-write the five training-stack features
|
||
that had been **explicitly deferred** earlier (project's actual goal = learn the whole stack).
|
||
Each is opt-in, kept the default path **bit-identical**, and held a **hard correctness gate**:
|
||
|
||
- **T14 · fused flash-attention** ([`docs/13-flash-attention.md`](docs/13-flash-attention.md)) —
|
||
a single hand-written kernel: **online (streaming) softmax, tiled over KV, never materializes
|
||
the `N×N` scores**; flash-style backward recomputes scores + the `D=ΣdO·O` Jacobian
|
||
simplification for dQ/dK/dV. Opt-in `--flash`, default off. **The win is memory, not
|
||
wall-clock**: peak activation **−16%@seq1024 / −23%@seq2048** (grows with seq, since the
|
||
`N×N` never lands), but **~2.3× slower** at head-dim 64 (a hand kernel can't beat cuBLAS
|
||
tensor-cores on a small head). Gate: flash == composed (loss rel `0.0`, grad `4.4e-5`),
|
||
PyTorch B>1 `7.9e-6`.
|
||
- **T15 · real GQA** ([`docs/14-gqa.md`](docs/14-gqa.md)) — `num_kv_heads < num_heads` via a new
|
||
`repeat_kv` **broadcast op** that copies K/V `group = nh/num_kv` times to feed **both**
|
||
(composed + flash) SDPA paths **unchanged**; its **backward is a deterministic group-sum**
|
||
(no atomics) collapsing each kv head's query-head group. Gate: `repeat_kv` grad-check +
|
||
**group=1 bit-identical to MHA** (regression guard); **xserv closed loop with real
|
||
`num_key_value_heads`** token-identical.
|
||
- **T16 · gradient accumulation** ([`docs/15-grad-accum.md`](docs/15-grad-accum.md)) — N
|
||
micro-steps scaled by `1/N` accumulate on the tape, then one AdamW step; DDP **all-reduces
|
||
only at the accumulation boundary**. Decouples effective batch from activation memory: same
|
||
effective batch 64, big-batch **27.7GB (OOM)** → accum 4×16 **7.2GB (−74%)**. Gate: `accum=N`
|
||
≡ one N× batch (grad `3.8e-5`); `accum=1` bit-identical.
|
||
- **T18 · dropout** ([`docs/17-dropout.md`](docs/17-dropout.md)) — a **stateless counter-based
|
||
device RNG** (Philox-style bit-mix) → Bernoulli mask, inverted `1/(1−p)` scaling in train,
|
||
identity in eval; wired at the two residual sites (attn-out, mlp-out). Stateless RNG is what
|
||
makes it **compose bit-exactly with T13 activation recompute** — the backward re-run
|
||
regenerates the *same* mask from `(seed, index)`. Gate: fixed-seed grad-check; **p=0
|
||
bit-identical**.
|
||
- **T17 · process-per-GPU** ([`docs/16-process-per-gpu.md`](docs/16-process-per-gpu.md)) — a
|
||
torchrun-style launcher: one worker process + CUDA context per GPU, the launcher mints one
|
||
`ncclUniqueId` and **hex-injects it into each child's env** (no shared FS/TCP, no race); the
|
||
worker reuses the T8 `train_rank` **unchanged**. Built and **correct** (proc vs thread loss
|
||
`1.5e-7`, cross-rank `1.2e-7`, xserv md5 identical) — but **measured throughput-neutral**:
|
||
8-GPU thread **491K (5.27×)** vs proc **493K (5.31×)**, `<1%`. This **falsifies** the
|
||
long-standing KI-5/T11 hypothesis that thread-per-GPU's shared CUDA context caused the
|
||
residual ~5×@8; with all 8 GPUs at 95–99% util, the residual is the **NCCL all-reduce + PCIe
|
||
topology wall**, not context serialization. The third profile-first falsification (see below).
|
||
|
||
## The scaling study — v0 → v8
|
||
|
||
Same Qwen3-style architecture throughout; we scaled **dim** and **data** and read out val
|
||
loss (full per-run detail in [`docs/runs/`](docs/runs/)).
|
||
|
||
| ver | data (trained tok / epoch) | dim / core params | val loss | axis explored |
|
||
|---|---|---|---|---|
|
||
| v0–v3 | TinyStories (↑) | 32→512 / 41K→67M | 3.80 → 1.30 | bring-up |
|
||
| v4 | TinyStories 1.54ep | 768 / 127M | 1.17 | — |
|
||
| v5 | TinyStories 5.33ep | 768 / 127M | **1.11** | **data volume → saturates** |
|
||
| v6 | FineWeb-edu 1.02ep | 768 / 127M | 3.07\* | **corpus swap → graduates to real text** |
|
||
| v7 | FineWeb-edu 1.45ep | 768 / 127M | 3.01\* | same subset, more epochs → near-ceiling |
|
||
| **v8** | FineWeb-edu 1.05ep | **1024 / 226M** | **2.98\*** | **capacity → helps** |
|
||
|
||
\* FineWeb-edu val is a different (harder) distribution — **not comparable** to the
|
||
TinyStories val of v0–v5. Judge v6+ by sample quality + transfer, not the number.
|
||
|
||
### Three findings
|
||
|
||
1. **Data volume saturates.** TinyStories at dim768: 3.5× more tokens (v4→v5) bought only
|
||
−5% val, curve flat. The narrow synthetic corpus is exhausted at this model size.
|
||
2. **Corpus > more-of-the-same.** Swapping TinyStories → FineWeb-edu (v5→v6) was a
|
||
*qualitative* jump: the model went from only-writes-kid-stories to writing genuine
|
||
historical/scientific expository prose. (Cost: TinyStories transfer val 1.11 → 2.75.)
|
||
3. **Capacity helps.** v8 (dim1024, ~1 epoch) beats both v6 (dim768, same epoch, by 0.085)
|
||
and v7 (dim768, *more* data, by 0.035) → the dim768 runs were partly capacity-limited.
|
||
|
||
**Meta-finding:** every *single*-axis lever (data volume, corpus breadth, capacity) is now
|
||
worth only **~3%**. Per the Chinchilla lesson, further gains require scaling **data and
|
||
capacity together** — single-axis moves are exhausted.
|
||
|
||
## Efficiency — throughput & MFU
|
||
|
||
The throughput story is the perf-infra report card (RTX 5090, bf16/fp32):
|
||
|
||
| | v1 | v2 | v3 | v4 | v5 |
|
||
|---|---|---|---|---|---|
|
||
| tok/s | 3.3K (1 GPU) | 3.6K (4 GPU) | 26K (1 GPU) | 145K (8 GPU) | 217K (8 GPU) |
|
||
| **MFU** | 0.4% | 0.2% | 14% | 17% | 13% |
|
||
| enabled by | — | DDP (weak) | **batched (T10)** | **alloc (T11)** | **bf16 (T12)** |
|
||
|
||
v1/v2 ran at **<0.5% MFU** — the single-sequence design left the GPU idle (launch-bound).
|
||
**Batched forward (T10) was the single biggest unlock** (~35× MFU jump). 6ND is an accurate
|
||
FLOPs count, but predicting *time* needs the *realized* MFU, which varied ~40× across
|
||
versions — a fixed-MFU estimate is off by up to ~100× for the early launch-bound runs.
|
||
|
||
## Engineering lessons
|
||
|
||
- **Profile before optimizing.** *Three* "known" fixes were *falsified by measurement*: (1)
|
||
"bigger batch fixes DDP scaling" (real cause: single-seq launch-bound → T10); (2) "bucket the
|
||
all-reduce" (real cause: per-op `cudaMalloc` serialization → T11 caching allocator); and (3)
|
||
"process-per-GPU would fix the residual ~5×@8" (T17 — built the torchrun-style launcher and
|
||
measured it **throughput-neutral**: the residual is the NCCL/PCIe communication wall, not
|
||
shared-context serialization). All three would have been no-ops; each got measured and either
|
||
reverted or recorded as a deliberate negative result instead of shipped on faith.
|
||
- **Honest correctness.** QK-norm was *added* to match xserv's Qwen3 (not faked); every change
|
||
kept a hard correctness gate, and **no tolerance was ever loosened to go green**. Phase 2 held
|
||
the line: flash == composed SDPA (grads/PyTorch), GQA group=1 bit-identical to MHA, gradient
|
||
accumulation `accum=1` bit-identical, dropout p=0 bit-identical *and* dropout × recompute
|
||
bit-exact, the default path unchanged on every feature, and the **xserv closed-loop md5
|
||
byte-identical (`b04fc9f9`) throughout both phases**.
|
||
- **The closed loop matters.** Exporting to xserv and checking token-identical greedy output
|
||
caught real bugs and proved the whole stack end-to-end.
|
||
|
||
## Running it
|
||
|
||
Everything trains on a remote 8× RTX 5090 box; model artifacts live in a registry
|
||
(`tiny-models/v0…v8`). Serve any trained version in xserv:
|
||
|
||
```bash
|
||
# on the GPU box
|
||
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v8-fineweb-edu-dim1024 --max-tokens 100
|
||
# then type a prompt, e.g. In science,
|
||
```
|
||
|
||
Build/test the engine itself (CUDA compiles + runs on the GPU box; host-side `cargo check`
|
||
works anywhere via the `no_cuda` cfg):
|
||
|
||
```sh
|
||
export PATH=/usr/local/cuda/bin:$HOME/.cargo/bin:$PATH
|
||
cargo test --workspace # autograd grad-checks, PyTorch parity, DDP, etc.
|
||
```
|
||
|
||
## Doc index
|
||
|
||
- [`docs/evolution.md`](docs/evolution.md) — per-milestone changes across algorithm / architecture / infra / dataset.
|
||
- [`docs/runs/README.md`](docs/runs/README.md) — the v0–v8 comparison; [`docs/runs/0N-*.md`](docs/runs/) — per-run detail.
|
||
- [`docs/00-*` … `17-*`](docs/) — per-phase design docs (build chain → tensor → autograd → transformer → training → perf → distributed → export → batched → allocator → bf16 → recompute → flash-attention → GQA → grad-accum → process-per-GPU → dropout).
|
||
- [`docs/known-issues.md`](docs/known-issues.md) — perf backlog (KI-1/2/3/5 fixed; process-per-GPU CLOSED = measured no-op; KI-4 = accepted modeling tradeoff).
|