from __future__ import annotations import torch def torch_row_softmax(x: torch.Tensor) -> torch.Tensor: """Numerically stable row-wise softmax for 2D inputs.""" if x.ndim != 2: raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}") row_max = x.max(dim=1, keepdim=True).values shifted = x - row_max exp_shifted = shifted.exp() row_sum = exp_shifted.sum(dim=1, keepdim=True) return exp_shifted / row_sum