test: bf16 test reads f32-cast logits (forward now returns bf16)

The `keep bf16 logits` change made forward_batched return bf16 logits
in bf16 mode; the bf16 test's host read must cast to f32 first.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 14:29:24 +08:00
parent 320c1ae4fb
commit 5b7dde1736

View File

@@ -85,9 +85,15 @@ fn bf16_matches_fp32_within_loose_tol() {
f_loss.backward();
let f_params = fp32.params();
// bf16 — SAME init (build re-runs the same deterministic fill).
// bf16 — SAME init (build re-runs the same deterministic fill). The forward
// now returns bf16 logits (CE upcasts internally); cast to f32 to read.
let bf16 = build(cfg, device).with_compute_dtype(DType::BF16);
let b_logits = host(&bf16.forward_batched(&ids, batch).value());
let b_logits = host(
&bf16
.forward_batched(&ids, batch)
.value()
.to_dtype(DType::F32),
);
let b_loss = bf16.loss_batched(&ids, &tgt, batch);
let b_loss_val = host(&b_loss.value())[0];
b_loss.backward();