Initial project scaffold
This commit is contained in:
25
reference/torch_online_softmax.py
Normal file
25
reference/torch_online_softmax.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_online_softmax(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Reference online-softmax derivation implemented with an explicit loop."""
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
|
||||
|
||||
running_max = torch.full(
|
||||
(x.shape[0],), float("-inf"), dtype=x.dtype, device=x.device
|
||||
)
|
||||
running_sum = torch.zeros((x.shape[0],), dtype=x.dtype, device=x.device)
|
||||
|
||||
for col in range(x.shape[1]):
|
||||
current = x[:, col]
|
||||
new_max = torch.maximum(running_max, current)
|
||||
old_scale = torch.exp(running_max - new_max)
|
||||
current_scale = torch.exp(current - new_max)
|
||||
running_sum = running_sum * old_scale + current_scale
|
||||
running_max = new_max
|
||||
|
||||
return torch.exp(x - running_max[:, None]) / running_sum[:, None]
|
||||
|
||||
Reference in New Issue
Block a user