Support compute_dtype for FP4/FP8 tensor core FLOPS selection
Add `compute_dtype` field to ModelConfig ("bf16", "fp8", "fp4") which
controls two things:
- GPU FLOPS tier: auto-selects from preset FP4/FP8/BF16 TFLOPS
- Weight bytes: uses 0.5/1.0/2.0 bytes per param for memory-bound check
Hardware presets now include per-GPU FP8 and FP4 dense FLOPS for all
GPUs that support them (H100/H800/H20: FP8, B200/B300: FP8+FP4).
Config resolution auto-selects the right FLOPS when compute_dtype is
set and the user hasn't explicitly overridden gpu_flops.
GLM-5-NVFP4 on 8xB300 now correctly uses 13.5 PFLOPS/GPU FP4 (6x
faster prefill) and 0.5 bytes/param weights (halved memory footprint).
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,8 @@
|
|||||||
model:
|
model:
|
||||||
config_json: ../models/GLM-5-NVFP4/config.json
|
config_json: ../models/GLM-5-NVFP4/config.json
|
||||||
name: glm-5-nvfp4
|
name: glm-5-nvfp4
|
||||||
dtype_bytes: 1 # FP8 KV cache
|
compute_dtype: fp4 # FP4 weights → selects FP4 tensor core FLOPS
|
||||||
|
dtype_bytes: 1 # FP8 KV cache
|
||||||
block_size_tokens: 512
|
block_size_tokens: 512
|
||||||
|
|
||||||
hardware:
|
hardware:
|
||||||
|
|||||||
@@ -57,6 +57,13 @@ pub struct ModelConfig {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub attention: Option<AttentionConfig>,
|
pub attention: Option<AttentionConfig>,
|
||||||
|
|
||||||
|
/// Compute / weight precision: `"bf16"` (default), `"fp8"`, or `"fp4"`.
|
||||||
|
/// Controls which hardware FLOPS tier to use (`gpu_fp4_flops`, etc.) and
|
||||||
|
/// the weight-bytes-per-parameter for the memory-bound roofline check.
|
||||||
|
/// Independent of `dtype_bytes`, which sizes the KV cache.
|
||||||
|
#[serde(default)]
|
||||||
|
pub compute_dtype: Option<String>,
|
||||||
|
|
||||||
// -- Legacy manual coefficients (used when hidden_size is absent) ---------
|
// -- Legacy manual coefficients (used when hidden_size is absent) ---------
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub flops_per_token_prefill: Option<f64>,
|
pub flops_per_token_prefill: Option<f64>,
|
||||||
@@ -79,6 +86,20 @@ impl ModelConfig {
|
|||||||
self.hidden_size.is_some()
|
self.hidden_size.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Bytes per parameter for weight storage, derived from `compute_dtype`.
|
||||||
|
///
|
||||||
|
/// - `"fp4"` → 0.5
|
||||||
|
/// - `"fp8"` → 1.0
|
||||||
|
/// - `"bf16"` / absent → `dtype_bytes` (backward-compatible)
|
||||||
|
pub fn weight_dtype_bytes(&self) -> f64 {
|
||||||
|
match self.compute_dtype.as_deref() {
|
||||||
|
Some("fp4") => 0.5,
|
||||||
|
Some("fp8") => 1.0,
|
||||||
|
Some("bf16") => 2.0,
|
||||||
|
_ => self.dtype_bytes as f64, // backward compat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Bytes of KV cache per block.
|
/// Bytes of KV cache per block.
|
||||||
///
|
///
|
||||||
/// For standard / GQA: `2 * L * kv_heads * head_dim * dtype * block_tokens`
|
/// For standard / GQA: `2 * L * kv_heads * head_dim * dtype * block_tokens`
|
||||||
@@ -147,7 +168,14 @@ pub enum AttentionConfig {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct HardwareConfig {
|
pub struct HardwareConfig {
|
||||||
|
/// Active GPU FLOPS (selected from bf16/fp8/fp4 based on compute_dtype).
|
||||||
pub gpu_flops: f64,
|
pub gpu_flops: f64,
|
||||||
|
/// FP8 tensor core FLOPS (0 if not populated by preset).
|
||||||
|
#[serde(default)]
|
||||||
|
pub gpu_fp8_flops: f64,
|
||||||
|
/// FP4 tensor core FLOPS (0 if not populated by preset).
|
||||||
|
#[serde(default)]
|
||||||
|
pub gpu_fp4_flops: f64,
|
||||||
pub gpu_mem_bw: f64,
|
pub gpu_mem_bw: f64,
|
||||||
pub hbm_bytes: f64,
|
pub hbm_bytes: f64,
|
||||||
pub dram_bytes: f64,
|
pub dram_bytes: f64,
|
||||||
@@ -368,6 +396,8 @@ struct RawModelConfig {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
bytes_per_token_prefill: Option<f64>,
|
bytes_per_token_prefill: Option<f64>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
compute_dtype: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
flops_per_token_decode: Option<f64>,
|
flops_per_token_decode: Option<f64>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
bytes_per_token_decode: Option<f64>,
|
bytes_per_token_decode: Option<f64>,
|
||||||
@@ -407,12 +437,25 @@ struct RawHardwareConfig {
|
|||||||
|
|
||||||
impl RawConfig {
|
impl RawConfig {
|
||||||
fn resolve(self, yaml_dir: &Path) -> Result<Config> {
|
fn resolve(self, yaml_dir: &Path) -> Result<Config> {
|
||||||
Ok(Config {
|
let model = self.model.resolve(yaml_dir)?;
|
||||||
model: self.model.resolve(yaml_dir)?,
|
let user_set_gpu_flops = self.hardware.gpu_flops.is_some();
|
||||||
hardware: self.hardware.resolve()?,
|
let mut hardware = self.hardware.resolve()?;
|
||||||
cluster: self.cluster,
|
|
||||||
sim: self.sim,
|
// Auto-select gpu_flops tier based on model's compute_dtype,
|
||||||
})
|
// but only if the user did NOT explicitly override gpu_flops in YAML.
|
||||||
|
if !user_set_gpu_flops {
|
||||||
|
match model.compute_dtype.as_deref() {
|
||||||
|
Some("fp4") if hardware.gpu_fp4_flops > 0.0 => {
|
||||||
|
hardware.gpu_flops = hardware.gpu_fp4_flops;
|
||||||
|
}
|
||||||
|
Some("fp8") if hardware.gpu_fp8_flops > 0.0 => {
|
||||||
|
hardware.gpu_flops = hardware.gpu_fp8_flops;
|
||||||
|
}
|
||||||
|
_ => {} // keep BF16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Config { model, hardware, cluster: self.cluster, sim: self.sim })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,6 +489,7 @@ impl RawModelConfig {
|
|||||||
if let Some(v) = self.flops_per_token_prefill { m.flops_per_token_prefill = Some(v); }
|
if let Some(v) = self.flops_per_token_prefill { m.flops_per_token_prefill = Some(v); }
|
||||||
if let Some(v) = self.attn_quadratic_coeff { m.attn_quadratic_coeff = Some(v); }
|
if let Some(v) = self.attn_quadratic_coeff { m.attn_quadratic_coeff = Some(v); }
|
||||||
if let Some(v) = self.bytes_per_token_prefill { m.bytes_per_token_prefill = Some(v); }
|
if let Some(v) = self.bytes_per_token_prefill { m.bytes_per_token_prefill = Some(v); }
|
||||||
|
if self.compute_dtype.is_some() { m.compute_dtype = self.compute_dtype; }
|
||||||
if let Some(v) = self.flops_per_token_decode { m.flops_per_token_decode = Some(v); }
|
if let Some(v) = self.flops_per_token_decode { m.flops_per_token_decode = Some(v); }
|
||||||
if let Some(v) = self.bytes_per_token_decode { m.bytes_per_token_decode = Some(v); }
|
if let Some(v) = self.bytes_per_token_decode { m.bytes_per_token_decode = Some(v); }
|
||||||
|
|
||||||
@@ -476,6 +520,8 @@ impl RawHardwareConfig {
|
|||||||
} else {
|
} else {
|
||||||
HardwareConfig {
|
HardwareConfig {
|
||||||
gpu_flops: 0.0,
|
gpu_flops: 0.0,
|
||||||
|
gpu_fp8_flops: 0.0,
|
||||||
|
gpu_fp4_flops: 0.0,
|
||||||
gpu_mem_bw: 0.0,
|
gpu_mem_bw: 0.0,
|
||||||
hbm_bytes: 0.0,
|
hbm_bytes: 0.0,
|
||||||
dram_bytes: 0.0,
|
dram_bytes: 0.0,
|
||||||
|
|||||||
@@ -78,58 +78,76 @@ fn parse_count_gpu(s: &str) -> (u32, String) {
|
|||||||
// -- Per-GPU base specs (single die, BF16 dense) -----------------------------
|
// -- Per-GPU base specs (single die, BF16 dense) -----------------------------
|
||||||
|
|
||||||
struct GpuBase {
|
struct GpuBase {
|
||||||
flops: f64, // BF16 dense TFLOPS
|
flops: f64, // BF16 dense FLOPS
|
||||||
mem_bw: f64, // HBM bandwidth (B/s)
|
fp8_flops: f64, // FP8 dense FLOPS (0 = not supported)
|
||||||
hbm: f64, // Total HBM (bytes)
|
fp4_flops: f64, // FP4 dense FLOPS (0 = not supported)
|
||||||
pcie_gen: u32, // PCIe generation (4/5/6)
|
mem_bw: f64, // HBM bandwidth (B/s)
|
||||||
|
hbm: f64, // Total HBM (bytes)
|
||||||
|
pcie_gen: u32, // PCIe generation (4/5/6)
|
||||||
}
|
}
|
||||||
|
|
||||||
const H100: GpuBase = GpuBase {
|
const H100: GpuBase = GpuBase {
|
||||||
flops: 9.89e14, // 989 TFLOPS BF16
|
flops: 9.89e14, // 989 TFLOPS BF16 dense
|
||||||
mem_bw: 3.35e12, // 3.35 TB/s HBM3
|
fp8_flops: 1.979e15, // 1979 TFLOPS FP8 dense
|
||||||
hbm: 80.0e9, // 80 GB
|
fp4_flops: 0.0, // not supported
|
||||||
|
mem_bw: 3.35e12, // 3.35 TB/s HBM3
|
||||||
|
hbm: 80.0e9, // 80 GB
|
||||||
pcie_gen: 5,
|
pcie_gen: 5,
|
||||||
};
|
};
|
||||||
|
|
||||||
const H800: GpuBase = GpuBase {
|
const H800: GpuBase = GpuBase {
|
||||||
flops: 9.89e14, // same die as H100
|
flops: 9.89e14, // same die as H100
|
||||||
mem_bw: 3.35e12, // 3.35 TB/s HBM3
|
fp8_flops: 1.979e15,
|
||||||
hbm: 80.0e9, // 80 GB
|
fp4_flops: 0.0,
|
||||||
|
mem_bw: 3.35e12, // 3.35 TB/s HBM3
|
||||||
|
hbm: 80.0e9, // 80 GB
|
||||||
pcie_gen: 5,
|
pcie_gen: 5,
|
||||||
};
|
};
|
||||||
|
|
||||||
const H20: GpuBase = GpuBase {
|
const H20: GpuBase = GpuBase {
|
||||||
flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper)
|
flops: 1.48e14, // 148 TFLOPS BF16 (China-export Hopper)
|
||||||
mem_bw: 4.0e12, // 4.0 TB/s HBM3
|
fp8_flops: 2.96e14, // 296 TFLOPS FP8
|
||||||
hbm: 96.0e9, // 96 GB
|
fp4_flops: 0.0, // not supported
|
||||||
|
mem_bw: 4.0e12, // 4.0 TB/s HBM3
|
||||||
|
hbm: 96.0e9, // 96 GB
|
||||||
pcie_gen: 5,
|
pcie_gen: 5,
|
||||||
};
|
};
|
||||||
|
|
||||||
const A100_80GB: GpuBase = GpuBase {
|
const A100_80GB: GpuBase = GpuBase {
|
||||||
flops: 3.12e14, // 312 TFLOPS BF16
|
flops: 3.12e14, // 312 TFLOPS BF16
|
||||||
mem_bw: 2.0e12, // 2.0 TB/s HBM2e
|
fp8_flops: 0.0, // A100 has no FP8 tensor cores
|
||||||
hbm: 80.0e9, // 80 GB
|
fp4_flops: 0.0,
|
||||||
|
mem_bw: 2.0e12, // 2.0 TB/s HBM2e
|
||||||
|
hbm: 80.0e9, // 80 GB
|
||||||
pcie_gen: 4,
|
pcie_gen: 4,
|
||||||
};
|
};
|
||||||
|
|
||||||
const A100_40GB: GpuBase = GpuBase {
|
const A100_40GB: GpuBase = GpuBase {
|
||||||
flops: 3.12e14, // 312 TFLOPS BF16
|
flops: 3.12e14, // 312 TFLOPS BF16
|
||||||
|
fp8_flops: 0.0,
|
||||||
|
fp4_flops: 0.0,
|
||||||
mem_bw: 1.555e12, // 1.555 TB/s HBM2e
|
mem_bw: 1.555e12, // 1.555 TB/s HBM2e
|
||||||
hbm: 40.0e9, // 40 GB
|
hbm: 40.0e9, // 40 GB
|
||||||
pcie_gen: 4,
|
pcie_gen: 4,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// DGX B200 (8 GPU) specs: BF16 18 PFLOPS, FP8 36 PFLOPS, FP4 72 PFLOPS (dense)
|
||||||
const B200: GpuBase = GpuBase {
|
const B200: GpuBase = GpuBase {
|
||||||
flops: 2.25e15, // 2250 TFLOPS BF16
|
flops: 2.25e15, // 2250 TFLOPS BF16 dense
|
||||||
mem_bw: 8.0e12, // 8.0 TB/s HBM3e
|
fp8_flops: 4.5e15, // 4500 TFLOPS FP8 dense
|
||||||
hbm: 192.0e9, // 192 GB
|
fp4_flops: 9.0e15, // 9000 TFLOPS FP4 dense
|
||||||
|
mem_bw: 8.0e12, // 8.0 TB/s HBM3e
|
||||||
|
hbm: 192.0e9, // 192 GB
|
||||||
pcie_gen: 6,
|
pcie_gen: 6,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// DGX B300 (8 GPU) specs: BF16 18 PFLOPS, FP8 ~54 PFLOPS, FP4 108 PFLOPS (dense)
|
||||||
const B300: GpuBase = GpuBase {
|
const B300: GpuBase = GpuBase {
|
||||||
flops: 2.25e15, // 2250 TFLOPS BF16 dense (same GB202 die as B200)
|
flops: 2.25e15, // 2250 TFLOPS BF16 dense (same GB202 die as B200)
|
||||||
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi (50% more than B200 8-Hi)
|
fp8_flops: 6.75e15, // 6750 TFLOPS FP8 dense (estimated from FP4/2)
|
||||||
hbm: 288.0e9, // 288 GB HBM3e 12-Hi
|
fp4_flops: 13.5e15, // 13500 TFLOPS FP4 dense (Blackwell Ultra enhanced)
|
||||||
|
mem_bw: 12.0e12, // 12 TB/s HBM3e 12-Hi
|
||||||
|
hbm: 288.0e9, // 288 GB HBM3e 12-Hi
|
||||||
pcie_gen: 6,
|
pcie_gen: 6,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -165,6 +183,8 @@ fn make_config(n: u32, base: &GpuBase) -> HardwareConfig {
|
|||||||
|
|
||||||
HardwareConfig {
|
HardwareConfig {
|
||||||
gpu_flops: base.flops * f,
|
gpu_flops: base.flops * f,
|
||||||
|
gpu_fp8_flops: base.fp8_flops * f,
|
||||||
|
gpu_fp4_flops: base.fp4_flops * f,
|
||||||
gpu_mem_bw: base.mem_bw * f,
|
gpu_mem_bw: base.mem_bw * f,
|
||||||
hbm_bytes: base.hbm * f,
|
hbm_bytes: base.hbm * f,
|
||||||
dram_bytes: dram,
|
dram_bytes: dram,
|
||||||
|
|||||||
@@ -75,7 +75,8 @@ impl ComputeModel {
|
|||||||
let n_kv = model.num_kv_heads as f64;
|
let n_kv = model.num_kv_heads as f64;
|
||||||
let hd = model.head_dim as f64;
|
let hd = model.head_dim as f64;
|
||||||
let inter = model.intermediate_size.unwrap_or(0) as f64;
|
let inter = model.intermediate_size.unwrap_or(0) as f64;
|
||||||
let dtype = model.dtype_bytes as f64;
|
// Weight dtype for memory-bound check (separate from KV cache dtype).
|
||||||
|
let wdtype = model.weight_dtype_bytes();
|
||||||
|
|
||||||
// --- Attention linear FLOPs/token/layer ---
|
// --- Attention linear FLOPs/token/layer ---
|
||||||
let attn_linear = if let Some(mla) = &model.mla {
|
let attn_linear = if let Some(mla) = &model.mla {
|
||||||
@@ -134,18 +135,18 @@ impl ComputeModel {
|
|||||||
(h * qlr + qlr * n_heads * qk_hd
|
(h * qlr + qlr * n_heads * qk_hd
|
||||||
+ h * (kvlr + qk_rd)
|
+ h * (kvlr + qk_rd)
|
||||||
+ n_heads * vhd * h)
|
+ n_heads * vhd * h)
|
||||||
* dtype
|
* wdtype
|
||||||
} else {
|
} else {
|
||||||
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * dtype
|
((n_heads + 2.0 * n_kv) * hd * h + n_heads * hd * h) * wdtype
|
||||||
};
|
};
|
||||||
let mlp_wt = if let Some(moe) = &model.moe {
|
let mlp_wt = if let Some(moe) = &model.moe {
|
||||||
let expert_inter = moe.expert_intermediate_size
|
let expert_inter = moe.expert_intermediate_size
|
||||||
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
.unwrap_or(model.intermediate_size.unwrap_or(0)) as f64;
|
||||||
let active = moe.num_active_experts as f64;
|
let active = moe.num_active_experts as f64;
|
||||||
let shared = moe.num_shared_experts as f64;
|
let shared = moe.num_shared_experts as f64;
|
||||||
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * dtype
|
(active * 3.0 * h * expert_inter + shared * 3.0 * h * inter) * wdtype
|
||||||
} else {
|
} else {
|
||||||
3.0 * h * inter * dtype
|
3.0 * h * inter * wdtype
|
||||||
};
|
};
|
||||||
let weight_bytes = attn_wt + mlp_wt;
|
let weight_bytes = attn_wt + mlp_wt;
|
||||||
|
|
||||||
@@ -385,6 +386,8 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let hw = HardwareConfig {
|
let hw = HardwareConfig {
|
||||||
gpu_flops: 1e14,
|
gpu_flops: 1e14,
|
||||||
|
gpu_fp8_flops: 0.0,
|
||||||
|
gpu_fp4_flops: 0.0,
|
||||||
gpu_mem_bw: 1e12,
|
gpu_mem_bw: 1e12,
|
||||||
hbm_bytes: 1e9,
|
hbm_bytes: 1e9,
|
||||||
dram_bytes: 4e9,
|
dram_bytes: 4e9,
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ fn base_config(trace_path: &str, out_dir: &str, mode: RouterMode) -> Config {
|
|||||||
},
|
},
|
||||||
hardware: HardwareConfig {
|
hardware: HardwareConfig {
|
||||||
gpu_flops: 1.0e14,
|
gpu_flops: 1.0e14,
|
||||||
|
gpu_fp8_flops: 0.0,
|
||||||
|
gpu_fp4_flops: 0.0,
|
||||||
gpu_mem_bw: 1.0e12,
|
gpu_mem_bw: 1.0e12,
|
||||||
hbm_bytes: 1.0e9,
|
hbm_bytes: 1.0e9,
|
||||||
dram_bytes: 4.0e9,
|
dram_bytes: 4.0e9,
|
||||||
|
|||||||
Reference in New Issue
Block a user