Support compute_dtype for FP4/FP8 tensor core FLOPS selection

Add `compute_dtype` field to ModelConfig ("bf16", "fp8", "fp4") which
controls two things:
- GPU FLOPS tier: auto-selects from preset FP4/FP8/BF16 TFLOPS
- Weight bytes: uses 0.5/1.0/2.0 bytes per param for memory-bound check

Hardware presets now include per-GPU FP8 and FP4 dense FLOPS for all
GPUs that support them (H100/H800/H20: FP8, B200/B300: FP8+FP4).
Config resolution auto-selects the right FLOPS when compute_dtype is
set and the user hasn't explicitly overridden gpu_flops.

GLM-5-NVFP4 on 8xB300 now correctly uses 13.5 PFLOPS/GPU FP4 (6x
faster prefill) and 0.5 bytes/param weights (halved memory footprint).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-14 11:54:10 +08:00
parent 84696604e8
commit 663ca9c5b9
5 changed files with 106 additions and 34 deletions

View File

@@ -75,7 +75,8 @@ impl ComputeModel {
let n_kv = model.num_kv_heads as f64;
let hd = model.head_dim as f64;
let inter = model.intermediate_size.unwrap_or(0) as f64;
let dtype = model.dtype_bytes as f64;
// Weight dtype for memory-bound check (separate from KV cache dtype).
let wdtype = model.weight_dtype_bytes();
// --- Attention linear FLOPs/token/layer ---
let attn_linear = if let Some(mla) = &model.mla {
@@ -134,18 +135,18 @@ impl ComputeModel {
(h * qlr + qlr * n_heads * qk_hd
+ h * (kvlr + qk_rd)
+ n_heads * vhd * h)
* dtype
* wdtype
} else {
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype
};
let mlp_wt = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64;
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype
} else {
3.0 * h * inter * dtype
3.0 * h * inter * wdtype
};
let weight_bytes = attn_wt + mlp_wt;
@@ -385,6 +386,8 @@ mod tests {
};
let hw = HardwareConfig {
gpu_flops: 1e14,
gpu_fp8_flops: 0.0,
gpu_fp4_flops: 0.0,
gpu_mem_bw: 1e12,
hbm_bytes: 1e9,
dram_bytes: 4e9,