Files
xserv/tools/test_concurrent.py
Gahow Wang d8493bd70f phase 12: implement real continuous batching scheduler
Rewrote engine.rs from scratch:
- Scheduler loop: admit → prefill → decode → finish → check new requests
- Multiple sequences run concurrently (max_batch_size configurable)
- Each sequence has independent GpuKVCache
- Non-blocking try_recv() for new requests during decode iterations
- Dynamic join: new requests enter batch immediately, don't wait for others

Verified with concurrent test (tools/test_concurrent.py):
- 3 concurrent requests: wall_time=3.8s, concurrency_ratio=2.82x ✓
- 5 concurrent requests: wall_time=6.1s, concurrency_ratio=4.04x ✓
- All outputs are coherent and correct

Design doc (docs/12-continuous-batching.md) fully rewritten with:
- Detailed scheduler loop pseudocode
- Data structures (Sequence, Scheduler)
- Acceptance criteria with specific test cases
- Clear separation from Phase 13 (HTTP layer)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 13:44:26 +08:00

108 lines
3.3 KiB
Python

"""
Test concurrent request handling.
Sends N requests simultaneously, verifies they all produce tokens concurrently.
Usage: python3 tools/test_concurrent.py <server_url> [num_requests]
"""
import sys
import time
import json
import threading
import urllib.request
import urllib.error
def send_request(url, prompt, max_tokens, results, idx):
"""Send a chat completion request and record timing."""
body = json.dumps({
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
}).encode()
req = urllib.request.Request(
f"{url}/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
t0 = time.time()
try:
with urllib.request.urlopen(req, timeout=120) as resp:
data = json.loads(resp.read())
t1 = time.time()
content = data["choices"][0]["message"]["content"]
results[idx] = {
"status": "ok",
"content": content,
"duration_s": t1 - t0,
"finish_reason": data["choices"][0]["finish_reason"],
}
except Exception as e:
t1 = time.time()
results[idx] = {"status": "error", "error": str(e), "duration_s": t1 - t0}
def main():
url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:9090"
n = int(sys.argv[2]) if len(sys.argv) > 2 else 3
max_tokens = 10
prompts = [
"What is the capital of France?",
"Tell me about quantum computing",
"How do airplanes fly?",
"What is machine learning?",
"Explain gravity in simple terms",
][:n]
print(f"Sending {n} concurrent requests to {url} (max_tokens={max_tokens})")
print("=" * 70)
results = [None] * n
threads = []
t_start = time.time()
for i, prompt in enumerate(prompts):
t = threading.Thread(target=send_request, args=(url, prompt, max_tokens, results, i))
threads.append(t)
t.start()
for t in threads:
t.join()
t_total = time.time() - t_start
print(f"\n{'#':>2} {'Status':>6} {'Duration':>8} {'Content':<50}")
print("-" * 70)
for i, r in enumerate(results):
if r["status"] == "ok":
content_short = r["content"].replace("\n", " ")[:48]
print(f"{i+1:>2} {'OK':>6} {r['duration_s']:>6.1f}s {content_short}")
else:
print(f"{i+1:>2} {'FAIL':>6} {r['duration_s']:>6.1f}s {r['error'][:48]}")
print("=" * 70)
print(f"Total wall time: {t_total:.1f}s")
# Analyze concurrency
durations = [r["duration_s"] for r in results if r["status"] == "ok"]
if len(durations) >= 2:
sequential_estimate = sum(durations)
actual_wall = t_total
concurrency_ratio = sequential_estimate / actual_wall if actual_wall > 0 else 0
print(f"Sum of individual durations: {sequential_estimate:.1f}s")
print(f"Actual wall time: {actual_wall:.1f}s")
print(f"Concurrency ratio: {concurrency_ratio:.2f}x")
if concurrency_ratio > 1.5:
print("✓ CONCURRENT: requests are being processed in parallel")
else:
print("✗ SERIAL: requests appear to be processed sequentially")
all_ok = all(r["status"] == "ok" for r in results)
print(f"\nAll requests succeeded: {all_ok}")
if __name__ == "__main__":
main()