Support compute_dtype for FP4/FP8 tensor core FLOPS selection
Add `compute_dtype` field to ModelConfig ("bf16", "fp8", "fp4") which
controls two things:
- GPU FLOPS tier: auto-selects from preset FP4/FP8/BF16 TFLOPS
- Weight bytes: uses 0.5/1.0/2.0 bytes per param for memory-bound check
Hardware presets now include per-GPU FP8 and FP4 dense FLOPS for all
GPUs that support them (H100/H800/H20: FP8, B200/B300: FP8+FP4).
Config resolution auto-selects the right FLOPS when compute_dtype is
set and the user hasn't explicitly overridden gpu_flops.
GLM-5-NVFP4 on 8xB300 now correctly uses 13.5 PFLOPS/GPU FP4 (6x
faster prefill) and 0.5 bytes/param weights (halved memory footprint).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -75,7 +75,8 @@ impl ComputeModel {
|
||||
let n_kv = model.num_kv_heads as f64;
|
||||
let hd = model.head_dim as f64;
|
||||
let inter = model.intermediate_size.unwrap_or(0) as f64;
|
||||
let dtype = model.dtype_bytes as f64;
|
||||
// Weight dtype for memory-bound check (separate from KV cache dtype).
|
||||
let wdtype = model.weight_dtype_bytes();
|
||||
|
||||
// --- Attention linear FLOPs/token/layer ---
|
||||
let attn_linear = if let Some(mla) = &model.mla {
|
||||
@@ -134,18 +135,18 @@ impl ComputeModel {
|
||||
(h * qlr + qlr * n_heads * qk_hd
|
||||
+ h * (kvlr + qk_rd)
|
||||
+ n_heads * vhd * h)
|
||||
* dtype
|
||||
* wdtype
|
||||
} else {
|
||||
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype
|
||||
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype
|
||||
};
|
||||
let mlp_wt = if let Some(moe) = &model.moe {
|
||||
let expert_inter = moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let active = moe.num_active_experts as f64;
|
||||
let shared = moe.num_shared_experts as f64;
|
||||
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype
|
||||
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype
|
||||
} else {
|
||||
3.0 * h * inter * dtype
|
||||
3.0 * h * inter * wdtype
|
||||
};
|
||||
let weight_bytes = attn_wt + mlp_wt;
|
||||
|
||||
@@ -385,6 +386,8 @@ mod tests {
|
||||
};
|
||||
let hw = HardwareConfig {
|
||||
gpu_flops: 1e14,
|
||||
gpu_fp8_flops: 0.0,
|
||||
gpu_fp4_flops: 0.0,
|
||||
gpu_mem_bw: 1e12,
|
||||
hbm_bytes: 1e9,
|
||||
dram_bytes: 4e9,
|
||||
|
||||
Reference in New Issue
Block a user