tools: warm-server FP8 vs BF16 benchmark + results doc
fp8_compare.py launches one xserv-server per model (same GPUs / TP for a fair comparison), gates readiness on a real generation (not /health), and streams GSM8K through /v1/chat/completions measuring per-request TTFT (time to first token) and TPOT (mean inter-token latency) plus exact-match accuracy. docs/benchmarks/fp8-quantization.md records the quantization scheme, the perf-bug fix, and the dash5 results. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
83
docs/benchmarks/fp8-quantization.md
Normal file
83
docs/benchmarks/fp8-quantization.md
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# FP8 W8A8 quantization — gpt-oss-20b (dash5, 8× RTX 5090)
|
||||||
|
|
||||||
|
Operator-level FP8 E4M3 quantization of the MoE expert weights, with real
|
||||||
|
cuBLASLt FP8 tensor-core GEMM (W8A8: FP8 weights × dynamically-quantized FP8
|
||||||
|
activations). All other tensors (attention, router, embeddings, norms, biases)
|
||||||
|
stay BF16.
|
||||||
|
|
||||||
|
## Scheme
|
||||||
|
|
||||||
|
- **Weights** (`tools/quantize_fp8.py`): expert `gate_up_proj` / `down_proj`
|
||||||
|
quantized BF16 → FP8 E4M3 with a **per-expert scalar** scale (`absmax/448`).
|
||||||
|
Stored transposed `[E, N, K]` because cuBLASLt FP8 on Blackwell (sm120)
|
||||||
|
requires `transA=T`.
|
||||||
|
- **Activations**: quantized dynamically at runtime, **per-token** (per-row
|
||||||
|
absmax), recovered by a post-GEMM row scale.
|
||||||
|
- **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`)
|
||||||
|
runs one cuBLASLt FP8 matmul per expert; the per-expert weight scale is
|
||||||
|
supplied via the cuBLASLt B-scale device pointer (FP32 epilogue, so precision
|
||||||
|
matches folding it into `alpha`).
|
||||||
|
- Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a
|
||||||
|
single 32 GB 5090; BF16 needs ≥ 2.
|
||||||
|
|
||||||
|
## The performance bug that was fixed
|
||||||
|
|
||||||
|
`batched_gemm_fp8` originally rebuilt the entire cuBLASLt plan **per expert,
|
||||||
|
per GEMM, per layer, on every forward pass** — running the algo heuristic
|
||||||
|
search, creating/destroying the descriptor + 4 layouts + preference, and
|
||||||
|
`cudaMalloc`-ing a 4-byte scale buffer — roughly 1500 heuristic searches per
|
||||||
|
decoded token. This made FP8 **slower than BF16**:
|
||||||
|
|
||||||
|
| | FP8 (buggy) | FP8 (fixed) | BF16 |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Decode TPOT | 27.0 ms | **17.9 ms** | 18.8 ms |
|
||||||
|
| Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s |
|
||||||
|
|
||||||
|
Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo)
|
||||||
|
in a thread-local map keyed by `(M, N, K)` so the heuristic runs once per shape;
|
||||||
|
allocate the scale buffer once; pass per-expert weight scales by device pointer.
|
||||||
|
The per-expert loop now issues only `cublasLtMatmul`.
|
||||||
|
|
||||||
|
## Results — GSM8K (200 problems, greedy, TP=2 on the same 2 GPUs)
|
||||||
|
|
||||||
|
Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed
|
||||||
|
through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean
|
||||||
|
inter-token latency, per request.
|
||||||
|
|
||||||
|
| metric | FP8 W8A8 | BF16 |
|
||||||
|
|---|---|---|
|
||||||
|
| GSM8K accuracy | **93.0 %** | 90.5 % |
|
||||||
|
| TTFT median | 67.4 ms | 68.8 ms |
|
||||||
|
| TTFT p90 | 90.4 ms | 96.7 ms |
|
||||||
|
| TPOT median | **17.45 ms** | 18.26 ms |
|
||||||
|
| TPOT p90 | 17.65 ms | 18.38 ms |
|
||||||
|
| Throughput | **57.3 tok/s** | 54.8 tok/s |
|
||||||
|
| Mean output tokens | 288 | 293 |
|
||||||
|
|
||||||
|
- **Accuracy: unchanged.** FP8 is nominally +2.5 pts, but with n=200 the
|
||||||
|
standard error is ~2.1 pts, so the two are statistically indistinguishable.
|
||||||
|
The takeaway is that FP8 did **not** degrade accuracy.
|
||||||
|
- **Decode: FP8 ~5 % faster** (TPOT 17.45 vs 18.26 ms), reproducible across
|
||||||
|
runs, with a tighter p90. Modest because the dense-MoE path loads *all*
|
||||||
|
experts every token and FP8 only halves the *expert* bytes; the per-expert
|
||||||
|
M=1 launches and M=1 tensor-core inefficiency absorb much of the bandwidth
|
||||||
|
saving.
|
||||||
|
- **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens)
|
||||||
|
gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e.
|
||||||
|
dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape
|
||||||
|
heuristic), not prefill compute, at these lengths.
|
||||||
|
|
||||||
|
## Single-GPU (TP=1)
|
||||||
|
|
||||||
|
FP8 runs gpt-oss-20b on **one** 5090 (`bench-gpt-oss --tp 1`, GPU6): TTFT 538 ms,
|
||||||
|
TPOT 29.0 ms, 34.5 tok/s. BF16 cannot (39 GB > 32 GB). This — fitting a model
|
||||||
|
that otherwise needs two GPUs onto one — is the largest practical win.
|
||||||
|
|
||||||
|
## Follow-ups (not done)
|
||||||
|
|
||||||
|
- Strided-batched FP8 (one call instead of ~768 per-expert launches per token) —
|
||||||
|
requires folding the per-expert weight scale into the post-scale kernel, at a
|
||||||
|
BF16-intermediate precision cost.
|
||||||
|
- Per-channel (per-output-row) weight scales for better accuracy headroom than
|
||||||
|
per-tensor.
|
||||||
|
- Warm common prefill shapes at load to hide the first-request heuristic stall.
|
||||||
283
tools/fp8_compare.py
Normal file
283
tools/fp8_compare.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Compare FP8-W8A8 vs BF16 gpt-oss on one box: GSM8K accuracy + TTFT/TPOT.
|
||||||
|
|
||||||
|
For each model it launches a warm xserv-server (same GPUs / same TP for a fair
|
||||||
|
compute comparison), waits for a *real* generation to succeed (not /health),
|
||||||
|
then streams N GSM8K problems through /v1/chat/completions measuring per-request
|
||||||
|
TTFT (time to first token) and TPOT (mean inter-token latency). Accuracy is the
|
||||||
|
exact-match rate on the extracted final number.
|
||||||
|
|
||||||
|
Run it ON the GPU box (it manages the servers itself):
|
||||||
|
|
||||||
|
python3 tools/fp8_compare.py \
|
||||||
|
--fp8 /opt/wjh/models/gpt-oss-20b-fp8 \
|
||||||
|
--bf16 /opt/wjh/models/gpt-oss-20b-bf16 \
|
||||||
|
--gpus 0,1 --tp 2 --limit 150 --max-tokens 512
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
SCRIPT_DIR = Path(__file__).parent
|
||||||
|
GSM8K = SCRIPT_DIR / "bench" / "data" / "gsm8k.json"
|
||||||
|
SERVER_BIN = SCRIPT_DIR.parent / "target" / "release" / "xserv-server"
|
||||||
|
SYSTEM = ("You are a careful math problem solver. Solve the problem step by step. "
|
||||||
|
"Put your final numeric answer inside \\boxed{}.")
|
||||||
|
|
||||||
|
_BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}")
|
||||||
|
_NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_num(s):
|
||||||
|
s = s.replace(",", "").strip()
|
||||||
|
try:
|
||||||
|
f = float(s)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return str(int(f)) if f == int(f) else f"{f:g}"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer(text):
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
boxed = _BOXED_RE.findall(text)
|
||||||
|
if boxed:
|
||||||
|
nums = _NUM_RE.findall(boxed[-1])
|
||||||
|
if nums:
|
||||||
|
return normalize_num(nums[-1])
|
||||||
|
nums = _NUM_RE.findall(text)
|
||||||
|
if nums:
|
||||||
|
return normalize_num(nums[-1])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def pct(vals, p):
|
||||||
|
if not vals:
|
||||||
|
return 0.0
|
||||||
|
s = sorted(vals)
|
||||||
|
i = max(0, min(len(s) - 1, int(round((p / 100.0) * (len(s) - 1)))))
|
||||||
|
return s[i]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- server lifecycle ----------
|
||||||
|
|
||||||
|
def gpu_mem_used_mb(gpus):
|
||||||
|
out = subprocess.check_output(
|
||||||
|
["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"],
|
||||||
|
text=True)
|
||||||
|
used = {}
|
||||||
|
for line in out.strip().splitlines():
|
||||||
|
idx, mem = [x.strip() for x in line.split(",")]
|
||||||
|
used[int(idx)] = int(mem)
|
||||||
|
return max(used.get(g, 0) for g in gpus)
|
||||||
|
|
||||||
|
|
||||||
|
def start_server(model_dir, port, tp, gpus, log_path):
|
||||||
|
env = dict(os.environ)
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpus)
|
||||||
|
cmd = [str(SERVER_BIN), str(model_dir), "--port", str(port),
|
||||||
|
"--tp", str(tp), "--max-seq-len", "2048", "--max-batch", "8"]
|
||||||
|
logf = open(log_path, "wb")
|
||||||
|
# New session so we can kill the whole process tree without touching ours.
|
||||||
|
p = subprocess.Popen(cmd, stdout=logf, stderr=subprocess.STDOUT,
|
||||||
|
env=env, start_new_session=True)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def stop_server(p, gpus, drain_to_mb=2000, timeout=120):
|
||||||
|
if p.poll() is None:
|
||||||
|
try:
|
||||||
|
os.killpg(os.getpgid(p.pid), signal.SIGTERM)
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
p.wait(timeout=30)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
try:
|
||||||
|
os.killpg(os.getpgid(p.pid), signal.SIGKILL)
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
# Wait for VRAM to drain so the next server starts clean.
|
||||||
|
t0 = time.time()
|
||||||
|
while time.time() - t0 < timeout:
|
||||||
|
if gpu_mem_used_mb(gpus) < drain_to_mb:
|
||||||
|
return
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_ready(base, model_id, timeout=900):
|
||||||
|
"""Gate on a real 1-token generation, not /health (which lies during load)."""
|
||||||
|
t0 = time.time()
|
||||||
|
body = json.dumps({
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"max_tokens": 1, "temperature": 0.0, "stream": False,
|
||||||
|
}).encode()
|
||||||
|
while time.time() - t0 < timeout:
|
||||||
|
try:
|
||||||
|
req = urllib.request.Request(base + "/v1/chat/completions", data=body,
|
||||||
|
headers={"Content-Type": "application/json"})
|
||||||
|
with urllib.request.urlopen(req, timeout=120) as r:
|
||||||
|
if r.status == 200:
|
||||||
|
json.loads(r.read())
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
time.sleep(3)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- one streamed request ----------
|
||||||
|
|
||||||
|
def stream_chat(base, model_id, user, max_tokens):
|
||||||
|
body = json.dumps({
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [{"role": "system", "content": SYSTEM},
|
||||||
|
{"role": "user", "content": user}],
|
||||||
|
"max_tokens": max_tokens, "temperature": 0.0, "stream": True,
|
||||||
|
}).encode()
|
||||||
|
req = urllib.request.Request(base + "/v1/chat/completions", data=body,
|
||||||
|
headers={"Content-Type": "application/json"})
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
ttft = None
|
||||||
|
t_last = t0
|
||||||
|
n = 0
|
||||||
|
parts = []
|
||||||
|
with urllib.request.urlopen(req, timeout=300) as resp:
|
||||||
|
for raw in resp:
|
||||||
|
line = raw.decode("utf-8", "ignore").strip()
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data = line[5:].strip()
|
||||||
|
if data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
obj = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
delta = obj["choices"][0].get("delta", {})
|
||||||
|
content = delta.get("content")
|
||||||
|
if content:
|
||||||
|
now = time.perf_counter()
|
||||||
|
if ttft is None:
|
||||||
|
ttft = now - t0
|
||||||
|
n += 1
|
||||||
|
t_last = now
|
||||||
|
parts.append(content)
|
||||||
|
ttft = ttft if ttft is not None else (time.perf_counter() - t0)
|
||||||
|
decode_span = t_last - t0 - ttft
|
||||||
|
tpot = decode_span / (n - 1) if n > 1 else 0.0
|
||||||
|
return "".join(parts), ttft, tpot, n
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval(base, model_id, problems, max_tokens):
|
||||||
|
correct = 0
|
||||||
|
ttfts, tpots, toks = [], [], []
|
||||||
|
n_scored = 0
|
||||||
|
for i, prob in enumerate(problems):
|
||||||
|
q = prob["problem"].replace("\n", " ")
|
||||||
|
try:
|
||||||
|
text, ttft, tpot, n = stream_chat(base, model_id, q, max_tokens)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [E] {i+1}/{len(problems)} {e}", flush=True)
|
||||||
|
continue
|
||||||
|
pred = extract_answer(text)
|
||||||
|
gold = normalize_num(prob["answer"])
|
||||||
|
ok = pred is not None and gold is not None and pred == gold
|
||||||
|
correct += int(ok)
|
||||||
|
n_scored += 1
|
||||||
|
ttfts.append(ttft * 1000.0)
|
||||||
|
if tpot > 0:
|
||||||
|
tpots.append(tpot * 1000.0)
|
||||||
|
toks.append(n)
|
||||||
|
mark = "✓" if ok else "✗"
|
||||||
|
print(f" [{mark}] {i+1:3d}/{len(problems)} gold={prob['answer']:>7s} "
|
||||||
|
f"pred={str(pred):>7s} ttft={ttft*1000:6.1f}ms tpot={tpot*1000:5.1f}ms tok={n}",
|
||||||
|
flush=True)
|
||||||
|
return {
|
||||||
|
"accuracy": correct / max(n_scored, 1),
|
||||||
|
"correct": correct, "scored": n_scored,
|
||||||
|
"ttft_ms_median": pct(ttfts, 50), "ttft_ms_p90": pct(ttfts, 90),
|
||||||
|
"tpot_ms_median": pct(tpots, 50), "tpot_ms_p90": pct(tpots, 90),
|
||||||
|
"tok_per_s_median": (1000.0 / pct(tpots, 50)) if pct(tpots, 50) > 0 else 0.0,
|
||||||
|
"mean_tokens": sum(toks) / max(len(toks), 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--fp8", required=True)
|
||||||
|
ap.add_argument("--bf16", required=True)
|
||||||
|
ap.add_argument("--gpus", default="0,1")
|
||||||
|
ap.add_argument("--tp", type=int, default=2)
|
||||||
|
ap.add_argument("--limit", type=int, default=150)
|
||||||
|
ap.add_argument("--max-tokens", type=int, default=512)
|
||||||
|
ap.add_argument("--port", type=int, default=18080)
|
||||||
|
ap.add_argument("--out", default=None)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
gpus = [int(g) for g in args.gpus.split(",")]
|
||||||
|
with open(GSM8K) as f:
|
||||||
|
problems = json.load(f)[:args.limit]
|
||||||
|
|
||||||
|
base = f"http://127.0.0.1:{args.port}"
|
||||||
|
results = {}
|
||||||
|
for label, model_dir in [("FP8_W8A8", args.fp8), ("BF16", args.bf16)]:
|
||||||
|
model_id = Path(model_dir).name
|
||||||
|
log_path = f"/tmp/xserv_{label}.log"
|
||||||
|
print(f"\n{'='*72}\n {label} ({model_dir}, tp={args.tp}, gpus={gpus})\n{'='*72}", flush=True)
|
||||||
|
print(f" starting server (log: {log_path}) ...", flush=True)
|
||||||
|
p = start_server(model_dir, args.port, args.tp, gpus, log_path)
|
||||||
|
try:
|
||||||
|
if not wait_ready(base, model_id):
|
||||||
|
print(f" SERVER NOT READY — tail of log:", flush=True)
|
||||||
|
print(subprocess.run(["tail", "-30", log_path], capture_output=True, text=True).stdout)
|
||||||
|
stop_server(p, gpus)
|
||||||
|
continue
|
||||||
|
print(f" ready. running {len(problems)} GSM8K problems...", flush=True)
|
||||||
|
t0 = time.time()
|
||||||
|
r = run_eval(base, model_id, problems, args.max_tokens)
|
||||||
|
r["wall_s"] = time.time() - t0
|
||||||
|
results[label] = r
|
||||||
|
print(f" -> acc={r['accuracy']*100:.1f}% ttft_med={r['ttft_ms_median']:.1f}ms "
|
||||||
|
f"tpot_med={r['tpot_ms_median']:.1f}ms ({r['tok_per_s_median']:.1f} tok/s)", flush=True)
|
||||||
|
finally:
|
||||||
|
print(f" stopping server...", flush=True)
|
||||||
|
stop_server(p, gpus)
|
||||||
|
|
||||||
|
print(f"\n{'='*72}\n SUMMARY (gpt-oss-20b, tp={args.tp}, GSM8K n={args.limit})\n{'='*72}")
|
||||||
|
print(f"{'metric':<26s} {'FP8_W8A8':>14s} {'BF16':>14s}")
|
||||||
|
print("-" * 56)
|
||||||
|
f8, b6 = results.get("FP8_W8A8", {}), results.get("BF16", {})
|
||||||
|
def row(name, key, fmt, scale=1.0):
|
||||||
|
a = f8.get(key); b = b6.get(key)
|
||||||
|
if a is None or b is None:
|
||||||
|
return
|
||||||
|
print(f"{name:<26s} {fmt.format(a*scale):>14s} {fmt.format(b*scale):>14s}")
|
||||||
|
row("GSM8K accuracy (%)", "accuracy", "{:.1f}", 100.0)
|
||||||
|
row("TTFT median (ms)", "ttft_ms_median", "{:.1f}")
|
||||||
|
row("TTFT p90 (ms)", "ttft_ms_p90", "{:.1f}")
|
||||||
|
row("TPOT median (ms)", "tpot_ms_median", "{:.2f}")
|
||||||
|
row("TPOT p90 (ms)", "tpot_ms_p90", "{:.2f}")
|
||||||
|
row("Throughput (tok/s)", "tok_per_s_median", "{:.1f}")
|
||||||
|
row("Mean output tokens", "mean_tokens", "{:.0f}")
|
||||||
|
if f8 and b6 and b6.get("tpot_ms_median"):
|
||||||
|
sp = b6["tpot_ms_median"] / f8["tpot_ms_median"] if f8.get("tpot_ms_median") else 0
|
||||||
|
print(f"\n FP8 decode speedup vs BF16: {sp:.2f}x")
|
||||||
|
|
||||||
|
out = args.out or f"/tmp/fp8_compare_{int(time.time())}.json"
|
||||||
|
with open(out, "w") as f:
|
||||||
|
json.dump({"args": vars(args), "results": results}, f, indent=2)
|
||||||
|
print(f"\n saved: {out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user