At vocab 50257 the logits tensor [B*S, vocab] is ~1.6GB fp32 at batch
32 — held across the whole backward. Keep it bf16: cross_entropy
upcasts the bf16 logits to fp32 internally (transient) + caches fp32
probs, and its backward casts dx back to bf16 to chain into the
bf16 lm_head matmul backward. The sampler casts bf16 logits→f32 before
the host argmax/softmax. Halves the persistent logits activation.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>