From 5b7dde1736bd4cca6d86b118b3bb238d98b7b533 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 16 Jun 2026 14:29:24 +0800 Subject: [PATCH] test: bf16 test reads f32-cast logits (forward now returns bf16) The `keep bf16 logits` change made forward_batched return bf16 logits in bf16 mode; the bf16 test's host read must cast to f32 first. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-model/tests/bf16.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/crates/xtrain-model/tests/bf16.rs b/crates/xtrain-model/tests/bf16.rs index 1de0e8b..db3ee4a 100644 --- a/crates/xtrain-model/tests/bf16.rs +++ b/crates/xtrain-model/tests/bf16.rs @@ -85,9 +85,15 @@ fn bf16_matches_fp32_within_loose_tol() { f_loss.backward(); let f_params = fp32.params(); - // bf16 — SAME init (build re-runs the same deterministic fill). + // bf16 — SAME init (build re-runs the same deterministic fill). The forward + // now returns bf16 logits (CE upcasts internally); cast to f32 to read. let bf16 = build(cfg, device).with_compute_dtype(DType::BF16); - let b_logits = host(&bf16.forward_batched(&ids, batch).value()); + let b_logits = host( + &bf16 + .forward_batched(&ids, batch) + .value() + .to_dtype(DType::F32), + ); let b_loss = bf16.loss_batched(&ids, &tgt, batch); let b_loss_val = host(&b_loss.value())[0]; b_loss.backward();