Files
xserv/docs/04-transformer-kernels.md
Gahow Wang c8e8153702 phase 4: transformer core kernels
CUDA kernels (csrc/):
- common.cuh: shared warp_reduce_sum/max, block_reduce_sum/max
- normalization/rmsnorm.cu: RMSNorm (F32 + BF16)
- normalization/layernorm.cu: LayerNorm with Welford (F32 + BF16)
- activation/activations.cu: GELU tanh-approx + SiLU (F32 + BF16)
- reduce/softmax.cu: safe softmax, 3-pass (F32 + BF16)
- embedding/embedding.cu: gather lookup (F32 + BF16)
- embedding/rope.cu: RoPE in-place + precomputed cos/sin cache (F32 + BF16)

Rust wrappers (xserv-kernels/src/):
- rmsnorm.rs, layernorm.rs, activation.rs, softmax.rs, embedding.rs, rope.rs
- RopeCache struct with GPU-side precomputation

Tests: 12 new tests (ops_test.rs), all passing with good precision:
- F32: max_err 1e-6 ~ 1e-9
- BF16: max_err 2e-3 ~ 7e-3
Total: 29 kernel tests + 27 prior = 56 tests passing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 21:07:24 +08:00

8.1 KiB
Raw Permalink Blame History

Phase 4: Transformer Core Kernels — Design Document

Goal

实现 Transformer 所需的所有非 Attention 算子的 CUDA kernel每个 kernel 都支持 BF16 和 F32与 PyTorch 参考实现对比验证。

Kernel 清单

Kernel 用于 核心计算 关键优化点
LayerNorm GPT-2 (x - mean) / sqrt(var + eps) * gamma + beta Welford online, warp reduce
RMSNorm Qwen3 x / sqrt(mean(x²) + eps) * gamma 无 mean比 LayerNorm 简单
GELU GPT-2 0.5x(1 + tanh(sqrt(2/π)(x + 0.044715x³))) tanh 近似,逐元素
SiLU Qwen3 x * sigmoid(x) 逐元素
Softmax Attention exp(x - max) / sum(exp(x - max)) Online safe softmax, warp reduce
Embedding 全部 output[i] = table[token_ids[i]] Gather, coalesced write
RoPE Qwen3 对 Q/K 的相邻元素对做旋转 Precompute freq, in-place

文件布局

csrc/
├── normalization/
│   ├── layernorm.cu
│   └── rmsnorm.cu
├── activation/
│   ├── gelu.cu
│   └── silu.cu
├── reduce/
│   └── softmax.cu
├── embedding/
│   ├── embedding.cu
│   └── rope.cu

crates/xserv-kernels/src/
├── layernorm.rs
├── rmsnorm.rs
├── activation.rs      # GELU + SiLU
├── softmax.rs
├── embedding.rs
├── rope.rs
└── lib.rs             # 新增 mod 声明

Kernel 设计细节

LayerNorm

输入 x: [*, hidden_size], 输出 y: [*, hidden_size] 参数 gamma, beta: [hidden_size]

y[i] = gamma[i] * (x[i] - mean) / sqrt(var + eps) + beta[i]

GPU 映射: 每个 thread block 处理一行(一个 hidden_size 向量)。

  • Phase 1: 并行加载 xWelford online 算法计算 mean 和 var
  • Phase 2: warp-level reduce (__shfl_down_sync) 聚合 mean/var
  • Phase 3: block-level reduce via shared memory
  • Phase 4: 每个 thread 对自己负责的元素做 normalize + affine

Block 配置: block = min(1024, hidden_size), grid = num_rows

RMSNorm

比 LayerNorm 简单:不减 mean只做 x * rsqrt(mean(x²) + eps) * gamma

rms = sqrt(sum(x²) / hidden_size + eps)
y[i] = x[i] / rms * gamma[i]

GPU 映射: 同 LayerNorm每个 block 处理一行。

  • 只需要一次 reduce求 sum(x²)不需要两次mean + var

GELU

逐元素操作,用 tanh 近似:

gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))

GPU 映射: 每个 thread 处理多个元素向量化grid 覆盖全部元素。

SiLU (Swish)

逐元素: silu(x) = x * sigmoid(x) = x / (1 + exp(-x))

Softmax

输入 x: [*, seq_len], 沿最后一维做 softmax:

1. m = max(x)                    // 数值稳定
2. e[i] = exp(x[i] - m)
3. s = sum(e)
4. y[i] = e[i] / s

GPU 映射: 每个 block 处理一行。

  • 第一遍 reduce: 求 max
  • 第二遍: exp(x - max) 并 reduce sum
  • 第三遍: 除以 sum

优化: 可以用 online softmax 合并前两遍(边算 exp 边更新 max但先实现三遍版本保证正确。

Embedding

output[seq_idx] = embedding_table[token_ids[seq_idx]]

GPU 映射: 每个 thread 处理一个 token 的部分维度。

  • grid = num_tokens, block = hidden_size(或分多个 thread 处理一个 token
  • 写端是 coalesced连续 thread 写连续地址),读端是 gather非连续

RoPE (Rotary Position Embedding)

对 Q/K 的每对相邻元素 (x0, x1) 做 2D 旋转:

freq[i] = 1.0 / (theta ^ (2i / dim))
cos_val = cos(position * freq[i])
sin_val = sin(position * freq[i])
y0 = x0 * cos_val - x1 * sin_val
y1 = x0 * sin_val + x1 * cos_val

GPU 映射: 每个 thread 处理一对元素 (x[2i], x[2i+1])

  • Precompute cos_cache[max_seq_len][head_dim/2]sin_cache 在初始化时
  • 运行时 kernel 只做乘加

theta: Qwen3 默认 rope_theta = 1000000.0

Reduction Pattern核心学习点

所有 Norm 和 Softmax 都涉及 reduction。GPU reduction 的分层结构:

Thread-level:  每个 thread 处理多个元素,本地累加
    ↓
Warp-level:    __shfl_down_sync() 在 32 threads 内规约(无需 shared memory
    ↓
Block-level:   shared memory 存各 warp 的结果warp 0 再规约

对于 hidden_size <= 8192LLM 常见),一个 block 足够,不需要 grid-level reduction。

Warp Reduce 模板

__device__ float warp_reduce_sum(float val) {
    for (int offset = 16; offset > 0; offset >>= 1)
        val += __shfl_down_sync(0xffffffff, val, offset);
    return val;
}

Block Reduce 模板

__device__ float block_reduce_sum(float val) {
    __shared__ float shared[32];  // max 32 warps per block
    int lane = threadIdx.x % 32;
    int warp_id = threadIdx.x / 32;

    val = warp_reduce_sum(val);
    if (lane == 0) shared[warp_id] = val;
    __syncthreads();

    val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0f;
    if (warp_id == 0) val = warp_reduce_sum(val);
    return val;
}

Reference 验证策略

tools/generate_reference.py 脚本,用 PyTorch 为每个 op 生成 reference input/output:

  • 保存为 .npy 格式
  • Rust 测试中加载对比
  • 或者直接在 Rust 测试中用 CPU 实现计算 expected 值(更简单,不依赖 Python

选择: 先用 Rust CPU 实现作为 reference简单关键 opRoPE再与 PyTorch 对比。

Test Plan

  • RMSNorm F32: hidden_size=768, 4 rows → max_err 7.2e-7
  • RMSNorm BF16: 同上 → max_err 7.0e-3
  • LayerNorm F32: hidden_size=768 → max_err 1.7e-6
  • GELU F32: 10000 elements → max_err 3.0e-8
  • GELU BF16: 同上 → max_err 2.4e-3
  • SiLU F32: 10000 elements → max_err 1.5e-8
  • Softmax F32: 8×256 → max_err 1.4e-9
  • Softmax sum=1 验证: 4×2048
  • Softmax 大值 (1000+) 数值稳定性 → max_err 1.5e-8
  • Embedding F32: vocab=100, hidden=64, 5 tokens → exact match
  • RoPE F32: 4 tokens × 2 heads × dim=8 → max_err 6.0e-8
  • RoPE position=0 恒等验证 → max_err 0

Takeaways

  1. common.cuh 抽取共用 reduction 是正确的做法warp_reduce_sum/maxblock_reduce_sum/max 被 RMSNorm, LayerNorm, Softmax 三个 kernel 复用。抽到头文件避免了代码重复,也确保 reduction 逻辑一致。build.rs 中需要 .include("../../csrc") 让 nvcc 能找到头文件。

  2. Shared memory 中广播标量的模式Norm 和 Softmax 都需要将 reduce 结果mean, rms_inv, max, sum广播给 block 内所有 thread。标准做法thread 0 写 __shared__ 变量,__syncthreads() 后所有 thread 读。这比让每个 thread 独立做 reduce 高效得多。

  3. Softmax 三遍 vs 两遍我们实现了三遍版本max → exp+sum → normalize简单可靠。Online softmax 可以合并前两遍(一遍 pass 内同时跟踪 running max 和 running sum但需要更复杂的数值更新公式。Flash AttentionPhase 14会用到 online softmax。

  4. RoPE 的 position=0 恒等性cos(0)=1, sin(0)=0,所以 position 0 的旋转是恒等变换。这是一个很好的 sanity check。如果 position=0 时输出不等于输入,说明 kernel 有 bug。

  5. BF16 Softmax 的精度陷阱exp 结果先写成 BF16 再读回做 normalize 会丢精度。理想做法是用 float scratch buffer 暂存 exp 结果。当前实现可接受(误差在 1e-2 量级),但在 attention score 很接近时可能引入可观察的差异。Phase 14 Flash Attention 会解决这个问题(全程 FP32 累加)。

  6. Embedding 就是 gather 操作:没有任何计算,纯粹的内存搬运。瓶颈在 global memory 随机读取token_ids 导致不连续读 table。写端是 coalesced 的(连续 token 写连续地址)。优化方向:使用向量化加载(float4)一次读 128 bit。

  7. RoPE in-place 修改 Tensor 的设计考量RoPE 在数学上是对 Q/K 的 in-place 旋转。我们通过 data_ptr() as *mut 绕过了 Rust 的不可变借用。这在 GPU 上是安全的kernel 内部互不干扰),但 Rust 侧没有 &mut 语义保护。后续如果需要更严格的安全性,可以引入 Tensor::as_mut_ptr() 方法并要求 &mut self