train: --eval-ckpt eval-only mode (v0-vs-v1 same-set val loss)
Expose eval_loss() and add a --eval-ckpt <path> branch to bin/train: load an existing checkpoint into a model of the given arch and score it on the held-out val split, then exit. Lets v0 and v1 be measured on the identical validation set (the acceptance metric) without a separate eval binary. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -167,6 +167,20 @@ fn main() {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||||
|
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||||
|
// metric — the v0-vs-v1 val-loss comparison. The arch flags must match the ckpt.
|
||||||
|
if let Some(p) = args.iter().position(|a| a == "--eval-ckpt") {
|
||||||
|
let ckpt_path = PathBuf::from(args.get(p + 1).expect("--eval-ckpt <path>"));
|
||||||
|
xtrain_train::checkpoint::load_into(&ckpt_path, &model.params())
|
||||||
|
.expect("load eval checkpoint");
|
||||||
|
let v = valid.expect("--eval-ckpt needs --val-tokens > 0");
|
||||||
|
let vl = xtrain_train::eval_loss(&model, device, &v, seq_len, eval_batches);
|
||||||
|
println!("eval-only: {} → val loss {vl:.4}", ckpt_path.display());
|
||||||
|
sample_some(&model, device, &tok_path);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let tcfg = TrainConfig {
|
let tcfg = TrainConfig {
|
||||||
seq_len,
|
seq_len,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
|||||||
@@ -19,4 +19,4 @@ pub mod sample;
|
|||||||
mod train_loop;
|
mod train_loop;
|
||||||
|
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
pub use train_loop::{TrainConfig, TrainResult, train};
|
pub use train_loop::{TrainConfig, TrainResult, eval_loss, train};
|
||||||
|
|||||||
@@ -153,8 +153,8 @@ pub fn train(
|
|||||||
|
|
||||||
/// Mean cross-entropy over `batches` deterministic, non-overlapping windows of
|
/// Mean cross-entropy over `batches` deterministic, non-overlapping windows of
|
||||||
/// the validation corpus (no backward — eval only). Deterministic so val loss is
|
/// the validation corpus (no backward — eval only). Deterministic so val loss is
|
||||||
/// comparable across steps and runs.
|
/// comparable across steps and runs (and across models — the v0-vs-v1 metric).
|
||||||
fn eval_loss(
|
pub fn eval_loss(
|
||||||
model: &TinyTransformer,
|
model: &TinyTransformer,
|
||||||
device: Device,
|
device: Device,
|
||||||
valid: &Corpus,
|
valid: &Corpus,
|
||||||
|
|||||||
Reference in New Issue
Block a user