test: flash+dropout cross-feature grad-check (Phase-2 integration)
Add flash_plus_dropout_grad_check_fp32 to xtrain-model dropout tests: the two orthogonal Phase-2 features (T14 flash-attn, T18 dropout) in the same model must still grad-check. Both models run train-mode p=0.2 (identical masks, seed is flash-independent) so the only delta is the SDPA reduction order — checked against the flash-vs-composed tolerance. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -220,3 +220,64 @@ fn dropout_recompute_matches_fp32() {
|
|||||||
fn dropout_recompute_matches_bf16() {
|
fn dropout_recompute_matches_bf16() {
|
||||||
recompute_with_dropout(DType::BF16, 5e-3);
|
recompute_with_dropout(DType::BF16, 5e-3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Cross-feature gate (Phase-2 integration): flash (T14) + dropout (T18)
|
||||||
|
// together in the SAME model still grad-checks. Build two identical models, both
|
||||||
|
// in train mode with p=0.2 (so dropout fires), one with `--flash` on, one off.
|
||||||
|
// The dropout site seeds are a pure function of (step_seed, layer, site) and are
|
||||||
|
// INDEPENDENT of flash, so both models draw the SAME masks on their first training
|
||||||
|
// forward → the only difference is the SDPA reduction order. Assert logits/loss/
|
||||||
|
// grads match within the flash-vs-composed tolerance and are finite. This is the
|
||||||
|
// orthogonality check for the two Phase-2 features landing together.
|
||||||
|
#[test]
|
||||||
|
fn flash_plus_dropout_grad_check_fp32() {
|
||||||
|
let device = require_gpu();
|
||||||
|
let batch = 3;
|
||||||
|
// seq=40 > FA_TILE=32 exercises flash's online-softmax tile-rescale path.
|
||||||
|
let mut cfg = Config::tiny();
|
||||||
|
cfg.vocab = 16;
|
||||||
|
cfg.n_layers = 4;
|
||||||
|
cfg.dropout = 0.2;
|
||||||
|
let seq = 40usize;
|
||||||
|
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||||
|
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||||
|
.collect();
|
||||||
|
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||||
|
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||||
|
.collect();
|
||||||
|
let ids = batched_ids_tensor(&seqs, device);
|
||||||
|
let tgt = batched_ids_tensor(&tgts, device);
|
||||||
|
|
||||||
|
// Both: same init, train mode (dropout active), same step_seed progression →
|
||||||
|
// identical masks; one composed SDPA, one flash SDPA.
|
||||||
|
let off = build(cfg, device).with_training(true);
|
||||||
|
let on = build(cfg, device).with_flash(true).with_training(true);
|
||||||
|
|
||||||
|
let (off_logits, off_loss, off_grads) = fwd_bwd(&off, &ids, &tgt, batch);
|
||||||
|
let (on_logits, on_loss, on_grads) = fwd_bwd(&on, &ids, &tgt, batch);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
on_logits.iter().all(|v| v.is_finite()) && on_grads.iter().flatten().all(|v| v.is_finite()),
|
||||||
|
"flash+dropout produced non-finite logits/grads"
|
||||||
|
);
|
||||||
|
|
||||||
|
let logit_rel = off_logits
|
||||||
|
.iter()
|
||||||
|
.zip(&on_logits)
|
||||||
|
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||||||
|
.fold(0.0f32, f32::max);
|
||||||
|
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||||||
|
let mut grad_rel = 0.0f32;
|
||||||
|
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
||||||
|
grad_rel = grad_rel.max((a - b).abs() / a.abs().max(1e-3));
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"[F32] flash+dropout vs composed+dropout: loss rel {loss_rel:.2e}, \
|
||||||
|
logits max rel {logit_rel:.2e}, grad max rel {grad_rel:.3e}"
|
||||||
|
);
|
||||||
|
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
|
||||||
|
// differs from composed only by reduction order; dropout masks are identical.
|
||||||
|
assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}");
|
||||||
|
assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}");
|
||||||
|
assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}");
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user