Files
xtrain/README.md

141 lines
7.8 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# xtrain
A from-scratch **Rust + CUDA** LLM **training** engine — the sibling of **xserv** (the
inference side). A learning project: hand-write the entire training-systems stack
(autograd → backward → optimizer → training loop → distributed → mixed precision →
gradient checkpointing), then use it to run a multi-version **scaling study** that maps
the data-vs-capacity frontier for a tiny model.
> **Status: complete.** From-scratch full stack (phases T1T13) + an 8-version scaling
> ladder (v0v8). Trains a Qwen3-compatible LM whose weights load into **xserv** and
> generate **token-identical** output. This README is the capstone; per-topic detail
> lives in [`docs/`](docs/).
---
## What got built (from scratch, by hand)
7 crates, no ML framework — only cuBLAS / NCCL / safetensors as deliberate "heavy-lifting"
borrows, the rest hand-written CUDA + Rust:
| crate | what's hand-written |
|---|---|
| `xtrain-cuda` | CUDA Runtime FFI, RAII `GpuBuffer`, **caching/pool allocator**, cuBLAS (sgemm + bf16 GemmEx) bindings |
| `xtrain-tensor` | tensor (dtype/shape/strides/storage), elementwise + transpose + embedding kernels |
| `xtrain-autodiff` | **tape autograd engine** (grad accumulation), per-op backward, finite-diff grad-check, **checkpoint** (recompute) primitive |
| `xtrain-model` | tiny **Qwen3-style** transformer (RoPE + RMSNorm + QK-norm + SwiGLU), batched forward |
| `xtrain-optim` | hand-written **AdamW** (host + GPU kernels) |
| `xtrain-train` | training loop, LR schedule, grad clip, checkpoint, BPE corpus + cache, samplers, safetensors export |
| `xtrain-distributed` | **NCCL DDP** (thread-per-GPU, all-reduce) |
Every op's backward is verified against **finite differences** and against **PyTorch**
(forward + per-parameter grads, batch > 1). Trained weights export to HF-safetensors and
load into xserv (Qwen3, BF16) producing token-identical greedy output — the closed loop.
## The build journey — phases T1T13
Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`](docs/) and
[`docs/evolution.md`](docs/evolution.md) for the per-axis changelog).
| phase | what | result |
|---|---|---|
| T1T2 | Rust↔CUDA build chain · tensor abstraction | vector-add verified · roundtrip |
| T3T4 | hand GEMM fwd/bwd + finite-diff · **tape autograd** + 11 op backwards | grads vs cuBLAS 1e-7 / finite-diff |
| T5 | tiny transformer (RoPE+RMSNorm+SwiGLU) | overfit + **PyTorch parity** |
| T6 | AdamW + training loop + checkpoint · GPT-2 BPE + TinyStories | first **coherent English** |
| T7 | cuBLAS + GPU optimizer + drop syncs | ~3× (2.7K→8.5K tok/s) |
| T8 | NCCL DDP | multi-GPU (weak scaling, then) |
| T9 | + per-head **QK-norm** (Qwen3-compat) + safetensors export | **xserv closed loop, token-identical** |
| **T10** | **batched multi-sequence forward** (fixes KI-1) | **single-GPU 1524×**; MFU 0.4%→14% |
| **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** |
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; 29% mem |
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
The four performance fixes (T10T13) each removed a real bottleneck — see
[`docs/known-issues.md`](docs/known-issues.md).
## The scaling study — v0 → v8
Same Qwen3-style architecture throughout; we scaled **dim** and **data** and read out val
loss (full per-run detail in [`docs/runs/`](docs/runs/)).
| ver | data (trained tok / epoch) | dim / core params | val loss | axis explored |
|---|---|---|---|---|
| v0v3 | TinyStories (↑) | 32→512 / 41K→67M | 3.80 → 1.30 | bring-up |
| v4 | TinyStories 1.54ep | 768 / 127M | 1.17 | — |
| v5 | TinyStories 5.33ep | 768 / 127M | **1.11** | **data volume → saturates** |
| v6 | FineWeb-edu 1.02ep | 768 / 127M | 3.07\* | **corpus swap → graduates to real text** |
| v7 | FineWeb-edu 1.45ep | 768 / 127M | 3.01\* | same subset, more epochs → near-ceiling |
| **v8** | FineWeb-edu 1.05ep | **1024 / 226M** | **2.98\*** | **capacity → helps** |
\* FineWeb-edu val is a different (harder) distribution — **not comparable** to the
TinyStories val of v0v5. Judge v6+ by sample quality + transfer, not the number.
### Three findings
1. **Data volume saturates.** TinyStories at dim768: 3.5× more tokens (v4→v5) bought only
5% val, curve flat. The narrow synthetic corpus is exhausted at this model size.
2. **Corpus > more-of-the-same.** Swapping TinyStories → FineWeb-edu (v5→v6) was a
*qualitative* jump: the model went from only-writes-kid-stories to writing genuine
historical/scientific expository prose. (Cost: TinyStories transfer val 1.11 → 2.75.)
3. **Capacity helps.** v8 (dim1024, ~1 epoch) beats both v6 (dim768, same epoch, by 0.085)
and v7 (dim768, *more* data, by 0.035) → the dim768 runs were partly capacity-limited.
**Meta-finding:** every *single*-axis lever (data volume, corpus breadth, capacity) is now
worth only **~3%**. Per the Chinchilla lesson, further gains require scaling **data and
capacity together** — single-axis moves are exhausted.
## Efficiency — throughput & MFU
The throughput story is the perf-infra report card (RTX 5090, bf16/fp32):
| | v1 | v2 | v3 | v4 | v5 |
|---|---|---|---|---|---|
| tok/s | 3.3K (1 GPU) | 3.6K (4 GPU) | 26K (1 GPU) | 145K (8 GPU) | 217K (8 GPU) |
| **MFU** | 0.4% | 0.2% | 14% | 17% | 13% |
| enabled by | — | DDP (weak) | **batched (T10)** | **alloc (T11)** | **bf16 (T12)** |
v1/v2 ran at **<0.5% MFU** the single-sequence design left the GPU idle (launch-bound).
**Batched forward (T10) was the single biggest unlock** (~35× MFU jump). 6ND is an accurate
FLOPs count, but predicting *time* needs the *realized* MFU, which varied ~40× across
versions a fixed-MFU estimate is off by up to ~100× for the early launch-bound runs.
## Engineering lessons
- **Profile before optimizing.** Two "known" perf fixes were *falsified by measurement*
before being shipped: "bigger batch fixes DDP scaling" (real cause: single-seq
launch-bound T10) and "bucket the all-reduce" (real cause: per-op `cudaMalloc`
serialization T11 caching allocator). Both would have been no-ops; both got reverted +
re-diagnosed instead of shipped.
- **Honest correctness.** QK-norm was *added* to match xserv's Qwen3 (not faked); every perf
change kept a hard correctness gate (recompute grads bit-identical; bf16 keeps the fp32
path untouched; the full grad-check / PyTorch / DDP / xserv suite must stay green).
- **The closed loop matters.** Exporting to xserv and checking token-identical greedy output
caught real bugs and proved the whole stack end-to-end.
## Running it
Everything trains on a remote 8× RTX 5090 box; model artifacts live in a registry
(`tiny-models/v0…v8`). Serve any trained version in xserv:
```bash
# on the GPU box
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v8-fineweb-edu-dim1024 --max-tokens 100
# then type a prompt, e.g. In science,
```
Build/test the engine itself (CUDA compiles + runs on the GPU box; host-side `cargo check`
works anywhere via the `no_cuda` cfg):
```sh
export PATH=/usr/local/cuda/bin:$HOME/.cargo/bin:$PATH
cargo test --workspace # autograd grad-checks, PyTorch parity, DDP, etc.
```
## Doc index
- [`docs/evolution.md`](docs/evolution.md) per-milestone changes across algorithm / architecture / infra / dataset.
- [`docs/runs/README.md`](docs/runs/README.md) the v0v8 comparison; [`docs/runs/0N-*.md`](docs/runs/) per-run detail.
- [`docs/00-*` … `12-*`](docs/) per-phase design docs (build chain tensor autograd transformer training perf distributed export batched allocator bf16 recompute).
- [`docs/known-issues.md`](docs/known-issues.md) perf backlog (KI-1/2/3/5 fixed; KI-4 + process-per-GPU open).