Files
xtrain/crates/xtrain-model/tests/parity_dump.rs
Gahow Wang 7a4f69e430 model: add per-head QK-norm (Qwen3-compat) for xserv export
xserv's Qwen3 forward unconditionally applies per-head RMSNorm to Q and K
(q_norm/k_norm, shape [head_dim]) before RoPE — even gamma=1 is a real RMS
divide, not identity. xtrain never had this, so an exact xserv<->xtrain loop
was structurally impossible. Add it (reusing the 2D rms_norm op on the
[seq*nh, hd] head rows, inserted between reshape and rope to mirror
qwen3.rs's order) so the trained model is genuinely Qwen3-compatible.

params() inserts q_norm,k_norm after wv; num_params() counts them; the
PyTorch parity refs (parity.py / adamw_parity.py) + their name lists add the
same step so the dumps stay self-consistent.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 17:33:19 +08:00

163 lines
5.3 KiB
Rust

// PyTorch parity, step 1 of 2: dump the Rust tiny-transformer's exact weights,
// inputs, forward logits, loss, and per-parameter gradients (after one backward)
// to a directory, so an equivalent PyTorch model (tests/parity.py) can be built
// from the SAME weights and the forward + grads compared within rtol.
//
// Run: XTRAIN_PARITY_DIR=/tmp/xtrain_parity cargo test -p xtrain-model \
// --test parity_dump -- --nocapture --ignored
// then: python3 crates/xtrain-model/tests/parity.py /tmp/xtrain_parity
//
// Marked #[ignore] (it's a fixture generator, not a pass/fail assertion) and
// gated #![cfg(not(no_cuda))].
#![cfg(not(no_cuda))]
use std::fs;
use std::io::Write;
use std::path::PathBuf;
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
use xtrain_tensor::Device;
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) {
let mut f = fs::File::create(dir.join(name)).unwrap();
let shape_str: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
writeln!(f, "# shape {}", shape_str.join(",")).unwrap();
for v in data {
writeln!(f, "{v:.8e}").unwrap();
}
}
#[test]
#[ignore = "fixture generator for PyTorch parity; run with --ignored"]
fn dump_for_parity() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let dir = PathBuf::from(
std::env::var("XTRAIN_PARITY_DIR").unwrap_or_else(|_| "/tmp/xtrain_parity".to_string()),
);
fs::create_dir_all(&dir).unwrap();
// Fixed config + ids (independent of any text, for reproducibility).
let mut cfg = Config::tiny();
cfg.vocab = 12;
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
let seq = ids.len();
// Same deterministic init as the overfit test.
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
// config + ids
{
let mut f = fs::File::create(dir.join("config.txt")).unwrap();
writeln!(f, "vocab {}", cfg.vocab).unwrap();
writeln!(f, "dim {}", cfg.dim).unwrap();
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
writeln!(f, "eps {:e}", cfg.eps).unwrap();
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
writeln!(f, "seq {seq}").unwrap();
}
{
let mut f = fs::File::create(dir.join("ids.txt")).unwrap();
for v in &ids {
writeln!(f, "{v}").unwrap();
}
let mut f = fs::File::create(dir.join("targets.txt")).unwrap();
for v in &targets {
writeln!(f, "{v}").unwrap();
}
}
// Stable param order, named to match parity.py.
let names = param_names(&cfg);
let params = model.params();
assert_eq!(names.len(), params.len(), "param name/count mismatch");
for (name, p) in names.iter().zip(&params) {
let shape = p.value().shape().to_vec();
write_vec(&dir, &format!("w_{name}.txt"), &param_to_host(p), &shape);
}
// Forward logits + loss, then backward → per-param grads.
let ids_t = ids_tensor(&ids, device);
let targets_t = ids_tensor(&targets, device);
let logits = model.forward(&ids_t);
write_vec(
&dir,
"logits.txt",
&param_to_host(&logits),
logits.value().shape(),
);
let loss = model.loss(&ids_t, &targets_t);
let loss_val = param_to_host(&loss)[0];
{
let mut f = fs::File::create(dir.join("loss.txt")).unwrap();
writeln!(f, "{loss_val:.8e}").unwrap();
}
loss.backward();
for (name, p) in names.iter().zip(&params) {
let g = p.grad().expect("param has no grad");
let gh = g.to_device(Device::Cpu);
write_vec(
&dir,
&format!("g_{name}.txt"),
gh.as_slice::<f32>(),
g.shape(),
);
}
println!("parity: dumped to {} (loss={loss_val:.6e})", dir.display());
}
fn param_names(cfg: &Config) -> Vec<String> {
let mut names = vec!["embed".to_string()];
for l in 0..cfg.n_layers {
for p in [
"attn_norm",
"wq",
"wk",
"wv",
"q_norm",
"k_norm",
"wo",
"ffn_norm",
"w_gate",
"w_up",
"w_down",
] {
names.push(format!("l{l}_{p}"));
}
}
names.push("final_norm".to_string());
names.push("lm_head".to_string());
names
}