Initial project scaffold
This commit is contained in:
2
reference/__init__.py
Normal file
2
reference/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Reference PyTorch implementations used throughout the lab."""
|
||||
|
||||
30
reference/torch_attention.py
Normal file
30
reference/torch_attention.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Reference scaled dot-product attention forward pass."""
|
||||
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
|
||||
raise ValueError("expected tensors shaped [batch, heads, seq, dim]")
|
||||
if q.shape != k.shape or q.shape != v.shape:
|
||||
raise ValueError(f"q, k, v must have matching shapes; got {q.shape}, {k.shape}, {v.shape}")
|
||||
|
||||
dim = q.shape[-1]
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
|
||||
if causal:
|
||||
seq = q.shape[-2]
|
||||
mask = torch.triu(
|
||||
torch.ones((seq, seq), dtype=torch.bool, device=q.device), diagonal=1
|
||||
)
|
||||
scores = scores.masked_fill(mask, float("-inf"))
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
return torch.matmul(probs, v)
|
||||
|
||||
13
reference/torch_matmul.py
Normal file
13
reference/torch_matmul.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""Reference matrix multiplication with simple shape validation."""
|
||||
if a.ndim != 2 or b.ndim != 2:
|
||||
raise ValueError("torch_matmul expects two 2D tensors")
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
||||
return a @ b
|
||||
|
||||
25
reference/torch_online_softmax.py
Normal file
25
reference/torch_online_softmax.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_online_softmax(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Reference online-softmax derivation implemented with an explicit loop."""
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
|
||||
|
||||
running_max = torch.full(
|
||||
(x.shape[0],), float("-inf"), dtype=x.dtype, device=x.device
|
||||
)
|
||||
running_sum = torch.zeros((x.shape[0],), dtype=x.dtype, device=x.device)
|
||||
|
||||
for col in range(x.shape[1]):
|
||||
current = x[:, col]
|
||||
new_max = torch.maximum(running_max, current)
|
||||
old_scale = torch.exp(running_max - new_max)
|
||||
current_scale = torch.exp(current - new_max)
|
||||
running_sum = running_sum * old_scale + current_scale
|
||||
running_max = new_max
|
||||
|
||||
return torch.exp(x - running_max[:, None]) / running_sum[:, None]
|
||||
|
||||
15
reference/torch_row_softmax.py
Normal file
15
reference/torch_row_softmax.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_row_softmax(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Numerically stable row-wise softmax for 2D inputs."""
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected a 2D tensor, got shape {tuple(x.shape)}")
|
||||
row_max = x.max(dim=1, keepdim=True).values
|
||||
shifted = x - row_max
|
||||
exp_shifted = shifted.exp()
|
||||
row_sum = exp_shifted.sum(dim=1, keepdim=True)
|
||||
return exp_shifted / row_sum
|
||||
|
||||
11
reference/torch_vector_add.py
Normal file
11
reference/torch_vector_add.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_vector_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""Reference vector add with explicit shape checks."""
|
||||
if x.shape != y.shape:
|
||||
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
|
||||
return x + y
|
||||
|
||||
Reference in New Issue
Block a user