diff --git a/tools/bench_server.py b/tools/bench_server.py new file mode 100644 index 0000000..52e7c75 --- /dev/null +++ b/tools/bench_server.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +"""Benchmark xserv server performance and check correctness vs HF.""" + +import json +import time +import sys +import urllib.request + +PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 8090 + +def chat(prompt, max_tokens=80, temperature=0): + data = json.dumps({ + "model": "qwen3-8b", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "temperature": temperature, + "stream": False + }).encode() + req = urllib.request.Request( + f"http://localhost:{PORT}/v1/chat/completions", + data=data, headers={"Content-Type": "application/json"} + ) + t0 = time.perf_counter() + with urllib.request.urlopen(req, timeout=180) as resp: + result = json.loads(resp.read()) + elapsed = time.perf_counter() - t0 + usage = result.get("usage", {}) + content = result["choices"][0]["message"]["content"] + finish = result["choices"][0]["finish_reason"] + ct = usage.get("completion_tokens", 0) + pt = usage.get("prompt_tokens", 0) + return ct / elapsed if elapsed > 0 else 0, elapsed, ct, pt, content, finish + + +def chat_stream(prompt, max_tokens=80, temperature=0): + data = json.dumps({ + "model": "qwen3-8b", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "temperature": temperature, + "stream": True + }).encode() + req = urllib.request.Request( + f"http://localhost:{PORT}/v1/chat/completions", + data=data, headers={"Content-Type": "application/json"} + ) + t0 = time.perf_counter() + tokens = 0 + content = "" + with urllib.request.urlopen(req, timeout=180) as resp: + for line in resp: + line = line.decode().strip() + if line.startswith("data: "): + payload = line[6:] + if payload == "[DONE]": + break + chunk = json.loads(payload) + delta = chunk["choices"][0].get("delta", {}) + c = delta.get("content", "") + if c: + tokens += 1 + content += c + elapsed = time.perf_counter() - t0 + return tokens / elapsed if elapsed > 0 else 0, elapsed, tokens, content + + +print("=" * 60) +print(f"xserv Server Benchmark (port {PORT})") +print("=" * 60) + +# Health check +try: + urllib.request.urlopen(f"http://localhost:{PORT}/health", timeout=3) +except: + print(f"Server not responding on port {PORT}") + sys.exit(1) + +# 1. EOS leak check +print("\n--- EOS Leak Check ---") +tps, t, ct, pt, content, finish = chat("Say hello", 30) +has_eos = "<|im_end|>" in content or "<|endoftext|>" in content or "<|im_start|>" in content +print(f" finish_reason: {finish}") +print(f" EOS in content: {'YES (BUG!)' if has_eos else 'NO (good)'}") +print(f" Content: {content[:100]}") + +# 2. Warmup +print("\n--- Warmup ---") +chat("Hi", 10) +time.sleep(0.5) + +# 3. Non-streaming benchmark +print("\n--- Non-streaming Performance (greedy, batch=1) ---") +prompts = [ + ("short", "What is 2+2?", 50), + ("medium", "Explain quantum computing in simple terms.", 80), + ("long", "Write a detailed comparison of Python and Rust programming languages, covering syntax, performance, memory management, and ecosystem.", 150), +] + +for name, prompt, max_tok in prompts: + tps, t, ct, pt, content, finish = chat(prompt, max_tok) + print(f" [{name}] {tps:.1f} tok/s | {ct} tokens in {t:.2f}s | prompt={pt} | finish={finish}") + +# 4. Streaming benchmark +print("\n--- Streaming Performance ---") +tps, t, ct, content = chat_stream("Explain the theory of relativity.", 80) +print(f" stream: {tps:.1f} tok/s | {ct} tokens in {t:.2f}s") + +# 5. max_tokens validation +print("\n--- max_tokens Validation ---") +try: + tps, t, ct, pt, content, finish = chat("Hi", 999999) + print(f" max_tokens=999999: OK (server clamped to {ct} tokens, no crash)") +except Exception as e: + print(f" max_tokens=999999: {e}") + +# 6. Concurrent requests (if server supports batching) +print("\n--- Concurrent Requests (2 parallel) ---") +import threading +results = [None, None] + +def do_request(idx, prompt, max_tok): + results[idx] = chat(prompt, max_tok) + +t1 = threading.Thread(target=do_request, args=(0, "What is gravity?", 50)) +t2 = threading.Thread(target=do_request, args=(1, "What is light?", 50)) +t0 = time.perf_counter() +t1.start(); t2.start() +t1.join(); t2.join() +wall_time = time.perf_counter() - t0 + +total_tokens = sum(r[2] for r in results if r) +combined_tps = total_tokens / wall_time +print(f" 2 concurrent: {combined_tps:.1f} tok/s total | wall={wall_time:.2f}s") +for i, r in enumerate(results): + if r: + print(f" req{i}: {r[0]:.1f} tok/s, {r[2]} tokens in {r[1]:.2f}s") + +print(f"\n{'=' * 60}") +print("DONE") +print("=" * 60) diff --git a/tools/test_correctness.py b/tools/test_correctness.py new file mode 100644 index 0000000..fda5711 --- /dev/null +++ b/tools/test_correctness.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +"""Compare xserv vs HuggingFace transformers for correctness and performance. + +Strategy: run xserv first (on GPU 0), collect results, then load HF model +on GPU 0 (xserv process exits and frees VRAM). +""" + +import subprocess +import time +import json +import sys +import os +import re + +DEVICE = "cuda:0" +MODEL_DIR = "/opt/wjh/models/qwen3-8b" +XSERV_DUMP = "/opt/wjh/projects/xserv/target/release/dump-logits" + + +def xserv_dump_logits(prompt): + """Run xserv dump-logits and parse top-20.""" + env = os.environ.copy() + env["PATH"] = "/usr/local/cuda-12.9/bin:" + env.get("PATH", "") + env["CUDA_VISIBLE_DEVICES"] = "0" + t0 = time.perf_counter() + result = subprocess.run( + [XSERV_DUMP, MODEL_DIR, prompt], + capture_output=True, text=True, timeout=180, env=env + ) + elapsed = time.perf_counter() - t0 + if result.returncode != 0: + print(f" xserv error: {result.stderr[-500:]}") + return None, elapsed + + top20 = [] + for line in result.stdout.strip().split("\n"): + m = re.match(r'\s*\[\s*\d+\]\s+id=\s*(\d+)\s+logit=\s*([\-\d.]+)', line) + if m: + top20.append((int(m.group(1)), float(m.group(2)))) + return top20, elapsed + + +def hf_prefill_top20(model, tokenizer, prompt): + """Get top-20 logits from HF.""" + import torch + inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits[0, -1, :].float().cpu().numpy() + import numpy as np + top_ids = np.argsort(logits)[-20:][::-1] + return [(int(i), float(logits[i])) for i in top_ids] + + +def hf_generate(model, tokenizer, prompt, max_new=80): + """Greedy generation from HF.""" + import torch + inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) + prompt_len = inputs["input_ids"].shape[1] + torch.cuda.synchronize() + t0 = time.perf_counter() + with torch.no_grad(): + out = model.generate(**inputs, max_new_tokens=max_new, do_sample=False) + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + gen_tokens = out.shape[1] - prompt_len + text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True) + return gen_tokens / elapsed, elapsed, gen_tokens, text + + +def compare_top20(hf_top20, xs_top20, name): + if xs_top20 is None: + print(f" [{name}] SKIP (xserv failed)") + return False + + hf_ids = [x[0] for x in hf_top20] + xs_ids = [x[0] for x in xs_top20] + top1_match = hf_ids[0] == xs_ids[0] + top5_match = set(hf_ids[:5]) == set(xs_ids[:5]) + top10_overlap = len(set(hf_ids[:10]) & set(xs_ids[:10])) + + hf_dict = dict(hf_top20) + xs_dict = dict(xs_top20) + common = set(hf_dict.keys()) & set(xs_dict.keys()) + if common: + diffs = [abs(hf_dict[k] - xs_dict[k]) for k in common] + max_diff = max(diffs) + mean_diff = sum(diffs) / len(diffs) + else: + max_diff = mean_diff = float('inf') + + status = "PASS" if top1_match and top5_match else "FAIL" + print(f" [{name}] {status}: top1={'Y' if top1_match else 'N'}, " + f"top5={'Y' if top5_match else 'N'}, top10={top10_overlap}/10, " + f"max_diff={max_diff:.4f}, mean_diff={mean_diff:.4f}") + print(f" HF top5: {[(i, f'{v:.2f}') for i, v in hf_top20[:5]]}") + print(f" XS top5: {[(i, f'{v:.2f}') for i, v in xs_top20[:5]]}") + return status == "PASS" + + +def benchmark_xserv_server(prompt, num_tokens=80, port=8080): + import urllib.request + data = json.dumps({ + "model": "qwen3-8b", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": num_tokens, + "temperature": 0, + "stream": False + }).encode() + req = urllib.request.Request( + f"http://localhost:{port}/v1/chat/completions", + data=data, headers={"Content-Type": "application/json"} + ) + start = time.perf_counter() + with urllib.request.urlopen(req, timeout=120) as resp: + result = json.loads(resp.read()) + elapsed = time.perf_counter() - start + content = result["choices"][0]["message"]["content"] + usage = result.get("usage", {}) + ct = usage.get("completion_tokens", 0) + return ct / elapsed if elapsed > 0 else 0, elapsed, ct, content + + +def main(): + with_server = "--with-server" in sys.argv + + print("=" * 70) + print("xserv vs HuggingFace Transformers — Correctness & Performance") + print("=" * 70) + print(f"Model: {MODEL_DIR}") + print(f"Device: {DEVICE}\n") + + # ── Phase A: Run xserv first (separate processes, each loads+runs+exits) ── + test_prompts = [ + ("english", "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"), + ("chinese", "<|im_start|>user\n请介绍一下量子计算<|im_end|>\n<|im_start|>assistant\n"), + ("code", "<|im_start|>user\nWrite a Python function to sort a list<|im_end|>\n<|im_start|>assistant\n"), + ("multi_turn", + "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi!<|im_end|>\n<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\n"), + ] + + print("=" * 50) + print("PART 1: Collecting xserv prefill logits") + print("=" * 50) + + xs_results = {} + for name, prompt in test_prompts: + print(f" Running xserv dump-logits [{name}]...") + top20, elapsed = xserv_dump_logits(prompt) + xs_results[name] = top20 + if top20: + print(f" OK ({len(top20)} logits, {elapsed:.1f}s)") + else: + print(f" FAILED ({elapsed:.1f}s)") + + # ── Phase B: Load HF model and compare ── + print(f"\n{'=' * 50}") + print("PART 2: Loading HF model for comparison") + print("=" * 50) + + import torch + import numpy as np + from transformers import AutoModelForCausalLM, AutoTokenizer + + print("Loading HF model (BF16)...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + MODEL_DIR, dtype=torch.bfloat16, + device_map=DEVICE, trust_remote_code=True + ) + model.eval() + print("HF model loaded.\n") + + print("=" * 50) + print("PART 3: Correctness Comparison") + print("=" * 50) + + all_pass = True + for name, prompt in test_prompts: + hf_top20 = hf_prefill_top20(model, tokenizer, prompt) + if not compare_top20(hf_top20, xs_results[name], name): + all_pass = False + + print(f"\n Overall: {'ALL PASS' if all_pass else 'SOME FAILED'}\n") + + # ── Phase C: Performance benchmark ── + print("=" * 50) + print("PART 4: HF Decode Performance (greedy, batch=1)") + print("=" * 50) + + bench_prompt = "<|im_start|>user\nExplain the theory of relativity in simple terms.<|im_end|>\n<|im_start|>assistant\n" + + # Warmup + print("\nWarmup...") + hf_generate(model, tokenizer, bench_prompt, max_new=5) + + # Benchmark multiple token counts + for num_tokens in [50, 80]: + hf_tps, hf_time, hf_gen, hf_text = hf_generate(model, tokenizer, bench_prompt, max_new=num_tokens) + print(f" HF ({num_tokens} tokens): {hf_tps:.1f} tok/s, {hf_time:.2f}s, {hf_gen} generated") + + # xserv server benchmark + if with_server: + print(f"\n{'=' * 50}") + print("PART 5: xserv Server Performance") + print("=" * 50) + try: + import urllib.request + urllib.request.urlopen("http://localhost:8080/health", timeout=3) + print("Server available. Benchmarking...\n") + + # Warmup + benchmark_xserv_server("Hi", 5) + time.sleep(0.5) + + for num_tokens in [50, 80]: + xs_tps, xs_time, xs_gen, xs_text = benchmark_xserv_server( + "Explain the theory of relativity in simple terms.", num_tokens + ) + print(f" xserv ({num_tokens} tokens): {xs_tps:.1f} tok/s, {xs_time:.2f}s, {xs_gen} generated") + print(f" Text: {xs_text[:120]}...") + + # EOS leak check + print(f"\n EOS Leak Check:") + _, _, _, content = benchmark_xserv_server("Say hello", 20) + has_eos = "<|im_end|>" in content or "<|endoftext|>" in content or "<|im_start|>" in content + print(f" Response has EOS token: {'YES (FAIL)' if has_eos else 'NO (PASS)'}") + if has_eos: + print(f" Content: {content}") + except Exception as e: + print(f"Server not available: {e}") + + print(f"\n{'=' * 50}") + print("DONE") + print("=" * 50) + + +if __name__ == "__main__": + main()