sft: assistant-only SFT (ignore-index CE) + chat-prompt greedy eval

Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path
used by the v12 SFT runs:

- cross_entropy ignores negative targets (-100 ignore-index), normalizing by
  valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu).
- Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is
  formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens
  masked to -100 while answer+EOS are supervised; i32 label cache alongside the
  u16 token cache; sample() retries windows that are fully masked; eval uses
  target_window so masking applies to val loss too (data.rs, train_loop.rs).
- train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues
  training from a base checkpoint.
- greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt
  generation eval.

Test fixtures updated for the new Corpus.labels field; dropout.rs carries
incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout);
correctness rests on the documented v12 base+SFT runs on the GPU box.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-29 16:19:02 +08:00
parent 5c27493a90
commit fbf4ac2917
11 changed files with 327 additions and 30 deletions

View File

@@ -398,7 +398,8 @@ pub fn repeat_kv(kv: &Var, nh: usize, batch: usize) -> Var {
}
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// row. Negative targets are ignored, which is useful for assistant-only SFT
/// masks. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/valid_rows`,
/// scaled by the upstream scalar grad.
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
// CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32
@@ -407,10 +408,22 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
// fp32 logits buffer) is a real activation-memory saving at large vocab.
let logit_dtype = x.value().dtype();
let (probs, per_row) = x.value().cross_entropy(target);
let rows = x.value().shape()[0];
let cols = x.value().shape()[1] as i32;
let target_host = target.to_device(xtrain_tensor::Device::Cpu);
let valid_rows = target_host
.as_slice::<i32>()
.iter()
.filter(|&&t| {
if t >= cols {
panic!("cross_entropy target {t} out of vocab range {cols}");
}
t >= 0
})
.count()
.max(1);
// Mean loss as a host scalar wrapped back into a [1] tensor.
let mean = per_row.to_device(xtrain_tensor::Device::Cpu);
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / rows as f32;
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / valid_rows as f32;
let loss = Tensor::from_slice(&[mean_val], &[1]).to_device(x.value().device());
let target = target.clone();
@@ -420,7 +433,7 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
Box::new(move |d, parents| {
// `d` is the scalar upstream grad (1.0 when this is the loss root).
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
let scale = upstream / rows as f32;
let scale = upstream / valid_rows as f32;
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),

View File

@@ -88,6 +88,7 @@ fn main() {
let val_tokens: usize = flag(&args, "--val-tokens", 0);
let eval_every: usize = flag(&args, "--eval-every", 0);
let eval_batches: usize = flag(&args, "--eval-batches", 64);
let sft_tsv = args.iter().any(|a| a == "--sft-tsv");
// Dropout (Phase T18/T21): residual-path dropout prob, active at training time
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
// (forward graph bit-identical to the no-dropout path). Mirrors bin/train; the
@@ -109,6 +110,11 @@ fn main() {
.position(|a| a == "--ckpt")
.and_then(|i| args.get(i + 1))
.map(PathBuf::from);
let init_ckpt: Option<PathBuf> = args
.iter()
.position(|a| a == "--init-ckpt")
.and_then(|i| args.get(i + 1))
.map(PathBuf::from);
// Use every visible GPU as a rank (CUDA_VISIBLE_DEVICES selects the set;
// device ordinals are 0..count within it).
@@ -129,12 +135,19 @@ fn main() {
);
// Reuse the cached token-id stream (v1's u16 cache); never re-tokenize 2GB.
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
let corpus = if sft_tsv {
Corpus::load_sft_tsv_cached(&tok_path, &corpus_path)
} else {
Corpus::load_cached(&tok_path, &corpus_path)
};
println!(
"corpus: {} tokens, vocab {}",
corpus.len(),
corpus.vocab_size
);
if sft_tsv {
println!("SFT TSV: ON (assistant-only loss via ignore-index labels)");
}
let vocab = corpus.vocab_size;
// Hold out a tail slice for validation (rank 0 evaluates on it).
let (train_corpus, valid) = if val_tokens > 0 {
@@ -200,6 +213,10 @@ fn main() {
if dropout > 0.0 {
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
}
if let Some(path) = &init_ckpt {
println!("init checkpoint: {}", path.display());
}
let init_ckpt_for_ranks = init_ckpt.clone();
let results = launch(
&devices,
&train_corpus,
@@ -216,6 +233,10 @@ fn main() {
if flash {
m = m.with_flash(true);
}
if let Some(path) = &init_ckpt_for_ranks {
xtrain_train::checkpoint::load_into(path, &m.params())
.expect("load init checkpoint");
}
m
},
);

View File

@@ -27,6 +27,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
.collect();
Corpus {
tokens,
labels: None,
vocab_size: vocab,
}
}

View File

@@ -37,6 +37,7 @@ fn synth_corpus() -> Corpus {
.collect();
Corpus {
tokens,
labels: None,
vocab_size: VOCAB,
}
}

View File

@@ -66,10 +66,18 @@ fn tiny_cfg(dropout: f32) -> Config {
fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) {
let (batch, seq) = (3usize, 6usize);
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
(
batched_ids_tensor(&seqs, device),
@@ -94,7 +102,11 @@ fn fwd_bwd(
let loss = m.loss_batched(ids, tgt, batch);
let loss_val = host(&loss.value())[0];
loss.backward();
let grads: Vec<Vec<f32>> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect();
let grads: Vec<Vec<f32>> = m
.params()
.iter()
.map(|p| host(&p.grad().unwrap()))
.collect();
(logits, loss_val, grads)
}
@@ -186,7 +198,9 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
// Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps
// to 1 on the first training forward in BOTH, so they draw the same masks.
let off = build(cfg, device).with_compute_dtype(dtype).with_training(true);
let off = build(cfg, device)
.with_compute_dtype(dtype)
.with_training(true);
let on = build(cfg, device)
.with_compute_dtype(dtype)
.with_recompute(true)
@@ -194,11 +208,19 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
let off_loss = off.loss_batched(&ids, &tgt, batch);
off_loss.backward();
let off_grads: Vec<Vec<f32>> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect();
let off_grads: Vec<Vec<f32>> = off
.params()
.iter()
.map(|p| host(&p.grad().unwrap()))
.collect();
let on_loss = on.loss_batched(&ids, &tgt, batch);
on_loss.backward();
let on_grads: Vec<Vec<f32>> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect();
let on_grads: Vec<Vec<f32>> = on
.params()
.iter()
.map(|p| host(&p.grad().unwrap()))
.collect();
let mut max_rel = 0.0f32;
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
@@ -240,10 +262,18 @@ fn flash_plus_dropout_grad_check_fp32() {
cfg.dropout = 0.2;
let seq = 40usize;
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
@@ -277,7 +307,16 @@ fn flash_plus_dropout_grad_check_fp32() {
);
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
// differs from composed only by reduction order; dropout masks are identical.
assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}");
assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}");
assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}");
assert!(
logit_rel < 1e-3,
"[F32] flash+dropout logits diverged: {logit_rel:.2e}"
);
assert!(
loss_rel < 1e-3,
"[F32] flash+dropout loss diverged: {loss_rel:.2e}"
);
assert!(
grad_rel < 2e-2,
"[F32] flash+dropout grads diverged: {grad_rel:.3e}"
);
}

View File

@@ -7,7 +7,8 @@
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
//! cargo run -p xtrain-train --release --bin greedy_sample -- \
//! /tmp/xtrain_v4.ckpt /opt/wjh/models/gpt2/tokenizer.json \
//! --heads 24 --head-dim 32 --layers 18 --ffn 2048
//! --heads 24 --head-dim 32 --layers 18 --ffn 2048 \
//! --prompts-file scripts/chat_alpha_fixed_prompts.txt --max-tokens 120
#[cfg(no_cuda)]
fn main() {
@@ -52,6 +53,60 @@ fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.cloned()
}
#[cfg(not(no_cuda))]
fn flag_values(args: &[String], name: &str) -> Vec<String> {
args.iter()
.enumerate()
.filter_map(|(i, a)| {
if a == name {
args.get(i + 1).cloned()
} else {
None
}
})
.collect()
}
#[cfg(not(no_cuda))]
fn decode_prompt_escapes(s: &str) -> String {
s.replace("\\n", "\n").replace("\\t", "\t")
}
#[cfg(not(no_cuda))]
fn load_prompts(args: &[String]) -> Vec<String> {
let mut prompts = Vec::new();
if let Some(path) = flag_value(args, "--prompts-file") {
let text = std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("failed to read prompts file {path}: {e}"));
prompts.extend(
text.lines()
.map(str::trim)
.filter(|line| !line.is_empty() && !line.starts_with('#'))
.map(decode_prompt_escapes),
);
}
prompts.extend(
flag_values(args, "--prompt")
.into_iter()
.map(|p| decode_prompt_escapes(&p)),
);
if prompts.is_empty() {
prompts = ["Once upon a time", "One day", "The little"]
.into_iter()
.map(String::from)
.collect();
}
prompts
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
@@ -75,6 +130,8 @@ fn main() {
// GQA (Phase T15): num K/V heads (must match the ckpt; default = --heads).
let kv_heads = flag(&args, "--kv-heads", n_heads);
let max_new = flag(&args, "--max-tokens", 40usize);
let temperature = flag(&args, "--temperature", 0.0f32);
let prompts = load_prompts(&args);
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
@@ -106,11 +163,16 @@ fn main() {
});
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
let prompts = ["Once upon a time", "One day", "The little"];
println!(
"decode: prompts={} max_new={} temperature={}",
prompts.len(),
max_new,
temperature
);
for p in prompts {
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
let ids: Vec<i32> = tok.encode(&p).into_iter().map(|t| t as i32).collect();
let mut rng = 7u64;
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
let out = generate(&model, device, &ids, max_new, temperature, &mut rng);
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
println!("[{p}] → {text}");
}

View File

@@ -115,6 +115,7 @@ fn main() {
let val_tokens: usize = flag(&args, "--val-tokens", 0);
let eval_every: usize = flag(&args, "--eval-every", 0);
let eval_batches: usize = flag(&args, "--eval-batches", 64);
let sft_tsv = args.iter().any(|a| a == "--sft-tsv");
// Dropout (Phase T18): residual-path dropout prob, active at training time
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
// (forward graph bit-identical to the no-dropout path).
@@ -136,6 +137,11 @@ fn main() {
.cloned()
.unwrap_or_else(|| "/tmp/xtrain_tinystories.ckpt".to_string()),
);
let init_ckpt: Option<PathBuf> = args
.iter()
.position(|a| a == "--init-ckpt")
.and_then(|i| args.get(i + 1))
.map(PathBuf::from);
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
@@ -146,12 +152,19 @@ fn main() {
tok_path.display(),
corpus_path.display()
);
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
let corpus = if sft_tsv {
Corpus::load_sft_tsv_cached(&tok_path, &corpus_path)
} else {
Corpus::load_cached(&tok_path, &corpus_path)
};
println!(
"corpus: {} tokens, vocab {}",
corpus.len(),
corpus.vocab_size
);
if sft_tsv {
println!("SFT TSV: ON (assistant-only loss via ignore-index labels)");
}
let vocab = corpus.vocab_size;
// Hold out a tail slice for validation (if requested and the corpus is big).
let (train_corpus, valid) = if val_tokens > 0 {
@@ -206,6 +219,10 @@ fn main() {
if dropout > 0.0 {
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
}
if let Some(path) = &init_ckpt {
xtrain_train::checkpoint::load_into(path, &model.params()).expect("load init checkpoint");
println!("init checkpoint: loaded {}", path.display());
}
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same

View File

@@ -15,6 +15,7 @@ use xserv_tokenizer::Tokenizer;
/// A tokenized corpus: one flat stream of token ids, plus the vocab size.
pub struct Corpus {
pub tokens: Vec<i32>,
pub labels: Option<Vec<i32>>,
pub vocab_size: usize,
}
@@ -33,6 +34,7 @@ impl Corpus {
let ids: Vec<i32> = tok.encode(text).into_iter().map(|t| t as i32).collect();
Self {
tokens: ids,
labels: None,
vocab_size: tok.vocab_size(),
}
}
@@ -52,7 +54,11 @@ impl Corpus {
tokens.len(),
cache.display()
);
return Self { tokens, vocab_size };
return Self {
tokens,
labels: None,
vocab_size,
};
}
let me = Self::load(tokenizer_path, corpus_path);
write_u16_cache(&cache, &me.tokens);
@@ -64,22 +70,104 @@ impl Corpus {
me
}
/// Load assistant-only SFT data from a two-column TSV:
///
/// ```text
/// user<TAB>assistant
/// ```
///
/// Literal `\n` and `\t` escapes are decoded. Each row is formatted as
/// `User: ...\nAssistant:` + answer + `<|endoftext|>`. Labels are `-100`
/// for prompt tokens and the token id itself for answer/EOS tokens, so the
/// cross-entropy op ignores prompt rows while still training the assistant
/// answer and stop token.
pub fn load_sft_tsv_cached(tokenizer_path: &Path, corpus_path: &Path) -> Self {
let token_cache = cache_path(corpus_path);
let label_cache = label_cache_path(corpus_path);
let vocab_size = Tokenizer::from_file(tokenizer_path).vocab_size();
if token_cache.exists() && label_cache.exists() {
let tokens = read_u16_cache(&token_cache);
let labels = read_i32_cache(&label_cache);
assert_eq!(
tokens.len(),
labels.len(),
"SFT cache token/label length mismatch"
);
println!(
"corpus: read {} cached SFT tokens from {} (+ labels {})",
tokens.len(),
token_cache.display(),
label_cache.display()
);
return Self {
tokens,
labels: Some(labels),
vocab_size,
};
}
let tok = Tokenizer::from_file(tokenizer_path);
let text = std::fs::read_to_string(corpus_path)
.unwrap_or_else(|e| panic!("failed to read SFT corpus {}: {e}", corpus_path.display()));
let mut tokens = Vec::new();
let mut labels = Vec::new();
for (lineno, line) in text.lines().enumerate() {
if line.trim().is_empty() {
continue;
}
let (user, assistant) = line
.split_once('\t')
.unwrap_or_else(|| panic!("SFT TSV line {} missing tab", lineno + 1));
let user = decode_tsv_escapes(user);
let assistant = decode_tsv_escapes(assistant);
let prompt = format!("User: {user}\nAssistant:");
let answer = format!(" {assistant}\n<|endoftext|>");
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
labels.extend(answer_ids.iter().copied());
tokens.extend(prompt_ids);
tokens.extend(answer_ids);
}
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
write_u16_cache(&token_cache, &tokens);
write_i32_cache(&label_cache, &labels);
println!(
"corpus: tokenized {} SFT tokens → cached to {} (+ labels {})",
tokens.len(),
token_cache.display(),
label_cache.display()
);
Self {
tokens,
labels: Some(labels),
vocab_size: tok.vocab_size(),
}
}
/// Split off the last `n` tokens as a held-out validation corpus, leaving the
/// rest as the train corpus. Returns `(train, valid)`. Used for periodic val
/// loss during training without leaking the eval window into training.
pub fn split_tail(self, n: usize) -> (Self, Self) {
let n = n.min(self.tokens.len() / 10); // never hand off more than 10%
let cut = self.tokens.len() - n;
let valid = self.tokens[cut..].to_vec();
let valid_tokens = self.tokens[cut..].to_vec();
let valid_labels = self.labels.as_ref().map(|labels| labels[cut..].to_vec());
let mut train = self.tokens;
train.truncate(cut);
let train_labels = self.labels.map(|mut labels| {
labels.truncate(cut);
labels
});
(
Self {
tokens: train,
labels: train_labels,
vocab_size: self.vocab_size,
},
Self {
tokens: valid,
tokens: valid_tokens,
labels: valid_labels,
vocab_size: self.vocab_size,
},
)
@@ -101,11 +189,27 @@ impl Corpus {
pub fn sample(&self, seq: usize, rng_state: &mut u64) -> (Vec<i32>, Vec<i32>) {
assert!(self.tokens.len() > seq + 1, "corpus shorter than a window");
let max_start = self.tokens.len() - seq - 1;
let start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
let mut start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
if let Some(labels) = &self.labels {
for _ in 0..16 {
if labels[start + 1..start + seq + 1].iter().any(|&t| t >= 0) {
break;
}
start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
}
}
let input = self.tokens[start..start + seq].to_vec();
let target = self.tokens[start + 1..start + seq + 1].to_vec();
let target = self.target_window(start, seq);
(input, target)
}
/// Deterministic target labels for an input window starting at `start`.
pub fn target_window(&self, start: usize, seq: usize) -> Vec<i32> {
match &self.labels {
Some(labels) => labels[start + 1..start + seq + 1].to_vec(),
None => self.tokens[start + 1..start + seq + 1].to_vec(),
}
}
}
/// Drop a leading partial line (before the first newline) and everything after
@@ -127,6 +231,12 @@ fn cache_path(corpus_path: &Path) -> PathBuf {
PathBuf::from(s)
}
fn label_cache_path(corpus_path: &Path) -> PathBuf {
let mut s = corpus_path.as_os_str().to_os_string();
s.push(".labels.i32.bin");
PathBuf::from(s)
}
/// Read a flat little-endian `[u16]` cache into an `i32` id stream.
fn read_u16_cache(path: &Path) -> Vec<i32> {
let mut r = BufReader::new(
@@ -140,6 +250,18 @@ fn read_u16_cache(path: &Path) -> Vec<i32> {
.collect()
}
fn read_i32_cache(path: &Path) -> Vec<i32> {
let mut r = BufReader::new(
std::fs::File::open(path).unwrap_or_else(|e| panic!("open cache {}: {e}", path.display())),
);
let mut buf = Vec::new();
r.read_to_end(&mut buf).expect("read cache");
assert!(buf.len() % 4 == 0, "corrupt i32 cache (odd byte count)");
buf.chunks_exact(4)
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
}
/// Write an id stream as a flat little-endian `[u16]` cache. Ids must fit in u16
/// (GPT-2 vocab = 50257 < 65536); asserts otherwise.
fn write_u16_cache(path: &Path, tokens: &[i32]) {
@@ -154,6 +276,21 @@ fn write_u16_cache(path: &Path, tokens: &[i32]) {
w.flush().expect("flush cache");
}
fn write_i32_cache(path: &Path, labels: &[i32]) {
let mut w = BufWriter::new(
std::fs::File::create(path)
.unwrap_or_else(|e| panic!("create cache {}: {e}", path.display())),
);
for &t in labels {
w.write_all(&t.to_le_bytes()).expect("write cache");
}
w.flush().expect("flush cache");
}
fn decode_tsv_escapes(s: &str) -> String {
s.replace("\\n", "\n").replace("\\t", "\t")
}
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
/// sampling is reproducible from a single u64 seed.
fn next_rand(state: &mut u64) -> u64 {

View File

@@ -207,7 +207,7 @@ pub fn eval_loss(
break;
}
let input: Vec<i32> = valid.tokens[s..s + seq].to_vec();
let target: Vec<i32> = valid.tokens[s + 1..s + seq + 1].to_vec();
let target = valid.target_window(s, seq);
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
let loss = model.loss(&ids, &targets);

View File

@@ -216,6 +216,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
tokens: (0..n_tokens)
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
.collect(),
labels: None,
vocab_size: vocab,
}
}

View File

@@ -338,7 +338,7 @@ __global__ void cross_entropy_fwd_k(const float* x, const int* target,
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
if (threadIdx.x == 0) {
int t = target[r];
loss[r] = -logf(pr[t]);
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
}
}
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
@@ -354,9 +354,14 @@ __global__ void cross_entropy_dx_k(const float* probs, const int* target,
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= rows * cols) return;
int r = i / cols, c = i % cols;
float g = probs[i] - (c == target[r] ? 1.0f : 0.0f);
int t = target[r];
if (t < 0) {
dx[i] = 0.0f;
} else {
float g = probs[i] - (c == t ? 1.0f : 0.0f);
dx[i] = g * scale;
}
}
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
float* dx, int rows, int cols, float scale, void* s) {
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;