Files
xtrain/README.md
Gahow Wang 8bd7db16e1 docs: T16 grad-accum results — evolution row + README build-journey
dash5-verified gate numbers: accum=N bit-close to N× big batch (loss
8.5e-8 / grad 3.8e-5), accum=1 bit-identical (0.0), DDP+accum matches
single-GPU (5.7e-7), memory flat (same effective batch 64: 27.7GB big →
7.2GB accum, −74%), xserv closed loop md5-identical + token-identical.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:52:32 +08:00

8.2 KiB
Raw Blame History

xtrain

A from-scratch Rust + CUDA LLM training engine — the sibling of xserv (the inference side). A learning project: hand-write the entire training-systems stack (autograd → backward → optimizer → training loop → distributed → mixed precision → gradient checkpointing), then use it to run a multi-version scaling study that maps the data-vs-capacity frontier for a tiny model.

Status: complete. From-scratch full stack (phases T1T13) + an 8-version scaling ladder (v0v8). Trains a Qwen3-compatible LM whose weights load into xserv and generate token-identical output. This README is the capstone; per-topic detail lives in docs/.


What got built (from scratch, by hand)

7 crates, no ML framework — only cuBLAS / NCCL / safetensors as deliberate "heavy-lifting" borrows, the rest hand-written CUDA + Rust:

crate what's hand-written
xtrain-cuda CUDA Runtime FFI, RAII GpuBuffer, caching/pool allocator, cuBLAS (sgemm + bf16 GemmEx) bindings
xtrain-tensor tensor (dtype/shape/strides/storage), elementwise + transpose + embedding kernels
xtrain-autodiff tape autograd engine (grad accumulation), per-op backward, finite-diff grad-check, checkpoint (recompute) primitive
xtrain-model tiny Qwen3-style transformer (RoPE + RMSNorm + QK-norm + SwiGLU), batched forward
xtrain-optim hand-written AdamW (host + GPU kernels)
xtrain-train training loop, LR schedule, grad clip, checkpoint, BPE corpus + cache, samplers, safetensors export
xtrain-distributed NCCL DDP (thread-per-GPU, all-reduce)

Every op's backward is verified against finite differences and against PyTorch (forward + per-parameter grads, batch > 1). Trained weights export to HF-safetensors and load into xserv (Qwen3, BF16) producing token-identical greedy output — the closed loop.

The build journey — phases T1T13

Each phase: design doc + implementation + tests + a scoped commit (see docs/ and docs/evolution.md for the per-axis changelog).

phase what result
T1T2 Rust↔CUDA build chain · tensor abstraction vector-add verified · roundtrip
T3T4 hand GEMM fwd/bwd + finite-diff · tape autograd + 11 op backwards grads vs cuBLAS 1e-7 / finite-diff
T5 tiny transformer (RoPE+RMSNorm+SwiGLU) overfit + PyTorch parity
T6 AdamW + training loop + checkpoint · GPT-2 BPE + TinyStories first coherent English
T7 cuBLAS + GPU optimizer + drop syncs ~3× (2.7K→8.5K tok/s)
T8 NCCL DDP multi-GPU (weak scaling, then)
T9 + per-head QK-norm (Qwen3-compat) + safetensors export xserv closed loop, token-identical
T10 batched multi-sequence forward (fixes KI-1) single-GPU 1524×; MFU 0.4%→14%
T11 device caching allocator (fixes KI-5) single-GPU 2.3×; 8-GPU 461K tok/s
T12 bf16 mixed precision (fp32 master, fixes KI-2) dim768 OOM solved; 29% mem
T13 activation recompute / checkpointing (fixes KI-3) dim1024 fits; grads bit-identical
T16 gradient accumulation (--accum-steps; DDP all-reduces only at the boundary) equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (74%)

The four performance fixes (T10T13) each removed a real bottleneck — see docs/known-issues.md. Phase 2 (systems-stack depth) revisits hand-writing deferred training-stack features; T16 = micro-batch gradient accumulation (docs/15-grad-accum.md), which decouples the effective batch from activation memory (memory tracks the micro-batch, not N×).

The scaling study — v0 → v8

Same Qwen3-style architecture throughout; we scaled dim and data and read out val loss (full per-run detail in docs/runs/).

ver data (trained tok / epoch) dim / core params val loss axis explored
v0v3 TinyStories (↑) 32→512 / 41K→67M 3.80 → 1.30 bring-up
v4 TinyStories 1.54ep 768 / 127M 1.17
v5 TinyStories 5.33ep 768 / 127M 1.11 data volume → saturates
v6 FineWeb-edu 1.02ep 768 / 127M 3.07* corpus swap → graduates to real text
v7 FineWeb-edu 1.45ep 768 / 127M 3.01* same subset, more epochs → near-ceiling
v8 FineWeb-edu 1.05ep 1024 / 226M 2.98* capacity → helps

* FineWeb-edu val is a different (harder) distribution — not comparable to the TinyStories val of v0v5. Judge v6+ by sample quality + transfer, not the number.

Three findings

  1. Data volume saturates. TinyStories at dim768: 3.5× more tokens (v4→v5) bought only 5% val, curve flat. The narrow synthetic corpus is exhausted at this model size.
  2. Corpus > more-of-the-same. Swapping TinyStories → FineWeb-edu (v5→v6) was a qualitative jump: the model went from only-writes-kid-stories to writing genuine historical/scientific expository prose. (Cost: TinyStories transfer val 1.11 → 2.75.)
  3. Capacity helps. v8 (dim1024, ~1 epoch) beats both v6 (dim768, same epoch, by 0.085) and v7 (dim768, more data, by 0.035) → the dim768 runs were partly capacity-limited.

Meta-finding: every single-axis lever (data volume, corpus breadth, capacity) is now worth only ~3%. Per the Chinchilla lesson, further gains require scaling data and capacity together — single-axis moves are exhausted.

Efficiency — throughput & MFU

The throughput story is the perf-infra report card (RTX 5090, bf16/fp32):

v1 v2 v3 v4 v5
tok/s 3.3K (1 GPU) 3.6K (4 GPU) 26K (1 GPU) 145K (8 GPU) 217K (8 GPU)
MFU 0.4% 0.2% 14% 17% 13%
enabled by DDP (weak) batched (T10) alloc (T11) bf16 (T12)

v1/v2 ran at <0.5% MFU — the single-sequence design left the GPU idle (launch-bound). Batched forward (T10) was the single biggest unlock (~35× MFU jump). 6ND is an accurate FLOPs count, but predicting time needs the realized MFU, which varied ~40× across versions — a fixed-MFU estimate is off by up to ~100× for the early launch-bound runs.

Engineering lessons

  • Profile before optimizing. Two "known" perf fixes were falsified by measurement before being shipped: "bigger batch fixes DDP scaling" (real cause: single-seq launch-bound → T10) and "bucket the all-reduce" (real cause: per-op cudaMalloc serialization → T11 caching allocator). Both would have been no-ops; both got reverted + re-diagnosed instead of shipped.
  • Honest correctness. QK-norm was added to match xserv's Qwen3 (not faked); every perf change kept a hard correctness gate (recompute grads bit-identical; bf16 keeps the fp32 path untouched; the full grad-check / PyTorch / DDP / xserv suite must stay green).
  • The closed loop matters. Exporting to xserv and checking token-identical greedy output caught real bugs and proved the whole stack end-to-end.

Running it

Everything trains on a remote 8× RTX 5090 box; model artifacts live in a registry (tiny-models/v0…v8). Serve any trained version in xserv:

# on the GPU box
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v8-fineweb-edu-dim1024 --max-tokens 100
# then type a prompt, e.g.  In science,

Build/test the engine itself (CUDA compiles + runs on the GPU box; host-side cargo check works anywhere via the no_cuda cfg):

export PATH=/usr/local/cuda/bin:$HOME/.cargo/bin:$PATH
cargo test --workspace            # autograd grad-checks, PyTorch parity, DDP, etc.

Doc index

  • docs/evolution.md — per-milestone changes across algorithm / architecture / infra / dataset.
  • docs/runs/README.md — the v0v8 comparison; docs/runs/0N-*.md — per-run detail.
  • docs/00-*12-* — per-phase design docs (build chain → tensor → autograd → transformer → training → perf → distributed → export → batched → allocator → bf16 → recompute).
  • docs/known-issues.md — perf backlog (KI-1/2/3/5 fixed; KI-4 + process-per-GPU open).