Files
xtrain/crates/xtrain-distributed
Gahow Wang 5f3b81ac96 test+bins: flash grad-check, flash==composed, PyTorch parity, --flash flag
autograd: flash_attention_batched_bwd (dQ/dK/dV finite-diff, seq>tile)
+ flash_matches_composed_fwd. model/tests/flash.rs: flash==composed
on-vs-off (logits/loss/every param grad), fp32 + bf16. parity_dump:
XTRAIN_PARITY_FLASH dumps the flash path for the same parity.py oracle
(PyTorch SDPA parity at B>1). train + train_ddp get the --flash flag.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:10:39 +08:00
..
2026-06-15 17:14:56 +08:00