Files
xtrain/crates
Gahow Wang 0a2a4dcaa8 train: --bf16 flag (fp32-master AMP) + bf16 correctness test
- TinyTransformer::with_compute_dtype(BF16): embedding stays fp32
  master then casts to bf16; each linear casts its fp32 weight to bf16
  on the fly; logits cast back to fp32 for cross-entropy. Default F32
  reproduces the v0-v4 forward graph bit-for-bit.
- --bf16 flag on bin/train and bin/train_ddp (off by default).
- tests/bf16.rs: same fp32 master weights run fp32 vs bf16; assert
  loss/logits/grads within a loose bf16 tol, no NaN, and grads are
  fp32 (master untouched).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 14:14:55 +08:00
..