Files
kernel-lab/reference/torch_online_softmax.py
2026-04-10 13:15:06 +00:00

26 lines
859 B
Python

from __future__ import annotations
import torch
def torch_online_softmax(x: torch.Tensor) -> torch.Tensor:
"""Reference online-softmax derivation implemented with an explicit loop."""
if x.ndim != 2:
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
running_max = torch.full(
(x.shape[0],), float("-inf"), dtype=x.dtype, device=x.device
)
running_sum = torch.zeros((x.shape[0],), dtype=x.dtype, device=x.device)
for col in range(x.shape[1]):
current = x[:, col]
new_max = torch.maximum(running_max, current)
old_scale = torch.exp(running_max - new_max)
current_scale = torch.exp(current - new_max)
running_sum = running_sum * old_scale + current_scale
running_max = new_max
return torch.exp(x - running_max[:, None]) / running_sum[:, None]