from __future__ import annotations import torch def torch_online_softmax(x: torch.Tensor) -> torch.Tensor: """Reference online-softmax derivation implemented with an explicit loop.""" if x.ndim != 2: raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}") running_max = torch.full( (x.shape[0],), float("-inf"), dtype=x.dtype, device=x.device ) running_sum = torch.zeros((x.shape[0],), dtype=x.dtype, device=x.device) for col in range(x.shape[1]): current = x[:, col] new_max = torch.maximum(running_max, current) old_scale = torch.exp(running_max - new_max) current_scale = torch.exp(current - new_max) running_sum = running_sum * old_scale + current_scale running_max = new_max return torch.exp(x - running_max[:, None]) / running_sum[:, None]