Support compute_dtype for FP4/FP8 tensor core FLOPS selection
Add `compute_dtype` field to ModelConfig ("bf16", "fp8", "fp4") which
controls two things:
- GPU FLOPS tier: auto-selects from preset FP4/FP8/BF16 TFLOPS
- Weight bytes: uses 0.5/1.0/2.0 bytes per param for memory-bound check
Hardware presets now include per-GPU FP8 and FP4 dense FLOPS for all
GPUs that support them (H100/H800/H20: FP8, B200/B300: FP8+FP4).
Config resolution auto-selects the right FLOPS when compute_dtype is
set and the user hasn't explicitly overridden gpu_flops.
GLM-5-NVFP4 on 8xB300 now correctly uses 13.5 PFLOPS/GPU FP4 (6x
faster prefill) and 0.5 bytes/param weights (halved memory footprint).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
model:
|
||||
config_json: ../models/GLM-5-NVFP4/config.json
|
||||
name: glm-5-nvfp4
|
||||
compute_dtype: fp4 # FP4 weights → selects FP4 tensor core FLOPS
|
||||
dtype_bytes: 1 # FP8 KV cache
|
||||
block_size_tokens: 512
|
||||
|
||||
|
||||
@@ -57,6 +57,13 @@ pub struct ModelConfig {
|
||||
#[serde(default)]
|
||||
pub attention: Option<AttentionConfig>,
|
||||
|
||||
/// Compute / weight precision: `"bf16"` (default), `"fp8"`, or `"fp4"`.
|
||||
/// Controls which hardware FLOPS tier to use (`gpu_fp4_flops`, etc.) and
|
||||
/// the weight-bytes-per-parameter for the memory-bound roofline check.
|
||||
/// Independent of `dtype_bytes`, which sizes the KV cache.
|
||||
#[serde(default)]
|
||||
pub compute_dtype: Option<String>,
|
||||
|
||||
// -- Legacy manual coefficients (used when hidden_size is absent) ---------
|
||||
#[serde(default)]
|
||||
pub flops_per_token_prefill: Option<f64>,
|
||||
@@ -79,6 +86,20 @@ impl ModelConfig {
|
||||
self.hidden_size.is_some()
|
||||
}
|
||||
|
||||
/// Bytes per parameter for weight storage, derived from `compute_dtype`.
|
||||
///
|
||||
/// - `"fp4"` → 0.5
|
||||
/// - `"fp8"` → 1.0
|
||||
/// - `"bf16"` / absent → `dtype_bytes` (backward-compatible)
|
||||
pub fn weight_dtype_bytes(&self) -> f64 {
|
||||
match self.compute_dtype.as_deref() {
|
||||
Some("fp4") => 0.5,
|
||||
Some("fp8") => 1.0,
|
||||
Some("bf16") => 2.0,
|
||||
_ => self.dtype_bytes as f64, // backward compat
|
||||
}
|
||||
}
|
||||
|
||||
/// Bytes of KV cache per block.
|
||||
///
|
||||
/// For standard / GQA: `2 * L * kv_heads * head_dim * dtype * block_tokens`
|
||||
@@ -147,7 +168,14 @@ pub enum AttentionConfig {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HardwareConfig {
|
||||
/// Active GPU FLOPS (selected from bf16/fp8/fp4 based on compute_dtype).
|
||||
pub gpu_flops: f64,
|
||||
/// FP8 tensor core FLOPS (0 if not populated by preset).
|
||||
#[serde(default)]
|
||||
pub gpu_fp8_flops: f64,
|
||||
/// FP4 tensor core FLOPS (0 if not populated by preset).
|
||||
#[serde(default)]
|
||||
pub gpu_fp4_flops: f64,
|
||||
pub gpu_mem_bw: f64,
|
||||
pub hbm_bytes: f64,
|
||||
pub dram_bytes: f64,
|
||||
@@ -368,6 +396,8 @@ struct RawModelConfig {
|
||||
#[serde(default)]
|
||||
bytes_per_token_prefill: Option<f64>,
|
||||
#[serde(default)]
|
||||
compute_dtype: Option<String>,
|
||||
#[serde(default)]
|
||||
flops_per_token_decode: Option<f64>,
|
||||
#[serde(default)]
|
||||
bytes_per_token_decode: Option<f64>,
|
||||
@@ -407,12 +437,25 @@ struct RawHardwareConfig {
|
||||
|
||||
impl RawConfig {
|
||||
fn resolve(self, yaml_dir: &Path) -> Result<Config> {
|
||||
Ok(Config {
|
||||
model: self.model.resolve(yaml_dir)?,
|
||||
hardware: self.hardware.resolve()?,
|
||||
cluster: self.cluster,
|
||||
sim: self.sim,
|
||||
})
|
||||
let model = self.model.resolve(yaml_dir)?;
|
||||
let user_set_gpu_flops = self.hardware.gpu_flops.is_some();
|
||||
let mut hardware = self.hardware.resolve()?;
|
||||
|
||||
// Auto-select gpu_flops tier based on model's compute_dtype,
|
||||
// but only if the user did NOT explicitly override gpu_flops in YAML.
|
||||
if !user_set_gpu_flops {
|
||||
match model.compute_dtype.as_deref() {
|
||||
Some("fp4") if hardware.gpu_fp4_flops > 0.0 => {
|
||||
hardware.gpu_flops = hardware.gpu_fp4_flops;
|
||||
}
|
||||
Some("fp8") if hardware.gpu_fp8_flops > 0.0 => {
|
||||
hardware.gpu_flops = hardware.gpu_fp8_flops;
|
||||
}
|
||||
_ => {} // keep BF16
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Config { model, hardware, cluster: self.cluster, sim: self.sim })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -446,6 +489,7 @@ impl RawModelConfig {
|
||||
if let Some(v) = self.flops_per_token_prefill { m.flops_per_token_prefill = Some(v); }
|
||||
if let Some(v) = self.attn_quadratic_coeff { m.attn_quadratic_coeff = Some(v); }
|
||||
if let Some(v) = self.bytes_per_token_prefill { m.bytes_per_token_prefill = Some(v); }
|
||||
if self.compute_dtype.is_some() { m.compute_dtype = self.compute_dtype; }
|
||||
if let Some(v) = self.flops_per_token_decode { m.flops_per_token_decode = Some(v); }
|
||||
if let Some(v) = self.bytes_per_token_decode { m.bytes_per_token_decode = Some(v); }
|
||||
|
||||
@@ -476,6 +520,8 @@ impl RawHardwareConfig {
|
||||
} else {
|
||||
HardwareConfig {
|
||||
gpu_flops: 0.0,
|
||||
gpu_fp8_flops: 0.0,
|
||||
gpu_fp4_flops: 0.0,
|
||||
gpu_mem_bw: 0.0,
|
||||
hbm_bytes: 0.0,
|
||||
dram_bytes: 0.0,
|
||||
|
||||
@@ -78,14 +78,18 @@ fn parse_count_gpu(s: &str) -> (u32, String) {
|
||||
// -- Per-GPU base specs (single die, BF16 dense) -----------------------------
|
||||
|
||||
struct GpuBase {
|
||||
flops: f64, // BF16 dense TFLOPS
|
||||
flops: f64, // BF16 dense FLOPS
|
||||
fp8_flops: f64, // FP8 dense FLOPS (0 = not supported)
|
||||
fp4_flops: f64, // FP4 dense FLOPS (0 = not supported)
|
||||
mem_bw: f64, // HBM bandwidth (B/s)
|
||||
hbm: f64, // Total HBM (bytes)
|
||||
pcie_gen: u32, // PCIe generation (4/5/6)
|
||||
}
|
||||
|
||||
const H100: GpuBase = GpuBase {
|
||||
flops: 9.89e14, // 989 TFLOPS BF16
|
||||
flops: 9.89e14, // 989 TFLOPS BF16 dense
|
||||
fp8_flops: 1.979e15, // 1979 TFLOPS FP8 dense
|
||||
fp4_flops: 0.0, // not supported
|
||||
mem_bw: 3.35e12, // 3.35 TB/s HBM3
|
||||
hbm: 80.0e9, // 80 GB
|
||||
pcie_gen: 5,
|
||||
@@ -93,6 +97,8 @@ const H100: GpuBase = GpuBase {
|
||||
|
||||
const H800: GpuBase = GpuBase {
|
||||
flops: 9.89e14, // same die as H100
|
||||
fp8_flops: 1.979e15,
|
||||
fp4_flops: 0.0,
|
||||
mem_bw: 3.35e12, // 3.35 TB/s HBM3
|
||||
hbm: 80.0e9, // 80 GB
|
||||
pcie_gen: 5,
|
||||
@@ -100,6 +106,8 @@ const H800: GpuBase = GpuBase {
|
||||
|
||||
const H20: GpuBase = GpuBase {
|
||||
flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper)
|
||||
fp8_flops: 2.96e14, // 296 TFLOPS FP8
|
||||
fp4_flops: 0.0, // not supported
|
||||
mem_bw: 4.0e12, // 4.0 TB/s HBM3
|
||||
hbm: 96.0e9, // 96 GB
|
||||
pcie_gen: 5,
|
||||
@@ -107,6 +115,8 @@ const H20: GpuBase = GpuBase {
|
||||
|
||||
const A100_80GB: GpuBase = GpuBase {
|
||||
flops: 3.12e14, // 312 TFLOPS BF16
|
||||
fp8_flops: 0.0, // A100 has no FP8 tensor cores
|
||||
fp4_flops: 0.0,
|
||||
mem_bw: 2.0e12, // 2.0 TB/s HBM2e
|
||||
hbm: 80.0e9, // 80 GB
|
||||
pcie_gen: 4,
|
||||
@@ -114,21 +124,29 @@ const A100_80GB: GpuBase = GpuBase {
|
||||
|
||||
const A100_40GB: GpuBase = GpuBase {
|
||||
flops: 3.12e14, // 312 TFLOPS BF16
|
||||
fp8_flops: 0.0,
|
||||
fp4_flops: 0.0,
|
||||
mem_bw: 1.555e12, // 1.555 TB/s HBM2e
|
||||
hbm: 40.0e9, // 40 GB
|
||||
pcie_gen: 4,
|
||||
};
|
||||
|
||||
// DGX B200 (8 GPU) specs: BF16 18 PFLOPS, FP8 36 PFLOPS, FP4 72 PFLOPS (dense)
|
||||
const B200: GpuBase = GpuBase {
|
||||
flops: 2.25e15, // 2250 TFLOPS BF16
|
||||
flops: 2.25e15, // 2250 TFLOPS BF16 dense
|
||||
fp8_flops: 4.5e15, // 4500 TFLOPS FP8 dense
|
||||
fp4_flops: 9.0e15, // 9000 TFLOPS FP4 dense
|
||||
mem_bw: 8.0e12, // 8.0 TB/s HBM3e
|
||||
hbm: 192.0e9, // 192 GB
|
||||
pcie_gen: 6,
|
||||
};
|
||||
|
||||
// DGX B300 (8 GPU) specs: BF16 18 PFLOPS, FP8 ~54 PFLOPS, FP4 108 PFLOPS (dense)
|
||||
const B300: GpuBase = GpuBase {
|
||||
flops: 2.25e15, // 2250 TFLOPS BF16 dense (same GB202 die as B200)
|
||||
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi (50% more than B200 8-Hi)
|
||||
fp8_flops: 6.75e15, // 6750 TFLOPS FP8 dense (estimated from FP4/2)
|
||||
fp4_flops: 13.5e15, // 13500 TFLOPS FP4 dense (Blackwell Ultra enhanced)
|
||||
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi
|
||||
hbm: 288.0e9, // 288 GB HBM3e 12-Hi
|
||||
pcie_gen: 6,
|
||||
};
|
||||
@@ -165,6 +183,8 @@ fn make_config(n: u32, base: &GpuBase) -> HardwareConfig {
|
||||
|
||||
HardwareConfig {
|
||||
gpu_flops: base.flops * f,
|
||||
gpu_fp8_flops: base.fp8_flops * f,
|
||||
gpu_fp4_flops: base.fp4_flops * f,
|
||||
gpu_mem_bw: base.mem_bw * f,
|
||||
hbm_bytes: base.hbm * f,
|
||||
dram_bytes: dram,
|
||||
|
||||
@@ -75,7 +75,8 @@ impl ComputeModel {
|
||||
let n_kv = model.num_kv_heads as f64;
|
||||
let hd = model.head_dim as f64;
|
||||
let inter = model.intermediate_size.unwrap_or(0) as f64;
|
||||
let dtype = model.dtype_bytes as f64;
|
||||
// Weight dtype for memory-bound check (separate from KV cache dtype).
|
||||
let wdtype = model.weight_dtype_bytes();
|
||||
|
||||
// --- Attention linear FLOPs/token/layer ---
|
||||
let attn_linear = if let Some(mla) = &model.mla {
|
||||
@@ -134,18 +135,18 @@ impl ComputeModel {
|
||||
(h * qlr + qlr * n_heads * qk_hd
|
||||
+ h * (kvlr + qk_rd)
|
||||
+ n_heads * vhd * h)
|
||||
* dtype
|
||||
* wdtype
|
||||
} else {
|
||||
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype
|
||||
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype
|
||||
};
|
||||
let mlp_wt = if let Some(moe) = &model.moe {
|
||||
let expert_inter = moe.expert_intermediate_size
|
||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||
let active = moe.num_active_experts as f64;
|
||||
let shared = moe.num_shared_experts as f64;
|
||||
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype
|
||||
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype
|
||||
} else {
|
||||
3.0 * h * inter * dtype
|
||||
3.0 * h * inter * wdtype
|
||||
};
|
||||
let weight_bytes = attn_wt + mlp_wt;
|
||||
|
||||
@@ -385,6 +386,8 @@ mod tests {
|
||||
};
|
||||
let hw = HardwareConfig {
|
||||
gpu_flops: 1e14,
|
||||
gpu_fp8_flops: 0.0,
|
||||
gpu_fp4_flops: 0.0,
|
||||
gpu_mem_bw: 1e12,
|
||||
hbm_bytes: 1e9,
|
||||
dram_bytes: 4e9,
|
||||
|
||||
@@ -22,6 +22,8 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
|
||||
},
|
||||
hardware: HardwareConfig {
|
||||
gpu_flops: 1.0e14,
|
||||
gpu_fp8_flops: 0.0,
|
||||
gpu_fp4_flops: 0.0,
|
||||
gpu_mem_bw: 1.0e12,
|
||||
hbm_bytes: 1.0e9,
|
||||
dram_bytes: 4.0e9,
|
||||
|
||||
Reference in New Issue
Block a user