Gahow Wang 0e82b2438e test: M2d — ragged-forward + batched-op equivalence gates + throughput bench
Two exact correctness gates (composed = the end-to-end batched GRPO step == looped):
- xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded
  ragged sequences == per-sequence single-seq forward on the real rows. fp32
  max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is
  free under causal".
- xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped
  Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32).

bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards
(loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets
so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on
dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms
(9.2×), training forwards 2526→280 ms (9.0×).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:03:09 +08:00
2026-06-15 17:14:56 +08:00

xtrain

A from-scratch Rust + CUDA LLM training engine — the sibling of xserv (the inference side). A learning project: hand-write the entire training-systems stack (autograd → backward → optimizer → training loop → distributed → mixed precision → gradient checkpointing), then use it to run a multi-version scaling study that maps the data-vs-capacity frontier for a tiny model.

Status: complete — three phases. Phase 1 = the from-scratch full stack (T1T13) + an 8-version scaling study (v0v8): hand-write the whole training-systems stack, then map the data-vs-capacity frontier. Phase 2 = systems-stack depth (T14T18): hand-write the five deferred training-stack features — fused flash-attention, real GQA, gradient accumulation, process-per-GPU DDP, dropout. Phase 3 = one Chinchilla-style double-axis run (v9): dim1280 true-GQA + 6.01B FineWeb tokens, validating the v8 conclusion that data and capacity must scale together. Trains Qwen3-compatible LMs whose weights load into xserv; deterministic gates stay byte-identical, while large BF16 checkpoints are served and checked for prompt-level drift. This README is the capstone; per-topic detail lives in docs/.


What got built (from scratch, by hand)

7 crates, no ML framework — only cuBLAS / NCCL / safetensors as deliberate "heavy-lifting" borrows, the rest hand-written CUDA + Rust:

crate what's hand-written
xtrain-cuda CUDA Runtime FFI, RAII GpuBuffer, caching/pool allocator, cuBLAS (sgemm + bf16 GemmEx) bindings
xtrain-tensor tensor (dtype/shape/strides/storage), elementwise + transpose + embedding kernels
xtrain-autodiff tape autograd engine (grad accumulation), per-op backward, finite-diff grad-check, checkpoint (recompute) primitive, fused flash-attention (online-softmax) fwd/bwd, repeat_kv broadcast (GQA), dropout (counter-based device RNG + mask)
xtrain-model tiny Qwen3-style transformer (RoPE + RMSNorm + QK-norm + SwiGLU), batched forward, GQA (num_kv_heads<num_heads), residual/MLP dropout
xtrain-optim hand-written AdamW (host + GPU kernels)
xtrain-train training loop, LR schedule, grad clip, gradient accumulation, checkpoint, BPE corpus + cache, samplers, safetensors export
xtrain-distributed NCCL DDP (thread-per-GPU + torchrun-style process-per-GPU launcher / cross-process ncclUniqueId, all-reduce)

Every op's backward is verified against finite differences and against PyTorch (forward + per-parameter grads, batch > 1). Trained weights export to HF-safetensors and load into xserv (Qwen3, BF16); deterministic fixtures produce token-identical greedy output, and large checkpoints are validated end-to-end in the serving path.

The build journey — Phase 1 (T1T13) + Phase 2 (T14T18)

Each phase: design doc + implementation + tests + a scoped commit (see docs/ and docs/evolution.md for the per-axis changelog). Phase 1 (T1T13) hand-built the stack and fixed the four real bottlenecks; Phase 2 (T14T18) went back to hand-write five deferred training-stack features — see the Phase-2 summary below the table.

phase what result
T1T2 Rust↔CUDA build chain · tensor abstraction vector-add verified · roundtrip
T3T4 hand GEMM fwd/bwd + finite-diff · tape autograd + 11 op backwards grads vs cuBLAS 1e-7 / finite-diff
T5 tiny transformer (RoPE+RMSNorm+SwiGLU) overfit + PyTorch parity
T6 AdamW + training loop + checkpoint · GPT-2 BPE + TinyStories first coherent English
T7 cuBLAS + GPU optimizer + drop syncs ~3× (2.7K→8.5K tok/s)
T8 NCCL DDP multi-GPU (weak scaling, then)
T9 + per-head QK-norm (Qwen3-compat) + safetensors export xserv closed loop, token-identical
T10 batched multi-sequence forward (fixes KI-1) single-GPU 1524×; MFU 0.4%→14%
T11 device caching allocator (fixes KI-5) single-GPU 2.3×; 8-GPU 461K tok/s
T12 bf16 mixed precision (fp32 master, fixes KI-2) dim768 OOM solved; 29% mem
T13 activation recompute / checkpointing (fixes KI-3) dim1024 fits; grads bit-identical
T14 fused flash-attention kernel (online softmax, no materialized N×N; opt-in --flash) peak mem 16%@1k / 23%@2k seq; flash==composed (grads/PyTorch)
T15 grouped-query attention (num_kv_heads<num_heads; repeat_kv broadcast feeds both SDPA paths; backward sums each kv head's group; --kv-heads) repeat_kv grad-check + group=1 bit-identical to MHA; GQA flash==composed; PyTorch GQA B>1; xserv closed loop with real num_key_value_heads token-identical
T16 gradient accumulation (--accum-steps; DDP all-reduces only at the boundary) equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (74%)
T17 process-per-GPU DDP (torchrun-style: 1 worker process / CUDA context per GPU; launcher mints ncclUniqueId → hex env injection; train_rank reused unchanged; thread-per-GPU path kept) proc==thread loss 1.5e-7, cross-rank 1.2e-7, xserv md5 identical · measured no-op on throughput: thread 5.27× vs proc 5.31×@8 (8 GPUs 9599% util) → residual non-linearity is NCCL/PCIe, not CUDA-context serialization (falsifies the old KI-5 hypothesis)
T18 dropout (hand counter-based device RNG + mask, inverted scaling, train/eval switch) fixed-seed grad-check; p=0 bit-identical; recompute-safe

The four performance fixes (T10T13) each removed a real bottleneck — see docs/known-issues.md — which is where Phase 1 closed.

Phase 2 — systems-stack depth (T14T18)

Phase 1 fixed bottlenecks; Phase 2 went back to hand-write the five training-stack features that had been explicitly deferred earlier (project's actual goal = learn the whole stack). Each is opt-in, kept the default path bit-identical, and held a hard correctness gate:

  • T14 · fused flash-attention (docs/13-flash-attention.md) — a single hand-written kernel: online (streaming) softmax, tiled over KV, never materializes the N×N scores; flash-style backward recomputes scores + the D=ΣdO·O Jacobian simplification for dQ/dK/dV. Opt-in --flash, default off. The win is memory, not wall-clock: peak activation 16%@seq1024 / 23%@seq2048 (grows with seq, since the N×N never lands), but ~2.3× slower at head-dim 64 (a hand kernel can't beat cuBLAS tensor-cores on a small head). Gate: flash == composed (loss rel 0.0, grad 4.4e-5), PyTorch B>1 7.9e-6.
  • T15 · real GQA (docs/14-gqa.md) — num_kv_heads < num_heads via a new repeat_kv broadcast op that copies K/V group = nh/num_kv times to feed both (composed + flash) SDPA paths unchanged; its backward is a deterministic group-sum (no atomics) collapsing each kv head's query-head group. Gate: repeat_kv grad-check + group=1 bit-identical to MHA (regression guard); xserv closed loop with real num_key_value_heads token-identical.
  • T16 · gradient accumulation (docs/15-grad-accum.md) — N micro-steps scaled by 1/N accumulate on the tape, then one AdamW step; DDP all-reduces only at the accumulation boundary. Decouples effective batch from activation memory: same effective batch 64, big-batch 27.7GB (OOM) → accum 4×16 7.2GB (74%). Gate: accum=N ≡ one N× batch (grad 3.8e-5); accum=1 bit-identical.
  • T18 · dropout (docs/17-dropout.md) — a stateless counter-based device RNG (Philox-style bit-mix) → Bernoulli mask, inverted 1/(1p) scaling in train, identity in eval; wired at the two residual sites (attn-out, mlp-out). Stateless RNG is what makes it compose bit-exactly with T13 activation recompute — the backward re-run regenerates the same mask from (seed, index). Gate: fixed-seed grad-check; p=0 bit-identical.
  • T17 · process-per-GPU (docs/16-process-per-gpu.md) — a torchrun-style launcher: one worker process + CUDA context per GPU, the launcher mints one ncclUniqueId and hex-injects it into each child's env (no shared FS/TCP, no race); the worker reuses the T8 train_rank unchanged. Built and correct (proc vs thread loss 1.5e-7, cross-rank 1.2e-7, xserv md5 identical) — but measured throughput-neutral: 8-GPU thread 491K (5.27×) vs proc 493K (5.31×), <1%. This falsifies the long-standing KI-5/T11 hypothesis that thread-per-GPU's shared CUDA context caused the residual ~5×@8; with all 8 GPUs at 9599% util, the residual is the NCCL all-reduce + PCIe topology wall, not context serialization. The third profile-first falsification (see below).

The scaling study — v0 → v10

Same Qwen3-style architecture throughout; we scaled dim and data and read out val loss (full per-run detail in docs/runs/).

ver data (trained tok / epoch) dim / core params val loss axis explored
v0v3 TinyStories (↑) 32→512 / 41K→67M 3.80 → 1.30 bring-up
v4 TinyStories 1.54ep 768 / 127M 1.17
v5 TinyStories 5.33ep 768 / 127M 1.11 data volume → saturates
v6 FineWeb-edu 1.02ep 768 / 127M 3.07* corpus swap → graduates to real text
v7 FineWeb-edu 1.45ep 768 / 127M 3.01* same subset, more epochs → near-ceiling
v8 FineWeb-edu 1.05ep 1024 / 226M 2.98* capacity → helps
v9 FineWeb-edu 6.01B / ~1ep 1280 / 357M + GQA 2.89* data + capacity → helps
v10 FineWeb-edu 6.76B / ~1ep 1280 / 357M + GQA 2.88* data-only top-up → small gain

* FineWeb-edu val is a different (harder) distribution — not comparable to the TinyStories val of v0v5. Judge v6+ by sample quality + transfer, not the number.

Four findings

  1. Data volume saturates. TinyStories at dim768: 3.5× more tokens (v4→v5) bought only 5% val, curve flat. The narrow synthetic corpus is exhausted at this model size.
  2. Corpus > more-of-the-same. Swapping TinyStories → FineWeb-edu (v5→v6) was a qualitative jump: the model went from only-writes-kid-stories to writing genuine historical/scientific expository prose. (Cost: TinyStories transfer val 1.11 → 2.75.)
  3. Capacity helps. v8 (dim1024, ~1 epoch) beats both v6 (dim768, same epoch, by 0.085) and v7 (dim768, more data, by 0.035) → the dim768 runs were partly capacity-limited.
  4. Double-axis scale helps. v9 scales both axes (dim1280/core357M + 6.01B FineWeb tokens) and beats v8 by another 0.095 val loss (~3.2%). The direction is validated, but the gain is still incremental and greedy decoding still repeats.
  5. Moving validation tails must stop. v10 added one more FineWeb shard and got moving-tail val 2.8816, but appending data moves the held-out tail. A fixed eval v1 was created from the shard010 tail: v6/v7/v8/v9/v10 = 3.2328 / 3.1850 / 3.1515 / 2.9278 / 2.8814. Future runs should report this fixed eval first.

Meta-finding: every lever is now in the ~3% or smaller regime. Single-axis moves were exhausted by v8; v9 confirms Chinchilla-style double-axis scale works; v10 shows a data-only top-up mostly adapts to the new shard. The next useful run should change model/context, not just append another shard.

Efficiency — throughput & MFU

The throughput story is the perf-infra report card (RTX 5090, bf16/fp32):

v1 v2 v3 v4 v5
tok/s 3.3K (1 GPU) 3.6K (4 GPU) 26K (1 GPU) 145K (8 GPU) 217K (8 GPU)
MFU 0.4% 0.2% 14% 17% 13%
enabled by DDP (weak) batched (T10) alloc (T11) bf16 (T12)

v1/v2 ran at <0.5% MFU — the single-sequence design left the GPU idle (launch-bound). Batched forward (T10) was the single biggest unlock (~35× MFU jump). 6ND is an accurate FLOPs count, but predicting time needs the realized MFU, which varied ~40× across versions — a fixed-MFU estimate is off by up to ~100× for the early launch-bound runs.

Engineering lessons

  • Profile before optimizing. Three "known" fixes were falsified by measurement: (1) "bigger batch fixes DDP scaling" (real cause: single-seq launch-bound → T10); (2) "bucket the all-reduce" (real cause: per-op cudaMalloc serialization → T11 caching allocator); and (3) "process-per-GPU would fix the residual ~5×@8" (T17 — built the torchrun-style launcher and measured it throughput-neutral: the residual is the NCCL/PCIe communication wall, not shared-context serialization). All three would have been no-ops; each got measured and either reverted or recorded as a deliberate negative result instead of shipped on faith.
  • Honest correctness. QK-norm was added to match xserv's Qwen3 (not faked); every change kept a hard correctness gate, and no tolerance was ever loosened to go green. Phase 2 held the line: flash == composed SDPA (grads/PyTorch), GQA group=1 bit-identical to MHA, gradient accumulation accum=1 bit-identical, dropout p=0 bit-identical and dropout × recompute bit-exact, the default path unchanged on every feature, and the xserv closed-loop md5 byte-identical (b04fc9f9) throughout the deterministic gates.
  • The closed loop matters. Exporting to xserv and checking generated continuations caught real bugs and proved the whole stack end-to-end.

Running it

Everything trains on a remote 8× RTX 5090 box; model artifacts live in a registry (tiny-models/v0…v10). Serve any trained version in xserv:

# on the GPU box
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v10-fineweb-edu-dim1280-gqa-data6765 --max-tokens 100
# then type a prompt, e.g.  In science,

Build/test the engine itself (CUDA compiles + runs on the GPU box; host-side cargo check works anywhere via the no_cuda cfg):

export PATH=/usr/local/cuda/bin:$HOME/.cargo/bin:$PATH
cargo test --workspace            # autograd grad-checks, PyTorch parity, DDP, etc.

Doc index

  • docs/evolution.md — per-milestone changes across algorithm / architecture / infra / dataset.
  • docs/runs/README.md — the v0v10 comparison; docs/runs/0N-*.md — per-run detail.
  • docs/00-*17-* — per-phase design docs (build chain → tensor → autograd → transformer → training → perf → distributed → export → batched → allocator → bf16 → recompute → flash-attention → GQA → grad-accum → process-per-GPU → dropout).
  • docs/known-issues.md — perf backlog (KI-1/2/3/5 fixed; process-per-GPU CLOSED = measured no-op; KI-4 = accepted modeling tradeoff).
Description
No description provided
Readme 3.1 MiB
Languages
Rust 87.6%
Cuda 8.7%
Python 2.2%
Shell 1.5%