Add per-run design+result docs for the two Chinchilla-axis runs that were done but never committed: - v9 (dim1280 true-GQA, core 357M, 6.01B FineWeb tokens): double-axis scale, best moving-tail val 2.8854 (~3.2% below v8) — direction validated, gain still incremental, greedy repetition remains. - v10 (same arch, data-only top-up to 6.765B): moving-tail 2.8816; fixed eval v1 v6→v10 = 3.2328/3.1850/3.1515/2.9278/2.8814. Extend the comparison tables in docs/runs/README.md and docs/evolution.md to v10, and reframe README to v0–v10 with Phase 3 = the v9 double-axis run. No code changes. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
6.3 KiB
Scaling Run v9: Chinchilla 双轴 — dim1280/18L true GQA(core 356.9M) + FineWeb-edu 6.01B token + Phase-2 stack — Design Document
Goal
v8 给出的元结论是:单独拨容量轴有用,但只有约 3% 的边际;单独重复旧数据也只有约 1.6% 的边际。要继续明显超过 v8,必须把 模型容量 + 新 token 一起放大,而不是只拨一根轴。
v9 就是这个双轴点:
- 模型轴:dim1024/core 226M -> dim1280/core 356.9M,同时启用真 GQA(40 query heads / 10 kv heads)。
- 数据轴:v6-v8 的 2.255B FineWeb 子集 -> 6.013B token,追加了新 FineWeb-edu shards 003-009。
- 系统栈:使用 Phase-2 现代路径:
--flash + --accum-steps + bf16 + recompute + DDP。dropout 设为 0,按标准预训练。
v9 的 val 仍是 FineWeb-edu 分布,不能和 v0-v5 的 TinyStories val 直接比。注意:v9 扩展 cache 后默认 tail-heldout 已经从 v6-v8 的旧 tail 移到新 shards 末尾;严格横比后续以 fixed eval v1 为准。
Data
| 项 | 值 |
|---|---|
| 来源 | FineWeb-edu sample/10BT,原 shards 000-002 + 新 shards 003-009 |
| token cache | data/fineweb-edu.txt.u16.bin |
| 总 token | 6,013,639,492 |
| held-out val | 末尾 1,000,000 token |
| train corpus | 6,012,639,492 token |
| 训练消费 token | 6,012,600,320 = 91745 steps x effective batch 256 x seq 256 |
| epoch | ~1.00 |
P3-DATA 目标本来是约 7B token;shard 010 下载 curl rc=18 中断,所以最终停在 6.01B。对 core 356.9M 来说,
D/N 约 16.8 token/param,低于理想 Chinchilla 20,但已经远高于 v8 的约 10.4,是一个干净的双轴 scale 点。
Architecture
| 项 | v8 | v9 |
|---|---|---|
| dim | 1024 | 1280 |
| layers | 18 | 18 |
| query heads x head_dim | 32 x 32 | 40 x 32 |
| kv heads | 32 (MHA) | 10 (true GQA, group=4) |
| ffn | 2730 | 4096 |
| core params | 226.50M | 356.89M |
| total params | 329.42M | 485.55M |
| export tensors | 201 | 201 |
config.json writes real num_key_value_heads = 10, so xserv loads v9 as true GQA rather than MHA.
Training
| 项 | 值 |
|---|---|
| optimizer | hand-written AdamW, wd=0.1 |
| schedule | warmup -> cosine, max_lr 6e-4 -> min_lr 6e-5 |
| grad clip | global norm 1.0 |
| steps | 91745 |
| effective global batch | 256 (--batch 128 --accum-steps 2) |
| seq_len | 256 |
| precision | bf16 mixed precision, fp32 master |
| memory stack | activation recompute + flash-attention + gradient accumulation |
| world size | 8 x RTX 5090 |
| wall clock | 21h15m |
| steady throughput | ~78.6K tok/s |
| peak observed memory | ~17GB / GPU |
Command:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 cargo run -p xtrain-distributed --release --bin train_ddp -- \
/opt/wjh/models/gpt2/tokenizer.json data/fineweb-edu.txt \
--heads 40 --head-dim 32 --kv-heads 10 --layers 18 --ffn 4096 \
--steps 91745 --batch 128 --accum-steps 2 --seq 256 \
--max-lr 6e-4 --min-lr 6e-5 --val-tokens 1000000 --eval-every 1000 \
--eval-batches 64 --bf16 --recompute --flash --dropout 0.0 \
--ckpt /dashscope-tmp/wjh/xtrain_v9.ckpt
Results
- train loss: 11.1550 -> 2.9340
- first val: step 1000 = 5.1517
- best val: step 91000 = 2.8854
- final val: step 91745 = 2.8873
- exit code: 0
FineWeb val curve milestones:
| step | 1000 | 10000 | 20000 | 30000 | 40000 | 50000 | 60000 | 70000 | 80000 | 90000 | 91000 | final |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| val | 5.1517 | 3.4820 | 3.2953 | 3.2026 | 3.1422 | 3.0844 | 3.0148 | 2.9616 | 2.9160 | 2.8915 | 2.8854 | 2.8873 |
The curve kept improving into the last 1K-step window, then the final eval bounced slightly from 2.8854 to 2.8873. This is close to the floor for this run, but not a clear overfit failure.
Comparison
| v6 | v7 | v8 | v9 | |
|---|---|---|---|---|
| model | dim768/core127M | dim768/core127M | dim1024/core226M | dim1280/core357M + GQA |
| data | 2.29B | 3.28B same subset | 2.36B same subset | 6.01B expanded shards |
| best val | 3.0652 | 3.0149 | 2.9801 | 2.8854 |
On the run-local moving tail, v9 beats v8 by 0.0947 val loss (~3.2% relative), essentially the same size as the v6->v8 capacity gain but now on top of it. A later fixed eval v1 check still supports the same direction (v8 3.1515 -> v9 2.9278 on shard010-tail holdout), while making the moving-tail caveat explicit. This confirms the v8 prediction: 双轴 scale 有效. It is still an incremental gain, not a qualitative jump.
Samples
xserv greedy samples (--max-tokens 60) are more coherent than the v8 examples on some prompts, but repetition remains:
[The history of] the United States is the story of the people, the places, and the events that have shaped the nation...
[In science,] the term "scientific method" is used to describe the process of gathering information and testing it...
[The most important] thing is to be aware of the symptoms and to seek medical attention...
[Water is] a natural resource that is essential for human life...
The model writes real explanatory English and the domain mix is FineWeb-like. Greedy decoding still falls into repeated clauses on
some prompts (scientific method, symptoms, and earlier fixed prompts), so the val gain is more visible in the metric than in a
dramatic sample-quality leap.
xserv validation
Registry path:
/opt/wjh/projects/tiny-models/v9-fineweb-edu-dim1280-gqa
Files:
config.jsonmodel.safetensors(BF16, 201 tensors, 927MB)tokenizer.jsonxtrain.ckpt(fp32 master checkpoint, 1.9GB)RUN.md
xserv loads v9 as:
Model: qwen3, layers=18, hidden=1280, heads=40/10 kv, vocab=50257
Loaded 201 tensors
Ready (KV cache, dtype=bf16).
Token-match check against xtrain greedy (max-tokens 40):
Once upon a time: xtrain and xserv matched through the checked continuation.One day: diverged after "large, dark," (very tall manvsmetallic object) from BF16 greedy tie sensitivity.The little: same repetitive pattern, with a short BF16 path divergence.
This is the same class of BF16-vs-f32 greedy drift seen in v8; the important integration result is that xserv successfully loads
true GQA (kv_heads=10 < heads=40) and generates from the exported weights.