#include #include #include #include int main() { cublasLtHandle_t handle; cublasLtCreate(&handle); // Model dimensions: M=1 (decode), K=2880, N=5760 int M=1, N=5760, K=2880; float one = 1.0f; void *dScale; cudaMalloc(&dScale, 4); cudaMemcpy(dScale, &one, 4, cudaMemcpyHostToDevice); cublasLtMatmulPreference_t pref; cublasLtMatmulPreferenceCreate(&pref); size_t ws = 32*1024*1024; cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws, sizeof(ws)); cublasLtMatmulDesc_t desc; cublasLtMatrixLayout_t Adesc, Bdesc, Cdesc, Ddesc; cublasLtMatmulHeuristicResult_t result; int found; cublasStatus_t status; // Test 1: transA=T, transB=N, m=N, n=M, k=K // A stored (K, N) ld=K -> transposed to (N, K) // B stored (K, M) ld=K printf("Test1: transA=T transB=N, m=%d n=%d k=%d\n", N, M, K); { cublasLtMatmulDescCreate(&desc, CUBLAS_COMPUTE_32F, CUDA_R_32F); int32_t transA = 1; // CUBLAS_OP_T cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &transA, 4); cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &dScale, sizeof(void*)); cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &dScale, sizeof(void*)); cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, K, N, K); cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, K, M, K); cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16BF, N, M, N); cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_16BF, N, M, N); found = 0; status = cublasLtMatmulAlgoGetHeuristic(handle, desc, Adesc, Bdesc, Cdesc, Ddesc, pref, 1, &result, &found); printf(" status=%d found=%d\n", status, found); cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Bdesc); cublasLtMatrixLayoutDestroy(Cdesc); cublasLtMatrixLayoutDestroy(Ddesc); cublasLtMatmulDescDestroy(desc); } // Test 2: same but transA=N, transB=N printf("Test2: transA=N transB=N, m=%d n=%d k=%d\n", N, M, K); { cublasLtMatmulDescCreate(&desc, CUBLAS_COMPUTE_32F, CUDA_R_32F); cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &dScale, sizeof(void*)); cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &dScale, sizeof(void*)); cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, N, K, N); cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, K, M, K); cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16BF, N, M, N); cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_16BF, N, M, N); found = 0; status = cublasLtMatmulAlgoGetHeuristic(handle, desc, Adesc, Bdesc, Cdesc, Ddesc, pref, 1, &result, &found); printf(" status=%d found=%d\n", status, found); cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Bdesc); cublasLtMatrixLayoutDestroy(Cdesc); cublasLtMatrixLayoutDestroy(Ddesc); cublasLtMatmulDescDestroy(desc); } // Test 3: transA=N, transB=T printf("Test3: transA=N transB=T, m=%d n=%d k=%d\n", N, M, K); { cublasLtMatmulDescCreate(&desc, CUBLAS_COMPUTE_32F, CUDA_R_32F); int32_t transB = 1; cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &transB, 4); cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &dScale, sizeof(void*)); cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &dScale, sizeof(void*)); cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, N, K, N); cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, M, K, M); cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16BF, N, M, N); cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_16BF, N, M, N); found = 0; status = cublasLtMatmulAlgoGetHeuristic(handle, desc, Adesc, Bdesc, Cdesc, Ddesc, pref, 1, &result, &found); printf(" status=%d found=%d\n", status, found); cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatrixLayoutDestroy(Bdesc); cublasLtMatrixLayoutDestroy(Cdesc); cublasLtMatrixLayoutDestroy(Ddesc); cublasLtMatmulDescDestroy(desc); } cublasLtMatmulPreferenceDestroy(pref); cublasLtDestroy(handle); cudaFree(dScale); printf("Done.\n"); return 0; }