export: dump_logits bin for xserv-vs-xtrain comparison
xtrain-side top-k next-token logit dump (f32 forward, same model/config/ckpt as the exporter) mirroring xserv's dump-logits, so the closed-loop check can compare both sides numerically for the same prompt + weights. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -25,3 +25,7 @@ path = "src/bin/train.rs"
|
||||
[[bin]]
|
||||
name = "export_safetensors"
|
||||
path = "src/bin/export_safetensors.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "dump_logits"
|
||||
path = "src/bin/dump_logits.rs"
|
||||
|
||||
98
crates/xtrain-train/src/bin/dump_logits.rs
Normal file
98
crates/xtrain-train/src/bin/dump_logits.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
//! Phase T9 verification helper — dump xtrain's OWN top-k next-token logits for a
|
||||
//! prompt, so they can be compared against xserv's `dump-logits` on the exported
|
||||
//! model (the closed-loop acceptance check). f32 forward, same model/config/ckpt
|
||||
//! as bin/train.rs + bin/export_safetensors.rs.
|
||||
//!
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! cargo run -p xtrain-train --release --bin dump_logits -- \
|
||||
//! /tmp/xtrain_tinystories.ckpt /opt/wjh/models/gpt2/tokenizer.json "Once upon a time"
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("dump_logits: built without CUDA (no_cuda); run on a GPU host (dash5).");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let ckpt = args
|
||||
.get(1)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
|
||||
let tok_path = args
|
||||
.get(2)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
let prompt = args
|
||||
.get(3)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "Once upon a time".to_string());
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let dev = Device::Cuda(0);
|
||||
|
||||
let tok = Tokenizer::from_file(&tok_path);
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = tok.vocab_size();
|
||||
cfg.n_layers = 4;
|
||||
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, dev, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
});
|
||||
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
|
||||
|
||||
let ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
||||
eprintln!("Prompt: {prompt}");
|
||||
eprintln!("Token IDs: {ids:?}");
|
||||
|
||||
let logits = model
|
||||
.forward(&ids_tensor(&ids, dev))
|
||||
.value()
|
||||
.to_device(Device::Cpu);
|
||||
let lg = logits.as_slice::<f32>();
|
||||
let vocab = cfg.vocab;
|
||||
let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab];
|
||||
|
||||
let mut idx: Vec<(usize, f32)> = last.iter().copied().enumerate().collect();
|
||||
idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
println!("Top-20 logits (last position):");
|
||||
for (rank, (id, val)) in idx.iter().take(20).enumerate() {
|
||||
let t = tok.decode(&[*id as u32]);
|
||||
println!(" [{rank:>2}] id={id:>6} logit={val:>10.4} token={t:?}");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user