CUDA kernels (csrc/): - common.cuh: shared warp_reduce_sum/max, block_reduce_sum/max - normalization/rmsnorm.cu: RMSNorm (F32 + BF16) - normalization/layernorm.cu: LayerNorm with Welford (F32 + BF16) - activation/activations.cu: GELU tanh-approx + SiLU (F32 + BF16) - reduce/softmax.cu: safe softmax, 3-pass (F32 + BF16) - embedding/embedding.cu: gather lookup (F32 + BF16) - embedding/rope.cu: RoPE in-place + precomputed cos/sin cache (F32 + BF16) Rust wrappers (xserv-kernels/src/): - rmsnorm.rs, layernorm.rs, activation.rs, softmax.rs, embedding.rs, rope.rs - RopeCache struct with GPU-side precomputation Tests: 12 new tests (ops_test.rs), all passing with good precision: - F32: max_err 1e-6 ~ 1e-9 - BF16: max_err 2e-3 ~ 7e-3 Total: 29 kernel tests + 27 prior = 56 tests passing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
214 lines
8.1 KiB
Markdown
214 lines
8.1 KiB
Markdown
# Phase 4: Transformer Core Kernels — Design Document
|
||
|
||
## Goal
|
||
|
||
实现 Transformer 所需的所有非 Attention 算子的 CUDA kernel,每个 kernel 都支持 BF16 和 F32,与 PyTorch 参考实现对比验证。
|
||
|
||
## Kernel 清单
|
||
|
||
| Kernel | 用于 | 核心计算 | 关键优化点 |
|
||
|--------|------|---------|-----------|
|
||
| LayerNorm | GPT-2 | `(x - mean) / sqrt(var + eps) * gamma + beta` | Welford online, warp reduce |
|
||
| RMSNorm | Qwen3 | `x / sqrt(mean(x²) + eps) * gamma` | 无 mean,比 LayerNorm 简单 |
|
||
| GELU | GPT-2 | `0.5x(1 + tanh(sqrt(2/π)(x + 0.044715x³)))` | tanh 近似,逐元素 |
|
||
| SiLU | Qwen3 | `x * sigmoid(x)` | 逐元素 |
|
||
| Softmax | Attention | `exp(x - max) / sum(exp(x - max))` | Online safe softmax, warp reduce |
|
||
| Embedding | 全部 | `output[i] = table[token_ids[i]]` | Gather, coalesced write |
|
||
| RoPE | Qwen3 | 对 Q/K 的相邻元素对做旋转 | Precompute freq, in-place |
|
||
|
||
## 文件布局
|
||
|
||
```
|
||
csrc/
|
||
├── normalization/
|
||
│ ├── layernorm.cu
|
||
│ └── rmsnorm.cu
|
||
├── activation/
|
||
│ ├── gelu.cu
|
||
│ └── silu.cu
|
||
├── reduce/
|
||
│ └── softmax.cu
|
||
├── embedding/
|
||
│ ├── embedding.cu
|
||
│ └── rope.cu
|
||
|
||
crates/xserv-kernels/src/
|
||
├── layernorm.rs
|
||
├── rmsnorm.rs
|
||
├── activation.rs # GELU + SiLU
|
||
├── softmax.rs
|
||
├── embedding.rs
|
||
├── rope.rs
|
||
└── lib.rs # 新增 mod 声明
|
||
```
|
||
|
||
## Kernel 设计细节
|
||
|
||
### LayerNorm
|
||
|
||
输入 `x: [*, hidden_size]`, 输出 `y: [*, hidden_size]`
|
||
参数 `gamma, beta: [hidden_size]`
|
||
|
||
```
|
||
y[i] = gamma[i] * (x[i] - mean) / sqrt(var + eps) + beta[i]
|
||
```
|
||
|
||
**GPU 映射**: 每个 thread block 处理一行(一个 hidden_size 向量)。
|
||
- Phase 1: 并行加载 x,Welford online 算法计算 mean 和 var
|
||
- Phase 2: warp-level reduce (`__shfl_down_sync`) 聚合 mean/var
|
||
- Phase 3: block-level reduce via shared memory
|
||
- Phase 4: 每个 thread 对自己负责的元素做 normalize + affine
|
||
|
||
**Block 配置**: `block = min(1024, hidden_size)`, `grid = num_rows`
|
||
|
||
### RMSNorm
|
||
|
||
比 LayerNorm 简单:不减 mean,只做 `x * rsqrt(mean(x²) + eps) * gamma`。
|
||
|
||
```
|
||
rms = sqrt(sum(x²) / hidden_size + eps)
|
||
y[i] = x[i] / rms * gamma[i]
|
||
```
|
||
|
||
**GPU 映射**: 同 LayerNorm,每个 block 处理一行。
|
||
- 只需要一次 reduce(求 sum(x²)),不需要两次(mean + var)。
|
||
|
||
### GELU
|
||
|
||
逐元素操作,用 tanh 近似:
|
||
```
|
||
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||
```
|
||
|
||
**GPU 映射**: 每个 thread 处理多个元素(向量化),grid 覆盖全部元素。
|
||
|
||
### SiLU (Swish)
|
||
|
||
逐元素: `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`
|
||
|
||
### Softmax
|
||
|
||
输入 `x: [*, seq_len]`, 沿最后一维做 softmax:
|
||
```
|
||
1. m = max(x) // 数值稳定
|
||
2. e[i] = exp(x[i] - m)
|
||
3. s = sum(e)
|
||
4. y[i] = e[i] / s
|
||
```
|
||
|
||
**GPU 映射**: 每个 block 处理一行。
|
||
- 第一遍 reduce: 求 max
|
||
- 第二遍: exp(x - max) 并 reduce sum
|
||
- 第三遍: 除以 sum
|
||
|
||
**优化**: 可以用 online softmax 合并前两遍(边算 exp 边更新 max),但先实现三遍版本保证正确。
|
||
|
||
### Embedding
|
||
|
||
```
|
||
output[seq_idx] = embedding_table[token_ids[seq_idx]]
|
||
```
|
||
|
||
**GPU 映射**: 每个 thread 处理一个 token 的部分维度。
|
||
- `grid = num_tokens`, `block = hidden_size`(或分多个 thread 处理一个 token)
|
||
- 写端是 coalesced(连续 thread 写连续地址),读端是 gather(非连续)
|
||
|
||
### RoPE (Rotary Position Embedding)
|
||
|
||
对 Q/K 的每对相邻元素 `(x0, x1)` 做 2D 旋转:
|
||
```
|
||
freq[i] = 1.0 / (theta ^ (2i / dim))
|
||
cos_val = cos(position * freq[i])
|
||
sin_val = sin(position * freq[i])
|
||
y0 = x0 * cos_val - x1 * sin_val
|
||
y1 = x0 * sin_val + x1 * cos_val
|
||
```
|
||
|
||
**GPU 映射**: 每个 thread 处理一对元素 `(x[2i], x[2i+1])`。
|
||
- Precompute `cos_cache[max_seq_len][head_dim/2]` 和 `sin_cache` 在初始化时
|
||
- 运行时 kernel 只做乘加
|
||
|
||
**theta**: Qwen3 默认 `rope_theta = 1000000.0`
|
||
|
||
## Reduction Pattern(核心学习点)
|
||
|
||
所有 Norm 和 Softmax 都涉及 reduction。GPU reduction 的分层结构:
|
||
|
||
```
|
||
Thread-level: 每个 thread 处理多个元素,本地累加
|
||
↓
|
||
Warp-level: __shfl_down_sync() 在 32 threads 内规约(无需 shared memory)
|
||
↓
|
||
Block-level: shared memory 存各 warp 的结果,warp 0 再规约
|
||
```
|
||
|
||
对于 hidden_size <= 8192(LLM 常见),一个 block 足够,不需要 grid-level reduction。
|
||
|
||
### Warp Reduce 模板
|
||
|
||
```cuda
|
||
__device__ float warp_reduce_sum(float val) {
|
||
for (int offset = 16; offset > 0; offset >>= 1)
|
||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||
return val;
|
||
}
|
||
```
|
||
|
||
### Block Reduce 模板
|
||
|
||
```cuda
|
||
__device__ float block_reduce_sum(float val) {
|
||
__shared__ float shared[32]; // max 32 warps per block
|
||
int lane = threadIdx.x % 32;
|
||
int warp_id = threadIdx.x / 32;
|
||
|
||
val = warp_reduce_sum(val);
|
||
if (lane == 0) shared[warp_id] = val;
|
||
__syncthreads();
|
||
|
||
val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0f;
|
||
if (warp_id == 0) val = warp_reduce_sum(val);
|
||
return val;
|
||
}
|
||
```
|
||
|
||
## Reference 验证策略
|
||
|
||
写 `tools/generate_reference.py` 脚本,用 PyTorch 为每个 op 生成 reference input/output:
|
||
- 保存为 `.npy` 格式
|
||
- Rust 测试中加载对比
|
||
- 或者直接在 Rust 测试中用 CPU 实现计算 expected 值(更简单,不依赖 Python)
|
||
|
||
**选择**: 先用 Rust CPU 实现作为 reference(简单),关键 op(RoPE)再与 PyTorch 对比。
|
||
|
||
## Test Plan
|
||
|
||
- [x] RMSNorm F32: hidden_size=768, 4 rows → max_err 7.2e-7
|
||
- [x] RMSNorm BF16: 同上 → max_err 7.0e-3
|
||
- [x] LayerNorm F32: hidden_size=768 → max_err 1.7e-6
|
||
- [x] GELU F32: 10000 elements → max_err 3.0e-8
|
||
- [x] GELU BF16: 同上 → max_err 2.4e-3
|
||
- [x] SiLU F32: 10000 elements → max_err 1.5e-8
|
||
- [x] Softmax F32: 8×256 → max_err 1.4e-9
|
||
- [x] Softmax sum=1 验证: 4×2048
|
||
- [x] Softmax 大值 (1000+) 数值稳定性 → max_err 1.5e-8
|
||
- [x] Embedding F32: vocab=100, hidden=64, 5 tokens → exact match
|
||
- [x] RoPE F32: 4 tokens × 2 heads × dim=8 → max_err 6.0e-8
|
||
- [x] RoPE position=0 恒等验证 → max_err 0
|
||
|
||
## Takeaways
|
||
|
||
1. **`common.cuh` 抽取共用 reduction 是正确的做法**:`warp_reduce_sum/max` 和 `block_reduce_sum/max` 被 RMSNorm, LayerNorm, Softmax 三个 kernel 复用。抽到头文件避免了代码重复,也确保 reduction 逻辑一致。build.rs 中需要 `.include("../../csrc")` 让 nvcc 能找到头文件。
|
||
|
||
2. **Shared memory 中广播标量的模式**:Norm 和 Softmax 都需要将 reduce 结果(mean, rms_inv, max, sum)广播给 block 内所有 thread。标准做法:thread 0 写 `__shared__` 变量,`__syncthreads()` 后所有 thread 读。这比让每个 thread 独立做 reduce 高效得多。
|
||
|
||
3. **Softmax 三遍 vs 两遍**:我们实现了三遍版本(max → exp+sum → normalize),简单可靠。Online softmax 可以合并前两遍(一遍 pass 内同时跟踪 running max 和 running sum),但需要更复杂的数值更新公式。Flash Attention(Phase 14)会用到 online softmax。
|
||
|
||
4. **RoPE 的 position=0 恒等性**:`cos(0)=1, sin(0)=0`,所以 position 0 的旋转是恒等变换。这是一个很好的 sanity check。如果 position=0 时输出不等于输入,说明 kernel 有 bug。
|
||
|
||
5. **BF16 Softmax 的精度陷阱**:exp 结果先写成 BF16 再读回做 normalize 会丢精度。理想做法是用 float scratch buffer 暂存 exp 结果。当前实现可接受(误差在 1e-2 量级),但在 attention score 很接近时可能引入可观察的差异。Phase 14 Flash Attention 会解决这个问题(全程 FP32 累加)。
|
||
|
||
6. **Embedding 就是 gather 操作**:没有任何计算,纯粹的内存搬运。瓶颈在 global memory 随机读取(token_ids 导致不连续读 table)。写端是 coalesced 的(连续 token 写连续地址)。优化方向:使用向量化加载(`float4`)一次读 128 bit。
|
||
|
||
7. **RoPE in-place 修改 Tensor 的设计考量**:RoPE 在数学上是对 Q/K 的 in-place 旋转。我们通过 `data_ptr() as *mut` 绕过了 Rust 的不可变借用。这在 GPU 上是安全的(kernel 内部互不干扰),但 Rust 侧没有 `&mut` 语义保护。后续如果需要更严格的安全性,可以引入 `Tensor::as_mut_ptr()` 方法并要求 `&mut self`。
|