31 lines
916 B
Python
31 lines
916 B
Python
from __future__ import annotations
|
|
|
|
import math
|
|
|
|
import torch
|
|
|
|
|
|
def torch_attention(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
causal: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Reference scaled dot-product attention forward pass."""
|
|
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
|
|
raise ValueError("expected tensors shaped [batch, heads, seq, dim]")
|
|
if q.shape != k.shape or q.shape != v.shape:
|
|
raise ValueError(f"q, k, v must have matching shapes; got {q.shape}, {k.shape}, {v.shape}")
|
|
|
|
dim = q.shape[-1]
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
|
|
if causal:
|
|
seq = q.shape[-2]
|
|
mask = torch.triu(
|
|
torch.ones((seq, seq), dtype=torch.bool, device=q.device), diagonal=1
|
|
)
|
|
scores = scores.masked_fill(mask, float("-inf"))
|
|
probs = torch.softmax(scores, dim=-1)
|
|
return torch.matmul(probs, v)
|
|
|