Support compute_dtype for FP4/FP8 tensor core FLOPS selection

Add `compute_dtype` field to ModelConfig ("bf16", "fp8", "fp4") which
controls two things:
- GPU FLOPS tier: auto-selects from preset FP4/FP8/BF16 TFLOPS
- Weight bytes: uses 0.5/1.0/2.0 bytes per param for memory-bound check

Hardware presets now include per-GPU FP8 and FP4 dense FLOPS for all
GPUs that support them (H100/H800/H20: FP8, B200/B300: FP8+FP4).
Config resolution auto-selects the right FLOPS when compute_dtype is
set and the user hasn't explicitly overridden gpu_flops.

GLM-5-NVFP4 on 8xB300 now correctly uses 13.5 PFLOPS/GPU FP4 (6x
faster prefill) and 0.5 bytes/param weights (halved memory footprint).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-14 11:54:10 +08:00
parent 84696604e8
commit 663ca9c5b9
5 changed files with 106 additions and 34 deletions

View File

@@ -78,58 +78,76 @@ fn parse_count_gpu(s: &str) -> (u32, String) {
// -- Per-GPU base specs (single die, BF16 dense) -----------------------------
struct GpuBase {
flops: f64, // BF16 dense TFLOPS
mem_bw: f64, // HBM bandwidth (B/s)
hbm: f64, // Total HBM (bytes)
pcie_gen: u32, // PCIe generation (4/5/6)
flops: f64, // BF16 dense FLOPS
fp8_flops: f64, // FP8 dense FLOPS (0 = not supported)
fp4_flops: f64, // FP4 dense FLOPS (0 = not supported)
mem_bw: f64, // HBM bandwidth (B/s)
hbm: f64, // Total HBM (bytes)
pcie_gen: u32, // PCIe generation (4/5/6)
}
const H100: GpuBase = GpuBase {
flops: 9.89e14, // 989 TFLOPS BF16
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
flops: 9.89e14, // 989 TFLOPS BF16 dense
fp8_flops: 1.979e15, // 1979 TFLOPS FP8 dense
fp4_flops: 0.0, // not supported
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
pcie_gen: 5,
};
const H800: GpuBase = GpuBase {
flops: 9.89e14, // same die as H100
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
flops: 9.89e14, // same die as H100
fp8_flops: 1.979e15,
fp4_flops: 0.0,
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
pcie_gen: 5,
};
const H20: GpuBase = GpuBase {
flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper)
mem_bw: 4.0e12, // 4.0 TB/s HBM3
hbm: 96.0e9, // 96 GB
flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper)
fp8_flops: 2.96e14, // 296 TFLOPS FP8
fp4_flops: 0.0, // not supported
mem_bw: 4.0e12, // 4.0 TB/s HBM3
hbm: 96.0e9, // 96 GB
pcie_gen: 5,
};
const A100_80GB: GpuBase = GpuBase {
flops: 3.12e14, // 312 TFLOPS BF16
mem_bw: 2.0e12, // 2.0 TB/s HBM2e
hbm: 80.0e9, // 80 GB
flops: 3.12e14, // 312 TFLOPS BF16
fp8_flops: 0.0, // A100 has no FP8 tensor cores
fp4_flops: 0.0,
mem_bw: 2.0e12, // 2.0 TB/s HBM2e
hbm: 80.0e9, // 80 GB
pcie_gen: 4,
};
const A100_40GB: GpuBase = GpuBase {
flops: 3.12e14, // 312 TFLOPS BF16
fp8_flops: 0.0,
fp4_flops: 0.0,
mem_bw: 1.555e12, // 1.555 TB/s HBM2e
hbm: 40.0e9, // 40 GB
pcie_gen: 4,
};
// DGX B200 (8 GPU) specs: BF16 18 PFLOPS, FP8 36 PFLOPS, FP4 72 PFLOPS (dense)
const B200: GpuBase = GpuBase {
flops: 2.25e15, // 2250 TFLOPS BF16
mem_bw: 8.0e12, // 8.0 TB/s HBM3e
hbm: 192.0e9, // 192 GB
flops: 2.25e15, // 2250 TFLOPS BF16 dense
fp8_flops: 4.5e15, // 4500 TFLOPS FP8 dense
fp4_flops: 9.0e15, // 9000 TFLOPS FP4 dense
mem_bw: 8.0e12, // 8.0 TB/s HBM3e
hbm: 192.0e9, // 192 GB
pcie_gen: 6,
};
// DGX B300 (8 GPU) specs: BF16 18 PFLOPS, FP8 ~54 PFLOPS, FP4 108 PFLOPS (dense)
const B300: GpuBase = GpuBase {
flops: 2.25e15, // 2250 TFLOPS BF16 dense (same GB202 die as B200)
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi (50% more than B200 8-Hi)
hbm: 288.0e9, // 288 GB HBM3e 12-Hi
flops: 2.25e15, // 2250 TFLOPS BF16 dense (same GB202 die as B200)
fp8_flops: 6.75e15, // 6750 TFLOPS FP8 dense (estimated from FP4/2)
fp4_flops: 13.5e15, // 13500 TFLOPS FP4 dense (Blackwell Ultra enhanced)
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi
hbm: 288.0e9, // 288 GB HBM3e 12-Hi
pcie_gen: 6,
};
@@ -165,6 +183,8 @@ fn make_config(n: u32, base: &GpuBase) -> HardwareConfig {
HardwareConfig {
gpu_flops: base.flops * f,
gpu_fp8_flops: base.fp8_flops * f,
gpu_fp4_flops: base.fp4_flops * f,
gpu_mem_bw: base.mem_bw * f,
hbm_bytes: base.hbm * f,
dram_bytes: dram,