Strict code review identified 30+ issues across correctness, performance, and architecture. This commit addresses 14 of them with verified fixes, restructures Phase 12 for honest continuous batching, and updates Phase 14 to target FA2 (RTX 5090 SM120 lacks TMEM required by FA4). Bug fixes: - FIX-01: Global cuBLAS handle (thread-local singleton, was per-call) - FIX-02: Remove 19 unnecessary cudaDeviceSynchronize calls from kernels - FIX-03: Qwen3 ChatML template (was plain text concatenation) - FIX-04: EOS token from tokenizer (was hardcoded 151645) - FIX-05: Storage tracks actual GPU device ordinal (was always Cuda(0)) - FIX-06: unsqueeze stride preserves contiguous layout - FIX-08: CudaDeviceProp replaced with heap buffer (was UB-prone padding) - FIX-09: Tokenizer byte_fallback to <0xNN> tokens (was panic) Feature additions: - FIX-10: SSE streaming (/v1/chat/completions, OpenAI-compatible) - FIX-11: Correct usage statistics (prompt/completion/total tokens) - FIX-13: Temperature / top-k / top-p sampling with SamplingParams Performance improvements: - FIX-07: Caching allocator wired up (thread-local pool, pooled flag) - FIX-12: KV cache staging buffers (zero-alloc get_kv_len via borrow_raw) - FIX-14: GPU strided copy kernel (eliminates contiguous() CPU round-trip) Architecture: - Phase 12 engine restructured: prefill/decode separation, honest TODO for batched GPU forward (requires Flash Attention) - Phase 14 updated: FA2 for SM120 (FA4 requires TMEM, absent on 5090) - Qwen3-7B → Qwen3-8B typo fixed across all docs (36 layers, hidden 4096) Validated on dash5 (8x RTX 5090): - 52/52 API prompts pass (EN/CN/code), SSE streaming verified - Logits match HF transformers 9/10 top-1, 4.0/5 avg top-5 overlap - 8 concurrent requests: 5.99x scheduling speedup (batch_size=4) - Throughput: 10.3 tok/s (serial), 30% of HF baseline Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
395 lines
13 KiB
Python
395 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
End-to-end validation for xserv after bug fixes.
|
||
1. Correctness: compare top-k logits with HuggingFace transformers
|
||
2. Generation: run 50+ prompts through the HTTP API
|
||
3. Performance: measure latency and throughput
|
||
|
||
Usage:
|
||
# Step 1: Start xserv server in background:
|
||
# ./target/release/xserv-server /opt/wjh/models/qwen3-8b --port 8080
|
||
#
|
||
# Step 2: Run this script:
|
||
# python3 tools/e2e_validate.py --mode all
|
||
# python3 tools/e2e_validate.py --mode logits # correctness only
|
||
# python3 tools/e2e_validate.py --mode api # API + perf only
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import time
|
||
import subprocess
|
||
import sys
|
||
import os
|
||
from pathlib import Path
|
||
|
||
MODEL_DIR = "/opt/wjh/models/qwen3-8b"
|
||
XSERV_URL = "http://localhost:8080"
|
||
TOP_K = 10
|
||
|
||
# 50+ diverse test prompts
|
||
TEST_PROMPTS = [
|
||
"What is the capital of France?",
|
||
"Explain quantum computing in simple terms.",
|
||
"Write a Python function to sort a list.",
|
||
"你好,请用中文介绍一下你自己。",
|
||
"What is 2 + 2?",
|
||
"The theory of relativity states that",
|
||
"In a far away galaxy,",
|
||
"def fibonacci(n):",
|
||
"请解释什么是机器学习。",
|
||
"How does photosynthesis work?",
|
||
"What are the benefits of exercise?",
|
||
"Once upon a time in a small village,",
|
||
"The most important invention of the 20th century was",
|
||
"Translate 'hello world' to Japanese.",
|
||
"What is the meaning of life?",
|
||
"Describe the process of making bread.",
|
||
"Why is the sky blue?",
|
||
"What is the difference between AI and ML?",
|
||
"如何评价GPT-4?",
|
||
"Write a haiku about autumn.",
|
||
"Explain the Pythagorean theorem.",
|
||
"What causes earthquakes?",
|
||
"How does the internet work?",
|
||
"What is the speed of light?",
|
||
"Describe the water cycle.",
|
||
"What is democracy?",
|
||
"How do vaccines work?",
|
||
"What is blockchain technology?",
|
||
"Explain supply and demand.",
|
||
"What is the Big Bang theory?",
|
||
"How do airplanes fly?",
|
||
"What is climate change?",
|
||
"Describe the human digestive system.",
|
||
"What is artificial intelligence?",
|
||
"How does electricity work?",
|
||
"What is the solar system?",
|
||
"Explain the concept of gravity.",
|
||
"What is DNA?",
|
||
"How do computers store data?",
|
||
"What is the greenhouse effect?",
|
||
"Describe the structure of an atom.",
|
||
"What is machine learning?",
|
||
"How does Wi-Fi work?",
|
||
"What is the stock market?",
|
||
"Explain natural selection.",
|
||
"What is renewable energy?",
|
||
"How do batteries work?",
|
||
"What is the United Nations?",
|
||
"Describe the process of evolution.",
|
||
"What is cryptography?",
|
||
"请用三句话总结量子力学的核心概念。",
|
||
"用Python写一个计算斐波那契数列的函数。",
|
||
]
|
||
|
||
|
||
def logits_correctness_test():
|
||
"""Compare xserv prefill logits with HuggingFace transformers."""
|
||
print("\n" + "=" * 60)
|
||
print("CORRECTNESS TEST: Comparing logits with HuggingFace")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
import torch
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
except ImportError:
|
||
print("SKIP: transformers/torch not installed")
|
||
return None
|
||
|
||
print(f"Loading HF model from {MODEL_DIR}...")
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
MODEL_DIR,
|
||
torch_dtype=torch.bfloat16,
|
||
device_map="cuda:1", # Use GPU 1 (xserv uses GPU 0)
|
||
trust_remote_code=True,
|
||
)
|
||
model.eval()
|
||
|
||
test_prompts = TEST_PROMPTS[:10] # Use first 10 for logits comparison
|
||
xserv_bin = "/opt/wjh/projects/xserv/target/release/dump-logits"
|
||
|
||
results = []
|
||
for i, prompt in enumerate(test_prompts):
|
||
print(f"\n[{i+1}/{len(test_prompts)}] Prompt: {prompt[:50]}...")
|
||
|
||
# --- HuggingFace ---
|
||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||
with torch.no_grad():
|
||
outputs = model(**inputs)
|
||
hf_logits = outputs.logits[0, -1, :].float().cpu()
|
||
hf_top = torch.topk(hf_logits, TOP_K)
|
||
hf_ids = hf_top.indices.tolist()
|
||
hf_vals = hf_top.values.tolist()
|
||
|
||
# --- xserv ---
|
||
try:
|
||
result = subprocess.run(
|
||
[xserv_bin, MODEL_DIR, prompt],
|
||
capture_output=True, text=True, timeout=120,
|
||
env={**os.environ, "CUDA_VISIBLE_DEVICES": "0",
|
||
"PATH": "/usr/local/cuda-12.9/bin:" + os.environ.get("PATH", "")},
|
||
)
|
||
xserv_lines = [l for l in result.stdout.strip().split('\n') if l.strip().startswith('[')]
|
||
xserv_top = []
|
||
for line in xserv_lines[:TOP_K]:
|
||
parts = line.strip().split()
|
||
tid = int([p for p in parts if p.startswith('id=')][0].split('=')[1])
|
||
val = float([p for p in parts if p.startswith('logit=')][0].split('=')[1])
|
||
xserv_top.append((tid, val))
|
||
except Exception as e:
|
||
print(f" xserv FAILED: {e}")
|
||
results.append({"prompt": prompt, "match": False, "error": str(e)})
|
||
continue
|
||
|
||
# --- Compare ---
|
||
xserv_ids = [t[0] for t in xserv_top]
|
||
xserv_vals = [t[1] for t in xserv_top]
|
||
|
||
# Top-1 match
|
||
top1_match = hf_ids[0] == xserv_ids[0] if xserv_ids else False
|
||
# Top-5 overlap
|
||
top5_overlap = len(set(hf_ids[:5]) & set(xserv_ids[:5]))
|
||
# Max logit difference for matching tokens
|
||
max_diff = 0
|
||
for j, (hid, hval) in enumerate(zip(hf_ids[:5], hf_vals[:5])):
|
||
for xid, xval in xserv_top[:5]:
|
||
if hid == xid:
|
||
max_diff = max(max_diff, abs(hval - xval))
|
||
|
||
hf_tok = tokenizer.decode([hf_ids[0]])
|
||
xs_tok = tokenizer.decode([xserv_ids[0]]) if xserv_ids else "???"
|
||
|
||
status = "PASS" if top1_match else "WARN"
|
||
print(f" Top-1: HF={hf_ids[0]}({hf_tok!r}) vs xserv={xserv_ids[0]}({xs_tok!r}) → {status}")
|
||
print(f" Top-5 overlap: {top5_overlap}/5, max logit diff: {max_diff:.4f}")
|
||
|
||
results.append({
|
||
"prompt": prompt[:50],
|
||
"top1_match": top1_match,
|
||
"top5_overlap": top5_overlap,
|
||
"max_logit_diff": max_diff,
|
||
"hf_top1": f"{hf_ids[0]}({hf_tok})",
|
||
"xserv_top1": f"{xserv_ids[0]}({xs_tok})" if xserv_ids else "???",
|
||
})
|
||
|
||
# Summary
|
||
print("\n" + "-" * 40)
|
||
top1_matches = sum(1 for r in results if r.get("top1_match"))
|
||
avg_overlap = sum(r.get("top5_overlap", 0) for r in results) / max(len(results), 1)
|
||
print(f"Top-1 match: {top1_matches}/{len(results)}")
|
||
print(f"Avg top-5 overlap: {avg_overlap:.1f}/5")
|
||
print(f"Verdict: {'PASS' if top1_matches >= len(results) * 0.8 else 'FAIL'}")
|
||
|
||
# Cleanup
|
||
del model
|
||
torch.cuda.empty_cache()
|
||
|
||
return results
|
||
|
||
|
||
def api_generation_test():
|
||
"""Test 50+ prompts through the HTTP API."""
|
||
print("\n" + "=" * 60)
|
||
print("API GENERATION TEST: 50+ prompts via /v1/chat/completions")
|
||
print("=" * 60)
|
||
|
||
import urllib.request
|
||
import urllib.error
|
||
|
||
# Health check
|
||
try:
|
||
req = urllib.request.Request(f"{XSERV_URL}/health")
|
||
resp = urllib.request.urlopen(req, timeout=5)
|
||
assert resp.read().decode() == "ok"
|
||
print("Health check: OK")
|
||
except Exception as e:
|
||
print(f"FAIL: Server not reachable at {XSERV_URL}: {e}")
|
||
print("Start the server first: ./target/release/xserv-server /opt/wjh/models/qwen3-8b")
|
||
return None
|
||
|
||
# Models endpoint
|
||
try:
|
||
req = urllib.request.Request(f"{XSERV_URL}/v1/models")
|
||
resp = urllib.request.urlopen(req, timeout=5)
|
||
models = json.loads(resp.read())
|
||
print(f"Models: {[m['id'] for m in models['data']]}")
|
||
except Exception as e:
|
||
print(f"WARN: /v1/models failed: {e}")
|
||
|
||
results = []
|
||
total_prompt_tokens = 0
|
||
total_completion_tokens = 0
|
||
total_latency = 0
|
||
failures = 0
|
||
|
||
for i, prompt in enumerate(TEST_PROMPTS):
|
||
body = json.dumps({
|
||
"model": "qwen3-8b",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": 32,
|
||
"temperature": 0.0,
|
||
}).encode()
|
||
|
||
try:
|
||
req = urllib.request.Request(
|
||
f"{XSERV_URL}/v1/chat/completions",
|
||
data=body,
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
t0 = time.time()
|
||
resp = urllib.request.urlopen(req, timeout=120)
|
||
latency = time.time() - t0
|
||
data = json.loads(resp.read())
|
||
|
||
content = data["choices"][0]["message"]["content"]
|
||
finish = data["choices"][0]["finish_reason"]
|
||
usage = data.get("usage", {})
|
||
pt = usage.get("prompt_tokens", 0)
|
||
ct = usage.get("completion_tokens", 0)
|
||
|
||
total_prompt_tokens += pt
|
||
total_completion_tokens += ct
|
||
total_latency += latency
|
||
|
||
# Basic quality checks
|
||
has_content = len(content.strip()) > 0
|
||
reasonable_length = ct > 0
|
||
|
||
status = "OK" if has_content and reasonable_length else "WARN"
|
||
if not has_content:
|
||
status = "FAIL"
|
||
failures += 1
|
||
|
||
truncated = content[:60].replace('\n', ' ')
|
||
print(f" [{i+1:2d}/{len(TEST_PROMPTS)}] {status} | {latency:5.2f}s | pt={pt:3d} ct={ct:2d} | {truncated}...")
|
||
|
||
results.append({
|
||
"prompt": prompt[:40],
|
||
"status": status,
|
||
"latency": latency,
|
||
"prompt_tokens": pt,
|
||
"completion_tokens": ct,
|
||
"finish_reason": finish,
|
||
"content_preview": content[:80],
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f" [{i+1:2d}/{len(TEST_PROMPTS)}] FAIL | {e}")
|
||
failures += 1
|
||
results.append({"prompt": prompt[:40], "status": "FAIL", "error": str(e)})
|
||
|
||
# Summary
|
||
successes = len(results) - failures
|
||
avg_latency = total_latency / max(successes, 1)
|
||
tok_per_sec = total_completion_tokens / max(total_latency, 0.001)
|
||
|
||
print("\n" + "-" * 40)
|
||
print(f"Results: {successes}/{len(TEST_PROMPTS)} succeeded, {failures} failed")
|
||
print(f"Total prompt tokens: {total_prompt_tokens}")
|
||
print(f"Total completion tokens: {total_completion_tokens}")
|
||
print(f"Average latency: {avg_latency:.2f}s per request")
|
||
print(f"Throughput: {tok_per_sec:.1f} tokens/s (completion only)")
|
||
print(f"Verdict: {'PASS' if failures <= 2 else 'FAIL'}")
|
||
|
||
return results
|
||
|
||
|
||
def streaming_test():
|
||
"""Test SSE streaming works correctly."""
|
||
print("\n" + "=" * 60)
|
||
print("STREAMING TEST: SSE /v1/chat/completions?stream=true")
|
||
print("=" * 60)
|
||
|
||
import urllib.request
|
||
import urllib.error
|
||
|
||
body = json.dumps({
|
||
"model": "qwen3-8b",
|
||
"messages": [{"role": "user", "content": "Count from 1 to 5."}],
|
||
"max_tokens": 32,
|
||
"temperature": 0.0,
|
||
"stream": True,
|
||
}).encode()
|
||
|
||
req = urllib.request.Request(
|
||
f"{XSERV_URL}/v1/chat/completions",
|
||
data=body,
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
|
||
try:
|
||
resp = urllib.request.urlopen(req, timeout=60)
|
||
content_type = resp.headers.get("content-type", "")
|
||
print(f"Content-Type: {content_type}")
|
||
|
||
chunks = []
|
||
full_text = ""
|
||
has_role_chunk = False
|
||
has_done = False
|
||
has_finish = False
|
||
|
||
for line in resp:
|
||
line = line.decode().strip()
|
||
if not line:
|
||
continue
|
||
if line.startswith("data: "):
|
||
data = line[6:]
|
||
if data == "[DONE]":
|
||
has_done = True
|
||
chunks.append("[DONE]")
|
||
continue
|
||
try:
|
||
obj = json.loads(data)
|
||
delta = obj["choices"][0]["delta"]
|
||
fr = obj["choices"][0].get("finish_reason")
|
||
if "role" in delta:
|
||
has_role_chunk = True
|
||
if "content" in delta:
|
||
full_text += delta["content"]
|
||
if fr is not None:
|
||
has_finish = True
|
||
chunks.append(delta)
|
||
except json.JSONDecodeError:
|
||
print(f" WARN: bad JSON: {data[:80]}")
|
||
|
||
print(f"Chunks received: {len(chunks)}")
|
||
print(f"Has role chunk: {has_role_chunk}")
|
||
print(f"Has finish_reason: {has_finish}")
|
||
print(f"Has [DONE]: {has_done}")
|
||
print(f"Full text: {full_text[:100]!r}")
|
||
|
||
ok = has_role_chunk and has_done and has_finish and len(full_text) > 0
|
||
# SSE content-type check
|
||
if "text/event-stream" in content_type:
|
||
print("Content-Type: OK (text/event-stream)")
|
||
else:
|
||
print(f"WARN: Expected text/event-stream, got {content_type}")
|
||
|
||
print(f"Verdict: {'PASS' if ok else 'FAIL'}")
|
||
return ok
|
||
|
||
except Exception as e:
|
||
print(f"FAIL: {e}")
|
||
return False
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--mode", choices=["all", "logits", "api", "stream"], default="all")
|
||
args = parser.parse_args()
|
||
|
||
if args.mode in ("all", "logits"):
|
||
logits_correctness_test()
|
||
|
||
if args.mode in ("all", "api"):
|
||
api_generation_test()
|
||
|
||
if args.mode in ("all", "stream"):
|
||
streaming_test()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|