data: full TinyStories + tokenized-id cache, val loss, CLI arch
- Corpus::load_cached: tokenize the (large) corpus ONCE, cache the id stream to
<corpus>.u16.bin (gpt2 vocab 50257 < 65536 → exact u16), read cache on reruns.
- Corpus::split_tail: hold out a tail slice as a validation corpus.
- train(): take an optional valid corpus + eval_every/eval_batches; periodic
deterministic val-loss eval that checkpoints the BEST val model; returns
TrainResult{train_losses, evals, best_val}. T6 fixed-cadence path preserved.
- bin/train + bin/export_safetensors: read architecture (--heads/--head-dim/
--layers/--ffn) + opt knobs (--steps/--batch/--seq/--max-lr/--val-tokens/
--eval-every) from CLI flags; defaults reproduce the v0-baseline tiny config.
- gitignore the multi-GB corpus + *.u16.bin caches + *.ckpt (dash5-only).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -13,12 +13,13 @@
|
||||
//!
|
||||
//! See `docs/08-export-xserv.md` for the full architecture diff + mapping table.
|
||||
//!
|
||||
//! Run on dash5 (needs a GPU to materialise the checkpoint params):
|
||||
//! Run on dash5 (needs a GPU to materialise the checkpoint params). The model
|
||||
//! architecture must match the checkpoint — pass the same arch flags used to
|
||||
//! train (defaults reproduce the v0-baseline tiny config):
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! cargo run -p xtrain-train --release --bin export_safetensors -- \
|
||||
//! /tmp/xtrain_tinystories.ckpt \
|
||||
//! /opt/wjh/models/gpt2/tokenizer.json \
|
||||
//! /tmp/xtrain_export
|
||||
//! /tmp/xtrain_v1.ckpt /opt/wjh/models/gpt2/tokenizer.json /tmp/xtrain_export \
|
||||
//! --heads 8 --head-dim 32 --layers 8 --ffn 1024
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
@@ -39,6 +40,16 @@ use xtrain_model::{Config, TinyTransformer, param_to_host};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
// A flag like `--layers 8`: scan argv for `name`, parse the following token.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
// Same deterministic init scheme as bin/train.rs, so a freshly-built model has
|
||||
// the right shapes before `load_into` overwrites the values from the checkpoint.
|
||||
#[cfg(not(no_cuda))]
|
||||
@@ -176,29 +187,34 @@ fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let ckpt = args
|
||||
.get(1)
|
||||
.map(PathBuf::from)
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let ckpt = positionals
|
||||
.first()
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
|
||||
let tok_path = args
|
||||
.get(2)
|
||||
.map(PathBuf::from)
|
||||
let tok_path = positionals
|
||||
.get(1)
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
let out_dir = args
|
||||
.get(3)
|
||||
.map(PathBuf::from)
|
||||
let out_dir = positionals
|
||||
.get(2)
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_export"));
|
||||
|
||||
// Architecture must match the checkpoint. Defaults = v0-baseline tiny config.
|
||||
let n_heads = flag(&args, "--heads", 2usize);
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let dev = Device::Cuda(0);
|
||||
|
||||
// Size the model exactly like bin/train.rs: gpt2 vocab + n_layers = 4.
|
||||
// Size the model from the arch flags + gpt2 vocab; must match the checkpoint.
|
||||
let tok = Tokenizer::from_file(&tok_path);
|
||||
let vocab = tok.vocab_size();
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = vocab;
|
||||
cfg.n_layers = 4;
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
println!(
|
||||
"export: ckpt {} → {} (vocab {}, dim {}, layers {}, heads {}, head_dim {})",
|
||||
ckpt.display(),
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
//! End-to-end training entry point (Phase T6): load the GPT-2 BPE + TinyStories
|
||||
//! corpus, train the tiny transformer with hand-written AdamW for a BOUNDED
|
||||
//! budget, checkpoint it, and print a few generated samples.
|
||||
//! End-to-end training entry point: load the GPT-2 BPE + a TinyStories corpus,
|
||||
//! train the tiny transformer with hand-written AdamW for a BOUNDED budget,
|
||||
//! evaluate held-out val loss, checkpoint the best, and print a few samples.
|
||||
//!
|
||||
//! The MODEL SIZE is a CLI-tunable scaling-ladder rung (v0 baseline = the
|
||||
//! defaults; v1 = dim256/8L/8h via flags), not a hardcoded tiny config.
|
||||
//!
|
||||
//! Run on dash5 (needs a GPU + the corpus + tokenizer.json):
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! cargo run -p xtrain-train --release --bin train -- \
|
||||
//! /opt/wjh/models/gpt2/tokenizer.json \
|
||||
//! data/tinystories-valid-3mb.txt
|
||||
//! /opt/wjh/models/gpt2/tokenizer.json data/tinystories-train.txt \
|
||||
//! --dim 256 --heads 8 --head-dim 32 --layers 8 --ffn 1024 \
|
||||
//! --steps 3000 --batch 16 --seq 128 --max-lr 6e-4 \
|
||||
//! --val-tokens 200000 --eval-every 250 --ckpt /tmp/xtrain_v1.ckpt
|
||||
//!
|
||||
//! Optional 3rd/4th args: number of steps, checkpoint path.
|
||||
//! Positional: <tokenizer.json> <corpus.txt>. Everything else is a flag with a
|
||||
//! sane default (defaults reproduce the v0-baseline tiny config).
|
||||
|
||||
// On a GPU-less host (no_cuda) the whole training body is unavailable; keep a
|
||||
// stub `main` so the crate still builds for `cargo check`.
|
||||
@@ -51,51 +57,101 @@ fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
// A flag like `--dim 256`: scan argv for `name`, parse the following token.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let tok_path = args
|
||||
.get(1)
|
||||
.map(PathBuf::from)
|
||||
// First two non-flag positionals: tokenizer.json, corpus.txt.
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let tok_path = positionals
|
||||
.first()
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
let corpus_path = args
|
||||
.get(2)
|
||||
.map(PathBuf::from)
|
||||
let corpus_path = positionals
|
||||
.get(1)
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
|
||||
let steps: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(2000);
|
||||
let ckpt: PathBuf = args
|
||||
.get(4)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
|
||||
|
||||
// Architecture (scaling-ladder rung). Defaults = v0-baseline tiny config.
|
||||
let n_heads = flag(&args, "--heads", 2usize);
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
|
||||
let dim_flag = flag(&args, "--dim", 0usize);
|
||||
if dim_flag != 0 && dim_flag != n_heads * head_dim {
|
||||
eprintln!(
|
||||
"warning: --dim {dim_flag} != heads*head_dim {}; using {}",
|
||||
n_heads * head_dim,
|
||||
n_heads * head_dim
|
||||
);
|
||||
}
|
||||
|
||||
// Optimization knobs.
|
||||
let steps: usize = flag(&args, "--steps", 2000);
|
||||
let batch_size: usize = flag(&args, "--batch", 8);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
let weight_decay: f32 = flag(&args, "--wd", 0.1);
|
||||
let max_grad_norm: f32 = flag(&args, "--clip", 1.0);
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
let ckpt: PathBuf = PathBuf::from(
|
||||
args.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "/tmp/xtrain_tinystories.ckpt".to_string()),
|
||||
);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
println!(
|
||||
"loading tokenizer {} + corpus {}",
|
||||
"loading tokenizer {} + corpus {} (cached id stream)",
|
||||
tok_path.display(),
|
||||
corpus_path.display()
|
||||
);
|
||||
let corpus = Corpus::load(&tok_path, &corpus_path);
|
||||
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
let vocab = corpus.vocab_size;
|
||||
// Hold out a tail slice for validation (if requested and the corpus is big).
|
||||
let (train_corpus, valid) = if val_tokens > 0 {
|
||||
let (t, v) = corpus.split_tail(val_tokens);
|
||||
println!("split: {} train tokens / {} val tokens", t.len(), v.len());
|
||||
(t, Some(v))
|
||||
} else {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
// Tiny model sized to the BPE vocab. A real (but small) config: wider than
|
||||
// the overfit test so it has capacity to learn English structure.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = corpus.vocab_size;
|
||||
cfg.n_layers = 4;
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} ffn {} → {} params",
|
||||
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
(+ embed/lm {:.2}M = {:.2}M total)",
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.head_dim,
|
||||
cfg.ffn_hidden,
|
||||
cfg.num_params()
|
||||
cfg.core_params() as f32 / 1e6,
|
||||
(cfg.num_params() - cfg.core_params()) as f32 / 1e6,
|
||||
cfg.num_params() as f32 / 1e6,
|
||||
);
|
||||
|
||||
let mut seed = 1u64;
|
||||
@@ -111,33 +167,45 @@ fn main() {
|
||||
}
|
||||
});
|
||||
|
||||
let seq_len = 64;
|
||||
let tcfg = TrainConfig {
|
||||
seq_len,
|
||||
batch_size: 8,
|
||||
batch_size,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
max_lr,
|
||||
min_lr,
|
||||
warmup: (steps / 20).max(20),
|
||||
total: steps,
|
||||
},
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
weight_decay,
|
||||
max_grad_norm,
|
||||
log_every: 50,
|
||||
ckpt_path: Some(ckpt.clone()),
|
||||
ckpt_every: 500,
|
||||
eval_every,
|
||||
eval_batches,
|
||||
seed: 42,
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}",
|
||||
tcfg.steps, tcfg.seq_len, tcfg.batch_size, tcfg.schedule.max_lr, tcfg.schedule.min_lr
|
||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}, eval every {}",
|
||||
tcfg.steps,
|
||||
tcfg.seq_len,
|
||||
tcfg.batch_size,
|
||||
tcfg.schedule.max_lr,
|
||||
tcfg.schedule.min_lr,
|
||||
tcfg.eval_every
|
||||
);
|
||||
let losses = train(&model, device, &corpus, &tcfg);
|
||||
let start = losses.first().copied().unwrap_or(0.0);
|
||||
let end = losses.last().copied().unwrap_or(0.0);
|
||||
println!("loss: start {start:.4} → end {end:.4}");
|
||||
let result = train(&model, device, &train_corpus, valid.as_ref(), &tcfg);
|
||||
let start = result.train_losses.first().copied().unwrap_or(0.0);
|
||||
let end = result.train_losses.last().copied().unwrap_or(0.0);
|
||||
println!("train loss: start {start:.4} → end {end:.4}");
|
||||
if let Some(best) = result.best_val {
|
||||
println!("best val loss: {best:.4}");
|
||||
}
|
||||
if let Some((s, v)) = result.evals.last() {
|
||||
println!("final val loss (step {s}): {v:.4}");
|
||||
}
|
||||
|
||||
sample_some(&model, device, &tok_path);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
//! Data pipeline: load the GPT-2 BPE (reusing xserv's from-scratch tokenizer),
|
||||
//! tokenize a text corpus into one flat token stream, and sample fixed-length
|
||||
//! `(input, target)` windows for next-token prediction. Host-only (no GPU).
|
||||
//!
|
||||
//! For the scaling runs the corpus is large (full TinyStories ≈ 2 GB / ~470 M
|
||||
//! tokens), and the from-scratch BPE is slow, so [`Corpus::load_cached`]
|
||||
//! tokenizes ONCE and caches the id stream to a `<corpus>.u16.bin` next to the
|
||||
//! text (GPT-2 vocab = 50257 < 65536, so u16 is exact). Subsequent runs mmap-read
|
||||
//! the cache instead of re-tokenizing.
|
||||
|
||||
use std::path::Path;
|
||||
use std::io::{BufReader, BufWriter, Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
/// A tokenized corpus: one flat stream of token ids, plus the vocab size.
|
||||
@@ -30,6 +37,54 @@ impl Corpus {
|
||||
}
|
||||
}
|
||||
|
||||
/// Like [`load`](Self::load) but caches the tokenized id stream to
|
||||
/// `<corpus_path>.u16.bin`. On the first run it tokenizes the (large) corpus
|
||||
/// and writes the cache; on later runs it reads the cache directly, skipping
|
||||
/// the slow BPE. The cache is just a flat little-endian `[u16]` (no header) —
|
||||
/// it is keyed only by path, so delete it if the corpus or tokenizer changes.
|
||||
pub fn load_cached(tokenizer_path: &Path, corpus_path: &Path) -> Self {
|
||||
let cache = cache_path(corpus_path);
|
||||
let vocab_size = Tokenizer::from_file(tokenizer_path).vocab_size();
|
||||
if cache.exists() {
|
||||
let tokens = read_u16_cache(&cache);
|
||||
println!(
|
||||
"corpus: read {} cached tokens from {}",
|
||||
tokens.len(),
|
||||
cache.display()
|
||||
);
|
||||
return Self { tokens, vocab_size };
|
||||
}
|
||||
let me = Self::load(tokenizer_path, corpus_path);
|
||||
write_u16_cache(&cache, &me.tokens);
|
||||
println!(
|
||||
"corpus: tokenized {} tokens → cached to {}",
|
||||
me.tokens.len(),
|
||||
cache.display()
|
||||
);
|
||||
me
|
||||
}
|
||||
|
||||
/// Split off the last `n` tokens as a held-out validation corpus, leaving the
|
||||
/// rest as the train corpus. Returns `(train, valid)`. Used for periodic val
|
||||
/// loss during training without leaking the eval window into training.
|
||||
pub fn split_tail(self, n: usize) -> (Self, Self) {
|
||||
let n = n.min(self.tokens.len() / 10); // never hand off more than 10%
|
||||
let cut = self.tokens.len() - n;
|
||||
let valid = self.tokens[cut..].to_vec();
|
||||
let mut train = self.tokens;
|
||||
train.truncate(cut);
|
||||
(
|
||||
Self {
|
||||
tokens: train,
|
||||
vocab_size: self.vocab_size,
|
||||
},
|
||||
Self {
|
||||
tokens: valid,
|
||||
vocab_size: self.vocab_size,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Total number of tokens.
|
||||
pub fn len(&self) -> usize {
|
||||
self.tokens.len()
|
||||
@@ -65,6 +120,40 @@ fn trim_to_whole_stories(text: &str) -> &str {
|
||||
}
|
||||
}
|
||||
|
||||
/// `<corpus_path>.u16.bin` — the token-id cache beside the corpus text.
|
||||
fn cache_path(corpus_path: &Path) -> PathBuf {
|
||||
let mut s = corpus_path.as_os_str().to_os_string();
|
||||
s.push(".u16.bin");
|
||||
PathBuf::from(s)
|
||||
}
|
||||
|
||||
/// Read a flat little-endian `[u16]` cache into an `i32` id stream.
|
||||
fn read_u16_cache(path: &Path) -> Vec<i32> {
|
||||
let mut r = BufReader::new(
|
||||
std::fs::File::open(path).unwrap_or_else(|e| panic!("open cache {}: {e}", path.display())),
|
||||
);
|
||||
let mut buf = Vec::new();
|
||||
r.read_to_end(&mut buf).expect("read cache");
|
||||
assert!(buf.len() % 2 == 0, "corrupt u16 cache (odd byte count)");
|
||||
buf.chunks_exact(2)
|
||||
.map(|b| u16::from_le_bytes([b[0], b[1]]) as i32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Write an id stream as a flat little-endian `[u16]` cache. Ids must fit in u16
|
||||
/// (GPT-2 vocab = 50257 < 65536); asserts otherwise.
|
||||
fn write_u16_cache(path: &Path, tokens: &[i32]) {
|
||||
let mut w = BufWriter::new(
|
||||
std::fs::File::create(path)
|
||||
.unwrap_or_else(|e| panic!("create cache {}: {e}", path.display())),
|
||||
);
|
||||
for &t in tokens {
|
||||
assert!((0..=u16::MAX as i32).contains(&t), "token id {t} > u16");
|
||||
w.write_all(&(t as u16).to_le_bytes()).expect("write cache");
|
||||
}
|
||||
w.flush().expect("flush cache");
|
||||
}
|
||||
|
||||
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
||||
/// sampling is reproducible from a single u64 seed.
|
||||
fn next_rand(state: &mut u64) -> u64 {
|
||||
|
||||
@@ -19,4 +19,4 @@ pub mod sample;
|
||||
mod train_loop;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use train_loop::{TrainConfig, train};
|
||||
pub use train_loop::{TrainConfig, TrainResult, train};
|
||||
|
||||
@@ -31,28 +31,47 @@ pub struct TrainConfig {
|
||||
pub max_grad_norm: f32,
|
||||
pub log_every: usize,
|
||||
/// Optional checkpoint path written every `ckpt_every` steps (and at the end).
|
||||
/// When `eval_every > 0`, the checkpoint instead tracks the BEST val loss.
|
||||
pub ckpt_path: Option<PathBuf>,
|
||||
pub ckpt_every: usize,
|
||||
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Each eval
|
||||
/// averages cross-entropy over `eval_batches` fixed windows of the val corpus.
|
||||
pub eval_every: usize,
|
||||
pub eval_batches: usize,
|
||||
/// Seed for reproducible sequence sampling.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Outcome of a run: per-step train losses and (step, val_loss) eval points.
|
||||
pub struct TrainResult {
|
||||
pub train_losses: Vec<f32>,
|
||||
pub evals: Vec<(usize, f32)>,
|
||||
pub best_val: Option<f32>,
|
||||
}
|
||||
|
||||
/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step
|
||||
/// loss trace (one mean loss per step, read from the first sequence of the
|
||||
/// batch — cheap and representative). Logs progress and checkpoints as configured.
|
||||
/// train-loss trace plus any (step, val_loss) eval points. Logs progress, and —
|
||||
/// when `valid` is given and `cfg.eval_every > 0` — evaluates held-out val loss
|
||||
/// periodically and checkpoints the BEST val model (else checkpoints on a fixed
|
||||
/// cadence, as in T6). Logs progress.
|
||||
pub fn train(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
corpus: &Corpus,
|
||||
valid: Option<&Corpus>,
|
||||
cfg: &TrainConfig,
|
||||
) -> Vec<f32> {
|
||||
) -> TrainResult {
|
||||
let params = model.params();
|
||||
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
||||
let mut rng = cfg.seed;
|
||||
let mut losses = Vec::with_capacity(cfg.steps);
|
||||
let mut evals = Vec::new();
|
||||
let mut best_val: Option<f32> = None;
|
||||
let inv_batch = 1.0 / cfg.batch_size as f32;
|
||||
let start = Instant::now();
|
||||
let mut tokens_seen: u64 = 0;
|
||||
// Best-val checkpointing only kicks in when we actually evaluate.
|
||||
let track_best = valid.is_some() && cfg.eval_every > 0;
|
||||
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
@@ -88,18 +107,86 @@ pub fn train(
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
|
||||
checkpoint::save(path, ¶ms).expect("checkpoint save");
|
||||
// Periodic held-out eval (deterministic windows, no grad).
|
||||
if let Some(v) = valid {
|
||||
if cfg.eval_every > 0 && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
|
||||
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
|
||||
evals.push((step, vl));
|
||||
let improved = best_val.map(|b| vl < b).unwrap_or(true);
|
||||
println!(
|
||||
" eval @ step {step}: val loss {vl:.4}{}",
|
||||
if improved { " (best)" } else { "" }
|
||||
);
|
||||
if improved {
|
||||
best_val = Some(vl);
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
checkpoint::save(path, ¶ms).expect("best checkpoint save");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fixed-cadence checkpointing (only when not tracking best val).
|
||||
if !track_best {
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
|
||||
checkpoint::save(path, ¶ms).expect("checkpoint save");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
checkpoint::save(path, ¶ms).expect("final checkpoint save");
|
||||
println!("saved checkpoint → {}", path.display());
|
||||
// Without periodic eval, still persist the final params (T6 behaviour). With
|
||||
// best-val tracking the checkpoint already holds the best model — don't clobber.
|
||||
if !track_best {
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
checkpoint::save(path, ¶ms).expect("final checkpoint save");
|
||||
println!("saved checkpoint → {}", path.display());
|
||||
}
|
||||
}
|
||||
TrainResult {
|
||||
train_losses: losses,
|
||||
evals,
|
||||
best_val,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean cross-entropy over `batches` deterministic, non-overlapping windows of
|
||||
/// the validation corpus (no backward — eval only). Deterministic so val loss is
|
||||
/// comparable across steps and runs.
|
||||
fn eval_loss(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
valid: &Corpus,
|
||||
seq: usize,
|
||||
batches: usize,
|
||||
) -> f32 {
|
||||
if valid.len() <= seq + 1 {
|
||||
return f32::NAN;
|
||||
}
|
||||
let n_win = (valid.len() - 1) / seq; // disjoint windows that fit
|
||||
let batches = batches.max(1).min(n_win.max(1));
|
||||
let stride = (n_win / batches).max(1);
|
||||
let mut sum = 0.0f32;
|
||||
let mut count = 0usize;
|
||||
for i in 0..batches {
|
||||
let s = (i * stride) * seq;
|
||||
if s + seq + 1 > valid.len() {
|
||||
break;
|
||||
}
|
||||
let input: Vec<i32> = valid.tokens[s..s + seq].to_vec();
|
||||
let target: Vec<i32> = valid.tokens[s + 1..s + seq + 1].to_vec();
|
||||
let ids = ids_tensor(&input, device);
|
||||
let targets = ids_tensor(&target, device);
|
||||
let loss = model.loss(&ids, &targets);
|
||||
sum += read_scalar(&loss);
|
||||
count += 1;
|
||||
}
|
||||
if count == 0 {
|
||||
f32::NAN
|
||||
} else {
|
||||
sum / count as f32
|
||||
}
|
||||
losses
|
||||
}
|
||||
|
||||
fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
|
||||
|
||||
@@ -96,10 +96,12 @@ fn trains_on_tinystories() {
|
||||
log_every: 50,
|
||||
ckpt_path: None,
|
||||
ckpt_every: 0,
|
||||
eval_every: 0,
|
||||
eval_batches: 0,
|
||||
seed: 42,
|
||||
};
|
||||
|
||||
let losses = train(&model, device, &corpus, &tcfg);
|
||||
let losses = train(&model, device, &corpus, None, &tcfg).train_losses;
|
||||
// Average the first/last few steps to smooth per-step noise.
|
||||
let head: f32 =
|
||||
losses[..10.min(losses.len())].iter().sum::<f32>() / 10.0_f32.min(losses.len() as f32);
|
||||
|
||||
Reference in New Issue
Block a user