Files
xserv/docs/10-qwen3.md
Gahow Wang ee68d3565d fix: comprehensive review + 14 bug fixes + Phase 12/14 overhaul
Strict code review identified 30+ issues across correctness, performance,
and architecture. This commit addresses 14 of them with verified fixes,
restructures Phase 12 for honest continuous batching, and updates Phase 14
to target FA2 (RTX 5090 SM120 lacks TMEM required by FA4).

Bug fixes:
- FIX-01: Global cuBLAS handle (thread-local singleton, was per-call)
- FIX-02: Remove 19 unnecessary cudaDeviceSynchronize calls from kernels
- FIX-03: Qwen3 ChatML template (was plain text concatenation)
- FIX-04: EOS token from tokenizer (was hardcoded 151645)
- FIX-05: Storage tracks actual GPU device ordinal (was always Cuda(0))
- FIX-06: unsqueeze stride preserves contiguous layout
- FIX-08: CudaDeviceProp replaced with heap buffer (was UB-prone padding)
- FIX-09: Tokenizer byte_fallback to <0xNN> tokens (was panic)

Feature additions:
- FIX-10: SSE streaming (/v1/chat/completions, OpenAI-compatible)
- FIX-11: Correct usage statistics (prompt/completion/total tokens)
- FIX-13: Temperature / top-k / top-p sampling with SamplingParams

Performance improvements:
- FIX-07: Caching allocator wired up (thread-local pool, pooled flag)
- FIX-12: KV cache staging buffers (zero-alloc get_kv_len via borrow_raw)
- FIX-14: GPU strided copy kernel (eliminates contiguous() CPU round-trip)

Architecture:
- Phase 12 engine restructured: prefill/decode separation, honest TODO
  for batched GPU forward (requires Flash Attention)
- Phase 14 updated: FA2 for SM120 (FA4 requires TMEM, absent on 5090)
- Qwen3-7B → Qwen3-8B typo fixed across all docs (36 layers, hidden 4096)

Validated on dash5 (8x RTX 5090):
- 52/52 API prompts pass (EN/CN/code), SSE streaming verified
- Logits match HF transformers 9/10 top-1, 4.0/5 avg top-5 overlap
- 8 concurrent requests: 5.99x scheduling speedup (batch_size=4)
- Throughput: 10.3 tok/s (serial), 30% of HF baseline

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 17:53:28 +08:00

5.1 KiB
Raw Blame History

Phase 10: Qwen3-8B Support — Design Document (Milestone ②)

Goal

扩展模型定义支持 Qwen3-8B 架构,验证输出正确性。与 GPT-2 的关键差异RMSNorm、RoPE、GQA、SwiGLU、不共享 embedding。

架构差异 (GPT-2 → Qwen3)

特性 GPT-2 Qwen3-8B
Norm LayerNorm(gamma, beta) RMSNorm(gamma only)
Position Learned absolute (wpe) RoPE (no params)
Attention MHA (12 Q = 12 KV heads) GQA (32 Q, 8 KV heads)
QKV projection Combined c_attn [H, 3H] Separate q/k/v_proj [H, Hq/Hk/Hv]
FFN 2 Linear (fc, proj) + GELU 3 Linear (gate, up, down) + SwiGLU
Weight layout [in, out] (Conv1D style) [out, in] (standard Linear)
Tied embeddings Yes No (separate lm_head)
hidden_size 768 4096
num_layers 12 36
head_dim 64 128

Weight Names (HuggingFace)

model.embed_tokens.weight                           [151936, 3584]
model.layers.{i}.input_layernorm.weight             [3584]
model.layers.{i}.self_attn.q_proj.weight            [3584, 3584]    (32 heads × 112 dim? or 28 heads)
model.layers.{i}.self_attn.q_proj.bias              [3584]
model.layers.{i}.self_attn.k_proj.weight            [512, 3584]     (4 KV heads × 128 dim)
model.layers.{i}.self_attn.k_proj.bias              [512]
model.layers.{i}.self_attn.v_proj.weight            [512, 3584]
model.layers.{i}.self_attn.v_proj.bias              [512]
model.layers.{i}.self_attn.o_proj.weight            [3584, 3584]
model.layers.{i}.post_attention_layernorm.weight    [3584]
model.layers.{i}.mlp.gate_proj.weight               [18944, 3584]
model.layers.{i}.mlp.up_proj.weight                 [18944, 3584]
model.layers.{i}.mlp.down_proj.weight               [3584, 18944]
model.norm.weight                                    [3584]
lm_head.weight                                      [151936, 3584]

注意: Qwen3 权重是 [out, in] layoutx @ W^T 而不是 x @ W

GQA (Grouped Query Attention)

num_heads = 28, num_kv_heads = 4, head_dim = 128
Q: [B, 28, S, 128]
K: [B, 4, S, 128]   ← 每个 KV head 服务 28/4 = 7 个 Q head
V: [B, 4, S, 128]

attention 时需要 repeat K/V:
K_expanded: [B, 28, S, 128]  ← repeat_interleave(K, 7, dim=1)

实现:在 CPU 侧 split_qkv 时直接做 repeat。

SwiGLU FFN

gate = gate_proj(x)     # [S, 3584] @ [3584, 18944]^T → [S, 18944]
up   = up_proj(x)       # [S, 3584] @ [3584, 18944]^T → [S, 18944]
out  = silu(gate) * up   # element-wise
out  = down_proj(out)    # [S, 18944] @ [18944, 3584]^T → [S, 3584]

显存预算 (BF16, 单卡 5090)

权重: 8B × 2B = ~16 GB (BF16)
        8B × 4B = ~32 GB (FP32) — 不够! 必须用 BF16
KV cache (S=256, B=1): ~0.1 GB
总计: ~16 GB (BF16), 单卡可运行

关键: Qwen3-8B 必须用 BF16 才能在单张 5090 (32GB) 上运行。当前 GPT-2 用 FP32需要支持 BF16 forward pass。

Implementation Plan

  1. 下载 Qwen3-8B 模型 (BF16, ~14GB)
  2. 实现 Qwen3 模型结构 (qwen3.rs)
  3. 支持 BF16 forward pass (linear_transpose for [out, in] weights)
  4. 实现 GQA (K/V repeat in split)
  5. 集成 RoPE + RMSNorm + SwiGLU
  6. 验证输出

Test Plan

  • 加载 Qwen3-8B BF16 权重 (399 tensors, ~15.5GB) 到单张 5090
  • 英文: "The meaning of life is" → "to be happy"
  • 中文: "请用中文回答1+1等于几" → "1加1"
  • 61/61 单元测试无回归
  • GPT-2 benchmark 性能无回归

Takeaways

  1. Qwen3 实际是 8B不是 7Bmodelscope 上的 Qwen/Qwen3-8B 有 36 层 × hidden 4096 × 32 heads参数量约 8B。BF16 权重 ~15.5GB,单张 5090 (32GB) 可以运行。

  2. QK Normalization 是 Qwen3 的新特性:每层有 q_normk_norm (shape [head_dim]),对 Q 和 K 做 per-head RMSNorm。这在 attention score 的数值稳定性上很重要——没有 QK norm 会导致 attention score 爆炸。

  3. attention_bias=falseQwen3 的 Q/K/V/O projection 没有 bias。这和 GPT-2 (有 bias) 不同。需要在模型代码中条件处理。

  4. Tokenizer 的 byte-to-unicode 映射 bugGPT-2 和 Qwen3 都使用同一套 byte-to-unicode 映射printable ASCII identity其余 68 bytes shifted to U+0100+)。初始实现中 unicode_to_byte 的 shifted 范围转换错误(直接 u - 0x100 而非查表),导致中文输入时 UTF-8 bytes 无法正确映射。修复:用 OnceLock 缓存反向映射表。

  5. Weight layout [out, in] vs [in, out]GPT-2 的 Conv1D 存为 [in, out],计算 x @ WQwen3 的 Linear 存为 [out, in],计算 x @ W^Tlinear_t 函数通过 weight.transpose(0,1).contiguous() 处理。

  6. RoPE 的 tensor layout 不匹配RoPE kernel 期望 [S, H, D],但 attention 需要 [1, H, S, D]。需要在 RoPE 前后做 transpose。这引入了额外的 CPU round-trip因为 transpose+contiguous 经过 CPU

  7. GQA repeat_kv 的实现:每个 KV head 服务 num_heads/num_kv_heads 个 Q head。在 CPU 上做数据复制repeat简单但每步 decode 都要做。后续应在 attention kernel 中直接支持 GQA 索引,避免数据复制。