Files
kernel-lab/reference/torch_row_softmax.py
2026-04-10 13:22:19 +00:00

16 lines
457 B
Python

from __future__ import annotations
import torch
def torch_row_softmax(x: torch.Tensor) -> torch.Tensor:
"""Numerically stable row-wise softmax for 2D inputs."""
if x.ndim != 2:
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
row_max = x.max(dim=1, keepdim=True).values
shifted = x - row_max
exp_shifted = shifted.exp()
row_sum = exp_shifted.sum(dim=1, keepdim=True)
return exp_shifted / row_sum