diff --git a/crates/xtrain-autodiff/src/ops.rs b/crates/xtrain-autodiff/src/ops.rs index 7e91db8..e7a98a0 100644 --- a/crates/xtrain-autodiff/src/ops.rs +++ b/crates/xtrain-autodiff/src/ops.rs @@ -398,7 +398,8 @@ pub fn repeat_kv(kv: &Var, nh: usize, batch: usize) -> Var { } /// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per -/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`, +/// row. Negative targets are ignored, which is useful for assistant-only SFT +/// masks. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/valid_rows`, /// scaled by the upstream scalar grad. pub fn cross_entropy(x: &Var, target: &Tensor) -> Var { // CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32 @@ -407,10 +408,22 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var { // fp32 logits buffer) is a real activation-memory saving at large vocab. let logit_dtype = x.value().dtype(); let (probs, per_row) = x.value().cross_entropy(target); - let rows = x.value().shape()[0]; + let cols = x.value().shape()[1] as i32; + let target_host = target.to_device(xtrain_tensor::Device::Cpu); + let valid_rows = target_host + .as_slice::() + .iter() + .filter(|&&t| { + if t >= cols { + panic!("cross_entropy target {t} out of vocab range {cols}"); + } + t >= 0 + }) + .count() + .max(1); // Mean loss as a host scalar wrapped back into a [1] tensor. let mean = per_row.to_device(xtrain_tensor::Device::Cpu); - let mean_val: f32 = mean.as_slice::().iter().sum::() / rows as f32; + let mean_val: f32 = mean.as_slice::().iter().sum::() / valid_rows as f32; let loss = Tensor::from_slice(&[mean_val], &[1]).to_device(x.value().device()); let target = target.clone(); @@ -420,7 +433,7 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var { Box::new(move |d, parents| { // `d` is the scalar upstream grad (1.0 when this is the loss root). let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::()[0]; - let scale = upstream / rows as f32; + let scale = upstream / valid_rows as f32; let dx = Tensor::cross_entropy_backward(&probs, &target, scale); Var::push_grad(&parents[0], dx.to_dtype(logit_dtype)); }), diff --git a/crates/xtrain-distributed/src/bin/train_ddp.rs b/crates/xtrain-distributed/src/bin/train_ddp.rs index eb31dd8..be55568 100644 --- a/crates/xtrain-distributed/src/bin/train_ddp.rs +++ b/crates/xtrain-distributed/src/bin/train_ddp.rs @@ -88,6 +88,7 @@ fn main() { let val_tokens: usize = flag(&args, "--val-tokens", 0); let eval_every: usize = flag(&args, "--eval-every", 0); let eval_batches: usize = flag(&args, "--eval-batches", 64); + let sft_tsv = args.iter().any(|a| a == "--sft-tsv"); // Dropout (Phase T18/T21): residual-path dropout prob, active at training time // only (inverted scaling), identity at eval/sampling/export. Default 0 = off // (forward graph bit-identical to the no-dropout path). Mirrors bin/train; the @@ -109,6 +110,11 @@ fn main() { .position(|a| a == "--ckpt") .and_then(|i| args.get(i + 1)) .map(PathBuf::from); + let init_ckpt: Option = args + .iter() + .position(|a| a == "--init-ckpt") + .and_then(|i| args.get(i + 1)) + .map(PathBuf::from); // Use every visible GPU as a rank (CUDA_VISIBLE_DEVICES selects the set; // device ordinals are 0..count within it). @@ -129,12 +135,19 @@ fn main() { ); // Reuse the cached token-id stream (v1's u16 cache); never re-tokenize 2GB. - let corpus = Corpus::load_cached(&tok_path, &corpus_path); + let corpus = if sft_tsv { + Corpus::load_sft_tsv_cached(&tok_path, &corpus_path) + } else { + Corpus::load_cached(&tok_path, &corpus_path) + }; println!( "corpus: {} tokens, vocab {}", corpus.len(), corpus.vocab_size ); + if sft_tsv { + println!("SFT TSV: ON (assistant-only loss via ignore-index labels)"); + } let vocab = corpus.vocab_size; // Hold out a tail slice for validation (rank 0 evaluates on it). let (train_corpus, valid) = if val_tokens > 0 { @@ -200,6 +213,10 @@ fn main() { if dropout > 0.0 { println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)"); } + if let Some(path) = &init_ckpt { + println!("init checkpoint: {}", path.display()); + } + let init_ckpt_for_ranks = init_ckpt.clone(); let results = launch( &devices, &train_corpus, @@ -216,6 +233,10 @@ fn main() { if flash { m = m.with_flash(true); } + if let Some(path) = &init_ckpt_for_ranks { + xtrain_train::checkpoint::load_into(path, &m.params()) + .expect("load init checkpoint"); + } m }, ); diff --git a/crates/xtrain-distributed/tests/ddp_correctness.rs b/crates/xtrain-distributed/tests/ddp_correctness.rs index 8d5d759..bc41abe 100644 --- a/crates/xtrain-distributed/tests/ddp_correctness.rs +++ b/crates/xtrain-distributed/tests/ddp_correctness.rs @@ -27,6 +27,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus { .collect(); Corpus { tokens, + labels: None, vocab_size: vocab, } } diff --git a/crates/xtrain-distributed/tests/ddp_proc.rs b/crates/xtrain-distributed/tests/ddp_proc.rs index 9950569..5cfe88a 100644 --- a/crates/xtrain-distributed/tests/ddp_proc.rs +++ b/crates/xtrain-distributed/tests/ddp_proc.rs @@ -37,6 +37,7 @@ fn synth_corpus() -> Corpus { .collect(); Corpus { tokens, + labels: None, vocab_size: VOCAB, } } diff --git a/crates/xtrain-model/tests/dropout.rs b/crates/xtrain-model/tests/dropout.rs index c8392ab..1cc15a9 100644 --- a/crates/xtrain-model/tests/dropout.rs +++ b/crates/xtrain-model/tests/dropout.rs @@ -66,10 +66,18 @@ fn tiny_cfg(dropout: f32) -> Config { fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) { let (batch, seq) = (3usize, 6usize); let seqs: Vec> = (0..batch) - .map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect()) + .map(|b| { + (0..seq) + .map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32) + .collect() + }) .collect(); let tgts: Vec> = (0..batch) - .map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect()) + .map(|b| { + (0..seq) + .map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32) + .collect() + }) .collect(); ( batched_ids_tensor(&seqs, device), @@ -94,7 +102,11 @@ fn fwd_bwd( let loss = m.loss_batched(ids, tgt, batch); let loss_val = host(&loss.value())[0]; loss.backward(); - let grads: Vec> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect(); + let grads: Vec> = m + .params() + .iter() + .map(|p| host(&p.grad().unwrap())) + .collect(); (logits, loss_val, grads) } @@ -186,7 +198,9 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) { // Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps // to 1 on the first training forward in BOTH, so they draw the same masks. - let off = build(cfg, device).with_compute_dtype(dtype).with_training(true); + let off = build(cfg, device) + .with_compute_dtype(dtype) + .with_training(true); let on = build(cfg, device) .with_compute_dtype(dtype) .with_recompute(true) @@ -194,11 +208,19 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) { let off_loss = off.loss_batched(&ids, &tgt, batch); off_loss.backward(); - let off_grads: Vec> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect(); + let off_grads: Vec> = off + .params() + .iter() + .map(|p| host(&p.grad().unwrap())) + .collect(); let on_loss = on.loss_batched(&ids, &tgt, batch); on_loss.backward(); - let on_grads: Vec> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect(); + let on_grads: Vec> = on + .params() + .iter() + .map(|p| host(&p.grad().unwrap())) + .collect(); let mut max_rel = 0.0f32; for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) { @@ -240,10 +262,18 @@ fn flash_plus_dropout_grad_check_fp32() { cfg.dropout = 0.2; let seq = 40usize; let seqs: Vec> = (0..batch) - .map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect()) + .map(|b| { + (0..seq) + .map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32) + .collect() + }) .collect(); let tgts: Vec> = (0..batch) - .map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect()) + .map(|b| { + (0..seq) + .map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32) + .collect() + }) .collect(); let ids = batched_ids_tensor(&seqs, device); let tgt = batched_ids_tensor(&tgts, device); @@ -277,7 +307,16 @@ fn flash_plus_dropout_grad_check_fp32() { ); // Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash // differs from composed only by reduction order; dropout masks are identical. - assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}"); - assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}"); - assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}"); + assert!( + logit_rel < 1e-3, + "[F32] flash+dropout logits diverged: {logit_rel:.2e}" + ); + assert!( + loss_rel < 1e-3, + "[F32] flash+dropout loss diverged: {loss_rel:.2e}" + ); + assert!( + grad_rel < 2e-2, + "[F32] flash+dropout grads diverged: {grad_rel:.3e}" + ); } diff --git a/crates/xtrain-train/src/bin/greedy_sample.rs b/crates/xtrain-train/src/bin/greedy_sample.rs index 27dbf4f..54f1651 100644 --- a/crates/xtrain-train/src/bin/greedy_sample.rs +++ b/crates/xtrain-train/src/bin/greedy_sample.rs @@ -7,7 +7,8 @@ //! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH //! cargo run -p xtrain-train --release --bin greedy_sample -- \ //! /tmp/xtrain_v4.ckpt /opt/wjh/models/gpt2/tokenizer.json \ -//! --heads 24 --head-dim 32 --layers 18 --ffn 2048 +//! --heads 24 --head-dim 32 --layers 18 --ffn 2048 \ +//! --prompts-file scripts/chat_alpha_fixed_prompts.txt --max-tokens 120 #[cfg(no_cuda)] fn main() { @@ -52,6 +53,60 @@ fn flag(args: &[String], name: &str, default: T) -> T { .unwrap_or(default) } +#[cfg(not(no_cuda))] +fn flag_value(args: &[String], name: &str) -> Option { + args.iter() + .position(|a| a == name) + .and_then(|i| args.get(i + 1)) + .cloned() +} + +#[cfg(not(no_cuda))] +fn flag_values(args: &[String], name: &str) -> Vec { + args.iter() + .enumerate() + .filter_map(|(i, a)| { + if a == name { + args.get(i + 1).cloned() + } else { + None + } + }) + .collect() +} + +#[cfg(not(no_cuda))] +fn decode_prompt_escapes(s: &str) -> String { + s.replace("\\n", "\n").replace("\\t", "\t") +} + +#[cfg(not(no_cuda))] +fn load_prompts(args: &[String]) -> Vec { + let mut prompts = Vec::new(); + if let Some(path) = flag_value(args, "--prompts-file") { + let text = std::fs::read_to_string(&path) + .unwrap_or_else(|e| panic!("failed to read prompts file {path}: {e}")); + prompts.extend( + text.lines() + .map(str::trim) + .filter(|line| !line.is_empty() && !line.starts_with('#')) + .map(decode_prompt_escapes), + ); + } + prompts.extend( + flag_values(args, "--prompt") + .into_iter() + .map(|p| decode_prompt_escapes(&p)), + ); + if prompts.is_empty() { + prompts = ["Once upon a time", "One day", "The little"] + .into_iter() + .map(String::from) + .collect(); + } + prompts +} + #[cfg(not(no_cuda))] fn main() { use xserv_tokenizer::Tokenizer; @@ -75,6 +130,8 @@ fn main() { // GQA (Phase T15): num K/V heads (must match the ckpt; default = --heads). let kv_heads = flag(&args, "--kv-heads", n_heads); let max_new = flag(&args, "--max-tokens", 40usize); + let temperature = flag(&args, "--temperature", 0.0f32); + let prompts = load_prompts(&args); assert!(device::device_count().unwrap() > 0, "no CUDA device"); device::set_device(0).unwrap(); @@ -106,11 +163,16 @@ fn main() { }); xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint"); - let prompts = ["Once upon a time", "One day", "The little"]; + println!( + "decode: prompts={} max_new={} temperature={}", + prompts.len(), + max_new, + temperature + ); for p in prompts { - let ids: Vec = tok.encode(p).into_iter().map(|t| t as i32).collect(); + let ids: Vec = tok.encode(&p).into_iter().map(|t| t as i32).collect(); let mut rng = 7u64; - let out = generate(&model, device, &ids, max_new, 0.0, &mut rng); + let out = generate(&model, device, &ids, max_new, temperature, &mut rng); let text = tok.decode(&out.iter().map(|&t| t as u32).collect::>()); println!("[{p}] → {text}"); } diff --git a/crates/xtrain-train/src/bin/train.rs b/crates/xtrain-train/src/bin/train.rs index 76fe0d8..b2dd270 100644 --- a/crates/xtrain-train/src/bin/train.rs +++ b/crates/xtrain-train/src/bin/train.rs @@ -115,6 +115,7 @@ fn main() { let val_tokens: usize = flag(&args, "--val-tokens", 0); let eval_every: usize = flag(&args, "--eval-every", 0); let eval_batches: usize = flag(&args, "--eval-batches", 64); + let sft_tsv = args.iter().any(|a| a == "--sft-tsv"); // Dropout (Phase T18): residual-path dropout prob, active at training time // only (inverted scaling), identity at eval/sampling/export. Default 0 = off // (forward graph bit-identical to the no-dropout path). @@ -136,6 +137,11 @@ fn main() { .cloned() .unwrap_or_else(|| "/tmp/xtrain_tinystories.ckpt".to_string()), ); + let init_ckpt: Option = args + .iter() + .position(|a| a == "--init-ckpt") + .and_then(|i| args.get(i + 1)) + .map(PathBuf::from); assert!(device::device_count().unwrap() > 0, "no CUDA device"); device::set_device(0).unwrap(); @@ -146,12 +152,19 @@ fn main() { tok_path.display(), corpus_path.display() ); - let corpus = Corpus::load_cached(&tok_path, &corpus_path); + let corpus = if sft_tsv { + Corpus::load_sft_tsv_cached(&tok_path, &corpus_path) + } else { + Corpus::load_cached(&tok_path, &corpus_path) + }; println!( "corpus: {} tokens, vocab {}", corpus.len(), corpus.vocab_size ); + if sft_tsv { + println!("SFT TSV: ON (assistant-only loss via ignore-index labels)"); + } let vocab = corpus.vocab_size; // Hold out a tail slice for validation (if requested and the corpus is big). let (train_corpus, valid) = if val_tokens > 0 { @@ -206,6 +219,10 @@ fn main() { if dropout > 0.0 { println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)"); } + if let Some(path) = &init_ckpt { + xtrain_train::checkpoint::load_into(path, &model.params()).expect("load init checkpoint"); + println!("init checkpoint: loaded {}", path.display()); + } // Eval-only mode: load a checkpoint and score it on the held-out val set, then // exit. Used to put an EXISTING model (e.g. v0) and a new one on the same diff --git a/crates/xtrain-train/src/data.rs b/crates/xtrain-train/src/data.rs index 0ba36e7..12bdc83 100644 --- a/crates/xtrain-train/src/data.rs +++ b/crates/xtrain-train/src/data.rs @@ -15,6 +15,7 @@ use xserv_tokenizer::Tokenizer; /// A tokenized corpus: one flat stream of token ids, plus the vocab size. pub struct Corpus { pub tokens: Vec, + pub labels: Option>, pub vocab_size: usize, } @@ -33,6 +34,7 @@ impl Corpus { let ids: Vec = tok.encode(text).into_iter().map(|t| t as i32).collect(); Self { tokens: ids, + labels: None, vocab_size: tok.vocab_size(), } } @@ -52,7 +54,11 @@ impl Corpus { tokens.len(), cache.display() ); - return Self { tokens, vocab_size }; + return Self { + tokens, + labels: None, + vocab_size, + }; } let me = Self::load(tokenizer_path, corpus_path); write_u16_cache(&cache, &me.tokens); @@ -64,22 +70,104 @@ impl Corpus { me } + /// Load assistant-only SFT data from a two-column TSV: + /// + /// ```text + /// userassistant + /// ``` + /// + /// Literal `\n` and `\t` escapes are decoded. Each row is formatted as + /// `User: ...\nAssistant:` + answer + `<|endoftext|>`. Labels are `-100` + /// for prompt tokens and the token id itself for answer/EOS tokens, so the + /// cross-entropy op ignores prompt rows while still training the assistant + /// answer and stop token. + pub fn load_sft_tsv_cached(tokenizer_path: &Path, corpus_path: &Path) -> Self { + let token_cache = cache_path(corpus_path); + let label_cache = label_cache_path(corpus_path); + let vocab_size = Tokenizer::from_file(tokenizer_path).vocab_size(); + if token_cache.exists() && label_cache.exists() { + let tokens = read_u16_cache(&token_cache); + let labels = read_i32_cache(&label_cache); + assert_eq!( + tokens.len(), + labels.len(), + "SFT cache token/label length mismatch" + ); + println!( + "corpus: read {} cached SFT tokens from {} (+ labels {})", + tokens.len(), + token_cache.display(), + label_cache.display() + ); + return Self { + tokens, + labels: Some(labels), + vocab_size, + }; + } + + let tok = Tokenizer::from_file(tokenizer_path); + let text = std::fs::read_to_string(corpus_path) + .unwrap_or_else(|e| panic!("failed to read SFT corpus {}: {e}", corpus_path.display())); + let mut tokens = Vec::new(); + let mut labels = Vec::new(); + for (lineno, line) in text.lines().enumerate() { + if line.trim().is_empty() { + continue; + } + let (user, assistant) = line + .split_once('\t') + .unwrap_or_else(|| panic!("SFT TSV line {} missing tab", lineno + 1)); + let user = decode_tsv_escapes(user); + let assistant = decode_tsv_escapes(assistant); + let prompt = format!("User: {user}\nAssistant:"); + let answer = format!(" {assistant}\n<|endoftext|>"); + let prompt_ids: Vec = tok.encode(&prompt).into_iter().map(|t| t as i32).collect(); + let answer_ids: Vec = tok.encode(&answer).into_iter().map(|t| t as i32).collect(); + labels.extend(std::iter::repeat(-100).take(prompt_ids.len())); + labels.extend(answer_ids.iter().copied()); + tokens.extend(prompt_ids); + tokens.extend(answer_ids); + } + assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch"); + write_u16_cache(&token_cache, &tokens); + write_i32_cache(&label_cache, &labels); + println!( + "corpus: tokenized {} SFT tokens → cached to {} (+ labels {})", + tokens.len(), + token_cache.display(), + label_cache.display() + ); + Self { + tokens, + labels: Some(labels), + vocab_size: tok.vocab_size(), + } + } + /// Split off the last `n` tokens as a held-out validation corpus, leaving the /// rest as the train corpus. Returns `(train, valid)`. Used for periodic val /// loss during training without leaking the eval window into training. pub fn split_tail(self, n: usize) -> (Self, Self) { let n = n.min(self.tokens.len() / 10); // never hand off more than 10% let cut = self.tokens.len() - n; - let valid = self.tokens[cut..].to_vec(); + let valid_tokens = self.tokens[cut..].to_vec(); + let valid_labels = self.labels.as_ref().map(|labels| labels[cut..].to_vec()); let mut train = self.tokens; train.truncate(cut); + let train_labels = self.labels.map(|mut labels| { + labels.truncate(cut); + labels + }); ( Self { tokens: train, + labels: train_labels, vocab_size: self.vocab_size, }, Self { - tokens: valid, + tokens: valid_tokens, + labels: valid_labels, vocab_size: self.vocab_size, }, ) @@ -101,11 +189,27 @@ impl Corpus { pub fn sample(&self, seq: usize, rng_state: &mut u64) -> (Vec, Vec) { assert!(self.tokens.len() > seq + 1, "corpus shorter than a window"); let max_start = self.tokens.len() - seq - 1; - let start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize; + let mut start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize; + if let Some(labels) = &self.labels { + for _ in 0..16 { + if labels[start + 1..start + seq + 1].iter().any(|&t| t >= 0) { + break; + } + start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize; + } + } let input = self.tokens[start..start + seq].to_vec(); - let target = self.tokens[start + 1..start + seq + 1].to_vec(); + let target = self.target_window(start, seq); (input, target) } + + /// Deterministic target labels for an input window starting at `start`. + pub fn target_window(&self, start: usize, seq: usize) -> Vec { + match &self.labels { + Some(labels) => labels[start + 1..start + seq + 1].to_vec(), + None => self.tokens[start + 1..start + seq + 1].to_vec(), + } + } } /// Drop a leading partial line (before the first newline) and everything after @@ -127,6 +231,12 @@ fn cache_path(corpus_path: &Path) -> PathBuf { PathBuf::from(s) } +fn label_cache_path(corpus_path: &Path) -> PathBuf { + let mut s = corpus_path.as_os_str().to_os_string(); + s.push(".labels.i32.bin"); + PathBuf::from(s) +} + /// Read a flat little-endian `[u16]` cache into an `i32` id stream. fn read_u16_cache(path: &Path) -> Vec { let mut r = BufReader::new( @@ -140,6 +250,18 @@ fn read_u16_cache(path: &Path) -> Vec { .collect() } +fn read_i32_cache(path: &Path) -> Vec { + let mut r = BufReader::new( + std::fs::File::open(path).unwrap_or_else(|e| panic!("open cache {}: {e}", path.display())), + ); + let mut buf = Vec::new(); + r.read_to_end(&mut buf).expect("read cache"); + assert!(buf.len() % 4 == 0, "corrupt i32 cache (odd byte count)"); + buf.chunks_exact(4) + .map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect() +} + /// Write an id stream as a flat little-endian `[u16]` cache. Ids must fit in u16 /// (GPT-2 vocab = 50257 < 65536); asserts otherwise. fn write_u16_cache(path: &Path, tokens: &[i32]) { @@ -154,6 +276,21 @@ fn write_u16_cache(path: &Path, tokens: &[i32]) { w.flush().expect("flush cache"); } +fn write_i32_cache(path: &Path, labels: &[i32]) { + let mut w = BufWriter::new( + std::fs::File::create(path) + .unwrap_or_else(|e| panic!("create cache {}: {e}", path.display())), + ); + for &t in labels { + w.write_all(&t.to_le_bytes()).expect("write cache"); + } + w.flush().expect("flush cache"); +} + +fn decode_tsv_escapes(s: &str) -> String { + s.replace("\\n", "\n").replace("\\t", "\t") +} + /// Tiny LCG (same constants as the model tests' deterministic fill) so dataset /// sampling is reproducible from a single u64 seed. fn next_rand(state: &mut u64) -> u64 { diff --git a/crates/xtrain-train/src/train_loop.rs b/crates/xtrain-train/src/train_loop.rs index 06b41c2..b91ce42 100644 --- a/crates/xtrain-train/src/train_loop.rs +++ b/crates/xtrain-train/src/train_loop.rs @@ -207,7 +207,7 @@ pub fn eval_loss( break; } let input: Vec = valid.tokens[s..s + seq].to_vec(); - let target: Vec = valid.tokens[s + 1..s + seq + 1].to_vec(); + let target = valid.target_window(s, seq); let ids = ids_tensor(&input, device); let targets = ids_tensor(&target, device); let loss = model.loss(&ids, &targets); diff --git a/crates/xtrain-train/tests/grad_accum.rs b/crates/xtrain-train/tests/grad_accum.rs index d39e390..076ab92 100644 --- a/crates/xtrain-train/tests/grad_accum.rs +++ b/crates/xtrain-train/tests/grad_accum.rs @@ -216,6 +216,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus { tokens: (0..n_tokens) .map(|i| (i * 7 + 3) as i32 % vocab as i32) .collect(), + labels: None, vocab_size: vocab, } } diff --git a/csrc/ops/nn.cu b/csrc/ops/nn.cu index 1cbac91..b74947f 100644 --- a/csrc/ops/nn.cu +++ b/csrc/ops/nn.cu @@ -338,7 +338,7 @@ __global__ void cross_entropy_fwd_k(const float* x, const int* target, for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv; if (threadIdx.x == 0) { int t = target[r]; - loss[r] = -logf(pr[t]); + loss[r] = t < 0 ? 0.0f : -logf(pr[t]); } } void launch_cross_entropy_fwd_f32(const float* x, const int* target, @@ -354,8 +354,13 @@ __global__ void cross_entropy_dx_k(const float* probs, const int* target, int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= rows * cols) return; int r = i / cols, c = i % cols; - float g = probs[i] - (c == target[r] ? 1.0f : 0.0f); - dx[i] = g * scale; + int t = target[r]; + if (t < 0) { + dx[i] = 0.0f; + } else { + float g = probs[i] - (c == t ? 1.0f : 0.0f); + dx[i] = g * scale; + } } void launch_cross_entropy_dx_f32(const float* probs, const int* target, float* dx, int rows, int cols, float scale, void* s) {