Files
xserv/docs/03-gemm.md
Gahow Wang 51a0f2eb14 docs: add design docs + takeaways for Phase 2 and Phase 3
- docs/01-cuda-ffi.md: added takeaways (struct layout pitfall,
  Rust 2024 unsafe changes, caching allocator strategy, etc.)
- docs/02-tensor.md: design doc + takeaways for tensor abstraction
- docs/03-gemm.md: design doc + takeaways for GEMM kernels

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 20:59:45 +08:00

4.2 KiB
Raw Permalink Blame History

Phase 3: GEMM — Design Document

Goal

实现矩阵乘法的多个版本naive → tiled → cuBLAS建立 benchmark 对比框架,深入理解 GPU 编程中的内存访问模式和优化手段。

Module Layout

csrc/gemm/
├── naive.cu          # 每个 thread 算一个输出元素
└── tiled.cu          # shared memory tiling, 32x32 tiles

crates/xserv-kernels/
├── build.rs          # 编译 .cu + 链接 cublas
└── src/
    ├── lib.rs
    └── gemm.rs       # FFI 封装, GemmBackend enum, matmul(), CublasContext

Kernel Implementations

Version 1: Naive GEMM

Grid:  (ceil(N/16), ceil(M/16))
Block: (16, 16)
每个 thread: C[row][col] = sum_k(A[row][k] * B[k][col])
  • 每个 thread 独立遍历 K 维度做点积
  • 所有读取走 global memory无局部性优化
  • BF16 版本在 FP32 中累加(__bfloat162float → 累加 → __float2bfloat16

Version 2: Tiled GEMM (Shared Memory)

TILE_SIZE = 32
Grid:  (ceil(N/32), ceil(M/32))
Block: (32, 32) = 1024 threads

每个 tile iteration:
  1. 协作加载 A[tile] 和 B[tile] 到 shared memory
  2. __syncthreads()
  3. 在 shared memory 中做 32 次乘加
  4. __syncthreads()
  • 每个 global memory 读取被 TILE_SIZE 个 thread 复用
  • 理论上减少 global memory 访问 TILE_SIZE 倍
  • BF16 版本同样在 shared memory 中存 floatFP32 累加)

Version 3: cuBLAS

  • cublasGemmEx 支持混合精度
  • Row-major 适配cuBLAS 使用 column-major 布局,我们的 tensor 是 row-major
    • 利用恒等式:C = A @ B (row-major) ⟺ C^T = B^T @ A^T (col-major)
    • 传入 CUBLAS_OP_N,让 cuBLAS 把我们的 row-major 数据当作 col-major 的转置
    • 参数:m=N, n=M, k=K, lda=N (B), ldb=K (A), ldc=N (C)

Backend Registry

pub enum GemmBackend { Naive, Tiled, CuBlas }
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor;

运行时可切换 backend方便 benchmark 对比和逐步替换。

CublasContext

RAII 封装 cublasHandle_tDrop 时调 cublasDestroy_v2。 目前每次 matmul 创建一个新 handle后续优化为全局复用。

Test Plan

  • F32: naive/tiled/cuBLAS × small(4)/medium(64-256)/rect(65x33x97)
  • BF16: naive/tiled/cuBLAS × small/medium
  • 三种 backend 在相同输入上输出一致cross-backend consistency
  • 非方阵测试M≠N≠K
  • 1024x1024 cuBLAS 验证

Takeaways

  1. Row-major vs Column-major 陷阱:这是 GEMM 实现中最容易出错的地方。cuBLAS 的 column-major 假设与 C/Rust 的 row-major 冲突。理解 C=ABC^T=B^T A^T 这个恒等式是关键。实际做法:不做任何显式转置,只是交换 A/B 的传入顺序和调整 leading dimension 参数。

  2. BF16 的累加精度BF16 只有 ~3 位有效数字vs FP32 的 ~7 位)。如果在 BF16 中累加 K 次乘法,误差会快速放大。正确做法是在 FP32 中累加,最后才转回 BF16。我们的 naive 和 tiled kernel 都遵循了这一点(float sum = 0.0f。cuBLAS 通过 CUBLAS_COMPUTE_32F 参数控制。

  3. Shared memory tiling 的核心思想global memory 带宽是 GPU 计算的主要瓶颈。通过 shared memory tiling每个数据从 global memory 读一次,被 TILE_SIZE 个 thread 复用。对于 TILE_SIZE=32理论上减少 32 倍 global memory 访问。

  4. __syncthreads() 的位置关键tile 加载后必须同步(确保所有 thread 写完 shared memory计算后也要同步防止下一轮加载覆盖还在使用的数据。漏掉任何一个 sync 都会产生 race condition 导致结果错误。

  5. cuBLAS handle 开销:每次 matmul 创建/销毁 handle 有~0.1ms 开销。生产环境应全局复用一个 handle。Phase 15性能优化时需要修复这个问题。

  6. error::check 需要 pubPhase 1 中 check()pub(crate)Phase 3 需要跨 crate 调用。反思:基础设施 crate 的错误处理函数应该从一开始就设计为 public API。

后续优化方向Phase 15

  • Register tiling每个 thread 算多个输出元素)
  • Tensor Core WMMA利用 5090 的硬件加速)
  • CublasContext 全局复用
  • 非 contiguous input 支持(避免 matmul 前的拷贝)