Files
xserv/tools/bench_fp8.py
Gahow Wang 3a530956af tools: add FP8 vs BF16 benchmark and GSM8K eval harness
bench_fp8.py — head-to-head comparison of FP8 and BF16 models on
  GSM8K / AIME2025 accuracy plus TTFT/TPOT performance measurement.

eval_gsm8k_batch.sh — lightweight GSM8K accuracy evaluator that
  pipes one problem per xserv-chat invocation and scores with
  \boxed{} / last-number extraction.

Benchmark results (gpt-oss-20b, 50-problem GSM8K):
  FP8 W8A8 TP1 : 94.0%  (single RTX 5090, 25 GB)
  FP8 W8A16 TP1: 94.0%
  BF16 TP2     : 94.0%  (requires 2× RTX 5090)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-06-08 15:43:04 +08:00

251 lines
9.7 KiB
Python

#!/usr/bin/env python3
"""Benchmark FP8 vs BF16: accuracy (GSM8K, AIME2025) and performance (TTFT/TPOT).
Usage:
python bench_fp8.py --fp8 <model-dir> --bf16 <model-dir> [options]
Measures:
- Accuracy on GSM8K (100 problems) and AIME2025 (30 problems)
- TTFT: Time to first token (prefill latency, measured with max_tokens=1)
- TPOT: Time per output token (decode throughput, measured from generation)
"""
import argparse
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent
GSM8K_PATH = SCRIPT_DIR / "bench" / "data" / "gsm8k.json"
AIME_PATH = SCRIPT_DIR / "bench" / "data" / "aime2025.json"
XSERV_CHAT = SCRIPT_DIR.parent / "target" / "release" / "xserv-chat"
SYSTEM_PROMPT_MATH = "Solve the problem step by step. Put your final numeric answer inside \\boxed{}."
PERF_PROMPT = "Write a detailed explanation of how neural networks learn through backpropagation, covering the chain rule, gradient descent, and weight updates."
_BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}")
_NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?")
def normalize_num(s):
s = s.replace(",", "").strip()
try:
f = float(s)
except ValueError:
return None
return str(int(f)) if f == int(f) else f"{f:g}"
def extract_answer(text):
if not text:
return None
boxed = _BOXED_RE.findall(text)
if boxed:
nums = _NUM_RE.findall(boxed[-1])
if nums:
return normalize_num(nums[-1])
nums = _NUM_RE.findall(text)
if nums:
return normalize_num(nums[-1])
return None
def run_chat(model_dir, question, max_tokens, max_seq_len, tp, system=None):
"""Run xserv-chat with a single question, return (output_text, elapsed_sec)."""
cmd = [str(XSERV_CHAT), model_dir, "--max-tokens", str(max_tokens),
"--max-seq-len", str(max_seq_len), "--no-color"]
if tp > 1:
cmd += ["--tp", str(tp)]
if system:
cmd += ["--system", system]
t0 = time.perf_counter()
proc = subprocess.run(cmd, input=question + "\n", capture_output=True,
text=True, timeout=300)
elapsed = time.perf_counter() - t0
output = proc.stdout
response = ""
if "assistant>" in output:
parts = output.split("assistant>", 1)
if len(parts) > 1:
rest = parts[1]
if "user>" in rest:
response = rest[:rest.rindex("user>")].strip()
else:
response = rest.strip()
return response, elapsed
def count_tokens_approx(text):
"""Rough token count estimate (words * 1.3)."""
return max(1, int(len(text.split()) * 1.3))
def run_accuracy(model_dir, dataset_path, task_name, limit, tp, max_tokens):
"""Run accuracy evaluation on a dataset."""
with open(dataset_path) as f:
problems = json.load(f)[:limit]
correct = 0
total = len(problems)
total_time = 0.0
total_gen_tokens = 0
print(f" [{task_name}] Running {total} problems (max_tokens={max_tokens})...")
for i, prob in enumerate(problems):
question = prob["problem"].replace("\n", " ")
try:
resp, elapsed = run_chat(model_dir, question, max_tokens, 2048, tp, SYSTEM_PROMPT_MATH)
total_time += elapsed
total_gen_tokens += count_tokens_approx(resp)
pred = extract_answer(resp)
gold = normalize_num(prob["answer"])
is_correct = pred is not None and gold is not None and pred == gold
if is_correct:
correct += 1
mark = "" if is_correct else ""
print(f" [{mark}] {i+1:3d}/{total} gold={prob['answer']:>8s} pred={str(pred):>8s} {elapsed:.1f}s")
except subprocess.TimeoutExpired:
print(f" [T] {i+1:3d}/{total} TIMEOUT")
except Exception as e:
print(f" [E] {i+1:3d}/{total} {e}")
accuracy = correct / total if total > 0 else 0
avg_time = total_time / total if total > 0 else 0
return {
"task": task_name, "correct": correct, "total": total,
"accuracy": accuracy, "total_time": total_time,
"avg_time_per_problem": avg_time, "total_gen_tokens": total_gen_tokens,
}
def run_perf(model_dir, tp, n_runs=5):
"""Measure TTFT and TPOT."""
# TTFT: measure prefill time with max_tokens=1
ttft_times = []
for i in range(n_runs):
_, elapsed = run_chat(model_dir, PERF_PROMPT, 1, 2048, tp, None)
ttft_times.append(elapsed)
print(f" TTFT run {i+1}: {elapsed:.3f}s")
# TPOT: generate 128 tokens and measure decode rate
tpot_times = []
gen_tokens_list = []
for i in range(n_runs):
resp, elapsed = run_chat(model_dir, PERF_PROMPT, 128, 2048, tp, None)
tokens = count_tokens_approx(resp)
gen_tokens_list.append(tokens)
# TPOT = (total - ttft) / (tokens - 1) approximately
# But we reload model each time, so elapsed includes model load.
# Subtract median TTFT (which also includes load) as approximation.
tpot_times.append(elapsed)
print(f" Gen run {i+1}: {elapsed:.3f}s, ~{tokens} tokens")
# Since each run includes model load, the relative difference (FP8 vs BF16)
# still shows the decode speedup. Report raw times.
median_ttft = sorted(ttft_times)[len(ttft_times) // 2]
median_gen = sorted(tpot_times)[len(tpot_times) // 2]
median_tokens = sorted(gen_tokens_list)[len(gen_tokens_list) // 2]
# Approximate TPOT: (gen_time - ttft_time) / tokens
# This accounts for model load being roughly the same in both.
approx_decode_time = median_gen - median_ttft
approx_tpot = approx_decode_time / max(median_tokens - 1, 1)
return {
"median_ttft_s": median_ttft,
"median_gen128_s": median_gen,
"median_tokens": median_tokens,
"approx_decode_time_s": approx_decode_time,
"approx_tpot_ms": approx_tpot * 1000,
"approx_tok_per_s": max(median_tokens - 1, 1) / max(approx_decode_time, 0.001),
}
def main():
parser = argparse.ArgumentParser(description="FP8 vs BF16 benchmark")
parser.add_argument("--fp8", required=True, help="FP8 model directory")
parser.add_argument("--bf16", required=True, help="BF16 model directory")
parser.add_argument("--fp8-tp", type=int, default=1, help="TP for FP8 model")
parser.add_argument("--bf16-tp", type=int, default=2, help="TP for BF16 model")
parser.add_argument("--fp8-gpu", type=str, default="2", help="GPU for FP8")
parser.add_argument("--bf16-gpu", type=str, default="4,5", help="GPUs for BF16")
parser.add_argument("--gsm8k-limit", type=int, default=100, help="GSM8K problems")
parser.add_argument("--skip-perf", action="store_true")
parser.add_argument("--skip-accuracy", action="store_true")
args = parser.parse_args()
results = {}
for label, model_dir, tp, gpu in [
("FP8_W8A8", args.fp8, args.fp8_tp, args.fp8_gpu),
("BF16", args.bf16, args.bf16_tp, args.bf16_gpu),
]:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
print(f"\n{'='*72}")
print(f" Model: {label} (tp={tp}, gpu={gpu})")
print(f" Path: {model_dir}")
print(f"{'='*72}")
results[label] = {}
if not args.skip_accuracy:
print(f"\n --- Accuracy ---")
r_gsm = run_accuracy(model_dir, str(GSM8K_PATH), "gsm8k", args.gsm8k_limit, tp, 512)
results[label]["gsm8k"] = r_gsm
print(f" GSM8K: {r_gsm['correct']}/{r_gsm['total']} = {r_gsm['accuracy']*100:.1f}%")
r_aime = run_accuracy(model_dir, str(AIME_PATH), "aime2025", 30, tp, 2048)
results[label]["aime2025"] = r_aime
print(f" AIME2025: {r_aime['correct']}/{r_aime['total']} = {r_aime['accuracy']*100:.1f}%")
if not args.skip_perf:
print(f"\n --- Performance ---")
perf = run_perf(model_dir, tp, n_runs=5)
results[label]["perf"] = perf
print(f" TTFT (median): {perf['median_ttft_s']:.3f}s")
print(f" TPOT (approx): {perf['approx_tpot_ms']:.1f}ms")
print(f" Throughput: {perf['approx_tok_per_s']:.1f} tok/s")
# Final comparison table
print(f"\n{'='*72}")
print(" COMPARISON SUMMARY")
print(f"{'='*72}")
print(f"{'Metric':<30s} {'FP8_W8A8':>12s} {'BF16':>12s}")
print("-" * 56)
if not args.skip_accuracy:
for task in ["gsm8k", "aime2025"]:
if task in results.get("FP8_W8A8", {}) and task in results.get("BF16", {}):
fp8_acc = results["FP8_W8A8"][task]["accuracy"] * 100
bf16_acc = results["BF16"][task]["accuracy"] * 100
print(f"{task + ' accuracy':<30s} {fp8_acc:>11.1f}% {bf16_acc:>11.1f}%")
if not args.skip_perf:
if "perf" in results.get("FP8_W8A8", {}) and "perf" in results.get("BF16", {}):
fp8_p = results["FP8_W8A8"]["perf"]
bf16_p = results["BF16"]["perf"]
print(f"{'TTFT (s)':<30s} {fp8_p['median_ttft_s']:>12.3f} {bf16_p['median_ttft_s']:>12.3f}")
print(f"{'TPOT (ms)':<30s} {fp8_p['approx_tpot_ms']:>12.1f} {bf16_p['approx_tpot_ms']:>12.1f}")
print(f"{'Throughput (tok/s)':<30s} {fp8_p['approx_tok_per_s']:>12.1f} {bf16_p['approx_tok_per_s']:>12.1f}")
speedup = fp8_p['approx_tok_per_s'] / max(bf16_p['approx_tok_per_s'], 0.1)
print(f"{'Decode speedup':<30s} {speedup:>12.2f}x {'(baseline)':>12s}")
print(f"\n{'='*72}")
# Save results
out_path = SCRIPT_DIR.parent / "bench-out" / f"fp8_bench_{int(time.time())}.json"
out_path.parent.mkdir(exist_ok=True)
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to: {out_path}")
if __name__ == "__main__":
main()