Initial project scaffold
This commit is contained in:
15
reference/torch_row_softmax.py
Normal file
15
reference/torch_row_softmax.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_row_softmax(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Numerically stable row-wise softmax for 2D inputs."""
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
|
||||
row_max = x.max(dim=1, keepdim=True).values
|
||||
shifted = x - row_max
|
||||
exp_shifted = shifted.exp()
|
||||
row_sum = exp_shifted.sum(dim=1, keepdim=True)
|
||||
return exp_shifted / row_sum
|
||||
|
||||
Reference in New Issue
Block a user