Initial project scaffold

This commit is contained in:
2026-04-10 13:22:19 +00:00
commit 7fa69b1354
94 changed files with 3964 additions and 0 deletions

2
reference/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""Reference PyTorch implementations used throughout the lab."""

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import math
import torch
def torch_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
"""Reference scaled dot-product attention forward pass."""
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError("expected tensors shaped [batch, heads, seq, dim]")
if q.shape != k.shape or q.shape != v.shape:
raise ValueError(f"q, k, v must have matching shapes; got {q.shape}, {k.shape}, {v.shape}")
dim = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
if causal:
seq = q.shape[-2]
mask = torch.triu(
torch.ones((seq, seq), dtype=torch.bool, device=q.device), diagonal=1
)
scores = scores.masked_fill(mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
return torch.matmul(probs, v)

13
reference/torch_matmul.py Normal file
View File

@@ -0,0 +1,13 @@
from __future__ import annotations
import torch
def torch_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Reference matrix multiplication with simple shape validation."""
if a.ndim != 2 or b.ndim != 2:
raise ValueError("torch_matmul expects two 2D tensors")
if a.shape[1] != b.shape[0]:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
return a @ b

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
import torch
def torch_online_softmax(x: torch.Tensor) -> torch.Tensor:
"""Reference online-softmax derivation implemented with an explicit loop."""
if x.ndim != 2:
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
running_max = torch.full(
(x.shape[0],), float("-inf"), dtype=x.dtype, device=x.device
)
running_sum = torch.zeros((x.shape[0],), dtype=x.dtype, device=x.device)
for col in range(x.shape[1]):
current = x[:, col]
new_max = torch.maximum(running_max, current)
old_scale = torch.exp(running_max - new_max)
current_scale = torch.exp(current - new_max)
running_sum = running_sum * old_scale + current_scale
running_max = new_max
return torch.exp(x - running_max[:, None]) / running_sum[:, None]

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
import torch
def torch_row_softmax(x: torch.Tensor) -> torch.Tensor:
"""Numerically stable row-wise softmax for 2D inputs."""
if x.ndim != 2:
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
row_max = x.max(dim=1, keepdim=True).values
shifted = x - row_max
exp_shifted = shifted.exp()
row_sum = exp_shifted.sum(dim=1, keepdim=True)
return exp_shifted / row_sum

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
import torch
def torch_vector_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Reference vector add with explicit shape checks."""
if x.shape != y.shape:
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
return x + y