xserv's Qwen3 forward unconditionally applies per-head RMSNorm to Q and K (q_norm/k_norm, shape [head_dim]) before RoPE — even gamma=1 is a real RMS divide, not identity. xtrain never had this, so an exact xserv<->xtrain loop was structurally impossible. Add it (reusing the 2D rms_norm op on the [seq*nh, hd] head rows, inserted between reshape and rope to mirror qwen3.rs's order) so the trained model is genuinely Qwen3-compatible. params() inserts q_norm,k_norm after wv; num_params() counts them; the PyTorch parity refs (parity.py / adamw_parity.py) + their name lists add the same step so the dumps stay self-consistent. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
174 lines
6.1 KiB
Rust
174 lines
6.1 KiB
Rust
// AdamW-vs-PyTorch parity, step 1 of 2: build the tiny transformer with a fixed
|
|
// deterministic init, then run N steps of the hand-written AdamW on a FIXED
|
|
// (input, target) batch — recording the loss at each step and the final
|
|
// parameters. tests/adamw_parity.py rebuilds the identical model + torch.optim
|
|
// .AdamW with matched hyperparameters and compares the loss trajectory and final
|
|
// params within rtol. This is the rigorous correctness check for the optimizer.
|
|
//
|
|
// Run: XTRAIN_ADAMW_DIR=/tmp/xtrain_adamw cargo test -p xtrain-train \
|
|
// --test adamw_parity_dump -- --nocapture --ignored
|
|
// then: python3 crates/xtrain-train/tests/adamw_parity.py /tmp/xtrain_adamw
|
|
//
|
|
// Marked #[ignore] (fixture generator) and gated #![cfg(not(no_cuda))].
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use std::fs;
|
|
use std::io::Write;
|
|
use std::path::PathBuf;
|
|
|
|
use xtrain_cuda::device;
|
|
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
|
|
use xtrain_optim::AdamW;
|
|
use xtrain_tensor::Device;
|
|
|
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) {
|
|
let mut f = fs::File::create(dir.join(name)).unwrap();
|
|
let shape_str: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
|
|
writeln!(f, "# shape {}", shape_str.join(",")).unwrap();
|
|
for v in data {
|
|
writeln!(f, "{v:.8e}").unwrap();
|
|
}
|
|
}
|
|
|
|
const LR: f32 = 0.01;
|
|
const WD: f32 = 0.1;
|
|
// Kept short on purpose: AdamW correctness shows in the per-step loss trajectory
|
|
// and the parameter values *while the loss is still well-determined*. Run it long
|
|
// enough to memorise the tiny batch and the model enters a flat, overparameterised
|
|
// region where many weight configs give the same loss — there f32(GPU) vs the
|
|
// torch reference diverge per-weight (large *relative* error on tiny weights)
|
|
// while the loss stays identical. 10 steps keeps both signals sharp.
|
|
const N_STEPS: usize = 10;
|
|
|
|
#[test]
|
|
#[ignore = "fixture generator for AdamW PyTorch parity; run with --ignored"]
|
|
fn dump_adamw_trajectory() {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
|
|
let dir = PathBuf::from(
|
|
std::env::var("XTRAIN_ADAMW_DIR").unwrap_or_else(|_| "/tmp/xtrain_adamw".to_string()),
|
|
);
|
|
fs::create_dir_all(&dir).unwrap();
|
|
|
|
let mut cfg = Config::tiny();
|
|
cfg.vocab = 12;
|
|
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
|
|
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
|
|
|
|
// Same deterministic init the parity dump uses (so the torch side can reuse it).
|
|
let mut seed = 1u64;
|
|
let model = TinyTransformer::new(cfg, device, |shape| {
|
|
seed = seed.wrapping_add(1);
|
|
let n: usize = shape.iter().product();
|
|
if shape.len() == 1 {
|
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
|
} else {
|
|
fill(n, seed, 0.08)
|
|
}
|
|
});
|
|
|
|
// Dump config + ids + initial params (named for adamw_parity.py).
|
|
{
|
|
let mut f = fs::File::create(dir.join("config.txt")).unwrap();
|
|
writeln!(f, "vocab {}", cfg.vocab).unwrap();
|
|
writeln!(f, "dim {}", cfg.dim).unwrap();
|
|
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
|
|
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
|
|
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
|
|
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
|
|
writeln!(f, "eps {:e}", cfg.eps).unwrap();
|
|
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
|
|
writeln!(f, "lr {LR:e}").unwrap();
|
|
writeln!(f, "wd {WD:e}").unwrap();
|
|
writeln!(f, "n_steps {N_STEPS}").unwrap();
|
|
let mut g = fs::File::create(dir.join("ids.txt")).unwrap();
|
|
for v in &ids {
|
|
writeln!(g, "{v}").unwrap();
|
|
}
|
|
let mut g = fs::File::create(dir.join("targets.txt")).unwrap();
|
|
for v in &targets {
|
|
writeln!(g, "{v}").unwrap();
|
|
}
|
|
}
|
|
|
|
let names = param_names(&cfg);
|
|
let params = model.params();
|
|
for (name, p) in names.iter().zip(¶ms) {
|
|
let shape = p.value().shape().to_vec();
|
|
write_vec(&dir, &format!("w0_{name}.txt"), ¶m_to_host(p), &shape);
|
|
}
|
|
|
|
// Train N steps of AdamW with a CONSTANT lr (no schedule) on the fixed batch.
|
|
let ids_t = ids_tensor(&ids, device);
|
|
let targets_t = ids_tensor(&targets, device);
|
|
let mut opt = AdamW::new(LR, WD);
|
|
let mut losses = Vec::with_capacity(N_STEPS);
|
|
for _ in 0..N_STEPS {
|
|
let loss = model.loss(&ids_t, &targets_t);
|
|
losses.push(param_to_host(&loss)[0]);
|
|
loss.backward();
|
|
opt.step(LR, ¶ms);
|
|
for p in ¶ms {
|
|
p.zero_grad();
|
|
}
|
|
}
|
|
|
|
{
|
|
let mut f = fs::File::create(dir.join("losses.txt")).unwrap();
|
|
for l in &losses {
|
|
writeln!(f, "{l:.8e}").unwrap();
|
|
}
|
|
}
|
|
for (name, p) in names.iter().zip(¶ms) {
|
|
let shape = p.value().shape().to_vec();
|
|
write_vec(&dir, &format!("wN_{name}.txt"), ¶m_to_host(p), &shape);
|
|
}
|
|
|
|
println!(
|
|
"adamw parity: dumped to {} (loss {:.6e} → {:.6e} over {N_STEPS} steps)",
|
|
dir.display(),
|
|
losses.first().unwrap(),
|
|
losses.last().unwrap()
|
|
);
|
|
}
|
|
|
|
fn param_names(cfg: &Config) -> Vec<String> {
|
|
let mut names = vec!["embed".to_string()];
|
|
for l in 0..cfg.n_layers {
|
|
for p in [
|
|
"attn_norm",
|
|
"wq",
|
|
"wk",
|
|
"wv",
|
|
"q_norm",
|
|
"k_norm",
|
|
"wo",
|
|
"ffn_norm",
|
|
"w_gate",
|
|
"w_up",
|
|
"w_down",
|
|
] {
|
|
names.push(format!("l{l}_{p}"));
|
|
}
|
|
}
|
|
names.push("final_norm".to_string());
|
|
names.push("lm_head".to_string());
|
|
names
|
|
}
|