From 24c49c31c226d734a8eade9813a35bc41c55675d Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 12 Jun 2026 00:58:46 +0800 Subject: [PATCH] tools: warm-server FP8 vs BF16 benchmark + results doc fp8_compare.py launches one xserv-server per model (same GPUs / TP for a fair comparison), gates readiness on a real generation (not /health), and streams GSM8K through /v1/chat/completions measuring per-request TTFT (time to first token) and TPOT (mean inter-token latency) plus exact-match accuracy. docs/benchmarks/fp8-quantization.md records the quantization scheme, the perf-bug fix, and the dash5 results. Co-Authored-By: Claude Opus 4.8 --- docs/benchmarks/fp8-quantization.md | 83 ++++++++ tools/fp8_compare.py | 283 ++++++++++++++++++++++++++++ 2 files changed, 366 insertions(+) create mode 100644 docs/benchmarks/fp8-quantization.md create mode 100644 tools/fp8_compare.py diff --git a/docs/benchmarks/fp8-quantization.md b/docs/benchmarks/fp8-quantization.md new file mode 100644 index 0000000..30da244 --- /dev/null +++ b/docs/benchmarks/fp8-quantization.md @@ -0,0 +1,83 @@ +# FP8 W8A8 quantization — gpt-oss-20b (dash5, 8× RTX 5090) + +Operator-level FP8 E4M3 quantization of the MoE expert weights, with real +cuBLASLt FP8 tensor-core GEMM (W8A8: FP8 weights × dynamically-quantized FP8 +activations). All other tensors (attention, router, embeddings, norms, biases) +stay BF16. + +## Scheme + +- **Weights** (`tools/quantize_fp8.py`): expert `gate_up_proj` / `down_proj` + quantized BF16 → FP8 E4M3 with a **per-expert scalar** scale (`absmax/448`). + Stored transposed `[E, N, K]` because cuBLASLt FP8 on Blackwell (sm120) + requires `transA=T`. +- **Activations**: quantized dynamically at runtime, **per-token** (per-row + absmax), recovered by a post-GEMM row scale. +- **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`) + runs one cuBLASLt FP8 matmul per expert; the per-expert weight scale is + supplied via the cuBLASLt B-scale device pointer (FP32 epilogue, so precision + matches folding it into `alpha`). +- Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a + single 32 GB 5090; BF16 needs ≥ 2. + +## The performance bug that was fixed + +`batched_gemm_fp8` originally rebuilt the entire cuBLASLt plan **per expert, +per GEMM, per layer, on every forward pass** — running the algo heuristic +search, creating/destroying the descriptor + 4 layouts + preference, and +`cudaMalloc`-ing a 4-byte scale buffer — roughly 1500 heuristic searches per +decoded token. This made FP8 **slower than BF16**: + +| | FP8 (buggy) | FP8 (fixed) | BF16 | +|---|---|---|---| +| Decode TPOT | 27.0 ms | **17.9 ms** | 18.8 ms | +| Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s | + +Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo) +in a thread-local map keyed by `(M, N, K)` so the heuristic runs once per shape; +allocate the scale buffer once; pass per-expert weight scales by device pointer. +The per-expert loop now issues only `cublasLtMatmul`. + +## Results — GSM8K (200 problems, greedy, TP=2 on the same 2 GPUs) + +Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed +through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean +inter-token latency, per request. + +| metric | FP8 W8A8 | BF16 | +|---|---|---| +| GSM8K accuracy | **93.0 %** | 90.5 % | +| TTFT median | 67.4 ms | 68.8 ms | +| TTFT p90 | 90.4 ms | 96.7 ms | +| TPOT median | **17.45 ms** | 18.26 ms | +| TPOT p90 | 17.65 ms | 18.38 ms | +| Throughput | **57.3 tok/s** | 54.8 tok/s | +| Mean output tokens | 288 | 293 | + +- **Accuracy: unchanged.** FP8 is nominally +2.5 pts, but with n=200 the + standard error is ~2.1 pts, so the two are statistically indistinguishable. + The takeaway is that FP8 did **not** degrade accuracy. +- **Decode: FP8 ~5 % faster** (TPOT 17.45 vs 18.26 ms), reproducible across + runs, with a tighter p90. Modest because the dense-MoE path loads *all* + experts every token and FP8 only halves the *expert* bytes; the per-expert + M=1 launches and M=1 tensor-core inefficiency absorb much of the bandwidth + saving. +- **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens) + gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e. + dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape + heuristic), not prefill compute, at these lengths. + +## Single-GPU (TP=1) + +FP8 runs gpt-oss-20b on **one** 5090 (`bench-gpt-oss --tp 1`, GPU6): TTFT 538 ms, +TPOT 29.0 ms, 34.5 tok/s. BF16 cannot (39 GB > 32 GB). This — fitting a model +that otherwise needs two GPUs onto one — is the largest practical win. + +## Follow-ups (not done) + +- Strided-batched FP8 (one call instead of ~768 per-expert launches per token) — + requires folding the per-expert weight scale into the post-scale kernel, at a + BF16-intermediate precision cost. +- Per-channel (per-output-row) weight scales for better accuracy headroom than + per-tensor. +- Warm common prefill shapes at load to hide the first-request heuristic stall. diff --git a/tools/fp8_compare.py b/tools/fp8_compare.py new file mode 100644 index 0000000..21b2407 --- /dev/null +++ b/tools/fp8_compare.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +"""Compare FP8-W8A8 vs BF16 gpt-oss on one box: GSM8K accuracy + TTFT/TPOT. + +For each model it launches a warm xserv-server (same GPUs / same TP for a fair +compute comparison), waits for a *real* generation to succeed (not /health), +then streams N GSM8K problems through /v1/chat/completions measuring per-request +TTFT (time to first token) and TPOT (mean inter-token latency). Accuracy is the +exact-match rate on the extracted final number. + +Run it ON the GPU box (it manages the servers itself): + + python3 tools/fp8_compare.py \ + --fp8 /opt/wjh/models/gpt-oss-20b-fp8 \ + --bf16 /opt/wjh/models/gpt-oss-20b-bf16 \ + --gpus 0,1 --tp 2 --limit 150 --max-tokens 512 +""" + +import argparse +import json +import os +import re +import signal +import subprocess +import sys +import time +import urllib.request +import urllib.error +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent +GSM8K = SCRIPT_DIR / "bench" / "data" / "gsm8k.json" +SERVER_BIN = SCRIPT_DIR.parent / "target" / "release" / "xserv-server" +SYSTEM = ("You are a careful math problem solver. Solve the problem step by step. " + "Put your final numeric answer inside \\boxed{}.") + +_BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}") +_NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?") + + +def normalize_num(s): + s = s.replace(",", "").strip() + try: + f = float(s) + except ValueError: + return None + return str(int(f)) if f == int(f) else f"{f:g}" + + +def extract_answer(text): + if not text: + return None + boxed = _BOXED_RE.findall(text) + if boxed: + nums = _NUM_RE.findall(boxed[-1]) + if nums: + return normalize_num(nums[-1]) + nums = _NUM_RE.findall(text) + if nums: + return normalize_num(nums[-1]) + return None + + +def pct(vals, p): + if not vals: + return 0.0 + s = sorted(vals) + i = max(0, min(len(s) - 1, int(round((p / 100.0) * (len(s) - 1))))) + return s[i] + + +# ---------- server lifecycle ---------- + +def gpu_mem_used_mb(gpus): + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"], + text=True) + used = {} + for line in out.strip().splitlines(): + idx, mem = [x.strip() for x in line.split(",")] + used[int(idx)] = int(mem) + return max(used.get(g, 0) for g in gpus) + + +def start_server(model_dir, port, tp, gpus, log_path): + env = dict(os.environ) + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpus) + cmd = [str(SERVER_BIN), str(model_dir), "--port", str(port), + "--tp", str(tp), "--max-seq-len", "2048", "--max-batch", "8"] + logf = open(log_path, "wb") + # New session so we can kill the whole process tree without touching ours. + p = subprocess.Popen(cmd, stdout=logf, stderr=subprocess.STDOUT, + env=env, start_new_session=True) + return p + + +def stop_server(p, gpus, drain_to_mb=2000, timeout=120): + if p.poll() is None: + try: + os.killpg(os.getpgid(p.pid), signal.SIGTERM) + except ProcessLookupError: + pass + try: + p.wait(timeout=30) + except subprocess.TimeoutExpired: + try: + os.killpg(os.getpgid(p.pid), signal.SIGKILL) + except ProcessLookupError: + pass + # Wait for VRAM to drain so the next server starts clean. + t0 = time.time() + while time.time() - t0 < timeout: + if gpu_mem_used_mb(gpus) < drain_to_mb: + return + time.sleep(2) + + +def wait_ready(base, model_id, timeout=900): + """Gate on a real 1-token generation, not /health (which lies during load).""" + t0 = time.time() + body = json.dumps({ + "model": model_id, + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 1, "temperature": 0.0, "stream": False, + }).encode() + while time.time() - t0 < timeout: + try: + req = urllib.request.Request(base + "/v1/chat/completions", data=body, + headers={"Content-Type": "application/json"}) + with urllib.request.urlopen(req, timeout=120) as r: + if r.status == 200: + json.loads(r.read()) + return True + except Exception: + time.sleep(3) + return False + + +# ---------- one streamed request ---------- + +def stream_chat(base, model_id, user, max_tokens): + body = json.dumps({ + "model": model_id, + "messages": [{"role": "system", "content": SYSTEM}, + {"role": "user", "content": user}], + "max_tokens": max_tokens, "temperature": 0.0, "stream": True, + }).encode() + req = urllib.request.Request(base + "/v1/chat/completions", data=body, + headers={"Content-Type": "application/json"}) + t0 = time.perf_counter() + ttft = None + t_last = t0 + n = 0 + parts = [] + with urllib.request.urlopen(req, timeout=300) as resp: + for raw in resp: + line = raw.decode("utf-8", "ignore").strip() + if not line.startswith("data:"): + continue + data = line[5:].strip() + if data == "[DONE]": + break + try: + obj = json.loads(data) + except json.JSONDecodeError: + continue + delta = obj["choices"][0].get("delta", {}) + content = delta.get("content") + if content: + now = time.perf_counter() + if ttft is None: + ttft = now - t0 + n += 1 + t_last = now + parts.append(content) + ttft = ttft if ttft is not None else (time.perf_counter() - t0) + decode_span = t_last - t0 - ttft + tpot = decode_span / (n - 1) if n > 1 else 0.0 + return "".join(parts), ttft, tpot, n + + +def run_eval(base, model_id, problems, max_tokens): + correct = 0 + ttfts, tpots, toks = [], [], [] + n_scored = 0 + for i, prob in enumerate(problems): + q = prob["problem"].replace("\n", " ") + try: + text, ttft, tpot, n = stream_chat(base, model_id, q, max_tokens) + except Exception as e: + print(f" [E] {i+1}/{len(problems)} {e}", flush=True) + continue + pred = extract_answer(text) + gold = normalize_num(prob["answer"]) + ok = pred is not None and gold is not None and pred == gold + correct += int(ok) + n_scored += 1 + ttfts.append(ttft * 1000.0) + if tpot > 0: + tpots.append(tpot * 1000.0) + toks.append(n) + mark = "✓" if ok else "✗" + print(f" [{mark}] {i+1:3d}/{len(problems)} gold={prob['answer']:>7s} " + f"pred={str(pred):>7s} ttft={ttft*1000:6.1f}ms tpot={tpot*1000:5.1f}ms tok={n}", + flush=True) + return { + "accuracy": correct / max(n_scored, 1), + "correct": correct, "scored": n_scored, + "ttft_ms_median": pct(ttfts, 50), "ttft_ms_p90": pct(ttfts, 90), + "tpot_ms_median": pct(tpots, 50), "tpot_ms_p90": pct(tpots, 90), + "tok_per_s_median": (1000.0 / pct(tpots, 50)) if pct(tpots, 50) > 0 else 0.0, + "mean_tokens": sum(toks) / max(len(toks), 1), + } + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--fp8", required=True) + ap.add_argument("--bf16", required=True) + ap.add_argument("--gpus", default="0,1") + ap.add_argument("--tp", type=int, default=2) + ap.add_argument("--limit", type=int, default=150) + ap.add_argument("--max-tokens", type=int, default=512) + ap.add_argument("--port", type=int, default=18080) + ap.add_argument("--out", default=None) + args = ap.parse_args() + + gpus = [int(g) for g in args.gpus.split(",")] + with open(GSM8K) as f: + problems = json.load(f)[:args.limit] + + base = f"http://127.0.0.1:{args.port}" + results = {} + for label, model_dir in [("FP8_W8A8", args.fp8), ("BF16", args.bf16)]: + model_id = Path(model_dir).name + log_path = f"/tmp/xserv_{label}.log" + print(f"\n{'='*72}\n {label} ({model_dir}, tp={args.tp}, gpus={gpus})\n{'='*72}", flush=True) + print(f" starting server (log: {log_path}) ...", flush=True) + p = start_server(model_dir, args.port, args.tp, gpus, log_path) + try: + if not wait_ready(base, model_id): + print(f" SERVER NOT READY — tail of log:", flush=True) + print(subprocess.run(["tail", "-30", log_path], capture_output=True, text=True).stdout) + stop_server(p, gpus) + continue + print(f" ready. running {len(problems)} GSM8K problems...", flush=True) + t0 = time.time() + r = run_eval(base, model_id, problems, args.max_tokens) + r["wall_s"] = time.time() - t0 + results[label] = r + print(f" -> acc={r['accuracy']*100:.1f}% ttft_med={r['ttft_ms_median']:.1f}ms " + f"tpot_med={r['tpot_ms_median']:.1f}ms ({r['tok_per_s_median']:.1f} tok/s)", flush=True) + finally: + print(f" stopping server...", flush=True) + stop_server(p, gpus) + + print(f"\n{'='*72}\n SUMMARY (gpt-oss-20b, tp={args.tp}, GSM8K n={args.limit})\n{'='*72}") + print(f"{'metric':<26s} {'FP8_W8A8':>14s} {'BF16':>14s}") + print("-" * 56) + f8, b6 = results.get("FP8_W8A8", {}), results.get("BF16", {}) + def row(name, key, fmt, scale=1.0): + a = f8.get(key); b = b6.get(key) + if a is None or b is None: + return + print(f"{name:<26s} {fmt.format(a*scale):>14s} {fmt.format(b*scale):>14s}") + row("GSM8K accuracy (%)", "accuracy", "{:.1f}", 100.0) + row("TTFT median (ms)", "ttft_ms_median", "{:.1f}") + row("TTFT p90 (ms)", "ttft_ms_p90", "{:.1f}") + row("TPOT median (ms)", "tpot_ms_median", "{:.2f}") + row("TPOT p90 (ms)", "tpot_ms_p90", "{:.2f}") + row("Throughput (tok/s)", "tok_per_s_median", "{:.1f}") + row("Mean output tokens", "mean_tokens", "{:.0f}") + if f8 and b6 and b6.get("tpot_ms_median"): + sp = b6["tpot_ms_median"] / f8["tpot_ms_median"] if f8.get("tpot_ms_median") else 0 + print(f"\n FP8 decode speedup vs BF16: {sp:.2f}x") + + out = args.out or f"/tmp/fp8_compare_{int(time.time())}.json" + with open(out, "w") as f: + json.dump({"args": vars(args), "results": results}, f, indent=2) + print(f"\n saved: {out}") + + +if __name__ == "__main__": + main()