Rewrote engine.rs from scratch: - Scheduler loop: admit → prefill → decode → finish → check new requests - Multiple sequences run concurrently (max_batch_size configurable) - Each sequence has independent GpuKVCache - Non-blocking try_recv() for new requests during decode iterations - Dynamic join: new requests enter batch immediately, don't wait for others Verified with concurrent test (tools/test_concurrent.py): - 3 concurrent requests: wall_time=3.8s, concurrency_ratio=2.82x ✓ - 5 concurrent requests: wall_time=6.1s, concurrency_ratio=4.04x ✓ - All outputs are coherent and correct Design doc (docs/12-continuous-batching.md) fully rewritten with: - Detailed scheduler loop pseudocode - Data structures (Sequence, Scheduler) - Acceptance criteria with specific test cases - Clear separation from Phase 13 (HTTP layer) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
108 lines
3.3 KiB
Python
108 lines
3.3 KiB
Python
"""
|
|
Test concurrent request handling.
|
|
Sends N requests simultaneously, verifies they all produce tokens concurrently.
|
|
|
|
Usage: python3 tools/test_concurrent.py <server_url> [num_requests]
|
|
"""
|
|
import sys
|
|
import time
|
|
import json
|
|
import threading
|
|
import urllib.request
|
|
import urllib.error
|
|
|
|
|
|
def send_request(url, prompt, max_tokens, results, idx):
|
|
"""Send a chat completion request and record timing."""
|
|
body = json.dumps({
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": max_tokens,
|
|
}).encode()
|
|
|
|
req = urllib.request.Request(
|
|
f"{url}/v1/chat/completions",
|
|
data=body,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
|
|
t0 = time.time()
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=120) as resp:
|
|
data = json.loads(resp.read())
|
|
t1 = time.time()
|
|
content = data["choices"][0]["message"]["content"]
|
|
results[idx] = {
|
|
"status": "ok",
|
|
"content": content,
|
|
"duration_s": t1 - t0,
|
|
"finish_reason": data["choices"][0]["finish_reason"],
|
|
}
|
|
except Exception as e:
|
|
t1 = time.time()
|
|
results[idx] = {"status": "error", "error": str(e), "duration_s": t1 - t0}
|
|
|
|
|
|
def main():
|
|
url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:9090"
|
|
n = int(sys.argv[2]) if len(sys.argv) > 2 else 3
|
|
max_tokens = 10
|
|
|
|
prompts = [
|
|
"What is the capital of France?",
|
|
"Tell me about quantum computing",
|
|
"How do airplanes fly?",
|
|
"What is machine learning?",
|
|
"Explain gravity in simple terms",
|
|
][:n]
|
|
|
|
print(f"Sending {n} concurrent requests to {url} (max_tokens={max_tokens})")
|
|
print("=" * 70)
|
|
|
|
results = [None] * n
|
|
threads = []
|
|
|
|
t_start = time.time()
|
|
for i, prompt in enumerate(prompts):
|
|
t = threading.Thread(target=send_request, args=(url, prompt, max_tokens, results, i))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
t_total = time.time() - t_start
|
|
|
|
print(f"\n{'#':>2} {'Status':>6} {'Duration':>8} {'Content':<50}")
|
|
print("-" * 70)
|
|
for i, r in enumerate(results):
|
|
if r["status"] == "ok":
|
|
content_short = r["content"].replace("\n", " ")[:48]
|
|
print(f"{i+1:>2} {'OK':>6} {r['duration_s']:>6.1f}s {content_short}")
|
|
else:
|
|
print(f"{i+1:>2} {'FAIL':>6} {r['duration_s']:>6.1f}s {r['error'][:48]}")
|
|
|
|
print("=" * 70)
|
|
print(f"Total wall time: {t_total:.1f}s")
|
|
|
|
# Analyze concurrency
|
|
durations = [r["duration_s"] for r in results if r["status"] == "ok"]
|
|
if len(durations) >= 2:
|
|
sequential_estimate = sum(durations)
|
|
actual_wall = t_total
|
|
concurrency_ratio = sequential_estimate / actual_wall if actual_wall > 0 else 0
|
|
|
|
print(f"Sum of individual durations: {sequential_estimate:.1f}s")
|
|
print(f"Actual wall time: {actual_wall:.1f}s")
|
|
print(f"Concurrency ratio: {concurrency_ratio:.2f}x")
|
|
|
|
if concurrency_ratio > 1.5:
|
|
print("✓ CONCURRENT: requests are being processed in parallel")
|
|
else:
|
|
print("✗ SERIAL: requests appear to be processed sequentially")
|
|
|
|
all_ok = all(r["status"] == "ok" for r in results)
|
|
print(f"\nAll requests succeeded: {all_ok}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|