from __future__ import annotations import torch def torch_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Reference matrix multiplication with simple shape validation.""" if a.ndim != 2 or b.ndim != 2: raise ValueError("torch_matmul expects two 2D tensors") if a.shape[1] != b.shape[0]: raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}") return a @ b