diff --git a/crates/xtrain-train/tests/real_training.rs b/crates/xtrain-train/tests/real_training.rs index 5d6a4e6..70e186c 100644 --- a/crates/xtrain-train/tests/real_training.rs +++ b/crates/xtrain-train/tests/real_training.rs @@ -46,9 +46,15 @@ fn trains_on_tinystories() { std::env::var("XTRAIN_TOKENIZER") .unwrap_or_else(|_| "/opt/wjh/models/gpt2/tokenizer.json".into()), ); - let corpus_path = PathBuf::from( - std::env::var("XTRAIN_CORPUS").unwrap_or_else(|_| "data/tinystories-valid-3mb.txt".into()), - ); + // Default resolves relative to the repo root (cargo runs tests with cwd = + // crate dir, so `../../data/...` from crates/xtrain-train); override with + // XTRAIN_CORPUS for any other location. + let corpus_path = PathBuf::from(std::env::var("XTRAIN_CORPUS").unwrap_or_else(|_| { + format!( + "{}/../../data/tinystories-valid-3mb.txt", + env!("CARGO_MANIFEST_DIR") + ) + })); let corpus = Corpus::load(&tok_path, &corpus_path); println!(