perf: keep bf16 logits (no persistent fp32 logits buffer)

At vocab 50257 the logits tensor [B*S, vocab] is ~1.6GB fp32 at batch
32 — held across the whole backward. Keep it bf16: cross_entropy
upcasts the bf16 logits to fp32 internally (transient) + caches fp32
probs, and its backward casts dx back to bf16 to chain into the
bf16 lm_head matmul backward. The sampler casts bf16 logits→f32 before
the host argmax/softmax. Halves the persistent logits activation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 14:20:48 +08:00
parent 30db62d8f2
commit 48922cb628
3 changed files with 15 additions and 9 deletions

View File

@@ -329,6 +329,11 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// scaled by the upstream scalar grad.
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
// CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32
// probs). The grad must match the logits' dtype so it chains into a bf16
// lm_head matmul backward — cast dx back. Keeping logits bf16 (no persistent
// fp32 logits buffer) is a real activation-memory saving at large vocab.
let logit_dtype = x.value().dtype();
let (probs, per_row) = x.value().cross_entropy(target);
let rows = x.value().shape()[0];
// Mean loss as a host scalar wrapped back into a [1] tensor.
@@ -345,7 +350,7 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
let scale = upstream / rows as f32;
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
Var::push_grad(&parents[0], dx);
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),
)
}

View File

@@ -183,13 +183,10 @@ impl TinyTransformer {
}
let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps);
// lm_head matmul in compute dtype; cast logits back to fp32 for CE.
let logits = self.linear(&h, &self.lm_head); // [batch*seq, vocab]
if self.compute_dtype == DType::BF16 {
ops::cast(&logits, DType::F32)
} else {
logits
}
// lm_head matmul in compute dtype. Logits stay bf16 in bf16 mode — the
// cross_entropy op upcasts to fp32 internally (no persistent fp32 logits
// buffer, a real saving at vocab 50257), and its backward casts dx back.
self.linear(&h, &self.lm_head) // [batch*seq, vocab]
}
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast

View File

@@ -26,7 +26,11 @@ pub fn generate(
for _ in 0..max_new {
let ids_t = ids_tensor(&ids, device);
let logits = model.forward(&ids_t).value().to_device(Device::Cpu);
// In bf16 mode the logits are bf16; cast to f32 (on device) before reading.
let logits = model.forward(&ids_t).value();
let logits = logits
.to_dtype(xtrain_tensor::DType::F32)
.to_device(Device::Cpu);
let lg = logits.as_slice::<f32>();
// Last row = next-token distribution for the current prefix.
let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab];