- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each kv head sums its group of query-head grads; no atomics) + Tensor/ops node. - Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim; attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical. - --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export writes real num_key_value_heads (xserv repeat_kv grouping aligned). - Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape); parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
191 lines
6.9 KiB
Rust
191 lines
6.9 KiB
Rust
// PyTorch parity, step 1 of 2: dump the Rust tiny-transformer's exact weights,
|
|
// inputs, forward logits, loss, and per-parameter gradients (after one backward)
|
|
// to a directory, so an equivalent PyTorch model (tests/parity.py) can be built
|
|
// from the SAME weights and the forward + grads compared within rtol.
|
|
//
|
|
// Run: XTRAIN_PARITY_DIR=/tmp/xtrain_parity cargo test -p xtrain-model \
|
|
// --test parity_dump -- --nocapture --ignored
|
|
// then: python3 crates/xtrain-model/tests/parity.py /tmp/xtrain_parity
|
|
//
|
|
// Marked #[ignore] (it's a fixture generator, not a pass/fail assertion) and
|
|
// gated #![cfg(not(no_cuda))].
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use std::fs;
|
|
use std::io::Write;
|
|
use std::path::PathBuf;
|
|
use xtrain_cuda::device;
|
|
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
|
|
use xtrain_tensor::Device;
|
|
|
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) {
|
|
let mut f = fs::File::create(dir.join(name)).unwrap();
|
|
let shape_str: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
|
|
writeln!(f, "# shape {}", shape_str.join(",")).unwrap();
|
|
for v in data {
|
|
writeln!(f, "{v:.8e}").unwrap();
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
#[ignore = "fixture generator for PyTorch parity; run with --ignored"]
|
|
fn dump_for_parity() {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
|
|
let dir = PathBuf::from(
|
|
std::env::var("XTRAIN_PARITY_DIR").unwrap_or_else(|_| "/tmp/xtrain_parity".to_string()),
|
|
);
|
|
fs::create_dir_all(&dir).unwrap();
|
|
|
|
// Fixed config + ids (independent of any text, for reproducibility). B>1 so
|
|
// the batched forward is exercised: 2 sequences of length 4, flattened
|
|
// sequence-major to [B*S]=8 ids. Per-sequence RoPE position (resets at the
|
|
// sequence boundary) + per-sequence causal masking (no cross-sequence
|
|
// attention) are both checked against PyTorch.
|
|
// Default: tiny MHA (2 heads). With XTRAIN_PARITY_KV_HEADS=k set, dump a real
|
|
// GQA config (8 query heads / k kv heads) so parity.py checks GQA at B>1 — the
|
|
// kv-projection shapes + the repeat_kv group-sum backward against PyTorch.
|
|
let mut cfg = Config::tiny();
|
|
cfg.vocab = 12;
|
|
if let Ok(kv) = std::env::var("XTRAIN_PARITY_KV_HEADS") {
|
|
let kv: usize = kv.parse().expect("XTRAIN_PARITY_KV_HEADS");
|
|
cfg = Config::from_arch(cfg.vocab, 8, cfg.head_dim, cfg.n_layers, cfg.ffn_hidden)
|
|
.with_kv_heads(kv);
|
|
println!(
|
|
"parity: GQA config (n_heads {} kv_heads {})",
|
|
cfg.n_heads, cfg.num_kv_heads
|
|
);
|
|
}
|
|
let batch = 2usize;
|
|
let seq = 4usize;
|
|
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6]; // [B*S], sequence-major
|
|
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
|
|
|
|
// Same deterministic init as the overfit test.
|
|
let mut seed = 1u64;
|
|
let mut model = TinyTransformer::new(cfg, device, |shape| {
|
|
seed = seed.wrapping_add(1);
|
|
let n: usize = shape.iter().product();
|
|
if shape.len() == 1 {
|
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
|
} else {
|
|
fill(n, seed, 0.08)
|
|
}
|
|
});
|
|
// T14: with XTRAIN_PARITY_FLASH set, dump from the fused flash-attention path.
|
|
// flash is the SAME SDPA math, so the SAME parity.py PyTorch oracle is the
|
|
// reference for both paths — running this once per path checks flash against
|
|
// PyTorch at B>1 (forward logits + every parameter grad).
|
|
if std::env::var("XTRAIN_PARITY_FLASH").is_ok() {
|
|
model = model.with_flash(true);
|
|
println!("parity: FLASH attention path");
|
|
}
|
|
|
|
// config + ids
|
|
{
|
|
let mut f = fs::File::create(dir.join("config.txt")).unwrap();
|
|
writeln!(f, "vocab {}", cfg.vocab).unwrap();
|
|
writeln!(f, "dim {}", cfg.dim).unwrap();
|
|
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
|
|
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
|
|
writeln!(f, "num_kv_heads {}", cfg.num_kv_heads).unwrap();
|
|
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
|
|
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
|
|
writeln!(f, "eps {:e}", cfg.eps).unwrap();
|
|
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
|
|
writeln!(f, "batch {batch}").unwrap();
|
|
writeln!(f, "seq {seq}").unwrap();
|
|
}
|
|
{
|
|
let mut f = fs::File::create(dir.join("ids.txt")).unwrap();
|
|
for v in &ids {
|
|
writeln!(f, "{v}").unwrap();
|
|
}
|
|
let mut f = fs::File::create(dir.join("targets.txt")).unwrap();
|
|
for v in &targets {
|
|
writeln!(f, "{v}").unwrap();
|
|
}
|
|
}
|
|
|
|
// Stable param order, named to match parity.py.
|
|
let names = param_names(&cfg);
|
|
let params = model.params();
|
|
assert_eq!(names.len(), params.len(), "param name/count mismatch");
|
|
for (name, p) in names.iter().zip(¶ms) {
|
|
let shape = p.value().shape().to_vec();
|
|
write_vec(&dir, &format!("w_{name}.txt"), ¶m_to_host(p), &shape);
|
|
}
|
|
|
|
// Batched forward logits + loss (B sequences as one forward), then backward
|
|
// → per-param grads.
|
|
let ids_t = ids_tensor(&ids, device);
|
|
let targets_t = ids_tensor(&targets, device);
|
|
let logits = model.forward_batched(&ids_t, batch);
|
|
write_vec(
|
|
&dir,
|
|
"logits.txt",
|
|
¶m_to_host(&logits),
|
|
logits.value().shape(),
|
|
);
|
|
|
|
let loss = model.loss_batched(&ids_t, &targets_t, batch);
|
|
let loss_val = param_to_host(&loss)[0];
|
|
{
|
|
let mut f = fs::File::create(dir.join("loss.txt")).unwrap();
|
|
writeln!(f, "{loss_val:.8e}").unwrap();
|
|
}
|
|
loss.backward();
|
|
for (name, p) in names.iter().zip(¶ms) {
|
|
let g = p.grad().expect("param has no grad");
|
|
let gh = g.to_device(Device::Cpu);
|
|
write_vec(
|
|
&dir,
|
|
&format!("g_{name}.txt"),
|
|
gh.as_slice::<f32>(),
|
|
g.shape(),
|
|
);
|
|
}
|
|
|
|
println!("parity: dumped to {} (loss={loss_val:.6e})", dir.display());
|
|
}
|
|
|
|
fn param_names(cfg: &Config) -> Vec<String> {
|
|
let mut names = vec!["embed".to_string()];
|
|
for l in 0..cfg.n_layers {
|
|
for p in [
|
|
"attn_norm",
|
|
"wq",
|
|
"wk",
|
|
"wv",
|
|
"q_norm",
|
|
"k_norm",
|
|
"wo",
|
|
"ffn_norm",
|
|
"w_gate",
|
|
"w_up",
|
|
"w_down",
|
|
] {
|
|
names.push(format!("l{l}_{p}"));
|
|
}
|
|
}
|
|
names.push("final_norm".to_string());
|
|
names.push("lm_head".to_string());
|
|
names
|
|
}
|