"""B2 PD-colo interference microbench. Concurrently issues: - a steady stream of short-prompt decode-heavy requests to a designated "decode" instance; - a periodic single large-prompt one-token request that hits either the same instance (same-worker) or a separate one (different-worker). The harness writes per-request metrics with a `workload` tag ({"decode", "prefill"}) and a `variant` tag (same/different) plus unix timestamps so the analyzer can label same-worker overlap directly from the engine step JSONL. Outputs under ///: - metrics.jsonl — per-request rows for this cell - run_window.json — t_start_unix/t_end_unix for analyzer slice """ from __future__ import annotations import argparse import asyncio import json import logging import random import time from pathlib import Path import httpx logger = logging.getLogger(__name__) async def _send( client: httpx.AsyncClient, endpoint: str, model: str, prompt_ids: list[int], max_tokens: int, *, workload: str, variant: str, prefill_size: int, out_fh, fh_lock: asyncio.Lock, idx: int, ) -> None: payload = { "model": model, "prompt": prompt_ids, "max_tokens": max_tokens, "min_tokens": max_tokens, "temperature": 0, "return_token_ids": True, "stream": True, "stream_options": {"include_usage": True}, } rid = f"{workload}_{variant}_p{prefill_size}_{idx}" t_dispatch = time.time() ttft = None finish = None n_output = 0 err = None token_times: list[float] = [] try: async with client.stream( "POST", f"{endpoint}/v1/completions", json=payload, headers={"X-Request-Id": rid, "X-Session-Id": rid}, timeout=600.0, ) as resp: resp.raise_for_status() async for raw_line in resp.aiter_lines(): if not raw_line or not raw_line.startswith("data:"): continue data = raw_line[5:].strip() if data == "[DONE]": break try: chunk = json.loads(data) except json.JSONDecodeError: continue choices = chunk.get("choices", []) if choices: now = time.time() token_ids = choices[0].get("token_ids") delta = choices[0].get("text", "") if isinstance(token_ids, list) and token_ids: if ttft is None: ttft = now - t_dispatch token_times.extend([now] * len(token_ids)) elif delta: if ttft is None: ttft = now - t_dispatch token_times.append(now) usage = chunk.get("usage") if usage: n_output = usage.get("completion_tokens", n_output) finish = time.time() except Exception as exc: err = repr(exc)[:300] finish = time.time() tpot = None if len(token_times) > 1: diffs = [token_times[i + 1] - token_times[i] for i in range(len(token_times) - 1)] tpot = sum(diffs) / len(diffs) row = { "request_id": rid, "workload": workload, "variant": variant, "prefill_size": prefill_size, "endpoint": endpoint, "input_length": len(prompt_ids), "max_tokens": max_tokens, "t_dispatch_unix": t_dispatch, "t_finish_unix": finish, "ttft_s": ttft, "tpot_s": tpot, "latency_s": (finish - t_dispatch) if finish else None, "actual_output_tokens": n_output, "error": err, } async with fh_lock: out_fh.write(json.dumps(row, sort_keys=True) + "\n") out_fh.flush() async def decode_load( *, client, endpoint, model, qps, duration_s, workload, variant, prefill_size, out_fh, fh_lock, decode_prompt_tokens, decode_output_tokens, rng, ) -> None: period = 1.0 / qps end_t = time.time() + duration_s pending: list[asyncio.Task] = [] idx = 0 while time.time() < end_t: prompt_ids = [rng.randint(1000, 100000) for _ in range(decode_prompt_tokens)] task = asyncio.create_task(_send( client, endpoint, model, prompt_ids, decode_output_tokens, workload="decode", variant=variant, prefill_size=prefill_size, out_fh=out_fh, fh_lock=fh_lock, idx=idx, )) pending.append(task) idx += 1 await asyncio.sleep(period) await asyncio.gather(*pending, return_exceptions=True) async def prefill_injections( *, client, endpoint, model, prefill_size, n_injections, interval_s, variant, out_fh, fh_lock, start_delay_s, rng, ) -> None: await asyncio.sleep(start_delay_s) for i in range(n_injections): prompt_ids = [rng.randint(1000, 100000) for _ in range(prefill_size)] await _send( client, endpoint, model, prompt_ids, max_tokens=1, workload="prefill", variant=variant, prefill_size=prefill_size, out_fh=out_fh, fh_lock=fh_lock, idx=i, ) await asyncio.sleep(interval_s) async def run_cell( *, decode_endpoint, prefill_endpoint, model, prefill_size, variant, qps, duration_s, n_injections, injection_interval_s, start_delay_s, decode_prompt_tokens, decode_output_tokens, out_dir, ) -> dict: cell_dir = out_dir / variant / f"p{prefill_size}" cell_dir.mkdir(parents=True, exist_ok=True) metrics_path = cell_dir / "metrics.jsonl" fh_lock = asyncio.Lock() rng = random.Random(42 + prefill_size + (0 if variant == "same" else 1000)) t_start = time.time() logger.info("[b2] start variant=%s prefill_size=%d", variant, prefill_size) with metrics_path.open("w", encoding="utf-8") as out_fh: limits = httpx.Limits(max_connections=2000, max_keepalive_connections=500) async with httpx.AsyncClient(timeout=600.0, limits=limits) as client: await asyncio.gather( decode_load( client=client, endpoint=decode_endpoint, model=model, qps=qps, duration_s=duration_s, workload="decode", variant=variant, prefill_size=prefill_size, out_fh=out_fh, fh_lock=fh_lock, decode_prompt_tokens=decode_prompt_tokens, decode_output_tokens=decode_output_tokens, rng=rng, ), prefill_injections( client=client, endpoint=prefill_endpoint, model=model, prefill_size=prefill_size, n_injections=n_injections, interval_s=injection_interval_s, variant=variant, out_fh=out_fh, fh_lock=fh_lock, start_delay_s=start_delay_s, rng=rng, ), ) t_end = time.time() window = { "variant": variant, "prefill_size": prefill_size, "decode_endpoint": decode_endpoint, "prefill_endpoint": prefill_endpoint, "qps": qps, "duration_s": duration_s, "n_injections": n_injections, "injection_interval_s": injection_interval_s, "decode_prompt_tokens": decode_prompt_tokens, "decode_output_tokens": decode_output_tokens, "t_start_unix": t_start, "t_end_unix": t_end, } (cell_dir / "run_window.json").write_text(json.dumps(window, indent=2)) logger.info("[b2] done variant=%s prefill_size=%d wall=%.1fs", variant, prefill_size, t_end - t_start) return window async def amain(args: argparse.Namespace) -> None: sizes = [int(s) for s in args.prefill_sizes.split(",")] out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) overall: list[dict] = [] for size in sizes: for variant in args.variants.split(","): variant = variant.strip() if variant == "same": d_ep, p_ep = args.decode_endpoint, args.decode_endpoint elif variant == "different": d_ep, p_ep = args.decode_endpoint, args.prefill_endpoint else: raise ValueError(f"unknown variant {variant!r}") w = await run_cell( decode_endpoint=d_ep, prefill_endpoint=p_ep, model=args.model, prefill_size=size, variant=variant, qps=args.decode_qps, duration_s=args.duration_s, n_injections=args.injections, injection_interval_s=args.injection_interval_s, start_delay_s=args.start_delay_s, decode_prompt_tokens=args.decode_prompt_tokens, decode_output_tokens=args.decode_output_tokens, out_dir=out_dir, ) overall.append(w) (out_dir / "sweep_meta.json").write_text(json.dumps(overall, indent=2)) def main() -> None: p = argparse.ArgumentParser(description="B2 interference microbench") p.add_argument("--decode-endpoint", required=True, help="e.g. http://127.0.0.1:8000") p.add_argument("--prefill-endpoint", required=True, help="e.g. http://127.0.0.1:8001") p.add_argument("--model", required=True) p.add_argument("--out-dir", required=True) p.add_argument("--prefill-sizes", default="2048,8192,16384,32768,65536") p.add_argument("--variants", default="different,same", help="Comma-separated variants in run order") p.add_argument("--decode-qps", type=float, default=4.0, help="Decode-load arrival rate (req/s)") p.add_argument("--duration-s", type=float, default=60.0, help="Decode-load duration (s) per cell") p.add_argument("--injections", type=int, default=4, help="Number of prefill injections per cell") p.add_argument("--injection-interval-s", type=float, default=12.0) p.add_argument("--start-delay-s", type=float, default=10.0, help="Warmup before first prefill injection") p.add_argument("--decode-prompt-tokens", type=int, default=256) p.add_argument("--decode-output-tokens", type=int, default=100) p.add_argument("-v", "--verbose", action="store_true") args = p.parse_args() logging.basicConfig( level=logging.DEBUG if args.verbose else logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) asyncio.run(amain(args)) if __name__ == "__main__": main()