test: flash+dropout cross-feature grad-check (Phase-2 integration)

Add flash_plus_dropout_grad_check_fp32 to xtrain-model dropout tests: the two
orthogonal Phase-2 features (T14 flash-attn, T18 dropout) in the same model must
still grad-check. Both models run train-mode p=0.2 (identical masks, seed is
flash-independent) so the only delta is the SDPA reduction order — checked against
the flash-vs-composed tolerance.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 00:43:54 +08:00
parent c36cdf74d1
commit 4b6d3e0a79

View File

@@ -220,3 +220,64 @@ fn dropout_recompute_matches_fp32() {
fn dropout_recompute_matches_bf16() {
recompute_with_dropout(DType::BF16, 5e-3);
}
// --- Cross-feature gate (Phase-2 integration): flash (T14) + dropout (T18)
// together in the SAME model still grad-checks. Build two identical models, both
// in train mode with p=0.2 (so dropout fires), one with `--flash` on, one off.
// The dropout site seeds are a pure function of (step_seed, layer, site) and are
// INDEPENDENT of flash, so both models draw the SAME masks on their first training
// forward → the only difference is the SDPA reduction order. Assert logits/loss/
// grads match within the flash-vs-composed tolerance and are finite. This is the
// orthogonality check for the two Phase-2 features landing together.
#[test]
fn flash_plus_dropout_grad_check_fp32() {
let device = require_gpu();
let batch = 3;
// seq=40 > FA_TILE=32 exercises flash's online-softmax tile-rescale path.
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 4;
cfg.dropout = 0.2;
let seq = 40usize;
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
.collect();
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
// Both: same init, train mode (dropout active), same step_seed progression →
// identical masks; one composed SDPA, one flash SDPA.
let off = build(cfg, device).with_training(true);
let on = build(cfg, device).with_flash(true).with_training(true);
let (off_logits, off_loss, off_grads) = fwd_bwd(&off, &ids, &tgt, batch);
let (on_logits, on_loss, on_grads) = fwd_bwd(&on, &ids, &tgt, batch);
assert!(
on_logits.iter().all(|v| v.is_finite()) && on_grads.iter().flatten().all(|v| v.is_finite()),
"flash+dropout produced non-finite logits/grads"
);
let logit_rel = off_logits
.iter()
.zip(&on_logits)
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
.fold(0.0f32, f32::max);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
let mut grad_rel = 0.0f32;
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
grad_rel = grad_rel.max((a - b).abs() / a.abs().max(1e-3));
}
println!(
"[F32] flash+dropout vs composed+dropout: loss rel {loss_rel:.2e}, \
logits max rel {logit_rel:.2e}, grad max rel {grad_rel:.3e}"
);
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
// differs from composed only by reduction order; dropout masks are identical.
assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}");
assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}");
assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}");
}