diff --git a/docs/11-paged-attention.md b/docs/11-paged-attention.md index f390b5d..efd6cb1 100644 --- a/docs/11-paged-attention.md +++ b/docs/11-paged-attention.md @@ -1,8 +1,10 @@ -# Phase 11: Paged Attention + KV Cache Manager — Design Document +# Phase 11: GPU-Resident KV Cache — Design Document + +> **注意**: 原计划为 "Paged Attention + KV Cache Manager",实际实现为 GPU 连续预分配 KV cache(非 paged)。Paged allocation 留待后续优化。 ## Goal -将 KV cache 从 CPU Vec 迁移到 GPU,使用 block-based paging 管理显存。消除每步 decode 的 CPU round-trip(当前 KV cache 最大性能瓶颈之一)。 +将 KV cache 从 CPU Vec 迁移到 GPU,消除每步 decode 的 CPU round-trip(当前 KV cache 最大性能瓶颈之一)。 ## 当前问题 diff --git a/docs/12-continuous-batching.md b/docs/12-continuous-batching.md index c128108..9dc305d 100644 --- a/docs/12-continuous-batching.md +++ b/docs/12-continuous-batching.md @@ -150,4 +150,8 @@ HTTP Handler Engine Thread ## 当前状态 -**未实现**。当前是 FIFO 串行,一次只处理一个请求。本文档是实现的设计规格。 +**已实现: iteration-level scheduling**。多请求可以并发进入 batch (max_batch_size),新请求在 mid-generation 动态加入。Prefill 和 decode 阶段在每轮迭代内分离处理。 + +**未实现: batched GPU forward**。每个 seq 的 model forward 仍是串行调用 (per-seq forward_gpu_cache)。真正的 batched decode (多 seq 的 token 合并为一次 GPU forward) 需要 Flash Attention 的 variable-length attention 支持。Phase 14 实现了 FA2 kernel,为后续 batched forward 提供了基础。 + +**验证**: 8 个并发请求 (max_batch=4) 总 wall clock 22.5s,各请求延迟之和 135.0s,调度加速 6.0x。Server log 确认 `decode batch_size=4`。 diff --git a/docs/14-flash-attention.md b/docs/14-flash-attention.md new file mode 100644 index 0000000..4f4dd79 --- /dev/null +++ b/docs/14-flash-attention.md @@ -0,0 +1,167 @@ +# Phase 14: Flash Attention 2 for SM120 — Design Document + +## Goal + +用自写的 Flash Attention 2 CUDA kernel 替换 naive attention (Phase 5)。消除 O(S²) 显存分配,支持 GQA kernel 内部索引(消除 repeat_kv 开销)。 + +## 硬件约束: FA4 不适用于 RTX 5090 + +Flash Attention 已发展到第 4 代 (FA4, arxiv 2603.05451),但各版本有明确硬件依赖: + +| 版本 | 目标架构 | 关键硬件特性 | RTX 5090 (SM120) | +|------|---------|------------|-----------------| +| FA2 | 通用 CUDA (SM75+) | shared memory + HMMA | **兼容** | +| FA3 | Hopper SM90 (H100) | TMA + WGMMA + warp specialization | 不兼容 | +| FA4 | Blackwell SM100 (B200/B300) | TMEM + async MMA + 2-CTA mode | 不兼容 | + +RTX 5090 使用消费级 Blackwell (GB202, SM120),与数据中心 Blackwell (B200, SM100) 是不同硅片。SM120 **没有 TMEM (Tensor Memory)**,这是 FA4 kernel 设计的核心硬件依赖。这不是软件限制,是硬件级差异。 + +因此本项目实现 **FA2 算法**,使用标准 CUDA (shared memory + 标准 HMMA)。 + +## Naive Attention 的问题 + +Phase 5 的 naive attention 流程: +``` +k_t = K.transpose(2,3).contiguous() ← 分配 K^T 显存 +scores = batched_matmul(Q, k_t) ← 分配 [B,H,S,S] score 矩阵 (O(S²) 显存) +scores = scale(scores, 1/sqrt(d)) ← 逐元素 kernel +causal_mask(scores) ← 逐元素 kernel +weights = softmax(scores) ← 分配 [B,H,S,S] weight 矩阵 +output = batched_matmul(weights, V) ← 最终结果 +``` + +问题: +1. **显存 O(S²)**: score 和 weight 矩阵各需 `B × H × S × S × dtype_size`。S=2048, H=32, BF16 → 256 MB。S=8192 → 4 GB。 +2. **GQA 预处理**: 在调用 attention 前需要 `repeat_kv_gpu` 将 K/V 从 8 heads 扩展到 32 heads,每层额外分配和拷贝。 +3. **多次 kernel launch**: scale, mask, softmax 各一次 kernel launch + global memory round-trip。 +4. **K^T materialization**: `K.transpose().contiguous()` 需要分配和拷贝。 + +## FA2 算法 + +核心思想: **不 materialize S×S 矩阵**。将 Q, K, V 分成 tiles,在 shared memory (SRAM) 中计算,使用 **online softmax trick** 边算边更新 running max 和 sum。 + +FA2 (Dao 2023) 相比 FA1 的改进: 外层循环遍历 Q tiles (而非 K/V),减少 HBM 读写次数,提高并行性。 + +``` +scale = 1 / sqrt(head_dim) + +for each Q tile (q_start..q_start + BR): // 外层: Q tiles + load Q_tile [BR, D] to shared memory (一次加载,内层复用) + init per-row: O[D] = 0, m = -inf, l = 0 + + for each K/V tile j (kv_start..kv_start + BC): // 内层: K/V tiles + // Causal tile-skip: 如果整个 K tile 在 Q tile "未来",跳过 + if causal && kv_start > max_q_pos + kv_offset: skip + + load K_tile [BC, D] to shared memory + S = Q_tile @ K_tile^T * scale // [BR, BC], in registers + if causal: mask S[r][c] = -inf where kv_pos > q_pos + + // Online softmax update + m_new = max(m, rowmax(S)) + P = exp(S - m_new) + l_new = exp(m - m_new) * l + rowsum(P) + O = exp(m - m_new) * O // rescale accumulator + + load V_tile [BC, D] to shared memory (复用 K 的空间) + O += P @ V_tile // accumulate + + m = m_new, l = l_new + + O = O / l // final normalize + write O[BR, D] to HBM (convert FP32 → BF16) +``` + +## 实现细节 + +### Kernel 配置 + +| 参数 | 值 | 说明 | +|------|---|------| +| BR (Q tile rows) | 64 | Q tile 大小 | +| BC (K/V tile rows) | 64 | K/V tile 大小 | +| head_dim | 运行时参数 (≤128) | 支持 64 (GPT-2) 和 128 (Qwen3) | +| Block size | 128 threads | 64 线程各 own 一行 Q,其余协助加载 | +| Grid | (q_tiles, batch × num_q_heads) | 每个 block 处理一个 Q tile + 一个 head | + +### Shared Memory (BF16 存储) + +``` +smem_q [BR × head_dim] BF16 = 64 × 128 × 2 = 16 KB (加载一次,内层复用) +smem_kv[BC × head_dim] BF16 = 64 × 128 × 2 = 16 KB (K 和 V 交替使用) +──────────────────────────────────────────── +Total: 32 KB (SM120 默认 48 KB,余量充足) +``` + +### 线程映射 + +- Thread 0..63: 各 own Q_tile 的一行。负责该行的全部计算:dot products、softmax、PV 累加。 +- Thread 64..127: 协助 shared memory 加载 (K/V tile),不参与计算。 +- 加载模式: 每个 thread 加载 `(BR × head_dim) / 128 = 64` 个 BF16 元素。 + +### Per-Thread Register 使用 + +``` +O_acc[128] FP32 = 512 bytes (128 regs) — 输出累加器 +P[64] FP32 = 256 bytes (64 regs) — 当前 tile 的 softmax 后权重 +m, l FP32 = 8 bytes (2 regs) — online softmax running state +循环变量 + 临时 ≈ 16 regs +──────────────────────────────────────────── +Total: ~210 regs/thread (max 255,在限制内) +``` + +### GQA 支持 + +每个 thread block 处理一个 Q head,通过 `kv_head = q_head / (num_q_heads / num_kv_heads)` 映射到对应的 KV head。K/V 的数据指针直接指向 KV head 的存储,无需 repeat_kv。 + +``` +// 32 Q heads, 8 KV heads → heads_per_group = 4 +// Q head 0,1,2,3 → KV head 0 +// Q head 4,5,6,7 → KV head 1 +// ... +kv_head = q_head / heads_per_group; +K_ptr = K + (batch * num_kv_heads + kv_head) * kv_len * head_dim; +``` + +### Causal Mask + +两级优化: +1. **Tile-level skip**: 如果 `kv_tile_start > max_q_pos + kv_offset`,整个 K/V tile 都在未来,跳过(减少 ~50% 计算)。 +2. **Element-level mask**: 在 tile 内部,`if kv_pos > q_pos + kv_offset: S = -inf`。 + +`kv_offset = kv_len - q_len` 处理 decode 时 KV cache 长于 Q 的情况。 + +## 与 Naive Attention 的对比 + +| 特性 | Naive (Phase 5) | FA2 (Phase 14) | +|------|----------------|----------------| +| 显存 | O(B × H × S²) | O(B × H × S × D) | +| GQA | 需要 repeat_kv (分配+拷贝) | Kernel 内部索引 (零开销) | +| K^T | 需要 transpose+contiguous | Kernel 内部计算 | +| Kernel launches | 6 (matmul, scale, mask, softmax, matmul, ...) | 1 (单个 fused kernel) | +| S=8192 可行性 | OOM (~4 GB score matrix) | 可行 (32 KB shared memory) | + +## 源码结构 + +``` +csrc/attention/flash_attention.cu — FA2 kernel (BF16 in, FP32 accumulate, BF16 out) +crates/xserv-kernels/src/attention.rs — flash_attention() Rust wrapper + 原 attention() 保留 +crates/xserv-model/src/qwen3.rs — forward_gpu_cache 调用 flash_attention +``` + +## 已知局限与后续优化方向 + +1. **Decode (Q_len=1) 效率低**: BR=64 线程中只有 1 个 active(owns_row)。应写专用 decode attention kernel,沿 KV 维度 parallel reduction。 +2. **无向量化加载**: 当前逐元素 bf16→f32 转换,应改用 `float4` 或 `__nv_bfloat162` 批量加载。 +3. **Register tiling**: 每个 thread 目前串行计算 dot product (128 MADs per K column)。可改为多线程协作。 +4. **K/V double buffering**: 可在计算当前 tile 时预加载下一个 tile 到另一半 shared memory。 +5. **Tile size 调优**: 更大的 tile (BR=128) 可能在长 sequence 时更优,需要 opt-in shared memory。 + +## Test Plan + +- [x] 正确性: logits 与 HF transformers 对比 (top-1 match 9/10, top-5 overlap 4.0/5) +- [x] 生成质量: 52/52 prompt 生成连贯文本,中英文均可 +- [x] SSE streaming 正常工作 +- [x] 性能: 12.9 tok/s (vs naive 10.3 tok/s, +25%) +- [ ] 长 sequence (S=4096, S=8192): 验证 naive OOM 而 FA2 正常 +- [ ] ncu profile: compute utilization, memory throughput diff --git a/docs/benchmarks/phase14-flash-attention.md b/docs/benchmarks/phase14-flash-attention.md new file mode 100644 index 0000000..b7cfc06 --- /dev/null +++ b/docs/benchmarks/phase14-flash-attention.md @@ -0,0 +1,109 @@ +# Phase 14 Benchmark: Flash Attention 2 + +**Date**: 2026-05-22 +**Hardware**: RTX 5090 (32GB GDDR7, SM120 CC 12.0, 170 SMs) +**Model**: Qwen3-8B (BF16, 36 layers, 4096 hidden, 32 Q / 8 KV GQA heads, head_dim=128) +**Config**: greedy decoding (temperature=0), max_tokens=64, single-request serial + +## Correctness + +Logits comparison with HuggingFace transformers (10 prompts, raw text without ChatML): + +| Metric | Result | +|--------|--------| +| Prefill Top-1 match vs HF | **9/10 (90%)** | +| Avg Top-5 overlap vs HF | **4.0/5** | +| Result vs pre-FA2 naive attention | **Identical** (same 9/10 top-1, same 4.0/5 overlap) | + +The single top-1 mismatch ("Explain quantum computing.") has logits differing by 0.125 +(22.000 vs 21.875) — within BF16 precision. The top-5 sets are identical (5/5 overlap). + +FA2 introduces no precision degradation compared to the naive attention path. + +## API Generation + +52 diverse prompts (English, Chinese, code) via `/v1/chat/completions`: + +| Metric | Result | +|--------|--------| +| Success rate | **52/52 (100%)** | +| SSE streaming | **Working** (role chunk, content chunks, finish_reason, [DONE]) | +| Usage stats | Correct (prompt_tokens + completion_tokens = total_tokens) | + +## Performance + +### xserv vs HuggingFace transformers + +8 prompts (short/medium/long) × max_tokens=64, greedy: + +| Category | Prompt Tokens | xserv (tok/s) | HF (tok/s) | Ratio | +|----------|--------------|---------------|------------|-------| +| Short (~12 tok) | 12-14 | 12.5 | 38.5 | 0.32x | +| Medium (~28 tok) | 27-28 | 13.6 | 44.1 | 0.31x | +| Long (~60 tok) | 58-64 | 13.0 | 36.0 | 0.36x | +| **Overall** | — | **12.9** | **36.6** | **0.35x** | + +### Phase-over-Phase Improvement + +| Phase | Attention | repeat_kv | tok/s | vs HF | +|-------|-----------|-----------|-------|-------| +| 10 | Naive (O(S²), cuBLAS batched) | CPU round-trip | 6.9 | 15% | +| 11 | Naive + GPU KV cache | GPU repeat_kv | 10.3 | 30% | +| **14** | **FA2 (O(1), fused kernel)** | **None (GQA in kernel)** | **12.9** | **35%** | + +Phase 14 vs Phase 11: **+25% throughput** (10.3 → 12.9 tok/s). + +### Improvement Breakdown (estimated) + +| Factor | Contribution | +|--------|-------------| +| Eliminating repeat_kv GPU alloc + copy (per layer) | ~10% | +| Eliminating K^T transpose + contiguous | ~5% | +| Eliminating S×S score matrix alloc | ~5% | +| Fused kernel (1 launch vs 6) | ~5% | + +### Concurrent Requests + +8 concurrent requests, max_batch=4: + +| Metric | Result | +|--------|--------| +| Wall clock | 22.5s | +| Sum of individual latencies | 135.0s | +| Scheduling speedup | **6.0x** | +| Throughput | 11.4 tok/s | + +Continuous batching scheduling confirmed working (decode batch_size=4 in logs). + +## Remaining Performance Gap + +35% of HF throughput. Main bottlenecks: + +| Bottleneck | Impact | Fix | +|-----------|--------|-----| +| **Decode Q_len=1 inefficiency** | FA2 kernel: 64 threads, only 1 active (owns_row=true for single query) | Specialized decode attention kernel (vector-dot against KV, parallel reduction along S) | +| **No kernel fusion** | RMSNorm+residual, SiLU*up: separate kernels, redundant HBM reads/writes | Fused kernels (Phase 15) | +| **No CUDA Graphs** | ~100+ kernel launches per decode step, each has host-side overhead | Capture decode iteration as CUDA Graph (Phase 15) | +| **Per-seq forward (no batched decode)** | With batch=4, 4 serial forward passes per iteration | Batched projections + per-seq attention (Phase 15, depends on FA2 decode kernel) | +| **No vectorized loads in FA2** | Scalar bf16→f32 conversion in dot product loop | float4 / bfloat162 vectorized loads | + +## Memory Usage + +| Component | Naive (Phase 11) | FA2 (Phase 14) | +|-----------|-----------------|----------------| +| Score matrix [1, 32, S, S] | S² × 32 × 2B | **0** | +| repeat_kv K/V [1, 32, S, 128] | 2 × S × 32 × 128 × 2B per layer | **0** | +| K^T contiguous copy | S × 32 × 128 × 2B per layer | **0** | + +For S=256 (current max): savings ~6 MB per layer × 36 layers ≈ 216 MB. +For S=2048: savings ~384 MB per layer × 36 layers ≈ 13.5 GB (naive would OOM). + +## Tracking + +| Phase | Attention | tok/s | vs HF | Correctness | +|-------|-----------|-------|-------|-------------| +| 8 | Naive (no cache) | 2.5 | 5% | 50/50 vs HF | +| 9 | Naive + CPU KV cache | 44.3 (GPT-2) | — | 50/50 self | +| 10 | Naive + CPU KV cache | 6.9 (Qwen3-8B) | 15% | 100% top-5 | +| 11 | Naive + GPU KV cache | 10.3 | 30% | 9/10 top-1 | +| **14** | **FA2 + GQA in kernel** | **12.9** | **35%** | **9/10 top-1** |