Files
xserv/docs/14-flash-attention.md
Gahow Wang 6cc1c9332d docs: Phase 14 design doc + benchmark, fix Phase 11/12 honesty
Phase 14 (Flash Attention):
- Design doc: FA2 algorithm, SM120 hardware constraints (FA4 incompatible),
  kernel config (BR=BC=64, 32KB smem), GQA mapping, causal tile-skip,
  known limitations and optimization roadmap
- Benchmark doc: correctness (9/10 top-1 match, identical to pre-FA baseline),
  performance tracking (6.9→10.3→12.9 tok/s across phases), memory savings
  analysis, remaining bottleneck breakdown

Phase 11 doc: title corrected from "Paged Attention" to "GPU-Resident KV Cache"
with explicit note that paged allocation was not implemented.

Phase 12 doc: "当前状态" updated from "未实现" to reflect actual state —
iteration-level scheduling implemented + verified (6.0x concurrent speedup),
batched GPU forward explicitly marked as not yet implemented.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 18:51:29 +08:00

7.5 KiB
Raw Permalink Blame History

Phase 14: Flash Attention 2 for SM120 — Design Document

Goal

用自写的 Flash Attention 2 CUDA kernel 替换 naive attention (Phase 5)。消除 O(S²) 显存分配,支持 GQA kernel 内部索引(消除 repeat_kv 开销)。

硬件约束: FA4 不适用于 RTX 5090

Flash Attention 已发展到第 4 代 (FA4, arxiv 2603.05451),但各版本有明确硬件依赖:

版本 目标架构 关键硬件特性 RTX 5090 (SM120)
FA2 通用 CUDA (SM75+) shared memory + HMMA 兼容
FA3 Hopper SM90 (H100) TMA + WGMMA + warp specialization 不兼容
FA4 Blackwell SM100 (B200/B300) TMEM + async MMA + 2-CTA mode 不兼容

RTX 5090 使用消费级 Blackwell (GB202, SM120),与数据中心 Blackwell (B200, SM100) 是不同硅片。SM120 没有 TMEM (Tensor Memory),这是 FA4 kernel 设计的核心硬件依赖。这不是软件限制,是硬件级差异。

因此本项目实现 FA2 算法,使用标准 CUDA (shared memory + 标准 HMMA)。

Naive Attention 的问题

Phase 5 的 naive attention 流程:

k_t = K.transpose(2,3).contiguous()    ← 分配 K^T 显存
scores = batched_matmul(Q, k_t)         ← 分配 [B,H,S,S] score 矩阵 (O(S²) 显存)
scores = scale(scores, 1/sqrt(d))       ← 逐元素 kernel
causal_mask(scores)                     ← 逐元素 kernel
weights = softmax(scores)               ← 分配 [B,H,S,S] weight 矩阵
output = batched_matmul(weights, V)     ← 最终结果

问题:

  1. 显存 O(S²): score 和 weight 矩阵各需 B × H × S × S × dtype_size。S=2048, H=32, BF16 → 256 MB。S=8192 → 4 GB。
  2. GQA 预处理: 在调用 attention 前需要 repeat_kv_gpu 将 K/V 从 8 heads 扩展到 32 heads每层额外分配和拷贝。
  3. 多次 kernel launch: scale, mask, softmax 各一次 kernel launch + global memory round-trip。
  4. K^T materialization: K.transpose().contiguous() 需要分配和拷贝。

FA2 算法

核心思想: 不 materialize S×S 矩阵。将 Q, K, V 分成 tiles在 shared memory (SRAM) 中计算,使用 online softmax trick 边算边更新 running max 和 sum。

FA2 (Dao 2023) 相比 FA1 的改进: 外层循环遍历 Q tiles (而非 K/V),减少 HBM 读写次数,提高并行性。

scale = 1 / sqrt(head_dim)

for each Q tile (q_start..q_start + BR):                   // 外层: Q tiles
    load Q_tile [BR, D] to shared memory (一次加载,内层复用)
    init per-row: O[D] = 0, m = -inf, l = 0

    for each K/V tile j (kv_start..kv_start + BC):          // 内层: K/V tiles
        // Causal tile-skip: 如果整个 K tile 在 Q tile "未来",跳过
        if causal && kv_start > max_q_pos + kv_offset: skip

        load K_tile [BC, D] to shared memory
        S = Q_tile @ K_tile^T * scale                       // [BR, BC], in registers
        if causal: mask S[r][c] = -inf where kv_pos > q_pos

        // Online softmax update
        m_new = max(m, rowmax(S))
        P = exp(S - m_new)
        l_new = exp(m - m_new) * l + rowsum(P)
        O = exp(m - m_new) * O                              // rescale accumulator

        load V_tile [BC, D] to shared memory (复用 K 的空间)
        O += P @ V_tile                                     // accumulate

        m = m_new, l = l_new

    O = O / l                                               // final normalize
    write O[BR, D] to HBM (convert FP32 → BF16)

实现细节

Kernel 配置

参数 说明
BR (Q tile rows) 64 Q tile 大小
BC (K/V tile rows) 64 K/V tile 大小
head_dim 运行时参数 (≤128) 支持 64 (GPT-2) 和 128 (Qwen3)
Block size 128 threads 64 线程各 own 一行 Q其余协助加载
Grid (q_tiles, batch × num_q_heads) 每个 block 处理一个 Q tile + 一个 head

Shared Memory (BF16 存储)

smem_q [BR × head_dim] BF16 = 64 × 128 × 2 = 16 KB (加载一次,内层复用)
smem_kv[BC × head_dim] BF16 = 64 × 128 × 2 = 16 KB (K 和 V 交替使用)
────────────────────────────────────────────
Total: 32 KB (SM120 默认 48 KB余量充足)

线程映射

  • Thread 0..63: 各 own Q_tile 的一行。负责该行的全部计算dot products、softmax、PV 累加。
  • Thread 64..127: 协助 shared memory 加载 (K/V tile),不参与计算。
  • 加载模式: 每个 thread 加载 (BR × head_dim) / 128 = 64 个 BF16 元素。

Per-Thread Register 使用

O_acc[128] FP32     = 512 bytes (128 regs)  — 输出累加器
P[64]      FP32     = 256 bytes (64 regs)   — 当前 tile 的 softmax 后权重
m, l       FP32     = 8 bytes (2 regs)      — online softmax running state
循环变量 + 临时      ≈ 16 regs
────────────────────────────────────────────
Total: ~210 regs/thread (max 255在限制内)

GQA 支持

每个 thread block 处理一个 Q head通过 kv_head = q_head / (num_q_heads / num_kv_heads) 映射到对应的 KV head。K/V 的数据指针直接指向 KV head 的存储,无需 repeat_kv。

// 32 Q heads, 8 KV heads → heads_per_group = 4
// Q head 0,1,2,3 → KV head 0
// Q head 4,5,6,7 → KV head 1
// ...
kv_head = q_head / heads_per_group;
K_ptr = K + (batch * num_kv_heads + kv_head) * kv_len * head_dim;

Causal Mask

两级优化:

  1. Tile-level skip: 如果 kv_tile_start > max_q_pos + kv_offset,整个 K/V tile 都在未来,跳过(减少 ~50% 计算)。
  2. Element-level mask: 在 tile 内部,if kv_pos > q_pos + kv_offset: S = -inf

kv_offset = kv_len - q_len 处理 decode 时 KV cache 长于 Q 的情况。

与 Naive Attention 的对比

特性 Naive (Phase 5) FA2 (Phase 14)
显存 O(B × H × S²) O(B × H × S × D)
GQA 需要 repeat_kv (分配+拷贝) Kernel 内部索引 (零开销)
K^T 需要 transpose+contiguous Kernel 内部计算
Kernel launches 6 (matmul, scale, mask, softmax, matmul, ...) 1 (单个 fused kernel)
S=8192 可行性 OOM (~4 GB score matrix) 可行 (32 KB shared memory)

源码结构

csrc/attention/flash_attention.cu     — FA2 kernel (BF16 in, FP32 accumulate, BF16 out)
crates/xserv-kernels/src/attention.rs — flash_attention() Rust wrapper + 原 attention() 保留
crates/xserv-model/src/qwen3.rs       — forward_gpu_cache 调用 flash_attention

已知局限与后续优化方向

  1. Decode (Q_len=1) 效率低: BR=64 线程中只有 1 个 activeowns_row。应写专用 decode attention kernel沿 KV 维度 parallel reduction。
  2. 无向量化加载: 当前逐元素 bf16→f32 转换,应改用 float4__nv_bfloat162 批量加载。
  3. Register tiling: 每个 thread 目前串行计算 dot product (128 MADs per K column)。可改为多线程协作。
  4. K/V double buffering: 可在计算当前 tile 时预加载下一个 tile 到另一半 shared memory。
  5. Tile size 调优: 更大的 tile (BR=128) 可能在长 sequence 时更优,需要 opt-in shared memory。

Test Plan

  • 正确性: logits 与 HF transformers 对比 (top-1 match 9/10, top-5 overlap 4.0/5)
  • 生成质量: 52/52 prompt 生成连贯文本,中英文均可
  • SSE streaming 正常工作
  • 性能: 12.9 tok/s (vs naive 10.3 tok/s, +25%)
  • 长 sequence (S=4096, S=8192): 验证 naive OOM 而 FA2 正常
  • ncu profile: compute utilization, memory throughput