- docs/01-cuda-ffi.md: added takeaways (struct layout pitfall, Rust 2024 unsafe changes, caching allocator strategy, etc.) - docs/02-tensor.md: design doc + takeaways for tensor abstraction - docs/03-gemm.md: design doc + takeaways for GEMM kernels Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
4.2 KiB
Phase 3: GEMM — Design Document
Goal
实现矩阵乘法的多个版本(naive → tiled → cuBLAS),建立 benchmark 对比框架,深入理解 GPU 编程中的内存访问模式和优化手段。
Module Layout
csrc/gemm/
├── naive.cu # 每个 thread 算一个输出元素
└── tiled.cu # shared memory tiling, 32x32 tiles
crates/xserv-kernels/
├── build.rs # 编译 .cu + 链接 cublas
└── src/
├── lib.rs
└── gemm.rs # FFI 封装, GemmBackend enum, matmul(), CublasContext
Kernel Implementations
Version 1: Naive GEMM
Grid: (ceil(N/16), ceil(M/16))
Block: (16, 16)
每个 thread: C[row][col] = sum_k(A[row][k] * B[k][col])
- 每个 thread 独立遍历 K 维度做点积
- 所有读取走 global memory,无局部性优化
- BF16 版本在 FP32 中累加(
__bfloat162float→ 累加 →__float2bfloat16)
Version 2: Tiled GEMM (Shared Memory)
TILE_SIZE = 32
Grid: (ceil(N/32), ceil(M/32))
Block: (32, 32) = 1024 threads
每个 tile iteration:
1. 协作加载 A[tile] 和 B[tile] 到 shared memory
2. __syncthreads()
3. 在 shared memory 中做 32 次乘加
4. __syncthreads()
- 每个 global memory 读取被 TILE_SIZE 个 thread 复用
- 理论上减少 global memory 访问 TILE_SIZE 倍
- BF16 版本同样在 shared memory 中存 float(FP32 累加)
Version 3: cuBLAS
cublasGemmEx支持混合精度- Row-major 适配:cuBLAS 使用 column-major 布局,我们的 tensor 是 row-major
- 利用恒等式:
C = A @ B(row-major) ⟺C^T = B^T @ A^T(col-major) - 传入
CUBLAS_OP_N,让 cuBLAS 把我们的 row-major 数据当作 col-major 的转置 - 参数:
m=N, n=M, k=K, lda=N (B), ldb=K (A), ldc=N (C)
- 利用恒等式:
Backend Registry
pub enum GemmBackend { Naive, Tiled, CuBlas }
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor;
运行时可切换 backend,方便 benchmark 对比和逐步替换。
CublasContext
RAII 封装 cublasHandle_t,Drop 时调 cublasDestroy_v2。
目前每次 matmul 创建一个新 handle,后续优化为全局复用。
Test Plan
- F32: naive/tiled/cuBLAS × small(4)/medium(64-256)/rect(65x33x97)
- BF16: naive/tiled/cuBLAS × small/medium
- 三种 backend 在相同输入上输出一致(cross-backend consistency)
- 非方阵测试(M≠N≠K)
- 1024x1024 cuBLAS 验证
Takeaways
-
Row-major vs Column-major 陷阱:这是 GEMM 实现中最容易出错的地方。cuBLAS 的 column-major 假设与 C/Rust 的 row-major 冲突。理解
C=AB⟺C^T=B^T A^T这个恒等式是关键。实际做法:不做任何显式转置,只是交换 A/B 的传入顺序和调整 leading dimension 参数。 -
BF16 的累加精度:BF16 只有 ~3 位有效数字(vs FP32 的 ~7 位)。如果在 BF16 中累加 K 次乘法,误差会快速放大。正确做法是在 FP32 中累加,最后才转回 BF16。我们的 naive 和 tiled kernel 都遵循了这一点(
float sum = 0.0f)。cuBLAS 通过CUBLAS_COMPUTE_32F参数控制。 -
Shared memory tiling 的核心思想:global memory 带宽是 GPU 计算的主要瓶颈。通过 shared memory tiling,每个数据从 global memory 读一次,被 TILE_SIZE 个 thread 复用。对于 TILE_SIZE=32,理论上减少 32 倍 global memory 访问。
-
__syncthreads()的位置关键:tile 加载后必须同步(确保所有 thread 写完 shared memory),计算后也要同步(防止下一轮加载覆盖还在使用的数据)。漏掉任何一个 sync 都会产生 race condition 导致结果错误。 -
cuBLAS handle 开销:每次 matmul 创建/销毁 handle 有~0.1ms 开销。生产环境应全局复用一个 handle。Phase 15(性能优化)时需要修复这个问题。
-
error::check需要 pub:Phase 1 中check()是pub(crate),Phase 3 需要跨 crate 调用。反思:基础设施 crate 的错误处理函数应该从一开始就设计为 public API。
后续优化方向(Phase 15)
- Register tiling(每个 thread 算多个输出元素)
- Tensor Core WMMA(利用 5090 的硬件加速)
- CublasContext 全局复用
- 非 contiguous input 支持(避免 matmul 前的拷贝)