Initial project scaffold
This commit is contained in:
30
reference/torch_attention.py
Normal file
30
reference/torch_attention.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torch_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Reference scaled dot-product attention forward pass."""
|
||||
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
|
||||
raise ValueError("expected tensors shaped [batch, heads, seq, dim]")
|
||||
if q.shape != k.shape or q.shape != v.shape:
|
||||
raise ValueError(f"q, k, v must have matching shapes; got {q.shape}, {k.shape}, {v.shape}")
|
||||
|
||||
dim = q.shape[-1]
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
|
||||
if causal:
|
||||
seq = q.shape[-2]
|
||||
mask = torch.triu(
|
||||
torch.ones((seq, seq), dtype=torch.bool, device=q.device), diagonal=1
|
||||
)
|
||||
scores = scores.masked_fill(mask, float("-inf"))
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
return torch.matmul(probs, v)
|
||||
|
||||
Reference in New Issue
Block a user