Microbench: prefill-decode interference + PD transfer lifecycle
Two microbenchmarks quantifying the elastic offload decision:
1. Interference (corrected): cold prefill causes 14-214x TPOT p90
degradation on same-worker decode (D∈{1,2,4,8} × P∈{2k,8k,16k,32k}).
Earlier run had a prefix-cache bug (deterministic prompts hit cache
after rep 0); fixed with uuid+time_ns unique prompts.
2. Transfer lifecycle: PD-sep TTFT breakdown via Mooncake proxy,
measuring prefill→RDMA→decode startup overhead.
Key finding: offload wins at all P≥2048 operating points —
transfer cost is 25-50% of interference cost even with bulk Mooncake.
This commit is contained in:
422
microbench/interference/driver.py
Normal file
422
microbench/interference/driver.py
Normal file
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prefill-Decode Interference Microbenchmark Driver.
|
||||
|
||||
Measures TPOT degradation caused by prefill chunks interfering with ongoing decode batches.
|
||||
Produces: f(decode_batch_size, new_prefill_tokens, chunk_size) -> TPOT_penalty_ms
|
||||
|
||||
Usage:
|
||||
python driver.py --host 127.0.0.1 --port 8000 \
|
||||
--decode-batch-sizes 0,1,2,4,6,8,12 \
|
||||
--prefill-tokens 512,1024,2048,4096,8192,16384,32768 \
|
||||
--reps 5 --output-dir results/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
|
||||
FIXED_SEED_PROMPT = (
|
||||
"You are a helpful assistant. Please analyze the following document carefully "
|
||||
"and provide a comprehensive summary covering all key points, main arguments, "
|
||||
"supporting evidence, and conclusions. The document discusses various aspects "
|
||||
"of distributed systems, including consensus protocols, fault tolerance mechanisms, "
|
||||
"and performance optimization strategies for large-scale deployments.\n\n"
|
||||
) * 50 # ~4k tokens worth of repeated text for prefix cache sharing
|
||||
|
||||
WARMUP_TOKENS = 32
|
||||
MEASURE_WINDOW_TOKENS = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
decode_batch_size: int
|
||||
new_prefill_tokens: int
|
||||
chunk_size: int
|
||||
model: str
|
||||
repetition: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaselineResult:
|
||||
tpot_p50_ms: float
|
||||
tpot_p90_ms: float
|
||||
tpot_p99_ms: float
|
||||
tokens_collected: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterferenceResult:
|
||||
tpot_during_prefill_p50_ms: float
|
||||
tpot_during_prefill_p90_ms: float
|
||||
tpot_after_prefill_p50_ms: float
|
||||
prefill_ttft_ms: float
|
||||
num_tokens_during_prefill: int
|
||||
|
||||
|
||||
async def stream_tokens(client: httpx.AsyncClient, url: str, payload: dict) -> list[float]:
|
||||
"""Send a streaming request, return list of timestamps (seconds) for each token."""
|
||||
timestamps = []
|
||||
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
delta = choices[0].get("delta", {})
|
||||
if "role" in delta:
|
||||
continue
|
||||
timestamps.append(time.perf_counter())
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return timestamps
|
||||
|
||||
|
||||
def compute_tpot(timestamps: list[float], skip_first: int = 0) -> np.ndarray:
|
||||
"""Compute inter-token intervals in ms, skipping first N tokens."""
|
||||
if len(timestamps) < skip_first + 2:
|
||||
return np.array([])
|
||||
ts = np.array(timestamps[skip_first:])
|
||||
return np.diff(ts) * 1000.0 # seconds → ms
|
||||
|
||||
|
||||
def make_decode_payload(model: str) -> dict:
|
||||
return {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": FIXED_SEED_PROMPT}],
|
||||
"max_tokens": WARMUP_TOKENS + MEASURE_WINDOW_TOKENS + 50,
|
||||
"temperature": 0,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
|
||||
def make_prefill_payload(model: str, num_tokens: int) -> dict:
|
||||
import hashlib
|
||||
import uuid
|
||||
# Generate UNIQUE content every call to guarantee zero prefix cache hits.
|
||||
# Calibration: each "Block N: <32-hex>" → ~35 tokens after tokenization
|
||||
unique_id = f"{uuid.uuid4().hex}_{time.time_ns()}"
|
||||
n_parts = max(1, num_tokens // 35)
|
||||
content_parts = []
|
||||
for i in range(n_parts):
|
||||
seed = hashlib.md5(f"{unique_id}_{i}".encode()).hexdigest()
|
||||
content_parts.append(f"Block {i}: {seed}")
|
||||
content = " ".join(content_parts)
|
||||
return {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"max_tokens": 1,
|
||||
"temperature": 0,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
|
||||
async def wait_for_steady_state(decode_streams: list[asyncio.Task], min_tokens: int = 32):
|
||||
"""Wait until all decode streams have emitted at least min_tokens."""
|
||||
# We don't directly control this — we wait a fixed time based on expected TPOT
|
||||
# At ~50ms/token, 32 tokens ≈ 1.6s. Wait 3s to be safe.
|
||||
await asyncio.sleep(3.0)
|
||||
|
||||
|
||||
async def run_baseline(
|
||||
client: httpx.AsyncClient, url: str, model: str, decode_batch_size: int
|
||||
) -> Optional[BaselineResult]:
|
||||
"""Measure decode-only TPOT (no prefill interference)."""
|
||||
if decode_batch_size == 0:
|
||||
return BaselineResult(tpot_p50_ms=0, tpot_p90_ms=0, tpot_p99_ms=0, tokens_collected=0)
|
||||
|
||||
payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
|
||||
tasks = [asyncio.create_task(stream_tokens(client, url, p)) for p in payloads]
|
||||
|
||||
all_timestamps = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
all_tpots = []
|
||||
for ts in all_timestamps:
|
||||
if isinstance(ts, Exception):
|
||||
print(f" [WARN] decode stream error: {ts}")
|
||||
continue
|
||||
tpot = compute_tpot(ts, skip_first=WARMUP_TOKENS)
|
||||
if len(tpot) > 0:
|
||||
all_tpots.extend(tpot.tolist())
|
||||
|
||||
if not all_tpots:
|
||||
return None
|
||||
|
||||
arr = np.array(all_tpots)
|
||||
return BaselineResult(
|
||||
tpot_p50_ms=float(np.percentile(arr, 50)),
|
||||
tpot_p90_ms=float(np.percentile(arr, 90)),
|
||||
tpot_p99_ms=float(np.percentile(arr, 99)),
|
||||
tokens_collected=len(arr),
|
||||
)
|
||||
|
||||
|
||||
async def run_interference(
|
||||
client: httpx.AsyncClient,
|
||||
url: str,
|
||||
model: str,
|
||||
decode_batch_size: int,
|
||||
new_prefill_tokens: int,
|
||||
) -> Optional[InterferenceResult]:
|
||||
"""Measure TPOT while a prefill request is being processed."""
|
||||
if decode_batch_size == 0:
|
||||
# No decode to interfere with; just measure prefill TTFT
|
||||
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
|
||||
t_start = time.perf_counter()
|
||||
ts = await stream_tokens(client, url, prefill_payload)
|
||||
prefill_ttft = (ts[0] - t_start) * 1000.0 if ts else 0
|
||||
return InterferenceResult(
|
||||
tpot_during_prefill_p50_ms=0,
|
||||
tpot_during_prefill_p90_ms=0,
|
||||
tpot_after_prefill_p50_ms=0,
|
||||
prefill_ttft_ms=prefill_ttft,
|
||||
num_tokens_during_prefill=0,
|
||||
)
|
||||
|
||||
# Phase 1: Start decode streams
|
||||
decode_payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
|
||||
|
||||
decode_timestamps: list[list[float]] = [[] for _ in range(decode_batch_size)]
|
||||
prefill_done_event = asyncio.Event()
|
||||
prefill_ttft_ms = 0.0
|
||||
prefill_inject_time = 0.0
|
||||
|
||||
async def decode_stream_with_tracking(idx: int, payload: dict):
|
||||
timestamps = await stream_tokens(client, url, payload)
|
||||
decode_timestamps[idx] = timestamps
|
||||
|
||||
async def prefill_after_warmup():
|
||||
nonlocal prefill_ttft_ms, prefill_inject_time
|
||||
# Wait for decode streams to stabilize
|
||||
await asyncio.sleep(1.0)
|
||||
prefill_inject_time = time.perf_counter()
|
||||
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
|
||||
ts = await stream_tokens(client, url, prefill_payload)
|
||||
if ts:
|
||||
prefill_ttft_ms = (ts[0] - prefill_inject_time) * 1000.0
|
||||
prefill_done_event.set()
|
||||
|
||||
# Launch all
|
||||
decode_tasks = [
|
||||
asyncio.create_task(decode_stream_with_tracking(i, p))
|
||||
for i, p in enumerate(decode_payloads)
|
||||
]
|
||||
prefill_task = asyncio.create_task(prefill_after_warmup())
|
||||
|
||||
await asyncio.gather(*decode_tasks, prefill_task, return_exceptions=True)
|
||||
|
||||
# Analyze: split decode tokens into "during prefill" and "after prefill"
|
||||
prefill_end_time = prefill_inject_time + prefill_ttft_ms / 1000.0
|
||||
|
||||
tpot_during = []
|
||||
tpot_after = []
|
||||
|
||||
for ts_list in decode_timestamps:
|
||||
if len(ts_list) < WARMUP_TOKENS + 5:
|
||||
continue
|
||||
for i in range(WARMUP_TOKENS + 1, len(ts_list)):
|
||||
t_prev = ts_list[i - 1]
|
||||
t_curr = ts_list[i]
|
||||
interval_ms = (t_curr - t_prev) * 1000.0
|
||||
|
||||
if prefill_inject_time <= t_prev <= prefill_end_time:
|
||||
tpot_during.append(interval_ms)
|
||||
elif t_curr > prefill_end_time + 0.05: # 50ms after prefill settles
|
||||
tpot_after.append(interval_ms)
|
||||
|
||||
during_arr = np.array(tpot_during) if tpot_during else np.array([0.0])
|
||||
after_arr = np.array(tpot_after) if tpot_after else np.array([0.0])
|
||||
|
||||
return InterferenceResult(
|
||||
tpot_during_prefill_p50_ms=float(np.percentile(during_arr, 50)),
|
||||
tpot_during_prefill_p90_ms=float(np.percentile(during_arr, 90)),
|
||||
tpot_after_prefill_p50_ms=float(np.percentile(after_arr, 50)),
|
||||
prefill_ttft_ms=prefill_ttft_ms,
|
||||
num_tokens_during_prefill=len(tpot_during),
|
||||
)
|
||||
|
||||
|
||||
async def run_single_config(
|
||||
client: httpx.AsyncClient,
|
||||
url: str,
|
||||
model: str,
|
||||
decode_batch_size: int,
|
||||
new_prefill_tokens: int,
|
||||
chunk_size: int,
|
||||
rep: int,
|
||||
output_dir: Path,
|
||||
):
|
||||
"""Run one (D, P) configuration."""
|
||||
config = Config(
|
||||
decode_batch_size=decode_batch_size,
|
||||
new_prefill_tokens=new_prefill_tokens,
|
||||
chunk_size=chunk_size,
|
||||
model=model,
|
||||
repetition=rep,
|
||||
)
|
||||
|
||||
print(f" [rep {rep}] Running baseline (D={decode_batch_size})...")
|
||||
baseline = await run_baseline(client, url, model, decode_batch_size)
|
||||
if baseline is None:
|
||||
print(f" [rep {rep}] Baseline failed, skipping")
|
||||
return
|
||||
|
||||
# Brief cooldown between baseline and interference
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
print(f" [rep {rep}] Running interference (D={decode_batch_size}, P={new_prefill_tokens})...")
|
||||
interference = await run_interference(
|
||||
client, url, model, decode_batch_size, new_prefill_tokens
|
||||
)
|
||||
if interference is None:
|
||||
print(f" [rep {rep}] Interference measurement failed, skipping")
|
||||
return
|
||||
|
||||
# Compute derived metrics
|
||||
tpot_penalty_p50 = interference.tpot_during_prefill_p50_ms - baseline.tpot_p50_ms
|
||||
penalty_ratio = (
|
||||
interference.tpot_during_prefill_p50_ms / baseline.tpot_p50_ms
|
||||
if baseline.tpot_p50_ms > 0 else 0
|
||||
)
|
||||
|
||||
result = {
|
||||
"config": asdict(config),
|
||||
"baseline": asdict(baseline),
|
||||
"interference": asdict(interference),
|
||||
"derived": {
|
||||
"tpot_penalty_p50_ms": tpot_penalty_p50,
|
||||
"tpot_penalty_ratio": penalty_ratio,
|
||||
},
|
||||
}
|
||||
|
||||
# Save
|
||||
fname = f"D{decode_batch_size}_P{new_prefill_tokens}_rep{rep}.json"
|
||||
out_path = output_dir / fname
|
||||
out_path.write_text(json.dumps(result, indent=2))
|
||||
print(f" [rep {rep}] Done. penalty={tpot_penalty_p50:.1f}ms ratio={penalty_ratio:.2f}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Prefill-Decode Interference Microbenchmark")
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", default="Qwen3-Coder-30B-A3B-Instruct")
|
||||
parser.add_argument("--decode-batch-sizes", default="0,1,2,4,6,8,12",
|
||||
help="Comma-separated decode batch sizes")
|
||||
parser.add_argument("--prefill-tokens", default="512,1024,2048,4096,8192,16384,32768",
|
||||
help="Comma-separated prefill token counts")
|
||||
parser.add_argument("--chunk-size", type=int, default=8192,
|
||||
help="vLLM max_num_batched_tokens (effective chunk size)")
|
||||
parser.add_argument("--reps", type=int, default=5)
|
||||
parser.add_argument("--output-dir", default="results/interference")
|
||||
args = parser.parse_args()
|
||||
|
||||
decode_sizes = [int(x) for x in args.decode_batch_sizes.split(",")]
|
||||
prefill_tokens = [int(x) for x in args.prefill_tokens.split(",")]
|
||||
|
||||
output_dir = Path(args.output_dir) / f"chunk{args.chunk_size}"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
url = f"http://{args.host}:{args.port}/v1/chat/completions"
|
||||
print(f"Target: {url}")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Chunk size: {args.chunk_size}")
|
||||
print(f"Decode batch sizes: {decode_sizes}")
|
||||
print(f"Prefill tokens: {prefill_tokens}")
|
||||
print(f"Repetitions: {args.reps}")
|
||||
print(f"Output: {output_dir}")
|
||||
print()
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as client:
|
||||
# Sanity check: is the server up?
|
||||
try:
|
||||
resp = await client.get(f"http://{args.host}:{args.port}/v1/models")
|
||||
resp.raise_for_status()
|
||||
models = resp.json()
|
||||
print(f"Server ready. Models: {[m['id'] for m in models.get('data', [])]}")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Cannot reach server at {args.host}:{args.port}: {e}")
|
||||
return
|
||||
|
||||
total_configs = len(decode_sizes) * len(prefill_tokens)
|
||||
done = 0
|
||||
|
||||
for D in decode_sizes:
|
||||
for P in prefill_tokens:
|
||||
done += 1
|
||||
print(f"\n[{done}/{total_configs}] D={D}, P={P}")
|
||||
|
||||
for rep in range(args.reps):
|
||||
try:
|
||||
await run_single_config(
|
||||
client, url, args.model, D, P,
|
||||
args.chunk_size, rep, output_dir,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" [rep {rep}] ERROR: {e}")
|
||||
|
||||
# Cooldown between reps
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Cooldown between configs
|
||||
await asyncio.sleep(3.0)
|
||||
|
||||
print("\n\nDone! Results in:", output_dir)
|
||||
# Generate summary CSV
|
||||
await generate_summary(output_dir, args.chunk_size)
|
||||
|
||||
|
||||
async def generate_summary(output_dir: Path, chunk_size: int):
|
||||
"""Aggregate all per-run JSONs into a summary CSV."""
|
||||
import csv
|
||||
|
||||
rows = []
|
||||
for f in sorted(output_dir.glob("D*_P*_rep*.json")):
|
||||
data = json.loads(f.read_text())
|
||||
cfg = data["config"]
|
||||
bl = data["baseline"]
|
||||
itf = data["interference"]
|
||||
drv = data["derived"]
|
||||
rows.append({
|
||||
"chunk_size": cfg["chunk_size"],
|
||||
"decode_batch_size": cfg["decode_batch_size"],
|
||||
"new_prefill_tokens": cfg["new_prefill_tokens"],
|
||||
"repetition": cfg["repetition"],
|
||||
"tpot_baseline_p50_ms": bl["tpot_p50_ms"],
|
||||
"tpot_baseline_p90_ms": bl["tpot_p90_ms"],
|
||||
"tpot_during_prefill_p50_ms": itf["tpot_during_prefill_p50_ms"],
|
||||
"tpot_during_prefill_p90_ms": itf["tpot_during_prefill_p90_ms"],
|
||||
"tpot_after_prefill_p50_ms": itf["tpot_after_prefill_p50_ms"],
|
||||
"prefill_ttft_ms": itf["prefill_ttft_ms"],
|
||||
"num_tokens_during_prefill": itf["num_tokens_during_prefill"],
|
||||
"tpot_penalty_p50_ms": drv["tpot_penalty_p50_ms"],
|
||||
"tpot_penalty_ratio": drv["tpot_penalty_ratio"],
|
||||
})
|
||||
|
||||
if not rows:
|
||||
return
|
||||
|
||||
csv_path = output_dir / "summary.csv"
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
print(f"Summary CSV written: {csv_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
57
microbench/interference/launch_vllm.sh
Normal file
57
microbench/interference/launch_vllm.sh
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
# Launch a single vLLM instance on GPU 0 for interference microbenchmark.
|
||||
# Uses TP=1, enable-chunked-prefill, enable-prefix-caching.
|
||||
#
|
||||
# Usage: bash launch_microbench1.sh [chunk_size] [port]
|
||||
# chunk_size: max_num_batched_tokens (default: 8192)
|
||||
# port: serving port (default: 8000)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
CHUNK_SIZE=${1:-8192}
|
||||
PORT=${2:-8000}
|
||||
MODEL="${MODEL:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||
GPU_ID=${GPU_ID:-0}
|
||||
LOG_FILE="vllm_microbench1_chunk${CHUNK_SIZE}.log"
|
||||
|
||||
echo "=== Interference Microbench vLLM Instance ==="
|
||||
echo "Model: $MODEL"
|
||||
echo "GPU: $GPU_ID"
|
||||
echo "Port: $PORT"
|
||||
echo "Chunk size (max_num_batched_tokens): $CHUNK_SIZE"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
# Kill any existing vLLM on this port
|
||||
pkill -f "vllm.*--port $PORT" 2>/dev/null || true
|
||||
sleep 2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPU_ID python -m vllm.entrypoints.openai.api_server \
|
||||
--model "$MODEL" \
|
||||
--tensor-parallel-size 1 \
|
||||
--enable-prefix-caching \
|
||||
--dtype auto \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--max-model-len 200000 \
|
||||
--max-num-batched-tokens "$CHUNK_SIZE" \
|
||||
--port "$PORT" \
|
||||
--trust-remote-code \
|
||||
--disable-log-requests \
|
||||
2>&1 | tee "$LOG_FILE" &
|
||||
|
||||
VLLM_PID=$!
|
||||
echo "vLLM PID: $VLLM_PID"
|
||||
echo "$VLLM_PID" > .vllm_microbench1.pid
|
||||
|
||||
# Wait for server to be ready
|
||||
echo "Waiting for server to start..."
|
||||
for i in $(seq 1 120); do
|
||||
if curl -s "http://127.0.0.1:$PORT/v1/models" > /dev/null 2>&1; then
|
||||
echo "Server ready after ${i}s!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "ERROR: Server did not start within 120s"
|
||||
exit 1
|
||||
36
microbench/interference/run_sweep.sh
Normal file
36
microbench/interference/run_sweep.sh
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
# Run the interference microbenchmark sweep.
|
||||
# Assumes vLLM is already running on the specified port.
|
||||
#
|
||||
# Usage: bash run_sweep.sh [port] [chunk_size]
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PORT=${1:-8000}
|
||||
CHUNK_SIZE=${2:-8192}
|
||||
REPS=${REPS:-5}
|
||||
OUTPUT_DIR="results/interference"
|
||||
|
||||
echo "=== Interference Microbench Sweep ==="
|
||||
echo "Server: http://127.0.0.1:$PORT"
|
||||
echo "Chunk size: $CHUNK_SIZE"
|
||||
echo "Reps: $REPS"
|
||||
echo "Output: $OUTPUT_DIR"
|
||||
echo ""
|
||||
|
||||
# Quick sanity check
|
||||
curl -sf "http://127.0.0.1:$PORT/v1/models" > /dev/null || {
|
||||
echo "ERROR: vLLM not reachable on port $PORT"
|
||||
exit 1
|
||||
}
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
python driver.py \
|
||||
--host 127.0.0.1 \
|
||||
--port "$PORT" \
|
||||
--chunk-size "$CHUNK_SIZE" \
|
||||
--decode-batch-sizes "0,1,2,4,6,8,12" \
|
||||
--prefill-tokens "512,1024,2048,4096,8192,16384,32768" \
|
||||
--reps "$REPS" \
|
||||
--output-dir "$OUTPUT_DIR"
|
||||
Reference in New Issue
Block a user