2.2 KiB
prefill
对于 attention:
对一个序列长度为 S 的 transformer attention(多头、按 d = H * head_dim):
-
Q/K/V 投影(一次性做 3 个线性变换):
FLOPsQKV_proj≈2×S×d×(3d)=6Sd2FLOPsQKV_proj≈2×S×d×(3d)=6Sd2
(矩阵乘法的常用近似:2·m·n·k)
-
Attention 矩阵乘(Q·K^T):
FLOPsQK≈2×H×S×dhead×S=2S2dFLOPsQK≈2×H×S×dhead×S=2S2d
因为 Hdhead=dHdhead=d。
-
Attention·V(权重与 V 相乘):
FLOPsAV≈2S2dFLOPsAV≈2S2d
-
输出投影(从 heads 拼回 d 再线性变换):
FLOPsout≈2Sd2FLOPsout≈2Sd2
总 FLOPs:8Sd^2 + 4S^2d
T_{\text{comp}} = \frac{\text{FLOPs}_{\text{per\_GPU}}}{\text{peak\_flops\_per\_GPU} \times \text{compute\_utils}}
总 memory:$\text{bytes}_\text{prefill} \approx N \cdot \alpha \cdot BLd \cdot \text{elem_bytes}$,\alpha \sim 6
T_{\text{mem}} = \frac{\text{bytes}_{\text{per\_GPU}}}{\text{bandwidth\_per\_GPU} \times \text{mem\_utils}}
decode
总 FLOPs:8d^2 + 4dL
总 memory:$\text{bytes}_\text{decode} \approx N \cdot \beta \cdot BLd \cdot \text{elem_bytes}$,\beta \sim 4
\text{output} = \text{SiLU}(xW_1)W_2
TP 下,每 token T 激活的 expert E 的通信为:
- 输入
xAllGather 到所有 TP 节点,通信量:hidden_size * (TP - 1) - 每个 TP 节点独立计算
xW_1' - AllGather 后每个节点得到完整的 $xW_1$,通信量:moe_intermediate_size / TP * (TP - 1)
- 每个节点计算 SiLU 和 $IW2'$,AllReduce 每个节点得到完整的 output,通信量:hidden_size * (TP - 1)
EP 下:dispatch+combine
2 * hidden_size * (EP - 1) / EP (假设负载均衡)
With batch_size=2000, seq_len=2048, EP=8
Qwen-235B: attention comp time 0.06944874306412531 moe combine comm time 0.00028672 moe comp time 0.0004069259060131379 moe comm time with TP 0.00045056
Qwen-30B: attention comp time 0.01736218682573421 moe combine comm time 0.00014336 moe comp time 0.00010173147650328448 moe comm time with TP 0.00022528
EP=64: Qwen-235B: attention comp time 0.06944874306412531 moe combine comm time 0.00032256 moe comp time 0.0004069259060131379 moe comm time with TP 0.00045056