Initial project scaffold

This commit is contained in:
2026-04-10 13:22:19 +00:00
commit 7fa69b1354
94 changed files with 3964 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import math
import torch
def torch_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
"""Reference scaled dot-product attention forward pass."""
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError("expected tensors shaped [batch, heads, seq, dim]")
if q.shape != k.shape or q.shape != v.shape:
raise ValueError(f"q, k, v must have matching shapes; got {q.shape}, {k.shape}, {v.shape}")
dim = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
if causal:
seq = q.shape[-2]
mask = torch.triu(
torch.ones((seq, seq), dtype=torch.bool, device=q.device), diagonal=1
)
scores = scores.masked_fill(mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
return torch.matmul(probs, v)