26 lines
859 B
Python
26 lines
859 B
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
|
|
def torch_online_softmax(x: torch.Tensor) -> torch.Tensor:
|
|
"""Reference online-softmax derivation implemented with an explicit loop."""
|
|
if x.ndim != 2:
|
|
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
|
|
|
|
running_max = torch.full(
|
|
(x.shape[0],), float("-inf"), dtype=x.dtype, device=x.device
|
|
)
|
|
running_sum = torch.zeros((x.shape[0],), dtype=x.dtype, device=x.device)
|
|
|
|
for col in range(x.shape[1]):
|
|
current = x[:, col]
|
|
new_max = torch.maximum(running_max, current)
|
|
old_scale = torch.exp(running_max - new_max)
|
|
current_scale = torch.exp(current - new_max)
|
|
running_sum = running_sum * old_scale + current_scale
|
|
running_max = new_max
|
|
|
|
return torch.exp(x - running_max[:, None]) / running_sum[:, None]
|
|
|