16 lines
457 B
Python
16 lines
457 B
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
|
|
def torch_row_softmax(x: torch.Tensor) -> torch.Tensor:
|
|
"""Numerically stable row-wise softmax for 2D inputs."""
|
|
if x.ndim != 2:
|
|
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
|
|
row_max = x.max(dim=1, keepdim=True).values
|
|
shifted = x - row_max
|
|
exp_shifted = shifted.exp()
|
|
row_sum = exp_shifted.sum(dim=1, keepdim=True)
|
|
return exp_shifted / row_sum
|
|
|