from __future__ import annotations import math import torch def torch_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False, ) -> torch.Tensor: """Reference scaled dot-product attention forward pass.""" if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError("expected tensors shaped [batch, heads, seq, dim]") if q.shape != k.shape or q.shape != v.shape: raise ValueError(f"q, k, v must have matching shapes; got {q.shape}, {k.shape}, {v.shape}") dim = q.shape[-1] scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dim) if causal: seq = q.shape[-2] mask = torch.triu( torch.ones((seq, seq), dtype=torch.bool, device=q.device), diagonal=1 ) scores = scores.masked_fill(mask, float("-inf")) probs = torch.softmax(scores, dim=-1) return torch.matmul(probs, v)