test: bf16 test reads f32-cast logits (forward now returns bf16)
The `keep bf16 logits` change made forward_batched return bf16 logits in bf16 mode; the bf16 test's host read must cast to f32 first. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -85,9 +85,15 @@ fn bf16_matches_fp32_within_loose_tol() {
|
||||
f_loss.backward();
|
||||
let f_params = fp32.params();
|
||||
|
||||
// bf16 — SAME init (build re-runs the same deterministic fill).
|
||||
// bf16 — SAME init (build re-runs the same deterministic fill). The forward
|
||||
// now returns bf16 logits (CE upcasts internally); cast to f32 to read.
|
||||
let bf16 = build(cfg, device).with_compute_dtype(DType::BF16);
|
||||
let b_logits = host(&bf16.forward_batched(&ids, batch).value());
|
||||
let b_logits = host(
|
||||
&bf16
|
||||
.forward_batched(&ids, batch)
|
||||
.value()
|
||||
.to_dtype(DType::F32),
|
||||
);
|
||||
let b_loss = bf16.loss_batched(&ids, &tgt, batch);
|
||||
let b_loss_val = host(&b_loss.value())[0];
|
||||
b_loss.backward();
|
||||
|
||||
Reference in New Issue
Block a user