Files
xserv/tools/e2e_validate.py
Gahow Wang ee68d3565d fix: comprehensive review + 14 bug fixes + Phase 12/14 overhaul
Strict code review identified 30+ issues across correctness, performance,
and architecture. This commit addresses 14 of them with verified fixes,
restructures Phase 12 for honest continuous batching, and updates Phase 14
to target FA2 (RTX 5090 SM120 lacks TMEM required by FA4).

Bug fixes:
- FIX-01: Global cuBLAS handle (thread-local singleton, was per-call)
- FIX-02: Remove 19 unnecessary cudaDeviceSynchronize calls from kernels
- FIX-03: Qwen3 ChatML template (was plain text concatenation)
- FIX-04: EOS token from tokenizer (was hardcoded 151645)
- FIX-05: Storage tracks actual GPU device ordinal (was always Cuda(0))
- FIX-06: unsqueeze stride preserves contiguous layout
- FIX-08: CudaDeviceProp replaced with heap buffer (was UB-prone padding)
- FIX-09: Tokenizer byte_fallback to <0xNN> tokens (was panic)

Feature additions:
- FIX-10: SSE streaming (/v1/chat/completions, OpenAI-compatible)
- FIX-11: Correct usage statistics (prompt/completion/total tokens)
- FIX-13: Temperature / top-k / top-p sampling with SamplingParams

Performance improvements:
- FIX-07: Caching allocator wired up (thread-local pool, pooled flag)
- FIX-12: KV cache staging buffers (zero-alloc get_kv_len via borrow_raw)
- FIX-14: GPU strided copy kernel (eliminates contiguous() CPU round-trip)

Architecture:
- Phase 12 engine restructured: prefill/decode separation, honest TODO
  for batched GPU forward (requires Flash Attention)
- Phase 14 updated: FA2 for SM120 (FA4 requires TMEM, absent on 5090)
- Qwen3-7B → Qwen3-8B typo fixed across all docs (36 layers, hidden 4096)

Validated on dash5 (8x RTX 5090):
- 52/52 API prompts pass (EN/CN/code), SSE streaming verified
- Logits match HF transformers 9/10 top-1, 4.0/5 avg top-5 overlap
- 8 concurrent requests: 5.99x scheduling speedup (batch_size=4)
- Throughput: 10.3 tok/s (serial), 30% of HF baseline

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 17:53:28 +08:00

395 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
End-to-end validation for xserv after bug fixes.
1. Correctness: compare top-k logits with HuggingFace transformers
2. Generation: run 50+ prompts through the HTTP API
3. Performance: measure latency and throughput
Usage:
# Step 1: Start xserv server in background:
# ./target/release/xserv-server /opt/wjh/models/qwen3-8b --port 8080
#
# Step 2: Run this script:
# python3 tools/e2e_validate.py --mode all
# python3 tools/e2e_validate.py --mode logits # correctness only
# python3 tools/e2e_validate.py --mode api # API + perf only
"""
import argparse
import json
import time
import subprocess
import sys
import os
from pathlib import Path
MODEL_DIR = "/opt/wjh/models/qwen3-8b"
XSERV_URL = "http://localhost:8080"
TOP_K = 10
# 50+ diverse test prompts
TEST_PROMPTS = [
"What is the capital of France?",
"Explain quantum computing in simple terms.",
"Write a Python function to sort a list.",
"你好,请用中文介绍一下你自己。",
"What is 2 + 2?",
"The theory of relativity states that",
"In a far away galaxy,",
"def fibonacci(n):",
"请解释什么是机器学习。",
"How does photosynthesis work?",
"What are the benefits of exercise?",
"Once upon a time in a small village,",
"The most important invention of the 20th century was",
"Translate 'hello world' to Japanese.",
"What is the meaning of life?",
"Describe the process of making bread.",
"Why is the sky blue?",
"What is the difference between AI and ML?",
"如何评价GPT-4",
"Write a haiku about autumn.",
"Explain the Pythagorean theorem.",
"What causes earthquakes?",
"How does the internet work?",
"What is the speed of light?",
"Describe the water cycle.",
"What is democracy?",
"How do vaccines work?",
"What is blockchain technology?",
"Explain supply and demand.",
"What is the Big Bang theory?",
"How do airplanes fly?",
"What is climate change?",
"Describe the human digestive system.",
"What is artificial intelligence?",
"How does electricity work?",
"What is the solar system?",
"Explain the concept of gravity.",
"What is DNA?",
"How do computers store data?",
"What is the greenhouse effect?",
"Describe the structure of an atom.",
"What is machine learning?",
"How does Wi-Fi work?",
"What is the stock market?",
"Explain natural selection.",
"What is renewable energy?",
"How do batteries work?",
"What is the United Nations?",
"Describe the process of evolution.",
"What is cryptography?",
"请用三句话总结量子力学的核心概念。",
"用Python写一个计算斐波那契数列的函数。",
]
def logits_correctness_test():
"""Compare xserv prefill logits with HuggingFace transformers."""
print("\n" + "=" * 60)
print("CORRECTNESS TEST: Comparing logits with HuggingFace")
print("=" * 60)
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
print("SKIP: transformers/torch not installed")
return None
print(f"Loading HF model from {MODEL_DIR}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
torch_dtype=torch.bfloat16,
device_map="cuda:1", # Use GPU 1 (xserv uses GPU 0)
trust_remote_code=True,
)
model.eval()
test_prompts = TEST_PROMPTS[:10] # Use first 10 for logits comparison
xserv_bin = "/opt/wjh/projects/xserv/target/release/dump-logits"
results = []
for i, prompt in enumerate(test_prompts):
print(f"\n[{i+1}/{len(test_prompts)}] Prompt: {prompt[:50]}...")
# --- HuggingFace ---
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
hf_logits = outputs.logits[0, -1, :].float().cpu()
hf_top = torch.topk(hf_logits, TOP_K)
hf_ids = hf_top.indices.tolist()
hf_vals = hf_top.values.tolist()
# --- xserv ---
try:
result = subprocess.run(
[xserv_bin, MODEL_DIR, prompt],
capture_output=True, text=True, timeout=120,
env={**os.environ, "CUDA_VISIBLE_DEVICES": "0",
"PATH": "/usr/local/cuda-12.9/bin:" + os.environ.get("PATH", "")},
)
xserv_lines = [l for l in result.stdout.strip().split('\n') if l.strip().startswith('[')]
xserv_top = []
for line in xserv_lines[:TOP_K]:
parts = line.strip().split()
tid = int([p for p in parts if p.startswith('id=')][0].split('=')[1])
val = float([p for p in parts if p.startswith('logit=')][0].split('=')[1])
xserv_top.append((tid, val))
except Exception as e:
print(f" xserv FAILED: {e}")
results.append({"prompt": prompt, "match": False, "error": str(e)})
continue
# --- Compare ---
xserv_ids = [t[0] for t in xserv_top]
xserv_vals = [t[1] for t in xserv_top]
# Top-1 match
top1_match = hf_ids[0] == xserv_ids[0] if xserv_ids else False
# Top-5 overlap
top5_overlap = len(set(hf_ids[:5]) & set(xserv_ids[:5]))
# Max logit difference for matching tokens
max_diff = 0
for j, (hid, hval) in enumerate(zip(hf_ids[:5], hf_vals[:5])):
for xid, xval in xserv_top[:5]:
if hid == xid:
max_diff = max(max_diff, abs(hval - xval))
hf_tok = tokenizer.decode([hf_ids[0]])
xs_tok = tokenizer.decode([xserv_ids[0]]) if xserv_ids else "???"
status = "PASS" if top1_match else "WARN"
print(f" Top-1: HF={hf_ids[0]}({hf_tok!r}) vs xserv={xserv_ids[0]}({xs_tok!r}) → {status}")
print(f" Top-5 overlap: {top5_overlap}/5, max logit diff: {max_diff:.4f}")
results.append({
"prompt": prompt[:50],
"top1_match": top1_match,
"top5_overlap": top5_overlap,
"max_logit_diff": max_diff,
"hf_top1": f"{hf_ids[0]}({hf_tok})",
"xserv_top1": f"{xserv_ids[0]}({xs_tok})" if xserv_ids else "???",
})
# Summary
print("\n" + "-" * 40)
top1_matches = sum(1 for r in results if r.get("top1_match"))
avg_overlap = sum(r.get("top5_overlap", 0) for r in results) / max(len(results), 1)
print(f"Top-1 match: {top1_matches}/{len(results)}")
print(f"Avg top-5 overlap: {avg_overlap:.1f}/5")
print(f"Verdict: {'PASS' if top1_matches >= len(results) * 0.8 else 'FAIL'}")
# Cleanup
del model
torch.cuda.empty_cache()
return results
def api_generation_test():
"""Test 50+ prompts through the HTTP API."""
print("\n" + "=" * 60)
print("API GENERATION TEST: 50+ prompts via /v1/chat/completions")
print("=" * 60)
import urllib.request
import urllib.error
# Health check
try:
req = urllib.request.Request(f"{XSERV_URL}/health")
resp = urllib.request.urlopen(req, timeout=5)
assert resp.read().decode() == "ok"
print("Health check: OK")
except Exception as e:
print(f"FAIL: Server not reachable at {XSERV_URL}: {e}")
print("Start the server first: ./target/release/xserv-server /opt/wjh/models/qwen3-8b")
return None
# Models endpoint
try:
req = urllib.request.Request(f"{XSERV_URL}/v1/models")
resp = urllib.request.urlopen(req, timeout=5)
models = json.loads(resp.read())
print(f"Models: {[m['id'] for m in models['data']]}")
except Exception as e:
print(f"WARN: /v1/models failed: {e}")
results = []
total_prompt_tokens = 0
total_completion_tokens = 0
total_latency = 0
failures = 0
for i, prompt in enumerate(TEST_PROMPTS):
body = json.dumps({
"model": "qwen3-8b",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 32,
"temperature": 0.0,
}).encode()
try:
req = urllib.request.Request(
f"{XSERV_URL}/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
t0 = time.time()
resp = urllib.request.urlopen(req, timeout=120)
latency = time.time() - t0
data = json.loads(resp.read())
content = data["choices"][0]["message"]["content"]
finish = data["choices"][0]["finish_reason"]
usage = data.get("usage", {})
pt = usage.get("prompt_tokens", 0)
ct = usage.get("completion_tokens", 0)
total_prompt_tokens += pt
total_completion_tokens += ct
total_latency += latency
# Basic quality checks
has_content = len(content.strip()) > 0
reasonable_length = ct > 0
status = "OK" if has_content and reasonable_length else "WARN"
if not has_content:
status = "FAIL"
failures += 1
truncated = content[:60].replace('\n', ' ')
print(f" [{i+1:2d}/{len(TEST_PROMPTS)}] {status} | {latency:5.2f}s | pt={pt:3d} ct={ct:2d} | {truncated}...")
results.append({
"prompt": prompt[:40],
"status": status,
"latency": latency,
"prompt_tokens": pt,
"completion_tokens": ct,
"finish_reason": finish,
"content_preview": content[:80],
})
except Exception as e:
print(f" [{i+1:2d}/{len(TEST_PROMPTS)}] FAIL | {e}")
failures += 1
results.append({"prompt": prompt[:40], "status": "FAIL", "error": str(e)})
# Summary
successes = len(results) - failures
avg_latency = total_latency / max(successes, 1)
tok_per_sec = total_completion_tokens / max(total_latency, 0.001)
print("\n" + "-" * 40)
print(f"Results: {successes}/{len(TEST_PROMPTS)} succeeded, {failures} failed")
print(f"Total prompt tokens: {total_prompt_tokens}")
print(f"Total completion tokens: {total_completion_tokens}")
print(f"Average latency: {avg_latency:.2f}s per request")
print(f"Throughput: {tok_per_sec:.1f} tokens/s (completion only)")
print(f"Verdict: {'PASS' if failures <= 2 else 'FAIL'}")
return results
def streaming_test():
"""Test SSE streaming works correctly."""
print("\n" + "=" * 60)
print("STREAMING TEST: SSE /v1/chat/completions?stream=true")
print("=" * 60)
import urllib.request
import urllib.error
body = json.dumps({
"model": "qwen3-8b",
"messages": [{"role": "user", "content": "Count from 1 to 5."}],
"max_tokens": 32,
"temperature": 0.0,
"stream": True,
}).encode()
req = urllib.request.Request(
f"{XSERV_URL}/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
try:
resp = urllib.request.urlopen(req, timeout=60)
content_type = resp.headers.get("content-type", "")
print(f"Content-Type: {content_type}")
chunks = []
full_text = ""
has_role_chunk = False
has_done = False
has_finish = False
for line in resp:
line = line.decode().strip()
if not line:
continue
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
has_done = True
chunks.append("[DONE]")
continue
try:
obj = json.loads(data)
delta = obj["choices"][0]["delta"]
fr = obj["choices"][0].get("finish_reason")
if "role" in delta:
has_role_chunk = True
if "content" in delta:
full_text += delta["content"]
if fr is not None:
has_finish = True
chunks.append(delta)
except json.JSONDecodeError:
print(f" WARN: bad JSON: {data[:80]}")
print(f"Chunks received: {len(chunks)}")
print(f"Has role chunk: {has_role_chunk}")
print(f"Has finish_reason: {has_finish}")
print(f"Has [DONE]: {has_done}")
print(f"Full text: {full_text[:100]!r}")
ok = has_role_chunk and has_done and has_finish and len(full_text) > 0
# SSE content-type check
if "text/event-stream" in content_type:
print("Content-Type: OK (text/event-stream)")
else:
print(f"WARN: Expected text/event-stream, got {content_type}")
print(f"Verdict: {'PASS' if ok else 'FAIL'}")
return ok
except Exception as e:
print(f"FAIL: {e}")
return False
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["all", "logits", "api", "stream"], default="all")
args = parser.parse_args()
if args.mode in ("all", "logits"):
logits_correctness_test()
if args.mode in ("all", "api"):
api_generation_test()
if args.mode in ("all", "stream"):
streaming_test()
if __name__ == "__main__":
main()