Support compute_dtype for FP4/FP8 tensor core FLOPS selection

Add `compute_dtype` field to ModelConfig ("bf16", "fp8", "fp4") which
controls two things:
- GPU FLOPS tier: auto-selects from preset FP4/FP8/BF16 TFLOPS
- Weight bytes: uses 0.5/1.0/2.0 bytes per param for memory-bound check

Hardware presets now include per-GPU FP8 and FP4 dense FLOPS for all
GPUs that support them (H100/H800/H20: FP8, B200/B300: FP8+FP4).
Config resolution auto-selects the right FLOPS when compute_dtype is
set and the user hasn't explicitly overridden gpu_flops.

GLM-5-NVFP4 on 8xB300 now correctly uses 13.5 PFLOPS/GPU FP4 (6x
faster prefill) and 0.5 bytes/param weights (halved memory footprint).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-14 11:54:10 +08:00
parent 84696604e8
commit 663ca9c5b9
5 changed files with 106 additions and 34 deletions

View File

@@ -7,7 +7,8 @@
model: model:
config_json: ../models/GLM-5-NVFP4/config.json config_json: ../models/GLM-5-NVFP4/config.json
name: glm-5-nvfp4 name: glm-5-nvfp4
dtype_bytes: 1 # FP8 KV cache compute_dtype: fp4 # FP4 weights → selects FP4 tensor core FLOPS
dtype_bytes: 1 # FP8 KV cache
block_size_tokens: 512 block_size_tokens: 512
hardware: hardware:

View File

@@ -57,6 +57,13 @@ pub struct ModelConfig {
#[serde(default)] #[serde(default)]
pub attention: Option<AttentionConfig>, pub attention: Option<AttentionConfig>,
/// Compute / weight precision: `"bf16"` (default), `"fp8"`, or `"fp4"`.
/// Controls which hardware FLOPS tier to use (`gpu_fp4_flops`, etc.) and
/// the weight-bytes-per-parameter for the memory-bound roofline check.
/// Independent of `dtype_bytes`, which sizes the KV cache.
#[serde(default)]
pub compute_dtype: Option<String>,
// -- Legacy manual coefficients (used when hidden_size is absent) --------- // -- Legacy manual coefficients (used when hidden_size is absent) ---------
#[serde(default)] #[serde(default)]
pub flops_per_token_prefill: Option<f64>, pub flops_per_token_prefill: Option<f64>,
@@ -79,6 +86,20 @@ impl ModelConfig {
self.hidden_size.is_some() self.hidden_size.is_some()
} }
/// Bytes per parameter for weight storage, derived from `compute_dtype`.
///
/// - `"fp4"` → 0.5
/// - `"fp8"` → 1.0
/// - `"bf16"` / absent → `dtype_bytes` (backward-compatible)
pub fn weight_dtype_bytes(&self) -> f64 {
match self.compute_dtype.as_deref() {
Some("fp4") => 0.5,
Some("fp8") => 1.0,
Some("bf16") => 2.0,
_ => self.dtype_bytes as f64, // backward compat
}
}
/// Bytes of KV cache per block. /// Bytes of KV cache per block.
/// ///
/// For standard / GQA: `2 * L * kv_heads * head_dim * dtype * block_tokens` /// For standard / GQA: `2 * L * kv_heads * head_dim * dtype * block_tokens`
@@ -147,7 +168,14 @@ pub enum AttentionConfig {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareConfig { pub struct HardwareConfig {
/// Active GPU FLOPS (selected from bf16/fp8/fp4 based on compute_dtype).
pub gpu_flops: f64, pub gpu_flops: f64,
/// FP8 tensor core FLOPS (0 if not populated by preset).
#[serde(default)]
pub gpu_fp8_flops: f64,
/// FP4 tensor core FLOPS (0 if not populated by preset).
#[serde(default)]
pub gpu_fp4_flops: f64,
pub gpu_mem_bw: f64, pub gpu_mem_bw: f64,
pub hbm_bytes: f64, pub hbm_bytes: f64,
pub dram_bytes: f64, pub dram_bytes: f64,
@@ -368,6 +396,8 @@ struct RawModelConfig {
#[serde(default)] #[serde(default)]
bytes_per_token_prefill: Option<f64>, bytes_per_token_prefill: Option<f64>,
#[serde(default)] #[serde(default)]
compute_dtype: Option<String>,
#[serde(default)]
flops_per_token_decode: Option<f64>, flops_per_token_decode: Option<f64>,
#[serde(default)] #[serde(default)]
bytes_per_token_decode: Option<f64>, bytes_per_token_decode: Option<f64>,
@@ -407,12 +437,25 @@ struct RawHardwareConfig {
impl RawConfig { impl RawConfig {
fn resolve(self, yaml_dir: &Path) -> Result<Config> { fn resolve(self, yaml_dir: &Path) -> Result<Config> {
Ok(Config { let model = self.model.resolve(yaml_dir)?;
model: self.model.resolve(yaml_dir)?, let user_set_gpu_flops = self.hardware.gpu_flops.is_some();
hardware: self.hardware.resolve()?, let mut hardware = self.hardware.resolve()?;
cluster: self.cluster,
sim: self.sim, // Auto-select gpu_flops tier based on model's compute_dtype,
}) // but only if the user did NOT explicitly override gpu_flops in YAML.
if !user_set_gpu_flops {
match model.compute_dtype.as_deref() {
Some("fp4") if hardware.gpu_fp4_flops > 0.0 => {
hardware.gpu_flops = hardware.gpu_fp4_flops;
}
Some("fp8") if hardware.gpu_fp8_flops > 0.0 => {
hardware.gpu_flops = hardware.gpu_fp8_flops;
}
_ => {} // keep BF16
}
}
Ok(Config { model, hardware, cluster: self.cluster, sim: self.sim })
} }
} }
@@ -446,6 +489,7 @@ impl RawModelConfig {
if let Some(v) = self.flops_per_token_prefill { m.flops_per_token_prefill = Some(v); } if let Some(v) = self.flops_per_token_prefill { m.flops_per_token_prefill = Some(v); }
if let Some(v) = self.attn_quadratic_coeff { m.attn_quadratic_coeff = Some(v); } if let Some(v) = self.attn_quadratic_coeff { m.attn_quadratic_coeff = Some(v); }
if let Some(v) = self.bytes_per_token_prefill { m.bytes_per_token_prefill = Some(v); } if let Some(v) = self.bytes_per_token_prefill { m.bytes_per_token_prefill = Some(v); }
if self.compute_dtype.is_some() { m.compute_dtype = self.compute_dtype; }
if let Some(v) = self.flops_per_token_decode { m.flops_per_token_decode = Some(v); } if let Some(v) = self.flops_per_token_decode { m.flops_per_token_decode = Some(v); }
if let Some(v) = self.bytes_per_token_decode { m.bytes_per_token_decode = Some(v); } if let Some(v) = self.bytes_per_token_decode { m.bytes_per_token_decode = Some(v); }
@@ -476,6 +520,8 @@ impl RawHardwareConfig {
} else { } else {
HardwareConfig { HardwareConfig {
gpu_flops: 0.0, gpu_flops: 0.0,
gpu_fp8_flops: 0.0,
gpu_fp4_flops: 0.0,
gpu_mem_bw: 0.0, gpu_mem_bw: 0.0,
hbm_bytes: 0.0, hbm_bytes: 0.0,
dram_bytes: 0.0, dram_bytes: 0.0,

View File

@@ -78,58 +78,76 @@ fn parse_count_gpu(s: &str) -> (u32, String) {
// -- Per-GPU base specs (single die, BF16 dense) ----------------------------- // -- Per-GPU base specs (single die, BF16 dense) -----------------------------
struct GpuBase { struct GpuBase {
flops: f64, // BF16 dense TFLOPS flops: f64, // BF16 dense FLOPS
mem_bw: f64, // HBM bandwidth (B/s) fp8_flops: f64, // FP8 dense FLOPS (0 = not supported)
hbm: f64, // Total HBM (bytes) fp4_flops: f64, // FP4 dense FLOPS (0 = not supported)
pcie_gen: u32, // PCIe generation (4/5/6) mem_bw: f64, // HBM bandwidth (B/s)
hbm: f64, // Total HBM (bytes)
pcie_gen: u32, // PCIe generation (4/5/6)
} }
const H100: GpuBase = GpuBase { const H100: GpuBase = GpuBase {
flops: 9.89e14, // 989 TFLOPS BF16 flops: 9.89e14, // 989 TFLOPS BF16 dense
mem_bw: 3.35e12, // 3.35 TB/s HBM3 fp8_flops: 1.979e15, // 1979 TFLOPS FP8 dense
hbm: 80.0e9, // 80 GB fp4_flops: 0.0, // not supported
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
pcie_gen: 5, pcie_gen: 5,
}; };
const H800: GpuBase = GpuBase { const H800: GpuBase = GpuBase {
flops: 9.89e14, // same die as H100 flops: 9.89e14, // same die as H100
mem_bw: 3.35e12, // 3.35 TB/s HBM3 fp8_flops: 1.979e15,
hbm: 80.0e9, // 80 GB fp4_flops: 0.0,
mem_bw: 3.35e12, // 3.35 TB/s HBM3
hbm: 80.0e9, // 80 GB
pcie_gen: 5, pcie_gen: 5,
}; };
const H20: GpuBase = GpuBase { const H20: GpuBase = GpuBase {
flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper) flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper)
mem_bw: 4.0e12, // 4.0 TB/s HBM3 fp8_flops: 2.96e14, // 296 TFLOPS FP8
hbm: 96.0e9, // 96 GB fp4_flops: 0.0, // not supported
mem_bw: 4.0e12, // 4.0 TB/s HBM3
hbm: 96.0e9, // 96 GB
pcie_gen: 5, pcie_gen: 5,
}; };
const A100_80GB: GpuBase = GpuBase { const A100_80GB: GpuBase = GpuBase {
flops: 3.12e14, // 312 TFLOPS BF16 flops: 3.12e14, // 312 TFLOPS BF16
mem_bw: 2.0e12, // 2.0 TB/s HBM2e fp8_flops: 0.0, // A100 has no FP8 tensor cores
hbm: 80.0e9, // 80 GB fp4_flops: 0.0,
mem_bw: 2.0e12, // 2.0 TB/s HBM2e
hbm: 80.0e9, // 80 GB
pcie_gen: 4, pcie_gen: 4,
}; };
const A100_40GB: GpuBase = GpuBase { const A100_40GB: GpuBase = GpuBase {
flops: 3.12e14, // 312 TFLOPS BF16 flops: 3.12e14, // 312 TFLOPS BF16
fp8_flops: 0.0,
fp4_flops: 0.0,
mem_bw: 1.555e12, // 1.555 TB/s HBM2e mem_bw: 1.555e12, // 1.555 TB/s HBM2e
hbm: 40.0e9, // 40 GB hbm: 40.0e9, // 40 GB
pcie_gen: 4, pcie_gen: 4,
}; };
// DGX B200 (8 GPU) specs: BF16 18 PFLOPS, FP8 36 PFLOPS, FP4 72 PFLOPS (dense)
const B200: GpuBase = GpuBase { const B200: GpuBase = GpuBase {
flops: 2.25e15, // 2250 TFLOPS BF16 flops: 2.25e15, // 2250 TFLOPS BF16 dense
mem_bw: 8.0e12, // 8.0 TB/s HBM3e fp8_flops: 4.5e15, // 4500 TFLOPS FP8 dense
hbm: 192.0e9, // 192 GB fp4_flops: 9.0e15, // 9000 TFLOPS FP4 dense
mem_bw: 8.0e12, // 8.0 TB/s HBM3e
hbm: 192.0e9, // 192 GB
pcie_gen: 6, pcie_gen: 6,
}; };
// DGX B300 (8 GPU) specs: BF16 18 PFLOPS, FP8 ~54 PFLOPS, FP4 108 PFLOPS (dense)
const B300: GpuBase = GpuBase { const B300: GpuBase = GpuBase {
flops: 2.25e15, // 2250 TFLOPS BF16 dense (same GB202 die as B200) flops: 2.25e15, // 2250 TFLOPS BF16 dense (same GB202 die as B200)
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi (50% more than B200 8-Hi) fp8_flops: 6.75e15, // 6750 TFLOPS FP8 dense (estimated from FP4/2)
hbm: 288.0e9, // 288 GB HBM3e 12-Hi fp4_flops: 13.5e15, // 13500 TFLOPS FP4 dense (Blackwell Ultra enhanced)
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi
hbm: 288.0e9, // 288 GB HBM3e 12-Hi
pcie_gen: 6, pcie_gen: 6,
}; };
@@ -165,6 +183,8 @@ fn make_config(n: u32, base: &GpuBase) -> HardwareConfig {
HardwareConfig { HardwareConfig {
gpu_flops: base.flops * f, gpu_flops: base.flops * f,
gpu_fp8_flops: base.fp8_flops * f,
gpu_fp4_flops: base.fp4_flops * f,
gpu_mem_bw: base.mem_bw * f, gpu_mem_bw: base.mem_bw * f,
hbm_bytes: base.hbm * f, hbm_bytes: base.hbm * f,
dram_bytes: dram, dram_bytes: dram,

View File

@@ -75,7 +75,8 @@ impl ComputeModel {
let n_kv = model.num_kv_heads as f64; let n_kv = model.num_kv_heads as f64;
let hd = model.head_dim as f64; let hd = model.head_dim as f64;
let inter = model.intermediate_size.unwrap_or(0) as f64; let inter = model.intermediate_size.unwrap_or(0) as f64;
let dtype = model.dtype_bytes as f64; // Weight dtype for memory-bound check (separate from KV cache dtype).
let wdtype = model.weight_dtype_bytes();
// --- Attention linear FLOPs/token/layer --- // --- Attention linear FLOPs/token/layer ---
let attn_linear = if let Some(mla) = &model.mla { let attn_linear = if let Some(mla) = &model.mla {
@@ -134,18 +135,18 @@ impl ComputeModel {
(h * qlr + qlr * n_heads * qk_hd (h * qlr + qlr * n_heads * qk_hd
+ h * (kvlr + qk_rd) + h * (kvlr + qk_rd)
+ n_heads * vhd * h) + n_heads * vhd * h)
* dtype * wdtype
} else { } else {
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype ((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype
}; };
let mlp_wt = if let Some(moe) = &model.moe { let mlp_wt = if let Some(moe) = &model.moe {
let expert_inter = moe.expert_intermediate_size let expert_inter = moe.expert_intermediate_size
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64; .unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
let active = moe.num_active_experts as f64; let active = moe.num_active_experts as f64;
let shared = moe.num_shared_experts as f64; let shared = moe.num_shared_experts as f64;
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype (active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype
} else { } else {
3.0 * h * inter * dtype 3.0 * h * inter * wdtype
}; };
let weight_bytes = attn_wt + mlp_wt; let weight_bytes = attn_wt + mlp_wt;
@@ -385,6 +386,8 @@ mod tests {
}; };
let hw = HardwareConfig { let hw = HardwareConfig {
gpu_flops: 1e14, gpu_flops: 1e14,
gpu_fp8_flops: 0.0,
gpu_fp4_flops: 0.0,
gpu_mem_bw: 1e12, gpu_mem_bw: 1e12,
hbm_bytes: 1e9, hbm_bytes: 1e9,
dram_bytes: 4e9, dram_bytes: 4e9,

View File

@@ -22,6 +22,8 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
}, },
hardware: HardwareConfig { hardware: HardwareConfig {
gpu_flops: 1.0e14, gpu_flops: 1.0e14,
gpu_fp8_flops: 0.0,
gpu_fp4_flops: 0.0,
gpu_mem_bw: 1.0e12, gpu_mem_bw: 1.0e12,
hbm_bytes: 1.0e9, hbm_bytes: 1.0e9,
dram_bytes: 4.0e9, dram_bytes: 4.0e9,