perf: keep bf16 logits (no persistent fp32 logits buffer)
At vocab 50257 the logits tensor [B*S, vocab] is ~1.6GB fp32 at batch 32 — held across the whole backward. Keep it bf16: cross_entropy upcasts the bf16 logits to fp32 internally (transient) + caches fp32 probs, and its backward casts dx back to bf16 to chain into the bf16 lm_head matmul backward. The sampler casts bf16 logits→f32 before the host argmax/softmax. Halves the persistent logits activation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -329,6 +329,11 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
// CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32
|
||||
// probs). The grad must match the logits' dtype so it chains into a bf16
|
||||
// lm_head matmul backward — cast dx back. Keeping logits bf16 (no persistent
|
||||
// fp32 logits buffer) is a real activation-memory saving at large vocab.
|
||||
let logit_dtype = x.value().dtype();
|
||||
let (probs, per_row) = x.value().cross_entropy(target);
|
||||
let rows = x.value().shape()[0];
|
||||
// Mean loss as a host scalar wrapped back into a [1] tensor.
|
||||
@@ -345,7 +350,7 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
||||
let scale = upstream / rows as f32;
|
||||
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
|
||||
Var::push_grad(&parents[0], dx);
|
||||
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -183,13 +183,10 @@ impl TinyTransformer {
|
||||
}
|
||||
|
||||
let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps);
|
||||
// lm_head matmul in compute dtype; cast logits back to fp32 for CE.
|
||||
let logits = self.linear(&h, &self.lm_head); // [batch*seq, vocab]
|
||||
if self.compute_dtype == DType::BF16 {
|
||||
ops::cast(&logits, DType::F32)
|
||||
} else {
|
||||
logits
|
||||
}
|
||||
// lm_head matmul in compute dtype. Logits stay bf16 in bf16 mode — the
|
||||
// cross_entropy op upcasts to fp32 internally (no persistent fp32 logits
|
||||
// buffer, a real saving at vocab 50257), and its backward casts dx back.
|
||||
self.linear(&h, &self.lm_head) // [batch*seq, vocab]
|
||||
}
|
||||
|
||||
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast
|
||||
|
||||
@@ -26,7 +26,11 @@ pub fn generate(
|
||||
|
||||
for _ in 0..max_new {
|
||||
let ids_t = ids_tensor(&ids, device);
|
||||
let logits = model.forward(&ids_t).value().to_device(Device::Cpu);
|
||||
// In bf16 mode the logits are bf16; cast to f32 (on device) before reading.
|
||||
let logits = model.forward(&ids_t).value();
|
||||
let logits = logits
|
||||
.to_dtype(xtrain_tensor::DType::F32)
|
||||
.to_device(Device::Cpu);
|
||||
let lg = logits.as_slice::<f32>();
|
||||
// Last row = next-token distribution for the current prefix.
|
||||
let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab];
|
||||
|
||||
Reference in New Issue
Block a user