#!/usr/bin/env python3 """ End-to-end validation for xserv after bug fixes. 1. Correctness: compare top-k logits with HuggingFace transformers 2. Generation: run 50+ prompts through the HTTP API 3. Performance: measure latency and throughput Usage: # Step 1: Start xserv server in background: # ./target/release/xserv-server /opt/wjh/models/qwen3-8b --port 8080 # # Step 2: Run this script: # python3 tools/e2e_validate.py --mode all # python3 tools/e2e_validate.py --mode logits # correctness only # python3 tools/e2e_validate.py --mode api # API + perf only """ import argparse import json import time import subprocess import sys import os from pathlib import Path MODEL_DIR = "/opt/wjh/models/qwen3-8b" XSERV_URL = "http://localhost:8080" TOP_K = 10 # 50+ diverse test prompts TEST_PROMPTS = [ "What is the capital of France?", "Explain quantum computing in simple terms.", "Write a Python function to sort a list.", "你好,请用中文介绍一下你自己。", "What is 2 + 2?", "The theory of relativity states that", "In a far away galaxy,", "def fibonacci(n):", "请解释什么是机器学习。", "How does photosynthesis work?", "What are the benefits of exercise?", "Once upon a time in a small village,", "The most important invention of the 20th century was", "Translate 'hello world' to Japanese.", "What is the meaning of life?", "Describe the process of making bread.", "Why is the sky blue?", "What is the difference between AI and ML?", "如何评价GPT-4?", "Write a haiku about autumn.", "Explain the Pythagorean theorem.", "What causes earthquakes?", "How does the internet work?", "What is the speed of light?", "Describe the water cycle.", "What is democracy?", "How do vaccines work?", "What is blockchain technology?", "Explain supply and demand.", "What is the Big Bang theory?", "How do airplanes fly?", "What is climate change?", "Describe the human digestive system.", "What is artificial intelligence?", "How does electricity work?", "What is the solar system?", "Explain the concept of gravity.", "What is DNA?", "How do computers store data?", "What is the greenhouse effect?", "Describe the structure of an atom.", "What is machine learning?", "How does Wi-Fi work?", "What is the stock market?", "Explain natural selection.", "What is renewable energy?", "How do batteries work?", "What is the United Nations?", "Describe the process of evolution.", "What is cryptography?", "请用三句话总结量子力学的核心概念。", "用Python写一个计算斐波那契数列的函数。", ] def logits_correctness_test(): """Compare xserv prefill logits with HuggingFace transformers.""" print("\n" + "=" * 60) print("CORRECTNESS TEST: Comparing logits with HuggingFace") print("=" * 60) try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: print("SKIP: transformers/torch not installed") return None print(f"Loading HF model from {MODEL_DIR}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, torch_dtype=torch.bfloat16, device_map="cuda:1", # Use GPU 1 (xserv uses GPU 0) trust_remote_code=True, ) model.eval() test_prompts = TEST_PROMPTS[:10] # Use first 10 for logits comparison xserv_bin = "/opt/wjh/projects/xserv/target/release/dump-logits" results = [] for i, prompt in enumerate(test_prompts): print(f"\n[{i+1}/{len(test_prompts)}] Prompt: {prompt[:50]}...") # --- HuggingFace --- inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(**inputs) hf_logits = outputs.logits[0, -1, :].float().cpu() hf_top = torch.topk(hf_logits, TOP_K) hf_ids = hf_top.indices.tolist() hf_vals = hf_top.values.tolist() # --- xserv --- try: result = subprocess.run( [xserv_bin, MODEL_DIR, prompt], capture_output=True, text=True, timeout=120, env={**os.environ, "CUDA_VISIBLE_DEVICES": "0", "PATH": "/usr/local/cuda-12.9/bin:" + os.environ.get("PATH", "")}, ) xserv_lines = [l for l in result.stdout.strip().split('\n') if l.strip().startswith('[')] xserv_top = [] for line in xserv_lines[:TOP_K]: parts = line.strip().split() tid = int([p for p in parts if p.startswith('id=')][0].split('=')[1]) val = float([p for p in parts if p.startswith('logit=')][0].split('=')[1]) xserv_top.append((tid, val)) except Exception as e: print(f" xserv FAILED: {e}") results.append({"prompt": prompt, "match": False, "error": str(e)}) continue # --- Compare --- xserv_ids = [t[0] for t in xserv_top] xserv_vals = [t[1] for t in xserv_top] # Top-1 match top1_match = hf_ids[0] == xserv_ids[0] if xserv_ids else False # Top-5 overlap top5_overlap = len(set(hf_ids[:5]) & set(xserv_ids[:5])) # Max logit difference for matching tokens max_diff = 0 for j, (hid, hval) in enumerate(zip(hf_ids[:5], hf_vals[:5])): for xid, xval in xserv_top[:5]: if hid == xid: max_diff = max(max_diff, abs(hval - xval)) hf_tok = tokenizer.decode([hf_ids[0]]) xs_tok = tokenizer.decode([xserv_ids[0]]) if xserv_ids else "???" status = "PASS" if top1_match else "WARN" print(f" Top-1: HF={hf_ids[0]}({hf_tok!r}) vs xserv={xserv_ids[0]}({xs_tok!r}) → {status}") print(f" Top-5 overlap: {top5_overlap}/5, max logit diff: {max_diff:.4f}") results.append({ "prompt": prompt[:50], "top1_match": top1_match, "top5_overlap": top5_overlap, "max_logit_diff": max_diff, "hf_top1": f"{hf_ids[0]}({hf_tok})", "xserv_top1": f"{xserv_ids[0]}({xs_tok})" if xserv_ids else "???", }) # Summary print("\n" + "-" * 40) top1_matches = sum(1 for r in results if r.get("top1_match")) avg_overlap = sum(r.get("top5_overlap", 0) for r in results) / max(len(results), 1) print(f"Top-1 match: {top1_matches}/{len(results)}") print(f"Avg top-5 overlap: {avg_overlap:.1f}/5") print(f"Verdict: {'PASS' if top1_matches >= len(results) * 0.8 else 'FAIL'}") # Cleanup del model torch.cuda.empty_cache() return results def api_generation_test(): """Test 50+ prompts through the HTTP API.""" print("\n" + "=" * 60) print("API GENERATION TEST: 50+ prompts via /v1/chat/completions") print("=" * 60) import urllib.request import urllib.error # Health check try: req = urllib.request.Request(f"{XSERV_URL}/health") resp = urllib.request.urlopen(req, timeout=5) assert resp.read().decode() == "ok" print("Health check: OK") except Exception as e: print(f"FAIL: Server not reachable at {XSERV_URL}: {e}") print("Start the server first: ./target/release/xserv-server /opt/wjh/models/qwen3-8b") return None # Models endpoint try: req = urllib.request.Request(f"{XSERV_URL}/v1/models") resp = urllib.request.urlopen(req, timeout=5) models = json.loads(resp.read()) print(f"Models: {[m['id'] for m in models['data']]}") except Exception as e: print(f"WARN: /v1/models failed: {e}") results = [] total_prompt_tokens = 0 total_completion_tokens = 0 total_latency = 0 failures = 0 for i, prompt in enumerate(TEST_PROMPTS): body = json.dumps({ "model": "qwen3-8b", "messages": [{"role": "user", "content": prompt}], "max_tokens": 32, "temperature": 0.0, }).encode() try: req = urllib.request.Request( f"{XSERV_URL}/v1/chat/completions", data=body, headers={"Content-Type": "application/json"}, ) t0 = time.time() resp = urllib.request.urlopen(req, timeout=120) latency = time.time() - t0 data = json.loads(resp.read()) content = data["choices"][0]["message"]["content"] finish = data["choices"][0]["finish_reason"] usage = data.get("usage", {}) pt = usage.get("prompt_tokens", 0) ct = usage.get("completion_tokens", 0) total_prompt_tokens += pt total_completion_tokens += ct total_latency += latency # Basic quality checks has_content = len(content.strip()) > 0 reasonable_length = ct > 0 status = "OK" if has_content and reasonable_length else "WARN" if not has_content: status = "FAIL" failures += 1 truncated = content[:60].replace('\n', ' ') print(f" [{i+1:2d}/{len(TEST_PROMPTS)}] {status} | {latency:5.2f}s | pt={pt:3d} ct={ct:2d} | {truncated}...") results.append({ "prompt": prompt[:40], "status": status, "latency": latency, "prompt_tokens": pt, "completion_tokens": ct, "finish_reason": finish, "content_preview": content[:80], }) except Exception as e: print(f" [{i+1:2d}/{len(TEST_PROMPTS)}] FAIL | {e}") failures += 1 results.append({"prompt": prompt[:40], "status": "FAIL", "error": str(e)}) # Summary successes = len(results) - failures avg_latency = total_latency / max(successes, 1) tok_per_sec = total_completion_tokens / max(total_latency, 0.001) print("\n" + "-" * 40) print(f"Results: {successes}/{len(TEST_PROMPTS)} succeeded, {failures} failed") print(f"Total prompt tokens: {total_prompt_tokens}") print(f"Total completion tokens: {total_completion_tokens}") print(f"Average latency: {avg_latency:.2f}s per request") print(f"Throughput: {tok_per_sec:.1f} tokens/s (completion only)") print(f"Verdict: {'PASS' if failures <= 2 else 'FAIL'}") return results def streaming_test(): """Test SSE streaming works correctly.""" print("\n" + "=" * 60) print("STREAMING TEST: SSE /v1/chat/completions?stream=true") print("=" * 60) import urllib.request import urllib.error body = json.dumps({ "model": "qwen3-8b", "messages": [{"role": "user", "content": "Count from 1 to 5."}], "max_tokens": 32, "temperature": 0.0, "stream": True, }).encode() req = urllib.request.Request( f"{XSERV_URL}/v1/chat/completions", data=body, headers={"Content-Type": "application/json"}, ) try: resp = urllib.request.urlopen(req, timeout=60) content_type = resp.headers.get("content-type", "") print(f"Content-Type: {content_type}") chunks = [] full_text = "" has_role_chunk = False has_done = False has_finish = False for line in resp: line = line.decode().strip() if not line: continue if line.startswith("data: "): data = line[6:] if data == "[DONE]": has_done = True chunks.append("[DONE]") continue try: obj = json.loads(data) delta = obj["choices"][0]["delta"] fr = obj["choices"][0].get("finish_reason") if "role" in delta: has_role_chunk = True if "content" in delta: full_text += delta["content"] if fr is not None: has_finish = True chunks.append(delta) except json.JSONDecodeError: print(f" WARN: bad JSON: {data[:80]}") print(f"Chunks received: {len(chunks)}") print(f"Has role chunk: {has_role_chunk}") print(f"Has finish_reason: {has_finish}") print(f"Has [DONE]: {has_done}") print(f"Full text: {full_text[:100]!r}") ok = has_role_chunk and has_done and has_finish and len(full_text) > 0 # SSE content-type check if "text/event-stream" in content_type: print("Content-Type: OK (text/event-stream)") else: print(f"WARN: Expected text/event-stream, got {content_type}") print(f"Verdict: {'PASS' if ok else 'FAIL'}") return ok except Exception as e: print(f"FAIL: {e}") return False def main(): parser = argparse.ArgumentParser() parser.add_argument("--mode", choices=["all", "logits", "api", "stream"], default="all") args = parser.parse_args() if args.mode in ("all", "logits"): logits_correctness_test() if args.mode in ("all", "api"): api_generation_test() if args.mode in ("all", "stream"): streaming_test() if __name__ == "__main__": main()