Files
xserv/tools/eval_gsm8k.py
Gahow Wang 3a530956af tools: add FP8 vs BF16 benchmark and GSM8K eval harness
bench_fp8.py — head-to-head comparison of FP8 and BF16 models on
  GSM8K / AIME2025 accuracy plus TTFT/TPOT performance measurement.

eval_gsm8k_batch.sh — lightweight GSM8K accuracy evaluator that
  pipes one problem per xserv-chat invocation and scores with
  \boxed{} / last-number extraction.

Benchmark results (gpt-oss-20b, 50-problem GSM8K):
  FP8 W8A8 TP1 : 94.0%  (single RTX 5090, 25 GB)
  FP8 W8A16 TP1: 94.0%
  BF16 TP2     : 94.0%  (requires 2× RTX 5090)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-06-08 15:43:04 +08:00

157 lines
5.0 KiB
Python

#!/usr/bin/env python3
"""Direct GSM8K evaluation using xserv-chat CLI.
Usage:
python eval_gsm8k.py <model-dir> [--limit N] [--max-tokens N] [--tp N]
Runs xserv-chat on each GSM8K problem, extracts the numeric answer, and
reports accuracy. Uses temperature=0 (greedy) for determinism.
"""
import argparse
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent
DATA_PATH = SCRIPT_DIR / "bench" / "data" / "gsm8k.json"
XSERV_CHAT = Path(__file__).parent.parent / "target" / "release" / "xserv-chat"
SYSTEM_PROMPT = (
"You are a careful math problem solver. Solve the problem step by step. "
"Put your final numeric answer inside \\boxed{}."
)
_BOXED_RE = re.compile(r"\\boxed\s*\{([^{}]*)\}")
_NUM_RE = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?")
def normalize_num(s: str) -> str | None:
s = s.replace(",", "").strip()
try:
f = float(s)
except ValueError:
return None
return str(int(f)) if f.is_integer() else f"{f:g}"
def extract_answer(text: str) -> str | None:
if not text:
return None
boxed = _BOXED_RE.findall(text)
if boxed:
nums = _NUM_RE.findall(boxed[-1])
if nums:
return normalize_num(nums[-1])
nums = _NUM_RE.findall(text)
if nums:
return normalize_num(nums[-1])
return None
def run_one(model_dir: str, problem: str, max_tokens: int, tp: int) -> tuple[str, float]:
"""Run xserv-chat on a single problem, return (response_text, elapsed_s)."""
cmd = [
str(XSERV_CHAT), model_dir,
"--max-tokens", str(max_tokens),
"--max-seq-len", "2048",
"--system", SYSTEM_PROMPT,
"--no-color",
]
if tp > 1:
cmd += ["--tp", str(tp)]
t0 = time.time()
proc = subprocess.run(
cmd,
input=problem + "\n/exit\n",
capture_output=True,
text=True,
timeout=120,
env={**os.environ, "CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0")},
)
elapsed = time.time() - t0
# Parse the assistant response from stdout
output = proc.stdout
# The output format is: "assistant> <response>\nuser>"
response = ""
for line in output.split("\n"):
if line.startswith("assistant> "):
response = line[len("assistant> "):]
elif response and not line.startswith("user>"):
response += "\n" + line
# Also capture multi-line responses between "assistant>" and next "user>"
if "assistant>" in output:
parts = output.split("assistant>", 1)
if len(parts) > 1:
rest = parts[1]
if "user>" in rest:
response = rest[:rest.index("user>")].strip()
else:
response = rest.strip()
return response, elapsed
def main():
parser = argparse.ArgumentParser(description="GSM8K evaluation via xserv-chat")
parser.add_argument("model_dir", help="Model directory")
parser.add_argument("--limit", type=int, default=50, help="Number of problems (default: 50)")
parser.add_argument("--max-tokens", type=int, default=512, help="Max generation tokens")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism")
parser.add_argument("--offset", type=int, default=0, help="Start from problem N")
args = parser.parse_args()
if not DATA_PATH.exists():
print(f"Error: {DATA_PATH} not found", file=sys.stderr)
sys.exit(1)
with open(DATA_PATH) as f:
problems = json.load(f)
problems = problems[args.offset:args.offset + args.limit]
print(f"GSM8K evaluation: {len(problems)} problems, model={args.model_dir}")
print(f"max_tokens={args.max_tokens}, tp={args.tp}")
print("-" * 72)
correct = 0
errors = 0
total_time = 0.0
for i, prob in enumerate(problems):
try:
response, elapsed = run_one(args.model_dir, prob["problem"], args.max_tokens, args.tp)
total_time += elapsed
pred = extract_answer(response)
gold = normalize_num(prob["answer"])
is_correct = pred is not None and gold is not None and pred == gold
if is_correct:
correct += 1
mark = "" if is_correct else ""
print(f"[{mark}] {i+1:3d}/{len(problems)} "
f"id={prob['id']:>4s} gold={prob['answer']:>8s} "
f"pred={str(pred):>8s} {elapsed:.1f}s")
except subprocess.TimeoutExpired:
errors += 1
print(f"[T] {i+1:3d}/{len(problems)} id={prob['id']:>4s} TIMEOUT")
except Exception as e:
errors += 1
print(f"[E] {i+1:3d}/{len(problems)} id={prob['id']:>4s} {e}")
print("-" * 72)
n_scored = len(problems) - errors
accuracy = correct / max(n_scored, 1)
print(f"Results: {correct}/{n_scored} correct = {accuracy*100:.1f}% accuracy")
if errors:
print(f" ({errors} errors/timeouts)")
print(f"Total time: {total_time:.1f}s, avg {total_time/max(len(problems),1):.1f}s/problem")
if __name__ == "__main__":
main()