sft: assistant-only SFT (ignore-index CE) + chat-prompt greedy eval
Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path used by the v12 SFT runs: - cross_entropy ignores negative targets (-100 ignore-index), normalizing by valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu). - Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens masked to -100 while answer+EOS are supervised; i32 label cache alongside the u16 token cache; sample() retries windows that are fully masked; eval uses target_window so masking applies to val loss too (data.rs, train_loop.rs). - train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues training from a base checkpoint. - greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt generation eval. Test fixtures updated for the new Corpus.labels field; dropout.rs carries incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout); correctness rests on the documented v12 base+SFT runs on the GPU box. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -398,7 +398,8 @@ pub fn repeat_kv(kv: &Var, nh: usize, batch: usize) -> Var {
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// row. Negative targets are ignored, which is useful for assistant-only SFT
|
||||
/// masks. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/valid_rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
// CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32
|
||||
@@ -407,10 +408,22 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
// fp32 logits buffer) is a real activation-memory saving at large vocab.
|
||||
let logit_dtype = x.value().dtype();
|
||||
let (probs, per_row) = x.value().cross_entropy(target);
|
||||
let rows = x.value().shape()[0];
|
||||
let cols = x.value().shape()[1] as i32;
|
||||
let target_host = target.to_device(xtrain_tensor::Device::Cpu);
|
||||
let valid_rows = target_host
|
||||
.as_slice::<i32>()
|
||||
.iter()
|
||||
.filter(|&&t| {
|
||||
if t >= cols {
|
||||
panic!("cross_entropy target {t} out of vocab range {cols}");
|
||||
}
|
||||
t >= 0
|
||||
})
|
||||
.count()
|
||||
.max(1);
|
||||
// Mean loss as a host scalar wrapped back into a [1] tensor.
|
||||
let mean = per_row.to_device(xtrain_tensor::Device::Cpu);
|
||||
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / rows as f32;
|
||||
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / valid_rows as f32;
|
||||
let loss = Tensor::from_slice(&[mean_val], &[1]).to_device(x.value().device());
|
||||
|
||||
let target = target.clone();
|
||||
@@ -420,7 +433,7 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
Box::new(move |d, parents| {
|
||||
// `d` is the scalar upstream grad (1.0 when this is the loss root).
|
||||
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
||||
let scale = upstream / rows as f32;
|
||||
let scale = upstream / valid_rows as f32;
|
||||
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
|
||||
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||
}),
|
||||
|
||||
@@ -88,6 +88,7 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
let sft_tsv = args.iter().any(|a| a == "--sft-tsv");
|
||||
// Dropout (Phase T18/T21): residual-path dropout prob, active at training time
|
||||
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
|
||||
// (forward graph bit-identical to the no-dropout path). Mirrors bin/train; the
|
||||
@@ -109,6 +110,11 @@ fn main() {
|
||||
.position(|a| a == "--ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
let init_ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--init-ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
|
||||
// Use every visible GPU as a rank (CUDA_VISIBLE_DEVICES selects the set;
|
||||
// device ordinals are 0..count within it).
|
||||
@@ -129,12 +135,19 @@ fn main() {
|
||||
);
|
||||
|
||||
// Reuse the cached token-id stream (v1's u16 cache); never re-tokenize 2GB.
|
||||
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||
let corpus = if sft_tsv {
|
||||
Corpus::load_sft_tsv_cached(&tok_path, &corpus_path)
|
||||
} else {
|
||||
Corpus::load_cached(&tok_path, &corpus_path)
|
||||
};
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
if sft_tsv {
|
||||
println!("SFT TSV: ON (assistant-only loss via ignore-index labels)");
|
||||
}
|
||||
let vocab = corpus.vocab_size;
|
||||
// Hold out a tail slice for validation (rank 0 evaluates on it).
|
||||
let (train_corpus, valid) = if val_tokens > 0 {
|
||||
@@ -200,6 +213,10 @@ fn main() {
|
||||
if dropout > 0.0 {
|
||||
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
|
||||
}
|
||||
if let Some(path) = &init_ckpt {
|
||||
println!("init checkpoint: {}", path.display());
|
||||
}
|
||||
let init_ckpt_for_ranks = init_ckpt.clone();
|
||||
let results = launch(
|
||||
&devices,
|
||||
&train_corpus,
|
||||
@@ -216,6 +233,10 @@ fn main() {
|
||||
if flash {
|
||||
m = m.with_flash(true);
|
||||
}
|
||||
if let Some(path) = &init_ckpt_for_ranks {
|
||||
xtrain_train::checkpoint::load_into(path, &m.params())
|
||||
.expect("load init checkpoint");
|
||||
}
|
||||
m
|
||||
},
|
||||
);
|
||||
|
||||
@@ -27,6 +27,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||||
.collect();
|
||||
Corpus {
|
||||
tokens,
|
||||
labels: None,
|
||||
vocab_size: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ fn synth_corpus() -> Corpus {
|
||||
.collect();
|
||||
Corpus {
|
||||
tokens,
|
||||
labels: None,
|
||||
vocab_size: VOCAB,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,10 +66,18 @@ fn tiny_cfg(dropout: f32) -> Config {
|
||||
fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) {
|
||||
let (batch, seq) = (3usize, 6usize);
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
(
|
||||
batched_ids_tensor(&seqs, device),
|
||||
@@ -94,7 +102,11 @@ fn fwd_bwd(
|
||||
let loss = m.loss_batched(ids, tgt, batch);
|
||||
let loss_val = host(&loss.value())[0];
|
||||
loss.backward();
|
||||
let grads: Vec<Vec<f32>> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
let grads: Vec<Vec<f32>> = m
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().unwrap()))
|
||||
.collect();
|
||||
(logits, loss_val, grads)
|
||||
}
|
||||
|
||||
@@ -186,7 +198,9 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
|
||||
|
||||
// Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps
|
||||
// to 1 on the first training forward in BOTH, so they draw the same masks.
|
||||
let off = build(cfg, device).with_compute_dtype(dtype).with_training(true);
|
||||
let off = build(cfg, device)
|
||||
.with_compute_dtype(dtype)
|
||||
.with_training(true);
|
||||
let on = build(cfg, device)
|
||||
.with_compute_dtype(dtype)
|
||||
.with_recompute(true)
|
||||
@@ -194,11 +208,19 @@ fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
|
||||
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
let off_grads: Vec<Vec<f32>> = off
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().unwrap()))
|
||||
.collect();
|
||||
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
let on_grads: Vec<Vec<f32>> = on
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().unwrap()))
|
||||
.collect();
|
||||
|
||||
let mut max_rel = 0.0f32;
|
||||
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
||||
@@ -240,10 +262,18 @@ fn flash_plus_dropout_grad_check_fp32() {
|
||||
cfg.dropout = 0.2;
|
||||
let seq = 40usize;
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
@@ -277,7 +307,16 @@ fn flash_plus_dropout_grad_check_fp32() {
|
||||
);
|
||||
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
|
||||
// differs from composed only by reduction order; dropout masks are identical.
|
||||
assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}");
|
||||
assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}");
|
||||
assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}");
|
||||
assert!(
|
||||
logit_rel < 1e-3,
|
||||
"[F32] flash+dropout logits diverged: {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
loss_rel < 1e-3,
|
||||
"[F32] flash+dropout loss diverged: {loss_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
grad_rel < 2e-2,
|
||||
"[F32] flash+dropout grads diverged: {grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! cargo run -p xtrain-train --release --bin greedy_sample -- \
|
||||
//! /tmp/xtrain_v4.ckpt /opt/wjh/models/gpt2/tokenizer.json \
|
||||
//! --heads 24 --head-dim 32 --layers 18 --ffn 2048
|
||||
//! --heads 24 --head-dim 32 --layers 18 --ffn 2048 \
|
||||
//! --prompts-file scripts/chat_alpha_fixed_prompts.txt --max-tokens 120
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
@@ -52,6 +53,60 @@ fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_values(args: &[String], name: &str) -> Vec<String> {
|
||||
args.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, a)| {
|
||||
if a == name {
|
||||
args.get(i + 1).cloned()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn decode_prompt_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn load_prompts(args: &[String]) -> Vec<String> {
|
||||
let mut prompts = Vec::new();
|
||||
if let Some(path) = flag_value(args, "--prompts-file") {
|
||||
let text = std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|e| panic!("failed to read prompts file {path}: {e}"));
|
||||
prompts.extend(
|
||||
text.lines()
|
||||
.map(str::trim)
|
||||
.filter(|line| !line.is_empty() && !line.starts_with('#'))
|
||||
.map(decode_prompt_escapes),
|
||||
);
|
||||
}
|
||||
prompts.extend(
|
||||
flag_values(args, "--prompt")
|
||||
.into_iter()
|
||||
.map(|p| decode_prompt_escapes(&p)),
|
||||
);
|
||||
if prompts.is_empty() {
|
||||
prompts = ["Once upon a time", "One day", "The little"]
|
||||
.into_iter()
|
||||
.map(String::from)
|
||||
.collect();
|
||||
}
|
||||
prompts
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
@@ -75,6 +130,8 @@ fn main() {
|
||||
// GQA (Phase T15): num K/V heads (must match the ckpt; default = --heads).
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let max_new = flag(&args, "--max-tokens", 40usize);
|
||||
let temperature = flag(&args, "--temperature", 0.0f32);
|
||||
let prompts = load_prompts(&args);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
@@ -106,11 +163,16 @@ fn main() {
|
||||
});
|
||||
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
|
||||
|
||||
let prompts = ["Once upon a time", "One day", "The little"];
|
||||
println!(
|
||||
"decode: prompts={} max_new={} temperature={}",
|
||||
prompts.len(),
|
||||
max_new,
|
||||
temperature
|
||||
);
|
||||
for p in prompts {
|
||||
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
|
||||
let ids: Vec<i32> = tok.encode(&p).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 7u64;
|
||||
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
|
||||
let out = generate(&model, device, &ids, max_new, temperature, &mut rng);
|
||||
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
println!("[{p}] → {text}");
|
||||
}
|
||||
|
||||
@@ -115,6 +115,7 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
let sft_tsv = args.iter().any(|a| a == "--sft-tsv");
|
||||
// Dropout (Phase T18): residual-path dropout prob, active at training time
|
||||
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
|
||||
// (forward graph bit-identical to the no-dropout path).
|
||||
@@ -136,6 +137,11 @@ fn main() {
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "/tmp/xtrain_tinystories.ckpt".to_string()),
|
||||
);
|
||||
let init_ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--init-ckpt")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
@@ -146,12 +152,19 @@ fn main() {
|
||||
tok_path.display(),
|
||||
corpus_path.display()
|
||||
);
|
||||
let corpus = Corpus::load_cached(&tok_path, &corpus_path);
|
||||
let corpus = if sft_tsv {
|
||||
Corpus::load_sft_tsv_cached(&tok_path, &corpus_path)
|
||||
} else {
|
||||
Corpus::load_cached(&tok_path, &corpus_path)
|
||||
};
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
if sft_tsv {
|
||||
println!("SFT TSV: ON (assistant-only loss via ignore-index labels)");
|
||||
}
|
||||
let vocab = corpus.vocab_size;
|
||||
// Hold out a tail slice for validation (if requested and the corpus is big).
|
||||
let (train_corpus, valid) = if val_tokens > 0 {
|
||||
@@ -206,6 +219,10 @@ fn main() {
|
||||
if dropout > 0.0 {
|
||||
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
|
||||
}
|
||||
if let Some(path) = &init_ckpt {
|
||||
xtrain_train::checkpoint::load_into(path, &model.params()).expect("load init checkpoint");
|
||||
println!("init checkpoint: loaded {}", path.display());
|
||||
}
|
||||
|
||||
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||
|
||||
@@ -15,6 +15,7 @@ use xserv_tokenizer::Tokenizer;
|
||||
/// A tokenized corpus: one flat stream of token ids, plus the vocab size.
|
||||
pub struct Corpus {
|
||||
pub tokens: Vec<i32>,
|
||||
pub labels: Option<Vec<i32>>,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ impl Corpus {
|
||||
let ids: Vec<i32> = tok.encode(text).into_iter().map(|t| t as i32).collect();
|
||||
Self {
|
||||
tokens: ids,
|
||||
labels: None,
|
||||
vocab_size: tok.vocab_size(),
|
||||
}
|
||||
}
|
||||
@@ -52,7 +54,11 @@ impl Corpus {
|
||||
tokens.len(),
|
||||
cache.display()
|
||||
);
|
||||
return Self { tokens, vocab_size };
|
||||
return Self {
|
||||
tokens,
|
||||
labels: None,
|
||||
vocab_size,
|
||||
};
|
||||
}
|
||||
let me = Self::load(tokenizer_path, corpus_path);
|
||||
write_u16_cache(&cache, &me.tokens);
|
||||
@@ -64,22 +70,104 @@ impl Corpus {
|
||||
me
|
||||
}
|
||||
|
||||
/// Load assistant-only SFT data from a two-column TSV:
|
||||
///
|
||||
/// ```text
|
||||
/// user<TAB>assistant
|
||||
/// ```
|
||||
///
|
||||
/// Literal `\n` and `\t` escapes are decoded. Each row is formatted as
|
||||
/// `User: ...\nAssistant:` + answer + `<|endoftext|>`. Labels are `-100`
|
||||
/// for prompt tokens and the token id itself for answer/EOS tokens, so the
|
||||
/// cross-entropy op ignores prompt rows while still training the assistant
|
||||
/// answer and stop token.
|
||||
pub fn load_sft_tsv_cached(tokenizer_path: &Path, corpus_path: &Path) -> Self {
|
||||
let token_cache = cache_path(corpus_path);
|
||||
let label_cache = label_cache_path(corpus_path);
|
||||
let vocab_size = Tokenizer::from_file(tokenizer_path).vocab_size();
|
||||
if token_cache.exists() && label_cache.exists() {
|
||||
let tokens = read_u16_cache(&token_cache);
|
||||
let labels = read_i32_cache(&label_cache);
|
||||
assert_eq!(
|
||||
tokens.len(),
|
||||
labels.len(),
|
||||
"SFT cache token/label length mismatch"
|
||||
);
|
||||
println!(
|
||||
"corpus: read {} cached SFT tokens from {} (+ labels {})",
|
||||
tokens.len(),
|
||||
token_cache.display(),
|
||||
label_cache.display()
|
||||
);
|
||||
return Self {
|
||||
tokens,
|
||||
labels: Some(labels),
|
||||
vocab_size,
|
||||
};
|
||||
}
|
||||
|
||||
let tok = Tokenizer::from_file(tokenizer_path);
|
||||
let text = std::fs::read_to_string(corpus_path)
|
||||
.unwrap_or_else(|e| panic!("failed to read SFT corpus {}: {e}", corpus_path.display()));
|
||||
let mut tokens = Vec::new();
|
||||
let mut labels = Vec::new();
|
||||
for (lineno, line) in text.lines().enumerate() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (user, assistant) = line
|
||||
.split_once('\t')
|
||||
.unwrap_or_else(|| panic!("SFT TSV line {} missing tab", lineno + 1));
|
||||
let user = decode_tsv_escapes(user);
|
||||
let assistant = decode_tsv_escapes(assistant);
|
||||
let prompt = format!("User: {user}\nAssistant:");
|
||||
let answer = format!(" {assistant}\n<|endoftext|>");
|
||||
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
||||
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
|
||||
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
||||
labels.extend(answer_ids.iter().copied());
|
||||
tokens.extend(prompt_ids);
|
||||
tokens.extend(answer_ids);
|
||||
}
|
||||
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
|
||||
write_u16_cache(&token_cache, &tokens);
|
||||
write_i32_cache(&label_cache, &labels);
|
||||
println!(
|
||||
"corpus: tokenized {} SFT tokens → cached to {} (+ labels {})",
|
||||
tokens.len(),
|
||||
token_cache.display(),
|
||||
label_cache.display()
|
||||
);
|
||||
Self {
|
||||
tokens,
|
||||
labels: Some(labels),
|
||||
vocab_size: tok.vocab_size(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Split off the last `n` tokens as a held-out validation corpus, leaving the
|
||||
/// rest as the train corpus. Returns `(train, valid)`. Used for periodic val
|
||||
/// loss during training without leaking the eval window into training.
|
||||
pub fn split_tail(self, n: usize) -> (Self, Self) {
|
||||
let n = n.min(self.tokens.len() / 10); // never hand off more than 10%
|
||||
let cut = self.tokens.len() - n;
|
||||
let valid = self.tokens[cut..].to_vec();
|
||||
let valid_tokens = self.tokens[cut..].to_vec();
|
||||
let valid_labels = self.labels.as_ref().map(|labels| labels[cut..].to_vec());
|
||||
let mut train = self.tokens;
|
||||
train.truncate(cut);
|
||||
let train_labels = self.labels.map(|mut labels| {
|
||||
labels.truncate(cut);
|
||||
labels
|
||||
});
|
||||
(
|
||||
Self {
|
||||
tokens: train,
|
||||
labels: train_labels,
|
||||
vocab_size: self.vocab_size,
|
||||
},
|
||||
Self {
|
||||
tokens: valid,
|
||||
tokens: valid_tokens,
|
||||
labels: valid_labels,
|
||||
vocab_size: self.vocab_size,
|
||||
},
|
||||
)
|
||||
@@ -101,11 +189,27 @@ impl Corpus {
|
||||
pub fn sample(&self, seq: usize, rng_state: &mut u64) -> (Vec<i32>, Vec<i32>) {
|
||||
assert!(self.tokens.len() > seq + 1, "corpus shorter than a window");
|
||||
let max_start = self.tokens.len() - seq - 1;
|
||||
let start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
|
||||
let mut start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
|
||||
if let Some(labels) = &self.labels {
|
||||
for _ in 0..16 {
|
||||
if labels[start + 1..start + seq + 1].iter().any(|&t| t >= 0) {
|
||||
break;
|
||||
}
|
||||
start = (next_rand(rng_state) % (max_start as u64 + 1)) as usize;
|
||||
}
|
||||
}
|
||||
let input = self.tokens[start..start + seq].to_vec();
|
||||
let target = self.tokens[start + 1..start + seq + 1].to_vec();
|
||||
let target = self.target_window(start, seq);
|
||||
(input, target)
|
||||
}
|
||||
|
||||
/// Deterministic target labels for an input window starting at `start`.
|
||||
pub fn target_window(&self, start: usize, seq: usize) -> Vec<i32> {
|
||||
match &self.labels {
|
||||
Some(labels) => labels[start + 1..start + seq + 1].to_vec(),
|
||||
None => self.tokens[start + 1..start + seq + 1].to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop a leading partial line (before the first newline) and everything after
|
||||
@@ -127,6 +231,12 @@ fn cache_path(corpus_path: &Path) -> PathBuf {
|
||||
PathBuf::from(s)
|
||||
}
|
||||
|
||||
fn label_cache_path(corpus_path: &Path) -> PathBuf {
|
||||
let mut s = corpus_path.as_os_str().to_os_string();
|
||||
s.push(".labels.i32.bin");
|
||||
PathBuf::from(s)
|
||||
}
|
||||
|
||||
/// Read a flat little-endian `[u16]` cache into an `i32` id stream.
|
||||
fn read_u16_cache(path: &Path) -> Vec<i32> {
|
||||
let mut r = BufReader::new(
|
||||
@@ -140,6 +250,18 @@ fn read_u16_cache(path: &Path) -> Vec<i32> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn read_i32_cache(path: &Path) -> Vec<i32> {
|
||||
let mut r = BufReader::new(
|
||||
std::fs::File::open(path).unwrap_or_else(|e| panic!("open cache {}: {e}", path.display())),
|
||||
);
|
||||
let mut buf = Vec::new();
|
||||
r.read_to_end(&mut buf).expect("read cache");
|
||||
assert!(buf.len() % 4 == 0, "corrupt i32 cache (odd byte count)");
|
||||
buf.chunks_exact(4)
|
||||
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Write an id stream as a flat little-endian `[u16]` cache. Ids must fit in u16
|
||||
/// (GPT-2 vocab = 50257 < 65536); asserts otherwise.
|
||||
fn write_u16_cache(path: &Path, tokens: &[i32]) {
|
||||
@@ -154,6 +276,21 @@ fn write_u16_cache(path: &Path, tokens: &[i32]) {
|
||||
w.flush().expect("flush cache");
|
||||
}
|
||||
|
||||
fn write_i32_cache(path: &Path, labels: &[i32]) {
|
||||
let mut w = BufWriter::new(
|
||||
std::fs::File::create(path)
|
||||
.unwrap_or_else(|e| panic!("create cache {}: {e}", path.display())),
|
||||
);
|
||||
for &t in labels {
|
||||
w.write_all(&t.to_le_bytes()).expect("write cache");
|
||||
}
|
||||
w.flush().expect("flush cache");
|
||||
}
|
||||
|
||||
fn decode_tsv_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
||||
/// sampling is reproducible from a single u64 seed.
|
||||
fn next_rand(state: &mut u64) -> u64 {
|
||||
|
||||
@@ -207,7 +207,7 @@ pub fn eval_loss(
|
||||
break;
|
||||
}
|
||||
let input: Vec<i32> = valid.tokens[s..s + seq].to_vec();
|
||||
let target: Vec<i32> = valid.tokens[s + 1..s + seq + 1].to_vec();
|
||||
let target = valid.target_window(s, seq);
|
||||
let ids = ids_tensor(&input, device);
|
||||
let targets = ids_tensor(&target, device);
|
||||
let loss = model.loss(&ids, &targets);
|
||||
|
||||
@@ -216,6 +216,7 @@ fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||||
tokens: (0..n_tokens)
|
||||
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
|
||||
.collect(),
|
||||
labels: None,
|
||||
vocab_size: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,7 +338,7 @@ __global__ void cross_entropy_fwd_k(const float* x, const int* target,
|
||||
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
|
||||
if (threadIdx.x == 0) {
|
||||
int t = target[r];
|
||||
loss[r] = -logf(pr[t]);
|
||||
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
|
||||
}
|
||||
}
|
||||
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
|
||||
@@ -354,8 +354,13 @@ __global__ void cross_entropy_dx_k(const float* probs, const int* target,
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= rows * cols) return;
|
||||
int r = i / cols, c = i % cols;
|
||||
float g = probs[i] - (c == target[r] ? 1.0f : 0.0f);
|
||||
dx[i] = g * scale;
|
||||
int t = target[r];
|
||||
if (t < 0) {
|
||||
dx[i] = 0.0f;
|
||||
} else {
|
||||
float g = probs[i] - (c == t ? 1.0f : 0.0f);
|
||||
dx[i] = g * scale;
|
||||
}
|
||||
}
|
||||
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
|
||||
float* dx, int rows, int cols, float scale, void* s) {
|
||||
|
||||
Reference in New Issue
Block a user