data: full TinyStories + tokenized-id cache, val loss, CLI arch
- Corpus::load_cached: tokenize the (large) corpus ONCE, cache the id stream to
<corpus>.u16.bin (gpt2 vocab 50257 < 65536 → exact u16), read cache on reruns.
- Corpus::split_tail: hold out a tail slice as a validation corpus.
- train(): take an optional valid corpus + eval_every/eval_batches; periodic
deterministic val-loss eval that checkpoints the BEST val model; returns
TrainResult{train_losses, evals, best_val}. T6 fixed-cadence path preserved.
- bin/train + bin/export_safetensors: read architecture (--heads/--head-dim/
--layers/--ffn) + opt knobs (--steps/--batch/--seq/--max-lr/--val-tokens/
--eval-every) from CLI flags; defaults reproduce the v0-baseline tiny config.
- gitignore the multi-GB corpus + *.u16.bin caches + *.ckpt (dash5-only).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -9,3 +9,9 @@
|
|||||||
|
|
||||||
# Claude Code runtime state
|
# Claude Code runtime state
|
||||||
/.claude/
|
/.claude/
|
||||||
|
|
||||||
|
# Large scaling-run corpora + tokenized id caches live on dash5 only, never in
|
||||||
|
# git (the small data/tinystories-valid-3mb.txt is committed as a fixture).
|
||||||
|
/data/tinystories-train.txt
|
||||||
|
*.u16.bin
|
||||||
|
*.ckpt
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
//!
|
//!
|
||||||
//! See `docs/08-export-xserv.md` for the full architecture diff + mapping table.
|
//! See `docs/08-export-xserv.md` for the full architecture diff + mapping table.
|
||||||
//!
|
//!
|
||||||
//! Run on dash5 (needs a GPU to materialise the checkpoint params):
|
//! Run on dash5 (needs a GPU to materialise the checkpoint params). The model
|
||||||
|
//! architecture must match the checkpoint — pass the same arch flags used to
|
||||||
|
//! train (defaults reproduce the v0-baseline tiny config):
|
||||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||||
//! cargo run -p xtrain-train --release --bin export_safetensors -- \
|
//! cargo run -p xtrain-train --release --bin export_safetensors -- \
|
||||||
//! /tmp/xtrain_tinystories.ckpt \
|
//! /tmp/xtrain_v1.ckpt /opt/wjh/models/gpt2/tokenizer.json /tmp/xtrain_export \
|
||||||
//! /opt/wjh/models/gpt2/tokenizer.json \
|
//! --heads 8 --head-dim 32 --layers 8 --ffn 1024
|
||||||
//! /tmp/xtrain_export
|
|
||||||
|
|
||||||
#[cfg(no_cuda)]
|
#[cfg(no_cuda)]
|
||||||
fn main() {
|
fn main() {
|
||||||
@@ -39,6 +40,16 @@ use xtrain_model::{Config, TinyTransformer, param_to_host};
|
|||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
use xtrain_tensor::Device;
|
use xtrain_tensor::Device;
|
||||||
|
|
||||||
|
// A flag like `--layers 8`: scan argv for `name`, parse the following token.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == name)
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
// Same deterministic init scheme as bin/train.rs, so a freshly-built model has
|
// Same deterministic init scheme as bin/train.rs, so a freshly-built model has
|
||||||
// the right shapes before `load_into` overwrites the values from the checkpoint.
|
// the right shapes before `load_into` overwrites the values from the checkpoint.
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
@@ -176,29 +187,34 @@ fn main() {
|
|||||||
use xserv_tokenizer::Tokenizer;
|
use xserv_tokenizer::Tokenizer;
|
||||||
|
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
let ckpt = args
|
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||||
.get(1)
|
let ckpt = positionals
|
||||||
.map(PathBuf::from)
|
.first()
|
||||||
|
.map(|s| PathBuf::from(s.as_str()))
|
||||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
|
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
|
||||||
let tok_path = args
|
let tok_path = positionals
|
||||||
.get(2)
|
.get(1)
|
||||||
.map(PathBuf::from)
|
.map(|s| PathBuf::from(s.as_str()))
|
||||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||||
let out_dir = args
|
let out_dir = positionals
|
||||||
.get(3)
|
.get(2)
|
||||||
.map(PathBuf::from)
|
.map(|s| PathBuf::from(s.as_str()))
|
||||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_export"));
|
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_export"));
|
||||||
|
|
||||||
|
// Architecture must match the checkpoint. Defaults = v0-baseline tiny config.
|
||||||
|
let n_heads = flag(&args, "--heads", 2usize);
|
||||||
|
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||||
|
let n_layers = flag(&args, "--layers", 4usize);
|
||||||
|
let ffn = flag(&args, "--ffn", 64usize);
|
||||||
|
|
||||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||||
device::set_device(0).unwrap();
|
device::set_device(0).unwrap();
|
||||||
let dev = Device::Cuda(0);
|
let dev = Device::Cuda(0);
|
||||||
|
|
||||||
// Size the model exactly like bin/train.rs: gpt2 vocab + n_layers = 4.
|
// Size the model from the arch flags + gpt2 vocab; must match the checkpoint.
|
||||||
let tok = Tokenizer::from_file(&tok_path);
|
let tok = Tokenizer::from_file(&tok_path);
|
||||||
let vocab = tok.vocab_size();
|
let vocab = tok.vocab_size();
|
||||||
let mut cfg = Config::tiny();
|
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||||
cfg.vocab = vocab;
|
|
||||||
cfg.n_layers = 4;
|
|
||||||
println!(
|
println!(
|
||||||
"export: ckpt {} → {} (vocab {}, dim {}, layers {}, heads {}, head_dim {})",
|
"export: ckpt {} → {} (vocab {}, dim {}, layers {}, heads {}, head_dim {})",
|
||||||
ckpt.display(),
|
ckpt.display(),
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
//! End-to-end training entry point (Phase T6): load the GPT-2 BPE + TinyStories
|
//! End-to-end training entry point: load the GPT-2 BPE + a TinyStories corpus,
|
||||||
//! corpus, train the tiny transformer with hand-written AdamW for a BOUNDED
|
//! train the tiny transformer with hand-written AdamW for a BOUNDED budget,
|
||||||
//! budget, checkpoint it, and print a few generated samples.
|
//! evaluate held-out val loss, checkpoint the best, and print a few samples.
|
||||||
|
//!
|
||||||
|
//! The MODEL SIZE is a CLI-tunable scaling-ladder rung (v0 baseline = the
|
||||||
|
//! defaults; v1 = dim256/8L/8h via flags), not a hardcoded tiny config.
|
||||||
//!
|
//!
|
||||||
//! Run on dash5 (needs a GPU + the corpus + tokenizer.json):
|
//! Run on dash5 (needs a GPU + the corpus + tokenizer.json):
|
||||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||||
//! cargo run -p xtrain-train --release --bin train -- \
|
//! cargo run -p xtrain-train --release --bin train -- \
|
||||||
//! /opt/wjh/models/gpt2/tokenizer.json \
|
//! /opt/wjh/models/gpt2/tokenizer.json data/tinystories-train.txt \
|
||||||
//! data/tinystories-valid-3mb.txt
|
//! --dim 256 --heads 8 --head-dim 32 --layers 8 --ffn 1024 \
|
||||||
|
//! --steps 3000 --batch 16 --seq 128 --max-lr 6e-4 \
|
||||||
|
//! --val-tokens 200000 --eval-every 250 --ckpt /tmp/xtrain_v1.ckpt
|
||||||
//!
|
//!
|
||||||
//! Optional 3rd/4th args: number of steps, checkpoint path.
|
//! Positional: <tokenizer.json> <corpus.txt>. Everything else is a flag with a
|
||||||
|
//! sane default (defaults reproduce the v0-baseline tiny config).
|
||||||
|
|
||||||
// On a GPU-less host (no_cuda) the whole training body is unavailable; keep a
|
// On a GPU-less host (no_cuda) the whole training body is unavailable; keep a
|
||||||
// stub `main` so the crate still builds for `cargo check`.
|
// stub `main` so the crate still builds for `cargo check`.
|
||||||
@@ -51,51 +57,101 @@ fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A flag like `--dim 256`: scan argv for `name`, parse the following token.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == name)
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
fn main() {
|
fn main() {
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
let tok_path = args
|
// First two non-flag positionals: tokenizer.json, corpus.txt.
|
||||||
.get(1)
|
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||||
.map(PathBuf::from)
|
let tok_path = positionals
|
||||||
|
.first()
|
||||||
|
.map(|s| PathBuf::from(s.as_str()))
|
||||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||||
let corpus_path = args
|
let corpus_path = positionals
|
||||||
.get(2)
|
.get(1)
|
||||||
.map(PathBuf::from)
|
.map(|s| PathBuf::from(s.as_str()))
|
||||||
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
|
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
|
||||||
let steps: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(2000);
|
|
||||||
let ckpt: PathBuf = args
|
// Architecture (scaling-ladder rung). Defaults = v0-baseline tiny config.
|
||||||
.get(4)
|
let n_heads = flag(&args, "--heads", 2usize);
|
||||||
.map(PathBuf::from)
|
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
|
let n_layers = flag(&args, "--layers", 4usize);
|
||||||
|
let ffn = flag(&args, "--ffn", 64usize);
|
||||||
|
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
|
||||||
|
let dim_flag = flag(&args, "--dim", 0usize);
|
||||||
|
if dim_flag != 0 && dim_flag != n_heads * head_dim {
|
||||||
|
eprintln!(
|
||||||
|
"warning: --dim {dim_flag} != heads*head_dim {}; using {}",
|
||||||
|
n_heads * head_dim,
|
||||||
|
n_heads * head_dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimization knobs.
|
||||||
|
let steps: usize = flag(&args, "--steps", 2000);
|
||||||
|
let batch_size: usize = flag(&args, "--batch", 8);
|
||||||
|
let seq_len: usize = flag(&args, "--seq", 64);
|
||||||
|
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||||
|
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||||
|
let weight_decay: f32 = flag(&args, "--wd", 0.1);
|
||||||
|
let max_grad_norm: f32 = flag(&args, "--clip", 1.0);
|
||||||
|
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||||
|
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||||
|
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||||
|
let ckpt: PathBuf = PathBuf::from(
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == "--ckpt")
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "/tmp/xtrain_tinystories.ckpt".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||||
device::set_device(0).unwrap();
|
device::set_device(0).unwrap();
|
||||||
let device = Device::Cuda(0);
|
let device = Device::Cuda(0);
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"loading tokenizer {} + corpus {}",
|
"loading tokenizer {} + corpus {} (cached id stream)",
|
||||||
tok_path.display(),
|
tok_path.display(),
|
||||||
corpus_path.display()
|
corpus_path.display()
|
||||||
);
|
);
|
||||||
let corpus = Corpus::load(&tok_path, &corpus_path);
|
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||||
println!(
|
println!(
|
||||||
"corpus: {} tokens, vocab {}",
|
"corpus: {} tokens, vocab {}",
|
||||||
corpus.len(),
|
corpus.len(),
|
||||||
corpus.vocab_size
|
corpus.vocab_size
|
||||||
);
|
);
|
||||||
|
let vocab = corpus.vocab_size;
|
||||||
|
// Hold out a tail slice for validation (if requested and the corpus is big).
|
||||||
|
let (train_corpus, valid) = if val_tokens > 0 {
|
||||||
|
let (t, v) = corpus.split_tail(val_tokens);
|
||||||
|
println!("split: {} train tokens / {} val tokens", t.len(), v.len());
|
||||||
|
(t, Some(v))
|
||||||
|
} else {
|
||||||
|
(corpus, None)
|
||||||
|
};
|
||||||
|
|
||||||
// Tiny model sized to the BPE vocab. A real (but small) config: wider than
|
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||||
// the overfit test so it has capacity to learn English structure.
|
|
||||||
let mut cfg = Config::tiny();
|
|
||||||
cfg.vocab = corpus.vocab_size;
|
|
||||||
cfg.n_layers = 4;
|
|
||||||
println!(
|
println!(
|
||||||
"model: dim {} layers {} heads {} ffn {} → {} params",
|
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||||
|
(+ embed/lm {:.2}M = {:.2}M total)",
|
||||||
cfg.dim,
|
cfg.dim,
|
||||||
cfg.n_layers,
|
cfg.n_layers,
|
||||||
cfg.n_heads,
|
cfg.n_heads,
|
||||||
|
cfg.head_dim,
|
||||||
cfg.ffn_hidden,
|
cfg.ffn_hidden,
|
||||||
cfg.num_params()
|
cfg.core_params() as f32 / 1e6,
|
||||||
|
(cfg.num_params() - cfg.core_params()) as f32 / 1e6,
|
||||||
|
cfg.num_params() as f32 / 1e6,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut seed = 1u64;
|
let mut seed = 1u64;
|
||||||
@@ -111,33 +167,45 @@ fn main() {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let seq_len = 64;
|
|
||||||
let tcfg = TrainConfig {
|
let tcfg = TrainConfig {
|
||||||
seq_len,
|
seq_len,
|
||||||
batch_size: 8,
|
batch_size,
|
||||||
steps,
|
steps,
|
||||||
schedule: LrSchedule {
|
schedule: LrSchedule {
|
||||||
max_lr: 3e-3,
|
max_lr,
|
||||||
min_lr: 3e-4,
|
min_lr,
|
||||||
warmup: (steps / 20).max(20),
|
warmup: (steps / 20).max(20),
|
||||||
total: steps,
|
total: steps,
|
||||||
},
|
},
|
||||||
weight_decay: 0.1,
|
weight_decay,
|
||||||
max_grad_norm: 1.0,
|
max_grad_norm,
|
||||||
log_every: 50,
|
log_every: 50,
|
||||||
ckpt_path: Some(ckpt.clone()),
|
ckpt_path: Some(ckpt.clone()),
|
||||||
ckpt_every: 500,
|
ckpt_every: 500,
|
||||||
|
eval_every,
|
||||||
|
eval_batches,
|
||||||
seed: 42,
|
seed: 42,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}",
|
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}, eval every {}",
|
||||||
tcfg.steps, tcfg.seq_len, tcfg.batch_size, tcfg.schedule.max_lr, tcfg.schedule.min_lr
|
tcfg.steps,
|
||||||
|
tcfg.seq_len,
|
||||||
|
tcfg.batch_size,
|
||||||
|
tcfg.schedule.max_lr,
|
||||||
|
tcfg.schedule.min_lr,
|
||||||
|
tcfg.eval_every
|
||||||
);
|
);
|
||||||
let losses = train(&model, device, &corpus, &tcfg);
|
let result = train(&model, device, &train_corpus, valid.as_ref(), &tcfg);
|
||||||
let start = losses.first().copied().unwrap_or(0.0);
|
let start = result.train_losses.first().copied().unwrap_or(0.0);
|
||||||
let end = losses.last().copied().unwrap_or(0.0);
|
let end = result.train_losses.last().copied().unwrap_or(0.0);
|
||||||
println!("loss: start {start:.4} → end {end:.4}");
|
println!("train loss: start {start:.4} → end {end:.4}");
|
||||||
|
if let Some(best) = result.best_val {
|
||||||
|
println!("best val loss: {best:.4}");
|
||||||
|
}
|
||||||
|
if let Some((s, v)) = result.evals.last() {
|
||||||
|
println!("final val loss (step {s}): {v:.4}");
|
||||||
|
}
|
||||||
|
|
||||||
sample_some(&model, device, &tok_path);
|
sample_some(&model, device, &tok_path);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
//! Data pipeline: load the GPT-2 BPE (reusing xserv's from-scratch tokenizer),
|
//! Data pipeline: load the GPT-2 BPE (reusing xserv's from-scratch tokenizer),
|
||||||
//! tokenize a text corpus into one flat token stream, and sample fixed-length
|
//! tokenize a text corpus into one flat token stream, and sample fixed-length
|
||||||
//! `(input, target)` windows for next-token prediction. Host-only (no GPU).
|
//! `(input, target)` windows for next-token prediction. Host-only (no GPU).
|
||||||
|
//!
|
||||||
|
//! For the scaling runs the corpus is large (full TinyStories ≈ 2 GB / ~470 M
|
||||||
|
//! tokens), and the from-scratch BPE is slow, so [`Corpus::load_cached`]
|
||||||
|
//! tokenizes ONCE and caches the id stream to a `<corpus>.u16.bin` next to the
|
||||||
|
//! text (GPT-2 vocab = 50257 < 65536, so u16 is exact). Subsequent runs mmap-read
|
||||||
|
//! the cache instead of re-tokenizing.
|
||||||
|
|
||||||
use std::path::Path;
|
use std::io::{BufReader, BufWriter, Read, Write};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
use xserv_tokenizer::Tokenizer;
|
use xserv_tokenizer::Tokenizer;
|
||||||
|
|
||||||
/// A tokenized corpus: one flat stream of token ids, plus the vocab size.
|
/// A tokenized corpus: one flat stream of token ids, plus the vocab size.
|
||||||
@@ -30,6 +37,54 @@ impl Corpus {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Like [`load`](Self::load) but caches the tokenized id stream to
|
||||||
|
/// `<corpus_path>.u16.bin`. On the first run it tokenizes the (large) corpus
|
||||||
|
/// and writes the cache; on later runs it reads the cache directly, skipping
|
||||||
|
/// the slow BPE. The cache is just a flat little-endian `[u16]` (no header) —
|
||||||
|
/// it is keyed only by path, so delete it if the corpus or tokenizer changes.
|
||||||
|
pub fn load_cached(tokenizer_path: &Path, corpus_path: &Path) -> Self {
|
||||||
|
let cache = cache_path(corpus_path);
|
||||||
|
let vocab_size = Tokenizer::from_file(tokenizer_path).vocab_size();
|
||||||
|
if cache.exists() {
|
||||||
|
let tokens = read_u16_cache(&cache);
|
||||||
|
println!(
|
||||||
|
"corpus: read {} cached tokens from {}",
|
||||||
|
tokens.len(),
|
||||||
|
cache.display()
|
||||||
|
);
|
||||||
|
return Self { tokens, vocab_size };
|
||||||
|
}
|
||||||
|
let me = Self::load(tokenizer_path, corpus_path);
|
||||||
|
write_u16_cache(&cache, &me.tokens);
|
||||||
|
println!(
|
||||||
|
"corpus: tokenized {} tokens → cached to {}",
|
||||||
|
me.tokens.len(),
|
||||||
|
cache.display()
|
||||||
|
);
|
||||||
|
me
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split off the last `n` tokens as a held-out validation corpus, leaving the
|
||||||
|
/// rest as the train corpus. Returns `(train, valid)`. Used for periodic val
|
||||||
|
/// loss during training without leaking the eval window into training.
|
||||||
|
pub fn split_tail(self, n: usize) -> (Self, Self) {
|
||||||
|
let n = n.min(self.tokens.len() / 10); // never hand off more than 10%
|
||||||
|
let cut = self.tokens.len() - n;
|
||||||
|
let valid = self.tokens[cut..].to_vec();
|
||||||
|
let mut train = self.tokens;
|
||||||
|
train.truncate(cut);
|
||||||
|
(
|
||||||
|
Self {
|
||||||
|
tokens: train,
|
||||||
|
vocab_size: self.vocab_size,
|
||||||
|
},
|
||||||
|
Self {
|
||||||
|
tokens: valid,
|
||||||
|
vocab_size: self.vocab_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Total number of tokens.
|
/// Total number of tokens.
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
self.tokens.len()
|
self.tokens.len()
|
||||||
@@ -65,6 +120,40 @@ fn trim_to_whole_stories(text: &str) -> &str {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `<corpus_path>.u16.bin` — the token-id cache beside the corpus text.
|
||||||
|
fn cache_path(corpus_path: &Path) -> PathBuf {
|
||||||
|
let mut s = corpus_path.as_os_str().to_os_string();
|
||||||
|
s.push(".u16.bin");
|
||||||
|
PathBuf::from(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read a flat little-endian `[u16]` cache into an `i32` id stream.
|
||||||
|
fn read_u16_cache(path: &Path) -> Vec<i32> {
|
||||||
|
let mut r = BufReader::new(
|
||||||
|
std::fs::File::open(path).unwrap_or_else(|e| panic!("open cache {}: {e}", path.display())),
|
||||||
|
);
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
r.read_to_end(&mut buf).expect("read cache");
|
||||||
|
assert!(buf.len() % 2 == 0, "corrupt u16 cache (odd byte count)");
|
||||||
|
buf.chunks_exact(2)
|
||||||
|
.map(|b| u16::from_le_bytes([b[0], b[1]]) as i32)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write an id stream as a flat little-endian `[u16]` cache. Ids must fit in u16
|
||||||
|
/// (GPT-2 vocab = 50257 < 65536); asserts otherwise.
|
||||||
|
fn write_u16_cache(path: &Path, tokens: &[i32]) {
|
||||||
|
let mut w = BufWriter::new(
|
||||||
|
std::fs::File::create(path)
|
||||||
|
.unwrap_or_else(|e| panic!("create cache {}: {e}", path.display())),
|
||||||
|
);
|
||||||
|
for &t in tokens {
|
||||||
|
assert!((0..=u16::MAX as i32).contains(&t), "token id {t} > u16");
|
||||||
|
w.write_all(&(t as u16).to_le_bytes()).expect("write cache");
|
||||||
|
}
|
||||||
|
w.flush().expect("flush cache");
|
||||||
|
}
|
||||||
|
|
||||||
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
||||||
/// sampling is reproducible from a single u64 seed.
|
/// sampling is reproducible from a single u64 seed.
|
||||||
fn next_rand(state: &mut u64) -> u64 {
|
fn next_rand(state: &mut u64) -> u64 {
|
||||||
|
|||||||
@@ -19,4 +19,4 @@ pub mod sample;
|
|||||||
mod train_loop;
|
mod train_loop;
|
||||||
|
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
pub use train_loop::{TrainConfig, train};
|
pub use train_loop::{TrainConfig, TrainResult, train};
|
||||||
|
|||||||
@@ -31,28 +31,47 @@ pub struct TrainConfig {
|
|||||||
pub max_grad_norm: f32,
|
pub max_grad_norm: f32,
|
||||||
pub log_every: usize,
|
pub log_every: usize,
|
||||||
/// Optional checkpoint path written every `ckpt_every` steps (and at the end).
|
/// Optional checkpoint path written every `ckpt_every` steps (and at the end).
|
||||||
|
/// When `eval_every > 0`, the checkpoint instead tracks the BEST val loss.
|
||||||
pub ckpt_path: Option<PathBuf>,
|
pub ckpt_path: Option<PathBuf>,
|
||||||
pub ckpt_every: usize,
|
pub ckpt_every: usize,
|
||||||
|
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Each eval
|
||||||
|
/// averages cross-entropy over `eval_batches` fixed windows of the val corpus.
|
||||||
|
pub eval_every: usize,
|
||||||
|
pub eval_batches: usize,
|
||||||
/// Seed for reproducible sequence sampling.
|
/// Seed for reproducible sequence sampling.
|
||||||
pub seed: u64,
|
pub seed: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Outcome of a run: per-step train losses and (step, val_loss) eval points.
|
||||||
|
pub struct TrainResult {
|
||||||
|
pub train_losses: Vec<f32>,
|
||||||
|
pub evals: Vec<(usize, f32)>,
|
||||||
|
pub best_val: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step
|
/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step
|
||||||
/// loss trace (one mean loss per step, read from the first sequence of the
|
/// train-loss trace plus any (step, val_loss) eval points. Logs progress, and —
|
||||||
/// batch — cheap and representative). Logs progress and checkpoints as configured.
|
/// when `valid` is given and `cfg.eval_every > 0` — evaluates held-out val loss
|
||||||
|
/// periodically and checkpoints the BEST val model (else checkpoints on a fixed
|
||||||
|
/// cadence, as in T6). Logs progress.
|
||||||
pub fn train(
|
pub fn train(
|
||||||
model: &TinyTransformer,
|
model: &TinyTransformer,
|
||||||
device: Device,
|
device: Device,
|
||||||
corpus: &Corpus,
|
corpus: &Corpus,
|
||||||
|
valid: Option<&Corpus>,
|
||||||
cfg: &TrainConfig,
|
cfg: &TrainConfig,
|
||||||
) -> Vec<f32> {
|
) -> TrainResult {
|
||||||
let params = model.params();
|
let params = model.params();
|
||||||
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
||||||
let mut rng = cfg.seed;
|
let mut rng = cfg.seed;
|
||||||
let mut losses = Vec::with_capacity(cfg.steps);
|
let mut losses = Vec::with_capacity(cfg.steps);
|
||||||
|
let mut evals = Vec::new();
|
||||||
|
let mut best_val: Option<f32> = None;
|
||||||
let inv_batch = 1.0 / cfg.batch_size as f32;
|
let inv_batch = 1.0 / cfg.batch_size as f32;
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let mut tokens_seen: u64 = 0;
|
let mut tokens_seen: u64 = 0;
|
||||||
|
// Best-val checkpointing only kicks in when we actually evaluate.
|
||||||
|
let track_best = valid.is_some() && cfg.eval_every > 0;
|
||||||
|
|
||||||
for step in 0..cfg.steps {
|
for step in 0..cfg.steps {
|
||||||
let lr = cfg.schedule.lr(step);
|
let lr = cfg.schedule.lr(step);
|
||||||
@@ -88,18 +107,86 @@ pub fn train(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(path) = &cfg.ckpt_path {
|
// Periodic held-out eval (deterministic windows, no grad).
|
||||||
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
|
if let Some(v) = valid {
|
||||||
checkpoint::save(path, ¶ms).expect("checkpoint save");
|
if cfg.eval_every > 0 && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
|
||||||
|
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
|
||||||
|
evals.push((step, vl));
|
||||||
|
let improved = best_val.map(|b| vl < b).unwrap_or(true);
|
||||||
|
println!(
|
||||||
|
" eval @ step {step}: val loss {vl:.4}{}",
|
||||||
|
if improved { " (best)" } else { "" }
|
||||||
|
);
|
||||||
|
if improved {
|
||||||
|
best_val = Some(vl);
|
||||||
|
if let Some(path) = &cfg.ckpt_path {
|
||||||
|
checkpoint::save(path, ¶ms).expect("best checkpoint save");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fixed-cadence checkpointing (only when not tracking best val).
|
||||||
|
if !track_best {
|
||||||
|
if let Some(path) = &cfg.ckpt_path {
|
||||||
|
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
|
||||||
|
checkpoint::save(path, ¶ms).expect("checkpoint save");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(path) = &cfg.ckpt_path {
|
// Without periodic eval, still persist the final params (T6 behaviour). With
|
||||||
checkpoint::save(path, ¶ms).expect("final checkpoint save");
|
// best-val tracking the checkpoint already holds the best model — don't clobber.
|
||||||
println!("saved checkpoint → {}", path.display());
|
if !track_best {
|
||||||
|
if let Some(path) = &cfg.ckpt_path {
|
||||||
|
checkpoint::save(path, ¶ms).expect("final checkpoint save");
|
||||||
|
println!("saved checkpoint → {}", path.display());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TrainResult {
|
||||||
|
train_losses: losses,
|
||||||
|
evals,
|
||||||
|
best_val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mean cross-entropy over `batches` deterministic, non-overlapping windows of
|
||||||
|
/// the validation corpus (no backward — eval only). Deterministic so val loss is
|
||||||
|
/// comparable across steps and runs.
|
||||||
|
fn eval_loss(
|
||||||
|
model: &TinyTransformer,
|
||||||
|
device: Device,
|
||||||
|
valid: &Corpus,
|
||||||
|
seq: usize,
|
||||||
|
batches: usize,
|
||||||
|
) -> f32 {
|
||||||
|
if valid.len() <= seq + 1 {
|
||||||
|
return f32::NAN;
|
||||||
|
}
|
||||||
|
let n_win = (valid.len() - 1) / seq; // disjoint windows that fit
|
||||||
|
let batches = batches.max(1).min(n_win.max(1));
|
||||||
|
let stride = (n_win / batches).max(1);
|
||||||
|
let mut sum = 0.0f32;
|
||||||
|
let mut count = 0usize;
|
||||||
|
for i in 0..batches {
|
||||||
|
let s = (i * stride) * seq;
|
||||||
|
if s + seq + 1 > valid.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let input: Vec<i32> = valid.tokens[s..s + seq].to_vec();
|
||||||
|
let target: Vec<i32> = valid.tokens[s + 1..s + seq + 1].to_vec();
|
||||||
|
let ids = ids_tensor(&input, device);
|
||||||
|
let targets = ids_tensor(&target, device);
|
||||||
|
let loss = model.loss(&ids, &targets);
|
||||||
|
sum += read_scalar(&loss);
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
f32::NAN
|
||||||
|
} else {
|
||||||
|
sum / count as f32
|
||||||
}
|
}
|
||||||
losses
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
|
fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
|
||||||
|
|||||||
@@ -96,10 +96,12 @@ fn trains_on_tinystories() {
|
|||||||
log_every: 50,
|
log_every: 50,
|
||||||
ckpt_path: None,
|
ckpt_path: None,
|
||||||
ckpt_every: 0,
|
ckpt_every: 0,
|
||||||
|
eval_every: 0,
|
||||||
|
eval_batches: 0,
|
||||||
seed: 42,
|
seed: 42,
|
||||||
};
|
};
|
||||||
|
|
||||||
let losses = train(&model, device, &corpus, &tcfg);
|
let losses = train(&model, device, &corpus, None, &tcfg).train_losses;
|
||||||
// Average the first/last few steps to smooth per-step noise.
|
// Average the first/last few steps to smooth per-step noise.
|
||||||
let head: f32 =
|
let head: f32 =
|
||||||
losses[..10.min(losses.len())].iter().sum::<f32>() / 10.0_f32.min(losses.len() as f32);
|
losses[..10.min(losses.len())].iter().sum::<f32>() / 10.0_f32.min(losses.len() as f32);
|
||||||
|
|||||||
Reference in New Issue
Block a user