#!/usr/bin/env python3 """Compare FP8-W8A8 vs BF16 gpt-oss on one box: GSM8K accuracy + TTFT/TPOT. For each model it launches a warm xserv-server (same GPUs / same TP for a fair compute comparison), waits for a *real* generation to succeed (not /health), then streams N GSM8K problems through /v1/chat/completions measuring per-request TTFT (time to first token) and TPOT (mean inter-token latency). Accuracy is the exact-match rate on the extracted final number. Run it ON the GPU box (it manages the servers itself): python3 tools/fp8_compare.py \ --fp8 /opt/wjh/models/gpt-oss-20b-fp8 \ --bf16 /opt/wjh/models/gpt-oss-20b-bf16 \ --gpus 0,1 --tp 2 --limit 150 --max-tokens 512 """ import argparse import json import os import re import signal import subprocess import sys import time import urllib.request import urllib.error from pathlib import Path SCRIPT_DIR = Path(__file__).parent GSM8K = SCRIPT_DIR / "bench" / "data" / "gsm8k.json" SERVER_BIN = SCRIPT_DIR.parent / "target" / "release" / "xserv-server" SYSTEM = ("You are a careful math problem solver. Solve the problem step by step. " "Put your final numeric answer inside \\boxed{}.") _BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}") _NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?") def normalize_num(s): s = s.replace(",", "").strip() try: f = float(s) except ValueError: return None return str(int(f)) if f == int(f) else f"{f:g}" def extract_answer(text): if not text: return None boxed = _BOXED_RE.findall(text) if boxed: nums = _NUM_RE.findall(boxed[-1]) if nums: return normalize_num(nums[-1]) nums = _NUM_RE.findall(text) if nums: return normalize_num(nums[-1]) return None def pct(vals, p): if not vals: return 0.0 s = sorted(vals) i = max(0, min(len(s) - 1, int(round((p / 100.0) * (len(s) - 1))))) return s[i] # ---------- server lifecycle ---------- def gpu_mem_used_mb(gpus): out = subprocess.check_output( ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"], text=True) used = {} for line in out.strip().splitlines(): idx, mem = [x.strip() for x in line.split(",")] used[int(idx)] = int(mem) return max(used.get(g, 0) for g in gpus) def start_server(model_dir, port, tp, gpus, log_path): env = dict(os.environ) env["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpus) cmd = [str(SERVER_BIN), str(model_dir), "--port", str(port), "--tp", str(tp), "--max-seq-len", "2048", "--max-batch", "8"] logf = open(log_path, "wb") # New session so we can kill the whole process tree without touching ours. p = subprocess.Popen(cmd, stdout=logf, stderr=subprocess.STDOUT, env=env, start_new_session=True) return p def stop_server(p, gpus, drain_to_mb=2000, timeout=120): if p.poll() is None: try: os.killpg(os.getpgid(p.pid), signal.SIGTERM) except ProcessLookupError: pass try: p.wait(timeout=30) except subprocess.TimeoutExpired: try: os.killpg(os.getpgid(p.pid), signal.SIGKILL) except ProcessLookupError: pass # Wait for VRAM to drain so the next server starts clean. t0 = time.time() while time.time() - t0 < timeout: if gpu_mem_used_mb(gpus) < drain_to_mb: return time.sleep(2) def wait_ready(base, model_id, timeout=900): """Gate on a real 1-token generation, not /health (which lies during load).""" t0 = time.time() body = json.dumps({ "model": model_id, "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1, "temperature": 0.0, "stream": False, }).encode() while time.time() - t0 < timeout: try: req = urllib.request.Request(base + "/v1/chat/completions", data=body, headers={"Content-Type": "application/json"}) with urllib.request.urlopen(req, timeout=120) as r: if r.status == 200: json.loads(r.read()) return True except Exception: time.sleep(3) return False # ---------- one streamed request ---------- def stream_chat(base, model_id, user, max_tokens): body = json.dumps({ "model": model_id, "messages": [{"role": "system", "content": SYSTEM}, {"role": "user", "content": user}], "max_tokens": max_tokens, "temperature": 0.0, "stream": True, }).encode() req = urllib.request.Request(base + "/v1/chat/completions", data=body, headers={"Content-Type": "application/json"}) t0 = time.perf_counter() ttft = None t_last = t0 n = 0 parts = [] with urllib.request.urlopen(req, timeout=300) as resp: for raw in resp: line = raw.decode("utf-8", "ignore").strip() if not line.startswith("data:"): continue data = line[5:].strip() if data == "[DONE]": break try: obj = json.loads(data) except json.JSONDecodeError: continue delta = obj["choices"][0].get("delta", {}) content = delta.get("content") if content: now = time.perf_counter() if ttft is None: ttft = now - t0 n += 1 t_last = now parts.append(content) ttft = ttft if ttft is not None else (time.perf_counter() - t0) decode_span = t_last - t0 - ttft tpot = decode_span / (n - 1) if n > 1 else 0.0 return "".join(parts), ttft, tpot, n def run_eval(base, model_id, problems, max_tokens): correct = 0 ttfts, tpots, toks = [], [], [] n_scored = 0 for i, prob in enumerate(problems): q = prob["problem"].replace("\n", " ") try: text, ttft, tpot, n = stream_chat(base, model_id, q, max_tokens) except Exception as e: print(f" [E] {i+1}/{len(problems)} {e}", flush=True) continue pred = extract_answer(text) gold = normalize_num(prob["answer"]) ok = pred is not None and gold is not None and pred == gold correct += int(ok) n_scored += 1 ttfts.append(ttft * 1000.0) if tpot > 0: tpots.append(tpot * 1000.0) toks.append(n) mark = "✓" if ok else "✗" print(f" [{mark}] {i+1:3d}/{len(problems)} gold={prob['answer']:>7s} " f"pred={str(pred):>7s} ttft={ttft*1000:6.1f}ms tpot={tpot*1000:5.1f}ms tok={n}", flush=True) return { "accuracy": correct / max(n_scored, 1), "correct": correct, "scored": n_scored, "ttft_ms_median": pct(ttfts, 50), "ttft_ms_p90": pct(ttfts, 90), "tpot_ms_median": pct(tpots, 50), "tpot_ms_p90": pct(tpots, 90), "tok_per_s_median": (1000.0 / pct(tpots, 50)) if pct(tpots, 50) > 0 else 0.0, "mean_tokens": sum(toks) / max(len(toks), 1), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--fp8", required=True) ap.add_argument("--bf16", required=True) ap.add_argument("--gpus", default="0,1") ap.add_argument("--tp", type=int, default=2) ap.add_argument("--limit", type=int, default=150) ap.add_argument("--max-tokens", type=int, default=512) ap.add_argument("--port", type=int, default=18080) ap.add_argument("--out", default=None) args = ap.parse_args() gpus = [int(g) for g in args.gpus.split(",")] with open(GSM8K) as f: problems = json.load(f)[:args.limit] base = f"http://127.0.0.1:{args.port}" results = {} for label, model_dir in [("FP8_W8A8", args.fp8), ("BF16", args.bf16)]: model_id = Path(model_dir).name log_path = f"/tmp/xserv_{label}.log" print(f"\n{'='*72}\n {label} ({model_dir}, tp={args.tp}, gpus={gpus})\n{'='*72}", flush=True) print(f" starting server (log: {log_path}) ...", flush=True) p = start_server(model_dir, args.port, args.tp, gpus, log_path) try: if not wait_ready(base, model_id): print(f" SERVER NOT READY — tail of log:", flush=True) print(subprocess.run(["tail", "-30", log_path], capture_output=True, text=True).stdout) stop_server(p, gpus) continue print(f" ready. running {len(problems)} GSM8K problems...", flush=True) t0 = time.time() r = run_eval(base, model_id, problems, args.max_tokens) r["wall_s"] = time.time() - t0 results[label] = r print(f" -> acc={r['accuracy']*100:.1f}% ttft_med={r['ttft_ms_median']:.1f}ms " f"tpot_med={r['tpot_ms_median']:.1f}ms ({r['tok_per_s_median']:.1f} tok/s)", flush=True) finally: print(f" stopping server...", flush=True) stop_server(p, gpus) print(f"\n{'='*72}\n SUMMARY (gpt-oss-20b, tp={args.tp}, GSM8K n={args.limit})\n{'='*72}") print(f"{'metric':<26s} {'FP8_W8A8':>14s} {'BF16':>14s}") print("-" * 56) f8, b6 = results.get("FP8_W8A8", {}), results.get("BF16", {}) def row(name, key, fmt, scale=1.0): a = f8.get(key); b = b6.get(key) if a is None or b is None: return print(f"{name:<26s} {fmt.format(a*scale):>14s} {fmt.format(b*scale):>14s}") row("GSM8K accuracy (%)", "accuracy", "{:.1f}", 100.0) row("TTFT median (ms)", "ttft_ms_median", "{:.1f}") row("TTFT p90 (ms)", "ttft_ms_p90", "{:.1f}") row("TPOT median (ms)", "tpot_ms_median", "{:.2f}") row("TPOT p90 (ms)", "tpot_ms_p90", "{:.2f}") row("Throughput (tok/s)", "tok_per_s_median", "{:.1f}") row("Mean output tokens", "mean_tokens", "{:.0f}") if f8 and b6 and b6.get("tpot_ms_median"): sp = b6["tpot_ms_median"] / f8["tpot_ms_median"] if f8.get("tpot_ms_median") else 0 print(f"\n FP8 decode speedup vs BF16: {sp:.2f}x") out = args.out or f"/tmp/fp8_compare_{int(time.time())}.json" with open(out, "w") as f: json.dump({"args": vars(args), "results": results}, f, indent=2) print(f"\n saved: {out}") if __name__ == "__main__": main()