From 4b6d3e0a79465796b90543a62647056580abcc00 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 18 Jun 2026 00:43:54 +0800 Subject: [PATCH] test: flash+dropout cross-feature grad-check (Phase-2 integration) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add flash_plus_dropout_grad_check_fp32 to xtrain-model dropout tests: the two orthogonal Phase-2 features (T14 flash-attn, T18 dropout) in the same model must still grad-check. Both models run train-mode p=0.2 (identical masks, seed is flash-independent) so the only delta is the SDPA reduction order — checked against the flash-vs-composed tolerance. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-model/tests/dropout.rs | 61 ++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/crates/xtrain-model/tests/dropout.rs b/crates/xtrain-model/tests/dropout.rs index 04d5d68..c8392ab 100644 --- a/crates/xtrain-model/tests/dropout.rs +++ b/crates/xtrain-model/tests/dropout.rs @@ -220,3 +220,64 @@ fn dropout_recompute_matches_fp32() { fn dropout_recompute_matches_bf16() { recompute_with_dropout(DType::BF16, 5e-3); } + +// --- Cross-feature gate (Phase-2 integration): flash (T14) + dropout (T18) +// together in the SAME model still grad-checks. Build two identical models, both +// in train mode with p=0.2 (so dropout fires), one with `--flash` on, one off. +// The dropout site seeds are a pure function of (step_seed, layer, site) and are +// INDEPENDENT of flash, so both models draw the SAME masks on their first training +// forward → the only difference is the SDPA reduction order. Assert logits/loss/ +// grads match within the flash-vs-composed tolerance and are finite. This is the +// orthogonality check for the two Phase-2 features landing together. +#[test] +fn flash_plus_dropout_grad_check_fp32() { + let device = require_gpu(); + let batch = 3; + // seq=40 > FA_TILE=32 exercises flash's online-softmax tile-rescale path. + let mut cfg = Config::tiny(); + cfg.vocab = 16; + cfg.n_layers = 4; + cfg.dropout = 0.2; + let seq = 40usize; + let seqs: Vec> = (0..batch) + .map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect()) + .collect(); + let tgts: Vec> = (0..batch) + .map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect()) + .collect(); + let ids = batched_ids_tensor(&seqs, device); + let tgt = batched_ids_tensor(&tgts, device); + + // Both: same init, train mode (dropout active), same step_seed progression → + // identical masks; one composed SDPA, one flash SDPA. + let off = build(cfg, device).with_training(true); + let on = build(cfg, device).with_flash(true).with_training(true); + + let (off_logits, off_loss, off_grads) = fwd_bwd(&off, &ids, &tgt, batch); + let (on_logits, on_loss, on_grads) = fwd_bwd(&on, &ids, &tgt, batch); + + assert!( + on_logits.iter().all(|v| v.is_finite()) && on_grads.iter().flatten().all(|v| v.is_finite()), + "flash+dropout produced non-finite logits/grads" + ); + + let logit_rel = off_logits + .iter() + .zip(&on_logits) + .map(|(a, b)| (a - b).abs() / a.abs().max(1e-4)) + .fold(0.0f32, f32::max); + let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4); + let mut grad_rel = 0.0f32; + for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) { + grad_rel = grad_rel.max((a - b).abs() / a.abs().max(1e-3)); + } + println!( + "[F32] flash+dropout vs composed+dropout: loss rel {loss_rel:.2e}, \ + logits max rel {logit_rel:.2e}, grad max rel {grad_rel:.3e}" + ); + // Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash + // differs from composed only by reduction order; dropout masks are identical. + assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}"); + assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}"); + assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}"); +}