perf: keep bf16 logits (no persistent fp32 logits buffer)
At vocab 50257 the logits tensor [B*S, vocab] is ~1.6GB fp32 at batch 32 — held across the whole backward. Keep it bf16: cross_entropy upcasts the bf16 logits to fp32 internally (transient) + caches fp32 probs, and its backward casts dx back to bf16 to chain into the bf16 lm_head matmul backward. The sampler casts bf16 logits→f32 before the host argmax/softmax. Halves the persistent logits activation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -329,6 +329,11 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
|||||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||||
/// scaled by the upstream scalar grad.
|
/// scaled by the upstream scalar grad.
|
||||||
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||||
|
// CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32
|
||||||
|
// probs). The grad must match the logits' dtype so it chains into a bf16
|
||||||
|
// lm_head matmul backward — cast dx back. Keeping logits bf16 (no persistent
|
||||||
|
// fp32 logits buffer) is a real activation-memory saving at large vocab.
|
||||||
|
let logit_dtype = x.value().dtype();
|
||||||
let (probs, per_row) = x.value().cross_entropy(target);
|
let (probs, per_row) = x.value().cross_entropy(target);
|
||||||
let rows = x.value().shape()[0];
|
let rows = x.value().shape()[0];
|
||||||
// Mean loss as a host scalar wrapped back into a [1] tensor.
|
// Mean loss as a host scalar wrapped back into a [1] tensor.
|
||||||
@@ -345,7 +350,7 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
|||||||
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
||||||
let scale = upstream / rows as f32;
|
let scale = upstream / rows as f32;
|
||||||
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
|
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
|
||||||
Var::push_grad(&parents[0], dx);
|
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -183,13 +183,10 @@ impl TinyTransformer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps);
|
let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps);
|
||||||
// lm_head matmul in compute dtype; cast logits back to fp32 for CE.
|
// lm_head matmul in compute dtype. Logits stay bf16 in bf16 mode — the
|
||||||
let logits = self.linear(&h, &self.lm_head); // [batch*seq, vocab]
|
// cross_entropy op upcasts to fp32 internally (no persistent fp32 logits
|
||||||
if self.compute_dtype == DType::BF16 {
|
// buffer, a real saving at vocab 50257), and its backward casts dx back.
|
||||||
ops::cast(&logits, DType::F32)
|
self.linear(&h, &self.lm_head) // [batch*seq, vocab]
|
||||||
} else {
|
|
||||||
logits
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast
|
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast
|
||||||
|
|||||||
@@ -26,7 +26,11 @@ pub fn generate(
|
|||||||
|
|
||||||
for _ in 0..max_new {
|
for _ in 0..max_new {
|
||||||
let ids_t = ids_tensor(&ids, device);
|
let ids_t = ids_tensor(&ids, device);
|
||||||
let logits = model.forward(&ids_t).value().to_device(Device::Cpu);
|
// In bf16 mode the logits are bf16; cast to f32 (on device) before reading.
|
||||||
|
let logits = model.forward(&ids_t).value();
|
||||||
|
let logits = logits
|
||||||
|
.to_dtype(xtrain_tensor::DType::F32)
|
||||||
|
.to_device(Device::Cpu);
|
||||||
let lg = logits.as_slice::<f32>();
|
let lg = logits.as_slice::<f32>();
|
||||||
// Last row = next-token distribution for the current prefix.
|
// Last row = next-token distribution for the current prefix.
|
||||||
let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab];
|
let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab];
|
||||||
|
|||||||
Reference in New Issue
Block a user