Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path used by the v12 SFT runs: - cross_entropy ignores negative targets (-100 ignore-index), normalizing by valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu). - Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens masked to -100 while answer+EOS are supervised; i32 label cache alongside the u16 token cache; sample() retries windows that are fully masked; eval uses target_window so masking applies to val loss too (data.rs, train_loop.rs). - train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues training from a base checkpoint. - greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt generation eval. Test fixtures updated for the new Corpus.labels field; dropout.rs carries incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout); correctness rests on the documented v12 base+SFT runs on the GPU box. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
323 lines
11 KiB
Rust
323 lines
11 KiB
Rust
// T18 dropout model-level gates.
|
|
//
|
|
// 1. p=0 bit-identical: a model built with cfg.dropout=0 (in either train or
|
|
// eval mode) produces logits/loss/grads bit-for-bit identical to the same
|
|
// model with no dropout field touched — the default forward graph is
|
|
// unchanged (the regression guard).
|
|
// 2. eval identity: with p>0 but eval mode, the forward equals the p=0 forward
|
|
// bit-for-bit (dropout is OFF at eval).
|
|
// 3. train vs eval differ: with p>0 and train mode, the forward differs from
|
|
// eval (dropout actually does something) and grads are still finite.
|
|
// 4. recompute compatibility: with p>0 + train + recompute, grads match the
|
|
// non-recompute path (the counter-based seed reproduces the same mask on the
|
|
// backward re-run — T13 stays exact even with dropout in the block).
|
|
//
|
|
// (The fixed-seed grad-check of the dropout op and the E[out]≈x / keep-rate check
|
|
// live in xtrain-autodiff/tests/autograd.rs; p>0 training convergence is the
|
|
// dash5 short run noted in docs/17-dropout.md.)
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use xtrain_cuda::device;
|
|
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
|
use xtrain_tensor::{DType, Device};
|
|
|
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
|
let mut seed = 1u64;
|
|
TinyTransformer::new(cfg, device, |shape| {
|
|
seed = seed.wrapping_add(1);
|
|
let n: usize = shape.iter().product();
|
|
if shape.len() == 1 {
|
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
|
} else {
|
|
fill(n, seed, 0.08)
|
|
}
|
|
})
|
|
}
|
|
|
|
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
|
t.to_dtype(DType::F32)
|
|
.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.to_vec()
|
|
}
|
|
|
|
fn tiny_cfg(dropout: f32) -> Config {
|
|
let mut cfg = Config::tiny();
|
|
cfg.vocab = 16;
|
|
cfg.n_layers = 4;
|
|
cfg.dropout = dropout;
|
|
cfg
|
|
}
|
|
|
|
fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) {
|
|
let (batch, seq) = (3usize, 6usize);
|
|
let seqs: Vec<Vec<i32>> = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let tgts: Vec<Vec<i32>> = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
(
|
|
batched_ids_tensor(&seqs, device),
|
|
batched_ids_tensor(&tgts, device),
|
|
)
|
|
}
|
|
|
|
fn require_gpu() -> Device {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
device::set_device(0).unwrap();
|
|
Device::Cuda(0)
|
|
}
|
|
|
|
// Run forward+backward, return (logits, loss, per-param grads).
|
|
fn fwd_bwd(
|
|
m: &TinyTransformer,
|
|
ids: &xtrain_tensor::Tensor,
|
|
tgt: &xtrain_tensor::Tensor,
|
|
batch: usize,
|
|
) -> (Vec<f32>, f32, Vec<Vec<f32>>) {
|
|
let logits = host(&m.forward_batched(ids, batch).value());
|
|
let loss = m.loss_batched(ids, tgt, batch);
|
|
let loss_val = host(&loss.value())[0];
|
|
loss.backward();
|
|
let grads: Vec<Vec<f32>> = m
|
|
.params()
|
|
.iter()
|
|
.map(|p| host(&p.grad().unwrap()))
|
|
.collect();
|
|
(logits, loss_val, grads)
|
|
}
|
|
|
|
// --- Gate 3: p=0 is bit-identical to the no-dropout path (default graph). ---
|
|
#[test]
|
|
fn dropout_p0_bit_identical() {
|
|
let device = require_gpu();
|
|
let batch = 3;
|
|
|
|
// Reference: cfg.dropout default (0.0), never touched train/eval.
|
|
let cfg0 = tiny_cfg(0.0);
|
|
let (ids, tgt) = batch_data(&cfg0, device);
|
|
let ref_m = build(cfg0, device);
|
|
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
|
|
|
|
// p=0 in TRAINING mode: the seed bump is gated on p>0, the op no-ops at p==0,
|
|
// so the graph must be byte-identical.
|
|
let p0_train = build(tiny_cfg(0.0), device);
|
|
p0_train.train();
|
|
let (lt, lst, gt) = fwd_bwd(&p0_train, &ids, &tgt, batch);
|
|
|
|
assert_eq!(ref_logits, lt, "p=0 train logits not bit-identical");
|
|
assert_eq!(ref_loss, lst, "p=0 train loss not bit-identical");
|
|
for (i, (a, b)) in ref_grads.iter().zip(>).enumerate() {
|
|
assert_eq!(a, b, "p=0 train grad[{i}] not bit-identical");
|
|
}
|
|
println!("p=0 (train) vs no-dropout: logits/loss/grads bit-identical ✅");
|
|
}
|
|
|
|
// --- Gate 2: eval is exact identity (p>0 but eval mode == p=0). ---
|
|
#[test]
|
|
fn dropout_eval_is_identity() {
|
|
let device = require_gpu();
|
|
let batch = 3;
|
|
let cfg = tiny_cfg(0.2);
|
|
let (ids, tgt) = batch_data(&cfg, device);
|
|
|
|
// p=0 reference and a p=0.2 model held in eval — outputs must match bit-for-bit.
|
|
let ref_m = build(tiny_cfg(0.0), device);
|
|
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
|
|
|
|
let eval_m = build(cfg, device);
|
|
eval_m.eval(); // explicit; also the default
|
|
let (el, els, eg) = fwd_bwd(&eval_m, &ids, &tgt, batch);
|
|
|
|
assert_eq!(ref_logits, el, "eval (p>0) logits not identity");
|
|
assert_eq!(ref_loss, els, "eval (p>0) loss not identity");
|
|
for (i, (a, b)) in ref_grads.iter().zip(&eg).enumerate() {
|
|
assert_eq!(a, b, "eval (p>0) grad[{i}] not identity");
|
|
}
|
|
println!("eval (p=0.2) == no-dropout: bit-identical (eval is identity) ✅");
|
|
}
|
|
|
|
// --- Gate (train vs eval differ): with p>0 + train, dropout actually fires. ---
|
|
#[test]
|
|
fn dropout_train_differs_from_eval() {
|
|
let device = require_gpu();
|
|
let batch = 3;
|
|
let cfg = tiny_cfg(0.3);
|
|
let (ids, _tgt) = batch_data(&cfg, device);
|
|
|
|
let m = build(cfg, device);
|
|
m.eval();
|
|
let eval_logits = host(&m.forward_batched(&ids, batch).value());
|
|
m.train();
|
|
let train_logits = host(&m.forward_batched(&ids, batch).value());
|
|
|
|
let max_diff = eval_logits
|
|
.iter()
|
|
.zip(&train_logits)
|
|
.map(|(a, b)| (a - b).abs())
|
|
.fold(0.0f32, f32::max);
|
|
assert!(
|
|
max_diff > 1e-4 && train_logits.iter().all(|v| v.is_finite()),
|
|
"train logits should differ from eval (dropout active) and be finite; max_diff={max_diff}"
|
|
);
|
|
println!("train vs eval logits max diff {max_diff:.4e} (dropout active in train) ✅");
|
|
}
|
|
|
|
// --- Gate 4: p>0 + recompute grads match non-recompute (T13 stays exact). ---
|
|
// The counter-based seed is a pure function of (step_seed, layer, site); the
|
|
// checkpoint backward re-runs block_forward and re-derives the SAME seeds, so the
|
|
// recomputed dropout masks match the forward — grads stay bit-identical.
|
|
fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
|
|
let device = require_gpu();
|
|
let batch = 3;
|
|
let cfg = tiny_cfg(0.2);
|
|
let (ids, tgt) = batch_data(&cfg, device);
|
|
|
|
// Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps
|
|
// to 1 on the first training forward in BOTH, so they draw the same masks.
|
|
let off = build(cfg, device)
|
|
.with_compute_dtype(dtype)
|
|
.with_training(true);
|
|
let on = build(cfg, device)
|
|
.with_compute_dtype(dtype)
|
|
.with_recompute(true)
|
|
.with_training(true);
|
|
|
|
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
|
off_loss.backward();
|
|
let off_grads: Vec<Vec<f32>> = off
|
|
.params()
|
|
.iter()
|
|
.map(|p| host(&p.grad().unwrap()))
|
|
.collect();
|
|
|
|
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
|
on_loss.backward();
|
|
let on_grads: Vec<Vec<f32>> = on
|
|
.params()
|
|
.iter()
|
|
.map(|p| host(&p.grad().unwrap()))
|
|
.collect();
|
|
|
|
let mut max_rel = 0.0f32;
|
|
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
|
max_rel = max_rel.max((a - b).abs() / a.abs().max(1e-3));
|
|
}
|
|
println!("[{dtype:?}] dropout p=0.2 recompute on/off grad max rel = {max_rel:.3e}");
|
|
assert!(
|
|
max_rel < grad_tol,
|
|
"[{dtype:?}] recompute grads diverged with dropout: {max_rel:.3e}"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn dropout_recompute_matches_fp32() {
|
|
recompute_with_dropout(DType::F32, 1e-4);
|
|
}
|
|
|
|
#[test]
|
|
fn dropout_recompute_matches_bf16() {
|
|
recompute_with_dropout(DType::BF16, 5e-3);
|
|
}
|
|
|
|
// --- Cross-feature gate (Phase-2 integration): flash (T14) + dropout (T18)
|
|
// together in the SAME model still grad-checks. Build two identical models, both
|
|
// in train mode with p=0.2 (so dropout fires), one with `--flash` on, one off.
|
|
// The dropout site seeds are a pure function of (step_seed, layer, site) and are
|
|
// INDEPENDENT of flash, so both models draw the SAME masks on their first training
|
|
// forward → the only difference is the SDPA reduction order. Assert logits/loss/
|
|
// grads match within the flash-vs-composed tolerance and are finite. This is the
|
|
// orthogonality check for the two Phase-2 features landing together.
|
|
#[test]
|
|
fn flash_plus_dropout_grad_check_fp32() {
|
|
let device = require_gpu();
|
|
let batch = 3;
|
|
// seq=40 > FA_TILE=32 exercises flash's online-softmax tile-rescale path.
|
|
let mut cfg = Config::tiny();
|
|
cfg.vocab = 16;
|
|
cfg.n_layers = 4;
|
|
cfg.dropout = 0.2;
|
|
let seq = 40usize;
|
|
let seqs: Vec<Vec<i32>> = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let tgts: Vec<Vec<i32>> = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let ids = batched_ids_tensor(&seqs, device);
|
|
let tgt = batched_ids_tensor(&tgts, device);
|
|
|
|
// Both: same init, train mode (dropout active), same step_seed progression →
|
|
// identical masks; one composed SDPA, one flash SDPA.
|
|
let off = build(cfg, device).with_training(true);
|
|
let on = build(cfg, device).with_flash(true).with_training(true);
|
|
|
|
let (off_logits, off_loss, off_grads) = fwd_bwd(&off, &ids, &tgt, batch);
|
|
let (on_logits, on_loss, on_grads) = fwd_bwd(&on, &ids, &tgt, batch);
|
|
|
|
assert!(
|
|
on_logits.iter().all(|v| v.is_finite()) && on_grads.iter().flatten().all(|v| v.is_finite()),
|
|
"flash+dropout produced non-finite logits/grads"
|
|
);
|
|
|
|
let logit_rel = off_logits
|
|
.iter()
|
|
.zip(&on_logits)
|
|
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
|
.fold(0.0f32, f32::max);
|
|
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
|
let mut grad_rel = 0.0f32;
|
|
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
|
grad_rel = grad_rel.max((a - b).abs() / a.abs().max(1e-3));
|
|
}
|
|
println!(
|
|
"[F32] flash+dropout vs composed+dropout: loss rel {loss_rel:.2e}, \
|
|
logits max rel {logit_rel:.2e}, grad max rel {grad_rel:.3e}"
|
|
);
|
|
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
|
|
// differs from composed only by reduction order; dropout masks are identical.
|
|
assert!(
|
|
logit_rel < 1e-3,
|
|
"[F32] flash+dropout logits diverged: {logit_rel:.2e}"
|
|
);
|
|
assert!(
|
|
loss_rel < 1e-3,
|
|
"[F32] flash+dropout loss diverged: {loss_rel:.2e}"
|
|
);
|
|
assert!(
|
|
grad_rel < 2e-2,
|
|
"[F32] flash+dropout grads diverged: {grad_rel:.3e}"
|
|
);
|
|
}
|