Files
kernel-lab/reference/torch_matmul.py
2026-04-10 13:15:06 +00:00

14 lines
415 B
Python

from __future__ import annotations
import torch
def torch_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Reference matrix multiplication with simple shape validation."""
if a.ndim != 2 or b.ndim != 2:
raise ValueError("torch_matmul expects two 2D tensors")
if a.shape[1] != b.shape[0]:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
return a @ b