Compare commits
3 Commits
a1370446fe
...
7a1fba95b5
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a1fba95b5 | |||
| fbf4ac2917 | |||
| 5c27493a90 |
45
README.md
45
README.md
@@ -6,14 +6,16 @@ inference side). A learning project: hand-write the entire training-systems stac
|
||||
gradient checkpointing), then use it to run a multi-version **scaling study** that maps
|
||||
the data-vs-capacity frontier for a tiny model.
|
||||
|
||||
> **Status: complete — two phases.**
|
||||
> **Status: complete — three phases.**
|
||||
> **Phase 1** = the from-scratch full stack (T1–T13) + an 8-version scaling study (v0–v8):
|
||||
> hand-write the whole training-systems stack, then map the data-vs-capacity frontier.
|
||||
> **Phase 2** = systems-stack depth (T14–T18): hand-write the five deferred training-stack
|
||||
> features — fused flash-attention, real GQA, gradient accumulation, process-per-GPU DDP,
|
||||
> dropout. Trains a Qwen3-compatible LM whose weights load into **xserv** and generate
|
||||
> **token-identical** output — the closed loop held byte-for-byte across both phases. This
|
||||
> README is the capstone; per-topic detail lives in [`docs/`](docs/).
|
||||
> dropout. **Phase 3** = one Chinchilla-style double-axis run (v9): dim1280 true-GQA +
|
||||
> 6.01B FineWeb tokens, validating the v8 conclusion that data and capacity must scale
|
||||
> together. Trains Qwen3-compatible LMs whose weights load into **xserv**; deterministic
|
||||
> gates stay byte-identical, while large BF16 checkpoints are served and checked for
|
||||
> prompt-level drift. This README is the capstone; per-topic detail lives in [`docs/`](docs/).
|
||||
|
||||
---
|
||||
|
||||
@@ -34,7 +36,8 @@ borrows, the rest hand-written CUDA + Rust:
|
||||
|
||||
Every op's backward is verified against **finite differences** and against **PyTorch**
|
||||
(forward + per-parameter grads, batch > 1). Trained weights export to HF-safetensors and
|
||||
load into xserv (Qwen3, BF16) producing token-identical greedy output — the closed loop.
|
||||
load into xserv (Qwen3, BF16); deterministic fixtures produce token-identical greedy output,
|
||||
and large checkpoints are validated end-to-end in the serving path.
|
||||
|
||||
## The build journey — Phase 1 (T1–T13) + Phase 2 (T14–T18)
|
||||
|
||||
@@ -106,7 +109,7 @@ Each is opt-in, kept the default path **bit-identical**, and held a **hard corre
|
||||
residual ~5×@8; with all 8 GPUs at 95–99% util, the residual is the **NCCL all-reduce + PCIe
|
||||
topology wall**, not context serialization. The third profile-first falsification (see below).
|
||||
|
||||
## The scaling study — v0 → v8
|
||||
## The scaling study — v0 → v10
|
||||
|
||||
Same Qwen3-style architecture throughout; we scaled **dim** and **data** and read out val
|
||||
loss (full per-run detail in [`docs/runs/`](docs/runs/)).
|
||||
@@ -119,11 +122,13 @@ loss (full per-run detail in [`docs/runs/`](docs/runs/)).
|
||||
| v6 | FineWeb-edu 1.02ep | 768 / 127M | 3.07\* | **corpus swap → graduates to real text** |
|
||||
| v7 | FineWeb-edu 1.45ep | 768 / 127M | 3.01\* | same subset, more epochs → near-ceiling |
|
||||
| **v8** | FineWeb-edu 1.05ep | **1024 / 226M** | **2.98\*** | **capacity → helps** |
|
||||
| **v9** | FineWeb-edu 6.01B / ~1ep | **1280 / 357M + GQA** | **2.89\*** | **data + capacity → helps** |
|
||||
| **v10** | FineWeb-edu 6.76B / ~1ep | **1280 / 357M + GQA** | **2.88\*** | **data-only top-up → small gain** |
|
||||
|
||||
\* FineWeb-edu val is a different (harder) distribution — **not comparable** to the
|
||||
TinyStories val of v0–v5. Judge v6+ by sample quality + transfer, not the number.
|
||||
|
||||
### Three findings
|
||||
### Four findings
|
||||
|
||||
1. **Data volume saturates.** TinyStories at dim768: 3.5× more tokens (v4→v5) bought only
|
||||
−5% val, curve flat. The narrow synthetic corpus is exhausted at this model size.
|
||||
@@ -132,10 +137,18 @@ TinyStories val of v0–v5. Judge v6+ by sample quality + transfer, not the numb
|
||||
historical/scientific expository prose. (Cost: TinyStories transfer val 1.11 → 2.75.)
|
||||
3. **Capacity helps.** v8 (dim1024, ~1 epoch) beats both v6 (dim768, same epoch, by 0.085)
|
||||
and v7 (dim768, *more* data, by 0.035) → the dim768 runs were partly capacity-limited.
|
||||
4. **Double-axis scale helps.** v9 scales both axes (dim1280/core357M + 6.01B FineWeb tokens)
|
||||
and beats v8 by another 0.095 val loss (~3.2%). The direction is validated, but the gain is
|
||||
still incremental and greedy decoding still repeats.
|
||||
5. **Moving validation tails must stop.** v10 added one more FineWeb shard and got moving-tail
|
||||
val 2.8816, but appending data moves the held-out tail. A fixed eval v1 was created from the
|
||||
shard010 tail: v6/v7/v8/v9/v10 = 3.2328 / 3.1850 / 3.1515 / 2.9278 / 2.8814. Future runs
|
||||
should report this fixed eval first.
|
||||
|
||||
**Meta-finding:** every *single*-axis lever (data volume, corpus breadth, capacity) is now
|
||||
worth only **~3%**. Per the Chinchilla lesson, further gains require scaling **data and
|
||||
capacity together** — single-axis moves are exhausted.
|
||||
**Meta-finding:** every lever is now in the **~3% or smaller** regime. Single-axis moves were
|
||||
exhausted by v8; v9 confirms Chinchilla-style double-axis scale works; v10 shows a data-only
|
||||
top-up mostly adapts to the new shard. The next useful run should change model/context, not just
|
||||
append another shard.
|
||||
|
||||
## Efficiency — throughput & MFU
|
||||
|
||||
@@ -166,18 +179,18 @@ versions — a fixed-MFU estimate is off by up to ~100× for the early launch-bo
|
||||
the line: flash == composed SDPA (grads/PyTorch), GQA group=1 bit-identical to MHA, gradient
|
||||
accumulation `accum=1` bit-identical, dropout p=0 bit-identical *and* dropout × recompute
|
||||
bit-exact, the default path unchanged on every feature, and the **xserv closed-loop md5
|
||||
byte-identical (`b04fc9f9`) throughout both phases**.
|
||||
- **The closed loop matters.** Exporting to xserv and checking token-identical greedy output
|
||||
caught real bugs and proved the whole stack end-to-end.
|
||||
byte-identical (`b04fc9f9`) throughout the deterministic gates**.
|
||||
- **The closed loop matters.** Exporting to xserv and checking generated continuations caught
|
||||
real bugs and proved the whole stack end-to-end.
|
||||
|
||||
## Running it
|
||||
|
||||
Everything trains on a remote 8× RTX 5090 box; model artifacts live in a registry
|
||||
(`tiny-models/v0…v8`). Serve any trained version in xserv:
|
||||
(`tiny-models/v0…v10`). Serve any trained version in xserv:
|
||||
|
||||
```bash
|
||||
# on the GPU box
|
||||
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v8-fineweb-edu-dim1024 --max-tokens 100
|
||||
cargo run -p xserv-model --release --bin xserv-cli -- <registry>/v10-fineweb-edu-dim1280-gqa-data6765 --max-tokens 100
|
||||
# then type a prompt, e.g. In science,
|
||||
```
|
||||
|
||||
@@ -192,6 +205,6 @@ cargo test --workspace # autograd grad-checks, PyTorch parity, DDP, e
|
||||
## Doc index
|
||||
|
||||
- [`docs/evolution.md`](docs/evolution.md) — per-milestone changes across algorithm / architecture / infra / dataset.
|
||||
- [`docs/runs/README.md`](docs/runs/README.md) — the v0–v8 comparison; [`docs/runs/0N-*.md`](docs/runs/) — per-run detail.
|
||||
- [`docs/runs/README.md`](docs/runs/README.md) — the v0–v10 comparison; [`docs/runs/0N-*.md`](docs/runs/) — per-run detail.
|
||||
- [`docs/00-*` … `17-*`](docs/) — per-phase design docs (build chain → tensor → autograd → transformer → training → perf → distributed → export → batched → allocator → bf16 → recompute → flash-attention → GQA → grad-accum → process-per-GPU → dropout).
|
||||
- [`docs/known-issues.md`](docs/known-issues.md) — perf backlog (KI-1/2/3/5 fixed; process-per-GPU CLOSED = measured no-op; KI-4 = accepted modeling tradeoff).
|
||||
|
||||
@@ -398,7 +398,8 @@ pub fn repeat_kv(kv: &Var, nh: usize, batch: usize) -> Var {
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// row. Negative targets are ignored, which is useful for assistant-only SFT
|
||||
/// masks. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/valid_rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
// CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32
|
||||
@@ -407,10 +408,22 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
// fp32 logits buffer) is a real activation-memory saving at large vocab.
|
||||
let logit_dtype = x.value().dtype();
|
||||
let (probs, per_row) = x.value().cross_entropy(target);
|
||||
let rows = x.value().shape()[0];
|
||||
let cols = x.value().shape()[1] as i32;
|
||||
let target_host = target.to_device(xtrain_tensor::Device::Cpu);
|
||||
let valid_rows = target_host
|
||||
.as_slice::<i32>()
|
||||
.iter()
|
||||
.filter(|&&t| {
|
||||
if t >= cols {
|
||||
panic!("cross_entropy target {t} out of vocab range {cols}");
|
||||
}
|
||||
t >= 0
|
||||
})
|
||||
.count()
|
||||
.max(1);
|
||||
// Mean loss as a host scalar wrapped back into a [1] tensor.
|
||||
let mean = per_row.to_device(xtrain_tensor::Device::Cpu);
|
||||
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / rows as f32;
|
||||
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / valid_rows as f32;
|
||||
let loss = Tensor::from_slice(&[mean_val], &[1]).to_device(x.value().device());
|
||||
|
||||
let target = target.clone();
|
||||
@@ -420,7 +433,7 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
Box::new(move |d, parents| {
|
||||
// `d` is the scalar upstream grad (1.0 when this is the loss root).
|
||||
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
||||
let scale = upstream / rows as f32;
|
||||
let scale = upstream / valid_rows as f32;
|
||||
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
|
||||
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||
}),
|
||||
|
||||
@@ -88,6 +88,7 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
let sft_tsv = args.iter().any(|a| a == "--sft-tsv");
|
||||
// Dropout (Phase T18/T21): residual-path dropout prob, active at training time
|
||||
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
|
||||
// (forward graph bit-identical to the no-dropout path). Mirrors bin/train; the
|
||||
@@ -109,6 +110,11 @@ fn main() {
|
||||
.position(|a| a == "--ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
let init_ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--init-ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
|
||||
// Use every visible GPU as a rank (CUDA_VISIBLE_DEVICES selects the set;
|
||||
// device ordinals are 0..count within it).
|
||||
@@ -129,12 +135,19 @@ fn main() {
|
||||
);
|
||||
|
||||
// Reuse the cached token-id stream (v1's u16 cache); never re-tokenize 2GB.
|
||||
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||
let corpus = if sft_tsv {
|
||||
Corpus::load_sft_tsv_cached(&tok_path, &corpus_path)
|
||||
} else {
|
||||
Corpus::load_cached(&tok_path, &corpus_path)
|
||||
};
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
if sft_tsv {
|
||||
println!("SFT TSV: ON (assistant-only loss via ignore-index labels)");
|
||||
}
|
||||
let vocab = corpus.vocab_size;
|
||||
// Hold out a tail slice for validation (rank 0 evaluates on it).
|
||||
let (train_corpus, valid) = if val_tokens > 0 {
|
||||
@@ -200,6 +213,10 @@ fn main() {
|
||||
if dropout > 0.0 {
|
||||
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
|
||||
}
|
||||
if let Some(path) = &init_ckpt {
|
||||
println!("init checkpoint: {}", path.display());
|
||||
}
|
||||
let init_ckpt_for_ranks = init_ckpt.clone();
|
||||
let results = launch(
|
||||
&devices,
|
||||
&train_corpus,
|
||||
@@ -216,6 +233,10 @@ fn main() {
|
||||
if flash {
|
||||
m = m.with_flash(true);
|
||||
}
|
||||
if let Some(path) = &init_ckpt_for_ranks {
|
||||
xtrain_train::checkpoint::load_into(path, &m.params())
|
||||
.expect("load init checkpoint");
|
||||
}
|
||||
m
|
||||
},
|
||||
);
|
||||
|
||||
@@ -27,6 +27,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||||
.collect();
|
||||
Corpus {
|
||||
tokens,
|
||||
labels: None,
|
||||
vocab_size: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ fn synth_corpus() -> Corpus {
|
||||
.collect();
|
||||
Corpus {
|
||||
tokens,
|
||||
labels: None,
|
||||
vocab_size: VOCAB,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,10 +66,18 @@ fn tiny_cfg(dropout: f32) -> Config {
|
||||
fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) {
|
||||
let (batch, seq) = (3usize, 6usize);
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
(
|
||||
batched_ids_tensor(&seqs, device),
|
||||
@@ -94,7 +102,11 @@ fn fwd_bwd(
|
||||
let loss = m.loss_batched(ids, tgt, batch);
|
||||
let loss_val = host(&loss.value())[0];
|
||||
loss.backward();
|
||||
let grads: Vec<Vec<f32>> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
let grads: Vec<Vec<f32>> = m
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().unwrap()))
|
||||
.collect();
|
||||
(logits, loss_val, grads)
|
||||
}
|
||||
|
||||
@@ -186,7 +198,9 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
|
||||
|
||||
// Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps
|
||||
// to 1 on the first training forward in BOTH, so they draw the same masks.
|
||||
let off = build(cfg, device).with_compute_dtype(dtype).with_training(true);
|
||||
let off = build(cfg, device)
|
||||
.with_compute_dtype(dtype)
|
||||
.with_training(true);
|
||||
let on = build(cfg, device)
|
||||
.with_compute_dtype(dtype)
|
||||
.with_recompute(true)
|
||||
@@ -194,11 +208,19 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
|
||||
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
let off_grads: Vec<Vec<f32>> = off
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().unwrap()))
|
||||
.collect();
|
||||
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
let on_grads: Vec<Vec<f32>> = on
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().unwrap()))
|
||||
.collect();
|
||||
|
||||
let mut max_rel = 0.0f32;
|
||||
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
||||
@@ -240,10 +262,18 @@ fn flash_plus_dropout_grad_check_fp32() {
|
||||
cfg.dropout = 0.2;
|
||||
let seq = 40usize;
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
@@ -277,7 +307,16 @@ fn flash_plus_dropout_grad_check_fp32() {
|
||||
);
|
||||
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
|
||||
// differs from composed only by reduction order; dropout masks are identical.
|
||||
assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}");
|
||||
assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}");
|
||||
assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}");
|
||||
assert!(
|
||||
logit_rel < 1e-3,
|
||||
"[F32] flash+dropout logits diverged: {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
loss_rel < 1e-3,
|
||||
"[F32] flash+dropout loss diverged: {loss_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
grad_rel < 2e-2,
|
||||
"[F32] flash+dropout grads diverged: {grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! cargo run -p xtrain-train --release --bin greedy_sample -- \
|
||||
//! /tmp/xtrain_v4.ckpt /opt/wjh/models/gpt2/tokenizer.json \
|
||||
//! --heads 24 --head-dim 32 --layers 18 --ffn 2048
|
||||
//! --heads 24 --head-dim 32 --layers 18 --ffn 2048 \
|
||||
//! --prompts-file scripts/chat_alpha_fixed_prompts.txt --max-tokens 120
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
@@ -52,6 +53,60 @@ fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_values(args: &[String], name: &str) -> Vec<String> {
|
||||
args.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, a)| {
|
||||
if a == name {
|
||||
args.get(i + 1).cloned()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn decode_prompt_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn load_prompts(args: &[String]) -> Vec<String> {
|
||||
let mut prompts = Vec::new();
|
||||
if let Some(path) = flag_value(args, "--prompts-file") {
|
||||
let text = std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|e| panic!("failed to read prompts file {path}: {e}"));
|
||||
prompts.extend(
|
||||
text.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty() && !line.starts_with('#'))
|
||||
.map(decode_prompt_escapes),
|
||||
);
|
||||
}
|
||||
prompts.extend(
|
||||
flag_values(args, "--prompt")
|
||||
.into_iter()
|
||||
.map(|p| decode_prompt_escapes(&p)),
|
||||
);
|
||||
if prompts.is_empty() {
|
||||
prompts = ["Once upon a time", "One day", "The little"]
|
||||
.into_iter()
|
||||
.map(String::from)
|
||||
.collect();
|
||||
}
|
||||
prompts
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
@@ -75,6 +130,8 @@ fn main() {
|
||||
// GQA (Phase T15): num K/V heads (must match the ckpt; default = --heads).
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let max_new = flag(&args, "--max-tokens", 40usize);
|
||||
let temperature = flag(&args, "--temperature", 0.0f32);
|
||||
let prompts = load_prompts(&args);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
@@ -106,11 +163,16 @@ fn main() {
|
||||
});
|
||||
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
|
||||
|
||||
let prompts = ["Once upon a time", "One day", "The little"];
|
||||
println!(
|
||||
"decode: prompts={} max_new={} temperature={}",
|
||||
prompts.len(),
|
||||
max_new,
|
||||
temperature
|
||||
);
|
||||
for p in prompts {
|
||||
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
|
||||
let ids: Vec<i32> = tok.encode(&p).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 7u64;
|
||||
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
|
||||
let out = generate(&model, device, &ids, max_new, temperature, &mut rng);
|
||||
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
println!("[{p}] → {text}");
|
||||
}
|
||||
|
||||
@@ -115,6 +115,7 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
let sft_tsv = args.iter().any(|a| a == "--sft-tsv");
|
||||
// Dropout (Phase T18): residual-path dropout prob, active at training time
|
||||
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
|
||||
// (forward graph bit-identical to the no-dropout path).
|
||||
@@ -136,6 +137,11 @@ fn main() {
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "/tmp/xtrain_tinystories.ckpt".to_string()),
|
||||
);
|
||||
let init_ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--init-ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
@@ -146,12 +152,19 @@ fn main() {
|
||||
tok_path.display(),
|
||||
corpus_path.display()
|
||||
);
|
||||
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||
let corpus = if sft_tsv {
|
||||
Corpus::load_sft_tsv_cached(&tok_path, &corpus_path)
|
||||
} else {
|
||||
Corpus::load_cached(&tok_path, &corpus_path)
|
||||
};
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
if sft_tsv {
|
||||
println!("SFT TSV: ON (assistant-only loss via ignore-index labels)");
|
||||
}
|
||||
let vocab = corpus.vocab_size;
|
||||
// Hold out a tail slice for validation (if requested and the corpus is big).
|
||||
let (train_corpus, valid) = if val_tokens > 0 {
|
||||
@@ -206,6 +219,10 @@ fn main() {
|
||||
if dropout > 0.0 {
|
||||
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
|
||||
}
|
||||
if let Some(path) = &init_ckpt {
|
||||
xtrain_train::checkpoint::load_into(path, &model.params()).expect("load init checkpoint");
|
||||
println!("init checkpoint: loaded {}", path.display());
|
||||
}
|
||||
|
||||
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||
|
||||
@@ -15,6 +15,7 @@ use xserv_tokenizer::Tokenizer;
|
||||
/// A tokenized corpus: one flat stream of token ids, plus the vocab size.
|
||||
pub struct Corpus {
|
||||
pub tokens: Vec<i32>,
|
||||
pub labels: Option<Vec<i32>>,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ impl Corpus {
|
||||
let ids: Vec<i32> = tok.encode(text).into_iter().map(|t| t as i32).collect();
|
||||
Self {
|
||||
tokens: ids,
|
||||
labels: None,
|
||||
vocab_size: tok.vocab_size(),
|
||||
}
|
||||
}
|
||||
@@ -52,7 +54,11 @@ impl Corpus {
|
||||
tokens.len(),
|
||||
cache.display()
|
||||
);
|
||||
return Self { tokens, vocab_size };
|
||||
return Self {
|
||||
tokens,
|
||||
labels: None,
|
||||
vocab_size,
|
||||
};
|
||||
}
|
||||
let me = Self::load(tokenizer_path, corpus_path);
|
||||
write_u16_cache(&cache, &me.tokens);
|
||||
@@ -64,22 +70,104 @@ impl Corpus {
|
||||
me
|
||||
}
|
||||
|
||||
/// Load assistant-only SFT data from a two-column TSV:
|
||||
///
|
||||
/// ```text
|
||||
/// user<TAB>assistant
|
||||
/// ```
|
||||
///
|
||||
/// Literal `\n` and `\t` escapes are decoded. Each row is formatted as
|
||||
/// `User: ...\nAssistant:` + answer + `<|endoftext|>`. Labels are `-100`
|
||||
/// for prompt tokens and the token id itself for answer/EOS tokens, so the
|
||||
/// cross-entropy op ignores prompt rows while still training the assistant
|
||||
/// answer and stop token.
|
||||
pub fn load_sft_tsv_cached(tokenizer_path: &Path, corpus_path: &Path) -> Self {
|
||||
let token_cache = cache_path(corpus_path);
|
||||
let label_cache = label_cache_path(corpus_path);
|
||||
let vocab_size = Tokenizer::from_file(tokenizer_path).vocab_size();
|
||||
if token_cache.exists() && label_cache.exists() {
|
||||
let tokens = read_u16_cache(&token_cache);
|
||||
let labels = read_i32_cache(&label_cache);
|
||||
assert_eq!(
|
||||
tokens.len(),
|
||||
labels.len(),
|
||||
"SFT cache token/label length mismatch"
|
||||
);
|
||||
println!(
|
||||
"corpus: read {} cached SFT tokens from {} (+ labels {})",
|
||||
tokens.len(),
|
||||
token_cache.display(),
|
||||
label_cache.display()
|
||||
);
|
||||
return Self {
|
||||
tokens,
|
||||
labels: Some(labels),
|
||||
vocab_size,
|
||||
};
|
||||
}
|
||||
|
||||
let tok = Tokenizer::from_file(tokenizer_path);
|
||||
let text = std::fs::read_to_string(corpus_path)
|
||||
.unwrap_or_else(|e| panic!("failed to read SFT corpus {}: {e}", corpus_path.display()));
|
||||
let mut tokens = Vec::new();
|
||||
let mut labels = Vec::new();
|
||||
for (lineno, line) in text.lines().enumerate() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (user, assistant) = line
|
||||
.split_once('\t')
|
||||
.unwrap_or_else(|| panic!("SFT TSV line {} missing tab", lineno + 1));
|
||||
let user = decode_tsv_escapes(user);
|
||||
let assistant = decode_tsv_escapes(assistant);
|
||||
let prompt = format!("User: {user}\nAssistant:");
|
||||
let answer = format!(" {assistant}\n<|endoftext|>");
|
||||
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
||||
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
|
||||
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
||||
labels.extend(answer_ids.iter().copied());
|
||||
tokens.extend(prompt_ids);
|
||||
tokens.extend(answer_ids);
|
||||
}
|
||||
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
|
||||
write_u16_cache(&token_cache, &tokens);
|
||||
write_i32_cache(&label_cache, &labels);
|
||||
println!(
|
||||
"corpus: tokenized {} SFT tokens → cached to {} (+ labels {})",
|
||||
tokens.len(),
|
||||
token_cache.display(),
|
||||
label_cache.display()
|
||||
);
|
||||
Self {
|
||||
tokens,
|
||||
labels: Some(labels),
|
||||
vocab_size: tok.vocab_size(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Split off the last `n` tokens as a held-out validation corpus, leaving the
|
||||
/// rest as the train corpus. Returns `(train, valid)`. Used for periodic val
|
||||
/// loss during training without leaking the eval window into training.
|
||||
pub fn split_tail(self, n: usize) -> (Self, Self) {
|
||||
let n = n.min(self.tokens.len() / 10); // never hand off more than 10%
|
||||
let cut = self.tokens.len() - n;
|
||||
let valid = self.tokens[cut..].to_vec();
|
||||
let valid_tokens = self.tokens[cut..].to_vec();
|
||||
let valid_labels = self.labels.as_ref().map(|labels| labels[cut..].to_vec());
|
||||
let mut train = self.tokens;
|
||||
train.truncate(cut);
|
||||
let train_labels = self.labels.map(|mut labels| {
|
||||
labels.truncate(cut);
|
||||
labels
|
||||
});
|
||||
(
|
||||
Self {
|
||||
tokens: train,
|
||||
labels: train_labels,
|
||||
vocab_size: self.vocab_size,
|
||||
},
|
||||
Self {
|
||||
tokens: valid,
|
||||
tokens: valid_tokens,
|
||||
labels: valid_labels,
|
||||
vocab_size: self.vocab_size,
|
||||
},
|
||||
)
|
||||
@@ -101,11 +189,27 @@ impl Corpus {
|
||||
pub fn sample(&self, seq: usize, rng_state: &mut u64) -> (Vec<i32>, Vec<i32>) {
|
||||
assert!(self.tokens.len() > seq + 1, "corpus shorter than a window");
|
||||
let max_start = self.tokens.len() - seq - 1;
|
||||
let start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
|
||||
let mut start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
|
||||
if let Some(labels) = &self.labels {
|
||||
for _ in 0..16 {
|
||||
if labels[start + 1..start + seq + 1].iter().any(|&t| t >= 0) {
|
||||
break;
|
||||
}
|
||||
start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
|
||||
}
|
||||
}
|
||||
let input = self.tokens[start..start + seq].to_vec();
|
||||
let target = self.tokens[start + 1..start + seq + 1].to_vec();
|
||||
let target = self.target_window(start, seq);
|
||||
(input, target)
|
||||
}
|
||||
|
||||
/// Deterministic target labels for an input window starting at `start`.
|
||||
pub fn target_window(&self, start: usize, seq: usize) -> Vec<i32> {
|
||||
match &self.labels {
|
||||
Some(labels) => labels[start + 1..start + seq + 1].to_vec(),
|
||||
None => self.tokens[start + 1..start + seq + 1].to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop a leading partial line (before the first newline) and everything after
|
||||
@@ -127,6 +231,12 @@ fn cache_path(corpus_path: &Path) -> PathBuf {
|
||||
PathBuf::from(s)
|
||||
}
|
||||
|
||||
fn label_cache_path(corpus_path: &Path) -> PathBuf {
|
||||
let mut s = corpus_path.as_os_str().to_os_string();
|
||||
s.push(".labels.i32.bin");
|
||||
PathBuf::from(s)
|
||||
}
|
||||
|
||||
/// Read a flat little-endian `[u16]` cache into an `i32` id stream.
|
||||
fn read_u16_cache(path: &Path) -> Vec<i32> {
|
||||
let mut r = BufReader::new(
|
||||
@@ -140,6 +250,18 @@ fn read_u16_cache(path: &Path) -> Vec<i32> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn read_i32_cache(path: &Path) -> Vec<i32> {
|
||||
let mut r = BufReader::new(
|
||||
std::fs::File::open(path).unwrap_or_else(|e| panic!("open cache {}: {e}", path.display())),
|
||||
);
|
||||
let mut buf = Vec::new();
|
||||
r.read_to_end(&mut buf).expect("read cache");
|
||||
assert!(buf.len() % 4 == 0, "corrupt i32 cache (odd byte count)");
|
||||
buf.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Write an id stream as a flat little-endian `[u16]` cache. Ids must fit in u16
|
||||
/// (GPT-2 vocab = 50257 < 65536); asserts otherwise.
|
||||
fn write_u16_cache(path: &Path, tokens: &[i32]) {
|
||||
@@ -154,6 +276,21 @@ fn write_u16_cache(path: &Path, tokens: &[i32]) {
|
||||
w.flush().expect("flush cache");
|
||||
}
|
||||
|
||||
fn write_i32_cache(path: &Path, labels: &[i32]) {
|
||||
let mut w = BufWriter::new(
|
||||
std::fs::File::create(path)
|
||||
.unwrap_or_else(|e| panic!("create cache {}: {e}", path.display())),
|
||||
);
|
||||
for &t in labels {
|
||||
w.write_all(&t.to_le_bytes()).expect("write cache");
|
||||
}
|
||||
w.flush().expect("flush cache");
|
||||
}
|
||||
|
||||
fn decode_tsv_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
||||
/// sampling is reproducible from a single u64 seed.
|
||||
fn next_rand(state: &mut u64) -> u64 {
|
||||
|
||||
@@ -207,7 +207,7 @@ pub fn eval_loss(
|
||||
break;
|
||||
}
|
||||
let input: Vec<i32> = valid.tokens[s..s + seq].to_vec();
|
||||
let target: Vec<i32> = valid.tokens[s + 1..s + seq + 1].to_vec();
|
||||
let target = valid.target_window(s, seq);
|
||||
let ids = ids_tensor(&input, device);
|
||||
let targets = ids_tensor(&target, device);
|
||||
let loss = model.loss(&ids, &targets);
|
||||
|
||||
@@ -216,6 +216,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||||
tokens: (0..n_tokens)
|
||||
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
|
||||
.collect(),
|
||||
labels: None,
|
||||
vocab_size: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,7 +338,7 @@ __global__ void cross_entropy_fwd_k(const float* x, const int* target,
|
||||
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
|
||||
if (threadIdx.x == 0) {
|
||||
int t = target[r];
|
||||
loss[r] = -logf(pr[t]);
|
||||
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
|
||||
}
|
||||
}
|
||||
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
|
||||
@@ -354,9 +354,14 @@ __global__ void cross_entropy_dx_k(const float* probs, const int* target,
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= rows * cols) return;
|
||||
int r = i / cols, c = i % cols;
|
||||
float g = probs[i] - (c == target[r] ? 1.0f : 0.0f);
|
||||
int t = target[r];
|
||||
if (t < 0) {
|
||||
dx[i] = 0.0f;
|
||||
} else {
|
||||
float g = probs[i] - (c == t ? 1.0f : 0.0f);
|
||||
dx[i] = g * scale;
|
||||
}
|
||||
}
|
||||
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
|
||||
float* dx, int rows, int cols, float scale, void* s) {
|
||||
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
|
||||
|
||||
@@ -33,9 +33,9 @@
|
||||
|
||||
---
|
||||
|
||||
## 二、Scaling runs(v0–v8)—— 主要动「模型架构」与「数据集」
|
||||
## 二、Scaling runs(v0–v10)—— 主要动「模型架构」与「数据集」
|
||||
|
||||
架构始终是 **Qwen3-style**(RoPE + RMSNorm + QK-norm + SwiGLU,gpt2 50257 词表),逐版放大 dim/层/头(v8 起首次拨容量轴到 dim1024);其余维度逐版变化如下:
|
||||
架构始终是 **Qwen3-style**(RoPE + RMSNorm + QK-norm + SwiGLU,gpt2 50257 词表),逐版放大 dim/层/头(v8 起首次拨容量轴到 dim1024,v9 进入 dim1280+真 GQA 双轴点,v10 固定架构只补数据轴);其余维度逐版变化如下:
|
||||
|
||||
| ver | 模型架构(dim/层/头·hd · 核心/总参) | 数据集(语料 · 实训 token · epoch) | 算法/精度 | Infra(GPU · 吞吐) | 结果(val · 备注) |
|
||||
|---|---|---|---|---|---|
|
||||
@@ -48,22 +48,27 @@
|
||||
| v6 | dim768/18L(同 v4/v5) | **FineWeb-edu** 真实网页 · 2.29B · 1.02ep | bf16 | 8 GPU · 218K | val **3.07**:⚠️**FineWeb 留出集,与 v0–v5 不可比**(真实网页熵高,~3.0 是预期);判据=采样质量+transfer。第一版脱离 TinyStories,**语言种类质变**(小故事→真实说明文);transfer→TinyStories val 2.75(v5 native 1.11),纯通用数据对窄分布有代价;val 末步仍单调降=未饱和 |
|
||||
| v7 | dim768/18L(同 v4/v5/v6) | **同 v6 的 FineWeb-edu 子集**(非新数据)· 3.28B · **1.45ep** | bf16 | 8 GPU · 218K | val **3.01**(与 v6 可比):⚠️**同子集多 epoch 近天花板**——唯一变量=epoch(1.02→1.45),多喂 ~1B token val 仅 ↓0.05 且 ~step44000 后走平、采样无质变。与 v5 的 TinyStories 数据量饱和同类(重复老数据边际薄);真·更多数据要**新 shards** |
|
||||
| v8 | **dim1024**/18L/**32h** · **226M/329M**(+78% 容量,ffn 2730) | **同 v6/v7 的 FineWeb-edu 子集**(非新数据)· 2.36B · **1.05ep** | bf16 **+ 激活重计算(T13)** | 8 GPU · 129K(重算税) | val **2.98**(与 v6/v7 可比):⭐**容量轴 A/B——容量有用**:唯一变量=dim768→dim1024,同 ~1ep v6 3.07→**2.98**(↓0.085),且 v8(1.05ep) < v7(1.45ep 更多老数据) 3.01 ⇒ 放大容量 > 重复老数据 ⇒ v6/v7 部分 capacity-limited。⚠️但增益仅 ~3%、val 末步**仍在降未饱和** ⇒ **单轴(数据/容量)单步都已 ~3%/lever = 全面边际递减,要双轴一起 scale(Chinchilla)** |
|
||||
| v9 | **dim1280**/18L/**40h/10kv GQA** · **357M/486M**(ffn 4096) | **FineWeb-edu 扩展 shards 000-009** · **6.01B** · **~1.00ep** | bf16 + recompute + **flash + grad-accum + true GQA** | 8 GPU · **78.6K**(21.25h) | val **2.8854**(与 v6-v8 可比):✅**双轴 Chinchilla 点有效**——容量从 v8 226M→357M,同时数据从 2.255B 子集→6.013B token,best val 比 v8 再降 **0.0947 (~3.2%)**。采样写真实说明文更稳一些,但 greedy 重复仍明显;收益仍是稳健增量而非质变 |
|
||||
| v10 | **同 v9** | **FineWeb-edu 扩展 shards 000-010** · **6.765B** · **~1.00ep** | bf16 + recompute + flash + grad-accum + true GQA | 8 GPU · **79.0K**(23.86h) | moving-tail val **2.8816**;固定 eval v1 上 v9 **2.9278**→v10 **2.8814**。结论:补 shard010 对新分布有效,但只补数据轴不解决 greedy 重复;后续应固定 eval set,并优先试更大模型+长 context |
|
||||
|
||||
> 实训 token = steps×batch×seq(非数据集大小)。val 同一 1M-token TinyStories 留出集(v0–v5 可比;v6 起换 FineWeb-edu 留出集,分布不同、与 v0–v5 不可比;v6/v7/v8 同一 FineWeb 留出集、三版彼此可比 3.07/3.01/2.98)。
|
||||
> 实训 token = steps×batch×seq(非数据集大小)。v0–v5 的 val 是同一 1M-token TinyStories 留出集。v6 起换 FineWeb-edu,
|
||||
> 且 v9/v10 追加新 shards 会移动默认 tail-heldout;严格横比改用 fixed eval v1(shard010 tail 1M):
|
||||
> v6/v7/v8/v9/v10 = **3.2328 / 3.1850 / 3.1515 / 2.9278 / 2.8814**。
|
||||
|
||||
---
|
||||
|
||||
## 三、各维度的累积演进(轴向看一条线怎么走的)
|
||||
|
||||
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14,online softmax + flash 式 bwd) → 梯度累积(T16,复用 tape SUM,等效大 batch 而显存随 micro) → dropout(T18,counter-based 设备 RNG + inverted scaling,train/eval 切换)。
|
||||
- **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024**(v8 首拨容量轴,头数 24→32);核心参数 **41K→226M**(总 3.26M→329M)。+QK-norm(T9,Qwen3 兼容) → **真 GQA(T15,`num_kv_heads<num_heads`,repeat_kv broadcast + 组内梯度求和;默认=nh→MHA 逐位回归)**——架构补齐到现代 LLM 标配(MHA/GQA/MQA 一条 `num_kv_heads` 轴),两条 SDPA(composed/flash) 共用同一 broadcast,导出真 `num_key_value_heads` 且 xserv 闭环。
|
||||
- **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024→1280**(v8 首拨容量轴,v9 进入 dim1280);核心参数 **41K→357M**(总 3.26M→486M)。+QK-norm(T9,Qwen3 兼容) → **真 GQA(T15,`num_kv_heads<num_heads`,repeat_kv broadcast + 组内梯度求和;默认=nh→MHA 逐位回归;v9 用 40 query / 10 kv)**——架构补齐到现代 LLM 标配(MHA/GQA/MQA 一条 `num_kv_heads` 轴),两条 SDPA(composed/flash) 共用同一 broadcast,导出真 `num_key_value_heads` 且 xserv 闭环。
|
||||
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024) → flash-attention(T14,不物化 N×N,attention 显存收益随 seq 增长) → 梯度累积(T16,DDP 只在累积边界通信,显存随 micro 不随有效 batch) → process-per-GPU(T17,torchrun 式独立进程/CUDA context,复用 T8 train_rank 零改动)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。T13/T14/T16 是三条**显存杠杆**(重计算压激活峰值、flash 不物化 N×N attention scores、梯度累积解耦有效 batch 与激活显存),可叠加放大有效 batch。**T17 实测=负结果记账**:process-per-GPU 在本尺度对吞吐**中性**(thread ~5.27× vs proc ~5.31×@8,差<1% 噪声),8 卡全 95–99% util ⇒ 残留非线性是 NCCL/PCIe 通信墙、**非**单 context 串行——把 KI-5/T11 doc 长挂的「process-per-GPU 是残留串行的解」猜想实测钉死推翻(方法论同 T11 证伪「分桶 all-reduce」)。
|
||||
- **数据集**:TinyStories 3MB 切片 → 全量 TinyStories(epoch 0.01→5.33,**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**(2.255B 语料,1.02ep)→ **v7 同子集多 epoch(1.45ep,近顶)→ v8 同子集换大模型**(dim1024,1.05ep)。tokenizer 全程 gpt2 BPE(复用 xserv-tokenizer;v6 刻意不换 tokenizer 以隔离「数据来源」变量,KI-4 留后续版本)。
|
||||
- **数据集**:TinyStories 3MB 切片 → 全量 TinyStories(epoch 0.01→5.33,**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**(2.255B 语料,1.02ep)→ **v7 同子集多 epoch(1.45ep,近顶)→ v8 同子集换大模型**(dim1024,1.05ep)→ **v9 扩新 FineWeb shards 到 6.013B token 并同步放大模型** → **v10 补 shard010 到 6.765B token(只拨数据轴)**。tokenizer 全程 gpt2 BPE(复用 xserv-tokenizer;保闭环优先,KI-4 接受)。
|
||||
- **v5→v6 数据轴的质变**:v0–v5 都吃合成幼儿故事(TinyStories,低熵、词汇受控),v5 证明同尺寸模型在它上面已饱和;v6 第一版换成**真实教育类网页文本**(FineWeb-edu),语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。
|
||||
- ⚠️ **同子集多 epoch 也有天花板(v6→v7)**:v6 的 FineWeb val 才训 1.02ep、末步仍单调降,曾被读作「还没喂够」;v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B token),FineWeb val 仅 ↓0.05(3.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象**:**「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」,不在「同数据多读几遍」**。后续要继续降 val,必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。
|
||||
- ⚠️ **val 可比性**:v0–v5 的 val 是同一 TinyStories 1M 留出集(彼此可比);**v6 起换 FineWeb-edu 留出集,分布不同、val 不能和 v0–v5(~1.1)比大小**——真实网页熵高,~3.0 是预期而非回退;**v6/v7/v8 同一 FineWeb 留出集、三版彼此可比**(3.07→3.01→2.98)。v6 的判据还有采样质量 + **transfer eval**(v6→TinyStories val 2.75 vs v5 native 1.11,量化「纯通用数据对窄分布的代价」)。
|
||||
- ⚠️ **val 可比性**:v0–v5 的 val 是同一 TinyStories 1M 留出集(彼此可比);**v6 起换 FineWeb-edu 留出集,分布不同、val 不能和 v0–v5(~1.1)比大小**——真实网页熵高,~3.0 是预期而非回退。v9/v10 追加 shards 后默认 tail-heldout 会移动,不能再只看 moving-tail best。为后续建立 fixed eval v1(shard010 tail 1M):v6/v7/v8/v9/v10 = **3.2328 / 3.1850 / 3.1515 / 2.9278 / 2.8814**。
|
||||
- ⭐ **容量轴有用,但也只有 ~3%(v8)**:v6/v7 在 dim768 上「吃不动更多数据」,v8 用最干净的 A/B 回答了「是数据见够还是容量不够」——**冻结数据子集、纯把 dim768→dim1024(core 127M→226M,+78%)**,同 ~1 epoch 下 FineWeb val **3.07→2.98(↓0.085)**,且 v8(1.05ep)还低于 v7(1.45ep 更多老数据)的 3.01。⇒ **容量有用,v6/v7 部分是 capacity-limited(不全是数据见够)**;放大容量比「给小模型多喂老数据」更值。**但增益只有 ~3%**,与数据轴单步杠杆同量级。
|
||||
- 🧭 **元结论:单轴单步都已 ~3%/lever = 全面边际递减,要双轴一起 scale(Chinchilla 小尺度复现)**:把三条轴并起来看——数据量轴(v5/v7 同子集多 epoch,饱和,~1.6–5%/步)、数据广度轴(v6 换语料,是一次性换分布红利)、容量轴(v8,有用但 ~3%)——**到 v8,任何单轴的单步杠杆都收敛到 ~3%/lever**。而 v8 容量 +78% 却只配同样的 2.36B token、val 末步仍在降 ⇒ 数据立刻成新瓶颈。⇒ **要继续进步,容量与数据必须匹配地一起 scale,而不是单独猛拨一根轴**——这正是 Chinchilla 在这个 toy 尺度上的复现。
|
||||
- ✅ **双轴一起 scale 有效(v9)**:v9 把 v8 的提案落地:模型 core 226M→357M,数据 2.255B 子集→6.013B token(实训 6.012B),best FineWeb val **2.9801→2.8854**,再降 **0.0947 (~3.2%)**。这确认 Chinchilla 式双轴方向正确;但收益仍是 ~3% 级稳健增量,greedy 重复仍在,说明小尺度下“更好 val”尚未完全转化成肉眼质变。
|
||||
- 📌 **只补数据轴边际有限(v10)**:v10 保持 v9 架构,仅补 shard010 到 6.765B token。fixed eval v1 上 v9 2.9278→v10 2.8814,说明新 shard 分布被学到;但 moving-tail best 只从 2.8854→2.8816,且 greedy 复读不变。下一步更值得改模型/context,而不是继续一片片补数据。
|
||||
|
||||
## 三·五、Phase 2 系统栈深度综合(T14–T18 五条特性按四维收束)
|
||||
|
||||
|
||||
152
docs/runs/09-v9-fineweb-edu-dim1280-gqa.md
Normal file
152
docs/runs/09-v9-fineweb-edu-dim1280-gqa.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# Scaling Run v9: Chinchilla 双轴 — dim1280/18L true GQA(core 356.9M) + FineWeb-edu 6.01B token + Phase-2 stack — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
v8 给出的元结论是:单独拨容量轴有用,但只有约 3% 的边际;单独重复旧数据也只有约 1.6% 的边际。要继续明显超过
|
||||
v8,必须把 **模型容量 + 新 token** 一起放大,而不是只拨一根轴。
|
||||
|
||||
v9 就是这个双轴点:
|
||||
|
||||
1. **模型轴**:dim1024/core 226M -> **dim1280/core 356.9M**,同时启用真 GQA(40 query heads / 10 kv heads)。
|
||||
2. **数据轴**:v6-v8 的 2.255B FineWeb 子集 -> **6.013B token**,追加了新 FineWeb-edu shards 003-009。
|
||||
3. **系统栈**:使用 Phase-2 现代路径:`--flash + --accum-steps + bf16 + recompute + DDP`。dropout 设为 0,按标准预训练。
|
||||
|
||||
> v9 的 val 仍是 FineWeb-edu 分布,不能和 v0-v5 的 TinyStories val 直接比。注意:v9 扩展 cache 后默认
|
||||
> tail-heldout 已经从 v6-v8 的旧 tail 移到新 shards 末尾;严格横比后续以 fixed eval v1 为准。
|
||||
|
||||
## Data
|
||||
|
||||
| 项 | 值 |
|
||||
|----|----|
|
||||
| 来源 | FineWeb-edu `sample/10BT`,原 shards 000-002 + 新 shards 003-009 |
|
||||
| token cache | `data/fineweb-edu.txt.u16.bin` |
|
||||
| 总 token | **6,013,639,492** |
|
||||
| held-out val | 末尾 **1,000,000** token |
|
||||
| train corpus | 6,012,639,492 token |
|
||||
| 训练消费 token | **6,012,600,320** = 91745 steps x effective batch 256 x seq 256 |
|
||||
| epoch | ~1.00 |
|
||||
|
||||
P3-DATA 目标本来是约 7B token;shard 010 下载 `curl rc=18` 中断,所以最终停在 6.01B。对 core 356.9M 来说,
|
||||
D/N 约 **16.8 token/param**,低于理想 Chinchilla 20,但已经远高于 v8 的约 10.4,是一个干净的双轴 scale 点。
|
||||
|
||||
## Architecture
|
||||
|
||||
| 项 | v8 | **v9** |
|
||||
|----|----|----|
|
||||
| dim | 1024 | **1280** |
|
||||
| layers | 18 | 18 |
|
||||
| query heads x head_dim | 32 x 32 | **40 x 32** |
|
||||
| kv heads | 32 (MHA) | **10 (true GQA, group=4)** |
|
||||
| ffn | 2730 | **4096** |
|
||||
| core params | 226.50M | **356.89M** |
|
||||
| total params | 329.42M | **485.55M** |
|
||||
| export tensors | 201 | **201** |
|
||||
|
||||
`config.json` writes real `num_key_value_heads = 10`, so xserv loads v9 as true GQA rather than MHA.
|
||||
|
||||
## Training
|
||||
|
||||
| 项 | 值 |
|
||||
|----|----|
|
||||
| optimizer | hand-written AdamW, wd=0.1 |
|
||||
| schedule | warmup -> cosine, max_lr 6e-4 -> min_lr 6e-5 |
|
||||
| grad clip | global norm 1.0 |
|
||||
| steps | **91745** |
|
||||
| effective global batch | **256** (`--batch 128 --accum-steps 2`) |
|
||||
| seq_len | 256 |
|
||||
| precision | bf16 mixed precision, fp32 master |
|
||||
| memory stack | activation recompute + flash-attention + gradient accumulation |
|
||||
| world size | 8 x RTX 5090 |
|
||||
| wall clock | **21h15m** |
|
||||
| steady throughput | **~78.6K tok/s** |
|
||||
| peak observed memory | ~17GB / GPU |
|
||||
|
||||
Command:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 cargo run -p xtrain-distributed --release --bin train_ddp -- \
|
||||
/opt/wjh/models/gpt2/tokenizer.json data/fineweb-edu.txt \
|
||||
--heads 40 --head-dim 32 --kv-heads 10 --layers 18 --ffn 4096 \
|
||||
--steps 91745 --batch 128 --accum-steps 2 --seq 256 \
|
||||
--max-lr 6e-4 --min-lr 6e-5 --val-tokens 1000000 --eval-every 1000 \
|
||||
--eval-batches 64 --bf16 --recompute --flash --dropout 0.0 \
|
||||
--ckpt /dashscope-tmp/wjh/xtrain_v9.ckpt
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
- train loss: **11.1550 -> 2.9340**
|
||||
- first val: step 1000 = **5.1517**
|
||||
- best val: step 91000 = **2.8854**
|
||||
- final val: step 91745 = **2.8873**
|
||||
- exit code: **0**
|
||||
|
||||
FineWeb val curve milestones:
|
||||
|
||||
| step | 1000 | 10000 | 20000 | 30000 | 40000 | 50000 | 60000 | 70000 | 80000 | 90000 | 91000 | final |
|
||||
|------|------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
|
||||
| val | 5.1517 | 3.4820 | 3.2953 | 3.2026 | 3.1422 | 3.0844 | 3.0148 | 2.9616 | 2.9160 | 2.8915 | **2.8854** | 2.8873 |
|
||||
|
||||
The curve kept improving into the last 1K-step window, then the final eval bounced slightly from 2.8854 to 2.8873. This is close to
|
||||
the floor for this run, but not a clear overfit failure.
|
||||
|
||||
## Comparison
|
||||
|
||||
| | v6 | v7 | v8 | **v9** |
|
||||
|---|---|---|---|---|
|
||||
| model | dim768/core127M | dim768/core127M | dim1024/core226M | **dim1280/core357M + GQA** |
|
||||
| data | 2.29B | 3.28B same subset | 2.36B same subset | **6.01B expanded shards** |
|
||||
| best val | 3.0652 | 3.0149 | 2.9801 | **2.8854** |
|
||||
|
||||
On the run-local moving tail, v9 beats v8 by **0.0947** val loss (~3.2% relative), essentially the same size as the
|
||||
v6->v8 capacity gain but now on top of it. A later fixed eval v1 check still supports the same direction
|
||||
(v8 3.1515 -> v9 2.9278 on shard010-tail holdout), while making the moving-tail caveat explicit. This confirms
|
||||
the v8 prediction: **双轴 scale 有效**. It is still an incremental gain, not a qualitative jump.
|
||||
|
||||
## Samples
|
||||
|
||||
xserv greedy samples (`--max-tokens 60`) are more coherent than the v8 examples on some prompts, but repetition remains:
|
||||
|
||||
```text
|
||||
[The history of] the United States is the story of the people, the places, and the events that have shaped the nation...
|
||||
[In science,] the term "scientific method" is used to describe the process of gathering information and testing it...
|
||||
[The most important] thing is to be aware of the symptoms and to seek medical attention...
|
||||
[Water is] a natural resource that is essential for human life...
|
||||
```
|
||||
|
||||
The model writes real explanatory English and the domain mix is FineWeb-like. Greedy decoding still falls into repeated clauses on
|
||||
some prompts (`scientific method`, symptoms, and earlier fixed prompts), so the val gain is more visible in the metric than in a
|
||||
dramatic sample-quality leap.
|
||||
|
||||
## xserv validation
|
||||
|
||||
Registry path:
|
||||
|
||||
```text
|
||||
/opt/wjh/projects/tiny-models/v9-fineweb-edu-dim1280-gqa
|
||||
```
|
||||
|
||||
Files:
|
||||
|
||||
- `config.json`
|
||||
- `model.safetensors` (BF16, 201 tensors, 927MB)
|
||||
- `tokenizer.json`
|
||||
- `xtrain.ckpt` (fp32 master checkpoint, 1.9GB)
|
||||
- `RUN.md`
|
||||
|
||||
xserv loads v9 as:
|
||||
|
||||
```text
|
||||
Model: qwen3, layers=18, hidden=1280, heads=40/10 kv, vocab=50257
|
||||
Loaded 201 tensors
|
||||
Ready (KV cache, dtype=bf16).
|
||||
```
|
||||
|
||||
Token-match check against xtrain greedy (`max-tokens 40`):
|
||||
|
||||
- `Once upon a time`: xtrain and xserv matched through the checked continuation.
|
||||
- `One day`: diverged after "large, dark," (`very tall man` vs `metallic object`) from BF16 greedy tie sensitivity.
|
||||
- `The little`: same repetitive pattern, with a short BF16 path divergence.
|
||||
|
||||
This is the same class of BF16-vs-f32 greedy drift seen in v8; the important integration result is that xserv successfully loads
|
||||
true GQA (`kv_heads=10 < heads=40`) and generates from the exported weights.
|
||||
200
docs/runs/10-v10-fineweb-edu-dim1280-gqa-data6765.md
Normal file
200
docs/runs/10-v10-fineweb-edu-dim1280-gqa-data6765.md
Normal file
@@ -0,0 +1,200 @@
|
||||
# Scaling Run v10: Data-axis follow-up — dim1280/18L true GQA + FineWeb-edu 6.765B token — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
v9 证明了双轴 scale(更大模型 + 更多新 token)有效:best val 从 v8 的 2.9801 降到 2.8854。
|
||||
但 v9 的数据量只有 6.013B token,D/N 约 16.8,低于 Chinchilla 经验里的 20。v10 的目标很窄:
|
||||
|
||||
1. **只补数据轴**:补上 v9 中断的 FineWeb-edu shard010,把 cache 从 6.013B 推到 6.765B。
|
||||
2. **架构不变**:完全复用 v9 dim1280 / 18L / 40q-10kv GQA / ffn4096。
|
||||
3. **验证边际**:看 D/N 从 16.8 到 18.95 是否还能显著降低 val。
|
||||
|
||||
## Data
|
||||
|
||||
| 项 | 值 |
|
||||
|----|----|
|
||||
| 来源 | FineWeb-edu `sample/10BT`,shards 000-010 |
|
||||
| token cache | `data/fineweb-edu.txt.u16.bin` |
|
||||
| 总 token | **6,765,333,808** |
|
||||
| held-out val | 末尾 **1,000,000** token |
|
||||
| train corpus | 6,764,333,808 token |
|
||||
| 训练消费 token | **6,764,298,240** = 103215 steps x effective batch 256 x seq 256 |
|
||||
| epoch | ~1.00 |
|
||||
|
||||
Important caveat: xtrain 当前训练入口用“全 cache 的末尾 1M token”做 held-out。追加 shard010 后,v10 的 val tail
|
||||
和 v9 的 val tail 不再是同一个切片。因此 v9 原报告的 2.8854 与 v10 原报告的 2.8816 不能被当作严格同一
|
||||
验证集上的横比。
|
||||
|
||||
为了解决这个问题,本轮创建了固定 eval set:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_fixed_eval_v1/fineweb-fixed-eval-v1.txt.u16.bin
|
||||
```
|
||||
|
||||
它包含 shard010 末尾 11M token;前 10M token 只是为了复用现有 `split_tail(val_tokens=1M)`,真正 eval 的是最后
|
||||
1M token。该 fixed eval v1 对 v6-v9 都是未见数据;对 v10 也是训练时 held-out。
|
||||
|
||||
## Architecture
|
||||
|
||||
v10 与 v9 完全相同:
|
||||
|
||||
| 项 | 值 |
|
||||
|----|----|
|
||||
| dim | 1280 |
|
||||
| layers | 18 |
|
||||
| query heads x head_dim | 40 x 32 |
|
||||
| kv heads | 10 (true GQA, group=4) |
|
||||
| ffn | 4096 |
|
||||
| core params | 356.89M |
|
||||
| total params | 485.55M |
|
||||
| export tensors | 201 |
|
||||
|
||||
## Training
|
||||
|
||||
| 项 | 值 |
|
||||
|----|----|
|
||||
| optimizer | hand-written AdamW, wd=0.1 |
|
||||
| schedule | warmup -> cosine, max_lr 6e-4 -> min_lr 6e-5 |
|
||||
| grad clip | global norm 1.0 |
|
||||
| steps | **103215** |
|
||||
| effective global batch | **256** (`--batch 128 --accum-steps 2`) |
|
||||
| seq_len | 256 |
|
||||
| precision | bf16 mixed precision, fp32 master |
|
||||
| memory stack | activation recompute + flash-attention + gradient accumulation |
|
||||
| world size | 8 x RTX 5090 |
|
||||
| wall clock | **23h51m** |
|
||||
| steady throughput | **~79.0K tok/s** |
|
||||
| peak observed memory | ~17GB / GPU |
|
||||
|
||||
Command:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 cargo run -p xtrain-distributed --release --bin train_ddp -- \
|
||||
/opt/wjh/models/gpt2/tokenizer.json data/fineweb-edu.txt \
|
||||
--heads 40 --head-dim 32 --kv-heads 10 --layers 18 --ffn 4096 \
|
||||
--steps 103215 --batch 128 --accum-steps 2 --seq 256 \
|
||||
--max-lr 6e-4 --min-lr 6e-5 --val-tokens 1000000 --eval-every 1000 \
|
||||
--eval-batches 64 --bf16 --recompute --flash --dropout 0.0 \
|
||||
--ckpt /dashscope-tmp/wjh/xtrain_v10.ckpt
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
- train loss: **11.1575 -> 2.9000**
|
||||
- first val: step 999 = **5.3048**
|
||||
- best val: step 103214 = **2.8816**
|
||||
- final val: step 103214 = **2.8816**
|
||||
- exit code: **0**
|
||||
|
||||
FineWeb moving-tail val milestones:
|
||||
|
||||
| step | 999 | 9999 | 19999 | 29999 | 39999 | 49999 | 59999 | 69999 | 79999 | 89999 | 99999 | final |
|
||||
|------|-----|------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
|
||||
| val | 5.3048 | 3.5622 | 3.3282 | 3.2450 | 3.1886 | 3.1342 | 3.0714 | 3.0202 | 2.9724 | 2.9236 | 2.8950 | **2.8816** |
|
||||
|
||||
The curve still improved at the final eval. There is no overfit signal in this run.
|
||||
|
||||
## Fixed Eval V1
|
||||
|
||||
Fixed eval v1 (`shard010 tail 1M`, seq256, 64 eval batches):
|
||||
|
||||
| version | fixed eval v1 |
|
||||
|---------|---------------|
|
||||
| v6 | 3.2328 |
|
||||
| v7 | 3.1850 |
|
||||
| v8 | 3.1515 |
|
||||
| v9 | 2.9278 |
|
||||
| **v10** | **2.8814** |
|
||||
|
||||
This is the cleanest cross-version result in the v10 round. It says:
|
||||
|
||||
- v9's double-axis gain transfers to a shard010 holdout: v8 3.1515 -> v9 2.9278.
|
||||
- v10 further improves on the new shard010 distribution: v9 2.9278 -> v10 2.8814.
|
||||
- The apparent v9 moving-tail 2.8854 -> v10 moving-tail 2.8816 delta is tiny and not strict apples-to-apples.
|
||||
|
||||
## Decoding
|
||||
|
||||
Greedy decoding still repeats. Fixed prompts from xtrain:
|
||||
|
||||
```text
|
||||
[Once upon a time] there was a king who had a daughter. She was beautiful and beautiful...
|
||||
[The little] The little boy was a little boy. The little boy was a little boy...
|
||||
[One day] I was walking down the street and I saw a man with a dog...
|
||||
```
|
||||
|
||||
Temperature 0.8 is more varied and less immediately looped, but coherence remains weak:
|
||||
|
||||
```text
|
||||
[Once upon a time] I was a kid who did not go to the beach to swim...
|
||||
[The little] ones are not as loud as the adults...
|
||||
[One day] I was on the edge of the water, and I saw something I had never seen before...
|
||||
```
|
||||
|
||||
xserv loads the exported v10 true-GQA weights and generates FineWeb-like explanatory prose, but repeated sentence frames remain:
|
||||
|
||||
```text
|
||||
[The history of] the city of San Francisco is a story of the growth of the city...
|
||||
[In science,] the term "observation" is used to describe the act of observing something...
|
||||
[Water is] the most important element in the human body...
|
||||
```
|
||||
|
||||
Conclusion: decoding remains a separate bottleneck. The current xtrain sampler only supports greedy and temperature sampling; top-p and
|
||||
repetition penalty exist in xserv's chat path, but not in the raw xtrain sampler or `xserv-cli` path used for weight validation. A clean
|
||||
next step is to add a raw generation tool with `temperature/top-p/repetition-penalty` so decoding experiments do not depend on chat
|
||||
templates.
|
||||
|
||||
## xserv Validation
|
||||
|
||||
Registry path:
|
||||
|
||||
```text
|
||||
/opt/wjh/projects/tiny-models/v10-fineweb-edu-dim1280-gqa-data6765
|
||||
```
|
||||
|
||||
Files:
|
||||
|
||||
- `config.json`
|
||||
- `model.safetensors` (BF16, 201 tensors, 927MB)
|
||||
- `tokenizer.json`
|
||||
- `xtrain.ckpt` (fp32 master checkpoint, 1.9GB)
|
||||
|
||||
xserv loads v10 as true GQA:
|
||||
|
||||
```text
|
||||
Model: qwen3, layers=18, hidden=1280, heads=40/10 kv, vocab=50257
|
||||
Loaded 201 tensors
|
||||
Ready (KV cache, dtype=bf16).
|
||||
```
|
||||
|
||||
## v11 Feasibility: Bigger Model + Longer Context
|
||||
|
||||
A v11 smoke test prioritized the user's chosen direction: larger model plus longer context.
|
||||
|
||||
Candidate:
|
||||
|
||||
| item | value |
|
||||
|------|-------|
|
||||
| dim / layers | 1536 / 20 |
|
||||
| heads / kv_heads | 48 / 12 |
|
||||
| ffn | 6144 |
|
||||
| core / total params | 684.26M / 838.65M |
|
||||
| stack | bf16 + recompute + flash + accum + 8 GPU DDP |
|
||||
|
||||
Smoke results:
|
||||
|
||||
| seq | batch / accum | effective batch | peak mem | tok/s | result |
|
||||
|-----|---------------|-----------------|----------|-------|--------|
|
||||
| 512 | 64 / 4 | 256 | **30530 MiB** | **44.7K** | 50 steps OK |
|
||||
| 1024 | 32 / 8 | 256 | **30530 MiB** | **31.0K** | 20 steps OK |
|
||||
|
||||
Both fit, but the memory margin is thin on 32GB RTX 5090. Expected one-epoch wall clock on 6.76B tokens:
|
||||
|
||||
- seq512: roughly **42h**
|
||||
- seq1024: roughly **61h**
|
||||
|
||||
Recommendation: make v11 a controlled run, not a blind launch. Use fixed eval v1, keep data fixed, and choose either:
|
||||
|
||||
1. **v11a practical**: dim1536/20L, seq512, batch64/accum4. Faster, still doubles context over v10.
|
||||
2. **v11b long-context**: dim1536/20L, seq1024, batch32/accum8. More aligned with "long context", but ~2.5 days and tight memory.
|
||||
|
||||
For scientific clarity, v11 should not append more data before training; use the current 6.765B train cache while preserving fixed eval v1.
|
||||
251
docs/runs/12-v12-1b-longctx-chat-alpha.md
Normal file
251
docs/runs/12-v12-1b-longctx-chat-alpha.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# Scaling Run v12: 1B-class long-context base → chat-alpha-v2 — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
v11 proved that a larger `dim1536/20L` model can train at `seq1024` on dash5, and it improved the fixed-eval-data-v1, long-context (`seq1024`) score to **2.7467**. It also proved the current bottleneck: greedy generation still repeats, and broad real-data SFT on top of v11 regressed chat quality despite lower SFT validation loss.
|
||||
|
||||
v12 therefore separates the next phase into two gates:
|
||||
|
||||
1. **Base gate**: train a stronger English base model around 1B total params with `seq1024`, using the existing FineWeb-edu 6.765B-token cache and fixed eval data v1.
|
||||
2. **Chat gate**: only after the base gate is healthy, run assistant-only English SFT and judge it with fixed prompt generation, not SFT loss alone.
|
||||
|
||||
Success means the model is serviceable enough for a small chat-alpha: stable base loss, lower fixed eval than v11, less repetitive fixed generation, and SFT that improves instruction behavior without destroying arithmetic/refusal/debug prompts.
|
||||
|
||||
## Baseline: What v11 Taught Us
|
||||
|
||||
| item | v11 |
|
||||
|------|-----|
|
||||
| arch | dim1536 / 20L / 48q-12kv GQA / ffn6144 |
|
||||
| params | 684.26M core / 838.65M total |
|
||||
| data | FineWeb-edu 6.765B token, 1 epoch |
|
||||
| context | seq1024 |
|
||||
| throughput | ~30.96K tok/s on 8 x RTX 5090 |
|
||||
| fixed eval data v1, seq1024 | **2.7467** |
|
||||
| issue | greedy repetition remains; direct real SFT regressed generation quality |
|
||||
|
||||
SFT result from v11:
|
||||
|
||||
| model | train result | generation result |
|
||||
|-------|--------------|-------------------|
|
||||
| `v11-chat-alpha-sft-v2-anchor` | synthetic assistant-only anchor | current best narrow chat-alpha |
|
||||
| `v11-chat-alpha-real-sft-v1` | SFT val 1.4272 | bad hallucination, math failure |
|
||||
| `v11-chat-alpha-real-mix-v1` | SFT val 2.0543 | better than direct real-SFT, still worse than anchor |
|
||||
|
||||
Conclusion: SFT data quality matters, but v11's base is still too weak for broad real SFT to become a general chat model.
|
||||
|
||||
## Architecture
|
||||
|
||||
v12 target: slightly above 1B total params while staying close to the proven v11 shape and keeping GQA group size 4.
|
||||
|
||||
| item | value |
|
||||
|------|-------|
|
||||
| dim | **1664** |
|
||||
| layers | **22** |
|
||||
| query heads x head_dim | **52 x 32** |
|
||||
| kv heads | **13** |
|
||||
| GQA group | 4 |
|
||||
| ffn | **6656** |
|
||||
| core params | **883.4M** |
|
||||
| embed + lm_head | **167.3M** |
|
||||
| total params | **1.0506B** |
|
||||
|
||||
Why this shape:
|
||||
|
||||
- It is a controlled step from v11 rather than a new architecture family.
|
||||
- `52/13` preserves true GQA with group 4.
|
||||
- Total params are near the requested 1B target.
|
||||
- `dim1664` is less aggressive than `dim1792/22L` and has a better chance to fit `seq1024` on 32GB 5090s.
|
||||
|
||||
## Data
|
||||
|
||||
Base pretraining stays English-oriented and uses the current token cache. Pass the `.txt` stem to xtrain; `Corpus::load_cached` appends `.u16.bin` internally.
|
||||
|
||||
```text
|
||||
/opt/wjh/projects/xtrain/data/fineweb-edu.txt
|
||||
cache = /opt/wjh/projects/xtrain/data/fineweb-edu.txt.u16.bin
|
||||
tokens = 6,765,333,808
|
||||
```
|
||||
|
||||
Training uses the last 1M tokens as moving-tail validation. Every cross-version v12 claim must also run fixed eval data v1 with the long-context `seq1024` setting, matching the v11 `eval_v11_seq1024.log` score of **2.7467**. This is distinct from the older v10 table that used the same fixed eval data with `seq256`.
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_fixed_eval_v1/fineweb-fixed-eval-v1.txt
|
||||
cache = /dashscope-tmp/wjh/xtrain_fixed_eval_v1/fineweb-fixed-eval-v1.txt.u16.bin
|
||||
```
|
||||
|
||||
No new FineWeb shards are added in this phase. The experiment is model/context scale, not another data-axis change.
|
||||
|
||||
## Training Plan
|
||||
|
||||
Primary v12 run:
|
||||
|
||||
| item | value |
|
||||
|------|-------|
|
||||
| world | 8 x RTX 5090 on dash5 |
|
||||
| precision | bf16 mixed precision, fp32 master |
|
||||
| memory stack | recompute + flash + grad accumulation |
|
||||
| seq | **1024** |
|
||||
| micro global batch | **16** (2 sequences/rank) |
|
||||
| accum | **15** |
|
||||
| effective global batch | **240** |
|
||||
| tokens/step | **245,760** |
|
||||
| full steps | **27,524** |
|
||||
| max_lr → min_lr | **4e-4 → 4e-5** |
|
||||
| eval | moving-tail 1M every 500 steps; fixed eval data v1 at seq1024 after checkpoints |
|
||||
| smoke throughput | **~24.5K tok/s** |
|
||||
| estimated full wall clock | **~76-78h** |
|
||||
|
||||
The reduced micro-batch is intentional: v11 `seq1024` with global batch 32 already sat near the 5090 memory limit. v12 has larger weights; an initial `batch24/accum10` smoke OOMed after step 0, while `batch16/accum15` passed a 10-step smoke at ~29.4GB/GPU and preserved the same 245,760 tokens/step.
|
||||
|
||||
Command wrapper:
|
||||
|
||||
```sh
|
||||
scripts/run_v12_phase.sh start-pilot
|
||||
scripts/run_v12_phase.sh start-full
|
||||
scripts/run_v12_phase.sh status
|
||||
scripts/run_v12_phase.sh eval-fixed
|
||||
scripts/run_v12_phase.sh export
|
||||
scripts/run_v12_phase.sh sample
|
||||
```
|
||||
|
||||
## Gates
|
||||
|
||||
### Gate 0: build and smoke
|
||||
|
||||
Run:
|
||||
|
||||
```sh
|
||||
scripts/run_v12_phase.sh smoke
|
||||
```
|
||||
|
||||
Pass criteria:
|
||||
|
||||
- no CUDA OOM
|
||||
- no NaN loss
|
||||
- first 30 steps decrease from initialization
|
||||
- peak memory leaves enough margin for eval
|
||||
|
||||
### Gate 1: pilot
|
||||
|
||||
Run:
|
||||
|
||||
```sh
|
||||
scripts/run_v12_phase.sh start-pilot
|
||||
```
|
||||
|
||||
Default pilot is 300 steps with held-out eval every 100 steps.
|
||||
|
||||
Pass criteria:
|
||||
|
||||
- train loss decreases smoothly
|
||||
- grad norm does not spike persistently
|
||||
- moving-tail eval is finite and improving
|
||||
- checkpoint can be reloaded by `eval-fixed`
|
||||
|
||||
### Gate 2: full base
|
||||
|
||||
Run only after the pilot passes:
|
||||
|
||||
```sh
|
||||
scripts/run_v12_phase.sh start-full
|
||||
```
|
||||
|
||||
Pass criteria:
|
||||
|
||||
- fixed eval data v1 at `seq1024` beats v11's **2.7467**
|
||||
- generation samples improve or at least do not regress on repetition
|
||||
- checkpoint exports and xserv loads the true GQA config
|
||||
|
||||
### Gate 3: chat-alpha SFT
|
||||
|
||||
After a healthy v12 base:
|
||||
|
||||
1. Use assistant-only SFT (`--sft-tsv`) with English-only data.
|
||||
2. Start from narrow anchors first, then mix in Smol-SmolTalk.
|
||||
3. Judge with fixed generation prompts before calling it useful.
|
||||
|
||||
The primary high-quality source remains `HuggingFaceTB/smol-smoltalk` filtered to English single-turn examples, with local anchors preserved to keep deterministic behavior.
|
||||
|
||||
## Evaluation
|
||||
|
||||
Base metrics:
|
||||
|
||||
- moving-tail val during training
|
||||
- fixed eval data v1 at `seq1024`
|
||||
- xtrain fixed prompt samples from `scripts/chat_alpha_fixed_prompts.txt`
|
||||
- xserv exported-model smoke
|
||||
|
||||
Chat metrics:
|
||||
|
||||
- fixed prompt answers for SFT explanation, SFT data provenance, arithmetic, refusal, repetition-debug checklist, summary, and simple code generation
|
||||
- compare against `v11-chat-alpha-sft-v2-anchor`
|
||||
- reject models that lower SFT validation loss but hallucinate more in fixed prompts
|
||||
|
||||
## Artifacts
|
||||
|
||||
Expected paths:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_v12/
|
||||
/dashscope-tmp/wjh/xtrain_v12/xtrain_v12_pilot.ckpt
|
||||
/dashscope-tmp/wjh/xtrain_v12/xtrain_v12.ckpt
|
||||
/opt/wjh/projects/tiny-models/v12-fineweb-edu-1b-longctx
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### Gate 0/1: smoke + pilot
|
||||
|
||||
- `batch24/accum10` smoke OOMed after step 0.
|
||||
- `batch16/accum15` smoke passed 10 steps: train loss **11.2347 -> 7.9459**, ~24.5K tok/s, ~29.4GB/GPU.
|
||||
- 300-step pilot passed: train loss **11.2296 -> 5.4832**, val **6.5810 -> 5.9642 -> 5.5888**, exit code 0.
|
||||
- Pilot checkpoint reload matched final val: fixed eval data v1 at seq1024 = **5.5891**.
|
||||
- Fixed chat prompts still repeat heavily, as expected for a 300-step base; use them as a regression baseline, not as chat quality.
|
||||
|
||||
### Gate 2: full base
|
||||
|
||||
Full run completed on dash5:
|
||||
|
||||
| item | result |
|
||||
|------|--------|
|
||||
| wall clock | **81h01m** |
|
||||
| throughput | **~24.55K tok/s** |
|
||||
| train loss | **11.2294 -> 2.6696** |
|
||||
| moving-tail best val | **2.7411** |
|
||||
| moving-tail final val | **2.7412** |
|
||||
| fixed eval data v1, seq1024 reload | **2.7410** |
|
||||
| exit code | **0** |
|
||||
|
||||
Validation milestones:
|
||||
|
||||
| step | 499 | 999 | 1499 | 1999 | 2499 | 21999 | 23999 | 25999 | 26999 | 27499 | final |
|
||||
|------|-----|-----|------|------|------|-------|-------|-------|-------|-------|-------|
|
||||
| val | 5.3029 | 4.4079 | 3.9287 | 3.6964 | 3.5555 | 2.7805 | 2.7637 | 2.7468 | 2.7443 | **2.7411** | 2.7412 |
|
||||
|
||||
Compared with v11's fixed eval data v1 at seq1024 (**2.7467**), v12 reaches **2.7410** after reload. This is a real but very small gain
|
||||
(~0.006 absolute), despite the parameter increase from 838.65M to 1.0506B total and the slower 24.55K tok/s throughput. The result says the
|
||||
larger 1B-class base is viable and marginally better, but this scale step did not produce a qualitative base-model jump.
|
||||
|
||||
Generation:
|
||||
|
||||
- Raw FineWeb-style prompts are better than the pilot checkpoint and can produce plausible explanatory prose.
|
||||
- Greedy repetition remains visible, especially on story-like prompts.
|
||||
- Chat prompts are not reliable without SFT: SFT data provenance is hallucinated, arithmetic still fails, and the model repeats template-like text.
|
||||
- xserv loads the export correctly as true GQA: `layers=22, hidden=1664, heads=52/13 kv`.
|
||||
|
||||
Exported model:
|
||||
|
||||
```text
|
||||
/opt/wjh/projects/tiny-models/v12-fineweb-edu-1b-longctx
|
||||
```
|
||||
|
||||
Files:
|
||||
|
||||
- `config.json`
|
||||
- `model.safetensors` (2.0GB)
|
||||
- `tokenizer.json`
|
||||
- `xtrain.ckpt` (4.0GB)
|
||||
|
||||
Conclusion: v12 passes the base gate and is a better SFT starting point than v11 by metric, but the gain is narrow. The next step should be
|
||||
assistant-only chat SFT from v12 with conservative anchors first, then a small Smol-SmolTalk mix. Do not expect the base checkpoint itself to
|
||||
serve as a usable chat model.
|
||||
180
docs/runs/13-v12-chat-sft-quality.md
Normal file
180
docs/runs/13-v12-chat-sft-quality.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# v12 Chat SFT Quality Check
|
||||
|
||||
Date: 2026-06-29
|
||||
|
||||
## Goal
|
||||
|
||||
Turn the completed v12 1.05B base checkpoint into a usable chat-alpha model with
|
||||
SFT, then judge whether it is stable enough to call a high-quality chat model.
|
||||
|
||||
Base checkpoint:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_v12/xtrain_v12.ckpt
|
||||
```
|
||||
|
||||
Architecture:
|
||||
|
||||
```text
|
||||
dim=1664 layers=22 heads=52 kv_heads=13 head_dim=32 ffn=6656
|
||||
total params=1.0506B
|
||||
```
|
||||
|
||||
## Stage A: Synthetic SFT
|
||||
|
||||
Data:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_alpha_v2/chat_alpha_v2_sft.tsv
|
||||
211,257 examples, about 14.96M SFT tokens
|
||||
```
|
||||
|
||||
Run:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_alpha_v2/chat_alpha_v12_v2.ckpt
|
||||
```
|
||||
|
||||
Metrics:
|
||||
|
||||
```text
|
||||
train loss: 3.5730 -> 0.0426
|
||||
eval: step39 0.1078, step79 0.0582, step119 0.0466, step159 0.0423,
|
||||
step199 0.0403, step239 0.0390, step279 0.0389, step319 0.0378
|
||||
best/final val loss: 0.0378
|
||||
```
|
||||
|
||||
Export:
|
||||
|
||||
```text
|
||||
/opt/wjh/projects/tiny-models/v12-chat-alpha-sft-v2
|
||||
```
|
||||
|
||||
Quality notes:
|
||||
|
||||
- Learns the User/Assistant format and usually stops correctly.
|
||||
- Too narrow and template-heavy.
|
||||
- Fails basic math and code prompts in fixed greedy evaluation.
|
||||
|
||||
## Stage B: Anchor SFT
|
||||
|
||||
Data:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_alpha_v2_anchor/chat_alpha_v2_anchor.tsv
|
||||
32,020 examples, about 1.73M SFT tokens
|
||||
```
|
||||
|
||||
Run:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_anchor/chat_alpha_v12_anchor.ckpt
|
||||
```
|
||||
|
||||
Metrics:
|
||||
|
||||
```text
|
||||
train loss: 1.7777 -> 0.1165
|
||||
eval: step19 0.3447, step39 0.1449, step59 0.1217, step79 0.1158
|
||||
best/final val loss: 0.1158
|
||||
```
|
||||
|
||||
Export:
|
||||
|
||||
```text
|
||||
/opt/wjh/projects/tiny-models/v12-chat-alpha-sft-v2-anchor
|
||||
```
|
||||
|
||||
Generation artifacts:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_anchor/generation/anchor_xserv_greedy.txt
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_anchor/generation/anchor_diagnostic_greedy.txt
|
||||
```
|
||||
|
||||
Quality notes:
|
||||
|
||||
- Better project-context answers and summaries than synthetic-only.
|
||||
- Still unreliable on basic multiplication, yes/no facts, translation, and code.
|
||||
- Overuses "cannot verify" style answers outside appropriate uncertainty cases.
|
||||
|
||||
## Stage C: Real-Mix Repair
|
||||
|
||||
Data:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_real_mix_v1/smol_smoltalk_real_mix.tsv
|
||||
96,287 examples, about 25.3M SFT tokens
|
||||
```
|
||||
|
||||
Run:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_real_mix_repair/chat_alpha_v12_real_mix_repair.ckpt
|
||||
```
|
||||
|
||||
Training setup:
|
||||
|
||||
```text
|
||||
init=/dashscope-tmp/wjh/xtrain_sft_v12_anchor/chat_alpha_v12_anchor.ckpt
|
||||
steps=200
|
||||
seq=512
|
||||
batch=32
|
||||
accum=8
|
||||
effective batch=256
|
||||
lr=1e-6 -> 2e-7
|
||||
```
|
||||
|
||||
Metrics:
|
||||
|
||||
```text
|
||||
train loss: 2.7391 -> 2.0384
|
||||
eval: step49 2.1964, step99 2.0383, step149 1.9801, step199 1.9570
|
||||
best/final val loss: 1.9570
|
||||
```
|
||||
|
||||
Export:
|
||||
|
||||
```text
|
||||
/opt/wjh/projects/tiny-models/v12-chat-alpha-real-mix-repair
|
||||
```
|
||||
|
||||
Generation artifacts:
|
||||
|
||||
```text
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_real_mix_repair/generation/real_mix_repair_xserv_greedy.txt
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_real_mix_repair/generation/real_mix_repair_diagnostic_greedy.txt
|
||||
/dashscope-tmp/wjh/xtrain_sft_v12_real_mix_repair/generation/real_mix_repair_diagnostic_greedy_reppenalty1.txt
|
||||
```
|
||||
|
||||
Quality notes:
|
||||
|
||||
- Loss improved cleanly and the model kept chat formatting.
|
||||
- Fixed prompt math `17% of 240` improved in the standard suite.
|
||||
- General diagnostic math still fails, e.g. `12 * 13`.
|
||||
- Code generation remains unusable for simple Python function prompts.
|
||||
- Some outputs contain corrupted or off-topic fragments.
|
||||
- Reducing repeat penalty from 1.15 to 1.0 did not fix the failures.
|
||||
|
||||
## Verdict
|
||||
|
||||
The SFT pipeline works, and v12 can be turned into a chat-shaped model that follows
|
||||
the prompt format and stops correctly. However, none of the three SFT variants is a
|
||||
stable high-quality chat model yet.
|
||||
|
||||
The limiting issue is no longer infrastructure. It is data and objective quality:
|
||||
the current synthetic/anchor data is too narrow, while the current real-mix data
|
||||
adds breadth but also noisy or low-quality behavior. Validation loss alone is not a
|
||||
sufficient selection signal for chat quality.
|
||||
|
||||
## Recommended Next Step
|
||||
|
||||
Build a smaller, higher-precision SFT curriculum before another large run:
|
||||
|
||||
1. Keep the anchor data, but reduce over-refusal templates.
|
||||
2. Add verified small instruction sets for math, code, translation, summarization,
|
||||
and closed-book common facts.
|
||||
3. Add an automatic fixed-prompt eval harness that scores exact-match math, simple
|
||||
code syntax, refusal appropriateness, stop-token behavior, and corruption.
|
||||
4. Train a short curriculum from the v12 base or v12 anchor checkpoint, then pick
|
||||
by generation eval rather than SFT loss alone.
|
||||
@@ -22,6 +22,10 @@ val loss 一栏给的是各版**各自训练 run 报告的 best val**(held-out
|
||||
**v8 改测容量轴**:同 v6/v7 子集、纯把 dim768→dim1024(core 127M→226M),FineWeb val 3.07/3.01→**2.98** ⇒
|
||||
**容量有用**(v6/v7 部分 capacity-limited);但增益仅 ~3%、val 末步仍在降未饱和 ⇒ **到 v8,数据轴与容量轴的
|
||||
单步杠杆都收敛到 ~3%/lever = 全面边际递减,要双轴一起 scale**(Chinchilla,详见 [08-v8](08-v8-fineweb-edu-dim1024.md))。
|
||||
**v9 兑现双轴**:dim1024→dim1280(core 226M→357M)并把 FineWeb token 从 2.255B 子集扩到 6.013B,
|
||||
best val **2.8854**,相比 v8 再降 0.0947(~3.2%)。结论:双轴 scale 有效,但仍是稳健增量而非质变。
|
||||
**v10 只补数据轴**:同 v9 架构,只补 shard010 到 6.765B token,moving-tail best/final val **2.8816**。
|
||||
注意追加 shard 会移动 held-out tail;固定 eval v1 上 v6→v10 为 **3.2328 / 3.1850 / 3.1515 / 2.9278 / 2.8814**。
|
||||
|
||||
⚠️ **v6 起换了保留集(语料)**:v0–v5 的 val 都是 **TinyStories** 1M 留出集(彼此可比);v6 换成纯
|
||||
**FineWeb-edu**(真实网页文本),它的 val(3.07)是**另一把尺子上的另一个分布**,**不能**和 v0–v5 的
|
||||
@@ -39,18 +43,15 @@ val loss 一栏给的是各版**各自训练 run 报告的 best val**(held-out
|
||||
| [v6-fineweb-edu-dim768](06-v6-fineweb-edu-dim768.md) | **FineWeb-edu** 真实网页 (2.255B 语料) | ~2.29B | ~1.02 | 768 / 18 / 24·32 / 2048 (**同 v4/v5**) | 127.43M | 204.63M | **3.0652** ⚠️*(FineWeb val,与上不可比)* | **第一版脱离 TinyStories**,唯一变量=数据来源 + 8 卡 DDP bf16;~1.9h/8 卡 ~218K tok/s。**val 是另一分布**(真实网页熵高,~3.0 是预期非回退),判据=采样质量+transfer。FineWeb val 末步仍单调降=未饱和;**transfer**: v6→TinyStories val **2.75**(v5 native 1.11),纯通用数据对窄分布有代价。采样: v6 写真实说明文 vs v5 一律掉进小故事 |
|
||||
| [v7-fineweb-edu-dim768](07-v7-fineweb-edu-dim768.md) | **同 v6 的 2.255B FineWeb-edu 子集**(非新数据) | ~3.28B | ~1.45 | 768 / 18 / 24·32 / 2048 (**同 v4/v5/v6**) | 127.43M | 204.63M | **3.0149** *(FineWeb val,与 v6 可比)* | **唯一变量=epoch 数**(1.02→1.45) + 8 卡 DDP bf16;~4.2h/8 卡 ~218K tok/s。⚠️**核心发现:同子集多 epoch 近天花板**——多喂 ~1B token,val 仅 ↓0.05(3.07→3.01)且 ~step44000 后走平、采样无质变。真"更多数据"要**新 FineWeb shards**(更多样 token),非重复同一子集。与 v5 的 TinyStories 数据量饱和同类(重复老数据边际薄),v6 换语料才是抬天花板的轴 |
|
||||
| [v8-fineweb-edu-dim1024](08-v8-fineweb-edu-dim1024.md) | **同 v6/v7 的 2.255B FineWeb-edu 子集**(非新数据) | ~2.36B | ~1.05 | **1024 / 18 / 32·32 / 2730** | **226.50M** | **329.42M** | **2.9801** *(FineWeb val,与 v6/v7 可比)* | **唯一变量=模型容量**(dim768→dim1024, core 127M→226M +78%) + bf16 + **激活重计算(T13)** 装下 dim1024;~5h/8 卡 ~129K tok/s(重算税)。⭐**核心 A/B:容量有用**——同 ~1ep v6 3.07→v8 **2.98**(↓0.085),且 v8(1.05ep) < v7(1.45ep 更多老数据) 3.01 ⇒ 放大容量 > 重复老数据 ⇒ v6/v7 部分 capacity-limited。⚠️但增益仅 ~3%(与数据轴单步同量级),val 末步**仍在降未饱和**。**元结论:单轴(数据/容量)单步都已 ~3%/lever = 全面边际递减,要双轴一起 scale(Chinchilla)** |
|
||||
| [v9-fineweb-edu-dim1280-gqa](09-v9-fineweb-edu-dim1280-gqa.md) | **FineWeb-edu 扩展 shards 000-009**(6.013B token) | **~6.01B** | ~1.00 | **1280 / 18 / 40·32 / 4096, kv=10 GQA** | **356.89M** | **485.55M** | **2.8854** *(moving-tail FineWeb val)* | **Chinchilla 双轴**:dim1024→1280 + 真 GQA + 新 FineWeb token,Phase-2 stack(`--flash`+accum+bf16+recompute+DDP),21.25h/8 卡 ~78.6K tok/s。相比 v8 moving-tail 再降 **0.0947 (~3.2%)**,验证双轴 scale 有效;greedy 样本更像真实说明文但仍重复,增益主要体现在 val 而非质变 |
|
||||
| [v10-fineweb-edu-dim1280-gqa-data6765](10-v10-fineweb-edu-dim1280-gqa-data6765.md) | **FineWeb-edu 扩展 shards 000-010**(6.765B token) | **~6.76B** | ~1.00 | **同 v9** | **356.89M** | **485.55M** | **2.8816** *(moving-tail FineWeb val)* | **只补数据轴**:同架构从头训,23.86h/8 卡 ~79.0K tok/s。moving-tail 比 v9 只低 0.0038,不宜过读;固定 eval v1 上 v9 **2.9278**→v10 **2.8814**,说明补 shard010 对新分布有效。greedy 复读未解决 |
|
||||
|
||||
## 下一档(提案)
|
||||
|
||||
- **v9**(待定方向):到 v8,**数据量轴(v5/v7 饱和) / 数据广度轴(v6 一次性红利) / 容量轴(v8 有用但 ~3%)** 三根
|
||||
单轴都已测过,且**单步杠杆都收敛到 ~3%/lever = 全面边际递减**。Chinchilla 教训在小尺度复现:v8 容量 +78% 却只配
|
||||
同样的 2.36B token,val 末步仍在降 ⇒ 数据立刻成新瓶颈 ⇒ **容量与数据要匹配地一起 scale**。v9 选项:
|
||||
**1. 双轴一起 scale(最符合 Chinchilla:更大模型 + 新 FineWeb shards,真 scale 但大投入)**;
|
||||
**2. dim1024 多喂数据(最便宜:v8 才 1.05ep 未饱和,续训到 2–3ep / 加新 shards,直接验证容量是否被数据卡住)**;
|
||||
**3. 自然收尾(8 版 + 从零全栈 + 三轴完整分析 + Chinchilla 边际元结论,学习线已讲完整个故事)**。
|
||||
详见 [08-v8](08-v8-fineweb-edu-dim1024.md) 末尾 "v9 提案"。
|
||||
- **v11**:优先走**更大模型 + 更长 context**,而不是继续只补数据。smoke 已验证 dim1536/20L/48q/12kv/ffn6144
|
||||
能跑 seq512 和 seq1024,但峰值约 30.5GiB,贴近 5090 32GB 上限。建议先做 v11a(seq512,约 42h),
|
||||
或明确接受 2.5 天预算后做 v11b(seq1024,约 61h)。v11 必须使用固定 eval v1,避免 moving-tail 继续污染横比。
|
||||
|
||||
> **v7 时的提案(已被 v8 兑现,归档)**:v7 把首选定为「新 FineWeb shards」,把「更大模型(dim1024+,容量轴,
|
||||
> 需先做 T13 激活重计算)」列为待测。**v8 走了容量轴**并证明它有用(但 ~3%),把「是否 capacity-limited」从
|
||||
> 悬念变成了「部分是」的结论。
|
||||
</content>
|
||||
|
||||
10
scripts/chat_alpha_fixed_prompts.txt
Normal file
10
scripts/chat_alpha_fixed_prompts.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
# One escaped prompt per line. `greedy_sample` decodes literal \n before tokenizing.
|
||||
User: Explain supervised fine-tuning to a junior engineer.\nAssistant:
|
||||
User: What high-quality SFT data are we using now?\nAssistant:
|
||||
User: What training data did chat-alpha-v1 use?\nAssistant:
|
||||
User: What is 17% of 240?\nAssistant:
|
||||
User: I found that my small language model repeats the same phrase during generation. What should I inspect first?\nAssistant:
|
||||
User: Summarize this passage in one sentence: A team trained a base model, then continued with chat examples at a low learning rate. Validation loss improved, but they still need real prompt tests before calling it useful.\nAssistant:
|
||||
User: Who will win the world championship in 2099?\nAssistant:
|
||||
User: Give a compact checklist before launching an SFT run.\nAssistant:
|
||||
User: Write a Python function that returns the larger of two numbers.\nAssistant:
|
||||
329
scripts/run_v12_phase.sh
Executable file
329
scripts/run_v12_phase.sh
Executable file
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT="${XTRAIN_ROOT:-$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)}"
|
||||
cd "$ROOT"
|
||||
|
||||
export PATH="/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH"
|
||||
|
||||
strip_token_cache_suffix() {
|
||||
local path="$1"
|
||||
if [[ "$path" == *.u16.bin ]]; then
|
||||
printf '%s\n' "${path%.u16.bin}"
|
||||
else
|
||||
printf '%s\n' "$path"
|
||||
fi
|
||||
}
|
||||
|
||||
RUN_DIR="${RUN_DIR:-/dashscope-tmp/wjh/xtrain_v12}"
|
||||
TOKENIZER="${TOKENIZER:-/opt/wjh/models/gpt2/tokenizer.json}"
|
||||
CORPUS="${CORPUS:-data/fineweb-edu.txt}"
|
||||
FIXED_EVAL="${FIXED_EVAL:-/dashscope-tmp/wjh/xtrain_fixed_eval_v1/fineweb-fixed-eval-v1.txt}"
|
||||
EXPORT_DIR="${EXPORT_DIR:-/opt/wjh/projects/tiny-models/v12-fineweb-edu-1b-longctx}"
|
||||
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
|
||||
TMUX_SESSION="${TMUX_SESSION:-xtrain_v12}"
|
||||
|
||||
HEADS="${HEADS:-52}"
|
||||
HEAD_DIM="${HEAD_DIM:-32}"
|
||||
KV_HEADS="${KV_HEADS:-13}"
|
||||
LAYERS="${LAYERS:-22}"
|
||||
FFN="${FFN:-6656}"
|
||||
SEQ="${SEQ:-1024}"
|
||||
BATCH="${BATCH:-16}"
|
||||
ACCUM="${ACCUM:-15}"
|
||||
MAX_LR="${MAX_LR:-4e-4}"
|
||||
MIN_LR="${MIN_LR:-4e-5}"
|
||||
VAL_TOKENS="${VAL_TOKENS:-1000000}"
|
||||
EVAL_BATCHES="${EVAL_BATCHES:-64}"
|
||||
FIXED_EVAL_SEQ="${FIXED_EVAL_SEQ:-1024}"
|
||||
FIXED_EVAL_BATCHES="${FIXED_EVAL_BATCHES:-64}"
|
||||
PILOT_STEPS="${PILOT_STEPS:-300}"
|
||||
FULL_STEPS="${FULL_STEPS:-27524}"
|
||||
PILOT_EVAL_EVERY="${PILOT_EVAL_EVERY:-100}"
|
||||
FULL_EVAL_EVERY="${FULL_EVAL_EVERY:-500}"
|
||||
|
||||
CORPUS="$(strip_token_cache_suffix "$CORPUS")"
|
||||
FIXED_EVAL="$(strip_token_cache_suffix "$FIXED_EVAL")"
|
||||
|
||||
ARCH_ARGS=(
|
||||
--heads "$HEADS"
|
||||
--head-dim "$HEAD_DIM"
|
||||
--kv-heads "$KV_HEADS"
|
||||
--layers "$LAYERS"
|
||||
--ffn "$FFN"
|
||||
)
|
||||
|
||||
usage() {
|
||||
cat <<'EOF'
|
||||
usage: scripts/run_v12_phase.sh ACTION
|
||||
|
||||
Actions:
|
||||
build Build xtrain train/export/sample binaries.
|
||||
smoke Run a short no-checkpoint v12 seq1024 smoke test in foreground.
|
||||
pilot Run a 300-step v12 pilot with held-out eval and checkpoint.
|
||||
full Run the full one-epoch v12 base training job.
|
||||
eval-fixed Evaluate a checkpoint on fixed eval v1.
|
||||
sample Run xtrain greedy_sample on fixed chat-alpha prompts.
|
||||
export Export a checkpoint to xserv/tiny-models format.
|
||||
status Print one progress snapshot from RUN_DIR/full.log or pilot.log.
|
||||
monitor Show a refreshing progress dashboard until interrupted.
|
||||
start-pilot Start pilot + monitor in tmux sessions.
|
||||
start-full Start full train + monitor in tmux sessions.
|
||||
|
||||
Environment overrides:
|
||||
RUN_DIR, TOKENIZER, CORPUS, FIXED_EVAL, EXPORT_DIR, CUDA_VISIBLE_DEVICES
|
||||
HEADS, HEAD_DIM, KV_HEADS, LAYERS, FFN, SEQ, BATCH, ACCUM
|
||||
MAX_LR, MIN_LR, PILOT_STEPS, FULL_STEPS, FIXED_EVAL_SEQ
|
||||
EOF
|
||||
}
|
||||
|
||||
build() {
|
||||
cargo build --release -p xtrain-distributed --bin train_ddp
|
||||
cargo build --release -p xtrain-train --bin train --bin export_safetensors --bin greedy_sample
|
||||
}
|
||||
|
||||
write_meta() {
|
||||
local kind="$1"
|
||||
mkdir -p "$RUN_DIR"
|
||||
{
|
||||
echo "run=$kind"
|
||||
echo "created_utc=$(date -u '+%Y-%m-%dT%H:%M:%SZ')"
|
||||
echo "arch=heads${HEADS}_hd${HEAD_DIM}_kv${KV_HEADS}_layers${LAYERS}_ffn${FFN}"
|
||||
echo "seq=$SEQ"
|
||||
echo "batch=$BATCH"
|
||||
echo "accum=$ACCUM"
|
||||
echo "effective_batch=$((BATCH * ACCUM))"
|
||||
echo "tokens_per_step=$((BATCH * ACCUM * SEQ))"
|
||||
echo "max_lr=$MAX_LR"
|
||||
echo "min_lr=$MIN_LR"
|
||||
echo "corpus=$CORPUS"
|
||||
echo "fixed_eval=$FIXED_EVAL"
|
||||
echo "fixed_eval_seq=$FIXED_EVAL_SEQ"
|
||||
} > "$RUN_DIR/META.txt"
|
||||
}
|
||||
|
||||
write_env_file() {
|
||||
mkdir -p "$RUN_DIR"
|
||||
local env_file="$RUN_DIR/env.sh"
|
||||
: > "$env_file"
|
||||
local names=(
|
||||
XTRAIN_ROOT RUN_DIR TOKENIZER CORPUS FIXED_EVAL EXPORT_DIR CUDA_VISIBLE_DEVICES
|
||||
TMUX_SESSION HEADS HEAD_DIM KV_HEADS LAYERS FFN SEQ BATCH ACCUM MAX_LR MIN_LR
|
||||
VAL_TOKENS EVAL_BATCHES FIXED_EVAL_SEQ FIXED_EVAL_BATCHES PILOT_STEPS
|
||||
FULL_STEPS PILOT_EVAL_EVERY FULL_EVAL_EVERY
|
||||
)
|
||||
for name in "${names[@]}"; do
|
||||
if [[ "$name" == "XTRAIN_ROOT" ]]; then
|
||||
printf 'export XTRAIN_ROOT=%q\n' "$ROOT" >> "$env_file"
|
||||
else
|
||||
printf 'export %s=%q\n' "$name" "${!name}" >> "$env_file"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
run_train() {
|
||||
local kind="$1"
|
||||
local steps="$2"
|
||||
local eval_every="$3"
|
||||
local ckpt="$4"
|
||||
local log="$RUN_DIR/${kind}.log"
|
||||
write_meta "$kind"
|
||||
echo "$steps" > "$RUN_DIR/${kind}.steps"
|
||||
echo "$((BATCH * ACCUM * SEQ))" > "$RUN_DIR/${kind}.tokens_per_step"
|
||||
{
|
||||
echo "RUN_NAME=xtrain_v12_${kind}"
|
||||
echo "RUN_START_ISO=$(date -u '+%Y-%m-%dT%H:%M:%SZ')"
|
||||
echo "RUN_START_EPOCH=$(date +%s)"
|
||||
echo "CKPT=$ckpt"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
echo "TOTAL_STEPS=$steps"
|
||||
echo "TOKENS_PER_STEP=$((BATCH * ACCUM * SEQ))"
|
||||
set -x
|
||||
set +e
|
||||
if [[ -n "$ckpt" ]]; then
|
||||
CUDA_VISIBLE_DEVICES="$CUDA_VISIBLE_DEVICES" target/release/train_ddp \
|
||||
"$TOKENIZER" "$CORPUS" \
|
||||
"${ARCH_ARGS[@]}" \
|
||||
--steps "$steps" --batch "$BATCH" --accum-steps "$ACCUM" --seq "$SEQ" \
|
||||
--max-lr "$MAX_LR" --min-lr "$MIN_LR" \
|
||||
--val-tokens "$VAL_TOKENS" --eval-every "$eval_every" --eval-batches "$EVAL_BATCHES" \
|
||||
--bf16 --recompute --flash --dropout 0.0 \
|
||||
--ckpt "$ckpt"
|
||||
rc=$?
|
||||
else
|
||||
CUDA_VISIBLE_DEVICES="$CUDA_VISIBLE_DEVICES" target/release/train_ddp \
|
||||
"$TOKENIZER" "$CORPUS" \
|
||||
"${ARCH_ARGS[@]}" \
|
||||
--steps "$steps" --batch "$BATCH" --accum-steps "$ACCUM" --seq "$SEQ" \
|
||||
--max-lr "$MAX_LR" --min-lr "$MIN_LR" \
|
||||
--val-tokens 0 --eval-every 0 --eval-batches "$EVAL_BATCHES" \
|
||||
--bf16 --recompute --flash --dropout 0.0
|
||||
rc=$?
|
||||
fi
|
||||
set -e
|
||||
set +x
|
||||
echo "RUN_END_ISO=$(date -u '+%Y-%m-%dT%H:%M:%SZ')"
|
||||
echo "RUN_EXIT_CODE=$rc"
|
||||
exit "$rc"
|
||||
} 2>&1 | tee "$log"
|
||||
}
|
||||
|
||||
checkpoint_path() {
|
||||
local preferred="$RUN_DIR/xtrain_v12.ckpt"
|
||||
local pilot="$RUN_DIR/xtrain_v12_pilot.ckpt"
|
||||
if [[ -n "${CKPT:-}" ]]; then
|
||||
echo "$CKPT"
|
||||
elif [[ -f "$preferred" ]]; then
|
||||
echo "$preferred"
|
||||
else
|
||||
echo "$pilot"
|
||||
fi
|
||||
}
|
||||
|
||||
eval_fixed() {
|
||||
local ckpt
|
||||
ckpt="$(checkpoint_path)"
|
||||
target/release/train \
|
||||
"$TOKENIZER" "$FIXED_EVAL" \
|
||||
"${ARCH_ARGS[@]}" \
|
||||
--seq "$FIXED_EVAL_SEQ" --batch 1 --steps 1 \
|
||||
--val-tokens "$VAL_TOKENS" --eval-batches "$FIXED_EVAL_BATCHES" \
|
||||
--bf16 --recompute --flash \
|
||||
--eval-ckpt "$ckpt" \
|
||||
2>&1 | tee "$RUN_DIR/eval_fixed.log"
|
||||
}
|
||||
|
||||
sample_fixed() {
|
||||
local ckpt
|
||||
ckpt="$(checkpoint_path)"
|
||||
target/release/greedy_sample \
|
||||
"$ckpt" "$TOKENIZER" \
|
||||
"${ARCH_ARGS[@]}" \
|
||||
--max-tokens "${MAX_TOKENS:-120}" \
|
||||
--temperature "${TEMPERATURE:-0}" \
|
||||
--prompts-file "${PROMPTS_FILE:-scripts/chat_alpha_fixed_prompts.txt}" \
|
||||
2>&1 | tee "$RUN_DIR/sample_fixed.log"
|
||||
}
|
||||
|
||||
export_model() {
|
||||
local ckpt
|
||||
ckpt="$(checkpoint_path)"
|
||||
rm -rf "$EXPORT_DIR"
|
||||
target/release/export_safetensors \
|
||||
"$ckpt" "$TOKENIZER" "$EXPORT_DIR" \
|
||||
"${ARCH_ARGS[@]}"
|
||||
cp "$ckpt" "$EXPORT_DIR/xtrain.ckpt"
|
||||
echo "$EXPORT_DIR" | tee "$RUN_DIR/export_path.txt"
|
||||
}
|
||||
|
||||
progress_once() {
|
||||
local log="${1:-$RUN_DIR/full.log}"
|
||||
[[ -f "$log" ]] || log="$RUN_DIR/pilot.log"
|
||||
python3 - "$log" <<'PY'
|
||||
import os, re, sys, time
|
||||
log = sys.argv[1]
|
||||
text = open(log, errors="ignore").read() if os.path.exists(log) else ""
|
||||
steps = re.findall(r"\[rank0\] step\s+(\d+)/(\d+): loss\s+(\S+) lr\s+(\S+) gnorm\s+(\S+) \((\S+) tok/s global", text)
|
||||
evals = re.findall(r"eval @ step\s+(\d+): val loss\s+(\S+)( \(best\))?", text)
|
||||
start = re.search(r"RUN_START_EPOCH=(\d+)", text)
|
||||
tokens_per_step = re.search(r"TOKENS_PER_STEP=(\d+)", text)
|
||||
tokens_per_step = int(tokens_per_step.group(1)) if tokens_per_step else 245760
|
||||
exit_code = re.search(r"RUN_EXIT_CODE=(\d+)", text)
|
||||
warnings = re.findall(r"(?i)(nan|inf|oom|out of memory|panic|error)", text)
|
||||
print("xtrain v12 |", time.strftime("%Y-%m-%d %H:%M:%S %Z"), "| log:", log)
|
||||
if warnings:
|
||||
print("WARNING: suspicious log tokens:", ", ".join(sorted(set(w.lower() for w in warnings))[:8]))
|
||||
if not steps:
|
||||
print("waiting for first rank0 step")
|
||||
else:
|
||||
s, total, loss, lr, gnorm, tps = steps[-1]
|
||||
done = int(s) + 1
|
||||
total = int(total)
|
||||
pct = min(100.0, done * 100.0 / total)
|
||||
width = 44
|
||||
fill = int(width * pct / 100.0)
|
||||
bar = "#" * fill + "." * (width - fill)
|
||||
try:
|
||||
tpsf = float(tps)
|
||||
except ValueError:
|
||||
tpsf = 0.0
|
||||
elapsed = time.time() - int(start.group(1)) if start else None
|
||||
eta = (total - done) * tokens_per_step / tpsf if tpsf > 0 else None
|
||||
def fmt(sec):
|
||||
if sec is None:
|
||||
return "n/a"
|
||||
sec = int(max(0, sec))
|
||||
h, r = divmod(sec, 3600)
|
||||
m, s = divmod(r, 60)
|
||||
return f"{h:02d}:{m:02d}:{s:02d}"
|
||||
print(f"[{bar}] {pct:6.2f}%")
|
||||
print(f"step {done}/{total} | loss {loss} | lr {lr} | gnorm {gnorm}")
|
||||
print(f"speed {tpsf:,.0f} tok/s | elapsed {fmt(elapsed)} | ETA {fmt(eta)}")
|
||||
if evals:
|
||||
s, v, best = evals[-1]
|
||||
best_vals = []
|
||||
for _, vv, mark in evals:
|
||||
if not mark:
|
||||
continue
|
||||
try:
|
||||
best_vals.append(float(vv))
|
||||
except ValueError:
|
||||
pass
|
||||
best_txt = f"best {min(best_vals):.4f}" if best_vals else "best n/a"
|
||||
try:
|
||||
val_txt = f"{float(v):.4f}"
|
||||
except ValueError:
|
||||
val_txt = v
|
||||
print(f"eval step {int(s)+1}: val {val_txt} {best.strip()} | {best_txt}")
|
||||
else:
|
||||
print("eval: waiting")
|
||||
if exit_code:
|
||||
print("FINISHED exit code", exit_code.group(1))
|
||||
PY
|
||||
echo
|
||||
nvidia-smi --query-gpu=index,memory.used,utilization.gpu --format=csv,noheader,nounits \
|
||||
| awk -F, '{printf "gpu%s %sMiB %s%% ", $1, $2, $3} NR%4==0{print ""} END{print ""}'
|
||||
df -h /dashscope-tmp | awk 'NR==2{print "Disk: "$4" free ("$5" used)"}'
|
||||
}
|
||||
|
||||
monitor() {
|
||||
while true; do
|
||||
clear
|
||||
progress_once
|
||||
sleep "${MONITOR_INTERVAL:-30}"
|
||||
done
|
||||
}
|
||||
|
||||
start_tmux() {
|
||||
local kind="$1"
|
||||
local session="$TMUX_SESSION"
|
||||
if tmux has-session -t "=${session}" 2>/dev/null; then
|
||||
echo "tmux session already exists: $session"
|
||||
echo "attach: tmux attach -t $session"
|
||||
exit 1
|
||||
fi
|
||||
write_env_file
|
||||
tmux new-session -d -s "$session" "bash -lc 'source \"$RUN_DIR/env.sh\" && cd \"$ROOT\" && scripts/run_v12_phase.sh $kind'"
|
||||
if ! tmux has-session -t "=${session}_mon" 2>/dev/null; then
|
||||
tmux new-session -d -s "${session}_mon" "bash -lc 'source \"$RUN_DIR/env.sh\" && cd \"$ROOT\" && scripts/run_v12_phase.sh monitor'"
|
||||
fi
|
||||
echo "started $kind in tmux: $session"
|
||||
echo "monitor: tmux attach -t ${session}_mon"
|
||||
}
|
||||
|
||||
action="${1:-}"
|
||||
case "$action" in
|
||||
build) build ;;
|
||||
smoke) build; run_train smoke "${SMOKE_STEPS:-30}" 0 "" ;;
|
||||
pilot) build; run_train pilot "$PILOT_STEPS" "$PILOT_EVAL_EVERY" "$RUN_DIR/xtrain_v12_pilot.ckpt" ;;
|
||||
full) build; run_train full "$FULL_STEPS" "$FULL_EVAL_EVERY" "$RUN_DIR/xtrain_v12.ckpt" ;;
|
||||
eval-fixed) build; eval_fixed ;;
|
||||
sample) build; sample_fixed ;;
|
||||
export) build; export_model ;;
|
||||
status) progress_once ;;
|
||||
monitor) monitor ;;
|
||||
start-pilot) start_tmux pilot ;;
|
||||
start-full) start_tmux full ;;
|
||||
""|-h|--help|help) usage ;;
|
||||
*) echo "unknown action: $action" >&2; usage >&2; exit 2 ;;
|
||||
esac
|
||||
Reference in New Issue
Block a user