- test_correctness.py: compare prefill logits top-20 vs HF transformers - bench_server.py: HTTP API benchmark (throughput, streaming, concurrent, EOS leak check) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
"""Benchmark xserv server performance and check correctness vs HF."""
|
|
|
|
import json
|
|
import time
|
|
import sys
|
|
import urllib.request
|
|
|
|
PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 8090
|
|
|
|
def chat(prompt, max_tokens=80, temperature=0):
|
|
data = json.dumps({
|
|
"model": "qwen3-8b",
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"stream": False
|
|
}).encode()
|
|
req = urllib.request.Request(
|
|
f"http://localhost:{PORT}/v1/chat/completions",
|
|
data=data, headers={"Content-Type": "application/json"}
|
|
)
|
|
t0 = time.perf_counter()
|
|
with urllib.request.urlopen(req, timeout=180) as resp:
|
|
result = json.loads(resp.read())
|
|
elapsed = time.perf_counter() - t0
|
|
usage = result.get("usage", {})
|
|
content = result["choices"][0]["message"]["content"]
|
|
finish = result["choices"][0]["finish_reason"]
|
|
ct = usage.get("completion_tokens", 0)
|
|
pt = usage.get("prompt_tokens", 0)
|
|
return ct / elapsed if elapsed > 0 else 0, elapsed, ct, pt, content, finish
|
|
|
|
|
|
def chat_stream(prompt, max_tokens=80, temperature=0):
|
|
data = json.dumps({
|
|
"model": "qwen3-8b",
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"stream": True
|
|
}).encode()
|
|
req = urllib.request.Request(
|
|
f"http://localhost:{PORT}/v1/chat/completions",
|
|
data=data, headers={"Content-Type": "application/json"}
|
|
)
|
|
t0 = time.perf_counter()
|
|
tokens = 0
|
|
content = ""
|
|
with urllib.request.urlopen(req, timeout=180) as resp:
|
|
for line in resp:
|
|
line = line.decode().strip()
|
|
if line.startswith("data: "):
|
|
payload = line[6:]
|
|
if payload == "[DONE]":
|
|
break
|
|
chunk = json.loads(payload)
|
|
delta = chunk["choices"][0].get("delta", {})
|
|
c = delta.get("content", "")
|
|
if c:
|
|
tokens += 1
|
|
content += c
|
|
elapsed = time.perf_counter() - t0
|
|
return tokens / elapsed if elapsed > 0 else 0, elapsed, tokens, content
|
|
|
|
|
|
print("=" * 60)
|
|
print(f"xserv Server Benchmark (port {PORT})")
|
|
print("=" * 60)
|
|
|
|
# Health check
|
|
try:
|
|
urllib.request.urlopen(f"http://localhost:{PORT}/health", timeout=3)
|
|
except:
|
|
print(f"Server not responding on port {PORT}")
|
|
sys.exit(1)
|
|
|
|
# 1. EOS leak check
|
|
print("\n--- EOS Leak Check ---")
|
|
tps, t, ct, pt, content, finish = chat("Say hello", 30)
|
|
has_eos = "<|im_end|>" in content or "<|endoftext|>" in content or "<|im_start|>" in content
|
|
print(f" finish_reason: {finish}")
|
|
print(f" EOS in content: {'YES (BUG!)' if has_eos else 'NO (good)'}")
|
|
print(f" Content: {content[:100]}")
|
|
|
|
# 2. Warmup
|
|
print("\n--- Warmup ---")
|
|
chat("Hi", 10)
|
|
time.sleep(0.5)
|
|
|
|
# 3. Non-streaming benchmark
|
|
print("\n--- Non-streaming Performance (greedy, batch=1) ---")
|
|
prompts = [
|
|
("short", "What is 2+2?", 50),
|
|
("medium", "Explain quantum computing in simple terms.", 80),
|
|
("long", "Write a detailed comparison of Python and Rust programming languages, covering syntax, performance, memory management, and ecosystem.", 150),
|
|
]
|
|
|
|
for name, prompt, max_tok in prompts:
|
|
tps, t, ct, pt, content, finish = chat(prompt, max_tok)
|
|
print(f" [{name}] {tps:.1f} tok/s | {ct} tokens in {t:.2f}s | prompt={pt} | finish={finish}")
|
|
|
|
# 4. Streaming benchmark
|
|
print("\n--- Streaming Performance ---")
|
|
tps, t, ct, content = chat_stream("Explain the theory of relativity.", 80)
|
|
print(f" stream: {tps:.1f} tok/s | {ct} tokens in {t:.2f}s")
|
|
|
|
# 5. max_tokens validation
|
|
print("\n--- max_tokens Validation ---")
|
|
try:
|
|
tps, t, ct, pt, content, finish = chat("Hi", 999999)
|
|
print(f" max_tokens=999999: OK (server clamped to {ct} tokens, no crash)")
|
|
except Exception as e:
|
|
print(f" max_tokens=999999: {e}")
|
|
|
|
# 6. Concurrent requests (if server supports batching)
|
|
print("\n--- Concurrent Requests (2 parallel) ---")
|
|
import threading
|
|
results = [None, None]
|
|
|
|
def do_request(idx, prompt, max_tok):
|
|
results[idx] = chat(prompt, max_tok)
|
|
|
|
t1 = threading.Thread(target=do_request, args=(0, "What is gravity?", 50))
|
|
t2 = threading.Thread(target=do_request, args=(1, "What is light?", 50))
|
|
t0 = time.perf_counter()
|
|
t1.start(); t2.start()
|
|
t1.join(); t2.join()
|
|
wall_time = time.perf_counter() - t0
|
|
|
|
total_tokens = sum(r[2] for r in results if r)
|
|
combined_tps = total_tokens / wall_time
|
|
print(f" 2 concurrent: {combined_tps:.1f} tok/s total | wall={wall_time:.2f}s")
|
|
for i, r in enumerate(results):
|
|
if r:
|
|
print(f" req{i}: {r[0]:.1f} tok/s, {r[2]} tokens in {r[1]:.2f}s")
|
|
|
|
print(f"\n{'=' * 60}")
|
|
print("DONE")
|
|
print("=" * 60)
|