41 lines
1.2 KiB
Plaintext
41 lines
1.2 KiB
Plaintext
#include "../include/common.h"
|
|
#include "../include/cuda_utils.h"
|
|
|
|
namespace kernel_lab {
|
|
|
|
__global__ void tiled_matmul_kernel(
|
|
const float* a,
|
|
const float* b,
|
|
float* c,
|
|
int64_t m,
|
|
int64_t n,
|
|
int64_t k) {
|
|
// TODO(student): map blockIdx/threadIdx to a C tile.
|
|
// TODO(student): cooperatively load A and B tiles into shared memory.
|
|
// TODO(student): accumulate partial products across the K dimension.
|
|
(void)a;
|
|
(void)b;
|
|
(void)c;
|
|
(void)m;
|
|
(void)n;
|
|
(void)k;
|
|
}
|
|
|
|
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b) {
|
|
LAB_CHECK_CUDA(a);
|
|
LAB_CHECK_CUDA(b);
|
|
LAB_CHECK_CONTIGUOUS(a);
|
|
LAB_CHECK_CONTIGUOUS(b);
|
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "tiled_matmul_cuda expects 2D tensors");
|
|
TORCH_CHECK(a.size(1) == b.size(0), "inner dimensions must match");
|
|
TORCH_CHECK(a.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
|
|
TORCH_CHECK(b.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
|
|
|
|
TORCH_CHECK(
|
|
false,
|
|
"TODO(student): implement tiled_matmul_cuda in kernels/cuda/src/tiled_matmul.cu.");
|
|
return torch::Tensor();
|
|
}
|
|
|
|
} // namespace kernel_lab
|