- test_correctness.py: compare prefill logits top-20 vs HF transformers - bench_server.py: HTTP API benchmark (throughput, streaming, concurrent, EOS leak check) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
240 lines
8.4 KiB
Python
240 lines
8.4 KiB
Python
#!/usr/bin/env python3
|
|
"""Compare xserv vs HuggingFace transformers for correctness and performance.
|
|
|
|
Strategy: run xserv first (on GPU 0), collect results, then load HF model
|
|
on GPU 0 (xserv process exits and frees VRAM).
|
|
"""
|
|
|
|
import subprocess
|
|
import time
|
|
import json
|
|
import sys
|
|
import os
|
|
import re
|
|
|
|
DEVICE = "cuda:0"
|
|
MODEL_DIR = "/opt/wjh/models/qwen3-8b"
|
|
XSERV_DUMP = "/opt/wjh/projects/xserv/target/release/dump-logits"
|
|
|
|
|
|
def xserv_dump_logits(prompt):
|
|
"""Run xserv dump-logits and parse top-20."""
|
|
env = os.environ.copy()
|
|
env["PATH"] = "/usr/local/cuda-12.9/bin:" + env.get("PATH", "")
|
|
env["CUDA_VISIBLE_DEVICES"] = "0"
|
|
t0 = time.perf_counter()
|
|
result = subprocess.run(
|
|
[XSERV_DUMP, MODEL_DIR, prompt],
|
|
capture_output=True, text=True, timeout=180, env=env
|
|
)
|
|
elapsed = time.perf_counter() - t0
|
|
if result.returncode != 0:
|
|
print(f" xserv error: {result.stderr[-500:]}")
|
|
return None, elapsed
|
|
|
|
top20 = []
|
|
for line in result.stdout.strip().split("\n"):
|
|
m = re.match(r'\s*\[\s*\d+\]\s+id=\s*(\d+)\s+logit=\s*([\-\d.]+)', line)
|
|
if m:
|
|
top20.append((int(m.group(1)), float(m.group(2))))
|
|
return top20, elapsed
|
|
|
|
|
|
def hf_prefill_top20(model, tokenizer, prompt):
|
|
"""Get top-20 logits from HF."""
|
|
import torch
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
logits = outputs.logits[0, -1, :].float().cpu().numpy()
|
|
import numpy as np
|
|
top_ids = np.argsort(logits)[-20:][::-1]
|
|
return [(int(i), float(logits[i])) for i in top_ids]
|
|
|
|
|
|
def hf_generate(model, tokenizer, prompt, max_new=80):
|
|
"""Greedy generation from HF."""
|
|
import torch
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
|
prompt_len = inputs["input_ids"].shape[1]
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
with torch.no_grad():
|
|
out = model.generate(**inputs, max_new_tokens=max_new, do_sample=False)
|
|
torch.cuda.synchronize()
|
|
elapsed = time.perf_counter() - t0
|
|
gen_tokens = out.shape[1] - prompt_len
|
|
text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
|
|
return gen_tokens / elapsed, elapsed, gen_tokens, text
|
|
|
|
|
|
def compare_top20(hf_top20, xs_top20, name):
|
|
if xs_top20 is None:
|
|
print(f" [{name}] SKIP (xserv failed)")
|
|
return False
|
|
|
|
hf_ids = [x[0] for x in hf_top20]
|
|
xs_ids = [x[0] for x in xs_top20]
|
|
top1_match = hf_ids[0] == xs_ids[0]
|
|
top5_match = set(hf_ids[:5]) == set(xs_ids[:5])
|
|
top10_overlap = len(set(hf_ids[:10]) & set(xs_ids[:10]))
|
|
|
|
hf_dict = dict(hf_top20)
|
|
xs_dict = dict(xs_top20)
|
|
common = set(hf_dict.keys()) & set(xs_dict.keys())
|
|
if common:
|
|
diffs = [abs(hf_dict[k] - xs_dict[k]) for k in common]
|
|
max_diff = max(diffs)
|
|
mean_diff = sum(diffs) / len(diffs)
|
|
else:
|
|
max_diff = mean_diff = float('inf')
|
|
|
|
status = "PASS" if top1_match and top5_match else "FAIL"
|
|
print(f" [{name}] {status}: top1={'Y' if top1_match else 'N'}, "
|
|
f"top5={'Y' if top5_match else 'N'}, top10={top10_overlap}/10, "
|
|
f"max_diff={max_diff:.4f}, mean_diff={mean_diff:.4f}")
|
|
print(f" HF top5: {[(i, f'{v:.2f}') for i, v in hf_top20[:5]]}")
|
|
print(f" XS top5: {[(i, f'{v:.2f}') for i, v in xs_top20[:5]]}")
|
|
return status == "PASS"
|
|
|
|
|
|
def benchmark_xserv_server(prompt, num_tokens=80, port=8080):
|
|
import urllib.request
|
|
data = json.dumps({
|
|
"model": "qwen3-8b",
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": num_tokens,
|
|
"temperature": 0,
|
|
"stream": False
|
|
}).encode()
|
|
req = urllib.request.Request(
|
|
f"http://localhost:{port}/v1/chat/completions",
|
|
data=data, headers={"Content-Type": "application/json"}
|
|
)
|
|
start = time.perf_counter()
|
|
with urllib.request.urlopen(req, timeout=120) as resp:
|
|
result = json.loads(resp.read())
|
|
elapsed = time.perf_counter() - start
|
|
content = result["choices"][0]["message"]["content"]
|
|
usage = result.get("usage", {})
|
|
ct = usage.get("completion_tokens", 0)
|
|
return ct / elapsed if elapsed > 0 else 0, elapsed, ct, content
|
|
|
|
|
|
def main():
|
|
with_server = "--with-server" in sys.argv
|
|
|
|
print("=" * 70)
|
|
print("xserv vs HuggingFace Transformers — Correctness & Performance")
|
|
print("=" * 70)
|
|
print(f"Model: {MODEL_DIR}")
|
|
print(f"Device: {DEVICE}\n")
|
|
|
|
# ── Phase A: Run xserv first (separate processes, each loads+runs+exits) ──
|
|
test_prompts = [
|
|
("english", "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"),
|
|
("chinese", "<|im_start|>user\n请介绍一下量子计算<|im_end|>\n<|im_start|>assistant\n"),
|
|
("code", "<|im_start|>user\nWrite a Python function to sort a list<|im_end|>\n<|im_start|>assistant\n"),
|
|
("multi_turn",
|
|
"<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi!<|im_end|>\n<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\n"),
|
|
]
|
|
|
|
print("=" * 50)
|
|
print("PART 1: Collecting xserv prefill logits")
|
|
print("=" * 50)
|
|
|
|
xs_results = {}
|
|
for name, prompt in test_prompts:
|
|
print(f" Running xserv dump-logits [{name}]...")
|
|
top20, elapsed = xserv_dump_logits(prompt)
|
|
xs_results[name] = top20
|
|
if top20:
|
|
print(f" OK ({len(top20)} logits, {elapsed:.1f}s)")
|
|
else:
|
|
print(f" FAILED ({elapsed:.1f}s)")
|
|
|
|
# ── Phase B: Load HF model and compare ──
|
|
print(f"\n{'=' * 50}")
|
|
print("PART 2: Loading HF model for comparison")
|
|
print("=" * 50)
|
|
|
|
import torch
|
|
import numpy as np
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
print("Loading HF model (BF16)...")
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_DIR, dtype=torch.bfloat16,
|
|
device_map=DEVICE, trust_remote_code=True
|
|
)
|
|
model.eval()
|
|
print("HF model loaded.\n")
|
|
|
|
print("=" * 50)
|
|
print("PART 3: Correctness Comparison")
|
|
print("=" * 50)
|
|
|
|
all_pass = True
|
|
for name, prompt in test_prompts:
|
|
hf_top20 = hf_prefill_top20(model, tokenizer, prompt)
|
|
if not compare_top20(hf_top20, xs_results[name], name):
|
|
all_pass = False
|
|
|
|
print(f"\n Overall: {'ALL PASS' if all_pass else 'SOME FAILED'}\n")
|
|
|
|
# ── Phase C: Performance benchmark ──
|
|
print("=" * 50)
|
|
print("PART 4: HF Decode Performance (greedy, batch=1)")
|
|
print("=" * 50)
|
|
|
|
bench_prompt = "<|im_start|>user\nExplain the theory of relativity in simple terms.<|im_end|>\n<|im_start|>assistant\n"
|
|
|
|
# Warmup
|
|
print("\nWarmup...")
|
|
hf_generate(model, tokenizer, bench_prompt, max_new=5)
|
|
|
|
# Benchmark multiple token counts
|
|
for num_tokens in [50, 80]:
|
|
hf_tps, hf_time, hf_gen, hf_text = hf_generate(model, tokenizer, bench_prompt, max_new=num_tokens)
|
|
print(f" HF ({num_tokens} tokens): {hf_tps:.1f} tok/s, {hf_time:.2f}s, {hf_gen} generated")
|
|
|
|
# xserv server benchmark
|
|
if with_server:
|
|
print(f"\n{'=' * 50}")
|
|
print("PART 5: xserv Server Performance")
|
|
print("=" * 50)
|
|
try:
|
|
import urllib.request
|
|
urllib.request.urlopen("http://localhost:8080/health", timeout=3)
|
|
print("Server available. Benchmarking...\n")
|
|
|
|
# Warmup
|
|
benchmark_xserv_server("Hi", 5)
|
|
time.sleep(0.5)
|
|
|
|
for num_tokens in [50, 80]:
|
|
xs_tps, xs_time, xs_gen, xs_text = benchmark_xserv_server(
|
|
"Explain the theory of relativity in simple terms.", num_tokens
|
|
)
|
|
print(f" xserv ({num_tokens} tokens): {xs_tps:.1f} tok/s, {xs_time:.2f}s, {xs_gen} generated")
|
|
print(f" Text: {xs_text[:120]}...")
|
|
|
|
# EOS leak check
|
|
print(f"\n EOS Leak Check:")
|
|
_, _, _, content = benchmark_xserv_server("Say hello", 20)
|
|
has_eos = "<|im_end|>" in content or "<|endoftext|>" in content or "<|im_start|>" in content
|
|
print(f" Response has EOS token: {'YES (FAIL)' if has_eos else 'NO (PASS)'}")
|
|
if has_eos:
|
|
print(f" Content: {content}")
|
|
except Exception as e:
|
|
print(f"Server not available: {e}")
|
|
|
|
print(f"\n{'=' * 50}")
|
|
print("DONE")
|
|
print("=" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|