test: flash+dropout cross-feature grad-check (Phase-2 integration)
Add flash_plus_dropout_grad_check_fp32 to xtrain-model dropout tests: the two orthogonal Phase-2 features (T14 flash-attn, T18 dropout) in the same model must still grad-check. Both models run train-mode p=0.2 (identical masks, seed is flash-independent) so the only delta is the SDPA reduction order — checked against the flash-vs-composed tolerance. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -220,3 +220,64 @@ fn dropout_recompute_matches_fp32() {
|
||||
fn dropout_recompute_matches_bf16() {
|
||||
recompute_with_dropout(DType::BF16, 5e-3);
|
||||
}
|
||||
|
||||
// --- Cross-feature gate (Phase-2 integration): flash (T14) + dropout (T18)
|
||||
// together in the SAME model still grad-checks. Build two identical models, both
|
||||
// in train mode with p=0.2 (so dropout fires), one with `--flash` on, one off.
|
||||
// The dropout site seeds are a pure function of (step_seed, layer, site) and are
|
||||
// INDEPENDENT of flash, so both models draw the SAME masks on their first training
|
||||
// forward → the only difference is the SDPA reduction order. Assert logits/loss/
|
||||
// grads match within the flash-vs-composed tolerance and are finite. This is the
|
||||
// orthogonality check for the two Phase-2 features landing together.
|
||||
#[test]
|
||||
fn flash_plus_dropout_grad_check_fp32() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
// seq=40 > FA_TILE=32 exercises flash's online-softmax tile-rescale path.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 4;
|
||||
cfg.dropout = 0.2;
|
||||
let seq = 40usize;
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
// Both: same init, train mode (dropout active), same step_seed progression →
|
||||
// identical masks; one composed SDPA, one flash SDPA.
|
||||
let off = build(cfg, device).with_training(true);
|
||||
let on = build(cfg, device).with_flash(true).with_training(true);
|
||||
|
||||
let (off_logits, off_loss, off_grads) = fwd_bwd(&off, &ids, &tgt, batch);
|
||||
let (on_logits, on_loss, on_grads) = fwd_bwd(&on, &ids, &tgt, batch);
|
||||
|
||||
assert!(
|
||||
on_logits.iter().all(|v| v.is_finite()) && on_grads.iter().flatten().all(|v| v.is_finite()),
|
||||
"flash+dropout produced non-finite logits/grads"
|
||||
);
|
||||
|
||||
let logit_rel = off_logits
|
||||
.iter()
|
||||
.zip(&on_logits)
|
||||
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||||
.fold(0.0f32, f32::max);
|
||||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||||
let mut grad_rel = 0.0f32;
|
||||
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
||||
grad_rel = grad_rel.max((a - b).abs() / a.abs().max(1e-3));
|
||||
}
|
||||
println!(
|
||||
"[F32] flash+dropout vs composed+dropout: loss rel {loss_rel:.2e}, \
|
||||
logits max rel {logit_rel:.2e}, grad max rel {grad_rel:.3e}"
|
||||
);
|
||||
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
|
||||
// differs from composed only by reduction order; dropout masks are identical.
|
||||
assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}");
|
||||
assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}");
|
||||
assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user