Files
kernel-lab/reference/torch_attention.py
2026-04-10 13:15:06 +00:00

31 lines
916 B
Python

from __future__ import annotations
import math
import torch
def torch_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
"""Reference scaled dot-product attention forward pass."""
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError("expected tensors shaped [batch, heads, seq, dim]")
if q.shape != k.shape or q.shape != v.shape:
raise ValueError(f"q, k, v must have matching shapes; got {q.shape}, {k.shape}, {v.shape}")
dim = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim)
if causal:
seq = q.shape[-2]
mask = torch.triu(
torch.ones((seq, seq), dtype=torch.bool, device=q.device), diagonal=1
)
scores = scores.masked_fill(mask, float("-inf"))
probs = torch.softmax(scores, dim=-1)
return torch.matmul(probs, v)