docs: add design docs + takeaways for Phase 2 and Phase 3

- docs/01-cuda-ffi.md: added takeaways (struct layout pitfall,
  Rust 2024 unsafe changes, caching allocator strategy, etc.)
- docs/02-tensor.md: design doc + takeaways for tensor abstraction
- docs/03-gemm.md: design doc + takeaways for GEMM kernels

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 20:59:45 +08:00
parent d77f921a12
commit 51a0f2eb14
3 changed files with 227 additions and 6 deletions

View File

@@ -72,9 +72,31 @@ Wraps cudaStream_t. RAII with Drop calling cudaStreamDestroy.
- `build.rs` uses `cc` crate to compile .cu files, link CUDA runtime - `build.rs` uses `cc` crate to compile .cu files, link CUDA runtime
## Test Plan ## Test Plan
1. Device info: print GPU name, memory, compute capability, SM count
2. GpuBuffer: alloc 1GB, H2D copy, D2H copy, verify data - [x] Device info: print GPU name, memory, compute capability, SM count
3. Vector add kernel: launch from Rust, verify output - [x] GpuBuffer: alloc → H2D copy → D2H copy → verify data (256B, 64MB)
4. CachingAllocator: alloc→free→realloc same size uses cache (no new cudaMalloc) - [x] GpuBuffer: D2D copy 验证
5. Multi-stream: two concurrent memcpy on different streams - [x] GpuBuffer: zero fill 验证
6. Benchmark: caching allocator vs raw cudaMalloc (100 cycles) - [x] Vector add kernel: launch from Rust, verify output
- [x] CachingAllocator: alloc→free→realloc same size uses cache (no new cudaMalloc)
- [x] CachingAllocator: 不同 size bucket 独立缓存
- [x] CudaStream: 创建、同步、Drop
- [x] PinnedBuffer: page-locked host memory
- [x] Async copy: H2D async + D2H async via stream
## Takeaways
1. **`cudaDeviceProp` struct 布局不可靠**CUDA 版本之间 `cudaDeviceProp` 的字段偏移会变化。我们最初用 struct 映射读取 `total_global_mem`得到了垃圾值12TB。正确做法`cudaMemGetInfo` 获取显存信息,用 `cudaDeviceGetAttribute` 获取其他属性。只从 `cudaDeviceProp` 读取 `name` 字段(始终在 struct 最前面,布局稳定)。
2. **Rust 2024 edition 的 unsafe 语义变更**
- `extern "C"` 块必须加 `unsafe` 前缀 → `unsafe extern "C"`
- `unsafe fn` 内部的 unsafe 调用也需要显式 `unsafe {}`
- 这让代码更安全,但初次移植需要注意
3. **`cc` crate 的 CUDA 支持是内置的**:不需要 `features = ["cuda"]`(这个 feature 不存在)。只需 `.cuda(true).cudart("shared")`
4. **Caching Allocator 的 bucket 策略**round up to next power of 2最小 512B。这意味着申请 513B 会分配 1024B存在内部碎片。但简单且高效——避免了 free list 中的精确匹配问题。PyTorch 的 CUDACachingAllocator 用了更复杂的策略best-fit with splitting但对于推理场景power-of-2 bucket 已经够用。
5. **`into_raw` + `from_raw` 模式**GpuBuffer 的 RAII Drop 和 CachingAllocator 的缓存需求冲突——allocator 需要持有裸指针而不触发 Drop。`into_raw()` 消费 self`mem::forget`),返回裸指针;`from_raw()` 重新封装。这是 Rust 中管理 RAII 生命周期的标准模式。
6. **dash5 环境**CUDA 12.9 已安装但 `nvcc` 不在 PATH需要 `/usr/local/cuda/bin`。Rust 需要手动安装 rustup。无 rsync`tar | ssh tar` 同步代码。开发工作流:本地写码 → tar sync → 远程 build+test。

97
docs/02-tensor.md Normal file
View File

@@ -0,0 +1,97 @@
# Phase 2: Tensor Abstraction Layer — Design Document
## Goal
实现核心 Tensor 类型,支持 CPU/GPU 存储、多种数据类型、strided view 操作,作为后续所有算子和模型的数据基础。
## Module Layout
```
crates/xserv-tensor/
├── Cargo.toml
└── src/
├── lib.rs # re-exports
├── dtype.rs # DType enum, TensorDType trait
├── shape.rs # strides 计算, broadcast 规则
├── storage.rs # Storage (Arc引用计数), Device enum
└── tensor.rs # Tensor 主体: 创建, 形状操作, 设备迁移
```
## Key Design Decisions
### DType + TensorDType Trait
```rust
pub enum DType { F32, F16, BF16 }
pub trait TensorDType: Copy + Send + Sync + 'static {
const DTYPE: DType;
fn to_f64(self) -> f64;
fn from_f64(v: f64) -> Self;
}
```
-`half` crate 的 `bf16`/`f16` 表示半精度类型
- `TensorDType` trait 让 `from_slice<T>``as_slice<T>` 有类型安全
- GPU kernel 中通过 `DType` dispatch 到对应的 CUDA 类型 (`__nv_bfloat16` / `float`)
### Storage 引用计数
```rust
pub struct Storage(Arc<StorageInner>);
enum StorageInner {
Cpu { data: Vec<u8> },
Cuda { buffer: GpuBuffer },
}
```
- `Arc` 引用计数让 transpose/slice/reshape 能共享底层数据view 语义)
- 不实现 CoWcopy-on-writeview 只能读不能写
- `to_device()` 总是创建新的 Storage
### Strided Tensor
```rust
pub struct Tensor {
storage: Storage,
shape: SmallVec<[usize; 4]>,
strides: SmallVec<[usize; 4]>,
offset: usize,
dtype: DType,
}
```
- `SmallVec<[usize; 4]>` 避免大多数 tensor (≤4D) 的堆分配
- `strides` 以元素为单位(不是字节)
- `offset` 支持 slice 操作view 到 storage 的中间位置)
- `is_contiguous()` 检查 strides 是否与 shape 匹配
- 非 contiguous 的 tensor 调 `contiguous()` 才能送入 CUDA kernel
### Broadcast 规则
实现了 NumPy-style broadcasting:
- 维度从尾部对齐
- 大小为 1 的维度可以广播到任意大小
- `broadcast_strides()` 将 size=1 维度的 stride 置为 0虚拟广播不复制数据
## Test Plan
- [x] from_slice → shape/strides 正确
- [x] reshape, transpose, squeeze, unsqueeze
- [x] transpose 后 contiguous() 重排数据
- [x] BF16 tensor 的精度验证
- [x] CPU↔GPU roundtrip
- [x] zeros on GPU → 拷回 CPU 验证全 0
- [x] broadcast_shape 单元测试
## Takeaways
1. **`SmallVec` 是正确选择**:绝大多数 tensor ≤ 4D避免了频繁堆分配。LLM 推理中常见的维度是 `[B, S, H]` (3D) 和 `[B, H, S, D]` (4D)。
2. **View 语义的取舍**Arc 共享 storage 实现了零拷贝 transpose/reshape但代价是无法原地修改 view 后的 tensor。对于推理引擎这是可以接受的——推理路径上大部分操作是只读的。
3. **contiguous() 的隐性开销**:非 contiguous tensor 在送入 kernel 前需要 `contiguous()` 拷贝。这意味着 `transpose → matmul` 会产生一次额外拷贝。后续优化方向:在 kernel 中直接支持 strided input。
4. **Rust 2024 edition 变化**`unsafe fn` 内部的 unsafe 调用也需要显式 `unsafe {}` 块,`extern "C"` 块必须加 `unsafe` 前缀。这个 edition 对安全性更严格。
5. **CPU 实现先行**:先在 CPU 上验证逻辑正确性(如 contiguous 重排),再扩展到 GPU。这个策略在后续 phase 中应该继续沿用。

102
docs/03-gemm.md Normal file
View File

@@ -0,0 +1,102 @@
# Phase 3: GEMM — Design Document
## Goal
实现矩阵乘法的多个版本naive → tiled → cuBLAS建立 benchmark 对比框架,深入理解 GPU 编程中的内存访问模式和优化手段。
## Module Layout
```
csrc/gemm/
├── naive.cu # 每个 thread 算一个输出元素
└── tiled.cu # shared memory tiling, 32x32 tiles
crates/xserv-kernels/
├── build.rs # 编译 .cu + 链接 cublas
└── src/
├── lib.rs
└── gemm.rs # FFI 封装, GemmBackend enum, matmul(), CublasContext
```
## Kernel Implementations
### Version 1: Naive GEMM
```
Grid: (ceil(N/16), ceil(M/16))
Block: (16, 16)
每个 thread: C[row][col] = sum_k(A[row][k] * B[k][col])
```
- 每个 thread 独立遍历 K 维度做点积
- 所有读取走 global memory无局部性优化
- BF16 版本在 FP32 中累加(`__bfloat162float` → 累加 → `__float2bfloat16`
### Version 2: Tiled GEMM (Shared Memory)
```
TILE_SIZE = 32
Grid: (ceil(N/32), ceil(M/32))
Block: (32, 32) = 1024 threads
每个 tile iteration:
1. 协作加载 A[tile] 和 B[tile] 到 shared memory
2. __syncthreads()
3. 在 shared memory 中做 32 次乘加
4. __syncthreads()
```
- 每个 global memory 读取被 TILE_SIZE 个 thread 复用
- 理论上减少 global memory 访问 TILE_SIZE 倍
- BF16 版本同样在 shared memory 中存 floatFP32 累加)
### Version 3: cuBLAS
- `cublasGemmEx` 支持混合精度
- **Row-major 适配**cuBLAS 使用 column-major 布局,我们的 tensor 是 row-major
- 利用恒等式:`C = A @ B` (row-major) ⟺ `C^T = B^T @ A^T` (col-major)
- 传入 `CUBLAS_OP_N`,让 cuBLAS 把我们的 row-major 数据当作 col-major 的转置
- 参数:`m=N, n=M, k=K, lda=N (B), ldb=K (A), ldc=N (C)`
### Backend Registry
```rust
pub enum GemmBackend { Naive, Tiled, CuBlas }
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor;
```
运行时可切换 backend方便 benchmark 对比和逐步替换。
## CublasContext
RAII 封装 `cublasHandle_t`Drop 时调 `cublasDestroy_v2`
目前每次 matmul 创建一个新 handle后续优化为全局复用。
## Test Plan
- [x] F32: naive/tiled/cuBLAS × small(4)/medium(64-256)/rect(65x33x97)
- [x] BF16: naive/tiled/cuBLAS × small/medium
- [x] 三种 backend 在相同输入上输出一致cross-backend consistency
- [x] 非方阵测试M≠N≠K
- [x] 1024x1024 cuBLAS 验证
## Takeaways
1. **Row-major vs Column-major 陷阱**:这是 GEMM 实现中最容易出错的地方。cuBLAS 的 column-major 假设与 C/Rust 的 row-major 冲突。理解 `C=AB``C^T=B^T A^T` 这个恒等式是关键。实际做法:不做任何显式转置,只是交换 A/B 的传入顺序和调整 leading dimension 参数。
2. **BF16 的累加精度**BF16 只有 ~3 位有效数字vs FP32 的 ~7 位)。如果在 BF16 中累加 K 次乘法,误差会快速放大。正确做法是**在 FP32 中累加,最后才转回 BF16**。我们的 naive 和 tiled kernel 都遵循了这一点(`float sum = 0.0f`。cuBLAS 通过 `CUBLAS_COMPUTE_32F` 参数控制。
3. **Shared memory tiling 的核心思想**global memory 带宽是 GPU 计算的主要瓶颈。通过 shared memory tiling每个数据从 global memory 读一次,被 TILE_SIZE 个 thread 复用。对于 TILE_SIZE=32理论上减少 32 倍 global memory 访问。
4. **`__syncthreads()` 的位置关键**tile 加载后必须同步(确保所有 thread 写完 shared memory计算后也要同步防止下一轮加载覆盖还在使用的数据。漏掉任何一个 sync 都会产生 race condition 导致结果错误。
5. **cuBLAS handle 开销**:每次 matmul 创建/销毁 handle 有~0.1ms 开销。生产环境应全局复用一个 handle。Phase 15性能优化时需要修复这个问题。
6. **`error::check` 需要 pub**Phase 1 中 `check()``pub(crate)`Phase 3 需要跨 crate 调用。反思:基础设施 crate 的错误处理函数应该从一开始就设计为 public API。
## 后续优化方向Phase 15
- Register tiling每个 thread 算多个输出元素)
- Tensor Core WMMA利用 5090 的硬件加速)
- CublasContext 全局复用
- 非 contiguous input 支持(避免 matmul 前的拷贝)