Initial project scaffold

This commit is contained in:
wjh
2026-04-10 13:15:06 +00:00
commit a4a6b1f1c8
94 changed files with 3964 additions and 0 deletions

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
import torch
def torch_online_softmax(x: torch.Tensor) -> torch.Tensor:
"""Reference online-softmax derivation implemented with an explicit loop."""
if x.ndim != 2:
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
running_max = torch.full(
(x.shape[0],), float("-inf"), dtype=x.dtype, device=x.device
)
running_sum = torch.zeros((x.shape[0],), dtype=x.dtype, device=x.device)
for col in range(x.shape[1]):
current = x[:, col]
new_max = torch.maximum(running_max, current)
old_scale = torch.exp(running_max - new_max)
current_scale = torch.exp(current - new_max)
running_sum = running_sum * old_scale + current_scale
running_max = new_max
return torch.exp(x - running_max[:, None]) / running_sum[:, None]