Files
xtrain/README.md
Gahow Wang 2ff4573a31 docs: T15 GQA results + evolution row (模型架构) + README build-journey row
Backfill docs/14-gqa.md gate table (dash5 numbers); add T15 evolution row +
cumulative 模型架构 line; README build-journey T15 row + Phase 2 prose + doc
index range (00..14).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:44:58 +08:00

154 lines
9.4 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# xtrain
A from-scratch **Rust + CUDA** LLM **training** engine — the sibling of **xserv** (the
inference side). A learning project: hand-write the entire training-systems stack
(autograd → backward → optimizer → training loop → distributed → mixed precision →
gradient checkpointing), then use it to run a multi-version **scaling study** that maps
the data-vs-capacity frontier for a tiny model.
> **Status: complete.** From-scratch full stack (phases T1T13) + an 8-version scaling
> ladder (v0v8). Trains a Qwen3-compatible LM whose weights load into **xserv** and
> generate **token-identical** output. This README is the capstone; per-topic detail
> lives in [`docs/`](docs/).
---
## What got built (from scratch, by hand)
7 crates, no ML framework — only cuBLAS / NCCL / safetensors as deliberate "heavy-lifting"
borrows, the rest hand-written CUDA + Rust:
| crate | what's hand-written |
|---|---|
| `xtrain-cuda` | CUDA Runtime FFI, RAII `GpuBuffer`, **caching/pool allocator**, cuBLAS (sgemm + bf16 GemmEx) bindings |
| `xtrain-tensor` | tensor (dtype/shape/strides/storage), elementwise + transpose + embedding kernels |
| `xtrain-autodiff` | **tape autograd engine** (grad accumulation), per-op backward, finite-diff grad-check, **checkpoint** (recompute) primitive |
| `xtrain-model` | tiny **Qwen3-style** transformer (RoPE + RMSNorm + QK-norm + SwiGLU), batched forward |
| `xtrain-optim` | hand-written **AdamW** (host + GPU kernels) |
| `xtrain-train` | training loop, LR schedule, grad clip, checkpoint, BPE corpus + cache, samplers, safetensors export |
| `xtrain-distributed` | **NCCL DDP** (thread-per-GPU, all-reduce) |
Every op's backward is verified against **finite differences** and against **PyTorch**
(forward + per-parameter grads, batch > 1). Trained weights export to HF-safetensors and
load into xserv (Qwen3, BF16) producing token-identical greedy output — the closed loop.
## The build journey — phases T1T13
Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`](docs/) and
[`docs/evolution.md`](docs/evolution.md) for the per-axis changelog).
| phase | what | result |
|---|---|---|
| T1T2 | Rust↔CUDA build chain · tensor abstraction | vector-add verified · roundtrip |
| T3T4 | hand GEMM fwd/bwd + finite-diff · **tape autograd** + 11 op backwards | grads vs cuBLAS 1e-7 / finite-diff |
| T5 | tiny transformer (RoPE+RMSNorm+SwiGLU) | overfit + **PyTorch parity** |
| T6 | AdamW + training loop + checkpoint · GPT-2 BPE + TinyStories | first **coherent English** |
| T7 | cuBLAS + GPU optimizer + drop syncs | ~3× (2.7K→8.5K tok/s) |
| T8 | NCCL DDP | multi-GPU (weak scaling, then) |
| T9 | + per-head **QK-norm** (Qwen3-compat) + safetensors export | **xserv closed loop, token-identical** |
| **T10** | **batched multi-sequence forward** (fixes KI-1) | **single-GPU 1524×**; MFU 0.4%→14% |
| **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** |
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; 29% mem |
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem 16%@1k / 23%@2k seq; flash==composed (grads/PyTorch) |
| **T15** | **grouped-query attention** (`num_kv_heads<num_heads`; `repeat_kv` broadcast feeds both SDPA paths; backward sums each kv head's group; `--kv-heads`) | repeat_kv grad-check + **group=1 bit-identical to MHA**; GQA flash==composed; PyTorch GQA B>1; **xserv closed loop with real `num_key_value_heads`** token-identical |
| **T16** | **gradient accumulation** (`--accum-steps`; DDP all-reduces only at the boundary) | equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (74%) |
| **T18** | **dropout** (hand counter-based device RNG + mask, inverted scaling, train/eval switch) | fixed-seed grad-check; **p=0 bit-identical**; recompute-safe |
The four performance fixes (T10T13) each removed a real bottleneck — see
[`docs/known-issues.md`](docs/known-issues.md). **Phase 2 (systems-stack depth, T14)**
revisits hand-writing deferred training-stack features: T14 = the fused
flash-attention kernel ([`docs/13-flash-attention.md`](docs/13-flash-attention.md));
T15 = real grouped-query attention ([`docs/14-gqa.md`](docs/14-gqa.md), `num_kv_heads <
num_heads` via a `repeat_kv` broadcast op whose backward sums each kv head's query-head
group — feeding both SDPA paths unchanged, default MHA bit-identical);
T16 = micro-batch gradient accumulation ([`docs/15-grad-accum.md`](docs/15-grad-accum.md)),
which decouples the effective batch from activation memory (memory tracks the micro-batch,
not N×); T18 = dropout ([`docs/17-dropout.md`](docs/17-dropout.md), hand counter-based
device RNG + mask, inverted scaling, train/eval switch).
## The scaling study — v0 → v8
Same Qwen3-style architecture throughout; we scaled **dim** and **data** and read out val
loss (full per-run detail in [`docs/runs/`](docs/runs/)).
| ver | data (trained tok / epoch) | dim / core params | val loss | axis explored |
|---|---|---|---|---|
| v0v3 | TinyStories (↑) | 32→512 / 41K→67M | 3.80 → 1.30 | bring-up |
| v4 | TinyStories 1.54ep | 768 / 127M | 1.17 | — |
| v5 | TinyStories 5.33ep | 768 / 127M | **1.11** | **data volume → saturates** |
| v6 | FineWeb-edu 1.02ep | 768 / 127M | 3.07\* | **corpus swap → graduates to real text** |
| v7 | FineWeb-edu 1.45ep | 768 / 127M | 3.01\* | same subset, more epochs → near-ceiling |
| **v8** | FineWeb-edu 1.05ep | **1024 / 226M** | **2.98\*** | **capacity → helps** |
\* FineWeb-edu val is a different (harder) distribution — **not comparable** to the
TinyStories val of v0v5. Judge v6+ by sample quality + transfer, not the number.
### Three findings
1. **Data volume saturates.** TinyStories at dim768: 3.5× more tokens (v4→v5) bought only
5% val, curve flat. The narrow synthetic corpus is exhausted at this model size.
2. **Corpus > more-of-the-same.** Swapping TinyStories → FineWeb-edu (v5→v6) was a
*qualitative* jump: the model went from only-writes-kid-stories to writing genuine
historical/scientific expository prose. (Cost: TinyStories transfer val 1.11 → 2.75.)
3. **Capacity helps.** v8 (dim1024, ~1 epoch) beats both v6 (dim768, same epoch, by 0.085)
and v7 (dim768, *more* data, by 0.035) → the dim768 runs were partly capacity-limited.
**Meta-finding:** every *single*-axis lever (data volume, corpus breadth, capacity) is now
worth only **~3%**. Per the Chinchilla lesson, further gains require scaling **data and
capacity together** — single-axis moves are exhausted.
## Efficiency — throughput & MFU
The throughput story is the perf-infra report card (RTX 5090, bf16/fp32):
| | v1 | v2 | v3 | v4 | v5 |
|---|---|---|---|---|---|
| tok/s | 3.3K (1 GPU) | 3.6K (4 GPU) | 26K (1 GPU) | 145K (8 GPU) | 217K (8 GPU) |
| **MFU** | 0.4% | 0.2% | 14% | 17% | 13% |
| enabled by | — | DDP (weak) | **batched (T10)** | **alloc (T11)** | **bf16 (T12)** |
v1/v2 ran at **<0.5% MFU** the single-sequence design left the GPU idle (launch-bound).
**Batched forward (T10) was the single biggest unlock** (~35× MFU jump). 6ND is an accurate
FLOPs count, but predicting *time* needs the *realized* MFU, which varied ~40× across
versions a fixed-MFU estimate is off by up to ~100× for the early launch-bound runs.
## Engineering lessons
- **Profile before optimizing.** Two "known" perf fixes were *falsified by measurement*
before being shipped: "bigger batch fixes DDP scaling" (real cause: single-seq
launch-bound T10) and "bucket the all-reduce" (real cause: per-op `cudaMalloc`
serialization T11 caching allocator). Both would have been no-ops; both got reverted +
re-diagnosed instead of shipped.
- **Honest correctness.** QK-norm was *added* to match xserv's Qwen3 (not faked); every perf
change kept a hard correctness gate (recompute grads bit-identical; bf16 keeps the fp32
path untouched; the full grad-check / PyTorch / DDP / xserv suite must stay green).
- **The closed loop matters.** Exporting to xserv and checking token-identical greedy output
caught real bugs and proved the whole stack end-to-end.
## Running it
Everything trains on a remote 8× RTX 5090 box; model artifacts live in a registry
(`tiny-models/v0…v8`). Serve any trained version in xserv:
```bash
# on the GPU box
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v8-fineweb-edu-dim1024 --max-tokens 100
# then type a prompt, e.g. In science,
```
Build/test the engine itself (CUDA compiles + runs on the GPU box; host-side `cargo check`
works anywhere via the `no_cuda` cfg):
```sh
export PATH=/usr/local/cuda/bin:$HOME/.cargo/bin:$PATH
cargo test --workspace # autograd grad-checks, PyTorch parity, DDP, etc.
```
## Doc index
- [`docs/evolution.md`](docs/evolution.md) per-milestone changes across algorithm / architecture / infra / dataset.
- [`docs/runs/README.md`](docs/runs/README.md) the v0v8 comparison; [`docs/runs/0N-*.md`](docs/runs/) per-run detail.
- [`docs/00-*` … `14-*`](docs/) per-phase design docs (build chain tensor autograd transformer training perf distributed export batched allocator bf16 recompute flash-attention GQA).
- [`docs/known-issues.md`](docs/known-issues.md) perf backlog (KI-1/2/3/5 fixed; KI-4 + process-per-GPU open).