chore: vendor sglang v0.5.10 snapshot
This commit is contained in:
166
third_party/sglang/benchmark/asr/README.md
vendored
Normal file
166
third_party/sglang/benchmark/asr/README.md
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
# ASR Benchmark
|
||||
|
||||
This benchmark evaluates the performance and accuracy (Word Error Rate - WER) of Automatic Speech Recognition (ASR) models served via SGLang.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- `openai/whisper-large-v3`
|
||||
- `openai/whisper-large-v3-turbo`
|
||||
|
||||
## Setup
|
||||
|
||||
Install the required dependencies:
|
||||
|
||||
```bash
|
||||
apt install ffmpeg
|
||||
pip install librosa soundfile datasets evaluate jiwer transformers openai torchcodec torch
|
||||
```
|
||||
|
||||
## Running the Benchmark
|
||||
|
||||
### 1. Start SGLang Server
|
||||
|
||||
Launch the SGLang server with a Whisper model:
|
||||
|
||||
```bash
|
||||
python -m sglang.launch_server --model-path openai/whisper-large-v3 --port 30000
|
||||
```
|
||||
|
||||
### 2. Run the Benchmark Script
|
||||
|
||||
Basic usage (using chat completions API):
|
||||
|
||||
```bash
|
||||
python bench_sglang.py --base-url http://localhost:30000 --model openai/whisper-large-v3 --n-examples 10
|
||||
```
|
||||
|
||||
Using the OpenAI-compatible transcription API:
|
||||
|
||||
```bash
|
||||
python bench_sglang.py \
|
||||
--base-url http://localhost:30000 \
|
||||
--model openai/whisper-large-v3 \
|
||||
--api-type transcription \
|
||||
--language English \
|
||||
--n-examples 10
|
||||
```
|
||||
|
||||
Run with streaming and show real-time output:
|
||||
|
||||
```bash
|
||||
python bench_sglang.py \
|
||||
--base-url http://localhost:30000 \
|
||||
--model openai/whisper-large-v3 \
|
||||
--api-type transcription \
|
||||
--stream \
|
||||
--show-predictions \
|
||||
--concurrency 1
|
||||
```
|
||||
|
||||
Run with higher concurrency and save results:
|
||||
|
||||
```bash
|
||||
python bench_sglang.py \
|
||||
--base-url http://localhost:30000 \
|
||||
--model openai/whisper-large-v3 \
|
||||
--concurrency 8 \
|
||||
--n-examples 100 \
|
||||
--output results.json \
|
||||
--show-predictions
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
| Argument | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `--base-url` | SGLang server URL | `http://localhost:30000` |
|
||||
| `--model` | Model name on the server | `openai/whisper-large-v3` |
|
||||
| `--dataset` | HuggingFace dataset for evaluation | `D4nt3/esb-datasets-earnings22-validation-tiny-filtered` |
|
||||
| `--split` | Dataset split to use | `validation` |
|
||||
| `--concurrency` | Number of concurrent requests | `4` |
|
||||
| `--n-examples` | Number of examples to process (`-1` for all) | `-1` |
|
||||
| `--output` | Path to save results as JSON | `None` |
|
||||
| `--show-predictions` | Display sample predictions | `False` |
|
||||
| `--print-n` | Number of samples to display | `5` |
|
||||
| `--api-type` | API to use: `chat` (chat completions) or `transcription` (audio transcriptions) | `chat` |
|
||||
| `--language` | Language for transcription API (e.g., `English`, `en`) | `None` |
|
||||
| `--stream` | Enable streaming mode for transcription API | `False` |
|
||||
|
||||
## Metrics
|
||||
|
||||
The benchmark outputs:
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| **Total Requests** | Number of successful ASR requests processed |
|
||||
| **WER** | Word Error Rate (lower is better), computed using the `evaluate` library |
|
||||
| **Average Latency** | Mean time per request (seconds) |
|
||||
| **Median Latency** | 50th percentile latency (seconds) |
|
||||
| **95th Latency** | 95th percentile latency (seconds) |
|
||||
| **Throughput** | Requests processed per second |
|
||||
| **Token Throughput** | Output tokens per second |
|
||||
|
||||
## Example Output
|
||||
|
||||
```bash
|
||||
python bench_sglang.py --api-type transcription --concurrency 128 --model openai/whisper-large-v3 --show-predictions
|
||||
|
||||
Loading dataset: D4nt3/esb-datasets-earnings22-validation-tiny-filtered...
|
||||
Using API type: transcription
|
||||
Repo card metadata block was not found. Setting CardData to empty.
|
||||
WARNING:huggingface_hub.repocard:Repo card metadata block was not found. Setting CardData to empty.
|
||||
Performing warmup...
|
||||
Processing 511 samples...
|
||||
------------------------------
|
||||
Results for openai/whisper-large-v3:
|
||||
Total Requests: 511
|
||||
WER: 12.7690
|
||||
Average Latency: 1.3602s
|
||||
Median Latency: 1.2090s
|
||||
95th Latency: 2.9986s
|
||||
Throughput: 19.02 req/s
|
||||
Token Throughput: 354.19 tok/s
|
||||
Total Test Time: 26.8726s
|
||||
------------------------------
|
||||
|
||||
==================== Sample Predictions ====================
|
||||
Sample 1:
|
||||
REF: on the use of taxonomy i you know i think it is it is early days for us to to make any clear indications to the market about the proportion that would fall under that requirement
|
||||
PRED: on the eu taxonomy i think it is early days for us to make any clear indications to the market about the proportion that would fall under that requirement
|
||||
----------------------------------------
|
||||
Sample 2:
|
||||
REF: so within fiscal year 2021 say 120 a 100 depending on what the micro will do and next year it is not necessarily payable in q one is we will look at what the cash flows for 2022 look like
|
||||
PRED: so within fiscal year 2021 say $120000 $100000 depending on what the macro will do and next year it is not necessarily payable in q one is we will look at what the cash flows for 2022 look like
|
||||
----------------------------------------
|
||||
Sample 3:
|
||||
REF: we talked about 4.7 gigawatts
|
||||
PRED: we talked about 4.7 gigawatts
|
||||
----------------------------------------
|
||||
Sample 4:
|
||||
REF: and you know depending on that working capital build we will we will see what that yields
|
||||
PRED: and depending on that working capital build we will see what that yields what
|
||||
----------------------------------------
|
||||
Sample 5:
|
||||
REF: so on on sinopec what we have agreed with sinopec way back then is that free cash flows after paying all capexs are distributed out 30 70%
|
||||
PRED: so on sinopec what we have agreed with sinopec way back then is that free cash flows after paying all capexes are distributed out 30% 70%
|
||||
----------------------------------------
|
||||
============================================================
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Audio samples longer than 30 seconds are automatically filtered out (Whisper limitation)
|
||||
- The benchmark performs a warmup request before measuring performance
|
||||
- Results are normalized using the model's tokenizer when available
|
||||
- When using `--stream` with `--show-predictions`, use `--concurrency 1` for clean sequential output
|
||||
- The `--language` option accepts both full names (e.g., `English`) and ISO 639-1 codes (e.g., `en`)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Server connection refused**
|
||||
- Ensure the SGLang server is running and accessible at the specified `--base-url`
|
||||
- Check that the port is not blocked by a firewall
|
||||
|
||||
**Out of memory errors**
|
||||
- Reduce `--concurrency` to lower GPU memory usage
|
||||
- Use a smaller Whisper model variant
|
||||
404
third_party/sglang/benchmark/asr/bench_sglang.py
vendored
Normal file
404
third_party/sglang/benchmark/asr/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,404 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from statistics import mean, median
|
||||
|
||||
import httpx
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile
|
||||
from datasets import load_dataset
|
||||
from evaluate import load
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
|
||||
async def run_asr_chat(client, model_name, y, sr):
|
||||
"""Use chat completions API with audio_url for ASR."""
|
||||
with to_bytes(y, sr) as f:
|
||||
audio_bytes = f.read()
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": f"data:audio/wav;base64,{audio_base64}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
asr_text = response.choices[0].message.content
|
||||
latency = end_time - start_time
|
||||
return latency, asr_text
|
||||
|
||||
|
||||
def run_asr_transcription_sync(client, model_name, y, sr, language=None):
|
||||
"""Use audio transcriptions API for ASR (sync version)."""
|
||||
audio_buffer = to_bytes(y, sr)
|
||||
audio_buffer.name = "audio.wav" # OpenAI client needs a name attribute
|
||||
|
||||
start_time = time.perf_counter()
|
||||
kwargs = {
|
||||
"model": model_name,
|
||||
"file": audio_buffer,
|
||||
}
|
||||
if language:
|
||||
kwargs["language"] = language
|
||||
|
||||
transcription = client.audio.transcriptions.create(**kwargs)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
latency = end_time - start_time
|
||||
return latency, transcription.text
|
||||
|
||||
|
||||
def run_asr_transcription_stream_sync(
|
||||
base_url, model_name, y, sr, language=None, show_stream=False
|
||||
):
|
||||
"""Use audio transcriptions API with streaming for ASR."""
|
||||
audio_buffer = to_bytes(y, sr)
|
||||
audio_bytes = audio_buffer.read()
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"response_format": "json",
|
||||
"stream": "true",
|
||||
}
|
||||
if language:
|
||||
data["language"] = language
|
||||
|
||||
start_time = time.perf_counter()
|
||||
text_chunks = []
|
||||
|
||||
if show_stream:
|
||||
print("[STREAM] ", end="", flush=True)
|
||||
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
f"{base_url}/v1/audio/transcriptions",
|
||||
data=data,
|
||||
files={"file": ("audio.wav", audio_bytes, "audio/wav")},
|
||||
timeout=60.0,
|
||||
) as response:
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: ") and not line.startswith("data: [DONE]"):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
if "choices" in chunk and chunk["choices"]:
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
text_chunks.append(content)
|
||||
if show_stream:
|
||||
print(content, end="", flush=True)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if show_stream:
|
||||
print() # newline after stream
|
||||
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency, "".join(text_chunks)
|
||||
|
||||
|
||||
async def run_asr_transcription(
|
||||
client,
|
||||
model_name,
|
||||
y,
|
||||
sr,
|
||||
language=None,
|
||||
stream=False,
|
||||
base_url=None,
|
||||
show_stream=False,
|
||||
):
|
||||
"""Async wrapper for transcription API (runs sync call in executor)."""
|
||||
loop = asyncio.get_event_loop()
|
||||
if stream:
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
run_asr_transcription_stream_sync,
|
||||
base_url,
|
||||
model_name,
|
||||
y,
|
||||
sr,
|
||||
language,
|
||||
show_stream,
|
||||
)
|
||||
return await loop.run_in_executor(
|
||||
None, run_asr_transcription_sync, client, model_name, y, sr, language
|
||||
)
|
||||
|
||||
|
||||
async def bound_asr(
|
||||
sem,
|
||||
client,
|
||||
model_name,
|
||||
tokenizer,
|
||||
audio,
|
||||
reference,
|
||||
api_type="chat",
|
||||
language=None,
|
||||
stream=False,
|
||||
base_url=None,
|
||||
show_stream=False,
|
||||
):
|
||||
async with sem:
|
||||
try:
|
||||
if api_type == "transcription":
|
||||
latency, text = await run_asr_transcription(
|
||||
client,
|
||||
model_name,
|
||||
*audio,
|
||||
language=language,
|
||||
stream=stream,
|
||||
base_url=base_url,
|
||||
show_stream=show_stream,
|
||||
)
|
||||
else:
|
||||
latency, text = await run_asr_chat(client, model_name, *audio)
|
||||
|
||||
# Calculate tokens for throughput metrics
|
||||
num_output_tokens = len(tokenizer(text, add_special_tokens=False).input_ids)
|
||||
|
||||
# Normalize for WER evaluation
|
||||
# Whisper tokenizer has a normalize method
|
||||
if hasattr(tokenizer, "normalize"):
|
||||
out = tokenizer.normalize(text)
|
||||
ref = tokenizer.normalize(reference)
|
||||
else:
|
||||
out = text.lower().strip()
|
||||
ref = reference.lower().strip()
|
||||
|
||||
return latency, num_output_tokens, out, ref
|
||||
except Exception as e:
|
||||
print(f"Error during ASR: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def process_dataset(
|
||||
model_name,
|
||||
client,
|
||||
data,
|
||||
concurrent_request,
|
||||
api_type="chat",
|
||||
language=None,
|
||||
stream=False,
|
||||
base_url=None,
|
||||
show_predictions=False,
|
||||
):
|
||||
sem = asyncio.Semaphore(concurrent_request)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Warmup
|
||||
print("Performing warmup...")
|
||||
audio_warmup, sr_warmup = (
|
||||
data[0]["audio"]["array"],
|
||||
data[0]["audio"]["sampling_rate"],
|
||||
)
|
||||
await bound_asr(
|
||||
sem,
|
||||
client,
|
||||
model_name,
|
||||
tokenizer,
|
||||
(audio_warmup, sr_warmup),
|
||||
"",
|
||||
api_type=api_type,
|
||||
language=language,
|
||||
stream=stream,
|
||||
base_url=base_url,
|
||||
show_stream=False, # Don't show stream during warmup
|
||||
)
|
||||
|
||||
tasks = []
|
||||
print(f"Processing {len(data)} samples...")
|
||||
for sample in data:
|
||||
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
bound_asr(
|
||||
sem,
|
||||
client,
|
||||
model_name,
|
||||
tokenizer,
|
||||
(audio, sr),
|
||||
sample["text"],
|
||||
api_type=api_type,
|
||||
language=language,
|
||||
stream=stream,
|
||||
base_url=base_url,
|
||||
show_stream=show_predictions and stream,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
|
||||
def run_evaluation(args):
|
||||
# Use sync client for transcription API, async for chat API
|
||||
if args.api_type == "transcription":
|
||||
client = OpenAI(base_url=f"{args.base_url}/v1", api_key="None")
|
||||
else:
|
||||
client = AsyncOpenAI(base_url=f"{args.base_url}/v1", api_key="None")
|
||||
|
||||
print(f"Loading dataset: {args.dataset}...")
|
||||
print(f"Using API type: {args.api_type}" + (f" (streaming)" if args.stream else ""))
|
||||
dataset = load_dataset(args.dataset, split=args.split)
|
||||
|
||||
# Filter by duration if needed (Whisper max is 30s)
|
||||
def add_duration(sample):
|
||||
y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||
sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||
return sample
|
||||
|
||||
if "duration_ms" not in dataset.column_names:
|
||||
dataset = dataset.map(add_duration)
|
||||
|
||||
dataset = dataset.filter(lambda x: x["duration_ms"] < 30000)
|
||||
|
||||
if args.n_examples > 0:
|
||||
dataset = dataset.select(range(min(args.n_examples, len(dataset))))
|
||||
|
||||
start = time.perf_counter()
|
||||
results = asyncio.run(
|
||||
process_dataset(
|
||||
args.model,
|
||||
client,
|
||||
dataset,
|
||||
args.concurrency,
|
||||
api_type=args.api_type,
|
||||
language=args.language,
|
||||
stream=args.stream,
|
||||
base_url=args.base_url,
|
||||
show_predictions=args.show_predictions,
|
||||
)
|
||||
)
|
||||
total_test_time = time.perf_counter() - start
|
||||
|
||||
if not results:
|
||||
print("No successful results to evaluate.")
|
||||
return
|
||||
|
||||
# Metrics
|
||||
latencies = [res[0] for res in results]
|
||||
total_tokens = sum([res[1] for res in results])
|
||||
predictions = [res[2] for res in results]
|
||||
references = [res[3] for res in results]
|
||||
|
||||
wer_metric = load("wer")
|
||||
wer_score = 100 * wer_metric.compute(references=references, predictions=predictions)
|
||||
|
||||
print("-" * 30)
|
||||
print(f"Results for {args.model}:")
|
||||
print(f"Total Requests: {len(results)}")
|
||||
print(f"WER: {wer_score:.4f}")
|
||||
print(f"Average Latency: {mean(latencies):.4f}s")
|
||||
print(f"Median Latency: {median(latencies):.4f}s")
|
||||
print(f"95th Latency: {np.percentile(latencies, 95):.4f}s")
|
||||
print(f"Throughput: {len(results) / total_test_time:.2f} req/s")
|
||||
print(f"Token Throughput: {total_tokens / total_test_time:.2f} tok/s")
|
||||
print(f"Total Test Time: {total_test_time:.4f}s")
|
||||
print("-" * 30)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
import json
|
||||
|
||||
json.dump(
|
||||
{
|
||||
"model": args.model,
|
||||
"dataset": args.dataset,
|
||||
"wer": wer_score,
|
||||
"avg_latency": mean(latencies),
|
||||
"throughput": len(results) / total_test_time,
|
||||
"token_throughput": total_tokens / total_test_time,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
if args.show_predictions:
|
||||
print("\n" + "=" * 20 + " Sample Predictions " + "=" * 20)
|
||||
num_to_show = min(args.print_n, len(results))
|
||||
for i in range(num_to_show):
|
||||
print(f"Sample {i+1}:")
|
||||
print(f" REF: {references[i]}")
|
||||
print(f" PRED: {predictions[i]}")
|
||||
print("-" * 40)
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark sGLang ASR performance.")
|
||||
parser.add_argument(
|
||||
"--base-url", default="http://localhost:30000", help="sGLang server base URL"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", default="openai/whisper-large-v3", help="Model name on the server"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="D4nt3/esb-datasets-earnings22-validation-tiny-filtered",
|
||||
help="HF dataset repo",
|
||||
)
|
||||
parser.add_argument("--split", default="validation", help="Dataset split")
|
||||
parser.add_argument(
|
||||
"--concurrency", type=int, default=4, help="Number of concurrent requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-examples",
|
||||
"-n",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of examples to test (-1 for all)",
|
||||
)
|
||||
parser.add_argument("--output", help="Path to save results in JSON")
|
||||
parser.add_argument(
|
||||
"--show-predictions",
|
||||
action="store_true",
|
||||
help="Print sample predictions and references",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-n", type=int, default=5, help="Number of sample predictions to print"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-type",
|
||||
choices=["chat", "transcription"],
|
||||
default="chat",
|
||||
help="API type to use: 'chat' for chat completions with audio_url, 'transcription' for audio.transcriptions API",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
default=None,
|
||||
help="Language code for transcription API (e.g., 'en')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Use streaming mode for transcription API",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_evaluation(args)
|
||||
250
third_party/sglang/benchmark/bench_attention_sink/bench_attention_sink_triton.py
vendored
Normal file
250
third_party/sglang/benchmark/bench_attention_sink/bench_attention_sink_triton.py
vendored
Normal file
@@ -0,0 +1,250 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||
decode_attention_fwd_grouped,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
||||
|
||||
# gpt oss
|
||||
head_num = 64
|
||||
head_dim = 64
|
||||
head_kv_num = 8
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["S"], # sequence length on x-axis
|
||||
x_vals=[128, 256, 512, 1024, 2048, 4096],
|
||||
x_log=True,
|
||||
line_arg="B", # batch size as different lines
|
||||
line_vals=[1, 8, 32, 128],
|
||||
line_names=["B=1", "B=8", "B=32", "B=128"],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
("red", "-"),
|
||||
("cyan", "-"),
|
||||
],
|
||||
ylabel="TFLOPS",
|
||||
plot_name="attention-sink-triton-decode",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_decode(B, S, H_Q, H_KV, D):
|
||||
D_V = D
|
||||
dtype = torch.bfloat16
|
||||
seq_len = S
|
||||
total_tokens = B * seq_len
|
||||
device = torch.device("cuda")
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
max_kv_splits = 8
|
||||
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)
|
||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||
|
||||
attn_logits1 = torch.empty(
|
||||
(B, H_Q, max_kv_splits, D_V),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
attn_lse1 = torch.empty(
|
||||
(B, H_Q, max_kv_splits, D_V),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits1,
|
||||
attn_lse1,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sinks=sink,
|
||||
)
|
||||
|
||||
# benchmark
|
||||
run_step = 500
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(run_step):
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits1,
|
||||
attn_lse1,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sinks=sink,
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
ms = start_event.elapsed_time(end_event) / run_step
|
||||
tflops = lambda ms: (2 * B * S * H_Q * D) * 1e-9 / ms # must be causal
|
||||
return tflops(ms)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["S"], # sequence length on x-axis
|
||||
x_vals=[128, 256, 512, 1024, 2048, 4096],
|
||||
x_log=True,
|
||||
line_arg="B", # batch size as different lines
|
||||
line_vals=[1, 8, 32, 128],
|
||||
line_names=["B=1", "B=8", "B=32", "B=128"],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
("red", "-"),
|
||||
("cyan", "-"),
|
||||
],
|
||||
ylabel="TFLOPS",
|
||||
plot_name="attention-sink-triton-extend",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_extend(B, S, H_Q, H_KV, D):
|
||||
# S here represents N_CTX from the test
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
|
||||
# Split S into prefix and extend lengths
|
||||
prefill_len = S // 2 # Similar to test's N_CTX // 2
|
||||
extend_len = S // 4 # Make extend length smaller than prefix
|
||||
|
||||
# Calculate total tokens and extend tokens
|
||||
total_extend_tokens = B * extend_len
|
||||
total_prefix_tokens = B * prefill_len
|
||||
|
||||
# Create query, key, value tensors for extension
|
||||
q_extend = torch.randn(total_extend_tokens, H_Q, D, dtype=dtype, device=device)
|
||||
k_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
v_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
o_extend = torch.empty_like(q_extend)
|
||||
|
||||
# Create key-value buffers for prefix
|
||||
k_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
v_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
|
||||
# Create index pointers
|
||||
qo_indptr = torch.arange(0, (B + 1) * extend_len, extend_len, device=device).to(
|
||||
torch.int32
|
||||
)
|
||||
kv_indptr = torch.arange(0, (B + 1) * prefill_len, prefill_len, device=device).to(
|
||||
torch.int32
|
||||
)
|
||||
kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32)
|
||||
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
# sliding_window = 128 # From GPT-OSS config, skip for now
|
||||
sliding_window = -1
|
||||
|
||||
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask=None,
|
||||
is_causal=True,
|
||||
mask_indptr=None,
|
||||
max_len_extend=extend_len,
|
||||
sm_scale=sm_scale,
|
||||
sliding_window_size=sliding_window,
|
||||
sinks=sink,
|
||||
)
|
||||
|
||||
# benchmark
|
||||
run_step = 500
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(run_step):
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask=None,
|
||||
is_causal=True,
|
||||
mask_indptr=None,
|
||||
max_len_extend=extend_len,
|
||||
sm_scale=sm_scale,
|
||||
sliding_window_size=sliding_window,
|
||||
sinks=sink,
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
ms = start_event.elapsed_time(end_event) / run_step
|
||||
|
||||
# FLOPS calculation: each attention operation requires 2 multiplications per element
|
||||
total_flops = 2 * total_extend_tokens * H_Q * (prefill_len + extend_len / 2) * D
|
||||
tflops = lambda ms: total_flops * 1e-12 / (ms * 1e-3) # convert to TFLOPS
|
||||
return tflops(ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--bench", type=str, default="all", help="all, extend, decode")
|
||||
args = parser.parse_args()
|
||||
|
||||
kwargs = {
|
||||
"H_Q": head_num,
|
||||
"H_KV": head_kv_num,
|
||||
"D": head_dim,
|
||||
}
|
||||
|
||||
if args.bench in ["all", "decode"]:
|
||||
benchmark_decode.run(print_data=True, show_plots=False, **kwargs)
|
||||
|
||||
if args.bench in ["all", "extend"]:
|
||||
benchmark_extend.run(print_data=True, show_plots=False, **kwargs)
|
||||
|
||||
print("Benchmark finished!")
|
||||
130
third_party/sglang/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py
vendored
Normal file
130
third_party/sglang/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.
|
||||
#
|
||||
# Launch a server:
|
||||
# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning
|
||||
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import sglang as sgl
|
||||
from sglang import set_default_backend
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
|
||||
|
||||
def generate_random_string(token_length: int) -> str:
|
||||
random_string = "".join(
|
||||
random.choices(string.ascii_letters + string.digits, k=token_length * 100)
|
||||
)
|
||||
tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[
|
||||
:token_length
|
||||
]
|
||||
|
||||
if len(tokenized_output) < token_length:
|
||||
tokenized_output = tokenized_output + [tokenizer.pad_token_id] * (
|
||||
token_length - len(tokenized_output)
|
||||
)
|
||||
|
||||
decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False)
|
||||
return decoded_string
|
||||
|
||||
|
||||
def generate_unique_prefix(base_text, index):
|
||||
return str(index) + base_text[len(str(index)) :]
|
||||
|
||||
|
||||
@sgl.function
|
||||
def text_qa(s, question, gen_len):
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len)
|
||||
|
||||
|
||||
def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length):
|
||||
base_prefix = generate_random_string(prefix_length)
|
||||
|
||||
tot_input_len = 0
|
||||
all_prompts = []
|
||||
for i in tqdm(range(num_prefix), desc="prepare prompts"):
|
||||
unique_prefix = generate_unique_prefix(base_prefix, i)
|
||||
prompt_list = []
|
||||
for j in range(num_samples_per_prefix):
|
||||
suffix = generate_random_string(suffix_length)
|
||||
prompt = unique_prefix + suffix
|
||||
prompt_list.append(prompt)
|
||||
tot_input_len += len(tokenizer.encode(prompt))
|
||||
all_prompts.append(prompt_list)
|
||||
return all_prompts, tot_input_len
|
||||
|
||||
|
||||
def test_batch_by_batch(all_prompts, gen_len):
|
||||
backend.flush_cache()
|
||||
|
||||
tot_time = 0
|
||||
for i in range(len(all_prompts)):
|
||||
tic = time.perf_counter()
|
||||
text_qa.run_batch(
|
||||
list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))),
|
||||
)
|
||||
tot_time += time.perf_counter() - tic
|
||||
|
||||
return tot_time
|
||||
|
||||
|
||||
def test_batch_by_batch_with_hint(all_prompts, gen_len):
|
||||
backend.flush_cache()
|
||||
|
||||
tot_time = 0
|
||||
for i in range(len(all_prompts)):
|
||||
tic = time.perf_counter()
|
||||
# Send a hint to cache the prefix
|
||||
text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len])))
|
||||
# Send the batch
|
||||
text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))))
|
||||
|
||||
tot_time += time.perf_counter() - tic
|
||||
|
||||
return tot_time
|
||||
|
||||
|
||||
def test_send_all(all_prompts, gen_len):
|
||||
backend.flush_cache()
|
||||
|
||||
all_prompts = [x for prompt_list in all_prompts for x in prompt_list]
|
||||
|
||||
tic = time.perf_counter()
|
||||
text_qa.run_batch(
|
||||
list(zip(all_prompts, [gen_len] * len(all_prompts))),
|
||||
)
|
||||
tot_time = time.perf_counter() - tic
|
||||
|
||||
return tot_time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
backend = RuntimeEndpoint("http://127.0.0.1:30000")
|
||||
set_default_backend(backend)
|
||||
|
||||
random.seed(0)
|
||||
num_prefix = 10
|
||||
num_samples_per_prefix = 32
|
||||
prefix_length = 1024
|
||||
suffix_length = 128
|
||||
gen_len = 1
|
||||
all_prompts, tot_input_len = prepare_prompts(
|
||||
num_prefix, num_samples_per_prefix, prefix_length, suffix_length
|
||||
)
|
||||
|
||||
print(f"Total input token length: {tot_input_len}\n")
|
||||
|
||||
cost = test_batch_by_batch(all_prompts, gen_len)
|
||||
print(f"Latency of test_batch_by_batch : {cost:.4f} s\n")
|
||||
|
||||
cost = test_batch_by_batch_with_hint(all_prompts, gen_len)
|
||||
print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n")
|
||||
|
||||
cost = test_send_all(all_prompts, gen_len)
|
||||
print(f"Latency of test_send_all : {cost:.4f} s\n")
|
||||
481
third_party/sglang/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py
vendored
Normal file
481
third_party/sglang/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py
vendored
Normal file
@@ -0,0 +1,481 @@
|
||||
"""Benchmark & Correctness: CuTe DSL KDA Decode vs Triton KDA Decode.
|
||||
|
||||
This benchmark assumes the production / Triton canonical state layout:
|
||||
ssm_states.shape == (pool_size, HV, V, K)
|
||||
|
||||
Both the Triton baseline and the CuTe DSL candidate operate directly on that VK
|
||||
layout. No transpose is performed anywhere in the benchmark.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "python"))
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.jit_kernel.cutedsl_kda import cutedsl_fused_sigmoid_gating_kda_update
|
||||
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.kda import chunk_kda
|
||||
|
||||
|
||||
def make_inputs(
|
||||
B: int,
|
||||
H: int,
|
||||
HV: int,
|
||||
K: int,
|
||||
V: int,
|
||||
pool_size: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
layout: str,
|
||||
seed: int = 42,
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
assert K == 128
|
||||
assert V % 16 == 0 and V % 32 == 0
|
||||
|
||||
if layout == "varlen":
|
||||
q = torch.randn(1, B, H, K, device=device, dtype=dtype)
|
||||
k = torch.randn(1, B, H, K, device=device, dtype=dtype)
|
||||
v = torch.randn(1, B, HV, V, device=device, dtype=dtype)
|
||||
|
||||
# decode params
|
||||
a = torch.randn(B, HV, K, device=device, dtype=dtype)
|
||||
b = torch.randn(B, HV, device=device, dtype=dtype)
|
||||
|
||||
# prefill params for chunk_kda must keep batch dim = 1
|
||||
# chunk_kda requires g, beta, v to have the same head count as k (H),
|
||||
# matching the real KimiLinear model where num_heads == num_kv_heads.
|
||||
prefill_v = torch.randn(1, B, H, V, device=device, dtype=dtype)
|
||||
prefill_g = torch.randn(1, B, H, K, device=device, dtype=dtype)
|
||||
prefill_beta = torch.sigmoid(torch.randn(1, B, H, device=device, dtype=dtype))
|
||||
|
||||
cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32)
|
||||
|
||||
elif layout == "dense":
|
||||
q = torch.randn(B, 1, H, K, device=device, dtype=dtype)
|
||||
k = torch.randn(B, 1, H, K, device=device, dtype=dtype)
|
||||
v = torch.randn(B, 1, HV, V, device=device, dtype=dtype)
|
||||
|
||||
# decode params
|
||||
a = torch.randn(B, 1, HV, K, device=device, dtype=dtype)
|
||||
b = torch.randn(B, 1, HV, device=device, dtype=dtype)
|
||||
|
||||
# prefill params for chunk_kda dense path
|
||||
# chunk_kda requires g, beta, v to have the same head count as k (H),
|
||||
# matching the real KimiLinear model where num_heads == num_kv_heads.
|
||||
prefill_v = torch.randn(B, 1, H, V, device=device, dtype=dtype)
|
||||
prefill_g = torch.randn(B, 1, H, K, device=device, dtype=dtype)
|
||||
prefill_beta = torch.sigmoid(torch.randn(B, 1, H, device=device, dtype=dtype))
|
||||
|
||||
cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32)
|
||||
else:
|
||||
raise ValueError(f"Unknown layout: {layout}")
|
||||
|
||||
A_log = torch.randn(HV, device=device, dtype=torch.float32)
|
||||
dt_bias = torch.randn(HV, K, device=device, dtype=dtype)
|
||||
|
||||
ssm_states = (
|
||||
torch.randn(pool_size, HV, V, K, device=device, dtype=torch.float32) * 0.1
|
||||
)
|
||||
cache_indices = torch.arange(B, device=device, dtype=torch.int32)
|
||||
|
||||
return dict(
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
pool_size=pool_size,
|
||||
layout=layout,
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
a=a,
|
||||
b=b,
|
||||
prefill_v=prefill_v,
|
||||
prefill_g=prefill_g,
|
||||
prefill_beta=prefill_beta,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
ssm_states=ssm_states,
|
||||
cache_indices=cache_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
|
||||
def run_baseline(inp):
|
||||
state = inp["ssm_states"].clone()
|
||||
o = fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=inp["v"],
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=state,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
is_kda=True,
|
||||
)
|
||||
return o, state
|
||||
|
||||
|
||||
def run_cutedsl(inp):
|
||||
state = inp["ssm_states"].clone()
|
||||
o = cutedsl_fused_sigmoid_gating_kda_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=inp["v"],
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=state,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
return o, state
|
||||
|
||||
|
||||
def run_prefill_then_decode_baseline(inp):
|
||||
ssm_states = inp["ssm_states"].clone()
|
||||
prefill_v_clone = inp["prefill_v"].clone()
|
||||
v_clone = inp["v"].clone()
|
||||
|
||||
_ = chunk_kda(
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=prefill_v_clone,
|
||||
g=inp["prefill_g"],
|
||||
beta=inp["prefill_beta"],
|
||||
initial_state=ssm_states,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
)
|
||||
|
||||
o = fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=v_clone,
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=ssm_states,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
is_kda=True,
|
||||
)
|
||||
return o, ssm_states
|
||||
|
||||
|
||||
def run_prefill_then_decode_cutedsl(inp):
|
||||
ssm_states = inp["ssm_states"].clone()
|
||||
prefill_v_clone = inp["prefill_v"].clone()
|
||||
v_clone = inp["v"].clone()
|
||||
|
||||
_ = chunk_kda(
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=prefill_v_clone,
|
||||
g=inp["prefill_g"],
|
||||
beta=inp["prefill_beta"],
|
||||
initial_state=ssm_states,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
)
|
||||
|
||||
o = cutedsl_fused_sigmoid_gating_kda_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=v_clone,
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=ssm_states,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
return o, ssm_states
|
||||
|
||||
|
||||
def _assert_close(name, x, y, atol=3e-2, rtol=2e-2):
|
||||
try:
|
||||
torch.testing.assert_close(x.float(), y.float(), atol=atol, rtol=rtol)
|
||||
return True, 0.0
|
||||
except AssertionError:
|
||||
max_diff = (x - y).abs().max().item()
|
||||
return False, max_diff
|
||||
|
||||
|
||||
def check_correctness(B, H, HV, K, V, pool_size, device, dtype, layout):
|
||||
tag = (
|
||||
f"layout={layout:<6} B={B:>4} H={H:>2} HV={HV:>2} "
|
||||
f"K={K:>3} V={V:>3} pool={pool_size:>4}"
|
||||
)
|
||||
inp = make_inputs(B, H, HV, K, V, pool_size, device, dtype, layout)
|
||||
|
||||
o_ref, st_ref = run_baseline(inp)
|
||||
o_cute, st_cute = run_cutedsl(inp)
|
||||
|
||||
ok_o, diff_o = _assert_close("output", o_cute, o_ref)
|
||||
valid_mask = inp["cache_indices"] >= 0
|
||||
valid_idx = inp["cache_indices"][valid_mask]
|
||||
ok_s, diff_s = _assert_close("state", st_cute[valid_idx], st_ref[valid_idx])
|
||||
|
||||
if ok_o and ok_s:
|
||||
print(f" [PASS] {tag}")
|
||||
return True
|
||||
|
||||
details = []
|
||||
if not ok_o:
|
||||
details.append(f"output max_diff={diff_o:.6f}")
|
||||
if not ok_s:
|
||||
details.append(f"state max_diff={diff_s:.6f}")
|
||||
print(f" [FAIL] {tag} ({', '.join(details)})")
|
||||
return False
|
||||
|
||||
|
||||
def check_prefill_chain(B, H, HV, K, V, pool_size, device, dtype, layout):
|
||||
tag = (
|
||||
f"[prefill->decode] layout={layout:<6} B={B:>4} H={H:>2} HV={HV:>2} "
|
||||
f"K={K:>3} V={V:>3} pool={pool_size:>4}"
|
||||
)
|
||||
inp = make_inputs(B, H, HV, K, V, pool_size, device, dtype, layout)
|
||||
|
||||
o_ref, st_ref = run_prefill_then_decode_baseline(inp)
|
||||
o_cute, st_cute = run_prefill_then_decode_cutedsl(inp)
|
||||
|
||||
ok_o, diff_o = _assert_close("output", o_cute, o_ref)
|
||||
valid_mask = inp["cache_indices"] >= 0
|
||||
valid_idx = inp["cache_indices"][valid_mask]
|
||||
ok_s, diff_s = _assert_close("state", st_cute[valid_idx], st_ref[valid_idx])
|
||||
|
||||
if ok_o and ok_s:
|
||||
print(f" [PASS] {tag}")
|
||||
return True
|
||||
|
||||
details = []
|
||||
if not ok_o:
|
||||
details.append(f"output max_diff={diff_o:.6f}")
|
||||
if not ok_s:
|
||||
details.append(f"state max_diff={diff_s:.6f}")
|
||||
print(f" [FAIL] {tag} ({', '.join(details)})")
|
||||
return False
|
||||
|
||||
|
||||
def bench_shape(B, H, HV, K, V, pool_size, device, dtype, layout):
|
||||
inp = make_inputs(B, H, HV, K, V, pool_size, device, dtype, layout)
|
||||
|
||||
def fn_triton():
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=inp["v"],
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=inp["ssm_states"],
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
is_kda=True,
|
||||
)
|
||||
|
||||
def fn_cute():
|
||||
cutedsl_fused_sigmoid_gating_kda_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=inp["v"],
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=inp["ssm_states"],
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
fn_triton()
|
||||
fn_cute()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
try:
|
||||
ms_triton, _, _ = triton.testing.do_bench(
|
||||
fn_triton, quantiles=[0.5, 0.2, 0.8], warmup=50, rep=200
|
||||
)
|
||||
ms_cute, _, _ = triton.testing.do_bench(
|
||||
fn_cute, quantiles=[0.5, 0.2, 0.8], warmup=50, rep=200
|
||||
)
|
||||
except Exception:
|
||||
rep = 100
|
||||
st = time.perf_counter()
|
||||
for _ in range(rep):
|
||||
fn_triton()
|
||||
torch.cuda.synchronize()
|
||||
ms_triton = (time.perf_counter() - st) / rep * 1000
|
||||
|
||||
st = time.perf_counter()
|
||||
for _ in range(rep):
|
||||
fn_cute()
|
||||
torch.cuda.synchronize()
|
||||
ms_cute = (time.perf_counter() - st) / rep * 1000
|
||||
|
||||
speedup = ms_triton / ms_cute if ms_cute > 0 else float("inf")
|
||||
delta = (ms_cute - ms_triton) * 1000
|
||||
print(
|
||||
f" {layout:>6} {B:>5} {H:>3} {HV:>3} {K:>3} {V:>3} | "
|
||||
f"{ms_triton * 1000:>12.1f} | "
|
||||
f"{ms_cute * 1000:>13.1f} | "
|
||||
f"{speedup:>8.2f} | "
|
||||
f"{delta:>11.1f}"
|
||||
)
|
||||
|
||||
|
||||
def run_correctness(device, dtype):
|
||||
print("=" * 78)
|
||||
print("Correctness: Triton KDA Decode vs CuTe DSL KDA Decode")
|
||||
print("=" * 78)
|
||||
|
||||
shapes = [
|
||||
("dense", 1, 8, 16, 128, 128, 32),
|
||||
("dense", 4, 8, 16, 128, 128, 32),
|
||||
("dense", 32, 8, 16, 128, 128, 128),
|
||||
("dense", 64, 8, 16, 128, 128, 128),
|
||||
("varlen", 4, 8, 16, 128, 128, 32),
|
||||
("varlen", 16, 8, 16, 128, 128, 64),
|
||||
("varlen", 32, 8, 16, 128, 128, 128),
|
||||
("varlen", 64, 8, 16, 128, 128, 128),
|
||||
("varlen", 1, 16, 32, 128, 128, 32),
|
||||
("varlen", 32, 16, 32, 128, 128, 128),
|
||||
("varlen", 64, 16, 16, 128, 128, 128),
|
||||
]
|
||||
|
||||
all_pass = True
|
||||
for layout, B, H, HV, K, V, pool_size in shapes:
|
||||
if not check_correctness(B, H, HV, K, V, pool_size, device, dtype, layout):
|
||||
all_pass = False
|
||||
|
||||
print()
|
||||
print("=" * 78)
|
||||
print("Correctness: Triton prefill/extend -> CuTe decode chain")
|
||||
print("=" * 78)
|
||||
for layout, B, H, HV, K, V, pool_size in shapes[:8]:
|
||||
if not check_prefill_chain(B, H, HV, K, V, pool_size, device, dtype, layout):
|
||||
all_pass = False
|
||||
|
||||
print()
|
||||
print("ALL PASSED." if all_pass else "SOME FAILED.")
|
||||
return all_pass
|
||||
|
||||
|
||||
def run_benchmark(device, dtype):
|
||||
print()
|
||||
print("=" * 92)
|
||||
print("Benchmark: Triton KDA Decode vs CuTe DSL KDA Decode")
|
||||
print("=" * 92)
|
||||
|
||||
bench_configs = [
|
||||
("dense", 1, 8, 16),
|
||||
("dense", 4, 8, 16),
|
||||
("dense", 32, 8, 16),
|
||||
("dense", 64, 8, 16),
|
||||
("varlen", 1, 8, 16),
|
||||
("varlen", 4, 8, 16),
|
||||
("varlen", 8, 8, 16),
|
||||
("varlen", 16, 8, 16),
|
||||
("varlen", 32, 8, 16),
|
||||
("varlen", 64, 8, 16),
|
||||
("varlen", 128, 8, 16),
|
||||
("varlen", 32, 16, 32),
|
||||
("varlen", 64, 16, 16),
|
||||
]
|
||||
|
||||
K = 128
|
||||
V = 128
|
||||
pool_size = 512
|
||||
|
||||
print(f" Config: K={K}, V={V}, pool_size={pool_size}, dtype={dtype}")
|
||||
print(
|
||||
f" {'layout':>6} {'B':>5} {'H':>3} {'HV':>3} {'K':>3} {'V':>3} | "
|
||||
f"{'triton (μs)':>12} | "
|
||||
f"{'cutedsl (μs)':>13} | "
|
||||
f"{'speedup':>8} | "
|
||||
f"{'delta (μs)':>11}"
|
||||
)
|
||||
print(" " + "-" * 82)
|
||||
|
||||
for layout, B, H, HV in bench_configs:
|
||||
actual_pool = max(pool_size, B + 16)
|
||||
bench_shape(B, H, HV, K, V, actual_pool, device, dtype, layout)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark & Correctness: Triton KDA Decode vs CuTe DSL KDA Decode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["all", "correctness", "bench"],
|
||||
default="all",
|
||||
help="Run mode (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="bfloat16",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda"
|
||||
dtype = getattr(torch, args.dtype)
|
||||
|
||||
cap = torch.cuda.get_device_capability()
|
||||
dev_name = torch.cuda.get_device_name()
|
||||
print(f"Device: {dev_name} (SM {cap[0]}{cap[1]})")
|
||||
|
||||
if args.mode in ("all", "correctness"):
|
||||
all_pass = run_correctness(device, dtype)
|
||||
if not all_pass and args.mode == "all":
|
||||
print("\nSkipping benchmark due to correctness failures.")
|
||||
return 1
|
||||
|
||||
if args.mode in ("all", "bench"):
|
||||
run_benchmark(device, dtype)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
488
third_party/sglang/benchmark/bench_linear_attention/bench_gdn_decode.py
vendored
Normal file
488
third_party/sglang/benchmark/bench_linear_attention/bench_gdn_decode.py
vendored
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Benchmark & Correctness: GDN Packed Decode vs Baseline Decode.
|
||||
|
||||
Compares:
|
||||
- Baseline: split(mixed_qkv) → view → fused_sigmoid_gating_delta_rule_update
|
||||
- Packed: fused_recurrent_gated_delta_rule_packed_decode (single kernel)
|
||||
|
||||
The packed path eliminates:
|
||||
- torch.split() + .view() tensor materialization
|
||||
- Separate gating kernel launches
|
||||
- Intermediate tensor allocations
|
||||
|
||||
Reports correctness (output & state matching) and performance (ms, speedup).
|
||||
|
||||
Usage:
|
||||
python bench_gdn_decode.py # default sweep
|
||||
python bench_gdn_decode.py --mode bench # benchmark only
|
||||
python bench_gdn_decode.py --mode correctness # correctness only
|
||||
python bench_gdn_decode.py --preset qwen3.5-35b # Qwen3.5-35B-A3B config
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python"))
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention.fla.fused_recurrent import (
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_inputs(
|
||||
B: int,
|
||||
H: int,
|
||||
HV: int,
|
||||
K: int,
|
||||
V: int,
|
||||
pool_size: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 42,
|
||||
):
|
||||
"""Create all input tensors for a single benchmark / correctness run."""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
qkv_dim = 2 * H * K + HV * V
|
||||
mixed_qkv = torch.randn(B, qkv_dim, device=device, dtype=dtype)
|
||||
a = torch.randn(B, HV, device=device, dtype=dtype)
|
||||
b = torch.randn(B, HV, device=device, dtype=dtype)
|
||||
A_log = torch.randn(HV, device=device, dtype=dtype)
|
||||
dt_bias = torch.randn(HV, device=device, dtype=dtype)
|
||||
|
||||
ssm_states = torch.randn(pool_size, HV, V, K, device=device, dtype=dtype) * 0.1
|
||||
cache_indices = torch.arange(B, device=device, dtype=torch.int32)
|
||||
|
||||
cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.long)
|
||||
|
||||
return dict(
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
qkv_dim=qkv_dim,
|
||||
pool_size=pool_size,
|
||||
mixed_qkv=mixed_qkv,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
ssm_states=ssm_states,
|
||||
cache_indices=cache_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runner wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_baseline(inp):
|
||||
"""Baseline path: split → view → fused_sigmoid_gating_delta_rule_update.
|
||||
|
||||
This mirrors the FULL original decode path in GDNAttnBackend.forward_decode,
|
||||
including the split, view, and kernel call.
|
||||
"""
|
||||
B, H, HV, K, V = inp["B"], inp["H"], inp["HV"], inp["K"], inp["V"]
|
||||
mixed_qkv = inp["mixed_qkv"]
|
||||
ssm_states = inp["ssm_states"].clone()
|
||||
|
||||
# Step 1: split (same as forward_decode)
|
||||
q_flat, k_flat, v_flat = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1)
|
||||
|
||||
# Step 2: view + reshape (same as forward_decode)
|
||||
q = q_flat.view(1, B, H, K)
|
||||
k = k_flat.view(1, B, H, K)
|
||||
v = v_flat.view(1, B, HV, V)
|
||||
|
||||
# Step 3: fused gating + recurrent update
|
||||
o = fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=ssm_states,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
|
||||
return o, ssm_states
|
||||
|
||||
|
||||
def run_packed(inp):
|
||||
"""Packed path: single fused kernel directly on mixed_qkv."""
|
||||
B, HV, K, V = inp["B"], inp["HV"], inp["K"], inp["V"]
|
||||
ssm_states = inp["ssm_states"].clone()
|
||||
out = inp["mixed_qkv"].new_empty(B, 1, HV, V)
|
||||
|
||||
fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv=inp["mixed_qkv"],
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
scale=inp["K"] ** -0.5,
|
||||
initial_state=ssm_states,
|
||||
out=out,
|
||||
ssm_state_indices=inp["cache_indices"],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
# Convert [B, 1, HV, V] → [1, B, HV, V] to match baseline layout
|
||||
return out.transpose(0, 1), ssm_states
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Correctness check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def check_correctness(B, H, HV, K, V, pool_size, device, dtype, seed=42):
|
||||
"""Run correctness check for a single config. Returns True if PASS."""
|
||||
tag = f"B={B:>4} H={H:>2} HV={HV:>2} K={K:>3} V={V:>3} pool={pool_size:>4}"
|
||||
|
||||
inp = make_inputs(B, H, HV, K, V, pool_size, device, dtype, seed=seed)
|
||||
|
||||
o_baseline, state_baseline = run_baseline(inp)
|
||||
o_packed, state_packed = run_packed(inp)
|
||||
|
||||
# Output comparison
|
||||
atol = 2e-2 if dtype != torch.float32 else 1e-4
|
||||
rtol = 1e-2 if dtype != torch.float32 else 1e-4
|
||||
|
||||
try:
|
||||
torch.testing.assert_close(o_packed, o_baseline, atol=atol, rtol=rtol)
|
||||
output_ok = True
|
||||
except AssertionError as e:
|
||||
output_ok = False
|
||||
out_diff = (o_packed - o_baseline).abs().max().item()
|
||||
|
||||
# State comparison (only for slots that were updated)
|
||||
indices = inp["cache_indices"]
|
||||
try:
|
||||
torch.testing.assert_close(
|
||||
state_packed[indices], state_baseline[indices], atol=atol, rtol=rtol
|
||||
)
|
||||
state_ok = True
|
||||
except AssertionError:
|
||||
state_ok = False
|
||||
st_diff = (state_packed[indices] - state_baseline[indices]).abs().max().item()
|
||||
|
||||
passed = output_ok and state_ok
|
||||
|
||||
if passed:
|
||||
print(f" [PASS] {tag}")
|
||||
else:
|
||||
details = []
|
||||
if not output_ok:
|
||||
details.append(f"output max_diff={out_diff:.6f}")
|
||||
if not state_ok:
|
||||
details.append(f"state max_diff={st_diff:.6f}")
|
||||
print(f" [FAIL] {tag} ({', '.join(details)})")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bench_shape(B, H, HV, K, V, pool_size, device, dtype):
|
||||
"""Benchmark baseline vs packed for a single config."""
|
||||
inp = make_inputs(B, H, HV, K, V, pool_size, device, dtype)
|
||||
|
||||
# ── Baseline: full path including split + view ──
|
||||
def fn_baseline():
|
||||
q_flat, k_flat, v_flat = torch.split(
|
||||
inp["mixed_qkv"], [H * K, H * K, HV * V], dim=-1
|
||||
)
|
||||
q = q_flat.view(1, B, H, K)
|
||||
k = k_flat.view(1, B, H, K)
|
||||
v = v_flat.view(1, B, HV, V)
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
initial_state_source=inp["ssm_states"],
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
|
||||
# ── Packed: single kernel ──
|
||||
out_buf = inp["mixed_qkv"].new_empty(B, 1, HV, V)
|
||||
|
||||
def fn_packed():
|
||||
fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv=inp["mixed_qkv"],
|
||||
a=inp["a"],
|
||||
b=inp["b"],
|
||||
A_log=inp["A_log"],
|
||||
dt_bias=inp["dt_bias"],
|
||||
scale=K**-0.5,
|
||||
initial_state=inp["ssm_states"],
|
||||
out=out_buf,
|
||||
ssm_state_indices=inp["cache_indices"],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
fn_baseline()
|
||||
fn_packed()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
try:
|
||||
ms_baseline, ms_base_lo, ms_base_hi = triton.testing.do_bench(
|
||||
fn_baseline, quantiles=quantiles, warmup=50, rep=200
|
||||
)
|
||||
ms_packed, ms_pack_lo, ms_pack_hi = triton.testing.do_bench(
|
||||
fn_packed, quantiles=quantiles, warmup=50, rep=200
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to manual timing
|
||||
torch.cuda.synchronize()
|
||||
N = 200
|
||||
start = time.perf_counter()
|
||||
for _ in range(N):
|
||||
fn_baseline()
|
||||
torch.cuda.synchronize()
|
||||
ms_baseline = (time.perf_counter() - start) / N * 1000
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(N):
|
||||
fn_packed()
|
||||
torch.cuda.synchronize()
|
||||
ms_packed = (time.perf_counter() - start) / N * 1000
|
||||
|
||||
speedup = ms_baseline / ms_packed if ms_packed > 0 else float("inf")
|
||||
saved_us = (ms_baseline - ms_packed) * 1000
|
||||
|
||||
print(
|
||||
f" {B:>5} {H:>3} {HV:>3} {K:>3} {V:>3} | "
|
||||
f"{ms_baseline * 1000:>10.1f} | "
|
||||
f"{ms_packed * 1000:>10.1f} | "
|
||||
f"{speedup:>7.2f}x | "
|
||||
f"{saved_us:>+9.1f}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_correctness(device, dtype):
|
||||
print("=" * 70)
|
||||
print("Correctness: Baseline GDN Decode vs Packed GDN Decode")
|
||||
print("=" * 70)
|
||||
|
||||
shapes = [
|
||||
# (B, H, HV, K, V, pool_size)
|
||||
# --- Qwen3.5-35B-A3B style (TP=2: H=8, HV=16) ---
|
||||
(1, 8, 16, 128, 128, 32),
|
||||
(4, 8, 16, 128, 128, 32),
|
||||
(16, 8, 16, 128, 128, 64),
|
||||
(32, 8, 16, 128, 128, 128),
|
||||
(64, 8, 16, 128, 128, 128),
|
||||
(128, 8, 16, 128, 128, 256),
|
||||
(256, 8, 16, 128, 128, 512),
|
||||
# --- Qwen3.5-35B-A3B style (TP=1: H=16, HV=32) ---
|
||||
(1, 16, 32, 128, 128, 32),
|
||||
(32, 16, 32, 128, 128, 128),
|
||||
(64, 16, 32, 128, 128, 128),
|
||||
# --- Qwen3-Next-80B-A3B style ---
|
||||
(32, 16, 16, 128, 128, 128),
|
||||
(64, 16, 16, 128, 128, 128),
|
||||
# --- With PAD_SLOT_ID ---
|
||||
(32, 8, 16, 128, 128, 128), # some indices may be padded
|
||||
# --- Edge cases ---
|
||||
(1, 8, 16, 128, 128, 32),
|
||||
(2, 8, 16, 128, 128, 32),
|
||||
]
|
||||
|
||||
all_pass = True
|
||||
for B, H, HV, K, V, pool_size in shapes:
|
||||
if not check_correctness(B, H, HV, K, V, pool_size, device, dtype):
|
||||
all_pass = False
|
||||
|
||||
# PAD_SLOT_ID test
|
||||
print("\n PAD_SLOT_ID test (indices with -1):")
|
||||
inp = make_inputs(32, 8, 16, 128, 128, 128, device, dtype)
|
||||
o_baseline, st_baseline = run_baseline(inp)
|
||||
o_packed, st_packed = run_packed(inp)
|
||||
|
||||
try:
|
||||
torch.testing.assert_close(o_packed, o_baseline, atol=2e-2, rtol=1e-2)
|
||||
print(" [PASS] PAD_SLOT_ID=-1 handling")
|
||||
except AssertionError:
|
||||
print(" [FAIL] PAD_SLOT_ID=-1 handling")
|
||||
all_pass = False
|
||||
|
||||
print()
|
||||
if all_pass:
|
||||
print("ALL PASSED.")
|
||||
else:
|
||||
print("SOME FAILED.")
|
||||
return all_pass
|
||||
|
||||
|
||||
def run_benchmark(device, dtype, args):
|
||||
print()
|
||||
print("=" * 85)
|
||||
print("Benchmark: Baseline GDN Decode vs Packed GDN Decode")
|
||||
print("=" * 85)
|
||||
|
||||
K = args.head_size_k
|
||||
V = args.head_size_v
|
||||
pool_size = args.pool_size
|
||||
|
||||
if args.preset == "qwen3.5-35b":
|
||||
# Qwen3.5-35B-A3B: H_qk=16, H_v=32, K=128, V=128
|
||||
# After TP=2: H=8, HV=16
|
||||
bench_configs = [
|
||||
# (B, H, HV) — TP=2 config
|
||||
(1, 8, 16),
|
||||
(2, 8, 16),
|
||||
(4, 8, 16),
|
||||
(8, 8, 16),
|
||||
(16, 8, 16),
|
||||
(32, 8, 16),
|
||||
(64, 8, 16),
|
||||
(128, 8, 16),
|
||||
(256, 8, 16),
|
||||
(512, 8, 16),
|
||||
# TP=1 config (full heads)
|
||||
(1, 16, 32),
|
||||
(8, 16, 32),
|
||||
(32, 16, 32),
|
||||
(64, 16, 32),
|
||||
(128, 16, 32),
|
||||
(256, 16, 32),
|
||||
]
|
||||
elif args.preset == "qwen3-next-80b":
|
||||
bench_configs = [
|
||||
# Qwen3-Next-80B-A3B: all same H=HV=16 after TP
|
||||
(1, 16, 16),
|
||||
(8, 16, 16),
|
||||
(32, 16, 16),
|
||||
(64, 16, 16),
|
||||
(128, 16, 16),
|
||||
(256, 16, 16),
|
||||
]
|
||||
else:
|
||||
bench_configs = []
|
||||
for B in args.batch_sizes:
|
||||
for H in args.num_q_heads:
|
||||
for HV in args.num_v_heads:
|
||||
bench_configs.append((B, H, HV))
|
||||
|
||||
print(f" Config: K={K}, V={V}, pool_size={pool_size}, dtype={dtype}")
|
||||
print(
|
||||
f" {'B':>5} {'H':>3} {'HV':>3} {'K':>3} {'V':>3} | "
|
||||
f"{'base (μs)':>10} | "
|
||||
f"{'packed (μs)':>10} | "
|
||||
f"{'speedup':>8} | "
|
||||
f"{'saved (μs)':>10}"
|
||||
)
|
||||
print(" " + "-" * 75)
|
||||
|
||||
for B, H, HV in bench_configs:
|
||||
actual_pool = max(pool_size, B + 16)
|
||||
bench_shape(B, H, HV, K, V, actual_pool, device, dtype)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark & Correctness: GDN Packed Decode vs Baseline"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["all", "correctness", "bench"],
|
||||
default="all",
|
||||
help="Run mode (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preset",
|
||||
choices=["qwen3.5-35b", "qwen3-next-80b", "custom"],
|
||||
default="qwen3.5-35b",
|
||||
help="Preset config (default: qwen3.5-35b)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="bfloat16",
|
||||
)
|
||||
parser.add_argument("--head-size-k", type=int, default=128)
|
||||
parser.add_argument("--head-size-v", type=int, default=128)
|
||||
parser.add_argument("--pool-size", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 4, 8, 16, 32, 64, 128, 256, 512],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-q-heads",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8, 16],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-v-heads",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[16, 32],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda"
|
||||
dtype = getattr(torch, args.dtype)
|
||||
|
||||
cap = torch.cuda.get_device_capability()
|
||||
dev_name = torch.cuda.get_device_name()
|
||||
print(f"Device: {dev_name} (SM {cap[0]}{cap[1]})")
|
||||
|
||||
if args.mode in ("all", "correctness"):
|
||||
all_pass = run_correctness(device, dtype)
|
||||
if not all_pass and args.mode == "all":
|
||||
print("\nSkipping benchmark due to correctness failures.")
|
||||
return 1
|
||||
|
||||
if args.mode in ("all", "bench"):
|
||||
run_benchmark(device, dtype, args)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
639
third_party/sglang/benchmark/bench_linear_attention/bench_gdn_prefill.py
vendored
Normal file
639
third_party/sglang/benchmark/bench_linear_attention/bench_gdn_prefill.py
vendored
Normal file
@@ -0,0 +1,639 @@
|
||||
"""
|
||||
Benchmark & Correctness: Triton GDN vs FlashInfer GDN (prefill).
|
||||
|
||||
Compares:
|
||||
- Triton: sglang's chunk_gated_delta_rule (K-contiguous pool, pool-indexed)
|
||||
- FlashInfer: flashinfer's chunk_gated_delta_rule (gather/scatter, 3D tensors)
|
||||
|
||||
The two kernels have different APIs:
|
||||
- Triton: q/k/v=[1,T,H,D], g=logsigmoid, beta=sigmoid, has initial_state_indices
|
||||
- FlashInfer: q/k/v=[T,H,D], g=alpha(float32), beta=float32, no indices (gathered state)
|
||||
|
||||
Reports correctness (output & state matching) and performance (ms, TFLOPS, TB/s).
|
||||
|
||||
Usage:
|
||||
python benchmark_gdn_prefill.py # default sweep
|
||||
python benchmark_gdn_prefill.py --mode bench # benchmark only
|
||||
python benchmark_gdn_prefill.py --mode correctness # correctness only
|
||||
python benchmark_gdn_prefill.py --preset qwen3-next # Qwen3-Next config
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "python"))
|
||||
|
||||
import torch
|
||||
from flashinfer.gdn_prefill import (
|
||||
chunk_gated_delta_rule as flashinfer_chunk_gated_delta_rule,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.attention.fla.chunk import (
|
||||
chunk_gated_delta_rule as triton_chunk_gated_delta_rule,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_k_contiguous(t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Given a V-contiguous tensor [..., K, V], return a K-contiguous view of the
|
||||
same logical shape [..., K, V] (physically [..., V, K], K-last).
|
||||
"""
|
||||
return t.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
|
||||
|
||||
def gdn_flops(
|
||||
total_seq_len: int,
|
||||
num_heads: int,
|
||||
head_size_k: int,
|
||||
head_size_v: int,
|
||||
) -> int:
|
||||
"""
|
||||
FLOPs for GDN prefill (delta rule).
|
||||
|
||||
Per token per head:
|
||||
1. k @ v^T (outer product): 2 * K * V
|
||||
2. q @ state (output): 2 * K * V
|
||||
"""
|
||||
outer_product_flops = 2 * total_seq_len * num_heads * head_size_k * head_size_v
|
||||
output_flops = 2 * total_seq_len * num_heads * head_size_k * head_size_v
|
||||
return outer_product_flops + output_flops
|
||||
|
||||
|
||||
def gdn_bytes(
|
||||
total_seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_v_heads: int,
|
||||
head_size_k: int,
|
||||
head_size_v: int,
|
||||
num_seqs: int,
|
||||
dtype: torch.dtype,
|
||||
) -> int:
|
||||
"""Memory bytes accessed (inputs + outputs + state)."""
|
||||
num_o_heads = max(num_q_heads, num_v_heads)
|
||||
elem = dtype.itemsize
|
||||
|
||||
q_bytes = total_seq_len * num_q_heads * head_size_k * elem
|
||||
k_bytes = total_seq_len * num_v_heads * head_size_k * elem
|
||||
v_bytes = total_seq_len * num_v_heads * head_size_v * elem
|
||||
o_bytes = total_seq_len * num_o_heads * head_size_v * elem
|
||||
|
||||
# state (float32): read + write
|
||||
state_bytes = 2 * num_seqs * num_o_heads * head_size_k * head_size_v * 4
|
||||
|
||||
# g, beta (float32)
|
||||
g_bytes = total_seq_len * num_o_heads * 4
|
||||
beta_bytes = total_seq_len * num_o_heads * 4
|
||||
|
||||
return q_bytes + k_bytes + v_bytes + o_bytes + state_bytes + g_bytes + beta_bytes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_inputs(
|
||||
B: int,
|
||||
T_per_seq: int,
|
||||
H: int,
|
||||
K: int,
|
||||
V: int,
|
||||
pool_size: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
sequential_indices: bool = False,
|
||||
seed: int = 42,
|
||||
):
|
||||
"""Create all input tensors for a single benchmark / correctness run.
|
||||
|
||||
Returns a dict with both Triton-format and FlashInfer-format tensors.
|
||||
"""
|
||||
T = B * T_per_seq
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if sequential_indices:
|
||||
cache_indices = torch.arange(B, dtype=torch.int32, device=device)
|
||||
else:
|
||||
perm = torch.randperm(pool_size, device=device)[:B]
|
||||
cache_indices = perm.to(torch.int32)
|
||||
|
||||
pool_init = torch.randn(pool_size, H, K, V, dtype=dtype, device=device) * 0.1
|
||||
|
||||
cu_seqlens = torch.arange(
|
||||
0, (B + 1) * T_per_seq, T_per_seq, dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# Triton format: [1, T, H, D]
|
||||
q = torch.randn(1, T, H, K, dtype=dtype, device=device)
|
||||
k = torch.randn(1, T, H, K, dtype=dtype, device=device)
|
||||
v = torch.randn(1, T, H, V, dtype=dtype, device=device)
|
||||
|
||||
# g (logsigmoid) and beta (sigmoid) in Triton format: [1, T, H]
|
||||
g_raw = torch.randn(1, T, H, dtype=dtype, device=device)
|
||||
g_triton = torch.nn.functional.logsigmoid(g_raw) # logsigmoid for Triton
|
||||
beta_triton = torch.sigmoid(torch.randn(1, T, H, dtype=dtype, device=device))
|
||||
|
||||
return dict(
|
||||
B=B,
|
||||
T=T,
|
||||
T_per_seq=T_per_seq,
|
||||
H=H,
|
||||
K=K,
|
||||
V=V,
|
||||
pool_size=pool_size,
|
||||
cache_indices=cache_indices,
|
||||
pool_init=pool_init,
|
||||
cu_seqlens=cu_seqlens,
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g_triton=g_triton,
|
||||
beta_triton=beta_triton,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runner wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_triton(inp):
|
||||
"""Triton path: K-contiguous pool, pool-indexed, [1,T,H,D] tensors."""
|
||||
pool = make_k_contiguous(inp["pool_init"].clone())
|
||||
|
||||
o, _, h = triton_chunk_gated_delta_rule(
|
||||
q=inp["q"],
|
||||
k=inp["k"],
|
||||
v=inp["v"],
|
||||
g=inp["g_triton"],
|
||||
beta=inp["beta_triton"],
|
||||
initial_state=pool,
|
||||
initial_state_indices=inp["cache_indices"],
|
||||
cu_seqlens=inp["cu_seqlens"],
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
return o, pool, h
|
||||
|
||||
|
||||
def run_flashinfer(inp):
|
||||
"""FlashInfer path: matches sglang FlashInferGDNKernel.extend() exactly.
|
||||
|
||||
Key differences from Triton path:
|
||||
- q, k are L2-normalized BEFORE calling the kernel
|
||||
- use_qk_l2norm_in_kernel=False (kernel skips internal normalization)
|
||||
- Tensors are [T, H, D] (no batch dim)
|
||||
- g is alpha = exp(logsigmoid(...)) = sigmoid(...), float32
|
||||
- beta is float32
|
||||
- initial_state is gathered from pool (no pool-index support)
|
||||
- Uses keyword arguments (matching sglang production code)
|
||||
|
||||
NOTE: FlashInfer GDN requires K == V (square head_size).
|
||||
"""
|
||||
K = inp["K"]
|
||||
V = inp["V"]
|
||||
assert K == V, f"FlashInfer GDN requires K == V, got K={K}, V={V}"
|
||||
|
||||
pool = make_k_contiguous(inp["pool_init"].clone())
|
||||
cache_indices = inp["cache_indices"]
|
||||
|
||||
# Gather states from K-contiguous pool -> K-contiguous float32
|
||||
# In production, ssm_states is already float32 so .float() is no-op.
|
||||
# Here pool_init is bf16, so .float() loses K-contiguous layout.
|
||||
gathered = pool[cache_indices]
|
||||
initial_state = make_k_contiguous(gathered.float().contiguous())
|
||||
|
||||
q_fi = l2norm_fwd(inp["q"][0].contiguous())
|
||||
k_fi = l2norm_fwd(inp["k"][0].contiguous())
|
||||
v_fi = inp["v"][0].contiguous()
|
||||
|
||||
# g -> alpha (exp of logsigmoid = sigmoid), float32
|
||||
alpha_fi = torch.exp(inp["g_triton"][0].to(torch.float32))
|
||||
# beta -> float32
|
||||
beta_fi = inp["beta_triton"][0].to(torch.float32)
|
||||
|
||||
cu_seqlens_fi = inp["cu_seqlens"].to(torch.int64)
|
||||
|
||||
# Call FlashInfer with keyword args (matching sglang production code)
|
||||
# use_qk_l2norm_in_kernel=False because we pre-normalized above
|
||||
o_fi, state_fi = flashinfer_chunk_gated_delta_rule(
|
||||
q=q_fi,
|
||||
k=k_fi,
|
||||
v=v_fi,
|
||||
g=alpha_fi,
|
||||
beta=beta_fi,
|
||||
scale=None,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens_fi,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
)
|
||||
|
||||
# Scatter updated states back to K-contiguous pool
|
||||
pool[cache_indices] = state_fi.to(pool.dtype)
|
||||
|
||||
# Reshape output: [T, H, D] -> [1, T, H, D] to match Triton
|
||||
o_out = o_fi.unsqueeze(0)
|
||||
|
||||
return o_out, pool, state_fi
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Correctness check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def check_shape(
|
||||
B,
|
||||
T_per_seq,
|
||||
H,
|
||||
K,
|
||||
V,
|
||||
pool_size,
|
||||
device,
|
||||
dtype,
|
||||
sequential_indices=False,
|
||||
seed=42,
|
||||
):
|
||||
"""Run correctness check for a single shape config. Returns True if PASS.
|
||||
|
||||
Pass/fail is based on OUTPUT comparison only (atol=5e-2).
|
||||
Pool state diff is reported as informational — state divergence over many
|
||||
tokens is expected due to different chunk sizes and accumulation order.
|
||||
"""
|
||||
tag = (
|
||||
f"B={B:>3} T/seq={T_per_seq:>4} H={H:>2} K={K:>3} V={V:>3} pool={pool_size:>4}"
|
||||
)
|
||||
idx_tag = " (seq)" if sequential_indices else ""
|
||||
|
||||
# FlashInfer GDN requires K == V (square head_size)
|
||||
if K != V:
|
||||
print(f" [SKIP] {tag}{idx_tag} (FlashInfer requires K==V)")
|
||||
return True
|
||||
|
||||
# FlashInfer GDN CUTLASS kernels are only compiled for head_size=128.
|
||||
# Running with other sizes causes illegal memory access that poisons
|
||||
# the CUDA context (unrecoverable), so we must skip upfront.
|
||||
FLASHINFER_SUPPORTED_HEAD_SIZES = {128}
|
||||
if K not in FLASHINFER_SUPPORTED_HEAD_SIZES:
|
||||
print(
|
||||
f" [SKIP] {tag}{idx_tag} (FlashInfer only supports head_size={FLASHINFER_SUPPORTED_HEAD_SIZES})"
|
||||
)
|
||||
return True
|
||||
|
||||
inp = make_inputs(
|
||||
B,
|
||||
T_per_seq,
|
||||
H,
|
||||
K,
|
||||
V,
|
||||
pool_size,
|
||||
device,
|
||||
dtype,
|
||||
sequential_indices=sequential_indices,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
o_triton, pool_triton, h_triton = run_triton(inp)
|
||||
|
||||
# FlashInfer may not support all head_size values (e.g., only 128).
|
||||
# CUDA errors from unsupported configs are often asynchronous, so we
|
||||
# must synchronize inside the try block to catch them here.
|
||||
try:
|
||||
o_fi, pool_fi, _ = run_flashinfer(inp)
|
||||
torch.cuda.synchronize()
|
||||
except Exception as e:
|
||||
# Catch RuntimeError, torch.AcceleratorError, etc.
|
||||
# Reset CUDA error state so subsequent tests can proceed
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
print(f" [SKIP] {tag}{idx_tag} (FlashInfer error: {e})")
|
||||
return True
|
||||
|
||||
cache_indices = inp["cache_indices"]
|
||||
|
||||
# --- Output comparison ---
|
||||
# bf16 prefill with L2norm + chunked accumulation
|
||||
torch.testing.assert_close(o_triton, o_fi, atol=5e-2, rtol=1e-2)
|
||||
|
||||
# --- Stride check ---
|
||||
def strides_ok(pool):
|
||||
s = pool.stride()
|
||||
return s[-2] == 1 and s[-1] == K
|
||||
|
||||
strides_triton = strides_ok(pool_triton)
|
||||
strides_fi = strides_ok(pool_fi)
|
||||
|
||||
passed = strides_triton and strides_fi
|
||||
|
||||
# Build detail string
|
||||
details = []
|
||||
if not strides_triton:
|
||||
details.append("triton strides bad")
|
||||
if not strides_fi:
|
||||
details.append("flashinfer strides bad")
|
||||
|
||||
status = "PASS" if passed else "FAIL"
|
||||
detail_str = f" [{', '.join(details)}]"
|
||||
print(f" [{status}] {tag}{idx_tag}")
|
||||
return passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bench_shape(B, H, T_per_seq, K, V, pool_size, device, dtype):
|
||||
"""Benchmark Triton vs FlashInfer for a single config. Requires K == V."""
|
||||
import triton.testing
|
||||
|
||||
assert K == V, f"FlashInfer GDN requires K == V, got K={K}, V={V}"
|
||||
|
||||
T = B * T_per_seq
|
||||
inp = make_inputs(B, T_per_seq, H, K, V, pool_size, device, dtype)
|
||||
|
||||
# -- Shared read-only tensors --
|
||||
q, k_t, v = inp["q"], inp["k"], inp["v"]
|
||||
g_triton, beta_triton = inp["g_triton"], inp["beta_triton"]
|
||||
cu_seqlens = inp["cu_seqlens"]
|
||||
cache_indices = inp["cache_indices"]
|
||||
seq_indices = torch.arange(B, dtype=torch.int32, device=device)
|
||||
pool_v = inp["pool_init"]
|
||||
|
||||
def fn_triton():
|
||||
pool = make_k_contiguous(pool_v.clone())
|
||||
triton_chunk_gated_delta_rule(
|
||||
q=q,
|
||||
k=k_t,
|
||||
v=v,
|
||||
g=g_triton,
|
||||
beta=beta_triton,
|
||||
initial_state=pool,
|
||||
initial_state_indices=cache_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
def fn_flashinfer():
|
||||
# -- Pre-compute FlashInfer format tensors (outside timing) --
|
||||
# Pre-normalize q and k (matching sglang production: l2norm_fwd)
|
||||
# q_fi = torch.nn.functional.normalize(q[0].contiguous().float(), p=2.0, dim=-1).to(
|
||||
# dtype
|
||||
# )
|
||||
# k_fi = torch.nn.functional.normalize(k_t[0].contiguous().float(), p=2.0, dim=-1).to(
|
||||
# dtype
|
||||
# )
|
||||
q_fi = l2norm_fwd(q[0].contiguous())
|
||||
k_fi = l2norm_fwd(k_t[0].contiguous())
|
||||
v_fi = v[0].contiguous()
|
||||
alpha_fi = torch.exp(g_triton[0].to(torch.float32))
|
||||
beta_fi = beta_triton[0].to(torch.float32)
|
||||
cu_seqlens_fi = cu_seqlens.to(torch.int64)
|
||||
pool = make_k_contiguous(pool_v.clone())
|
||||
gathered = pool[cache_indices]
|
||||
initial_state = make_k_contiguous(gathered.float().contiguous())
|
||||
flashinfer_chunk_gated_delta_rule(
|
||||
q=q_fi,
|
||||
k=k_fi,
|
||||
v=v_fi,
|
||||
g=alpha_fi,
|
||||
beta=beta_fi,
|
||||
scale=None,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens_fi,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
# Warmup
|
||||
fn_triton()
|
||||
fn_flashinfer()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ms_triton, _, _ = triton.testing.do_bench_cudagraph(fn_triton, quantiles=quantiles)
|
||||
ms_fi, _, _ = triton.testing.do_bench_cudagraph(fn_flashinfer, quantiles=quantiles)
|
||||
|
||||
# Metrics
|
||||
num_o_heads = H
|
||||
flops = gdn_flops(T, num_o_heads, K, V)
|
||||
mem_bytes = gdn_bytes(T, H, H, K, V, B, dtype)
|
||||
|
||||
tflops_triton = flops / ms_triton / 1e9
|
||||
tflops_fi = flops / ms_fi / 1e9
|
||||
tb_s_triton = mem_bytes / ms_triton / 1e9
|
||||
tb_s_fi = mem_bytes / ms_fi / 1e9
|
||||
|
||||
speedup = ms_triton / ms_fi if ms_fi > 0 else float("inf")
|
||||
|
||||
print(
|
||||
f" {B:>5} {H:>3} {T_per_seq:>6} {T:>7} | "
|
||||
f"{ms_triton:>8.3f} {tflops_triton:>7.2f} {tb_s_triton:>7.2f} | "
|
||||
f"{ms_fi:>8.3f} {tflops_fi:>7.2f} {tb_s_fi:>7.2f} | "
|
||||
f"{speedup:>7.2f}x"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_correctness(device, dtype):
|
||||
print("=" * 78)
|
||||
print("Correctness sweep: Triton vs FlashInfer")
|
||||
print("=" * 78)
|
||||
|
||||
shapes = [
|
||||
# (B, T_per_seq, H, K, V, pool_size)
|
||||
# --- baseline (Qwen3-Next style) ---
|
||||
(4, 64, 16, 128, 128, 32),
|
||||
(4, 256, 16, 128, 128, 32),
|
||||
# --- different batch sizes ---
|
||||
(1, 128, 16, 128, 128, 32),
|
||||
(8, 128, 16, 128, 128, 64),
|
||||
(16, 64, 16, 128, 128, 128),
|
||||
(32, 32, 16, 128, 128, 256),
|
||||
# --- different head counts ---
|
||||
(4, 128, 4, 128, 128, 32),
|
||||
(4, 128, 8, 128, 128, 32),
|
||||
(4, 128, 16, 64, 64, 32),
|
||||
(4, 128, 32, 128, 128, 32),
|
||||
(4, 128, 64, 128, 128, 32),
|
||||
# --- short sequences ---
|
||||
(4, 1, 16, 128, 128, 32),
|
||||
(4, 7, 16, 128, 128, 32),
|
||||
(4, 16, 16, 128, 128, 32),
|
||||
# --- large pool (sparse access) ---
|
||||
(4, 128, 16, 128, 128, 512),
|
||||
# --- combined stress ---
|
||||
(32, 128, 32, 128, 128, 256),
|
||||
]
|
||||
|
||||
shapes_seq = [
|
||||
(8, 128, 16, 128, 128, 8),
|
||||
(4, 128, 32, 128, 128, 4),
|
||||
(4, 128, 64, 128, 128, 4),
|
||||
(32, 128, 32, 128, 128, 32),
|
||||
]
|
||||
|
||||
all_pass = True
|
||||
for B, T_per_seq, H, K, V, pool_size in shapes:
|
||||
if not check_shape(B, T_per_seq, H, K, V, pool_size, device, dtype):
|
||||
all_pass = False
|
||||
|
||||
print()
|
||||
print("Sequential-index variants:")
|
||||
for B, T_per_seq, H, K, V, pool_size in shapes_seq:
|
||||
if not check_shape(
|
||||
B,
|
||||
T_per_seq,
|
||||
H,
|
||||
K,
|
||||
V,
|
||||
pool_size,
|
||||
device,
|
||||
dtype,
|
||||
sequential_indices=True,
|
||||
):
|
||||
all_pass = False
|
||||
|
||||
print()
|
||||
if all_pass:
|
||||
print("ALL PASSED.")
|
||||
else:
|
||||
print("SOME FAILED.")
|
||||
return all_pass
|
||||
|
||||
|
||||
def run_benchmark(device, dtype, args):
|
||||
print()
|
||||
print("=" * 105)
|
||||
print("Benchmark: Triton GDN vs FlashInfer GDN (do_bench_cudagraph)")
|
||||
print("=" * 105)
|
||||
|
||||
K = args.head_size_k
|
||||
V = args.head_size_v
|
||||
pool_size = args.pool_size
|
||||
|
||||
if args.preset == "qwen3-next":
|
||||
bench_configs = [
|
||||
# (B, H, T_per_seq)
|
||||
(4, 16, 256),
|
||||
(4, 32, 256),
|
||||
(16, 16, 256),
|
||||
(16, 32, 256),
|
||||
(32, 16, 256),
|
||||
(32, 32, 256),
|
||||
(64, 16, 256),
|
||||
(64, 32, 256),
|
||||
(128, 16, 256),
|
||||
(128, 32, 256),
|
||||
# longer sequences
|
||||
(4, 16, 1024),
|
||||
(4, 32, 1024),
|
||||
(32, 16, 1024),
|
||||
(32, 32, 1024),
|
||||
]
|
||||
else:
|
||||
bench_configs = []
|
||||
for B in args.batch_sizes:
|
||||
for H in args.num_heads:
|
||||
for T_per_seq in args.seq_lens:
|
||||
bench_configs.append((B, H, T_per_seq))
|
||||
|
||||
print(f" Config: K={K}, V={V}, pool_size={pool_size}, dtype={dtype}")
|
||||
print(
|
||||
f" {'B':>5} {'H':>3} {'T/seq':>6} {'T_tot':>7} | "
|
||||
f"{'tri(ms)':>8} {'TFLOPS':>7} {'TB/s':>7} | "
|
||||
f"{'fi(ms)':>8} {'TFLOPS':>7} {'TB/s':>7} | "
|
||||
f"{'speedup':>8}"
|
||||
)
|
||||
print(" " + "-" * 98)
|
||||
|
||||
for B, H, T_per_seq in bench_configs:
|
||||
actual_pool = max(pool_size, B)
|
||||
bench_shape(B, H, T_per_seq, K, V, actual_pool, device, dtype)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark & Correctness: Triton GDN vs FlashInfer GDN"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["all", "correctness", "bench"],
|
||||
default="all",
|
||||
help="Run mode (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preset",
|
||||
choices=["qwen3-next", "custom"],
|
||||
default="qwen3-next",
|
||||
help="Preset config (default: qwen3-next)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=["float16", "bfloat16"],
|
||||
default="bfloat16",
|
||||
)
|
||||
parser.add_argument("--head-size-k", type=int, default=128)
|
||||
parser.add_argument("--head-size-v", type=int, default=128)
|
||||
parser.add_argument("--pool-size", type=int, default=256)
|
||||
parser.add_argument(
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[4, 16, 32, 64, 128],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-heads",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[16, 32],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-lens",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[128, 256, 512, 1024],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.preset == "qwen3-next":
|
||||
args.head_size_k = 128
|
||||
args.head_size_v = 128
|
||||
|
||||
device = "cuda"
|
||||
dtype = getattr(torch, args.dtype)
|
||||
|
||||
# Check SM version
|
||||
cap = torch.cuda.get_device_capability()
|
||||
dev_name = torch.cuda.get_device_name()
|
||||
print(f"Device: {dev_name} (SM {cap[0]}{cap[1]})")
|
||||
|
||||
if args.mode in ("all", "correctness"):
|
||||
all_pass = run_correctness(device, dtype)
|
||||
if not all_pass and args.mode == "all":
|
||||
print("\nSkipping benchmark due to correctness failures.")
|
||||
return 1
|
||||
|
||||
if args.mode in ("all", "bench"):
|
||||
run_benchmark(device, dtype, args)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
425
third_party/sglang/benchmark/bench_rope/benchmark_rope_index.py
vendored
Normal file
425
third_party/sglang/benchmark/bench_rope/benchmark_rope_index.py
vendored
Normal file
@@ -0,0 +1,425 @@
|
||||
# This script benchmarks MRotaryEmbedding.get_rope_index_glm4v (GLM4V mrope index builder).
|
||||
# It generates synthetic multimodal input_ids + attention_mask (+ optional image/video grids),
|
||||
# runs benchmarks.
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# python3 benchmark_rope_index.py --device cuda --num-tokens 1024 2048 --benchmark-iter 200
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Minimal config objects
|
||||
# -----------------------------
|
||||
@dataclass
|
||||
class DummyVisionConfig:
|
||||
spatial_merge_size: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyHFConfig:
|
||||
image_token_id: int = 32000
|
||||
video_start_token_id: int = 32001
|
||||
video_end_token_id: int = 32002
|
||||
vision_config: DummyVisionConfig = field(
|
||||
default_factory=lambda: DummyVisionConfig(spatial_merge_size=2)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Helpers
|
||||
# -----------------------------
|
||||
def calculate_stats(times: list[float]) -> dict[str, float]:
|
||||
"""Calculate statistics from a list of times."""
|
||||
times_array = np.array(times, dtype=np.float64)
|
||||
return {
|
||||
"mean": float(np.mean(times_array)),
|
||||
"median": float(np.median(times_array)),
|
||||
"p99": float(np.percentile(times_array, 99)),
|
||||
"min": float(np.min(times_array)),
|
||||
"max": float(np.max(times_array)),
|
||||
}
|
||||
|
||||
|
||||
def _sync(device: torch.device):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _approx_hw(patches: int, merge: int) -> tuple[int, int]:
|
||||
# want (h/merge)*(w/merge) ~= patches
|
||||
gh = int(math.sqrt(max(1, patches)))
|
||||
gw = max(1, patches // max(1, gh))
|
||||
return gh * merge, gw * merge
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
num_tokens: int,
|
||||
batch_size: int,
|
||||
hf_config: DummyHFConfig,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pad_ratio: float,
|
||||
num_images_per_sample: int,
|
||||
image_patch_tokens: int,
|
||||
num_videos_per_sample: int,
|
||||
video_patch_tokens: int,
|
||||
seed: int,
|
||||
):
|
||||
"""
|
||||
Generate synthetic (input_ids, attention_mask, image_grid_thw, video_grid_thw).
|
||||
|
||||
NOTE:
|
||||
- image_grid_thw / video_grid_thw are global lists across the entire batch in encounter order,
|
||||
matching the function's image_index/video_index behavior.
|
||||
- image patches are represented by repeated image_token_id.
|
||||
- video patches are represented by image_token_id wrapped with start/end tokens.
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
forbidden = {
|
||||
0,
|
||||
hf_config.image_token_id,
|
||||
hf_config.video_start_token_id,
|
||||
hf_config.video_end_token_id,
|
||||
}
|
||||
vocab_size = 50000
|
||||
|
||||
def rand_text(n: int) -> torch.Tensor:
|
||||
# generate random ids not in forbidden
|
||||
out = torch.randint(1, vocab_size, (n,), device=device, dtype=torch.long)
|
||||
# fix forbidden by +1 until ok (cheap, deterministic enough for benchmark data)
|
||||
for bad in forbidden:
|
||||
out = torch.where(out == bad, out + 1, out)
|
||||
return out
|
||||
|
||||
image_grids: list[list[int]] = []
|
||||
video_grids: list[list[int]] = []
|
||||
|
||||
input_ids = torch.zeros((batch_size, num_tokens), device=device, dtype=torch.long)
|
||||
attention_mask = torch.zeros(
|
||||
(batch_size, num_tokens), device=device, dtype=torch.long
|
||||
)
|
||||
|
||||
eff_len = int(round(num_tokens * (1.0 - pad_ratio)))
|
||||
eff_len = max(1, min(num_tokens, eff_len))
|
||||
|
||||
min_needed = 1
|
||||
min_needed += num_images_per_sample * image_patch_tokens
|
||||
min_needed += num_videos_per_sample * (2 + video_patch_tokens)
|
||||
if eff_len < min_needed:
|
||||
num_images_per_sample = 0
|
||||
num_videos_per_sample = 0
|
||||
|
||||
for b in range(batch_size):
|
||||
blocks: list[torch.Tensor] = []
|
||||
|
||||
reserved = (
|
||||
num_images_per_sample * image_patch_tokens
|
||||
+ num_videos_per_sample * (2 + video_patch_tokens)
|
||||
)
|
||||
reserved = min(reserved, max(0, eff_len - 1))
|
||||
text_budget = max(1, eff_len - reserved)
|
||||
|
||||
n_text_chunks = num_images_per_sample + num_videos_per_sample + 1
|
||||
base = text_budget // n_text_chunks
|
||||
rem = text_budget % n_text_chunks
|
||||
text_chunks = [base + (1 if i < rem else 0) for i in range(n_text_chunks)]
|
||||
|
||||
tci = 0
|
||||
for _ in range(num_images_per_sample):
|
||||
blocks.append(rand_text(text_chunks[tci]))
|
||||
tci += 1
|
||||
blocks.append(
|
||||
torch.full(
|
||||
(image_patch_tokens,),
|
||||
hf_config.image_token_id,
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
)
|
||||
|
||||
h, w = _approx_hw(
|
||||
image_patch_tokens, hf_config.vision_config.spatial_merge_size
|
||||
)
|
||||
image_grids.append([1, h, w])
|
||||
|
||||
for _ in range(num_videos_per_sample):
|
||||
blocks.append(rand_text(text_chunks[tci]))
|
||||
tci += 1
|
||||
blocks.append(
|
||||
torch.tensor(
|
||||
[hf_config.video_start_token_id], device=device, dtype=torch.long
|
||||
)
|
||||
)
|
||||
blocks.append(
|
||||
torch.full(
|
||||
(video_patch_tokens,),
|
||||
hf_config.image_token_id,
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
)
|
||||
blocks.append(
|
||||
torch.tensor(
|
||||
[hf_config.video_end_token_id], device=device, dtype=torch.long
|
||||
)
|
||||
)
|
||||
|
||||
h, w = _approx_hw(
|
||||
video_patch_tokens, hf_config.vision_config.spatial_merge_size
|
||||
)
|
||||
# first field = group count used by code; set to 1
|
||||
video_grids.append([1, h, w])
|
||||
|
||||
blocks.append(rand_text(text_chunks[tci]))
|
||||
|
||||
tokens = torch.cat(blocks, dim=0)[:eff_len]
|
||||
pad = torch.zeros(
|
||||
(num_tokens - tokens.numel(),), device=device, dtype=torch.long
|
||||
)
|
||||
ids = torch.cat([tokens, pad], dim=0)
|
||||
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones((tokens.numel(),), device=device, dtype=torch.long),
|
||||
torch.zeros(
|
||||
(num_tokens - tokens.numel(),), device=device, dtype=torch.long
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
input_ids[b] = ids
|
||||
attention_mask[b] = mask
|
||||
|
||||
image_grid_thw = (
|
||||
torch.tensor(image_grids, device=device, dtype=torch.long)
|
||||
if len(image_grids)
|
||||
else None
|
||||
)
|
||||
video_grid_thw = (
|
||||
torch.tensor(video_grids, device=device, dtype=torch.long)
|
||||
if len(video_grids)
|
||||
else None
|
||||
)
|
||||
return (
|
||||
input_ids.to(dtype=torch.long),
|
||||
attention_mask.to(dtype=torch.long),
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_rope_index(
|
||||
model_name: str,
|
||||
tp_size: int,
|
||||
num_tokens: int,
|
||||
batch_size: int,
|
||||
pad_ratio: float,
|
||||
spatial_merge_size: int,
|
||||
num_images: int,
|
||||
image_patch_tokens: int,
|
||||
num_videos: int,
|
||||
video_patch_tokens: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
warmup_iter: int,
|
||||
benchmark_iter: int,
|
||||
device: torch.device,
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
hf_config = DummyHFConfig(
|
||||
image_token_id=32000,
|
||||
video_start_token_id=32001,
|
||||
video_end_token_id=32002,
|
||||
vision_config=DummyVisionConfig(spatial_merge_size=spatial_merge_size),
|
||||
)
|
||||
|
||||
print(80 * "=")
|
||||
print(
|
||||
f"Evaluating: {model_name} tp_size={tp_size} "
|
||||
f"num_tokens={num_tokens} batch={batch_size} pad_ratio={pad_ratio} "
|
||||
f"images/sample={num_images} image_patch_tokens={image_patch_tokens} "
|
||||
f"videos/sample={num_videos} video_patch_tokens={video_patch_tokens} "
|
||||
f"dtype={dtype} device={device}"
|
||||
)
|
||||
|
||||
input_ids, attention_mask, image_grid_thw, video_grid_thw = generate_test_data(
|
||||
num_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
hf_config=hf_config,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pad_ratio=pad_ratio,
|
||||
num_images_per_sample=num_images,
|
||||
image_patch_tokens=image_patch_tokens,
|
||||
num_videos_per_sample=num_videos,
|
||||
video_patch_tokens=video_patch_tokens,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Smoke test
|
||||
has_mm = (image_grid_thw is not None) or (video_grid_thw is not None)
|
||||
if has_mm:
|
||||
pos, delta = MRotaryEmbedding.get_rope_index_glm4v(
|
||||
input_ids=input_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
assert pos.shape == (3, batch_size, num_tokens)
|
||||
assert delta.shape == (batch_size, 1)
|
||||
|
||||
# Warm up
|
||||
for _ in range(warmup_iter):
|
||||
if has_mm:
|
||||
MRotaryEmbedding.get_rope_index_glm4v(
|
||||
input_ids=input_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
MRotaryEmbedding.get_rope_index_glm4v(
|
||||
input_ids=input_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=None,
|
||||
video_grid_thw=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
_sync(device)
|
||||
|
||||
# Time multimodal branch
|
||||
multimodal_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
_sync(device)
|
||||
start = time.time()
|
||||
MRotaryEmbedding.get_rope_index_glm4v(
|
||||
input_ids=input_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
_sync(device)
|
||||
multimodal_times.append(time.time() - start)
|
||||
|
||||
# Time fallback branch
|
||||
fallback_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
_sync(device)
|
||||
start = time.time()
|
||||
MRotaryEmbedding.get_rope_index_glm4v(
|
||||
input_ids=input_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=None,
|
||||
video_grid_thw=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
_sync(device)
|
||||
fallback_times.append(time.time() - start)
|
||||
|
||||
multimodal_stats = calculate_stats(multimodal_times)
|
||||
fallback_stats = calculate_stats(fallback_times)
|
||||
|
||||
print(f"\nPerformance for config (B={batch_size}, T={num_tokens}):")
|
||||
print(
|
||||
f"Multimodal: mean={multimodal_stats['mean']:.8f}s, "
|
||||
f"median={multimodal_stats['median']:.8f}s, "
|
||||
f"p99={multimodal_stats['p99']:.8f}s"
|
||||
)
|
||||
print(
|
||||
f"Fallback: mean={fallback_stats['mean']:.8f}s, "
|
||||
f"median={fallback_stats['median']:.8f}s, "
|
||||
f"p99={fallback_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
if has_mm:
|
||||
speedup = (
|
||||
multimodal_stats["mean"] / fallback_stats["mean"]
|
||||
if fallback_stats["mean"] > 0
|
||||
else float("inf")
|
||||
)
|
||||
print(f"Fallback Speedup over Multimodal: {speedup:.8f}x")
|
||||
else:
|
||||
speedup = float("nan")
|
||||
print(
|
||||
"[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark."
|
||||
)
|
||||
|
||||
print(f"Fallback Speedup over Multimodal: {speedup:.8f}x")
|
||||
|
||||
return multimodal_stats, fallback_stats, speedup
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark GLM4V get_rope_index_glm4v."
|
||||
)
|
||||
parser.add_argument("--model-name", type=str, default="GLM4V")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
parser.add_argument("--warmup-iter", type=int, default=10)
|
||||
parser.add_argument("--benchmark-iter", type=int, default=100)
|
||||
parser.add_argument("--dtype", type=str, choices=["int64"], default="int64")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
|
||||
# token length sweep
|
||||
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
|
||||
|
||||
# data shape knobs
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
parser.add_argument("--pad-ratio", type=float, default=0.0)
|
||||
parser.add_argument("--spatial-merge-size", type=int, default=2)
|
||||
parser.add_argument("--num-images", type=int, default=1)
|
||||
parser.add_argument("--image-patch-tokens", type=int, default=256)
|
||||
parser.add_argument("--num-videos", type=int, default=1)
|
||||
parser.add_argument("--video-patch-tokens", type=int, default=256)
|
||||
|
||||
# output
|
||||
parser.add_argument("--out-dir", type=str, default=".")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
if args.num_tokens is None:
|
||||
num_tokens_list = [2**i for i in range(0, 18)]
|
||||
else:
|
||||
num_tokens_list = args.num_tokens
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
|
||||
for num_tokens in num_tokens_list:
|
||||
multimodal_stats, fallback_stats, speedup = benchmark_rope_index(
|
||||
model_name=args.model_name,
|
||||
tp_size=args.tp_size,
|
||||
num_tokens=num_tokens,
|
||||
batch_size=args.batch_size,
|
||||
pad_ratio=args.pad_ratio,
|
||||
spatial_merge_size=args.spatial_merge_size,
|
||||
num_images=args.num_images,
|
||||
image_patch_tokens=args.image_patch_tokens,
|
||||
num_videos=args.num_videos,
|
||||
video_patch_tokens=args.video_patch_tokens,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
warmup_iter=args.warmup_iter,
|
||||
benchmark_iter=args.benchmark_iter,
|
||||
device=device,
|
||||
)
|
||||
193
third_party/sglang/benchmark/benchmark_batch/benchmark_batch.py
vendored
Normal file
193
third_party/sglang/benchmark/benchmark_batch/benchmark_batch.py
vendored
Normal file
@@ -0,0 +1,193 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from statistics import mean
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
|
||||
###############################################################################
|
||||
# CONFIG
|
||||
###############################################################################
|
||||
ENDPOINT_URL = "http://127.0.0.1:30000"
|
||||
TOKENIZER_DIR = "/models/meta-llama/Llama-3.2-3B"
|
||||
|
||||
# Benchmark configurations
|
||||
NUM_REQUESTS = 10 # Total number of requests (each with BATCH_SIZE prompts)
|
||||
NUM_TOKENS = 32000 # Tokens per prompt
|
||||
BATCH_SIZE = 8 # Number of prompts per request
|
||||
GEN_TOKENS = 0 # Tokens to generate per prompt
|
||||
|
||||
|
||||
###############################################################################
|
||||
# REQUEST GENERATION (in parallel)
|
||||
###############################################################################
|
||||
def generate_random_prompt(index, tokenizer_dir, num_tokens):
|
||||
"""Generate a single random prompt with specified token count."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
def generate_random_text(num_toks):
|
||||
random_token_ids = [random.randint(0, vocab_size - 1) for _ in range(num_toks)]
|
||||
return tokenizer.decode(random_token_ids, clean_up_tokenization_spaces=True)
|
||||
|
||||
random_text = generate_random_text(num_tokens)
|
||||
return f"Prompt {index}: {random_text}"
|
||||
|
||||
|
||||
def prepare_all_prompts(num_requests, batch_size, num_tokens, tokenizer_dir):
|
||||
"""Generate prompts for all requests in parallel."""
|
||||
total_prompts = num_requests * batch_size
|
||||
all_prompts = [None] * total_prompts
|
||||
max_workers = min(os.cpu_count() or 1, total_prompts)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_random_prompt, i, tokenizer_dir, num_tokens)
|
||||
for i in range(total_prompts)
|
||||
]
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=total_prompts,
|
||||
desc="Generating prompts",
|
||||
):
|
||||
index = futures.index(future)
|
||||
all_prompts[index] = future.result()
|
||||
|
||||
batched_prompts = [
|
||||
all_prompts[i * batch_size : (i + 1) * batch_size] for i in range(num_requests)
|
||||
]
|
||||
|
||||
print(
|
||||
f"Generated {total_prompts} prompts with {num_tokens} tokens each, grouped into {num_requests} requests of {batch_size} prompts.\n"
|
||||
)
|
||||
return batched_prompts
|
||||
|
||||
|
||||
###############################################################################
|
||||
# HTTP CALLS
|
||||
###############################################################################
|
||||
def send_batch_request(endpoint, prompts, gen_tokens, request_id):
|
||||
"""Send a batch of prompts to the /generate endpoint synchronously."""
|
||||
sampling_params = {
|
||||
"max_new_tokens": gen_tokens,
|
||||
"temperature": 0.7,
|
||||
"stop": "\n",
|
||||
}
|
||||
data = {"text": prompts, "sampling_params": sampling_params}
|
||||
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
response = requests.post(
|
||||
endpoint.base_url + "/generate", json=data, timeout=3600
|
||||
)
|
||||
if response.status_code != 200:
|
||||
error = response.json()
|
||||
raise RuntimeError(f"Request {request_id} failed: {error}")
|
||||
result = response.json()
|
||||
elapsed_time = (time.perf_counter() - start_time) * 1000 # Convert to ms
|
||||
avg_per_prompt = elapsed_time / len(prompts) if prompts else 0
|
||||
return request_id, elapsed_time, avg_per_prompt, True, len(prompts)
|
||||
except Exception as e:
|
||||
print(f"[Request] Error for request {request_id}: {e}")
|
||||
return request_id, 0, 0, False, len(prompts)
|
||||
|
||||
|
||||
def run_benchmark(endpoint, batched_prompts, batch_size, gen_tokens):
|
||||
"""Run the benchmark sequentially."""
|
||||
results = []
|
||||
num_requests = len(batched_prompts)
|
||||
|
||||
# Record start time for total latency
|
||||
benchmark_start_time = time.perf_counter()
|
||||
|
||||
for i, batch_prompts in enumerate(batched_prompts):
|
||||
request_id = i + 1
|
||||
assert (
|
||||
len(batch_prompts) == batch_size
|
||||
), f"Request {request_id} should have {batch_size} prompts, got {len(batch_prompts)}"
|
||||
|
||||
print(
|
||||
f"[Request] Sending request {request_id}/{num_requests} with {len(batch_prompts)} prompts at {int(time.time()*1000)}"
|
||||
)
|
||||
result = send_batch_request(endpoint, batch_prompts, gen_tokens, request_id)
|
||||
results.append(result)
|
||||
|
||||
# Calculate total latency
|
||||
total_latency = (time.perf_counter() - benchmark_start_time) * 1000 # Convert to ms
|
||||
|
||||
return results, total_latency
|
||||
|
||||
|
||||
###############################################################################
|
||||
# RESULTS
|
||||
###############################################################################
|
||||
def process_results(results, total_latency, num_requests):
|
||||
"""Process and display benchmark results."""
|
||||
total_time = 0
|
||||
successful_requests = 0
|
||||
failed_requests = 0
|
||||
request_latencies = []
|
||||
per_prompt_latencies = []
|
||||
total_prompts = 0
|
||||
|
||||
for request_id, elapsed_time, avg_per_prompt, success, batch_size in results:
|
||||
if success:
|
||||
successful_requests += 1
|
||||
total_prompts += batch_size
|
||||
request_latencies.append(elapsed_time)
|
||||
per_prompt_latencies.append(avg_per_prompt)
|
||||
total_time += elapsed_time / 1000 # Convert to seconds
|
||||
else:
|
||||
failed_requests += 1
|
||||
|
||||
avg_request_latency = mean(request_latencies) if request_latencies else 0
|
||||
avg_per_prompt_latency = mean(per_prompt_latencies) if per_prompt_latencies else 0
|
||||
throughput = total_prompts / total_time if total_time > 0 else 0
|
||||
|
||||
print("\nBenchmark Summary:")
|
||||
print(f" Total requests sent: {len(results)}")
|
||||
print(f" Total prompts sent: {total_prompts}")
|
||||
print(f" Successful requests: {successful_requests}")
|
||||
print(f" Failed requests: {failed_requests}")
|
||||
print(f" Total latency (all requests): {total_latency:.2f} ms")
|
||||
print(f" Avg per request latency: {avg_request_latency:.2f} ms")
|
||||
print(f" Avg per prompt latency: {avg_per_prompt_latency:.2f} ms")
|
||||
print(f" Throughput: {throughput:.2f} prompts/second\n")
|
||||
|
||||
|
||||
###############################################################################
|
||||
# MAIN
|
||||
###############################################################################
|
||||
def main():
|
||||
# Initialize endpoint
|
||||
endpoint = RuntimeEndpoint(ENDPOINT_URL)
|
||||
|
||||
# Generate prompts
|
||||
batched_prompts = prepare_all_prompts(
|
||||
NUM_REQUESTS, BATCH_SIZE, NUM_TOKENS, TOKENIZER_DIR
|
||||
)
|
||||
|
||||
# Flush cache before benchmark
|
||||
# endpoint.flush_cache()
|
||||
|
||||
# Run benchmark
|
||||
print(
|
||||
f"Starting benchmark: NUM_TOKENS={NUM_TOKENS}, BATCH_SIZE={BATCH_SIZE}, NUM_REQUESTS={NUM_REQUESTS}\n"
|
||||
)
|
||||
results, total_latency = run_benchmark(
|
||||
endpoint, batched_prompts, BATCH_SIZE, GEN_TOKENS
|
||||
)
|
||||
|
||||
# Process and display results
|
||||
process_results(results, total_latency, NUM_REQUESTS)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(0)
|
||||
main()
|
||||
237
third_party/sglang/benchmark/benchmark_batch/benchmark_tokenizer.py
vendored
Normal file
237
third_party/sglang/benchmark/benchmark_batch/benchmark_tokenizer.py
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.srt.utils.patch_tokenizer import patch_tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print("Tokenizer Benchmark: Sequential vs Batch Processing")
|
||||
print("-" * 60)
|
||||
print(f"Tokenizer: {args.tokenizer}")
|
||||
print(f"Functions: {', '.join(args.function)}")
|
||||
print(f"Tokens per prompt: {args.num_tokens}")
|
||||
print(f"Number of runs per batch size: {args.num_runs}")
|
||||
print(f"Batch mode: {', '.join(args.batch_mode)}")
|
||||
print("-" * 60)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
|
||||
tokenizer = patch_tokenizer(tokenizer)
|
||||
max_batch_size = max(args.batch_sizes)
|
||||
|
||||
token_ids = generate_random_token_ids(
|
||||
num_prompts=max_batch_size, num_tokens=args.num_tokens, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
if "encode" in args.function:
|
||||
prompts = [
|
||||
tokenizer.decode(ids, clean_up_tokenization_spaces=True)
|
||||
for ids in token_ids
|
||||
]
|
||||
run_benchmark(
|
||||
name="encode",
|
||||
data=prompts,
|
||||
sequential_fn=lambda batch: [tokenizer.encode(p) for p in batch],
|
||||
batch_fn=lambda batch: tokenizer(batch),
|
||||
batch_sizes=args.batch_sizes,
|
||||
num_runs=args.num_runs,
|
||||
batch_mode=args.batch_mode,
|
||||
)
|
||||
|
||||
if "decode" in args.function:
|
||||
# mimic DetokenizerManager's usual case
|
||||
decode_kwargs = dict(
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
run_benchmark(
|
||||
name="decode",
|
||||
data=token_ids,
|
||||
sequential_fn=lambda batch: [
|
||||
tokenizer.decode(ids, **decode_kwargs) for ids in batch
|
||||
],
|
||||
batch_fn=lambda batch: tokenizer.batch_decode(batch, **decode_kwargs),
|
||||
batch_sizes=args.batch_sizes,
|
||||
num_runs=args.num_runs,
|
||||
batch_mode=args.batch_mode,
|
||||
)
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
*, name, data, sequential_fn, batch_fn, batch_sizes, num_runs, batch_mode
|
||||
):
|
||||
print("\n" + "=" * 60)
|
||||
print(f"{name.upper()} BENCHMARK")
|
||||
print("=" * 60)
|
||||
|
||||
results = [
|
||||
benchmark(
|
||||
data=data,
|
||||
batch_size=bs,
|
||||
sequential_fn=sequential_fn,
|
||||
batch_fn=batch_fn,
|
||||
num_runs=num_runs,
|
||||
batch_mode=batch_mode,
|
||||
)
|
||||
for bs in batch_sizes
|
||||
]
|
||||
print_results(results=results, func_name=name, batch_mode=batch_mode)
|
||||
|
||||
|
||||
def benchmark(*, data, batch_size, sequential_fn, batch_fn, num_runs, batch_mode):
|
||||
batch_data = data[:batch_size]
|
||||
run_single = "single" in batch_mode
|
||||
run_batch = "batch" in batch_mode
|
||||
|
||||
out = {"batch_size": batch_size}
|
||||
|
||||
if run_single:
|
||||
sequential_times = measure_times(
|
||||
fn=lambda: sequential_fn(batch_data), num_runs=num_runs
|
||||
)
|
||||
out |= {
|
||||
"avg_sequential_ms": mean(sequential_times),
|
||||
"sequential_runs": sequential_times,
|
||||
}
|
||||
|
||||
if run_batch:
|
||||
batch_times = measure_times(fn=lambda: batch_fn(batch_data), num_runs=num_runs)
|
||||
out |= {
|
||||
"avg_batch_ms": mean(batch_times),
|
||||
"batch_runs": batch_times,
|
||||
}
|
||||
|
||||
if run_single and run_batch:
|
||||
out["speedup_factor"] = (
|
||||
out["avg_sequential_ms"] / out["avg_batch_ms"]
|
||||
if out["avg_batch_ms"] > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def print_results(*, results, func_name, batch_mode):
|
||||
run_single = "single" in batch_mode
|
||||
run_batch = "batch" in batch_mode
|
||||
|
||||
for r in results:
|
||||
print(f"\nBatch size: {r['batch_size']}")
|
||||
if run_single:
|
||||
print_runs(
|
||||
label=f"Sequential {func_name}",
|
||||
runs=r["sequential_runs"],
|
||||
avg=r["avg_sequential_ms"],
|
||||
)
|
||||
if run_batch:
|
||||
print_runs(
|
||||
label=f"Batch {func_name}", runs=r["batch_runs"], avg=r["avg_batch_ms"]
|
||||
)
|
||||
if run_single and run_batch:
|
||||
print(f" Speedup factor: {r['speedup_factor']:.2f}x")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"SUMMARY: {func_name.upper()}")
|
||||
print("=" * 60)
|
||||
|
||||
headers = ["Batch Size"]
|
||||
if run_single:
|
||||
headers.append("Sequential (ms)")
|
||||
if run_batch:
|
||||
headers.append("Batch (ms)")
|
||||
if run_single and run_batch:
|
||||
headers.append("Speedup")
|
||||
print("".join(f"{h:<18}" for h in headers))
|
||||
print("-" * (18 * len(headers)))
|
||||
|
||||
for r in results:
|
||||
row = [f"{r['batch_size']}"]
|
||||
if run_single:
|
||||
row.append(f"{r['avg_sequential_ms']:.2f} ms")
|
||||
if run_batch:
|
||||
row.append(f"{r['avg_batch_ms']:.2f} ms")
|
||||
if run_single and run_batch:
|
||||
row.append(f"{r['speedup_factor']:.2f}x")
|
||||
print("".join(f"{v:<18}" for v in row))
|
||||
|
||||
|
||||
def print_runs(*, label, runs, avg):
|
||||
print(f" {label}:")
|
||||
for i, t in enumerate(runs):
|
||||
print(f" Run {i+1}: {t:.2f} ms")
|
||||
print(f" Average: {avg:.2f} ms")
|
||||
|
||||
|
||||
def measure_times(*, fn, num_runs):
|
||||
times = []
|
||||
for _ in range(num_runs):
|
||||
start = time.perf_counter()
|
||||
fn()
|
||||
times.append((time.perf_counter() - start) * 1000)
|
||||
return times
|
||||
|
||||
|
||||
def generate_random_token_ids(*, num_prompts, num_tokens, tokenizer):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
print(f"Generating {num_prompts} random sequences with {num_tokens} tokens each...")
|
||||
return [
|
||||
[random.randint(0, vocab_size - 1) for _ in range(num_tokens)]
|
||||
for _ in range(num_prompts)
|
||||
]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tokenizer Benchmark: Sequential vs Batch Processing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Tokenizer name or path (e.g. nvidia/Kimi-K2-Thinking-NVFP4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--function",
|
||||
type=str,
|
||||
nargs="+",
|
||||
choices=["encode", "decode"],
|
||||
default=["encode", "decode"],
|
||||
help="Functions to benchmark (default: encode decode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens",
|
||||
type=int,
|
||||
default=20000,
|
||||
help="Number of tokens per prompt (default: 20000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 2, 4, 8],
|
||||
help="Batch sizes to test (default: 1 2 4 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-mode",
|
||||
nargs="+",
|
||||
choices=["single", "batch"],
|
||||
default=["single", "batch"],
|
||||
help="Benchmark modes to run (default: single batch)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs per batch size (default: 5)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(0)
|
||||
main()
|
||||
89
third_party/sglang/benchmark/benchmark_vllm_060/README.md
vendored
Normal file
89
third_party/sglang/benchmark/benchmark_vllm_060/README.md
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0
|
||||
|
||||
In short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang.
|
||||
|
||||
## Online benchmark results
|
||||
|
||||
### Llama 3.1 8B Instruct 1 x A100 80G
|
||||
|
||||
| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|
||||
|------|-------------|--------|--------------------|-------------|-------------|------------|
|
||||
| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** |
|
||||
| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** |
|
||||
| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** |
|
||||
| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** |
|
||||
|
||||
### Llama 3.1 70B Insruct 4 x H100 80G
|
||||
|
||||
| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|
||||
|------|-------------|--------|--------------------|-------------|-------------|------------|
|
||||
| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** |
|
||||
| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** |
|
||||
| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** |
|
||||
| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** |
|
||||
|
||||
## Offline benchmark results
|
||||
|
||||
### Llama 3.1 8B Instruct 1 x A100 80G
|
||||
|
||||
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|
||||
|------|-------------|--------|--------------------|-------------------------|
|
||||
| inf | 5000 | SGLang | 22.03 | **4281.51** |
|
||||
| inf | 5000 | vLLM | 21.27 | **4132.37** |
|
||||
|
||||
### Llama 3.1 70B Insruct 4 x H100 80G
|
||||
|
||||
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|
||||
|------|-------------|--------|--------------------|-------------------------|
|
||||
| inf | 5000 | SGLang | 19.84 | **3856.01** |
|
||||
| inf | 5000 | vLLM | 19.04 | **3700.64** |
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# install sglang v0.3.0
|
||||
pip install --upgrade pip
|
||||
pip install "sglang[all]"==0.3.0
|
||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
||||
|
||||
# install vllm v0.6.0
|
||||
pip install vllm==0.6.0
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
We referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4.
|
||||
|
||||
## Online benchmarks
|
||||
|
||||
```bash
|
||||
# Llama 3.1 8B Instruct on 1 x A100
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
|
||||
|
||||
# Llama 3.1 70B Instruct on 4 x H100
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
|
||||
|
||||
# bench serving
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8
|
||||
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4
|
||||
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8
|
||||
```
|
||||
|
||||
## Offline benchmarks
|
||||
|
||||
```bash
|
||||
# Llama 3.1 8B Instruct on 1 x A100
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
|
||||
|
||||
# Llama 3.1 70B Instruct on 4 x H100
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
|
||||
|
||||
# bench serving
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000
|
||||
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000
|
||||
```
|
||||
24
third_party/sglang/benchmark/blog_v0_2/405b_sglang.sh
vendored
Normal file
24
third_party/sglang/benchmark/blog_v0_2/405b_sglang.sh
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Create dummy weights:
|
||||
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
|
||||
# 2. Get `config.json`` from ./config.md
|
||||
# 3. Download the tokenizer
|
||||
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
|
||||
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
||||
|
||||
# Launch sglang
|
||||
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87
|
||||
|
||||
# offline
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > sglang_log12
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > sglang_log13
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > sglang_log14
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > sglang_log15
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompt 2000 > sglang_log21
|
||||
|
||||
# online
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > sglang_log31
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > sglang_log32
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > sglang_log33
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > sglang_log34
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > sglang_log35
|
||||
17
third_party/sglang/benchmark/blog_v0_2/405b_trt.sh
vendored
Normal file
17
third_party/sglang/benchmark/blog_v0_2/405b_trt.sh
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
# Launch trtllm
|
||||
# https://github.com/sgl-project/tensorrt-demo
|
||||
|
||||
# offline
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log11
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log12
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log13
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log14
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log15
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name sharegpt --num-prompt 2000 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log21
|
||||
|
||||
# online
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log31
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log32
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log33
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log34
|
||||
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log35
|
||||
24
third_party/sglang/benchmark/blog_v0_2/405b_vllm.sh
vendored
Normal file
24
third_party/sglang/benchmark/blog_v0_2/405b_vllm.sh
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Create dummy weights:
|
||||
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
|
||||
# 2. Get `config.json`` from ./config.md
|
||||
# 3. Download the tokenizer
|
||||
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
|
||||
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
||||
|
||||
# Launch vllm
|
||||
# python3 -m vllm.entrypoints.openai.api_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --disable-log-requests --tensor-parallel-size 8 --max-model-len 10000
|
||||
|
||||
# offline
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > vllm_log11
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > vllm_log12
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > vllm_log13
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > vllm_log14
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > vllm_log15
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name sharegpt --num-prompt 2000 > vllm_log21
|
||||
|
||||
# online
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > vllm_log31
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > vllm_log32
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > vllm_log33
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > vllm_log34
|
||||
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > vllm_log35
|
||||
164
third_party/sglang/benchmark/blog_v0_2/README.md
vendored
Normal file
164
third_party/sglang/benchmark/blog_v0_2/README.md
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
# How to reproduce the benchmark results of SGLang
|
||||
|
||||
## Prerequisite
|
||||
|
||||
### Install the latest SGLang
|
||||
|
||||
```bash
|
||||
git clone https://github.com/sgl-project/sglang.git
|
||||
cd sglang
|
||||
git checkout v0.2.7
|
||||
|
||||
pip install --upgrade pip
|
||||
pip install -e "python[all]"
|
||||
|
||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
|
||||
```
|
||||
|
||||
### Set up ulimit and HF_TOKEN
|
||||
|
||||
```bash
|
||||
ulimit -n 65535
|
||||
# Change the token to a real and usable one, with access permissions for the Llama 3 models.
|
||||
export HF_TOKEN=hf_token
|
||||
```
|
||||
|
||||
### Launch the server
|
||||
|
||||
```bash
|
||||
# Meta-Llama-3.1-8B-Instruct
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
||||
|
||||
# Meta-Llama-3.1-70B-Instruct
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 8
|
||||
|
||||
# Meta-Llama-3-70B-Instruct-FP8
|
||||
python -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-radix-cache --tp 8
|
||||
```
|
||||
|
||||
## Benchmark
|
||||
|
||||
### Hardware Requirements
|
||||
|
||||
- 8B models: Single NVIDIA A100 80GB GPU
|
||||
- 70B models: 8 x NVIDIA A100 80GB GPUs with Tensor Parallelism (TP) 8
|
||||
- 70B FP8 models: 8 x NVIDIA H100 GPUs with Tensor Parallelism (TP) 8
|
||||
|
||||
Please ensure you have the appropriate hardware before running the benchmarks.
|
||||
|
||||
#### Offline benchmark
|
||||
|
||||
```bash
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 3000 --output-file offline.jsonl
|
||||
cat offline.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||
```
|
||||
|
||||
#### Online benchmark
|
||||
|
||||
```bash
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online.jsonl
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online.jsonl
|
||||
cat online.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||
```
|
||||
|
||||
## Other
|
||||
|
||||
We tried using vLLM 0.5.3.post1, but it often crashes under high loads, and it seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking, so we are using the older version, vLLM 0.5.2.
|
||||
|
||||
For TensorRT LLM preparation, follow your internal TensorRT-LLM deployment guide. Specifically, we used a batch size of 512, a max input length of 8192, and a max number of tokens of 8192. The instance count for preprocessing and postprocessing in Triton Server is 16.
|
||||
|
||||
```bash
|
||||
# vLLM
|
||||
pip install vllm==0.5.2
|
||||
pip install jsonschema==4.21.1
|
||||
|
||||
# Meta-Llama-3-8B-Instruct
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests
|
||||
|
||||
# meta-llama/Meta-Llama-3-70B-Instruct
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B-Instruct --disable-log-requests --tensor 8
|
||||
|
||||
# neuralmagic/Meta-Llama-3-70B-Instruct-FP8
|
||||
python -m vllm.entrypoints.openai.api_server --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-log-requests --tensor 8
|
||||
```
|
||||
|
||||
```bash
|
||||
wget https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py
|
||||
```
|
||||
|
||||
```bash
|
||||
# vLLM Offline
|
||||
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name sharegpt --num-prompts 3000 --output-file offline_vllm.jsonl
|
||||
cat offline_vllm.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||
```
|
||||
|
||||
```bash
|
||||
# vLLM Online
|
||||
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_vllm.jsonl
|
||||
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_vllm.jsonl
|
||||
cat online_vllm.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||
```
|
||||
|
||||
```bash
|
||||
# TensorRT LLM Offline 8B
|
||||
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_8b.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_8b.jsonl
|
||||
cat offline_trt_8b.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||
```
|
||||
|
||||
```bash
|
||||
# TensorRT LLM Online 8B
|
||||
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_8b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_8b.jsonl
|
||||
cat online_trt_8b.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||
```
|
||||
|
||||
```bash
|
||||
# TensorRT LLM Offline 70B
|
||||
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_70b.jsonl --model meta-llama/Meta-Llama-3-70B-Instruct
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_70b.jsonl
|
||||
cat offline_trt_70b.jsonl | cut -d':' -f12 | cut -d',' -f1
|
||||
```
|
||||
|
||||
```bash
|
||||
# TensorRT LLM Online 70B
|
||||
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_70b.jsonl
|
||||
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_70b.jsonl
|
||||
cat online_trt_70b.jsonl | cut -d':' -f9 | cut -d',' -f1
|
||||
```
|
||||
100
third_party/sglang/benchmark/blog_v0_2/config.md
vendored
Normal file
100
third_party/sglang/benchmark/blog_v0_2/config.md
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
### used for TensorRT LLM
|
||||
|
||||
```
|
||||
{
|
||||
"architecture": "LlamaForCausalLM",
|
||||
"dtype": "float16",
|
||||
"logits_dtype": "float32",
|
||||
"vocab_size": 128256,
|
||||
"max_position_embeddings": 8192,
|
||||
"hidden_size": 16384,
|
||||
"num_hidden_layers": 126,
|
||||
"num_attention_heads": 128,
|
||||
"num_key_value_heads": 16,
|
||||
"head_size": 128,
|
||||
"qk_layernorm": false,
|
||||
"hidden_act": "silu",
|
||||
"intermediate_size": 53248,
|
||||
"norm_epsilon": 1e-05,
|
||||
"position_embedding_type": "rope_gpt_neox",
|
||||
"use_parallel_embedding": false,
|
||||
"embedding_sharding_dim": 0,
|
||||
"share_embedding_table": false,
|
||||
"mapping": {
|
||||
"world_size": 8,
|
||||
"tp_size": 8,
|
||||
"pp_size": 1,
|
||||
"gpus_per_node": 8
|
||||
},
|
||||
"quantization": {
|
||||
"quant_algo": "FP8",
|
||||
"kv_cache_quant_algo": null,
|
||||
"group_size": 128,
|
||||
"smoothquant_val": null,
|
||||
"has_zero_point": false,
|
||||
"pre_quant_scale": false,
|
||||
"exclude_modules": [
|
||||
"lm_head"
|
||||
]
|
||||
},
|
||||
"kv_dtype": "float16",
|
||||
"rotary_scaling": null,
|
||||
"residual_mlp": false,
|
||||
"moe_normalization_mode": null,
|
||||
"rotary_base": 500000.0,
|
||||
"moe_num_experts": 0,
|
||||
"moe_top_k": 0,
|
||||
"moe_tp_mode": 2,
|
||||
"attn_bias": false,
|
||||
"disable_weight_only_quant_plugin": false,
|
||||
"mlp_bias": false
|
||||
}
|
||||
```
|
||||
|
||||
### used for vLLM and SGLang
|
||||
|
||||
```
|
||||
{
|
||||
"_name_or_path": "dummy_fp8",
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
],
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 128000,
|
||||
"eos_token_id": 128009,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 16384,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 53248,
|
||||
"mlp_bias": false,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 128,
|
||||
"num_hidden_layers": 126,
|
||||
"num_key_value_heads": 8,
|
||||
"pretraining_tp": 1,
|
||||
"quantization_config": {
|
||||
"activation_scheme": "static",
|
||||
"ignored_layers": [
|
||||
"lm_head"
|
||||
],
|
||||
"quant_method": "fp8"
|
||||
},
|
||||
"rope_scaling": {
|
||||
"factor": 8.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3"
|
||||
},
|
||||
"max_position_embeddings": 131072,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": null,
|
||||
"rope_theta": 500000.0,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.41.1",
|
||||
"use_cache": true,
|
||||
"vocab_size": 128256
|
||||
}
|
||||
```
|
||||
19
third_party/sglang/benchmark/boolq/README.md
vendored
Normal file
19
third_party/sglang/benchmark/boolq/README.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
## Download data
|
||||
```
|
||||
git clone https://hf-mirror.com/datasets/google/boolq
|
||||
```
|
||||
|
||||
## Convert parquet to json
|
||||
```
|
||||
bash parquet_to_json.sh
|
||||
```
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path ramblingpolymath/Qwen3-32B-W8A8 --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py
|
||||
```
|
||||
124
third_party/sglang/benchmark/boolq/bench_sglang.py
vendored
Normal file
124
third_party/sglang/benchmark/boolq/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.api import set_default_backend
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import read_jsonl
|
||||
|
||||
|
||||
def get_example(lines, i, answer):
|
||||
prompt = "Question: " + lines[i]["question"] + lines[i]["passage"] + "\nAnswer:"
|
||||
if answer:
|
||||
prompt += str(lines[i]["answer"])
|
||||
return prompt
|
||||
|
||||
|
||||
def few_shot_examples(lines, k):
|
||||
prompts = ""
|
||||
for i in range(k):
|
||||
prompts += get_example(lines, i, True) + "\n\n"
|
||||
return prompts
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
set_default_backend(select_sglang_backend(args))
|
||||
|
||||
# Read data
|
||||
train_data_path = args.train_data_path
|
||||
test_data_path = args.test_data_path
|
||||
lines_train = list(read_jsonl(train_data_path))
|
||||
lines_test = list(read_jsonl(test_data_path))
|
||||
|
||||
# Construct prompts
|
||||
num_questions = args.num_questions
|
||||
num_shots = args.num_shots
|
||||
few_shots = few_shot_examples(lines_train, num_shots)
|
||||
|
||||
questions = []
|
||||
answer = []
|
||||
for i in range(len(lines_test[:num_questions])):
|
||||
questions.append(get_example(lines_test, i, False))
|
||||
answer.append(str(lines_test[i]["answer"]))
|
||||
arguments = [{"question": q} for q in questions]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_boolq(s, question):
|
||||
s += few_shots + question
|
||||
s += sgl.gen("answer", max_tokens=5, stop=["\n"])
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
states = few_shot_boolq.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(states[i]["answer"])
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(answer))
|
||||
|
||||
# Compute speed
|
||||
num_output_tokens = sum(
|
||||
s.get_meta_info("answer")["completion_tokens"] for s in states
|
||||
)
|
||||
output_throughput = num_output_tokens / latency
|
||||
|
||||
# Print results
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
print(f"Latency: {latency:.3f} s")
|
||||
print(f"Output throughput: {output_throughput:.3f} token/s")
|
||||
|
||||
# Results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "boolq",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shots", type=int, default=5)
|
||||
parser.add_argument(
|
||||
"--train-data-path", type=str, default="./boolq/data/train-00000-of-00001.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-data-path",
|
||||
type=str,
|
||||
default="./boolq/data/validation-00000-of-00001.json",
|
||||
)
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
28
third_party/sglang/benchmark/boolq/convert_parquet_to_json.py
vendored
Normal file
28
third_party/sglang/benchmark/boolq/convert_parquet_to_json.py
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
import sys
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
|
||||
def convert_parquet_to_json(input_file, output_file):
|
||||
# read parquet file
|
||||
table = pq.read_table(input_file)
|
||||
|
||||
# turn parquet data to dataframe
|
||||
df = table.to_pandas()
|
||||
|
||||
# turn dataframe to json form
|
||||
json_data = df.to_json(orient="records", lines=True)
|
||||
|
||||
# write json to file
|
||||
with open(output_file, "w") as f:
|
||||
f.write(json_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage:python convert_parquet_to_json.py <input_file> <output_file>")
|
||||
|
||||
input_file = sys.argv[1]
|
||||
output_file = sys.argv[2]
|
||||
|
||||
convert_parquet_to_json(input_file, output_file)
|
||||
26
third_party/sglang/benchmark/boolq/parquet_to_json.sh
vendored
Executable file
26
third_party/sglang/benchmark/boolq/parquet_to_json.sh
vendored
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
#define input and output direction
|
||||
input_dir="./boolq/data"
|
||||
output_dir="./boolq/data"
|
||||
|
||||
#define files needed to be handled
|
||||
files=(
|
||||
"train-00000-of-00001.parquet"
|
||||
"validation-00000-of-00001.parquet"
|
||||
)
|
||||
|
||||
#foe files above, use python script to convert the form
|
||||
for file in "${files[@]}"; do
|
||||
input_file="${input_dir}/${file}"
|
||||
output_file="${output_dir}/${file%.parquet}.json"
|
||||
|
||||
echo "Converting ${input_file} to ${output_file} ..."
|
||||
python3 convert_parquet_to_json.py "${input_file}" "${output_file}"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Conversion successful: ${output_file}"
|
||||
else
|
||||
echo "Conversion failed: ${input_file}"
|
||||
fi
|
||||
done
|
||||
15
third_party/sglang/benchmark/ceval/README.md
vendored
Normal file
15
third_party/sglang/benchmark/ceval/README.md
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
## Download data
|
||||
```
|
||||
git lfs clone https://huggingface.co/datasets/ceval/ceval-exam
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path ramblingpolymath/Qwen3-32B-W8A8 --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py
|
||||
```
|
||||
138
third_party/sglang/benchmark/ceval/bench_sglang.py
vendored
Normal file
138
third_party/sglang/benchmark/ceval/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from sglang.lang.api import set_default_backend
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
choices = ["A", "B", "C", "D"]
|
||||
|
||||
|
||||
def get_one_example(line, include_answer):
|
||||
res = line["question"]
|
||||
res += f"\nA. {line['A']}"
|
||||
res += f"\nB. {line['B']}"
|
||||
res += f"\nC. {line['C']}"
|
||||
res += f"\nD. {line['D']}"
|
||||
|
||||
if include_answer:
|
||||
res += f"\nAnswer: {line['answer']} \n\n"
|
||||
return res
|
||||
|
||||
|
||||
def get_few_shot_examples(lines):
|
||||
res = ""
|
||||
for line in lines:
|
||||
res += get_one_example(line, True) + "\n\n"
|
||||
return res
|
||||
|
||||
|
||||
def get_answer_value(response):
|
||||
pattern = r"(Answer:|answer:|答案是|答案是:|正确答案是:|答案:|Assistant:)\s*([A-D])(?![\w])"
|
||||
match = re.search(pattern, response)
|
||||
|
||||
if match:
|
||||
return match.group(2)
|
||||
|
||||
return random.choice(choices)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Read data && Construct prompts
|
||||
arguments = []
|
||||
labels = []
|
||||
examples = "examples:\n"
|
||||
data_path = args.data_path
|
||||
for subject in os.listdir(data_path):
|
||||
subject_path = os.path.join(data_path, subject)
|
||||
if os.path.isdir(subject_path) and subject != ".git":
|
||||
dataset = load_dataset(data_path, name=subject)
|
||||
dev_lines_temp = dataset["dev"]
|
||||
val_lines_temp = dataset["val"]
|
||||
few_shot_examples = get_few_shot_examples(dev_lines_temp)
|
||||
examples += f"{few_shot_examples}"
|
||||
for val_line in val_lines_temp:
|
||||
arguments.append(
|
||||
{
|
||||
"examples": few_shot_examples,
|
||||
"question": get_one_example(val_line, False),
|
||||
}
|
||||
)
|
||||
labels.append(val_line["answer"])
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_ceval(s, examples, question):
|
||||
s += examples + question + sgl.gen("Answer")
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
num_questions = args.num_questions if args.num_questions else len(arguments)
|
||||
|
||||
# Select backend
|
||||
set_default_backend(select_sglang_backend(args))
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
states = few_shot_ceval.run_batch(
|
||||
arguments[:num_questions],
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
preds = [get_answer_value(states[i]["Answer"]) for i in range(num_questions)]
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels[:num_questions]))
|
||||
|
||||
# Compute speed
|
||||
num_output_tokens = sum(
|
||||
s.get_meta_info("Answer")["completion_tokens"] for s in states
|
||||
)
|
||||
output_throughput = num_output_tokens / latency
|
||||
|
||||
# Print results
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
print(f"Latency: {latency:.3f} s")
|
||||
print(f"Output throughput: {output_throughput:.3f} token/s")
|
||||
|
||||
# Write results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "ceval",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="ceval/ceval-exam")
|
||||
parser.add_argument("--num-questions", type=int, default=None)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
412
third_party/sglang/benchmark/deepseek_v3/README.md
vendored
Normal file
412
third_party/sglang/benchmark/deepseek_v3/README.md
vendored
Normal file
@@ -0,0 +1,412 @@
|
||||
# DeepSeek V3.1/V3/R1 Support
|
||||
|
||||
The SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVIDIA and AMD GPUs **from day one**. SGLang also supports [MLA optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [DP attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models), making SGLang one of the best open-source LLM engines for running DeepSeek models. SGLang is the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended).
|
||||
|
||||
Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources.
|
||||
|
||||
For optimizations made on the DeepSeek series models regarding SGLang, please refer to [DeepSeek V3/V3.1/R1 Model Optimizations in SGLang](https://docs.sglang.io/basic_usage/deepseek_v3.html#optimizations).
|
||||
|
||||
## Installation & Launch
|
||||
|
||||
If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded.
|
||||
|
||||
### Using Docker (Recommended)
|
||||
|
||||
```bash
|
||||
# Pull latest image
|
||||
# https://hub.docker.com/r/lmsysorg/sglang/tags
|
||||
docker pull lmsysorg/sglang:latest
|
||||
|
||||
# Launch
|
||||
docker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host --network=host --privileged lmsysorg/sglang:latest \
|
||||
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000
|
||||
```
|
||||
|
||||
If you are using RDMA, please note that:
|
||||
|
||||
1. `--network host` and `--privileged` are required by RDMA. If you don't need RDMA, you can remove them.
|
||||
2. You may need to set `NCCL_IB_GID_INDEX` if you are using RoCE, for example: `export NCCL_IB_GID_INDEX=3`.
|
||||
|
||||
Add [performance optimization options](#performance-optimization-options) as needed.
|
||||
|
||||
### Using pip
|
||||
|
||||
```bash
|
||||
# Installation
|
||||
pip install sglang
|
||||
|
||||
# Launch
|
||||
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
||||
```
|
||||
|
||||
Add [performance optimization options](#performance-optimization-options) as needed.
|
||||
|
||||
<a id="option_args"></a>
|
||||
|
||||
### Performance Optimization Options
|
||||
|
||||
[MLA optimizations](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) are enabled by default. Here are some optional optimizations can be enabled as needed.
|
||||
|
||||
- [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models): For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
|
||||
- [Torch.compile Optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#torchcompile-latency-optimizations): Add `--enable-torch-compile` argument to enable it. This will take some time while server starts. The maximum batch size for torch.compile optimization can be controlled with `--torch-compile-max-bs`. It's recommended to set it between `1` and `8`. (e.g., `--torch-compile-max-bs 8`)
|
||||
|
||||
### Usage: Chat with DeepSeek
|
||||
|
||||
#### DeepSeek V3/R1
|
||||
|
||||
```python3
|
||||
import openai
|
||||
client = openai.Client(
|
||||
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||
|
||||
# Chat completion
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=64,
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### DeepSeek V3.1
|
||||
On top of the basic usage similar to the DeepSeek V3/R1 example, DeepSeek V3.1 supports a request-level thinking/non-thinking toggle. Simply switch the `"thinking"` field in `extra_body={"chat_template_kwargs": {"thinking": True}}` to enable/disable the thinking mode.
|
||||
|
||||
##### Non Thinking
|
||||
```python3
|
||||
import openai
|
||||
client = openai.Client(
|
||||
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||
|
||||
# Chat completion
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Answer the following with the second letter of the correct answer only: What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
extra_body = {"chat_template_kwargs": {"thinking": False}}
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
Answer:
|
||||
```
|
||||
h
|
||||
```
|
||||
* The correct response should be 'A', as the correct answer to the question is 'Paris'.
|
||||
##### Thinking
|
||||
```python3
|
||||
import openai
|
||||
client = openai.Client(
|
||||
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||
|
||||
# Chat completion
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Answer the following with the second letter of the correct answer only: What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
extra_body = {"chat_template_kwargs": {"thinking": True}}
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
Answer:
|
||||
```
|
||||
First, the question is: "What is the capital of France?" I know that the capital of France is Paris.
|
||||
|
||||
The user says: "Answer the following with the second letter of the correct answer only." So, I need to provide only the second letter of the correct answer.
|
||||
|
||||
The correct answer is "Paris". Now, I need to find the second letter of "Paris".
|
||||
|
||||
Let's spell it out: P-A-R-I-S.
|
||||
|
||||
- First letter: P
|
||||
|
||||
- Second letter: A
|
||||
|
||||
- Third letter: R
|
||||
|
||||
- Fourth letter: I
|
||||
|
||||
- Fifth letter: S
|
||||
|
||||
So, the second letter is "A".
|
||||
|
||||
I should only output the second letter, which is "A". No additional text or explanation, just the letter.
|
||||
|
||||
The user emphasized "the second letter of the correct answer only", so my response should be just "A".
|
||||
|
||||
Finally, I need to make sure that this is the correct answer. Yes, Paris is indeed the capital of France.</think>A
|
||||
```
|
||||
* The response contains `</think>` thinking trace and model was able to derive the correct answer from it.
|
||||
|
||||
### Example: Serving with two H20\*8 nodes
|
||||
|
||||
For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. Please **use the first node's IP** for both commands.
|
||||
|
||||
If the command fails, try setting the `GLOO_SOCKET_IFNAME` parameter. For more information, see [Common Environment Variables](https://pytorch.org/docs/stable/distributed.html#common-environment-variables).
|
||||
|
||||
If the multi nodes support NVIDIA InfiniBand and encounter hanging issues during startup, consider adding the parameter `export NCCL_IB_GID_INDEX=3`. For more information, see [this](https://github.com/sgl-project/sglang/issues/3516#issuecomment-2668493307).
|
||||
|
||||
```bash
|
||||
# node 1
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
|
||||
|
||||
# node 2
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
|
||||
```
|
||||
|
||||
If you have two H100 nodes, the usage is similar to the aforementioned H20.
|
||||
|
||||
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||
|
||||
### Example: Serving with one B200 node
|
||||
|
||||
There is one B200 node with 4 (for FP4) GPUs or 8 (for FP4 or FP8) GPUs. Both FP4 and FP8 models are supported for DeepSeek R1. The flags to achieve optimal performance for each are slightly different.
|
||||
|
||||
#### FP4
|
||||
|
||||
If using 4 GPUs:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4-V2 --host 0.0.0.0 --port 8000 --tensor-parallel-size=4 --cuda-graph-max-bs 256 --max-running-requests 256 --mem-fraction-static 0.85 --ep-size 4 --scheduler-recv-interval 30 --enable-symm-mem --stream-interval 10
|
||||
```
|
||||
|
||||
If using 8 GPUs:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4-V2 --host 0.0.0.0 --port 8000 --tensor-parallel-size=8 --cuda-graph-max-bs 256 --max-running-requests 256 --mem-fraction-static 0.85 --ep-size 8 --scheduler-recv-interval 30 --enable-symm-mem --stream-interval 10
|
||||
```
|
||||
|
||||
#### FP8
|
||||
|
||||
```bash
|
||||
SGLANG_ENABLE_JIT_DEEPGEMM=false python3 -m sglang.launch_server --model-path=deepseek-ai/DeepSeek-R1-0528 --host=0.0.0.0 --port=8000 --tensor-parallel-size=8 --cuda-graph-max-bs 128 --max-running-requests 128 --mem-fraction-static 0.82 --kv-cache-dtype fp8_e4m3 --chunked-prefill-size 32768 --max-prefill-tokens 32768 --scheduler-recv-interval 30 --stream-interval 30 --fp8-gemm-backend flashinfer_trtllm
|
||||
```
|
||||
|
||||
### Example: Serving with two H200\*8 nodes and docker
|
||||
|
||||
There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.
|
||||
A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.
|
||||
|
||||
```bash
|
||||
# node 1
|
||||
docker run --gpus all \
|
||||
--shm-size 32g \
|
||||
--network=host \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--name sglang_multinode1 \
|
||||
-it \
|
||||
--rm \
|
||||
--env "HF_TOKEN=$HF_TOKEN" \
|
||||
--ipc=host \
|
||||
lmsysorg/sglang:latest \
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 40000
|
||||
```
|
||||
|
||||
```bash
|
||||
# node 2
|
||||
docker run --gpus all \
|
||||
--shm-size 32g \
|
||||
--network=host \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--name sglang_multinode2 \
|
||||
-it \
|
||||
--rm \
|
||||
--env "HF_TOKEN=$HF_TOKEN" \
|
||||
--ipc=host \
|
||||
lmsysorg/sglang:latest \
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 1 --trust-remote-code --host 0.0.0.0 --port 40000
|
||||
```
|
||||
|
||||
To ensure functionality, we include a test from a client Docker container.
|
||||
|
||||
```bash
|
||||
docker run --gpus all \
|
||||
--shm-size 32g \
|
||||
--network=host \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--name sglang_multinode_client \
|
||||
-it \
|
||||
--rm \
|
||||
--env "HF_TOKEN=$HF_TOKEN" \
|
||||
--ipc=host \
|
||||
lmsysorg/sglang:latest \
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 1 --host 0.0.0.0 --port 40000 --output-file "deepseekv3_multinode.jsonl"
|
||||
```
|
||||
|
||||
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||
|
||||
### Example: Serving with four A100\*8 nodes
|
||||
|
||||
To serve DeepSeek-V3 with A100 GPUs, we need to convert the [FP8 model checkpoints](https://huggingface.co/deepseek-ai/DeepSeek-V3) to BF16 with [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) mentioned [here](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) first.
|
||||
|
||||
Since the BF16 model is over 1.3 TB, we need to prepare four A100 nodes, each with 8 80GB GPUs. Assume the first node's IP is `10.0.0.1`, and the converted model path is `/path/to/DeepSeek-V3-BF16`, we can have following commands to launch the server.
|
||||
|
||||
```bash
|
||||
# node 1
|
||||
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 30000
|
||||
|
||||
# node 2
|
||||
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 1 --trust-remote-code
|
||||
|
||||
# node 3
|
||||
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 2 --trust-remote-code
|
||||
|
||||
# node 4
|
||||
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 3 --trust-remote-code
|
||||
```
|
||||
|
||||
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||
|
||||
Then we can benchmark the accuracy and latency by accessing the first node's exposed port with the following example commands.
|
||||
|
||||
```bash
|
||||
# bench accuracy
|
||||
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --host 10.0.0.1 --port 30000
|
||||
|
||||
# bench latency
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1:30000 --batch-size 1 --input-len 128 --output-len 128
|
||||
```
|
||||
|
||||
|
||||
### Example: Serving with 8 A100/A800 with AWQ Quantization
|
||||
|
||||
**Recommended Usage**
|
||||
|
||||
Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance.
|
||||
One example is as follows:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16
|
||||
```
|
||||
|
||||
Alternatively, you can use `--quantization awq_marlin` as follows:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization awq_marlin --dtype float16
|
||||
```
|
||||
|
||||
Note that `awq_marlin` only supports `float16` now, which may lead to some precision loss.
|
||||
|
||||
### Example: Serving with 16 A100/A800 with int8 Quantization
|
||||
|
||||
There are block-wise and per-channel quantization methods, and the quantization parameters have already been uploaded to Huggingface. One example is as follows:
|
||||
|
||||
- [meituan/DeepSeek-R1-Block-INT8](https://huggingface.co/meituan/DeepSeek-R1-Block-INT8)
|
||||
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
|
||||
|
||||
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-INT8` and port=5000, we can have following commands to launch the server:
|
||||
```bash
|
||||
#master
|
||||
python3 -m sglang.launch_server \
|
||||
--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \
|
||||
MASTER_IP:5000 --nnodes 2 --node-rank 0 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8
|
||||
#cluster
|
||||
python3 -m sglang.launch_server \
|
||||
--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \
|
||||
MASTER_IP:5000 --nnodes 2 --node-rank 1 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8
|
||||
```
|
||||
|
||||
> **Note that the launch command here enables `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
|
||||
|
||||
Then on the **master node**, supposing the ShareGPT data is located at `/path/to/ShareGPT_V3_unfiltered_cleaned_split.json`, you can run the following commands to benchmark the launched server:
|
||||
|
||||
```bash
|
||||
# bench accuracy
|
||||
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319
|
||||
|
||||
# bench serving
|
||||
python3 -m sglang.bench_serving --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json --dataset-name random --random-input 128 --random-output 128 --num-prompts 1000 --request-rate 128 --random-range-ratio 1.0
|
||||
```
|
||||
|
||||
> **Note: using `--parallel 200` can accelerate accuracy benchmarking**.
|
||||
|
||||
### Example: Serving with 32 L40S with int8 Quantization
|
||||
|
||||
Running with per-channel quantization model:
|
||||
|
||||
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
|
||||
|
||||
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-Channel-INT8` and port=5000, we can have following commands to launch the server:
|
||||
|
||||
```bash
|
||||
#master
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
#cluster
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
|
||||
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote \
|
||||
--enable-torch-compile --torch-compile-max-bs 32
|
||||
```
|
||||
|
||||
The benchmarking method is the same as describted in the previous [16 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization) example.
|
||||
|
||||
### Example: Serving on any cloud or Kubernetes with SkyPilot
|
||||
|
||||
SkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1).
|
||||
|
||||
To serve on multiple nodes:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/skypilot-org/skypilot.git
|
||||
# Serve on 2 H100/H200x8 nodes
|
||||
sky launch -c r1 llm/deepseek-r1/deepseek-r1-671B.yaml --retry-until-up
|
||||
# Serve on 4 A100x8 nodes
|
||||
sky launch -c r1 llm/deepseek-r1/deepseek-r1-671B-A100.yaml --retry-until-up
|
||||
```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
If you encounter the following error with fp16/bf16 checkpoint:
|
||||
|
||||
```bash
|
||||
ValueError: Weight output_partition_size = 576 is not divisible by weight quantization block_n = 128.
|
||||
```
|
||||
|
||||
edit your `config.json` and remove the `quantization_config` block. For example:
|
||||
|
||||
```json
|
||||
"quantization_config": {
|
||||
"activation_scheme": "dynamic",
|
||||
"fmt": "e4m3",
|
||||
"quant_method": "fp8",
|
||||
"weight_block_size": [128, 128]
|
||||
},
|
||||
```
|
||||
|
||||
Removing this block typically resolves the error. For more details, see the discussion in [sgl-project/sglang#3491](https://github.com/sgl-project/sglang/issues/3491#issuecomment-2650779851).
|
||||
|
||||
# Example: Serving with 4 H200 with w4fp8 Quantization
|
||||
There are mixed-precision quantization methods where MoE layers are computed using W4(int)A(FP)8 quantization while the dense layers remain in FP8 precision. Users can run these models efficiently on 4xH200 GPUs (or potentially 8xH100 GPUs), as the pre-quantized weights are already available on Hugging Face. Here's an example:
|
||||
|
||||
```bash
|
||||
python -m sglang.launch_server --model novita/Deepseek-V3-0324-W4AFP8 --mem-fraction-static 0.85 --disable-shared-experts-fusion --tp-size 4
|
||||
```
|
||||
|
||||
Other variants of pre-quantized DeepSeek models are also available:
|
||||
|
||||
- [novita/Deepseek-V3.1-W4AFP8](https://huggingface.co/novita/Deepseek-V3.1-W4AFP8)
|
||||
- [novita/Deepseek-R1-0528-W4AFP8](https://huggingface.co/novita/Deepseek-R1-0528-W4AFP8)
|
||||
- [novita/Deepseek-R1-W4AFP8](https://huggingface.co/novita/Deepseek-R1-W4AFP8)
|
||||
- [novita/Deepseek-V3-0324-W4AFP8](https://huggingface.co/novita/Deepseek-V3-0324-W4AFP8)
|
||||
|
||||
|
||||
## DeepSeek V3 Optimization Plan
|
||||
|
||||
https://github.com/sgl-project/sglang/issues/2591
|
||||
51
third_party/sglang/benchmark/dspy/README.md
vendored
Normal file
51
third_party/sglang/benchmark/dspy/README.md
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
## Install
|
||||
|
||||
```
|
||||
pip3 install dspy-ai
|
||||
```
|
||||
|
||||
Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.
|
||||
```
|
||||
cache_turn_on = False
|
||||
```
|
||||
|
||||
or set the environment variable
|
||||
|
||||
```
|
||||
export DSP_CACHEBOOL=false
|
||||
```
|
||||
|
||||
## Benchmark SGLang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_dspy_intro.py --backend sglang
|
||||
```
|
||||
|
||||
|
||||
## Benchmark TGI
|
||||
```
|
||||
docker run --name tgi --rm -ti --gpus all --network host \
|
||||
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||
ghcr.io/huggingface/text-generation-inference:1.3.0 \
|
||||
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||
--max-input-length 2048 --max-total-tokens 4096 \
|
||||
--port 24000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_dspy_intro.py --backend tgi
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Benchmark vLLM
|
||||
```
|
||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_dspy_intro.py --backend vllm
|
||||
```
|
||||
192
third_party/sglang/benchmark/dspy/bench_dspy_intro.py
vendored
Normal file
192
third_party/sglang/benchmark/dspy/bench_dspy_intro.py
vendored
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import dspy
|
||||
from dspy.datasets import HotPotQA
|
||||
|
||||
|
||||
class BasicQA(dspy.Signature):
|
||||
"""Answer questions with short factoid answers."""
|
||||
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||
|
||||
|
||||
class GenerateAnswer(dspy.Signature):
|
||||
"""Answer questions with short factoid answers."""
|
||||
|
||||
context = dspy.InputField(desc="may contain relevant facts")
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||
|
||||
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
prediction = self.generate_answer(context=context, question=question)
|
||||
return dspy.Prediction(context=context, answer=prediction.answer)
|
||||
|
||||
|
||||
def main(args):
|
||||
# lm = dspy.OpenAI(model='gpt-3.5-turbo')
|
||||
if args.backend == "tgi":
|
||||
lm = dspy.HFClientTGI(
|
||||
model="meta-llama/Llama-2-7b-chat-hf",
|
||||
port=args.port,
|
||||
url="http://localhost",
|
||||
)
|
||||
elif args.backend == "sglang":
|
||||
lm = dspy.HFClientSGLang(
|
||||
model="meta-llama/Llama-2-7b-chat-hf",
|
||||
port=args.port,
|
||||
url="http://localhost",
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
lm = dspy.HFClientVLLM(
|
||||
model="meta-llama/Llama-2-7b-chat-hf",
|
||||
port=args.port,
|
||||
url="http://localhost",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
colbertv2_wiki17_abstracts = dspy.ColBERTv2(
|
||||
url="http://20.102.90.50:2017/wiki17_abstracts"
|
||||
)
|
||||
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
|
||||
|
||||
# Load the dataset.
|
||||
dataset = HotPotQA(
|
||||
train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
|
||||
)
|
||||
|
||||
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
|
||||
trainset = [x.with_inputs("question") for x in dataset.train]
|
||||
devset = [x.with_inputs("question") for x in dataset.dev]
|
||||
|
||||
print(len(trainset), len(devset))
|
||||
|
||||
train_example = trainset[0]
|
||||
print(f"Question: {train_example.question}")
|
||||
print(f"Answer: {train_example.answer}")
|
||||
|
||||
dev_example = devset[18]
|
||||
print(f"Question: {dev_example.question}")
|
||||
print(f"Answer: {dev_example.answer}")
|
||||
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
|
||||
|
||||
print(
|
||||
f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
|
||||
)
|
||||
print(
|
||||
f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
|
||||
)
|
||||
|
||||
# Define the predictor.
|
||||
generate_answer = dspy.Predict(BasicQA)
|
||||
|
||||
# Call the predictor on a particular input.
|
||||
pred = generate_answer(question=dev_example.question)
|
||||
|
||||
# Print the input and the prediction.
|
||||
print(f"Question: {dev_example.question}")
|
||||
print(f"Predicted Answer: {pred.answer}")
|
||||
|
||||
lm.inspect_history(n=1)
|
||||
|
||||
# Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
|
||||
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
|
||||
|
||||
# Call the predictor on the same input.
|
||||
pred = generate_answer_with_chain_of_thought(question=dev_example.question)
|
||||
|
||||
# Print the input, the chain of thought, and the prediction.
|
||||
print(f"Question: {dev_example.question}")
|
||||
print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
|
||||
print(f"Predicted Answer: {pred.answer}")
|
||||
|
||||
retrieve = dspy.Retrieve(k=3)
|
||||
topK_passages = retrieve(dev_example.question).passages
|
||||
|
||||
print(
|
||||
f"Top {retrieve.k} passages for question: {dev_example.question} \n",
|
||||
"-" * 30,
|
||||
"\n",
|
||||
)
|
||||
|
||||
for idx, passage in enumerate(topK_passages):
|
||||
print(f"{idx+1}]", passage, "\n")
|
||||
|
||||
retrieve("When was the first FIFA World Cup held?").passages[0]
|
||||
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Validation logic: check that the predicted answer is correct.
|
||||
# Also check that the retrieved context does actually contain that answer.
|
||||
def validate_context_and_answer(example, pred, trace=None):
|
||||
answer_EM = dspy.evaluate.answer_exact_match(example, pred)
|
||||
answer_PM = dspy.evaluate.answer_passage_match(example, pred)
|
||||
return answer_EM and answer_PM
|
||||
|
||||
# Set up a basic teleprompter, which will compile our RAG program.
|
||||
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
|
||||
|
||||
# Compile!
|
||||
compiled_rag = teleprompter.compile(RAG(), trainset=trainset)
|
||||
|
||||
# Ask any question you like to this simple RAG program.
|
||||
my_question = "What castle did David Gregory inherit?"
|
||||
|
||||
# Get the prediction. This contains `pred.context` and `pred.answer`.
|
||||
pred = compiled_rag(my_question)
|
||||
|
||||
# Print the contexts and the answer.
|
||||
print(f"Question: {my_question}")
|
||||
print(f"Predicted Answer: {pred.answer}")
|
||||
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
|
||||
|
||||
from dspy.evaluate.evaluate import Evaluate
|
||||
|
||||
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
|
||||
evaluate_on_hotpotqa = Evaluate(
|
||||
devset=devset,
|
||||
num_threads=args.num_threads,
|
||||
display_progress=True,
|
||||
display_table=5,
|
||||
)
|
||||
|
||||
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
|
||||
metric = dspy.evaluate.answer_exact_match
|
||||
evaluate_on_hotpotqa(compiled_rag, metric=metric)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int)
|
||||
parser.add_argument("--num-threads", type=int, default=32)
|
||||
parser.add_argument("--dev-size", type=int, default=150)
|
||||
parser.add_argument(
|
||||
"--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port is None:
|
||||
default_port = {
|
||||
"vllm": 21000,
|
||||
"lightllm": 22000,
|
||||
"tgi": 24000,
|
||||
"sglang": 30000,
|
||||
}
|
||||
args.port = default_port.get(args.backend, None)
|
||||
|
||||
main(args)
|
||||
315
third_party/sglang/benchmark/fla/benchmark_layernorm_gated.py
vendored
Normal file
315
third_party/sglang/benchmark/fla/benchmark_layernorm_gated.py
vendored
Normal file
@@ -0,0 +1,315 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Import the function to benchmark
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import (
|
||||
_layer_norm_fwd as layer_norm_fwd,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import (
|
||||
rms_norm_ref,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_layer_norm_fwd(
|
||||
M: int = 65536,
|
||||
N: int = 128,
|
||||
eps: float = 1e-6,
|
||||
has_z: bool = True,
|
||||
has_bias: bool = False,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = True,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
warmup_iters: int = 10,
|
||||
benchmark_iters: int = 100,
|
||||
device: str = "cuda",
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Benchmark layer_norm_fwd with specified parameters.
|
||||
|
||||
Args:
|
||||
M: Number of rows (batch size)
|
||||
N: Number of columns (hidden dimension)
|
||||
eps: Epsilon for numerical stability
|
||||
has_z: Whether to use gating tensor z
|
||||
has_bias: Whether to use bias
|
||||
group_size: Group size for group normalization (None = full dimension)
|
||||
norm_before_gate: Whether to normalize before gating
|
||||
is_rms_norm: Whether to use RMS normalization (vs LayerNorm)
|
||||
dtype: Data type for tensors
|
||||
warmup_iters: Number of warmup iterations
|
||||
benchmark_iters: Number of benchmark iterations
|
||||
device: Device to run on
|
||||
"""
|
||||
if verbose:
|
||||
print("=" * 80)
|
||||
print("LayerNorm Forward Pass Benchmark")
|
||||
print("=" * 80)
|
||||
print(f"\nConfiguration:")
|
||||
print(f" x.shape: torch.Size([{M}, {N}])")
|
||||
print(f" weight.shape: torch.Size([{N}])")
|
||||
print(f" bias: {'torch.Size([{}])'.format(N) if has_bias else None}")
|
||||
print(f" eps: {eps}")
|
||||
print(f" z: {'torch.Size([{}, {}])'.format(M, N) if has_z else None}")
|
||||
print(f" out: None")
|
||||
print(f" group_size: {group_size}")
|
||||
print(f" norm_before_gate: {norm_before_gate}")
|
||||
print(f" is_rms_norm: {is_rms_norm}")
|
||||
print(f" dtype: {dtype}")
|
||||
print(f" device: {device}")
|
||||
print()
|
||||
|
||||
# Create input tensors
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(M, N, dtype=dtype, device=device)
|
||||
weight = torch.randn(N, dtype=dtype, device=device)
|
||||
bias = torch.randn(N, dtype=dtype, device=device) if has_bias else None
|
||||
z = torch.randn(M, N, dtype=dtype, device=device) if has_z else None
|
||||
|
||||
# Ensure contiguous memory layout
|
||||
x = x.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
if z is not None:
|
||||
z = z.contiguous()
|
||||
|
||||
if verbose:
|
||||
print("Warming up...")
|
||||
# Warmup
|
||||
for _ in range(warmup_iters):
|
||||
out, mean, rstd = layer_norm_fwd(
|
||||
x=x,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
eps=eps,
|
||||
z=z,
|
||||
out=None,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if verbose:
|
||||
print(f"Capturing CUDA graph...")
|
||||
|
||||
# Capture the kernel execution in a CUDA graph
|
||||
runs_per_measurement = 100
|
||||
|
||||
# Create output tensor for graph capture
|
||||
out_graph = torch.empty_like(x)
|
||||
mean_graph = (
|
||||
torch.empty((x.shape[0],), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd_graph = torch.empty((x.shape[0],), dtype=torch.float32, device=x.device)
|
||||
|
||||
# Capture the graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(runs_per_measurement):
|
||||
out, mean, rstd = layer_norm_fwd(
|
||||
x=x,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
eps=eps,
|
||||
z=z,
|
||||
out=out_graph,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"Running benchmark with {benchmark_iters} iterations using CUDA graph..."
|
||||
)
|
||||
|
||||
# Benchmark by replaying the graph
|
||||
times = []
|
||||
for i in range(benchmark_iters):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# elapsed_time_ms returns milliseconds, divide by runs_per_measurement
|
||||
elapsed_ms = start_event.elapsed_time(end_event)
|
||||
times.append(
|
||||
elapsed_ms / 1000.0 / runs_per_measurement
|
||||
) # Convert to seconds per run
|
||||
|
||||
# Compute statistics
|
||||
times = np.array(times) * 1_000_000 # Convert to microseconds
|
||||
mean_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
min_time = np.min(times)
|
||||
max_time = np.max(times)
|
||||
median_time = np.median(times)
|
||||
p95_time = np.percentile(times, 95)
|
||||
p99_time = np.percentile(times, 99)
|
||||
|
||||
# Calculate throughput
|
||||
num_elements = M * N
|
||||
throughput_gelements_per_sec = (num_elements / mean_time) * 1_000_000 / 1e9
|
||||
|
||||
# Calculate memory bandwidth
|
||||
# Read: x, weight, z (if has_z)
|
||||
# Write: out, rstd, mean (if not rms_norm)
|
||||
bytes_per_element = 2 if dtype == torch.float16 else 4 # fp16 or fp32
|
||||
read_bytes = (M * N + N) * bytes_per_element # x + weight
|
||||
if has_z:
|
||||
read_bytes += M * N * bytes_per_element # z
|
||||
write_bytes = M * N * bytes_per_element # out
|
||||
write_bytes += M * 4 # rstd (float32)
|
||||
if not is_rms_norm:
|
||||
write_bytes += M * 4 # mean (float32)
|
||||
|
||||
total_bytes = read_bytes + write_bytes
|
||||
bandwidth_gb_per_sec = (total_bytes / mean_time) * 1_000_000 / 1e9
|
||||
|
||||
if verbose:
|
||||
print("\n" + "=" * 80)
|
||||
print("Benchmark Results")
|
||||
print("=" * 80)
|
||||
print(f"\nTiming Statistics (microseconds):")
|
||||
print(f" Mean: {mean_time:.2f} us")
|
||||
print(f" Std Dev: {std_time:.2f} us")
|
||||
print(f" Min: {min_time:.2f} us")
|
||||
print(f" Max: {max_time:.2f} us")
|
||||
print(f" Median: {median_time:.2f} us")
|
||||
print(f" P95: {p95_time:.2f} us")
|
||||
print(f" P99: {p99_time:.2f} us")
|
||||
|
||||
print(f"\nThroughput:")
|
||||
print(f" {throughput_gelements_per_sec:.2f} GElements/sec")
|
||||
print(f" {bandwidth_gb_per_sec:.2f} GB/sec")
|
||||
|
||||
print(f"\nMemory Usage:")
|
||||
print(f" Input size: {read_bytes / 1e6:.2f} MB")
|
||||
print(f" Output size: {write_bytes / 1e6:.2f} MB")
|
||||
print(f" Total: {total_bytes / 1e6:.2f} MB")
|
||||
|
||||
# Verify correctness against reference implementation
|
||||
if verbose:
|
||||
print("\nVerifying correctness...")
|
||||
out_triton, mean_triton, rstd_triton = layer_norm_fwd(
|
||||
x=x,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
eps=eps,
|
||||
z=z,
|
||||
out=None,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
|
||||
# Compute reference output
|
||||
out_ref = rms_norm_ref(
|
||||
x=x,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
z=z,
|
||||
eps=eps,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
upcast=True,
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
max_diff = torch.max(torch.abs(out_triton - out_ref)).item()
|
||||
mean_diff = torch.mean(torch.abs(out_triton - out_ref)).item()
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out_triton - out_ref) / (torch.abs(out_ref) + 1e-5)
|
||||
).item()
|
||||
|
||||
if verbose:
|
||||
print(f"\nCorrectness Check (vs Reference Implementation):")
|
||||
print(f" Max absolute difference: {max_diff:.6e}")
|
||||
print(f" Mean absolute difference: {mean_diff:.6e}")
|
||||
print(f" Mean relative difference: {rel_diff:.6e}")
|
||||
|
||||
if max_diff < 1e-2:
|
||||
print(" ✓ PASS: Results match reference implementation")
|
||||
else:
|
||||
print(" ✗ FAIL: Results do not match reference implementation")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return {
|
||||
"mean_time_us": mean_time,
|
||||
"std_time_us": std_time,
|
||||
"min_time_us": min_time,
|
||||
"max_time_us": max_time,
|
||||
"median_time_us": median_time,
|
||||
"p95_time_us": p95_time,
|
||||
"p99_time_us": p99_time,
|
||||
"throughput_gelements_per_sec": throughput_gelements_per_sec,
|
||||
"bandwidth_gb_per_sec": bandwidth_gb_per_sec,
|
||||
"max_diff": max_diff,
|
||||
"mean_diff": mean_diff,
|
||||
"rel_diff": rel_diff,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the benchmark with the specified configuration."""
|
||||
# Configuration from user
|
||||
config = {
|
||||
"M": 65536,
|
||||
"N": 128,
|
||||
"eps": 1e-6,
|
||||
"has_z": True,
|
||||
"has_bias": False,
|
||||
"group_size": None,
|
||||
"norm_before_gate": True,
|
||||
"is_rms_norm": True,
|
||||
"dtype": torch.float16,
|
||||
"warmup_iters": 10,
|
||||
"benchmark_iters": 100,
|
||||
"device": "cuda",
|
||||
}
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is not available. This benchmark requires a CUDA-enabled GPU.")
|
||||
return
|
||||
|
||||
results = benchmark_layer_norm_fwd(**config)
|
||||
|
||||
# Collect all results
|
||||
all_results = []
|
||||
# Test with different batch sizes
|
||||
print("\nRunning benchmarks for varying batch sizes...")
|
||||
for M in [256, 512, 1024, 4096, 16384, 65536, 2**17, 2**18]:
|
||||
config_var = config.copy()
|
||||
config_var["M"] = M
|
||||
config_var["warmup_iters"] = 5
|
||||
config_var["benchmark_iters"] = 50
|
||||
config_var["verbose"] = False
|
||||
result = benchmark_layer_norm_fwd(**config_var)
|
||||
all_results.append({"M": M, "N": config_var["N"], **result})
|
||||
print(f" M={M:>5}: {result['mean_time_us']:>7.2f} us")
|
||||
|
||||
# Print summary table
|
||||
print("\n\n")
|
||||
print("=" * 30)
|
||||
print("SUMMARY TABLE - Varying Batch Size (M) with N=128")
|
||||
print("=" * 30)
|
||||
print(f"{'M':>8} | {'Median (us)':>12}")
|
||||
print("-" * 30)
|
||||
for r in all_results:
|
||||
print(f"{r['M']:>8} | {r['median_time_us']:>12.2f}")
|
||||
print("=" * 30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
third_party/sglang/benchmark/generative_agents/README.md
vendored
Normal file
38
third_party/sglang/benchmark/generative_agents/README.md
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
## Download the dataset
|
||||
|
||||
```
|
||||
wget -O agent_calls.jsonl https://drive.google.com/uc?export=download&id=19qLpD45e9JGTKF2cUjJJegwzSUEZEKht
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-events 1000 --parallel 1
|
||||
```
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
|
||||
```
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||
```
|
||||
|
||||
### Benchmark lmql
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-events 1000 --backend lmql --parallel 1
|
||||
```
|
||||
300
third_party/sglang/benchmark/generative_agents/agent_functions.py
vendored
Normal file
300
third_party/sglang/benchmark/generative_agents/agent_functions.py
vendored
Normal file
@@ -0,0 +1,300 @@
|
||||
import sglang as sgl
|
||||
|
||||
# here are the top five agent functions contributing ~70% LLM calls
|
||||
# reference: https://github.com/joonspk-research/generative_agents/
|
||||
|
||||
|
||||
@sgl.function
|
||||
def poignancy_event(s, persona_name, persona_iss, event):
|
||||
s += "Here is a brief description of " + persona_name + ".\n"
|
||||
s += persona_iss + "\n"
|
||||
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
||||
s += persona_name + ".\n\n"
|
||||
s += "Event: " + event
|
||||
s += "Rate (return a number between 1 to 10):"
|
||||
s += sgl.gen(name="Rate", max_tokens=2)
|
||||
|
||||
|
||||
def poignancy_event_prompt(persona_name, persona_iss, event):
|
||||
# return prompt and max_tokens
|
||||
s = ""
|
||||
s += "Here is a brief description of " + persona_name + ".\n"
|
||||
s += persona_iss + "\n"
|
||||
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
||||
s += persona_name + ".\n\n"
|
||||
s += "Event: " + event
|
||||
s += "Rate (return a number between 1 to 10):"
|
||||
return {"prompt": s, "max_tokens": 2, "stop": None}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def generate_event_triple(s, persona_name, action):
|
||||
s += """Task: Turn the input into (subject, predicate, object).
|
||||
Input: Sam Johnson is eating breakfast.
|
||||
Output: (Dolores Murphy, eat, breakfast)
|
||||
---
|
||||
Input: Joon Park is brewing coffee.
|
||||
Output: (Joon Park, brew, coffee)
|
||||
---
|
||||
Input: Jane Cook is sleeping.
|
||||
Output: (Jane Cook, is, sleep)
|
||||
---
|
||||
Input: Michael Bernstein is writing email on a computer.
|
||||
Output: (Michael Bernstein, write, email)
|
||||
---
|
||||
Input: Percy Liang is teaching students in a classroom.
|
||||
Output: (Percy Liang, teach, students)
|
||||
---
|
||||
Input: Merrie Morris is running on a treadmill.
|
||||
Output: (Merrie Morris, run, treadmill)
|
||||
---"""
|
||||
s += persona_name + "is" + action + ".\n"
|
||||
s += "(" + persona_name + ","
|
||||
s += sgl.gen(name="Triple", max_tokens=20, stop=")")
|
||||
|
||||
|
||||
def generate_event_triple_prompt(persona_name, action):
|
||||
s = ""
|
||||
s += """Task: Turn the input into (subject, predicate, object).
|
||||
Input: Sam Johnson is eating breakfast.
|
||||
Output: (Dolores Murphy, eat, breakfast)
|
||||
---
|
||||
Input: Joon Park is brewing coffee.
|
||||
Output: (Joon Park, brew, coffee)
|
||||
---
|
||||
Input: Jane Cook is sleeping.
|
||||
Output: (Jane Cook, is, sleep)
|
||||
---
|
||||
Input: Michael Bernstein is writing email on a computer.
|
||||
Output: (Michael Bernstein, write, email)
|
||||
---
|
||||
Input: Percy Liang is teaching students in a classroom.
|
||||
Output: (Percy Liang, teach, students)
|
||||
---
|
||||
Input: Merrie Morris is running on a treadmill.
|
||||
Output: (Merrie Morris, run, treadmill)
|
||||
---"""
|
||||
s += persona_name + "is" + action + ".\n"
|
||||
s += "(" + persona_name + ","
|
||||
return {"prompt": s, "max_tokens": 20, "stop": ")"}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def generate_pronunciatio(s, action):
|
||||
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
||||
s += "Action description: " + action + ".\n"
|
||||
s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6)
|
||||
|
||||
|
||||
def generate_pronunciatio_prompt(action):
|
||||
s = ""
|
||||
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
||||
s += "Action description: " + action + ".\n"
|
||||
s += "Emoji:"
|
||||
return {"prompt": s, "max_tokens": 6, "stop": None}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def action_location_sector(
|
||||
s,
|
||||
persona_name,
|
||||
living_sector,
|
||||
living_sector_areas,
|
||||
current_sector,
|
||||
current_sector_areas,
|
||||
daily_plan,
|
||||
sector_options,
|
||||
current_action,
|
||||
next_action,
|
||||
):
|
||||
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
||||
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
||||
---
|
||||
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
||||
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
||||
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||
---"""
|
||||
s += (
|
||||
persona_name
|
||||
+ " lives in "
|
||||
+ living_sector
|
||||
+ " that has "
|
||||
+ living_sector_areas
|
||||
+ ".\n"
|
||||
)
|
||||
s += (
|
||||
persona_name
|
||||
+ " is currently in "
|
||||
+ current_sector
|
||||
+ " that has "
|
||||
+ current_sector_areas
|
||||
+ ".\n"
|
||||
)
|
||||
s += daily_plan + ".\n"
|
||||
s += "Area options: " + sector_options + ".\n"
|
||||
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.\n"""
|
||||
s += (
|
||||
persona_name
|
||||
+ " is "
|
||||
+ current_action
|
||||
+ ". For "
|
||||
+ next_action
|
||||
+ ", "
|
||||
+ persona_name
|
||||
+ " should go to the following area: {"
|
||||
)
|
||||
s += sgl.gen(name="Location", max_tokens=10, stop="}")
|
||||
|
||||
|
||||
def action_location_sector_prompt(
|
||||
persona_name,
|
||||
living_sector,
|
||||
living_sector_areas,
|
||||
current_sector,
|
||||
current_sector_areas,
|
||||
daily_plan,
|
||||
sector_options,
|
||||
current_action,
|
||||
next_action,
|
||||
):
|
||||
s = ""
|
||||
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
||||
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
||||
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
||||
---
|
||||
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
||||
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
||||
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
||||
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.
|
||||
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||
---"""
|
||||
s += (
|
||||
persona_name
|
||||
+ " lives in "
|
||||
+ living_sector
|
||||
+ " that has "
|
||||
+ living_sector_areas
|
||||
+ ".\n"
|
||||
)
|
||||
s += (
|
||||
persona_name
|
||||
+ " is currently in "
|
||||
+ current_sector
|
||||
+ " that has "
|
||||
+ current_sector_areas
|
||||
+ ".\n"
|
||||
)
|
||||
s += daily_plan + ".\n"
|
||||
s += "Area options: " + sector_options + ".\n"
|
||||
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||
* Must be one of the "Area options," verbatim.\n"""
|
||||
s += (
|
||||
persona_name
|
||||
+ " is "
|
||||
+ current_action
|
||||
+ ". For "
|
||||
+ next_action
|
||||
+ ", "
|
||||
+ persona_name
|
||||
+ " should go to the following area: {"
|
||||
)
|
||||
return {"prompt": s, "max_tokens": 10, "stop": "}"}
|
||||
|
||||
|
||||
@sgl.function
|
||||
def action_location_object(
|
||||
s, persona_name, target_sector, target_sector_areas, current_action, next_action
|
||||
):
|
||||
s += """
|
||||
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
||||
Answer: {kitchen}
|
||||
---
|
||||
Tom Watson is in common room in Tom Watson's apartment.
|
||||
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||
Answer: {cafe}
|
||||
---"""
|
||||
s += (
|
||||
persona_name
|
||||
+ " is going to "
|
||||
+ target_sector
|
||||
+ " that has the following areas: {"
|
||||
+ target_sector_areas
|
||||
+ "}\n"
|
||||
)
|
||||
s += """* Stay in the current area if the activity can be done there.
|
||||
* NEVER go into other people's rooms unless necessary."""
|
||||
s += (
|
||||
persona_name
|
||||
+ " is "
|
||||
+ current_action
|
||||
+ ". For "
|
||||
+ next_action
|
||||
+ ", "
|
||||
+ persona_name
|
||||
+ "should go to the following area in "
|
||||
+ target_sector
|
||||
)
|
||||
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
|
||||
|
||||
|
||||
def action_location_object_prompt(
|
||||
persona_name, target_sector, target_sector_areas, current_action, next_action
|
||||
):
|
||||
s = ""
|
||||
s += """
|
||||
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
||||
Answer: {kitchen}
|
||||
---
|
||||
Tom Watson is in common room in Tom Watson's apartment.
|
||||
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
||||
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
||||
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||
Answer: {cafe}
|
||||
---"""
|
||||
s += (
|
||||
persona_name
|
||||
+ " is going to "
|
||||
+ target_sector
|
||||
+ " that has the following areas: {"
|
||||
+ target_sector_areas
|
||||
+ "}\n"
|
||||
)
|
||||
s += """* Stay in the current area if the activity can be done there.
|
||||
* NEVER go into other people's rooms unless necessary."""
|
||||
s += (
|
||||
persona_name
|
||||
+ " is "
|
||||
+ current_action
|
||||
+ ". For "
|
||||
+ next_action
|
||||
+ ", "
|
||||
+ persona_name
|
||||
+ "should go to the following area in "
|
||||
+ target_sector
|
||||
)
|
||||
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||
s += "Answer: {"
|
||||
return {"prompt": s, "max_tokens": 5, "stop": "}"}
|
||||
80
third_party/sglang/benchmark/generative_agents/bench_other.py
vendored
Normal file
80
third_party/sglang/benchmark/generative_agents/bench_other.py
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
from agent_functions import (
|
||||
action_location_object_prompt,
|
||||
action_location_sector_prompt,
|
||||
generate_event_triple_prompt,
|
||||
generate_pronunciatio_prompt,
|
||||
poignancy_event_prompt,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[: args.num_events]
|
||||
mapping = {
|
||||
"poignancy_event": poignancy_event_prompt,
|
||||
"generate_event_triple": generate_event_triple_prompt,
|
||||
"generate_pronunciatio": generate_pronunciatio_prompt,
|
||||
"action_location_sector": action_location_sector_prompt,
|
||||
"action_location_object": action_location_object_prompt,
|
||||
}
|
||||
|
||||
arguments = [mapping[k](**v) for l in lines for k, v in l.items()]
|
||||
states = []
|
||||
|
||||
# Select backend
|
||||
call_generate = get_call_generate(args)
|
||||
|
||||
def get_one_answer(arg):
|
||||
answer = call_generate(**arg, temperature=0)
|
||||
states.append(answer)
|
||||
|
||||
async def get_one_answer_async(arg):
|
||||
answer = await call_generate(**arg, temperature=0)
|
||||
states.append(answer)
|
||||
|
||||
tic = time.perf_counter()
|
||||
# we always sequentially execute agent calls to maintain its dependency
|
||||
if args.backend != "lmql":
|
||||
for arg in tqdm(arguments):
|
||||
get_one_answer(arg)
|
||||
else:
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for arg in tqdm(arguments):
|
||||
loop.run_until_complete(get_one_answer_async(arg))
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "Generative Agents",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
# to pack weighted functions as a single agent
|
||||
"num_requests": len(arguments) / len(mapping),
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
||||
parser.add_argument("--num-events", type=int, default=10)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
74
third_party/sglang/benchmark/generative_agents/bench_sglang.py
vendored
Normal file
74
third_party/sglang/benchmark/generative_agents/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
from agent_functions import (
|
||||
action_location_object,
|
||||
action_location_sector,
|
||||
generate_event_triple,
|
||||
generate_pronunciatio,
|
||||
poignancy_event,
|
||||
)
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)[: args.num_events]
|
||||
mapping = {
|
||||
"poignancy_event": poignancy_event,
|
||||
"generate_event_triple": generate_event_triple,
|
||||
"generate_pronunciatio": generate_pronunciatio,
|
||||
"action_location_sector": action_location_sector,
|
||||
"action_location_object": action_location_object,
|
||||
}
|
||||
arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
states = []
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
for a in arguments:
|
||||
# only a single key in the dict
|
||||
for func, arg in a.items():
|
||||
result = func.run(**arg)
|
||||
result.sync()
|
||||
states.append(result)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "Generative Agents",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
# to pack weighted functions as a single agent
|
||||
"num_requests": len(arguments) / len(mapping),
|
||||
"other": {
|
||||
"num_events": args.num_events,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
||||
parser.add_argument("--num-events", type=int, default=10)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
163
third_party/sglang/benchmark/gpt_oss/README.md
vendored
Normal file
163
third_party/sglang/benchmark/gpt_oss/README.md
vendored
Normal file
@@ -0,0 +1,163 @@
|
||||
# How to reproduce the result of GPT-OSS with SGLang
|
||||
|
||||
### Install the latest SGLang
|
||||
|
||||
```bash
|
||||
git clone https://github.com/sgl-project/sglang.git
|
||||
cd sglang
|
||||
git checkout v0.5.1.post3
|
||||
|
||||
pip install --upgrade pip
|
||||
pip install -e "python[all]"
|
||||
```
|
||||
|
||||
### Reproduce the benchmark throughput result (Batch Size 1)
|
||||
|
||||
Launch Command
|
||||
|
||||
```bash
|
||||
# MXFP4 120B on H100
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 8 --attention-backend triton
|
||||
|
||||
# BF16 120B on H100
|
||||
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 8 --attention-backend triton
|
||||
|
||||
# MXFP4 120B on B200
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 4
|
||||
|
||||
# BF16 120B on B200
|
||||
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 4
|
||||
```
|
||||
|
||||
Benchmark Command
|
||||
|
||||
```bash
|
||||
|
||||
# MXFP4 120B on H100
|
||||
python3 -m sglang.bench_one_batch_server --model openai/gpt-oss-120b --base-url http://localhost:30000 --batch-size 1 --input-len 1024 --output-len 512 --show-report
|
||||
```
|
||||
|
||||
### Reproduce the benchmark throughput result (Batch Size 32)
|
||||
|
||||
Launch Command
|
||||
|
||||
```bash
|
||||
# MXFP4 120B on H100
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 8
|
||||
|
||||
# BF16 120B on H100
|
||||
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 8
|
||||
|
||||
# MXFP4 120B on B200
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 4
|
||||
|
||||
# BF16 120B on B200
|
||||
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 4
|
||||
```
|
||||
|
||||
Benchmark Command
|
||||
|
||||
```bash
|
||||
python3 -m sglang.bench_one_batch_server --model openai/gpt-oss-120b --base-url http://localhost:30000 --batch-size 32 --input-len 1024 8192 --output-len 512 --show-report
|
||||
```
|
||||
|
||||
### Reproduce the evaluation result
|
||||
|
||||
Install gpt-oss
|
||||
|
||||
```bash
|
||||
git clone https://github.com/openai/gpt-oss.git
|
||||
cd gpt-oss
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Evaluation Command
|
||||
|
||||
```bash
|
||||
DATASET=gpqa
|
||||
BASE_URL=YOUR_BASE_URL
|
||||
OPENAI_API_KEY=dummy python -m gpt_oss.evals \
|
||||
--base-url ${BASE_URL}/v1 \
|
||||
--model dummy \
|
||||
--reasoning-effort low,medium,high \
|
||||
--eval $DATASET \
|
||||
--n-threads 1000
|
||||
```
|
||||
|
||||
### Reproduce the benchmark result of acceptance length
|
||||
> Note: On B200, if top k is 1, set `--attention-backend trtllm_mha`
|
||||
```bash
|
||||
git clone https://github.com/sgl-project/SpecForge.git
|
||||
cd SpecForge/benchmarks
|
||||
config_list=(
|
||||
"1,0,0,0"
|
||||
"1,3,1,4"
|
||||
"1,5,4,8"
|
||||
)
|
||||
python3 bench_model_speedup.py \
|
||||
--model-path openai/gpt-oss-120b \
|
||||
--speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 \
|
||||
--port 20001 \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.8 \
|
||||
--tp-size 4 \
|
||||
--attention-backend fa3 \
|
||||
--config-list "${config_list[@]}" \
|
||||
--benchmark-list mtbench:80 gsm8k:200 humaneval:200 math500:200 \
|
||||
--output lmsys_gpt-oss-120b_Eagle3_result.jsonl
|
||||
|
||||
python3 bench_model_speedup.py \
|
||||
--model-path openai/gpt-oss-120b \
|
||||
--speculative-draft-model-path nvidia/gpt-oss-120b-Eagle3 \
|
||||
--port 20001 \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.8 \
|
||||
--tp-size 4 \
|
||||
--attention-backend fa3 \
|
||||
--config-list "${config_list[@]}" \
|
||||
--benchmark-list mtbench:80 gsm8k:200 humaneval:200 math500:200 \
|
||||
--output nv_gpt-oss-120b_Eagle3_result.jsonl
|
||||
```
|
||||
|
||||
### Reproduce the result of speculative decoding speedup
|
||||
|
||||
Launch Command
|
||||
|
||||
```bash
|
||||
# On Hopper:
|
||||
# - Tree decoding (topk > 1) and chain decoding (topk = 1) are supported on both FA3 and Triton backends.
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tp 4
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --tp 4
|
||||
|
||||
# On Blackwell:
|
||||
# - Chain decoding (topk = 1) is supported on TRTLLM-MHA backend. Tree decoding (topk > 1) is in progress, stay tuned!
|
||||
# - Both tree decoding (topk > 1) and chain decoding (topk = 1) are supported on the Triton backend.
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algo EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tp 4
|
||||
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algo EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --attention-backend triton --tp 4
|
||||
```
|
||||
|
||||
Benchmark Command
|
||||
|
||||
```bash
|
||||
config_list=(
|
||||
"1,0,0,0"
|
||||
"1,3,1,4"
|
||||
"1,5,4,8"
|
||||
)
|
||||
python3 bench_model_speedup.py \
|
||||
--model-path openai/gpt-oss-120b \
|
||||
--speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 \
|
||||
--port 20001 \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.8 \
|
||||
--tp-size 4 \
|
||||
--attention-backend fa3 \
|
||||
--config-list "${config_list[@]}" \
|
||||
--benchmark-list gsm8k:200 humaneval:200 math500:200 \
|
||||
--output lmsys_gpt-oss-120b_Eagle3_result.jsonl
|
||||
```
|
||||
|
||||
We can gain the best speedup with the following settings:
|
||||
|
||||
- **1.39x** speedup with the `--speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4` setting.
|
||||
- **1.52x** speedup with the `--speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8` setting.
|
||||
57
third_party/sglang/benchmark/gsm8k/README.md
vendored
Normal file
57
third_party/sglang/benchmark/gsm8k/README.md
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
## Run benchmark
|
||||
|
||||
### Using GSM8K Platinum
|
||||
|
||||
GSM8K Platinum is a revised version of the GSM8K test set with corrected labels and removed ambiguous questions. It can be more stable than the original GSM8K dataset. It's a drop-in replacement that can be used by adding the `--platinum` flag:
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
|
||||
```
|
||||
|
||||
For more information, see: https://huggingface.co/datasets/madrylab/gsm8k-platinum
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 200
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lmql
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 100 --backend lmql --parallel 2
|
||||
```
|
||||
164
third_party/sglang/benchmark/gsm8k/bench_other.py
vendored
Normal file
164
third_party/sglang/benchmark/gsm8k/bench_other.py
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||
if include_answer:
|
||||
ret += " " + lines[i]["answer"]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r"\d+", answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
call_generate = get_call_generate(args)
|
||||
|
||||
# Read data
|
||||
if args.platinum:
|
||||
print("Loading GSM8K Platinum dataset from HuggingFace...")
|
||||
dataset = load_dataset("madrylab/gsm8k-platinum", "main", split="test")
|
||||
lines = [
|
||||
{"question": item["question"], "answer": item["answer"]} for item in dataset
|
||||
]
|
||||
else:
|
||||
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
filename = download_and_cache_file(url)
|
||||
lines = list(read_jsonl(filename))
|
||||
|
||||
# Construct prompts
|
||||
num_questions = args.num_questions
|
||||
num_shots = args.num_shots
|
||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
|
||||
states = [None] * len(labels)
|
||||
|
||||
# Run requests
|
||||
if args.backend != "lmql":
|
||||
# Use thread pool
|
||||
def get_one_answer(i):
|
||||
answer = call_generate(
|
||||
prompt=few_shot_examples + questions[i],
|
||||
temperature=0,
|
||||
max_tokens=256,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
)
|
||||
states[i] = answer
|
||||
|
||||
tic = time.perf_counter()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(questions))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
list(
|
||||
tqdm(
|
||||
executor.map(get_one_answer, list(range(len(questions)))),
|
||||
total=len(questions),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
# Use asyncio
|
||||
async def batched_call(batch_size):
|
||||
for i in range(0, len(questions), batch_size):
|
||||
tasks = []
|
||||
for q in questions[i : i + batch_size]:
|
||||
tasks.append(
|
||||
call_generate(
|
||||
few_shot_examples + q,
|
||||
temperature=0,
|
||||
max_tokens=256,
|
||||
stop="Question",
|
||||
)
|
||||
)
|
||||
rets = await asyncio.gather(*tasks)
|
||||
for j in range(len(rets)):
|
||||
states[i + j] = rets[j]
|
||||
|
||||
tic = time.perf_counter()
|
||||
asyncio.run(batched_call(batch_size=args.parallel))
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(get_answer_value(states[i]))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
|
||||
# Print results
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Latency: {latency:.3f} s")
|
||||
|
||||
# Dump results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "gsm8k-platinum" if args.platinum else "gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shots", type=int, default=5)
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
parser.add_argument(
|
||||
"--platinum",
|
||||
action="store_true",
|
||||
help="Use GSM8K Platinum dataset (drop-in replacement with corrected labels)",
|
||||
)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
199
third_party/sglang/benchmark/gsm8k/bench_sglang.py
vendored
Normal file
199
third_party/sglang/benchmark/gsm8k/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,199 @@
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from sglang.lang.api import set_default_backend
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
dump_bench_raw_result,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||
if include_answer:
|
||||
ret += " " + lines[i]["answer"]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def get_answer_value(answer_str):
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r"\d+", answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
set_default_backend(select_sglang_backend(args))
|
||||
|
||||
# Load tokenizer if enable_thinking is set
|
||||
tokenizer = None
|
||||
if args.enable_thinking:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
assert (
|
||||
args.tokenizer_path is not None
|
||||
), "--tokenizer-path is required when --enable-thinking is set"
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_path, trust_remote_code=True
|
||||
)
|
||||
|
||||
# Read data
|
||||
if args.platinum:
|
||||
print("Loading GSM8K Platinum dataset from HuggingFace...")
|
||||
dataset = load_dataset("madrylab/gsm8k-platinum", "main", split="test")
|
||||
lines = [
|
||||
{"question": item["question"], "answer": item["answer"]} for item in dataset
|
||||
]
|
||||
else:
|
||||
data_path = args.data_path
|
||||
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
if not os.path.isfile(data_path):
|
||||
data_path = download_and_cache_file(url)
|
||||
lines = list(read_jsonl(data_path))
|
||||
|
||||
# Construct prompts
|
||||
num_questions = args.num_questions
|
||||
num_shots = args.num_shots
|
||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(len(lines[:num_questions])):
|
||||
raw_question = few_shot_examples + get_one_example(lines, i, False)
|
||||
if tokenizer is not None:
|
||||
messages = [{"role": "user", "content": raw_question}]
|
||||
raw_question = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
questions.append(raw_question)
|
||||
labels.append(get_answer_value(lines[i]["answer"]))
|
||||
assert all(l != INVALID for l in labels)
|
||||
arguments = [{"question": q} for q in questions]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_gsm8k(s, question):
|
||||
s += question
|
||||
s += sgl.gen(
|
||||
"answer",
|
||||
max_tokens=args.max_new_tokens,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
)
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
states = few_shot_gsm8k.run_batch(
|
||||
arguments,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
preds = []
|
||||
for i in range(len(states)):
|
||||
preds.append(get_answer_value(states[i]["answer"]))
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid = np.mean(np.array(preds) == INVALID)
|
||||
|
||||
# Compute speed
|
||||
num_output_tokens = sum(
|
||||
s.get_meta_info("answer")["completion_tokens"] for s in states
|
||||
)
|
||||
output_throughput = num_output_tokens / latency
|
||||
|
||||
# Print results
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
print(f"Invalid: {invalid:.3f}")
|
||||
print(f"Latency: {latency:.3f} s")
|
||||
print(f"Output throughput: {output_throughput:.3f} token/s")
|
||||
|
||||
# Dump results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
dump_bench_raw_result(
|
||||
path=args.raw_result_file,
|
||||
states=states,
|
||||
preds=preds,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "gsm8k-platinum" if args.platinum else "gsm8k",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shots", type=int, default=5)
|
||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=512)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument(
|
||||
"--enable-thinking",
|
||||
action="store_true",
|
||||
help="Enable thinking mode by wrapping prompts with chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to tokenizer (required when --enable-thinking is set)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--platinum",
|
||||
action="store_true",
|
||||
help="Use GSM8K Platinum dataset (drop-in replacement with corrected labels)",
|
||||
)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
47
third_party/sglang/benchmark/hellaswag/README.md
vendored
Normal file
47
third_party/sglang/benchmark/hellaswag/README.md
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 200
|
||||
```
|
||||
|
||||
|
||||
### Benchmark vllm
|
||||
```
|
||||
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend vllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lightllm
|
||||
```
|
||||
# A10G
|
||||
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend lightllm
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||
```
|
||||
|
||||
|
||||
### Benchmark lmql
|
||||
```
|
||||
lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1
|
||||
```
|
||||
118
third_party/sglang/benchmark/hellaswag/bench_other.py
vendored
Normal file
118
third_party/sglang/benchmark/hellaswag/bench_other.py
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
|
||||
from sglang.utils import download_and_cache_file, read_jsonl
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
if include_answer:
|
||||
ret += lines[i]["endings"][lines[i]["label"]]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
call_select = get_call_select(args)
|
||||
|
||||
# Read data
|
||||
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||
filename = download_and_cache_file(url)
|
||||
lines = list(read_jsonl(filename))
|
||||
|
||||
# Construct prompts
|
||||
num_questions = args.num_questions
|
||||
num_shots = args.num_shots
|
||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||
|
||||
questions = []
|
||||
choices = []
|
||||
labels = []
|
||||
for i in range(len(lines[:num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
choices.append(lines[i]["endings"])
|
||||
labels.append(lines[i]["label"])
|
||||
|
||||
preds = [None] * len(labels)
|
||||
|
||||
# Run requests
|
||||
if args.backend != "lmql":
|
||||
# Use thread pool
|
||||
def get_one_answer(i):
|
||||
preds[i] = call_select(
|
||||
context=few_shot_examples + questions[i], choices=choices[i]
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(questions))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
list(
|
||||
tqdm(
|
||||
executor.map(get_one_answer, list(range(len(questions)))),
|
||||
total=len(questions),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use asyncio
|
||||
async def batched_call(batch_size):
|
||||
for i in range(0, len(questions), batch_size):
|
||||
tasks = []
|
||||
for q, c in zip(
|
||||
questions[i : i + batch_size], choices[i : i + batch_size]
|
||||
):
|
||||
tasks.append(call_select(context=few_shot_examples + q, choices=c))
|
||||
rets = await asyncio.gather(*tasks)
|
||||
for j in range(len(rets)):
|
||||
preds[i + j] = rets[j]
|
||||
|
||||
tic = time.perf_counter()
|
||||
asyncio.run(batched_call(batch_size=args.parallel))
|
||||
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "hellaswag",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shots", type=int, default=20)
|
||||
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
109
third_party/sglang/benchmark/hellaswag/bench_sglang.py
vendored
Normal file
109
third_party/sglang/benchmark/hellaswag/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.lang.api import set_default_backend
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import download_and_cache_file, read_jsonl
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
if include_answer:
|
||||
ret += lines[i]["endings"][lines[i]["label"]]
|
||||
return ret
|
||||
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
set_default_backend(select_sglang_backend(args))
|
||||
|
||||
# Read data
|
||||
data_path = args.data_path
|
||||
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||
if not os.path.isfile(data_path):
|
||||
data_path = download_and_cache_file(url)
|
||||
lines = list(read_jsonl(data_path))
|
||||
|
||||
# Construct prompts
|
||||
num_questions = args.num_questions
|
||||
num_shots = args.num_shots
|
||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||
|
||||
questions = []
|
||||
choices = []
|
||||
labels = []
|
||||
for i in range(len(lines[:num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
choices.append(lines[i]["endings"])
|
||||
labels.append(lines[i]["label"])
|
||||
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_hellaswag(s, question, choices):
|
||||
s += few_shot_examples + question
|
||||
s += sgl.select("answer", choices=choices)
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
rets = few_shot_hellaswag.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute accuracy
|
||||
acc = np.mean(np.array(preds) == np.array(labels))
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Accuracy: {acc:.3f}")
|
||||
|
||||
# Write results
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "hellaswag",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"accuracy": round(acc, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-shots", type=int, default=20)
|
||||
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=200)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
59
third_party/sglang/benchmark/hf3fs/bench.sh
vendored
Normal file
59
third_party/sglang/benchmark/hf3fs/bench.sh
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||
python3 benchmark/hf3fs/bench_client.py
|
||||
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
|
||||
python3 benchmark/hf3fs/bench_storage.py
|
||||
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json
|
||||
echo '{"file_path_prefix": "/data/hf3fs-test-0", "file_size": 1099511627776, "numjobs": 16, "entries": 8}' > \
|
||||
${SGLANG_HICACHE_HF3FS_CONFIG_PATH}
|
||||
python3 benchmark/hf3fs/bench_zerocopy.py
|
||||
|
||||
####################################################################################################
|
||||
|
||||
rm -rf nohup.out && \
|
||||
nohup python3 -m sglang.launch_server \
|
||||
--model-path /code/models/Qwen3-32B/ \
|
||||
--host 0.0.0.0 --port 33301 \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2 --hicache-size 0 \
|
||||
--hicache-write-policy write_through \
|
||||
--hicache-storage-backend hf3fs &
|
||||
|
||||
rm -rf bench_multiturn.out && \
|
||||
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||
--model-path /code/models/Qwen3-32B \
|
||||
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--port 33301 \
|
||||
--request-length 2048 --num-clients 512 --num-rounds 3 --max-parallel 8 \
|
||||
> bench_multiturn.out &
|
||||
|
||||
####################################################################################################
|
||||
|
||||
rm -rf nohup.out && \
|
||||
nohup python3 -m sglang.launch_server \
|
||||
--model-path /code/models/DeepSeek-R1/ \
|
||||
--tp 16 --nnodes 2 --node-rank 0 \
|
||||
--dist-init-addr 10.74.249.153:5000 \
|
||||
--host 0.0.0.0 --port 33301 \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2 --hicache-size 60 \
|
||||
--hicache-write-policy write_through \
|
||||
--hicache-storage-backend hf3fs &
|
||||
|
||||
rm -rf bench_multiturn.out && \
|
||||
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||
--model-path /code/models/Qwen3-32B \
|
||||
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--port 33301 \
|
||||
--request-length 2048 --num-clients 1024 --num-rounds 3 --max-parallel 8 \
|
||||
> bench_multiturn.out &
|
||||
|
||||
####################################################################################################
|
||||
|
||||
ps aux | grep "sglang.launch_server" | grep -v grep | awk '{print $2}' | xargs kill -9
|
||||
ps aux | grep "bench_multiturn.py" | grep -v grep | awk '{print $2}' | xargs kill -9
|
||||
162
third_party/sglang/benchmark/hf3fs/bench_client.py
vendored
Normal file
162
third_party/sglang/benchmark/hf3fs/bench_client.py
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import Hf3fsUsrBioClient
|
||||
|
||||
|
||||
def print_stats(x: List[int]):
|
||||
x = sorted(x)
|
||||
lenx = len(x)
|
||||
print(
|
||||
f"mean = {sum(x)/len(x):.2f}, "
|
||||
f"min = {min(x):.2f}, "
|
||||
f"p25 = {x[int(lenx*0.25)]:.2f}, "
|
||||
f"p50 = {x[int(lenx*0.5)]:.2f}, "
|
||||
f"p75 = {x[int(lenx*0.75)]:.2f}, "
|
||||
f"max = {max(x):.2f}"
|
||||
)
|
||||
|
||||
|
||||
def test():
|
||||
# /path/to/hf3fs
|
||||
file_path = "/data/bench.bin"
|
||||
file_size = 1 << 40
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 32
|
||||
file_ops = Hf3fsUsrBioClient(file_path, file_size, bytes_per_page, entries, 5)
|
||||
|
||||
print("test batch_read / batch_write")
|
||||
num_pages = 128
|
||||
dtype = torch.bfloat16
|
||||
numel = bytes_per_page // dtype.itemsize
|
||||
offsets = list(range(file_size // bytes_per_page))
|
||||
random.shuffle(offsets)
|
||||
offsets = offsets[:num_pages]
|
||||
offsets = [i * bytes_per_page for i in offsets]
|
||||
tensor_writes = [
|
||||
torch.randn(numel, dtype=dtype)
|
||||
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
||||
]
|
||||
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"):
|
||||
results = file_ops.batch_write(
|
||||
offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]
|
||||
)
|
||||
assert all([result == numel * dtype.itemsize for result in results])
|
||||
tensor_reads = [
|
||||
torch.empty(numel, dtype=dtype)
|
||||
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
||||
]
|
||||
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"):
|
||||
results = file_ops.batch_read(
|
||||
offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]
|
||||
)
|
||||
assert all([result == numel * dtype.itemsize for result in results])
|
||||
assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])
|
||||
|
||||
file_ops.close()
|
||||
print("test done")
|
||||
|
||||
|
||||
def bench():
|
||||
file_path = "/data/bench.bin"
|
||||
file_size = 1 << 40
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 8
|
||||
numjobs = 16
|
||||
|
||||
dtype = torch.bfloat16
|
||||
numel = bytes_per_page // dtype.itemsize
|
||||
|
||||
file_ops = [
|
||||
Hf3fsUsrBioClient(file_path, file_size, bytes_per_page, entries, 5)
|
||||
for _ in range(numjobs)
|
||||
]
|
||||
|
||||
num_page = entries
|
||||
|
||||
offsets = list(range(file_size // bytes_per_page))
|
||||
tensors_write = [torch.randn(numel, dtype=dtype)] * num_page
|
||||
tensors_read = [torch.empty(numel, dtype=dtype)] * num_page
|
||||
random.shuffle(offsets)
|
||||
|
||||
warmup = 50
|
||||
iteration = 100
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)
|
||||
|
||||
w_bw = []
|
||||
w_size = num_page * numjobs * bytes_per_page / (1 << 30)
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
|
||||
_offsets = [
|
||||
[
|
||||
offset * bytes_per_page
|
||||
for offset in offsets[
|
||||
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
||||
]
|
||||
]
|
||||
for j in range(numjobs)
|
||||
]
|
||||
tik = time.perf_counter()
|
||||
futures = [
|
||||
executor.submit(file_ops[j].batch_write, offset, tensors_write)
|
||||
for j, offset in enumerate(_offsets)
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
w_bw.append(w_size / (tok - tik))
|
||||
results = [
|
||||
_result == bytes_per_page for result in results for _result in result
|
||||
]
|
||||
assert all(results)
|
||||
print_stats(w_bw)
|
||||
|
||||
r_bw = []
|
||||
r_size = w_size
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||
_offsets = [
|
||||
[
|
||||
offset * bytes_per_page
|
||||
for offset in offsets[
|
||||
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
||||
]
|
||||
]
|
||||
for j in range(numjobs)
|
||||
]
|
||||
tik = time.perf_counter()
|
||||
futures = [
|
||||
executor.submit(file_ops[j].batch_read, offset, tensors_read)
|
||||
for j, offset in enumerate(_offsets)
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
r_bw.append(r_size / (tok - tik))
|
||||
results = [
|
||||
_result == bytes_per_page for result in results for _result in result
|
||||
]
|
||||
assert all(results)
|
||||
print_stats(r_bw)
|
||||
|
||||
executor.shutdown(wait=True)
|
||||
for _file_ops in file_ops:
|
||||
_file_ops.close()
|
||||
print("bench done")
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
test()
|
||||
bench()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
258
third_party/sglang/benchmark/hf3fs/bench_storage.py
vendored
Normal file
258
third_party/sglang/benchmark/hf3fs/bench_storage.py
vendored
Normal file
@@ -0,0 +1,258 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||
Hf3fsLocalMetadataClient,
|
||||
)
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||
|
||||
|
||||
def print_stats(x: List[int]):
|
||||
x = sorted(x)
|
||||
lenx = len(x)
|
||||
print(
|
||||
f"mean = {sum(x)/len(x):.2f}, "
|
||||
f"min = {min(x):.2f}, "
|
||||
f"p25 = {x[int(lenx*0.25)]:.2f}, "
|
||||
f"p50 = {x[int(lenx*0.5)]:.2f}, "
|
||||
f"p75 = {x[int(lenx*0.75)]:.2f}, "
|
||||
f"max = {max(x):.2f}"
|
||||
)
|
||||
|
||||
|
||||
def test():
|
||||
# Qwen3-32B
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||
store_dtype = torch.bfloat16
|
||||
tokens_per_page = 64
|
||||
|
||||
file_path_prefix = "/data/test"
|
||||
file_size = 128 << 20
|
||||
numjobs = 16
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 2
|
||||
dtype = store_dtype
|
||||
|
||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||
assert config_path
|
||||
try:
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"file_path_prefix": file_path_prefix,
|
||||
"file_size": file_size,
|
||||
"numjobs": numjobs,
|
||||
"entries": entries,
|
||||
},
|
||||
f,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
|
||||
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
assert numel * dtype.itemsize == bytes_per_page
|
||||
|
||||
num_pages = 10
|
||||
tensors = {}
|
||||
for i in range(num_pages):
|
||||
k = f"key_{i}"
|
||||
v = torch.randn((numel,)).to(dtype=dtype)
|
||||
ok = hicache_hf3fs.set(k, v)
|
||||
if i < (file_size // bytes_per_page):
|
||||
assert ok, f"Failed to insert {k}"
|
||||
else:
|
||||
assert not ok
|
||||
tensors[k] = v
|
||||
assert hicache_hf3fs.get("key_8") is None
|
||||
assert hicache_hf3fs.get("key_9") is None
|
||||
|
||||
start = 0
|
||||
for i in range(start, start + hicache_hf3fs.num_pages):
|
||||
k = f"key_{i}"
|
||||
assert hicache_hf3fs.exists(k)
|
||||
out = hicache_hf3fs.get(k)
|
||||
assert out is not None
|
||||
v = tensors[k]
|
||||
assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}"
|
||||
|
||||
assert not hicache_hf3fs.exists("not_exists")
|
||||
|
||||
hicache_hf3fs.delete("key_7")
|
||||
v2 = torch.randn((numel,)).to(dtype=dtype)
|
||||
assert hicache_hf3fs.set("key_new", v2)
|
||||
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
|
||||
|
||||
hicache_hf3fs.clear()
|
||||
assert (
|
||||
len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
|
||||
== hicache_hf3fs.metadata_client.rank_metadata.num_pages
|
||||
)
|
||||
|
||||
# batch
|
||||
num_pages = 10
|
||||
tensors = {}
|
||||
keys = []
|
||||
values = []
|
||||
for i in range(num_pages):
|
||||
k = f"key_{i}"
|
||||
keys.append(k)
|
||||
v = torch.randn((numel,)).to(dtype=dtype)
|
||||
values.append(v)
|
||||
|
||||
ok = hicache_hf3fs.batch_set(keys, values)
|
||||
assert not ok
|
||||
assert hicache_hf3fs.get("key_8") is None
|
||||
assert hicache_hf3fs.get("key_9") is None
|
||||
|
||||
results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])
|
||||
for result, key, value in zip(
|
||||
results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]
|
||||
):
|
||||
assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}"
|
||||
|
||||
hicache_hf3fs.close()
|
||||
os.remove(hicache_hf3fs.file_path)
|
||||
|
||||
print("All test cases passed.")
|
||||
|
||||
|
||||
def bench():
|
||||
# Qwen3-32B
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||
store_dtype = torch.bfloat16
|
||||
tokens_per_page = 64
|
||||
|
||||
file_path = "/data/test.bin"
|
||||
file_size = 1 << 40
|
||||
numjobs = 16
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 8
|
||||
dtype = store_dtype
|
||||
hicache_hf3fs = HiCacheHF3FS(
|
||||
rank=0,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
numjobs=numjobs,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=entries,
|
||||
dtype=dtype,
|
||||
metadata_client=Hf3fsLocalMetadataClient(),
|
||||
)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
assert numel * dtype.itemsize == bytes_per_page
|
||||
|
||||
num_page = 128
|
||||
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
|
||||
|
||||
warmup = 50
|
||||
iteration = 100
|
||||
|
||||
w_bw = []
|
||||
w_size = num_page * bytes_per_page / (1 << 30)
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
|
||||
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
|
||||
tik = time.perf_counter()
|
||||
ok = hicache_hf3fs.batch_set(keys, values)
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
w_bw.append(w_size / (tok - tik))
|
||||
assert ok
|
||||
print_stats(w_bw)
|
||||
|
||||
r_bw = []
|
||||
r_size = num_page * bytes_per_page / (1 << 30)
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||
keys = random.sample(
|
||||
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
|
||||
num_page,
|
||||
)
|
||||
tik = time.perf_counter()
|
||||
results = hicache_hf3fs.batch_get(keys)
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
r_bw.append(r_size / (tok - tik))
|
||||
assert all([r is not None for r in results])
|
||||
print_stats(r_bw)
|
||||
|
||||
hicache_hf3fs.close()
|
||||
|
||||
|
||||
def allclose():
|
||||
# Qwen3-32B
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||
store_dtype = torch.bfloat16
|
||||
tokens_per_page = 64
|
||||
|
||||
file_path = "/data/test.bin"
|
||||
file_size = 1 << 40
|
||||
numjobs = 16
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 8
|
||||
dtype = store_dtype
|
||||
hicache_hf3fs = HiCacheHF3FS(
|
||||
rank=0,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
numjobs=numjobs,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=entries,
|
||||
dtype=dtype,
|
||||
metadata_client=Hf3fsLocalMetadataClient(),
|
||||
)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
assert numel * dtype.itemsize == bytes_per_page
|
||||
|
||||
num_page = 128
|
||||
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
|
||||
|
||||
iteration = 100
|
||||
|
||||
for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"):
|
||||
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
|
||||
ok = hicache_hf3fs.batch_set(keys, values)
|
||||
assert ok
|
||||
|
||||
read_keys, read_results = [], []
|
||||
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
|
||||
keys = random.sample(
|
||||
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
|
||||
num_page,
|
||||
)
|
||||
results = hicache_hf3fs.batch_get(keys)
|
||||
read_keys.extend(keys)
|
||||
read_results.extend(results)
|
||||
assert all([r is not None for r in results])
|
||||
|
||||
for key, result in tqdm(zip(read_keys, read_results)):
|
||||
assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)
|
||||
|
||||
hicache_hf3fs.close()
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
test()
|
||||
bench()
|
||||
allclose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
140
third_party/sglang/benchmark/hf3fs/bench_zerocopy.py
vendored
Normal file
140
third_party/sglang/benchmark/hf3fs/bench_zerocopy.py
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_world_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.managers.cache_controller import (
|
||||
HiCacheController,
|
||||
PrefetchOperation,
|
||||
StorageOperation,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
group = get_world_group().cpu_group
|
||||
|
||||
max_total_num_tokens = 524288
|
||||
page_size = 64
|
||||
kv_cache_dtype = torch.bfloat16
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
device = "cuda"
|
||||
hicache_ratio = 2
|
||||
hicache_size = 0
|
||||
hicache_mem_layout = "page_first"
|
||||
# hicache_mem_layout = "layer_first"
|
||||
hicache_write_policy = "write_through"
|
||||
hicache_io_backend = "kernel"
|
||||
hicache_storage_backend = "hf3fs"
|
||||
prefetch_threshold = 256
|
||||
|
||||
op_size = 1024
|
||||
op_num = 16
|
||||
|
||||
token_to_kv_pool = MHATokenToKVPool(
|
||||
max_total_num_tokens,
|
||||
page_size=page_size,
|
||||
dtype=kv_cache_dtype,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
layer_num=layer_num,
|
||||
device=device,
|
||||
enable_memory_saver=True,
|
||||
)
|
||||
|
||||
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
max_total_num_tokens,
|
||||
dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
kvcache=token_to_kv_pool,
|
||||
need_sort=False,
|
||||
)
|
||||
|
||||
kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
kv_cache,
|
||||
hicache_ratio,
|
||||
hicache_size,
|
||||
page_size,
|
||||
hicache_mem_layout,
|
||||
)
|
||||
|
||||
load_cache_event = threading.Event()
|
||||
cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator,
|
||||
token_to_kv_pool_host,
|
||||
page_size,
|
||||
group,
|
||||
load_cache_event=load_cache_event,
|
||||
write_policy=hicache_write_policy,
|
||||
io_backend=hicache_io_backend,
|
||||
storage_backend=hicache_storage_backend,
|
||||
prefetch_threshold=prefetch_threshold,
|
||||
)
|
||||
|
||||
operations = [
|
||||
StorageOperation(
|
||||
torch.tensor(list(range(i, i + op_size))),
|
||||
list(range(i, i + op_size)),
|
||||
hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
|
||||
)
|
||||
for i in tqdm(range(0, op_num * op_size, op_size))
|
||||
]
|
||||
|
||||
tik = time.monotonic()
|
||||
if hicache_mem_layout == "page_first":
|
||||
for operation in operations:
|
||||
cache_controller.zerocopy_page_backup(operation, batch_size=128)
|
||||
elif hicache_mem_layout == "layer_first":
|
||||
for operation in operations:
|
||||
cache_controller.generic_page_backup(operation, batch_size=128)
|
||||
tok = time.monotonic()
|
||||
print(f"{tok-tik:.6f} s")
|
||||
|
||||
operations = [
|
||||
PrefetchOperation(
|
||||
f"{i}",
|
||||
torch.tensor(list(range(i, i + op_size))),
|
||||
list(range(i, i + op_size)),
|
||||
f"{i}",
|
||||
)
|
||||
for i in tqdm(range(0, op_num * op_size, op_size))
|
||||
]
|
||||
|
||||
for operation in operations:
|
||||
operation.hash_value = [
|
||||
f"{j}"
|
||||
for j in range(
|
||||
int(operation.last_hash), int(operation.last_hash) + op_size, page_size
|
||||
)
|
||||
]
|
||||
|
||||
tik = time.monotonic()
|
||||
if hicache_mem_layout == "page_first":
|
||||
for operation in operations:
|
||||
cache_controller.zerocopy_page_transfer(operation, batch_size=128)
|
||||
elif hicache_mem_layout == "layer_first":
|
||||
for operation in operations:
|
||||
cache_controller.generic_page_transfer(operation, batch_size=128)
|
||||
tok = time.monotonic()
|
||||
print(f"{tok-tik:.6f} s")
|
||||
91
third_party/sglang/benchmark/hicache/README.md
vendored
Normal file
91
third_party/sglang/benchmark/hicache/README.md
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
## Run synthetic multi-turn benchmark
|
||||
|
||||
```
|
||||
# SGLang server with radix cache disabled
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache
|
||||
|
||||
# SGLang server with radix cache on and first-come-first-serve policy
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs
|
||||
|
||||
# The default SGLang server with radix cache on and long-prefix-match policy
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000
|
||||
|
||||
# SGLang server with hierarchical radix cache enabled
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
|
||||
```
|
||||
|
||||
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
|
||||
|
||||
|
||||
# Benchmark with more datasets
|
||||
## Download Dataset
|
||||
```bash
|
||||
./download.sh {sharegpt|ultragpt|loogle|nextqa|all}
|
||||
```
|
||||
This script will automatically download the required dataset to the current working directory
|
||||
|
||||
## Multiturn Benchmark
|
||||
### Supported Datasets
|
||||
- sharegpt
|
||||
- ultrachat
|
||||
- loogle
|
||||
### Example Usage:
|
||||
```bash
|
||||
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
|
||||
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
|
||||
--port 8001 --enable-multiturn --disable-shuffle
|
||||
```
|
||||
This uses `mistralai/Mistral-7B-Instruct-v0.3` model with `sglang` as backend. The dataset
|
||||
is `longdep_qa.json`. We send `10 conversations` with `10 req/s` to port 8001. We enable
|
||||
multiturn chat without shuffling the order of conversations (i.e. following the original
|
||||
order in the dataset file).
|
||||
|
||||
### Note:
|
||||
The requests of multiple conversations are sent in a round robin fashion.
|
||||
For example, if we have 3 conversations A, B, C whose rounds are `[2, 3, 4]` correspondingly,
|
||||
multiturn chat will send the requests to the backend in the following order: `[A1, B1, C1, A2, B2, C2, B3, C3, C4]`
|
||||
This has implications on the cache reuse patterns: the cache reuse distance is the largest
|
||||
under this request pattern (which means a prefix-aware local scheduler in the backend can
|
||||
yield the most benefit compared to a FIFO scheduler)
|
||||
|
||||
## Shared Prefix Benchmark
|
||||
### Supported Datasets
|
||||
- loogle
|
||||
### Example Usage:
|
||||
```bash
|
||||
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
|
||||
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
|
||||
--port 8001 --enable-shared-prefix --disable-shuffle
|
||||
```
|
||||
### Note:
|
||||
Shared Prefix benchmark sends the questions for the same prompt together. For example,
|
||||
if we have 3 shared prefix A, B, C, which have [2, 3, 4] questions correspondingly,
|
||||
the shared prefix benchmark will send the requests to the
|
||||
backend in the following order: `[A+Q1, A+Q2, B+Q1, B+Q2, B+Q3, C+Q1, C+Q2, C+Q3]`.
|
||||
|
||||
|
||||
## Multi Modality Benchmark (WIP)
|
||||
### Supported Datasets:
|
||||
- nextqa
|
||||
### Example Usage:
|
||||
```bash
|
||||
Server:
|
||||
python3 -m sglang.launch_server --model-path lmms-lab/LLaVA-NeXT-Video-7B --tp 2 --dp 1 --port 8001 \
|
||||
--host 0.0.0.0 --mem-fraction-static 0.9 --tokenizer-path llava-hf/llava-1.5-7b-hf \
|
||||
--json-model-override-args "{\"architectures\": [\"LlavaVidForCausalLM\"], \"model_type\":\"llava\", \"mm_spatial_pool_stride\":2}"
|
||||
|
||||
Client:
|
||||
python3 bench_serving.py --model lmms-lab/LLaVA-NeXT-Video-7B --backend sglang --dataset-path \
|
||||
NExTVideo --dataset-name nextqa --request-rate 10 --num-prompts 1 --disable-shuffle --port 8001 \ --enable-multiturn --max-frames 16 --tokenizer llava-hf/llava-1.5-7b-hf --fixed-output-len 2048
|
||||
```
|
||||
Note: for the server args, `tokenizer-path`, overriding architecture are necessary.
|
||||
|
||||
## Supported Backend
|
||||
- sglang (oai)
|
||||
- vllm (oai)
|
||||
- lmdeploy (oai)
|
||||
100
third_party/sglang/benchmark/hicache/bench_long_context.py
vendored
Normal file
100
third_party/sglang/benchmark/hicache/bench_long_context.py
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
import json
|
||||
import queue
|
||||
import time
|
||||
|
||||
import requests
|
||||
from bench_multiturn import (
|
||||
ReadyQueue,
|
||||
WorkloadGenerator,
|
||||
gen_payload,
|
||||
log_to_jsonl_file,
|
||||
parse_args,
|
||||
)
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from sglang.benchmark.utils import get_tokenizer
|
||||
|
||||
|
||||
class ContextWorkloadGenerator(WorkloadGenerator):
|
||||
def __init__(self, args):
|
||||
# Construct the base URL for requests
|
||||
self.baseurl = f"http://{args.host}:{args.port}/"
|
||||
self.url = self.baseurl + "generate"
|
||||
|
||||
self.tokenizer = get_tokenizer(args.model_path)
|
||||
self.distribution = args.distribution
|
||||
self.request_rate = args.request_rate
|
||||
self.start_time = None
|
||||
self.finished_time = None
|
||||
|
||||
self.sent_requests = 0
|
||||
self.completed_requests = 0
|
||||
|
||||
self.dataset = json.load(open(args.dataset_path))
|
||||
num_requests = min(args.num_clients, len(self.dataset["queries"]))
|
||||
|
||||
init_requests = []
|
||||
for i in range(num_requests):
|
||||
context_id = self.dataset["queries"][i]["context"]
|
||||
# Tokenize the context + question to get input_ids
|
||||
prompt_text = (
|
||||
self.dataset["contexts"][context_id]
|
||||
+ self.dataset["queries"][i]["question"]
|
||||
)
|
||||
input_ids = self.tokenizer.encode(prompt_text)
|
||||
output_len = len(
|
||||
self.tokenizer(self.dataset["queries"][i]["reference_answer"])[
|
||||
"input_ids"
|
||||
]
|
||||
)
|
||||
init_requests.append((i, gen_payload(input_ids, output_len)))
|
||||
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
||||
|
||||
self.response_queue = queue.Queue()
|
||||
self.pbar = tqdm(total=num_requests)
|
||||
self.performance_metrics = {
|
||||
"ttft": [],
|
||||
"latency": [],
|
||||
"itl": [],
|
||||
"prompt_len": [],
|
||||
"cached_tokens": [],
|
||||
"generated_len": [],
|
||||
}
|
||||
|
||||
self.max_parallel = args.max_parallel
|
||||
self.logfile = args.log_file
|
||||
self.enable_round_barrier = False
|
||||
|
||||
def response_handler(self):
|
||||
while True:
|
||||
try:
|
||||
client_id, response = self.response_queue.get(
|
||||
timeout=10
|
||||
) # Block until response is available
|
||||
if not response.success:
|
||||
raise ValueError(f"Request failed with error: {response.error}")
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["itl"].extend(response.itl)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||
self.performance_metrics["generated_len"].append(response.generated_len)
|
||||
self.completed_requests += 1
|
||||
|
||||
except queue.Empty:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
args.num_rounds = 1
|
||||
args.max_parallel = 24
|
||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||
|
||||
for request_rate in [24, 16, 12, 8, 4, 2, 1]:
|
||||
args.request_rate = request_rate
|
||||
requests.post(flush_cache_url)
|
||||
time.sleep(1)
|
||||
performance_data = ContextWorkloadGenerator(args).run()
|
||||
log_to_jsonl_file(performance_data, args.log_file, args.tag)
|
||||
571
third_party/sglang/benchmark/hicache/bench_mix.py
vendored
Normal file
571
third_party/sglang/benchmark/hicache/bench_mix.py
vendored
Normal file
@@ -0,0 +1,571 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
|
||||
import aiohttp
|
||||
|
||||
from sglang.bench_serving import RequestFuncOutput
|
||||
from sglang.benchmark.datasets.random import sample_random_requests
|
||||
from sglang.benchmark.utils import get_tokenizer, remove_prefix
|
||||
|
||||
# Set up logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set up JSONL file for debug logging
|
||||
debug_log_file = None
|
||||
# Create a lock for thread-safe debug log writing
|
||||
debug_log_lock = threading.Lock()
|
||||
|
||||
|
||||
def write_debug_log(data):
|
||||
global debug_log_file
|
||||
|
||||
"""Write debug information to a JSONL file"""
|
||||
if debug_log_file is None:
|
||||
return
|
||||
|
||||
# Acquire lock for thread-safe writing
|
||||
with debug_log_lock:
|
||||
# Write as JSONL (JSON Line format)
|
||||
debug_log_file.write(json.dumps(data) + "\n")
|
||||
debug_log_file.flush()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to benchmark concurrent requests to a server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="/data/models/Qwen3-0.6B",
|
||||
help="model path compatible with Hugging Face Transformers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default="/data/models/ShareGPT_V3_unfiltered_cleaned_split/ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
help="local dataset to sample tokens from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Server hostname or IP (default: localhost)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=30000,
|
||||
help="Server port (default: 30000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Duration to run the benchmark in seconds (default: 300 seconds)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=["debug", "info"],
|
||||
help="Set the logging level (default: info)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-log-file",
|
||||
type=str,
|
||||
default="debug.log.jsonl",
|
||||
help="File to write debug logs in JSONL format",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_config():
|
||||
config_path = os.getenv("CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError("Environment variable 'CONFIG_PATH' is not set.")
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
required_keys = [
|
||||
"num_rounds",
|
||||
"num_clients",
|
||||
"round_ratios",
|
||||
"mean_new_tokens_per_round",
|
||||
"mean_return_tokens_per_round",
|
||||
"mean_inter_round_interval",
|
||||
]
|
||||
|
||||
for key in required_keys:
|
||||
if key not in config:
|
||||
raise KeyError(f"Missing required configuration key: {key}")
|
||||
|
||||
num_rounds = config["num_rounds"]
|
||||
assert len(config["round_ratios"]) == num_rounds
|
||||
assert len(config["mean_new_tokens_per_round"]) == num_rounds
|
||||
assert len(config["mean_return_tokens_per_round"]) == num_rounds
|
||||
assert len(config["mean_inter_round_interval"]) == num_rounds
|
||||
|
||||
print(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserData:
|
||||
user_id: int
|
||||
current_round: int
|
||||
total_rounds: int
|
||||
prompt: str
|
||||
return_tokens: int
|
||||
start: int
|
||||
|
||||
|
||||
def synchronized():
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class UserGenerator:
|
||||
def __init__(self, config, model_path, dataset_path):
|
||||
self.tokenizer_path = model_path
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_path)
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
self.user_id = 0
|
||||
self.lock = threading.Lock()
|
||||
|
||||
self.num_rounds = config["num_rounds"]
|
||||
|
||||
self.cumulative_ratios = [
|
||||
sum(config["round_ratios"][: i + 1])
|
||||
for i in range(len(config["round_ratios"]))
|
||||
]
|
||||
self.mean_new_tokens_per_round = config["mean_new_tokens_per_round"]
|
||||
self.mean_return_tokens_per_round = config["mean_return_tokens_per_round"]
|
||||
self.mean_inter_round_interval = config["mean_inter_round_interval"]
|
||||
|
||||
self.sigma = 100
|
||||
self.range_ratio = 0.8
|
||||
assert self.range_ratio <= 1
|
||||
|
||||
self.candidate_inputs = [
|
||||
[
|
||||
r
|
||||
for r in sample_random_requests(
|
||||
input_len=(
|
||||
self.mean_new_tokens_per_round[i] * (2 - self.range_ratio)
|
||||
),
|
||||
output_len=(
|
||||
self.mean_return_tokens_per_round[i] * (2 - self.range_ratio)
|
||||
),
|
||||
num_prompts=config["num_clients"],
|
||||
range_ratio=self.range_ratio / (2 - self.range_ratio),
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path=self.dataset_path,
|
||||
random_sample=False,
|
||||
)
|
||||
]
|
||||
for i in range(self.num_rounds)
|
||||
]
|
||||
|
||||
self.multiturn_queue = []
|
||||
|
||||
self.user_stats = [0 for _ in range(self.num_rounds)]
|
||||
self.input_stats = [[0, 0] for _ in range(self.num_rounds)]
|
||||
self.output_stats = [[0, 0] for _ in range(self.num_rounds)]
|
||||
|
||||
def gen(self):
|
||||
user_id = self.user_id
|
||||
self.user_id += 1
|
||||
|
||||
rand_ratio = random.randint(0, self.cumulative_ratios[-1])
|
||||
i = len(self.cumulative_ratios)
|
||||
for idx, cumulative_ratio in enumerate(self.cumulative_ratios):
|
||||
if rand_ratio >= cumulative_ratio:
|
||||
continue
|
||||
else:
|
||||
i = idx + 1
|
||||
break
|
||||
total_rounds = i
|
||||
current_round = 0
|
||||
|
||||
candidate_input = random.sample(self.candidate_inputs[current_round], 1)[0]
|
||||
self.input_stats[0][0] += candidate_input.prompt_len
|
||||
self.input_stats[0][1] += 1
|
||||
prompt = f"{user_id} " + candidate_input.prompt
|
||||
return_tokens = int(
|
||||
random.gauss(self.mean_return_tokens_per_round[current_round], self.sigma)
|
||||
)
|
||||
if return_tokens <= 0:
|
||||
return_tokens = self.mean_return_tokens_per_round[current_round]
|
||||
start = 0
|
||||
|
||||
user_data = UserData(
|
||||
user_id, current_round, total_rounds, prompt, return_tokens, start
|
||||
)
|
||||
|
||||
self.user_stats[total_rounds - 1] += 1
|
||||
|
||||
return user_data
|
||||
|
||||
@synchronized()
|
||||
def push(self, user_data, generated_text, len_itl):
|
||||
self.output_stats[user_data.current_round][0] += len_itl + 1
|
||||
self.output_stats[user_data.current_round][1] += 1
|
||||
user_data.current_round += 1
|
||||
if user_data.current_round >= user_data.total_rounds:
|
||||
return
|
||||
|
||||
candidate_input = random.sample(
|
||||
self.candidate_inputs[user_data.current_round], 1
|
||||
)[0]
|
||||
self.input_stats[user_data.current_round][0] += candidate_input.prompt_len
|
||||
self.input_stats[user_data.current_round][1] += 1
|
||||
user_data.prompt += generated_text + candidate_input.prompt
|
||||
user_data.return_tokens = int(
|
||||
random.gauss(
|
||||
self.mean_return_tokens_per_round[user_data.current_round], self.sigma
|
||||
)
|
||||
)
|
||||
if user_data.return_tokens <= 0:
|
||||
user_data.return_tokens = self.mean_return_tokens_per_round[
|
||||
user_data.current_round
|
||||
]
|
||||
interval = random.gauss(
|
||||
self.mean_inter_round_interval[user_data.current_round], self.sigma
|
||||
)
|
||||
if interval <= 0:
|
||||
interval = self.mean_inter_round_interval[user_data.current_round]
|
||||
user_data.start = time.perf_counter() + interval
|
||||
|
||||
if len(self.multiturn_queue) == 0:
|
||||
self.multiturn_queue.append(user_data)
|
||||
else:
|
||||
i = len(self.multiturn_queue)
|
||||
for idx, d in enumerate(self.multiturn_queue):
|
||||
if user_data.start < d.start:
|
||||
i = idx
|
||||
break
|
||||
self.multiturn_queue.insert(idx, user_data)
|
||||
|
||||
@synchronized()
|
||||
def pop(self):
|
||||
if (
|
||||
len(self.multiturn_queue)
|
||||
and time.perf_counter() > self.multiturn_queue[0].start
|
||||
):
|
||||
return self.multiturn_queue.pop(0)
|
||||
return self.gen()
|
||||
|
||||
|
||||
def gen_payload(prompt, output_len):
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
"lora_path": "",
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": -1,
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
|
||||
|
||||
|
||||
async def async_request_sglang_generate(
|
||||
user_data,
|
||||
url,
|
||||
atomic_counter,
|
||||
):
|
||||
"""
|
||||
Sends a streaming request to the server. Gathers text token-by-token.
|
||||
"""
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {}
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
output = RequestFuncOutput()
|
||||
payload = gen_payload(user_data.prompt, user_data.return_tokens)
|
||||
write_debug_log({"timestamp": st, "user_data": user_data.__dict__})
|
||||
|
||||
try:
|
||||
async with session.post(url=url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
prompt_tokens = 0
|
||||
cached_tokens = 0
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
if data.get("text"):
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
prompt_tokens = (data.get("meta_info") or {}).get(
|
||||
"prompt_tokens", 0
|
||||
)
|
||||
cached_tokens = (data.get("meta_info") or {}).get(
|
||||
"cached_tokens", 0
|
||||
)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text = data["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.prompt_len = prompt_tokens
|
||||
output.cached_tokens = cached_tokens
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
print(f"Request failed: {e}")
|
||||
|
||||
atomic_counter.increment(1)
|
||||
return output
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
def __init__(self, initial_value=0):
|
||||
self._value = initial_value
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@synchronized()
|
||||
def increment(self, amount=1):
|
||||
self._value += amount
|
||||
|
||||
@synchronized()
|
||||
def get(self):
|
||||
return self._value
|
||||
|
||||
|
||||
class WorkloadGenerator:
|
||||
def __init__(self, args):
|
||||
config = load_config()
|
||||
user_generator = UserGenerator(
|
||||
config,
|
||||
args.model_path,
|
||||
args.dataset_path,
|
||||
)
|
||||
|
||||
self.url = f"http://{args.host}:{args.port}/generate"
|
||||
|
||||
self.tokenizer = user_generator.tokenizer
|
||||
self.start_time = None
|
||||
self.finished_time = None
|
||||
self.duration = args.duration
|
||||
self.done = False
|
||||
|
||||
self.sent_requests = 0
|
||||
self.completed_requests = 0
|
||||
|
||||
self.user_generator = user_generator
|
||||
self.response_queue = queue.Queue()
|
||||
self.performance_metrics = {
|
||||
"ttft": [],
|
||||
"latency": [],
|
||||
"prompt_len": [],
|
||||
"cached_tokens": [],
|
||||
}
|
||||
self.max_parallel = config["num_clients"]
|
||||
|
||||
self.atomic_counter = AtomicCounter()
|
||||
|
||||
async def handle_request(self, user_data):
|
||||
try:
|
||||
response = await async_request_sglang_generate(
|
||||
user_data, self.url, self.atomic_counter
|
||||
)
|
||||
self.response_queue.put((user_data, response))
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
self.completed_requests += 1
|
||||
|
||||
def request_sender(self):
|
||||
async def request_loop():
|
||||
tasks = []
|
||||
while True:
|
||||
if self.sent_requests - self.completed_requests < self.max_parallel:
|
||||
new_request = self.user_generator.pop()
|
||||
if new_request:
|
||||
task = asyncio.create_task(self.handle_request(new_request))
|
||||
tasks.append(task)
|
||||
self.sent_requests += 1
|
||||
else:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if time.perf_counter() - self.start_time > self.duration:
|
||||
self.done = True
|
||||
break
|
||||
|
||||
# Cancel all pending tasks and wait for them to finish
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(request_loop())
|
||||
loop.close()
|
||||
|
||||
def response_handler(self):
|
||||
while True:
|
||||
try:
|
||||
user_data, response = self.response_queue.get(timeout=10)
|
||||
logger.info(
|
||||
f"{((time.perf_counter()-self.start_time)/self.duration*100):.2f}%"
|
||||
)
|
||||
if not response.success:
|
||||
raise ValueError(f"Request failed with error: {response.error}")
|
||||
|
||||
self.user_generator.push(
|
||||
user_data, response.generated_text, len(response.itl)
|
||||
)
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||
self.completed_requests += 1
|
||||
self.finished_time = time.perf_counter()
|
||||
|
||||
except queue.Empty:
|
||||
if self.done:
|
||||
break
|
||||
except ValueError as e:
|
||||
print(f"Error processing response for client {user_data}: {e}")
|
||||
continue
|
||||
|
||||
def run(self):
|
||||
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
||||
response_thread = threading.Thread(target=self.response_handler, daemon=True)
|
||||
|
||||
self.start_time = time.perf_counter()
|
||||
request_thread.start()
|
||||
response_thread.start()
|
||||
|
||||
request_thread.join()
|
||||
response_thread.join()
|
||||
|
||||
performance_data = {
|
||||
"summary": {
|
||||
"total_requests": len(self.performance_metrics["ttft"]),
|
||||
"average_ttft": sum(self.performance_metrics["ttft"])
|
||||
/ len(self.performance_metrics["ttft"]),
|
||||
"p90_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
int(0.9 * len(self.performance_metrics["ttft"]))
|
||||
],
|
||||
"median_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
len(self.performance_metrics["ttft"]) // 2
|
||||
],
|
||||
"average_latency": sum(self.performance_metrics["latency"])
|
||||
/ len(self.performance_metrics["latency"]),
|
||||
"p90_latency": sorted(self.performance_metrics["latency"])[
|
||||
int(0.9 * len(self.performance_metrics["latency"]))
|
||||
],
|
||||
"median_latency": sorted(self.performance_metrics["latency"])[
|
||||
len(self.performance_metrics["latency"]) // 2
|
||||
],
|
||||
"throughput": self.atomic_counter.get()
|
||||
/ (self.finished_time - self.start_time),
|
||||
"cache_hit_rate": (
|
||||
0
|
||||
if sum(self.performance_metrics["prompt_len"]) == 0
|
||||
else sum(self.performance_metrics["cached_tokens"])
|
||||
/ sum(self.performance_metrics["prompt_len"])
|
||||
),
|
||||
},
|
||||
}
|
||||
print("All requests completed")
|
||||
print("Performance metrics summary:")
|
||||
print(f" Total requests: {performance_data['summary']['total_requests']}")
|
||||
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
|
||||
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
|
||||
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
|
||||
print(
|
||||
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
|
||||
)
|
||||
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
||||
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
||||
print(
|
||||
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||
)
|
||||
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
||||
|
||||
user_stats = self.user_generator.user_stats
|
||||
input_stats = self.user_generator.input_stats
|
||||
output_stats = self.user_generator.output_stats
|
||||
print(f"round_ratios: {user_stats}")
|
||||
print(
|
||||
f"mean_new_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in input_stats]}"
|
||||
)
|
||||
print(
|
||||
f"mean_return_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in output_stats]}"
|
||||
)
|
||||
return performance_data
|
||||
|
||||
|
||||
def main():
|
||||
global debug_log_file
|
||||
|
||||
args = parse_args()
|
||||
if args.log_level == "debug":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger.info("use log_level debug")
|
||||
# Initialize debug log file
|
||||
debug_log_file = open(args.debug_log_file, "w")
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.info("use log_level info")
|
||||
performance_data = WorkloadGenerator(args).run()
|
||||
|
||||
# Close debug log file if it was opened
|
||||
if debug_log_file:
|
||||
debug_log_file.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
42
third_party/sglang/benchmark/hicache/bench_mix.sh
vendored
Executable file
42
third_party/sglang/benchmark/hicache/bench_mix.sh
vendored
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||
rm -rf nohup.out && \
|
||||
nohup python3 -m sglang.launch_server \
|
||||
--attention-backend triton \
|
||||
--model-path /code/models/Qwen3-32B/ \
|
||||
--log-level info \
|
||||
--tp 4 --mem-frac 0.25 \
|
||||
--host 0.0.0.0 --port 33301 \
|
||||
--enable-metrics --enable-cache-report \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2.5 --hicache-size 0 \
|
||||
--hicache-io-backend kernel \
|
||||
--hicache-mem-layout layer_first \
|
||||
--hicache-write-policy write_through \
|
||||
&
|
||||
|
||||
##################################################
|
||||
|
||||
export CONFIG_PATH=/tmp/bench_mix_config.json
|
||||
|
||||
# num_clients: Maximum number of concurrent client requests to be simulated
|
||||
# round_ratios: Distribution of requests across rounds. Given sum(round_ratios) total requests,
|
||||
# round_ratios[i] denotes the number of requests that will execute for (i+1) rounds
|
||||
echo '{
|
||||
"num_rounds": 10,
|
||||
"num_clients": 60,
|
||||
"round_ratios": [50, 25, 15, 15, 10, 10, 9, 8, 7, 6],
|
||||
"mean_new_tokens_per_round": [1000, 400, 350, 300, 280, 260, 240, 220, 210, 200],
|
||||
"mean_return_tokens_per_round": [100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
|
||||
"mean_inter_round_interval": [30, 30, 30, 30, 30, 30, 30, 30, 30, 30]
|
||||
}' > ${CONFIG_PATH}
|
||||
|
||||
rm -rf bench_mix.out && \
|
||||
nohup python3 /sgl-workspace/sglang/benchmark/hicache/bench_mix.py \
|
||||
--model-path /code/models/Qwen3-32B/ \
|
||||
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--port 33301 \
|
||||
--duration 600 \
|
||||
> bench_mix.out &
|
||||
755
third_party/sglang/benchmark/hicache/bench_multiturn.py
vendored
Normal file
755
third_party/sglang/benchmark/hicache/bench_multiturn.py
vendored
Normal file
@@ -0,0 +1,755 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from sglang.bench_serving import RequestFuncOutput
|
||||
from sglang.benchmark.datasets.random import sample_random_requests
|
||||
from sglang.benchmark.utils import get_tokenizer
|
||||
from sglang.test.kits.cache_hit_kit import (
|
||||
async_request_openai_chat_completions,
|
||||
async_request_sglang_generate,
|
||||
gen_payload,
|
||||
gen_payload_openai,
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to benchmark concurrent requests to a server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-clients",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of concurrent clients",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Maximum number of parallel requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Length of each new request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Length of each output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-rounds",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of rounds per client",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distribution",
|
||||
type=str,
|
||||
default="poisson",
|
||||
choices=["poisson", "uniform"],
|
||||
help="Distribution type for request intervals (poisson or uniform)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Average number of requests per second",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Server hostname or IP (default: localhost)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=30000,
|
||||
help="Server port (default: 30000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
help="model path compatible with Hugging Face Transformers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="local dataset to sample tokens from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
type=str,
|
||||
default="performance_metrics.jsonl",
|
||||
help="File to log performance metrics",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-auto-run",
|
||||
action="store_true",
|
||||
help="If set, disable automatically testing with a range of request rates.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-random-sample",
|
||||
action="store_true",
|
||||
help="If set, disable random sampling of requests from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-round-barrier",
|
||||
action="store_true",
|
||||
help="If set, only send i-th turn requests after all (i-1)-th turn requests finished.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sub-question-input-length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of the sub question input for each request, if set 0 use request_length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ready-queue-policy",
|
||||
type=str,
|
||||
default="random",
|
||||
help="Policy for popping requests from the ready queue (random or fifo)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tag",
|
||||
type=str,
|
||||
default="",
|
||||
help="Tag of a certain run in the log file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-rounds",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Min rounds per client (0 = use --num-rounds)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-rounds",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Max rounds per client (0 = use --num-rounds)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Length variation ratio for prompts and outputs (1.0 = no variation, 0.5 = 50%% variation)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-format",
|
||||
type=str,
|
||||
default="sglang",
|
||||
choices=["sglang", "openai"],
|
||||
help="API format to use: 'sglang' for native /generate endpoint, "
|
||||
"'openai' for OpenAI-compatible /v1/chat/completions endpoint.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl", tag=""):
|
||||
"""Append the data with a timestamp and tag to the specified JSONL file."""
|
||||
timestamped_data = {"timestamp": datetime.now().isoformat(), "tag": tag, **data}
|
||||
try:
|
||||
with open(file_path, "a") as file:
|
||||
file.write(
|
||||
json.dumps(timestamped_data) + "\n"
|
||||
) # Write as a single line in JSONL format
|
||||
except IOError as e:
|
||||
print(f"Error writing to JSONL file: {e}")
|
||||
|
||||
|
||||
class ReadyQueue:
|
||||
"""
|
||||
Thread-safe queue that can pop requests in different orders based on given policy.
|
||||
"""
|
||||
|
||||
def __init__(self, init_requests=None, policy="random"):
|
||||
self.lock = threading.Lock()
|
||||
self.requests = init_requests or []
|
||||
self.policy = policy
|
||||
|
||||
def append(self, item):
|
||||
with self.lock:
|
||||
self.requests.append(item)
|
||||
|
||||
def pop(self):
|
||||
with self.lock:
|
||||
if not self.requests:
|
||||
return None
|
||||
if self.policy == "random":
|
||||
index = random.randrange(len(self.requests))
|
||||
return self.requests.pop(index)
|
||||
elif self.policy == "fifo":
|
||||
return self.requests.pop(0)
|
||||
else:
|
||||
# todo, varying thinking time of clients
|
||||
raise ValueError(f"{self.policy} not implemented")
|
||||
|
||||
|
||||
class WorkloadGenerator:
|
||||
def __init__(self, args):
|
||||
self.api_format = args.api_format
|
||||
self.model_path = args.model_path
|
||||
|
||||
# Construct the base URL and select request/payload functions
|
||||
if self.api_format == "openai":
|
||||
self.url = f"http://{args.host}:{args.port}/v1/chat/completions"
|
||||
self.request_func = async_request_openai_chat_completions
|
||||
else:
|
||||
self.url = f"http://{args.host}:{args.port}/generate"
|
||||
self.request_func = async_request_sglang_generate
|
||||
|
||||
self.tokenizer = get_tokenizer(args.model_path)
|
||||
self.distribution = args.distribution
|
||||
self.request_rate = args.request_rate
|
||||
self.start_time = None
|
||||
self.finished_time = None
|
||||
self.lora_path = args.lora_path
|
||||
|
||||
self.sent_requests = 0
|
||||
self.completed_requests = 0
|
||||
|
||||
# Resolve per-client round counts
|
||||
min_rounds = args.min_rounds
|
||||
max_rounds = args.max_rounds
|
||||
if min_rounds == 0 and max_rounds == 0:
|
||||
# Backward compat: all clients use --num-rounds
|
||||
min_rounds = args.num_rounds
|
||||
max_rounds = args.num_rounds
|
||||
elif min_rounds == 0:
|
||||
min_rounds = max_rounds
|
||||
elif max_rounds == 0:
|
||||
max_rounds = min_rounds
|
||||
if min_rounds < 1:
|
||||
raise ValueError(f"--min-rounds must be >= 1, got {min_rounds}")
|
||||
if min_rounds > max_rounds:
|
||||
raise ValueError(
|
||||
f"--min-rounds ({min_rounds}) must be <= --max-rounds ({max_rounds})"
|
||||
)
|
||||
|
||||
self.min_rounds = min_rounds
|
||||
self.max_rounds = max_rounds
|
||||
|
||||
if min_rounds == max_rounds:
|
||||
# All clients have the same round count; skip randint to preserve random state
|
||||
self.client_total_rounds = [min_rounds] * args.num_clients
|
||||
else:
|
||||
self.client_total_rounds = [
|
||||
random.randint(min_rounds, max_rounds) for _ in range(args.num_clients)
|
||||
]
|
||||
|
||||
# clients_per_round[r] = number of clients participating in round r
|
||||
self.clients_per_round = [
|
||||
sum(1 for t in self.client_total_rounds if t > r) for r in range(max_rounds)
|
||||
]
|
||||
self.total_requests = sum(self.client_total_rounds)
|
||||
|
||||
range_ratio = args.range_ratio
|
||||
|
||||
# Use return_text=False to get token ids instead of text
|
||||
first_round_samples = sample_random_requests(
|
||||
input_len=args.request_length,
|
||||
output_len=args.output_length,
|
||||
num_prompts=args.num_clients,
|
||||
range_ratio=range_ratio,
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=not args.disable_random_sample,
|
||||
return_text=False,
|
||||
)
|
||||
# Store per-sample output_len for first round
|
||||
first_round_output_lens = [row.output_len for row in first_round_samples]
|
||||
# r.prompt is now List[int] when return_text=False
|
||||
self.candidate_inputs = [list(i.prompt) for i in first_round_samples]
|
||||
|
||||
if args.sub_question_input_length != 0:
|
||||
sub_question_input_length = args.sub_question_input_length
|
||||
else:
|
||||
sub_question_input_length = args.request_length
|
||||
|
||||
num_sub_questions = sum(max(t - 1, 0) for t in self.client_total_rounds)
|
||||
|
||||
self.sub_question_inputs = sample_random_requests(
|
||||
input_len=sub_question_input_length,
|
||||
output_len=args.output_length,
|
||||
num_prompts=max(num_sub_questions, 1),
|
||||
range_ratio=range_ratio,
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=not args.disable_random_sample,
|
||||
return_text=False,
|
||||
)
|
||||
|
||||
if self.api_format == "openai":
|
||||
# OpenAI mode: history is a messages list for /v1/chat/completions
|
||||
initial_messages = {
|
||||
i: [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.tokenizer.decode(self.candidate_inputs[i]),
|
||||
}
|
||||
]
|
||||
for i in range(args.num_clients)
|
||||
}
|
||||
init_requests = [
|
||||
(
|
||||
i,
|
||||
gen_payload_openai(
|
||||
initial_messages[i],
|
||||
first_round_output_lens[i],
|
||||
self.model_path,
|
||||
),
|
||||
)
|
||||
for i in range(args.num_clients)
|
||||
]
|
||||
self.client_records = {
|
||||
i: {
|
||||
"round": 0,
|
||||
"history": initial_messages[i],
|
||||
"total_rounds": self.client_total_rounds[i],
|
||||
}
|
||||
for i in range(args.num_clients)
|
||||
}
|
||||
else:
|
||||
# SGLang mode: history is List[int] (token ids)
|
||||
init_requests = [
|
||||
(
|
||||
i,
|
||||
gen_payload(
|
||||
self.candidate_inputs[i],
|
||||
first_round_output_lens[i],
|
||||
args.lora_path,
|
||||
),
|
||||
)
|
||||
for i in range(args.num_clients)
|
||||
]
|
||||
self.client_records = {
|
||||
i: {
|
||||
"round": 0,
|
||||
"history": list(self.candidate_inputs[i]),
|
||||
"total_rounds": self.client_total_rounds[i],
|
||||
}
|
||||
for i in range(args.num_clients)
|
||||
}
|
||||
self.ready_queue = ReadyQueue(
|
||||
init_requests=init_requests, policy=args.ready_queue_policy
|
||||
)
|
||||
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
||||
|
||||
self.response_queue = queue.Queue()
|
||||
self.pbar = tqdm(total=self.total_requests)
|
||||
self.performance_metrics = {
|
||||
"ttft": [],
|
||||
"itl": [],
|
||||
"latency": [],
|
||||
"prompt_len": [],
|
||||
"cached_tokens": [],
|
||||
"generated_len": [],
|
||||
}
|
||||
self.enable_round_barrier = args.enable_round_barrier
|
||||
if self.enable_round_barrier:
|
||||
# Add round-specific metrics while preserving the original structure
|
||||
for i in range(self.max_rounds):
|
||||
self.performance_metrics[f"round_{i}"] = {
|
||||
"ttft": [],
|
||||
"latency": [],
|
||||
"prompt_len": [],
|
||||
"cached_tokens": [],
|
||||
"generated_len": [],
|
||||
}
|
||||
self.num_clients = args.num_clients
|
||||
|
||||
self.num_rounds = self.max_rounds
|
||||
self.max_parallel = args.max_parallel
|
||||
self.output_length = args.output_length
|
||||
|
||||
async def handle_request(self, item):
|
||||
client_id, payload = item
|
||||
try:
|
||||
response = await self.request_func(payload, self.url, self.pbar)
|
||||
if self.pbar.n == self.pbar.total:
|
||||
self.finished_time = time.perf_counter()
|
||||
self.response_queue.put((client_id, response))
|
||||
except Exception as e:
|
||||
print(f"Request failed for client {client_id}: {e}")
|
||||
failed_response = RequestFuncOutput()
|
||||
failed_response.success = False
|
||||
failed_response.error = str(e)
|
||||
self.response_queue.put((client_id, failed_response))
|
||||
|
||||
def request_sender(self):
|
||||
async def request_loop():
|
||||
while True:
|
||||
if self.sent_requests - self.completed_requests < self.max_parallel:
|
||||
new_request = self.ready_queue.pop()
|
||||
if new_request:
|
||||
asyncio.create_task(self.handle_request(new_request))
|
||||
self.sent_requests += 1
|
||||
else:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
# Calculate Poisson-distributed wait time
|
||||
if self.distribution == "poisson":
|
||||
sleep_time = random.expovariate(self.request_rate)
|
||||
elif self.distribution == "uniform":
|
||||
avg_interval = (
|
||||
1.0 / self.request_rate if self.request_rate > 0 else 1.0
|
||||
)
|
||||
sleep_time = random.uniform(0, 2 * avg_interval)
|
||||
else:
|
||||
raise ValueError("Invalid distribution type")
|
||||
await asyncio.sleep(sleep_time) # Wait before sending the next request
|
||||
|
||||
# Create and run the event loop for asynchronous requests
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(request_loop())
|
||||
loop.close()
|
||||
|
||||
def response_handler(self):
|
||||
next_round_reqs = []
|
||||
current_barrier_round = 0
|
||||
barrier_round_completed = 0
|
||||
while True:
|
||||
try:
|
||||
client_id, response = self.response_queue.get(
|
||||
timeout=10
|
||||
) # Block until response is available
|
||||
if not response.success:
|
||||
print(f"Request failed for client {client_id}: {response.error}")
|
||||
self.completed_requests += 1
|
||||
continue
|
||||
# Extend history with response
|
||||
if self.api_format == "openai":
|
||||
if response.generated_text:
|
||||
self.client_records[client_id]["history"].append(
|
||||
{"role": "assistant", "content": response.generated_text}
|
||||
)
|
||||
else:
|
||||
self.client_records[client_id]["history"].extend(
|
||||
response.output_ids
|
||||
)
|
||||
current_round = self.client_records[client_id]["round"]
|
||||
self.client_records[client_id]["round"] += 1
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["itl"].extend(response.itl)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||
self.performance_metrics["generated_len"].append(response.generated_len)
|
||||
if self.enable_round_barrier:
|
||||
self.performance_metrics[f"round_{current_round}"]["ttft"].append(
|
||||
response.ttft
|
||||
)
|
||||
self.performance_metrics[f"round_{current_round}"][
|
||||
"latency"
|
||||
].append(response.latency)
|
||||
self.performance_metrics[f"round_{current_round}"][
|
||||
"prompt_len"
|
||||
].append(response.prompt_len)
|
||||
self.performance_metrics[f"round_{current_round}"][
|
||||
"cached_tokens"
|
||||
].append(response.cached_tokens)
|
||||
self.performance_metrics[f"round_{current_round}"][
|
||||
"generated_len"
|
||||
].append(response.generated_len)
|
||||
self.completed_requests += 1
|
||||
|
||||
client_total = self.client_records[client_id]["total_rounds"]
|
||||
if self.client_records[client_id]["round"] < client_total:
|
||||
sub_q = self.sub_question_inputs.pop()
|
||||
if self.api_format == "openai":
|
||||
# Append sub-question as a new user message
|
||||
sub_q_text = self.tokenizer.decode(list(sub_q.prompt))
|
||||
self.client_records[client_id]["history"].append(
|
||||
{"role": "user", "content": sub_q_text}
|
||||
)
|
||||
new_req = (
|
||||
client_id,
|
||||
gen_payload_openai(
|
||||
self.client_records[client_id]["history"],
|
||||
sub_q.output_len,
|
||||
self.model_path,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Append sub-question token ids to client's history
|
||||
sub_q_ids = list(sub_q.prompt)
|
||||
self.client_records[client_id]["history"].extend(sub_q_ids)
|
||||
new_req = (
|
||||
client_id,
|
||||
gen_payload(
|
||||
self.client_records[client_id]["history"],
|
||||
sub_q.output_len,
|
||||
self.lora_path,
|
||||
),
|
||||
)
|
||||
if self.enable_round_barrier:
|
||||
next_round_reqs.append(new_req)
|
||||
else:
|
||||
self.ready_queue.append(new_req)
|
||||
|
||||
# Barrier logic: release next round when all clients for
|
||||
# current barrier round have completed
|
||||
if (
|
||||
self.enable_round_barrier
|
||||
and current_barrier_round < self.max_rounds
|
||||
):
|
||||
barrier_round_completed += 1
|
||||
expected = self.clients_per_round[current_barrier_round]
|
||||
if barrier_round_completed == expected:
|
||||
print(
|
||||
f"\n Barrier: round {current_barrier_round} complete "
|
||||
f"({expected} clients), releasing {len(next_round_reqs)} "
|
||||
f"requests for round {current_barrier_round + 1}"
|
||||
)
|
||||
self._send_heartbeat(input_len=100, output_len=100)
|
||||
time.sleep(10)
|
||||
for req in next_round_reqs:
|
||||
self.ready_queue.append(req)
|
||||
next_round_reqs = []
|
||||
current_barrier_round += 1
|
||||
barrier_round_completed = 0
|
||||
except queue.Empty:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
except ValueError as e:
|
||||
print(f"Error processing response for client {client_id}: {e}")
|
||||
continue
|
||||
|
||||
def _send_heartbeat(self, input_len=100, output_len=20):
|
||||
"""Send a small heartbeat request to the server."""
|
||||
heartbeat_input = [1] * input_len
|
||||
payload = gen_payload(heartbeat_input, output_len, self.lora_path)
|
||||
try:
|
||||
requests.post(self.url, json=payload, timeout=30)
|
||||
except Exception as e:
|
||||
print(f"Heartbeat request failed: {e}")
|
||||
|
||||
def run(self):
|
||||
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
||||
response_thread = threading.Thread(target=self.response_handler, daemon=True)
|
||||
|
||||
self.start_time = time.perf_counter()
|
||||
request_thread.start()
|
||||
response_thread.start()
|
||||
|
||||
request_thread.join()
|
||||
response_thread.join()
|
||||
self.pbar.close()
|
||||
|
||||
duration = self.finished_time - self.start_time
|
||||
sorted_ttft = sorted(self.performance_metrics["ttft"])
|
||||
sorted_latency = sorted(self.performance_metrics["latency"])
|
||||
sorted_itl = sorted(self.performance_metrics["itl"])
|
||||
sorted_prompt_len = sorted(self.performance_metrics["prompt_len"])
|
||||
sorted_output_len = sorted(self.performance_metrics["generated_len"])
|
||||
|
||||
def percentile(sorted_vals, q):
|
||||
if not sorted_vals:
|
||||
return 0.0
|
||||
idx = int(q * len(sorted_vals))
|
||||
if idx >= len(sorted_vals):
|
||||
idx = len(sorted_vals) - 1
|
||||
return sorted_vals[idx]
|
||||
|
||||
def max_or_zero(sorted_vals):
|
||||
return sorted_vals[-1] if sorted_vals else 0.0
|
||||
|
||||
performance_data = {
|
||||
"summary": {
|
||||
"total_requests": len(self.performance_metrics["ttft"]),
|
||||
"request_rate": self.request_rate,
|
||||
"average_prompt_len": (
|
||||
sum(self.performance_metrics["prompt_len"])
|
||||
/ len(self.performance_metrics["prompt_len"])
|
||||
if self.performance_metrics["prompt_len"]
|
||||
else 0.0
|
||||
),
|
||||
"average_output_len": (
|
||||
sum(self.performance_metrics["generated_len"])
|
||||
/ len(self.performance_metrics["generated_len"])
|
||||
if self.performance_metrics["generated_len"]
|
||||
else 0.0
|
||||
),
|
||||
"p90_prompt_len": percentile(sorted_prompt_len, 0.9),
|
||||
"p99_prompt_len": percentile(sorted_prompt_len, 0.99),
|
||||
"p90_output_len": percentile(sorted_output_len, 0.9),
|
||||
"p99_output_len": percentile(sorted_output_len, 0.99),
|
||||
"average_ttft": sum(self.performance_metrics["ttft"])
|
||||
/ len(self.performance_metrics["ttft"]),
|
||||
"p90_ttft": percentile(sorted_ttft, 0.9),
|
||||
"p99_ttft": percentile(sorted_ttft, 0.99),
|
||||
"median_ttft": percentile(sorted_ttft, 0.5),
|
||||
"max_ttft": max_or_zero(sorted_ttft),
|
||||
"average_itl": (
|
||||
sum(self.performance_metrics["itl"])
|
||||
/ len(self.performance_metrics["itl"])
|
||||
if self.performance_metrics["itl"]
|
||||
else 0.0
|
||||
),
|
||||
"p90_itl": percentile(sorted_itl, 0.9),
|
||||
"p99_itl": percentile(sorted_itl, 0.99),
|
||||
"median_itl": percentile(sorted_itl, 0.5),
|
||||
"max_itl": max_or_zero(sorted_itl),
|
||||
"average_latency": sum(self.performance_metrics["latency"])
|
||||
/ len(self.performance_metrics["latency"]),
|
||||
"p90_latency": percentile(sorted_latency, 0.9),
|
||||
"p99_latency": percentile(sorted_latency, 0.99),
|
||||
"median_latency": percentile(sorted_latency, 0.5),
|
||||
"max_latency": max_or_zero(sorted_latency),
|
||||
"input_token_throughput": sum(self.performance_metrics["prompt_len"])
|
||||
/ duration,
|
||||
"output_token_throughput": sum(
|
||||
self.performance_metrics["generated_len"]
|
||||
)
|
||||
/ duration,
|
||||
"throughput": self.pbar.total / duration,
|
||||
"cache_hit_rate": (
|
||||
0
|
||||
if sum(self.performance_metrics["prompt_len"]) == 0
|
||||
else sum(self.performance_metrics["cached_tokens"])
|
||||
/ sum(self.performance_metrics["prompt_len"])
|
||||
),
|
||||
},
|
||||
}
|
||||
if self.enable_round_barrier:
|
||||
performance_data["round"] = {}
|
||||
for round_num in range(self.num_rounds):
|
||||
round_key = f"round_{round_num}"
|
||||
round_metrics = self.performance_metrics[round_key]
|
||||
performance_data["round"][round_key] = {
|
||||
"average_ttft": (
|
||||
sum(round_metrics["ttft"]) / len(round_metrics["ttft"])
|
||||
if round_metrics["ttft"]
|
||||
else 0
|
||||
),
|
||||
"cache_hit_rate": (
|
||||
0
|
||||
if sum(round_metrics["prompt_len"]) == 0
|
||||
else sum(round_metrics["cached_tokens"])
|
||||
/ sum(round_metrics["prompt_len"])
|
||||
),
|
||||
"request_count": len(round_metrics["ttft"]),
|
||||
}
|
||||
print("All requests completed")
|
||||
print("Performance metrics summary:")
|
||||
print(
|
||||
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
|
||||
)
|
||||
print(
|
||||
f" Average Prompt Length: {performance_data['summary']['average_prompt_len']:.2f} tokens"
|
||||
)
|
||||
print(
|
||||
f" Average Output Length: {performance_data['summary']['average_output_len']:.2f} tokens"
|
||||
)
|
||||
print(
|
||||
f" P90 Prompt Length: {performance_data['summary']['p90_prompt_len']:.0f} tokens"
|
||||
)
|
||||
print(
|
||||
f" P99 Prompt Length: {performance_data['summary']['p99_prompt_len']:.0f} tokens"
|
||||
)
|
||||
print(
|
||||
f" P90 Output Length: {performance_data['summary']['p90_output_len']:.0f} tokens"
|
||||
)
|
||||
print(
|
||||
f" P99 Output Length: {performance_data['summary']['p99_output_len']:.0f} tokens"
|
||||
)
|
||||
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
|
||||
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
|
||||
print(f" P99 TTFT: {performance_data['summary']['p99_ttft']:.2f}")
|
||||
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
|
||||
print(f" Max TTFT: {performance_data['summary']['max_ttft']:.2f}")
|
||||
print(f" Average ITL: {performance_data['summary']['average_itl']:.4f}")
|
||||
print(f" P90 ITL: {performance_data['summary']['p90_itl']:.4f}")
|
||||
print(f" P99 ITL: {performance_data['summary']['p99_itl']:.4f}")
|
||||
print(f" Median ITL: {performance_data['summary']['median_itl']:.4f}")
|
||||
print(f" Max ITL: {performance_data['summary']['max_itl']:.4f}")
|
||||
print(
|
||||
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
|
||||
)
|
||||
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
||||
print(f" P99 latency: {performance_data['summary']['p99_latency']:.2f}")
|
||||
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
||||
print(f" Max latency: {performance_data['summary']['max_latency']:.2f}")
|
||||
print(
|
||||
f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second"
|
||||
)
|
||||
print(
|
||||
f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second"
|
||||
)
|
||||
print(
|
||||
f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||
)
|
||||
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
||||
|
||||
if self.enable_round_barrier:
|
||||
# Print round-basedsummary
|
||||
print("Per-round metrics:")
|
||||
if "round" in performance_data:
|
||||
for round_num in range(self.num_rounds):
|
||||
round_key = f"round_{round_num}"
|
||||
if round_key in performance_data["round"]:
|
||||
round_data = performance_data["round"][round_key]
|
||||
avg_ttft = round_data["average_ttft"]
|
||||
cache_hit_rate = round_data["cache_hit_rate"]
|
||||
request_count = round_data["request_count"]
|
||||
clients_in_round = self.clients_per_round[round_num]
|
||||
print(
|
||||
f" Round {round_num}: Average TTFT = {avg_ttft:.2f}s, "
|
||||
f"Cache Hit Rate = {cache_hit_rate:.6f} "
|
||||
f"({request_count} requests, "
|
||||
f"{clients_in_round} clients)"
|
||||
)
|
||||
else:
|
||||
print(f" Round {round_num}: No requests completed")
|
||||
|
||||
return performance_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
if args.disable_auto_run:
|
||||
print("Running with specified request rate...")
|
||||
request_rates = [args.request_rate]
|
||||
else:
|
||||
print("Auto-running with different request rates...")
|
||||
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||
|
||||
for rate in request_rates:
|
||||
args.request_rate = rate
|
||||
requests.post(flush_cache_url)
|
||||
time.sleep(1)
|
||||
performance_data = WorkloadGenerator(args).run()
|
||||
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
|
||||
1029
third_party/sglang/benchmark/hicache/bench_serving.py
vendored
Normal file
1029
third_party/sglang/benchmark/hicache/bench_serving.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
583
third_party/sglang/benchmark/hicache/data_processing.py
vendored
Normal file
583
third_party/sglang/benchmark/hicache/data_processing.py
vendored
Normal file
@@ -0,0 +1,583 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from nextqa import NExTQALoader
|
||||
|
||||
# from nextqa.video import , VideoPrompt
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from sglang.benchmark.datasets.common import (
|
||||
SHAREGPT_FILENAME,
|
||||
SHAREGPT_REPO_ID,
|
||||
gen_prompt,
|
||||
)
|
||||
from sglang.benchmark.datasets.generated_shared_prefix import get_gen_prefix_cache_path
|
||||
from sglang.benchmark.utils import download_and_cache_hf_file
|
||||
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
|
||||
from sglang.utils import encode_video_base64
|
||||
|
||||
# type of content fields, can be only prompts or with images/videos
|
||||
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
|
||||
|
||||
# A list of all the conversations. Each conversation is a list of
|
||||
# tuples. If multiturn is not enabled, the length of list is 1,
|
||||
# containing only the first Q&A pair.
|
||||
# For the shared prefix workload (synthetic, loogle, nextqa), it
|
||||
# is a list of conversations sharing the same prefix (synthetic,
|
||||
# doc, video)
|
||||
SampleOutput = List[List[Tuple[MsgContent, int, int]]]
|
||||
|
||||
|
||||
def common_filter_chat(
|
||||
num_requests: int,
|
||||
new_dataset: List,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
min_prompt_len: Optional[int],
|
||||
min_output_len: Optional[int],
|
||||
max_prompt_len: Optional[int],
|
||||
max_output_len: Optional[int],
|
||||
fixed_output_len: Optional[int],
|
||||
) -> SampleOutput:
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: SampleOutput = []
|
||||
l = 0
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
while l < num_requests:
|
||||
for i in range(len(new_dataset)):
|
||||
if l == num_requests:
|
||||
break
|
||||
processed = []
|
||||
for j in new_dataset[i]:
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = j[0]
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
|
||||
completion = j[1]
|
||||
completion_token_ids = tokenizer.encode(completion)
|
||||
output_len = (
|
||||
len(completion_token_ids)
|
||||
if fixed_output_len is None
|
||||
else fixed_output_len
|
||||
)
|
||||
if (
|
||||
min_prompt_len is not None
|
||||
and prompt_len < min_prompt_len
|
||||
or min_output_len is not None
|
||||
and output_len < min_output_len
|
||||
or max_prompt_len is not None
|
||||
and prompt_len > max_prompt_len
|
||||
or max_output_len is not None
|
||||
and output_len > max_output_len
|
||||
):
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
input_tokens += prompt_len
|
||||
output_tokens += output_len
|
||||
processed.append((prompt, prompt_len, output_len))
|
||||
if len(processed) != 0:
|
||||
filtered_dataset.append(processed)
|
||||
l += 1
|
||||
|
||||
print(f"#Input tokens: {input_tokens}")
|
||||
print(f"#Output tokens: {output_tokens}")
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
disable_shuffle: bool = False,
|
||||
enable_multiturn: bool = True,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> SampleOutput:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path):
|
||||
dataset_path = download_and_cache_hf_file(
|
||||
repo_id=SHAREGPT_REPO_ID,
|
||||
filename=SHAREGPT_FILENAME,
|
||||
)
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
|
||||
# Keep one conversation in one list
|
||||
new_dataset = []
|
||||
for data in dataset:
|
||||
if len(data["conversations"]) % 2 != 0:
|
||||
continue
|
||||
if data["conversations"][0]["from"] != "human":
|
||||
continue
|
||||
chat = []
|
||||
total_len = 2
|
||||
if enable_multiturn:
|
||||
total_len = len(data["conversations"])
|
||||
for i in range(0, total_len, 2):
|
||||
# One user One Assistant
|
||||
chat.append(
|
||||
(
|
||||
data["conversations"][i]["value"],
|
||||
data["conversations"][i + 1]["value"],
|
||||
)
|
||||
)
|
||||
new_dataset.append(chat)
|
||||
|
||||
if not disable_shuffle:
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(new_dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: SampleOutput = common_filter_chat(
|
||||
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
|
||||
)
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_ultrachat_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
disable_shuffle: bool = False,
|
||||
enable_multiturn: bool = True,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> SampleOutput:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset
|
||||
dataset = []
|
||||
with open(dataset_path) as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
dataset.append(json.loads(line))
|
||||
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["data"]) >= 2]
|
||||
|
||||
# Keep one conversation in one list
|
||||
new_dataset = []
|
||||
for data in dataset:
|
||||
if len(data["data"]) % 2 != 0:
|
||||
continue
|
||||
chat = []
|
||||
total_len = 2
|
||||
if enable_multiturn:
|
||||
total_len = len(data["data"])
|
||||
for i in range(0, total_len, 2):
|
||||
# One user One Assistant
|
||||
chat.append((data["data"][i], data["data"][i + 1]))
|
||||
new_dataset.append(chat)
|
||||
|
||||
# Shuffle the dataset.
|
||||
if not disable_shuffle:
|
||||
random.shuffle(new_dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: SampleOutput = common_filter_chat(
|
||||
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
|
||||
)
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_loogle_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
disable_shuffle: bool = False,
|
||||
enable_multiturn: bool = True,
|
||||
enable_shared_prefix: bool = False,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> SampleOutput:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset
|
||||
dataset = []
|
||||
with open(dataset_path) as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
dataset.append(json.loads(line))
|
||||
|
||||
# Keep one conversation in one list
|
||||
new_dataset = []
|
||||
# TODO: Add shared prefix support for loogle
|
||||
# NOTE: Now we preprocess it only for chat
|
||||
for data in dataset:
|
||||
chat = []
|
||||
if (
|
||||
"qa_pairs" not in data
|
||||
or data["qa_pairs"] == "none"
|
||||
or len(data["qa_pairs"]) == 0
|
||||
):
|
||||
# If Q is none (for summarization),
|
||||
# We add a question for summarization
|
||||
# And keep the summary up to 1024 words
|
||||
chat.append(
|
||||
(
|
||||
"Input: "
|
||||
+ data["input"]
|
||||
+ " Question: "
|
||||
+ "Please summarize the input",
|
||||
data["input"][:1024],
|
||||
)
|
||||
)
|
||||
new_dataset.append(chat)
|
||||
else:
|
||||
qa_pairs = eval(data["qa_pairs"])
|
||||
for i, qa in enumerate(qa_pairs):
|
||||
if i == 0 or enable_shared_prefix:
|
||||
# Combine input with the first Q
|
||||
chat.append(
|
||||
("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"])
|
||||
)
|
||||
elif enable_multiturn:
|
||||
chat.append((qa["Q"], qa["A"]))
|
||||
|
||||
new_dataset.append(chat)
|
||||
|
||||
# Shuffle the dataset.
|
||||
if not disable_shuffle:
|
||||
random.shuffle(new_dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: SampleOutput = common_filter_chat(
|
||||
num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len
|
||||
)
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_nextqa_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_frames: int, # Specific for video
|
||||
model_path: str,
|
||||
disable_shuffle: bool = False,
|
||||
enable_multiturn: bool = True, # No multiturn support for now
|
||||
backend: str = "sglang-oai",
|
||||
chat_template_name: Optional[str] = None,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> SampleOutput:
|
||||
"""
|
||||
Example of messages:
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": base64_data}},
|
||||
{"type": "text", "text": video.prompt},
|
||||
],
|
||||
}
|
||||
"""
|
||||
|
||||
if fixed_output_len is None:
|
||||
fixed_output_len = 4096
|
||||
|
||||
# TODO: Check for multiturn
|
||||
dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames)
|
||||
new_dataset = []
|
||||
for v in dataset:
|
||||
new_dataset.append(v)
|
||||
|
||||
if not disable_shuffle:
|
||||
random.shuffle(new_dataset)
|
||||
|
||||
# TODO: prompt len can get from server side
|
||||
filtered_dataset = []
|
||||
l = 0
|
||||
while l < num_requests:
|
||||
for i in range(len(new_dataset)):
|
||||
if l == num_requests:
|
||||
break
|
||||
|
||||
video = new_dataset[i]
|
||||
|
||||
# text prompt
|
||||
prompt = video.prompt
|
||||
|
||||
# NOTE: Chat Template is a must for video benchmark because we have to
|
||||
# add special image token for later expansion
|
||||
if backend == "sglang" or backend == "sglang-native":
|
||||
if "chat_template" in tokenizer.init_kwargs:
|
||||
chat_template = get_chat_template(tokenizer.get_chat_template())
|
||||
elif chat_template_name is not None:
|
||||
chat_template = get_chat_template(chat_template_name)
|
||||
else:
|
||||
chat_template = get_chat_template_by_model_path(model_path)
|
||||
prompt = chat_template.image_token + prompt
|
||||
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = fixed_output_len # max output len, not real output len
|
||||
|
||||
# video input
|
||||
base64_data = encode_video_base64(video.path, video.num_frames)
|
||||
|
||||
# NOTE: This will be replaced by the expanded length from the server
|
||||
prompt_len += video.num_frames
|
||||
|
||||
# add to content
|
||||
content = [
|
||||
{"type": "image_url", "image_url": {"url": base64_data}},
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
|
||||
filtered_dataset.append([(content, prompt_len, output_len)])
|
||||
l += 1
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_random_requests(
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
num_prompts: int,
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dataset_path: str,
|
||||
disable_shuffle: bool = False,
|
||||
) -> SampleOutput:
|
||||
|
||||
input_lens = np.random.randint(
|
||||
max(int(input_len * range_ratio), 1),
|
||||
input_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
output_lens = np.random.randint(
|
||||
int(output_len * range_ratio),
|
||||
output_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
|
||||
if True:
|
||||
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
||||
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path):
|
||||
dataset_path = download_and_cache_hf_file(
|
||||
repo_id=SHAREGPT_REPO_ID,
|
||||
filename=SHAREGPT_FILENAME,
|
||||
)
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
if not disable_shuffle:
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
input_requests: SampleOutput = []
|
||||
for data in dataset:
|
||||
i = len(input_requests)
|
||||
if i == num_prompts:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = data[0]
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
|
||||
# Skip empty prompt
|
||||
if prompt_len == 0:
|
||||
continue
|
||||
|
||||
if prompt_len > input_lens[i]:
|
||||
input_ids = prompt_token_ids[: input_lens[i]]
|
||||
else:
|
||||
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
|
||||
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
|
||||
prompt = tokenizer.decode(input_ids)
|
||||
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
|
||||
else:
|
||||
# Sample token ids from random integers. This can cause some NaN issues.
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode(
|
||||
[
|
||||
(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])
|
||||
]
|
||||
)
|
||||
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
|
||||
|
||||
print(f"#Input tokens: {np.sum(input_lens)}")
|
||||
print(f"#Output tokens: {np.sum(output_lens)}")
|
||||
return input_requests
|
||||
|
||||
|
||||
def sample_generated_shared_prefix_requests(
|
||||
num_groups: int,
|
||||
prompts_per_group: int,
|
||||
system_prompt_len: int,
|
||||
question_len: int,
|
||||
output_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
args,
|
||||
disable_shuffle: bool = False,
|
||||
) -> SampleOutput:
|
||||
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
||||
cache_path = get_gen_prefix_cache_path(
|
||||
args.seed,
|
||||
num_groups,
|
||||
prompts_per_group,
|
||||
system_prompt_len,
|
||||
question_len,
|
||||
output_len,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
# Try to load from cache first
|
||||
if cache_path.exists():
|
||||
print(f"\nLoading cached generated input data from {cache_path}")
|
||||
with open(cache_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
print("\nGenerating new input data...")
|
||||
|
||||
# Generate system prompts for each group
|
||||
system_prompts = []
|
||||
for _ in range(num_groups):
|
||||
system_prompt = gen_prompt(tokenizer, system_prompt_len)
|
||||
system_prompts.append(system_prompt)
|
||||
|
||||
# Generate questions
|
||||
questions = []
|
||||
for _ in range(num_groups * prompts_per_group):
|
||||
question = gen_prompt(tokenizer, question_len)
|
||||
questions.append(question)
|
||||
|
||||
# Combine system prompts with questions
|
||||
input_requests = []
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
|
||||
system_prompt = system_prompts[group_idx]
|
||||
input_requests.append([])
|
||||
for prompt_idx in tqdm(
|
||||
range(prompts_per_group), desc="Generating questions", leave=False
|
||||
):
|
||||
question = questions[group_idx * prompts_per_group + prompt_idx]
|
||||
full_prompt = f"{system_prompt}\n\n{question}"
|
||||
prompt_len = len(tokenizer.encode(full_prompt))
|
||||
input_requests[-1].append((full_prompt, prompt_len, output_len))
|
||||
total_input_tokens += prompt_len
|
||||
total_output_tokens += output_len
|
||||
|
||||
if not disable_shuffle:
|
||||
# Shuffle questions
|
||||
random.shuffle(input_requests)
|
||||
|
||||
# Print statistics
|
||||
print(f"\nGenerated shared prefix dataset statistics:")
|
||||
print(f"Number of groups: {num_groups}")
|
||||
print(f"Prompts per group: {prompts_per_group}")
|
||||
print(f"Total prompts: {len(input_requests) * prompts_per_group}")
|
||||
print(f"Total input tokens: {total_input_tokens}")
|
||||
print(f"Total output tokens: {total_output_tokens}")
|
||||
print(
|
||||
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
|
||||
)
|
||||
print(
|
||||
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
||||
)
|
||||
|
||||
# Save to cache
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Caching generated input data to {cache_path}")
|
||||
with open(cache_path, "wb") as f:
|
||||
pickle.dump(input_requests, f)
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
def get_dataset(args, tokenizer):
|
||||
if args.dataset_name == "sharegpt":
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
enable_multiturn=args.enable_multiturn,
|
||||
fixed_output_len=args.fixed_output_len,
|
||||
)
|
||||
elif args.dataset_name == "ultrachat":
|
||||
input_requests = sample_ultrachat_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
enable_multiturn=args.enable_multiturn,
|
||||
fixed_output_len=args.fixed_output_len,
|
||||
)
|
||||
elif args.dataset_name == "loogle":
|
||||
input_requests = sample_loogle_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
enable_multiturn=args.enable_multiturn,
|
||||
enable_shared_prefix=args.enable_shared_prefix,
|
||||
fixed_output_len=args.fixed_output_len,
|
||||
)
|
||||
elif args.dataset_name == "nextqa":
|
||||
input_requests = sample_nextqa_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
max_frames=args.max_frames,
|
||||
model_path=args.model,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
enable_multiturn=args.enable_multiturn,
|
||||
backend=args.backend,
|
||||
chat_template_name=args.chat_template,
|
||||
fixed_output_len=args.fixed_output_len,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
num_prompts=args.num_prompts,
|
||||
range_ratio=args.random_range_ratio,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
)
|
||||
elif args.dataset_name == "generated-shared-prefix":
|
||||
input_requests = sample_generated_shared_prefix_requests(
|
||||
num_groups=args.gsp_num_groups,
|
||||
prompts_per_group=args.gsp_prompts_per_group,
|
||||
system_prompt_len=args.gsp_system_prompt_len,
|
||||
question_len=args.gsp_question_len,
|
||||
output_len=args.gsp_output_len,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
return input_requests
|
||||
66
third_party/sglang/benchmark/hicache/download.sh
vendored
Executable file
66
third_party/sglang/benchmark/hicache/download.sh
vendored
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/bash
|
||||
|
||||
# The usage function
|
||||
usage() {
|
||||
echo "Usage: $0 {sharegpt|ultragpt|loogle|nextqa|all}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# The download function
|
||||
download() {
|
||||
case "$1" in
|
||||
sharegpt)
|
||||
echo $1
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
;;
|
||||
ultragpt)
|
||||
echo $1
|
||||
# Questions about the world
|
||||
wget https://cloud.tsinghua.edu.cn/seafhttp/files/be1d7b87-22ca-449e-a6a7-c61d1ea7e010/ultrachat_release_230407.json
|
||||
# Writing and Creation
|
||||
wget https://cloud.tsinghua.edu.cn/seafhttp/files/61742d2a-25e2-4d08-b2b9-15f47ae50ace/ultrachat_material_release_230417.json
|
||||
wget https://cloud.tsinghua.edu.cn/seafhttp/files/f71f6aa6-d346-4b16-85b7-8502efa3d608/ultrachat_material_release_230412.json
|
||||
# External materials
|
||||
wget https://cloud.tsinghua.edu.cn/seafhttp/files/42d22e28-e899-4975-a70f-5eda163e265d/ultrachat_existent_material_release_230420.json.gz
|
||||
gunzip ultrachat_existent_material_release_230420.json.gz
|
||||
;;
|
||||
loogle)
|
||||
echo $1
|
||||
git lfs install
|
||||
git clone git@hf.co:datasets/bigainlco/LooGLE
|
||||
unzip LooGLE/data.zip
|
||||
;;
|
||||
nextqa)
|
||||
echo $1
|
||||
git lfs install
|
||||
git clone https://huggingface.co/datasets/lmms-lab/NExTQA
|
||||
unzip NExTQA/videos.zip
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Arg check
|
||||
if [ "$#" -ne 1 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
# Invoke
|
||||
|
||||
case "$1" in
|
||||
sharegpt|ultragpt|loogle|nextqa)
|
||||
download "$1"
|
||||
;;
|
||||
all)
|
||||
download sharegpt
|
||||
download ultragpt
|
||||
download loogle
|
||||
download nextqa
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
159
third_party/sglang/benchmark/hicache/nextqa.py
vendored
Normal file
159
third_party/sglang/benchmark/hicache/nextqa.py
vendored
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def find_video_files(video_dir) -> List[str]:
|
||||
if os.path.isfile(video_dir):
|
||||
return [video_dir]
|
||||
|
||||
video_files = []
|
||||
for root, dirs, files in os.walk(video_dir):
|
||||
for file in files:
|
||||
if file.endswith((".mp4", ".avi", ".mov")):
|
||||
video_files.append(os.path.join(root, file))
|
||||
# if file is dir
|
||||
elif os.path.isdir(file):
|
||||
video_files.extend(find_video_files(file))
|
||||
return video_files
|
||||
|
||||
|
||||
def video_frames(video_path, max_frames) -> int:
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
return min(total_frames, max_frames)
|
||||
|
||||
|
||||
class Video:
|
||||
def __init__(self, video_path, num_frames):
|
||||
self.path = video_path
|
||||
self.num_frames = num_frames
|
||||
|
||||
def __str__(self):
|
||||
return f"Video({self.path}, {self.num_frames})"
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.path, self.num_frames))
|
||||
|
||||
|
||||
class VideoPrompt(Video):
|
||||
def __init__(self, video_path, num_frames, prompt):
|
||||
super().__init__(video_path, num_frames)
|
||||
self.prompt = prompt
|
||||
|
||||
def __str__(self):
|
||||
return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})"
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.path, self.num_frames, self.prompt))
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
pass
|
||||
|
||||
|
||||
class VideoFileLoader(VideoLoader):
|
||||
"""
|
||||
Load all the videos in a directory
|
||||
"""
|
||||
|
||||
def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):
|
||||
super().__init__()
|
||||
self.video_dir = video_dir
|
||||
self.video_files = find_video_files(video_dir)
|
||||
self.batch_size = batch_size
|
||||
self.max_frames = max_frames
|
||||
print(f"batch_size: {batch_size}, max_frames: {max_frames}")
|
||||
|
||||
def __iter__(self): # (file, number of frames)
|
||||
if self.batch_size == 1:
|
||||
for video_file in self.video_files:
|
||||
yield Video(video_file, video_frames(video_file, self.max_frames))
|
||||
else:
|
||||
batch = []
|
||||
for video_file in self.video_files:
|
||||
video = Video(video_file, video_frames(video_file, self.max_frames))
|
||||
batch.append(video)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
|
||||
class NExTQALoader(VideoLoader):
|
||||
"""
|
||||
Load vdideos and prompts from NExT dataset
|
||||
set: train, test or validation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"
|
||||
):
|
||||
"""
|
||||
task: 'MV' or 'OE'
|
||||
"""
|
||||
super().__init__()
|
||||
self.task = task
|
||||
print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA")
|
||||
self.ds = load_dataset("lmms-lab/NExTQA", task)
|
||||
self.ds = self.ds[dset]
|
||||
|
||||
# self.n = ds.num_rows
|
||||
self.video_dir = video_dir
|
||||
self.video_files = find_video_files(video_dir)
|
||||
self.video_to_path = dict()
|
||||
for video_file in self.video_files:
|
||||
video_id = video_file.split("/")[-1].split(".")[0]
|
||||
self.video_to_path[video_id] = video_file
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.max_frames = max_frames
|
||||
|
||||
def get_video_prompt(self, entry, max_frames) -> VideoPrompt:
|
||||
# Get video
|
||||
video_id = entry["video"]
|
||||
video_path = self.video_to_path[video_id]
|
||||
assert os.path.exists(video_path), f"Video not found: {video_path}"
|
||||
num_frames = min(entry["frame_count"], max_frames)
|
||||
video = Video(video_path, num_frames)
|
||||
prompt = entry["question"] + "?"
|
||||
if self.task == "MC": # add choices
|
||||
prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}'
|
||||
return VideoPrompt(video_path, num_frames, prompt)
|
||||
|
||||
def __iter__(self):
|
||||
if self.batch_size == 1:
|
||||
for entry in self.ds:
|
||||
yield self.get_video_prompt(entry, self.max_frames)
|
||||
else:
|
||||
batch = []
|
||||
for entry in self.ds:
|
||||
video = self.get_video_prompt(entry, self.max_frames)
|
||||
batch.append(video)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
|
||||
# main
|
||||
if __name__ == "__main__":
|
||||
video_dir = "./videos"
|
||||
# video_loader = VideoFileLoader(video_dir, batch_size=16)
|
||||
# for batch in video_loader:
|
||||
# print(f"Number of videos in batch: {len(batch)}")
|
||||
# for video_file, num_frames in batch:
|
||||
# print(f"Video: {video_file} number of frames: {num_frames}")
|
||||
|
||||
video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE")
|
||||
for batch in video_loader:
|
||||
print(f"Number of videos in batch: {len(batch)}")
|
||||
for video_file, num_frames, prompt in batch:
|
||||
print(
|
||||
f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}"
|
||||
)
|
||||
# break
|
||||
# for video_file, prompt in batch:
|
||||
# print(f"Video: {video_file} prompt: {prompt}")
|
||||
# break
|
||||
248
third_party/sglang/benchmark/hicache/perf.py
vendored
Normal file
248
third_party/sglang/benchmark/hicache/perf.py
vendored
Normal file
@@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def jit_hicache_impl(
|
||||
k_cache_dst: torch.Tensor,
|
||||
v_cache_dst: torch.Tensor,
|
||||
indices_dst: torch.Tensor,
|
||||
k_cache_src: torch.Tensor,
|
||||
v_cache_src: torch.Tensor,
|
||||
indices_src: torch.Tensor,
|
||||
item_bytes: int,
|
||||
block_quota: int,
|
||||
) -> None:
|
||||
from sglang.jit_kernel.hicache import transfer_hicache_one_layer
|
||||
|
||||
_ = item_bytes
|
||||
|
||||
transfer_hicache_one_layer(
|
||||
k_cache_dst=k_cache_dst,
|
||||
v_cache_dst=v_cache_dst,
|
||||
indices_dst=indices_dst,
|
||||
k_cache_src=k_cache_src,
|
||||
v_cache_src=v_cache_src,
|
||||
indices_src=indices_src,
|
||||
block_quota=block_quota,
|
||||
)
|
||||
|
||||
|
||||
def ref_hicache_impl(
|
||||
k_cache_dst: torch.Tensor,
|
||||
v_cache_dst: torch.Tensor,
|
||||
indices_dst: torch.Tensor,
|
||||
k_cache_src: torch.Tensor,
|
||||
v_cache_src: torch.Tensor,
|
||||
indices_src: torch.Tensor,
|
||||
item_bytes: int,
|
||||
block_quota: int,
|
||||
) -> None:
|
||||
from sgl_kernel import transfer_kv_per_layer
|
||||
|
||||
transfer_kv_per_layer(
|
||||
src_k=k_cache_src,
|
||||
src_v=v_cache_src,
|
||||
dst_k=k_cache_dst,
|
||||
dst_v=v_cache_dst,
|
||||
src_indices=indices_src,
|
||||
dst_indices=indices_dst,
|
||||
item_size=item_bytes,
|
||||
block_quota=block_quota,
|
||||
)
|
||||
|
||||
|
||||
class HicacheBenchArgs(NamedTuple):
|
||||
cache_item_size: int
|
||||
dtype: torch.dtype
|
||||
block_quota: int
|
||||
|
||||
|
||||
def perf(f: Callable[[], Any], loop: int = 100) -> float:
|
||||
tic = torch.cuda.Event(enable_timing=True)
|
||||
toc = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
# warm up
|
||||
f()
|
||||
torch.cuda._sleep(10**8)
|
||||
tic.record()
|
||||
for _ in range(loop):
|
||||
f()
|
||||
toc.record()
|
||||
toc.synchronize()
|
||||
return tic.elapsed_time(toc) / loop
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_hicache_kernel(args: HicacheBenchArgs) -> None:
|
||||
CACHE_ITEM_SIZE, DTYPE, BLOCK_QUOTA = args
|
||||
|
||||
CUDA_CACHE_SIZE = 1024 * 1024
|
||||
HOST_CACHE_SIZE = CUDA_CACHE_SIZE * 2
|
||||
|
||||
cuda_cache = torch.randn(
|
||||
(2, CUDA_CACHE_SIZE, CACHE_ITEM_SIZE),
|
||||
dtype=DTYPE,
|
||||
device="cuda",
|
||||
)
|
||||
host_cache = torch.empty(
|
||||
(2, HOST_CACHE_SIZE, CACHE_ITEM_SIZE),
|
||||
dtype=DTYPE,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
ITEM_BYTES = cuda_cache.element_size() * CACHE_ITEM_SIZE
|
||||
|
||||
def _gen_indices(size: int, bs: int) -> torch.Tensor:
|
||||
assert bs <= size
|
||||
result = (
|
||||
(torch.randperm(size, dtype=torch.int64, device="cuda")[:bs]).sort().values
|
||||
)
|
||||
if not (torch.all(result >= 0) and torch.all(result < size)):
|
||||
where = (result < 0) | (result >= size)
|
||||
place = where.nonzero(as_tuple=False)
|
||||
print("Invalid indices at positions:", place)
|
||||
print("Invalid indices values:", result[place])
|
||||
raise ValueError("Generated invalid indices")
|
||||
return result
|
||||
|
||||
def _calc_tput(dur: float) -> float:
|
||||
return (MEM / (1024**3)) / (dur / 1000) # GB/s
|
||||
|
||||
def _gain_str(aot_dur: float, jit_dur: float) -> str:
|
||||
gain = 100 * (aot_dur / jit_dur - 1)
|
||||
if gain >= 0:
|
||||
return f"+{gain:>6.2f}%"
|
||||
else:
|
||||
return f"-{-gain:>6.2f}%"
|
||||
|
||||
print(f"{CACHE_ITEM_SIZE = }, {DTYPE = }, {BLOCK_QUOTA = }")
|
||||
|
||||
def _fast_test_correctness(bs: int):
|
||||
src_indices = _gen_indices(CUDA_CACHE_SIZE, bs)
|
||||
dst_indices = _gen_indices(HOST_CACHE_SIZE, bs)
|
||||
host_cache_cuda = torch.randn_like(host_cache, device="cuda")
|
||||
host_cache.copy_(host_cache_cuda, non_blocking=True)
|
||||
|
||||
# copy from cuda to host
|
||||
jit_hicache_impl(
|
||||
k_cache_dst=host_cache[0],
|
||||
v_cache_dst=host_cache[1],
|
||||
indices_dst=dst_indices,
|
||||
k_cache_src=cuda_cache[0],
|
||||
v_cache_src=cuda_cache[1],
|
||||
indices_src=src_indices,
|
||||
item_bytes=ITEM_BYTES,
|
||||
block_quota=BLOCK_QUOTA,
|
||||
)
|
||||
dst_indices = dst_indices.cpu()
|
||||
assert torch.all(
|
||||
host_cache[0][dst_indices].cuda() == cuda_cache[0][src_indices]
|
||||
)
|
||||
|
||||
BS_RANGE = [2**n for n in range(8, 18)]
|
||||
for bs in BS_RANGE:
|
||||
_fast_test_correctness(bs)
|
||||
|
||||
print("Correctness passed! Start HiCache kernel performance test...")
|
||||
print("=" * 70)
|
||||
|
||||
for bs in BS_RANGE:
|
||||
indices_dst = _gen_indices(CUDA_CACHE_SIZE, bs)
|
||||
indices_src = _gen_indices(HOST_CACHE_SIZE, bs)
|
||||
MEM = 2 * bs * ITEM_BYTES
|
||||
|
||||
def _run_kernel_h2d(impl):
|
||||
return impl(
|
||||
k_cache_dst=cuda_cache[0],
|
||||
v_cache_dst=cuda_cache[1],
|
||||
indices_dst=indices_dst,
|
||||
k_cache_src=host_cache[0],
|
||||
v_cache_src=host_cache[1],
|
||||
indices_src=indices_src,
|
||||
item_bytes=ITEM_BYTES,
|
||||
block_quota=BLOCK_QUOTA,
|
||||
)
|
||||
|
||||
our_h2d_dur = perf(lambda: _run_kernel_h2d(jit_hicache_impl))
|
||||
ref_h2d_dur = perf(lambda: _run_kernel_h2d(ref_hicache_impl))
|
||||
print(
|
||||
f"{bs = :6d}, H->D",
|
||||
f"| aot {_calc_tput(ref_h2d_dur):<6.2f} GB/s",
|
||||
f"| jit {_calc_tput(our_h2d_dur):<6.2f} GB/s",
|
||||
f"| {_gain_str(ref_h2d_dur, our_h2d_dur)}",
|
||||
)
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
for bs in BS_RANGE:
|
||||
indices_dst = _gen_indices(HOST_CACHE_SIZE, bs)
|
||||
indices_src = _gen_indices(CUDA_CACHE_SIZE, bs)
|
||||
MEM = 2 * bs * ITEM_BYTES
|
||||
|
||||
def _run_kernel_d2h(impl):
|
||||
return impl(
|
||||
k_cache_dst=host_cache[0],
|
||||
v_cache_dst=host_cache[1],
|
||||
indices_dst=indices_dst,
|
||||
k_cache_src=cuda_cache[0],
|
||||
v_cache_src=cuda_cache[1],
|
||||
indices_src=indices_src,
|
||||
item_bytes=ITEM_BYTES,
|
||||
block_quota=BLOCK_QUOTA,
|
||||
)
|
||||
|
||||
our_d2h_dur = perf(lambda: _run_kernel_d2h(jit_hicache_impl))
|
||||
ref_d2h_dur = perf(lambda: _run_kernel_d2h(ref_hicache_impl))
|
||||
print(
|
||||
f"{bs = :6d}, D->H",
|
||||
f"| aot {_calc_tput(ref_d2h_dur):<6.2f} GB/s",
|
||||
f"| jit {_calc_tput(our_d2h_dur):<6.2f} GB/s",
|
||||
f"| {_gain_str(ref_d2h_dur, our_d2h_dur)}",
|
||||
)
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
torch.cuda.set_device(0)
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
|
||||
tic = torch.cuda.Event(enable_timing=True)
|
||||
toc = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
BUF_SIZE = 1024 * 1024 * 1024
|
||||
cuda_mem = torch.empty(BUF_SIZE, dtype=torch.uint8, device="cuda")
|
||||
host_mem = torch.empty(BUF_SIZE, dtype=torch.uint8, device="cpu", pin_memory=True)
|
||||
|
||||
# test peak bandwidth
|
||||
tic.record()
|
||||
cuda_mem.copy_(host_mem, non_blocking=True)
|
||||
toc.record()
|
||||
toc.synchronize()
|
||||
dur = tic.elapsed_time(toc)
|
||||
print(f"Peak H->D Bandwidth: {(BUF_SIZE / (1024**3)) / (dur / 1000):.2f} GB/s")
|
||||
|
||||
tic.record()
|
||||
host_mem.copy_(cuda_mem, non_blocking=True)
|
||||
toc.record()
|
||||
toc.synchronize()
|
||||
dur = tic.elapsed_time(toc)
|
||||
print(f"Peak D->H Bandwidth: {(BUF_SIZE / (1024**3)) / (dur / 1000):.2f} GB/s")
|
||||
|
||||
for block_quota in [1, 2, 3, 4]:
|
||||
for cache_item_size in [128, 256, 512, 1024]:
|
||||
args = HicacheBenchArgs(
|
||||
cache_item_size=cache_item_size,
|
||||
dtype=torch.float16,
|
||||
block_quota=block_quota,
|
||||
)
|
||||
test_hicache_kernel(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
60
third_party/sglang/benchmark/json_decode_regex/README.md
vendored
Normal file
60
third_party/sglang/benchmark/json_decode_regex/README.md
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
## Run benchmark
|
||||
|
||||
### Build dataset
|
||||
```
|
||||
pip install wikipedia
|
||||
python3 build_dataset.py
|
||||
```
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
llama_cpp_python 0.2.19
|
||||
guidance 0.1.10
|
||||
vllm 0.2.5
|
||||
outlines 0.0.22
|
||||
```
|
||||
|
||||
### Benchmark sglang
|
||||
|
||||
Run Llama-7B
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
Run Mixtral-8x7B
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
||||
```
|
||||
|
||||
Benchmark
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-questions 10
|
||||
```
|
||||
|
||||
|
||||
### Benchmark Outlines + vLLM
|
||||
|
||||
Run Llama-7B
|
||||
|
||||
```
|
||||
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
Benchmark
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend outlines --num-questions 10
|
||||
```
|
||||
|
||||
|
||||
### Benchmark guidance
|
||||
|
||||
Run Llama-7B and benchmark
|
||||
|
||||
```
|
||||
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||
```
|
||||
98
third_party/sglang/benchmark/json_decode_regex/bench_other.py
vendored
Normal file
98
third_party/sglang/benchmark/json_decode_regex/bench_other.py
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
||||
|
||||
|
||||
# fmt: off
|
||||
def json_decode(document, generate):
|
||||
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
s += "{\n"
|
||||
s += ' "name": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "country": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "latitude": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": '
|
||||
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
s += "}\n"
|
||||
|
||||
return s
|
||||
# fmt: on
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
arguments = []
|
||||
for i in range(len(lines[: args.num_questions])):
|
||||
arguments.append(
|
||||
{
|
||||
"document": lines[i]["document"],
|
||||
}
|
||||
)
|
||||
states = [None] * len(arguments)
|
||||
|
||||
# Select backend
|
||||
call_generate = partial(get_call_generate(args), temperature=0)
|
||||
|
||||
# Run requests
|
||||
def get_one_answer(i):
|
||||
states[i] = json_decode(generate=call_generate, **arguments[i])
|
||||
|
||||
tic = time.perf_counter()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(arguments))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
rets = list(
|
||||
tqdm(
|
||||
executor.map(get_one_answer, list(range(len(arguments)))),
|
||||
total=len(arguments),
|
||||
)
|
||||
)
|
||||
for _ in rets:
|
||||
pass
|
||||
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "json_decode_regex",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=20)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
101
third_party/sglang/benchmark/json_decode_regex/bench_sglang.py
vendored
Normal file
101
third_party/sglang/benchmark/json_decode_regex/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def json_warm_up(s):
|
||||
s += "The information about Hogwarts is in the following JSON format.\n"
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
s += "}\n"
|
||||
print(f'The warmp up json result is:\n{s["json_output"]}')
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def json_decode(s, document):
|
||||
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
s += "}\n"
|
||||
# fmt: on
|
||||
|
||||
|
||||
def main(args):
|
||||
lines = read_jsonl(args.data_path)
|
||||
lines = list(lines)
|
||||
arguments = []
|
||||
for i in range(len(lines[: args.num_questions])):
|
||||
arguments.append(
|
||||
{
|
||||
"document": lines[i]["document"],
|
||||
}
|
||||
)
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Warm up
|
||||
json_warm_up.run().sync()
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
states = json_decode.run_batch(
|
||||
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
|
||||
for state in states:
|
||||
fout.write(state["json_output"] + "\n")
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "json_decode_regex",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": args.num_questions,
|
||||
"other": {
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||
parser.add_argument("--num-questions", type=int, default=20)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
58
third_party/sglang/benchmark/json_decode_regex/build_dataset.py
vendored
Normal file
58
third_party/sglang/benchmark/json_decode_regex/build_dataset.py
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
|
||||
import transformers
|
||||
import wikipedia
|
||||
|
||||
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
city_names = [
|
||||
"los angles",
|
||||
"london",
|
||||
"tokyo",
|
||||
"beijing",
|
||||
"singapore",
|
||||
"paris",
|
||||
"dubai",
|
||||
"sydney",
|
||||
"moscow",
|
||||
"rome",
|
||||
"toronto",
|
||||
"rio de janeiro",
|
||||
"istanbul",
|
||||
"berlin",
|
||||
"auckland",
|
||||
"buenos aires",
|
||||
"mexico city",
|
||||
"mumbai",
|
||||
"seoul",
|
||||
"bangkok",
|
||||
"cairo",
|
||||
"athens",
|
||||
"jerusalem",
|
||||
]
|
||||
|
||||
|
||||
def get_content(city_name):
|
||||
content = str(wikipedia.page(city_name).content)
|
||||
content = content.replace("\n\n", "\n")
|
||||
|
||||
tokens = t.encode(content)
|
||||
|
||||
expected_tokens = 3000
|
||||
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
||||
truncate_content = content[:truncate_len]
|
||||
truncate_tokens = t.encode(truncate_content)
|
||||
|
||||
# Count token
|
||||
print(
|
||||
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||
)
|
||||
|
||||
return truncate_content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("questions.jsonl", "w") as fout:
|
||||
for city_name in city_names:
|
||||
truncate_content = get_content(city_name)
|
||||
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||
88
third_party/sglang/benchmark/json_jump_forward/README.md
vendored
Normal file
88
third_party/sglang/benchmark/json_jump_forward/README.md
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
## Run benchmark
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
llama_cpp_python 0.2.38
|
||||
guidance 0.1.10
|
||||
vllm 0.2.7
|
||||
outlines 0.0.25
|
||||
```
|
||||
|
||||
### Build dataset
|
||||
|
||||
When benchmarking long document information retrieval, run the following command to build the dataset:
|
||||
|
||||
```bash
|
||||
pip install wikipedia
|
||||
python3 build_dataset.py
|
||||
```
|
||||
|
||||
### Benchmark sglang
|
||||
|
||||
Run Llama-7B
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
Benchmark Character Generation
|
||||
|
||||
```bash
|
||||
python3 bench_sglang.py --mode character
|
||||
```
|
||||
|
||||
Benchmark City Information Retrieval
|
||||
|
||||
```bash
|
||||
python3 bench_sglang.py --mode city
|
||||
```
|
||||
|
||||
|
||||
### Benchmark Outlines + vLLM
|
||||
|
||||
Run Llama-7B
|
||||
|
||||
```bash
|
||||
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||
```
|
||||
|
||||
Benchmark Character Generation
|
||||
|
||||
```bash
|
||||
python3 bench_other.py --mode character --backend outlines
|
||||
```
|
||||
|
||||
Benchmark City Information Retrieval
|
||||
|
||||
```bash
|
||||
python3 bench_other.py --mode city --backend outlines
|
||||
```
|
||||
|
||||
### Benchmark guidance
|
||||
|
||||
Run Llama-7B and benchmark character generation
|
||||
|
||||
```bash
|
||||
python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||
```
|
||||
|
||||
Run Llama-7B and benchmark city information retrieval
|
||||
|
||||
```bash
|
||||
python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
||||
```
|
||||
|
||||
### Benchmark lmql
|
||||
|
||||
Run Llama-7B and benchmark character generation
|
||||
|
||||
```
|
||||
python3 bench_other.py --mode character --backend lmql --parallel 1
|
||||
```
|
||||
|
||||
Run Llama-7B and benchmark city information retrieval
|
||||
|
||||
```
|
||||
python3 bench_other.py --mode city --backend lmql --parallel 1
|
||||
```
|
||||
288
third_party/sglang/benchmark/json_jump_forward/bench_other.py
vendored
Normal file
288
third_party/sglang/benchmark/json_jump_forward/bench_other.py
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import guidance
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
# there are some FSM bugs with json regex converted from pydantic model
|
||||
# here use a string regex instead
|
||||
# regex_string = build_regex_from_object(HarryPoterRole)
|
||||
character_regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
||||
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
||||
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
||||
+ r""" "wand": \{\n"""
|
||||
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
||||
+ r""" \},\n"""
|
||||
+ r""" "alive": "(Alive|Deceased)",\n"""
|
||||
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
city_regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
||||
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
||||
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
def character_gen(name, generate):
|
||||
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
||||
s += generate(s, max_tokens=256, regex=character_regex)
|
||||
return s
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
def city_gen(document, generate):
|
||||
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
s += generate(s, max_tokens=256, regex=city_regex)
|
||||
return s
|
||||
# fmt: on
|
||||
|
||||
|
||||
@guidance
|
||||
def character_maker(lm, name):
|
||||
regex_str_no_quote = r"[\w\d\s]+"
|
||||
regex_float = r"[0-9]+\.[0-9]+"
|
||||
lm += f"""\
|
||||
{name} is a character in Harry Potter. Please fill in the following information about this character.
|
||||
{{
|
||||
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
||||
"house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
|
||||
"blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
|
||||
"occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
|
||||
"wand": {{
|
||||
"wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
|
||||
"core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
|
||||
"length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
|
||||
}},
|
||||
"alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
|
||||
"patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
|
||||
"bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
|
||||
}}
|
||||
"""
|
||||
|
||||
return lm
|
||||
|
||||
|
||||
async def call_generate_lmql(
|
||||
prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
|
||||
):
|
||||
assert model is not None
|
||||
import lmql
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(question, max_tokens, regex):
|
||||
'''lmql
|
||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
return await program(
|
||||
question=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
max_len=max_len,
|
||||
regex=regex,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@guidance
|
||||
def city_maker(lm, document):
|
||||
regex_str_no_quote = r"[\w\d\s]+"
|
||||
regex_float = r"[0-9]+\.[0-9]+"
|
||||
lm += f"""\
|
||||
Please extract the information of a city from the following wikipedia page.
|
||||
Page begin.
|
||||
{document}
|
||||
Page end.
|
||||
Here is the name, country, and symbol of the city in JSON format.
|
||||
{{
|
||||
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
||||
"country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
|
||||
"latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
|
||||
"population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
|
||||
"top 3 landmarks": [
|
||||
"{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
return lm
|
||||
|
||||
|
||||
def bench_character(args):
|
||||
arguments = []
|
||||
with open(args.data_path, "r") as f:
|
||||
for line in f:
|
||||
arguments.append({"name": line.strip()})
|
||||
arguments = arguments[: args.num_jsons]
|
||||
|
||||
states = [None] * len(arguments)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "outlines":
|
||||
call_generate = partial(get_call_generate(args), temperature=0)
|
||||
|
||||
def get_one_answer(i):
|
||||
states[i] = character_gen(**arguments[i], generate=call_generate)
|
||||
|
||||
elif args.backend == "guidance":
|
||||
model = guidance.models.LlamaCpp(
|
||||
args.model_path,
|
||||
n_gpu_layers=-1,
|
||||
n_ctx=args.n_ctx,
|
||||
)
|
||||
|
||||
def get_one_answer(i):
|
||||
lm = model + character_maker(**arguments[i])
|
||||
states[i] = lm
|
||||
|
||||
elif args.backend == "lmql":
|
||||
import asyncio
|
||||
|
||||
import lmql
|
||||
|
||||
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
||||
call_generate = partial(
|
||||
call_generate_lmql,
|
||||
model=model,
|
||||
max_tokens=256,
|
||||
regex=character_regex,
|
||||
)
|
||||
|
||||
async def get_one_answer_async(i):
|
||||
states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
tic = time.perf_counter()
|
||||
|
||||
if args.backend != "lmql":
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(arguments))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
rets = list(
|
||||
tqdm(
|
||||
executor.map(get_one_answer, list(range(len(arguments)))),
|
||||
total=len(arguments),
|
||||
)
|
||||
)
|
||||
for _ in rets:
|
||||
pass
|
||||
else:
|
||||
batches = []
|
||||
for i in range(0, len(arguments), args.parallel):
|
||||
batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
for bt in tqdm(batches):
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*[get_one_answer_async(i) for i in bt])
|
||||
)
|
||||
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
return states, latency
|
||||
|
||||
|
||||
def bench_city_doc(args):
|
||||
arguments = []
|
||||
for line in read_jsonl(args.data_path):
|
||||
arguments.append({"document": line["document"]})
|
||||
arguments = arguments[: args.num_jsons]
|
||||
|
||||
states = [None] * len(arguments)
|
||||
|
||||
# Select backend
|
||||
if args.backend == "outlines":
|
||||
call_generate = partial(get_call_generate(args), temperature=0)
|
||||
|
||||
def get_one_answer(i):
|
||||
states[i] = city_gen(**arguments[i], generate=call_generate)
|
||||
|
||||
elif args.backend == "guidance":
|
||||
model = guidance.models.LlamaCpp(
|
||||
args.model_path,
|
||||
n_gpu_layers=-1,
|
||||
n_ctx=args.n_ctx,
|
||||
)
|
||||
|
||||
def get_one_answer(i):
|
||||
lm = model + city_maker(**arguments[i])
|
||||
states[i] = lm
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
tic = time.perf_counter()
|
||||
if args.parallel == 1:
|
||||
for i in tqdm(range(len(arguments))):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
rets = executor.map(get_one_answer, list(range(len(arguments))))
|
||||
for _ in rets:
|
||||
pass
|
||||
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
return states, latency
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mode == "character":
|
||||
args.data_path = "dataset.txt"
|
||||
states, latency = bench_character(args)
|
||||
elif args.mode == "city":
|
||||
args.data_path = "questions.jsonl"
|
||||
states, latency = bench_city_doc(args)
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "json_jump_forward",
|
||||
"backend": args.backend,
|
||||
"latency": round(latency, 3),
|
||||
"num_jsons": args.num_jsons,
|
||||
"mode": args.mode,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str)
|
||||
parser.add_argument("--num-jsons", type=int, default=50)
|
||||
parser.add_argument(
|
||||
"--mode", type=str, default="character", choices=["character", "city"]
|
||||
)
|
||||
args = add_common_other_args_and_parse(parser)
|
||||
main(args)
|
||||
143
third_party/sglang/benchmark/json_jump_forward/bench_sglang.py
vendored
Normal file
143
third_party/sglang/benchmark/json_jump_forward/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,143 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
# there are some FSM bugs with json regex converted from pydantic model
|
||||
# here use a string regex instead
|
||||
# regex_string = build_regex_from_object(HarryPoterRole)
|
||||
character_regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
||||
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
||||
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
||||
+ r""" "wand": \{\n"""
|
||||
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
||||
+ r""" \},\n"""
|
||||
+ r""" "alive": "(Alive|Deceased)",\n"""
|
||||
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
city_regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
||||
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
||||
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
||||
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def character_gen(s, name):
|
||||
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
||||
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def city_gen(s, document):
|
||||
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||
s += "Page begin.\n" + document + "Page end.\n"
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
|
||||
# fmt: on
|
||||
|
||||
|
||||
def bench_city_doc(args):
|
||||
arguments = []
|
||||
for line in read_jsonl(args.data_path):
|
||||
arguments.append({"document": line["document"]})
|
||||
arguments = arguments[: args.num_jsons]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
states = city_gen.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
return states, latency
|
||||
|
||||
|
||||
def bench_character(args):
|
||||
arguments = []
|
||||
with open(args.data_path, "r") as f:
|
||||
for line in f:
|
||||
arguments.append({"name": line.strip()})
|
||||
arguments = arguments[: args.num_jsons]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
states = character_gen.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
return states, latency
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mode == "character":
|
||||
args.data_path = "dataset.txt"
|
||||
states, latency = bench_character(args)
|
||||
elif args.mode == "city":
|
||||
args.data_path = "questions.jsonl"
|
||||
states, latency = bench_city_doc(args)
|
||||
|
||||
# Compute accuracy
|
||||
print(f"Latency: {latency:.3f}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
||||
with open(f"{args.backend}_{args.mode}.json", "w") as fout:
|
||||
for state in states:
|
||||
fout.write(state["json_output"] + "\n")
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "json_jump_forward",
|
||||
"backend": args.backend,
|
||||
"latency": round(latency, 3),
|
||||
"num_jsons": args.num_jsons,
|
||||
"mode": args.mode,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str)
|
||||
parser.add_argument("--num-jsons", type=int, default=50)
|
||||
parser.add_argument(
|
||||
"--mode", type=str, default="character", choices=["character", "city"]
|
||||
)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
58
third_party/sglang/benchmark/json_jump_forward/build_dataset.py
vendored
Normal file
58
third_party/sglang/benchmark/json_jump_forward/build_dataset.py
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
|
||||
import transformers
|
||||
import wikipedia
|
||||
|
||||
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
city_names = [
|
||||
"los angles",
|
||||
"london",
|
||||
"tokyo",
|
||||
"beijing",
|
||||
"singapore",
|
||||
"paris",
|
||||
"dubai",
|
||||
"sydney",
|
||||
"moscow",
|
||||
"rome",
|
||||
"toronto",
|
||||
"rio de janeiro",
|
||||
"istanbul",
|
||||
"berlin",
|
||||
"auckland",
|
||||
"buenos aires",
|
||||
"mexico city",
|
||||
"mumbai",
|
||||
"seoul",
|
||||
"bangkok",
|
||||
"cairo",
|
||||
"athens",
|
||||
"jerusalem",
|
||||
]
|
||||
|
||||
|
||||
def get_content(city_name):
|
||||
content = str(wikipedia.page(city_name).content)
|
||||
content = content.replace("\n\n", "\n")
|
||||
|
||||
tokens = t.encode(content)
|
||||
|
||||
expected_tokens = 3000
|
||||
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
||||
truncate_content = content[:truncate_len]
|
||||
truncate_tokens = t.encode(truncate_content)
|
||||
|
||||
# Count token
|
||||
print(
|
||||
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||
)
|
||||
|
||||
return truncate_content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("questions.jsonl", "w") as fout:
|
||||
for city_name in city_names:
|
||||
truncate_content = get_content(city_name)
|
||||
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||
50
third_party/sglang/benchmark/json_jump_forward/dataset.txt
vendored
Normal file
50
third_party/sglang/benchmark/json_jump_forward/dataset.txt
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
Harry Potter
|
||||
Hermione Granger
|
||||
Ron Weasley
|
||||
Albus Dumbledore
|
||||
Severus Snape
|
||||
Rubeus Hagrid
|
||||
Draco Malfoy
|
||||
Ginny Weasley
|
||||
Fred Weasley
|
||||
George Weasley
|
||||
Percy Weasley
|
||||
Sirius Black
|
||||
Remus Lupin
|
||||
Neville Longbottom
|
||||
Luna Lovegood
|
||||
Cedric Diggory
|
||||
Cho Chang
|
||||
Lord Voldemort
|
||||
Minerva McGonagall
|
||||
Filius Flitwick
|
||||
Dolores Umbridge
|
||||
Bellatrix Lestrange
|
||||
Lucius Malfoy
|
||||
Molly Weasley
|
||||
Arthur Weasley
|
||||
Nymphadora Tonks
|
||||
Dobby
|
||||
Moaning Myrtle
|
||||
Peter Pettigrew
|
||||
Alastor 'Mad-Eye' Moody
|
||||
Horace Slughorn
|
||||
Vernon Dursley
|
||||
Petunia Dursley
|
||||
Dudley Dursley
|
||||
Argus Filch
|
||||
Sybill Trelawney
|
||||
Gilderoy Lockhart
|
||||
Fleur Delacour
|
||||
Viktor Krum
|
||||
Bill Weasley
|
||||
Oliver Wood
|
||||
Cornelius Fudge
|
||||
Barty Crouch Sr.
|
||||
Barty Crouch Jr.
|
||||
Kingsley Shacklebolt
|
||||
Quirinus Quirrell
|
||||
Nearly Headless Nick
|
||||
Aunt Marge
|
||||
Griphook
|
||||
Ludo Bagman
|
||||
15
third_party/sglang/benchmark/json_schema/README.md
vendored
Normal file
15
third_party/sglang/benchmark/json_schema/README.md
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
|
||||
Run Llama-8b
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
|
||||
```
|
||||
|
||||
Benchmark
|
||||
|
||||
```bash
|
||||
python3 bench_sglang.py
|
||||
```
|
||||
146
third_party/sglang/benchmark/json_schema/bench_sglang.py
vendored
Normal file
146
third_party/sglang/benchmark/json_schema/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,146 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import jsonschema
|
||||
from datasets import load_dataset
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
|
||||
@sgl.function
|
||||
def schema_gen(s, message: Tuple[str, str], json_schema: str):
|
||||
system, user = message
|
||||
s += sgl.system(system)
|
||||
s += sgl.user(user)
|
||||
s += sgl.assistant(
|
||||
sgl.gen("json_output", temperature=0, max_tokens=256, json_schema=json_schema)
|
||||
)
|
||||
|
||||
|
||||
def contains_formats(schema, formats: List[str]):
|
||||
if isinstance(schema, dict):
|
||||
if schema.get("format", None) in formats:
|
||||
return True
|
||||
for value in schema.values():
|
||||
if contains_formats(value, formats):
|
||||
return True
|
||||
elif isinstance(schema, list):
|
||||
for item in schema:
|
||||
if contains_formats(item, formats):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_dataset(path: str):
|
||||
raw_dataset = load_dataset(path)
|
||||
dataset = []
|
||||
for data in raw_dataset["train"]:
|
||||
messages = data["prompt"]
|
||||
schema = data["schema"]
|
||||
obj = json.loads(schema)
|
||||
|
||||
# skip some corrupted examples
|
||||
if obj.get("type", None) is None:
|
||||
continue
|
||||
|
||||
# skip schema with format "email"
|
||||
# which is not supported by outlines for now
|
||||
if contains_formats(obj, ["email"]):
|
||||
continue
|
||||
|
||||
system = messages[0]
|
||||
user = messages[1]
|
||||
assert system["role"] == "system", "invalid role"
|
||||
assert user["role"] == "user", "invalid role"
|
||||
assert len(messages) == 2, "invalid message length"
|
||||
message = json.dumps(system["content"]), json.dumps(user["content"])
|
||||
dataset.append(
|
||||
{
|
||||
"message": message,
|
||||
"json_schema": schema,
|
||||
}
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def bench_schema(args):
|
||||
arguments = convert_dataset(args.data_path)
|
||||
|
||||
if args.num_jsons < 0 or args.num_jsons > len(arguments):
|
||||
args.num_jsons = len(arguments)
|
||||
arguments = arguments[: args.num_jsons]
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
# Run requests
|
||||
tic = time.perf_counter()
|
||||
states = schema_gen.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Check if the outputs are valid
|
||||
indexes = []
|
||||
for i, state in enumerate(states):
|
||||
try:
|
||||
schema = json.loads(arguments[i]["json_schema"])
|
||||
obj = json.loads(state["json_output"])
|
||||
assert jsonschema.validate(obj, schema) is None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
indexes.append(i)
|
||||
|
||||
return states, latency
|
||||
|
||||
|
||||
def main(args):
|
||||
states, latency = bench_schema(args)
|
||||
|
||||
# Compute accuracy
|
||||
tokenizer = get_tokenizer(
|
||||
global_config.default_backend.get_server_info()["tokenizer_path"]
|
||||
)
|
||||
output_jsons = [state["json_output"] for state in states]
|
||||
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
|
||||
print(f"Latency: {latency:.3f}")
|
||||
print(f"Output throughput: {num_output_tokens / latency:.3f} token/s")
|
||||
print(f"#output tokens: {num_output_tokens}")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
with open(f"{args.backend}.jsonl", "w") as fout:
|
||||
for state in states:
|
||||
fout.write(state["json_output"] + "\n")
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "json_schema",
|
||||
"backend": args.backend,
|
||||
"latency": round(latency, 3),
|
||||
"num_jsons": args.num_jsons,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="NousResearch/json-mode-eval")
|
||||
parser.add_argument("--num-jsons", type=int, default=-1)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
330
third_party/sglang/benchmark/kernels/all_reduce/benchmark_aiter.py
vendored
Normal file
330
third_party/sglang/benchmark/kernels/all_reduce/benchmark_aiter.py
vendored
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Benchmark SGLang vs Aiter custom all-reduce across message sizes.
|
||||
Usage:
|
||||
torchrun --nproc_per_node=2 benchmark_aiter.py
|
||||
torchrun --nproc_per_node=4 benchmark_aiter.py
|
||||
torchrun --nproc_per_node=8 benchmark_aiter.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark SGLang vs Aiter custom all-reduce across message sizes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="gloo",
|
||||
help="Process group backend for the custom-AR control path (must NOT be nccl).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Warmup iterations per size per implementation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iters-small",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Benchmark iterations for sizes <= 1MB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iters-large",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Benchmark iterations for sizes > 1MB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Print per-iteration timings on rank 0 for debugging.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_env_rank_world() -> Tuple[int, int, int]:
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", str(rank)))
|
||||
return rank, world_size, local_rank
|
||||
|
||||
|
||||
def init_dist(backend: str):
|
||||
rank, world_size, _ = get_env_rank_world()
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method="env://",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
def get_device(local_rank: int) -> torch.device:
|
||||
torch.cuda.set_device(local_rank)
|
||||
return torch.device(f"cuda:{local_rank}")
|
||||
|
||||
|
||||
def human_size(num_bytes: int) -> str:
|
||||
units = [("B", 1), ("K", 1024), ("M", 1024 * 1024), ("G", 1024 * 1024 * 1024)]
|
||||
for suf, base in reversed(units):
|
||||
if num_bytes % base == 0 and num_bytes >= base:
|
||||
val = num_bytes // base
|
||||
return f"{val}{suf}"
|
||||
return f"{num_bytes}B"
|
||||
|
||||
|
||||
def get_message_sizes() -> List[int]:
|
||||
return [
|
||||
32 * 1024,
|
||||
64 * 1024,
|
||||
128 * 1024,
|
||||
256 * 1024,
|
||||
512 * 1024,
|
||||
1 * 1024 * 1024,
|
||||
2 * 1024 * 1024,
|
||||
4 * 1024 * 1024,
|
||||
8 * 1024 * 1024,
|
||||
16 * 1024 * 1024,
|
||||
32 * 1024 * 1024,
|
||||
64 * 1024 * 1024,
|
||||
]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_once(comm, inp: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
if hasattr(comm, "all_reduce_unreg"):
|
||||
return comm.all_reduce_unreg(inp)
|
||||
if hasattr(comm, "custom_all_reduce"):
|
||||
return comm.custom_all_reduce(inp)
|
||||
raise RuntimeError("No known all-reduce method found on the communicator.")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def bench_impl(
|
||||
name: str,
|
||||
comm,
|
||||
sizes: List[int],
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters_small: int,
|
||||
iters_large: int,
|
||||
verbose: bool,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
) -> List[Tuple[int, Optional[float]]]:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
results: List[Tuple[int, Optional[float]]] = []
|
||||
|
||||
for size_bytes in sizes:
|
||||
elems = size_bytes // 2 # float16: 2 bytes per element
|
||||
inp = torch.empty(elems, dtype=torch.float16, device=device)
|
||||
inp.uniform_(0, 1)
|
||||
|
||||
disabled = False
|
||||
dist.barrier(group=pg)
|
||||
for _ in range(warmup):
|
||||
torch.cuda.synchronize()
|
||||
out = run_once(comm, inp)
|
||||
torch.cuda.synchronize()
|
||||
if out is None:
|
||||
disabled = True
|
||||
break
|
||||
dist.barrier(group=pg)
|
||||
|
||||
if disabled:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[{name}] {human_size(size_bytes)}: custom AR disabled (skipped)"
|
||||
)
|
||||
results.append((size_bytes, None))
|
||||
continue
|
||||
|
||||
num_iters = iters_small if size_bytes <= (1 * 1024 * 1024) else iters_large
|
||||
|
||||
times_ms: List[float] = []
|
||||
for it in range(num_iters):
|
||||
dist.barrier(group=pg)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = run_once(comm, inp)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
dist.barrier(group=pg)
|
||||
|
||||
if out is None:
|
||||
disabled = True
|
||||
break
|
||||
|
||||
dt_ms = (t1 - t0) * 1000.0
|
||||
times_ms.append(dt_ms)
|
||||
|
||||
if verbose and rank == 0:
|
||||
print(
|
||||
f"[{name}] size={human_size(size_bytes)} iter={it} time={dt_ms:.3f} ms"
|
||||
)
|
||||
|
||||
if disabled or not times_ms:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[{name}] {human_size(size_bytes)}: custom AR disabled (no timings)"
|
||||
)
|
||||
results.append((size_bytes, None))
|
||||
continue
|
||||
|
||||
avg_ms_local = sum(times_ms) / len(times_ms)
|
||||
avg_tensor = torch.tensor([avg_ms_local], dtype=torch.float64, device=device)
|
||||
gather_list = [torch.zeros_like(avg_tensor) for _ in range(world_size)]
|
||||
dist.all_gather(gather_list, avg_tensor, group=pg)
|
||||
if rank == 0:
|
||||
avg_ms = float(torch.stack(gather_list).mean().item())
|
||||
print(
|
||||
f"[{name}] {human_size(size_bytes)}: {avg_ms:.3f} ms (avg across ranks)"
|
||||
)
|
||||
results.append((size_bytes, avg_ms))
|
||||
else:
|
||||
results.append((size_bytes, None))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
rank, world_size, local_rank = get_env_rank_world()
|
||||
|
||||
if world_size not in (2, 4, 6, 8):
|
||||
print(
|
||||
f"[rank {rank}] WARNING: world_size={world_size} not in supported set (2,4,6,8). "
|
||||
"Custom AR may disable itself.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
init_dist(args.backend)
|
||||
device = get_device(local_rank)
|
||||
|
||||
# Import after dist init; some libs query torch dist state on import
|
||||
sgl_comm = None
|
||||
aiter_comm = None
|
||||
HAVE_SGLANG = False
|
||||
HAVE_AITER = False
|
||||
|
||||
try:
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce as SGLCustomAllreduce,
|
||||
)
|
||||
|
||||
HAVE_SGLANG = True
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f"SGLang CustomAllreduce import failed: {e}", file=sys.stderr)
|
||||
|
||||
try:
|
||||
from aiter.dist.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce as AiterCustomAllreduce,
|
||||
)
|
||||
|
||||
HAVE_AITER = True
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f"Aiter CustomAllreduce import failed: {e}", file=sys.stderr)
|
||||
|
||||
if rank == 0:
|
||||
print(f"Initialized PG backend={args.backend} world_size={world_size}")
|
||||
print(f"Device: {device.type}:{device.index}")
|
||||
print(f"SGLang available: {HAVE_SGLANG}, Aiter available: {HAVE_AITER}")
|
||||
|
||||
pg = dist.group.WORLD
|
||||
sizes = get_message_sizes()
|
||||
max_size = max(sizes) if sizes else (64 * 1024 * 1024)
|
||||
|
||||
if HAVE_SGLANG:
|
||||
try:
|
||||
sgl_comm = SGLCustomAllreduce(group=pg, device=device, max_size=max_size)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Failed to construct SGLang CustomAllreduce: {e}", file=sys.stderr
|
||||
)
|
||||
sgl_comm = None
|
||||
|
||||
if HAVE_AITER:
|
||||
try:
|
||||
aiter_comm = AiterCustomAllreduce(
|
||||
group=pg, device=device, max_size=max_size
|
||||
)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Failed to construct Aiter CustomAllreduce: {e}", file=sys.stderr
|
||||
)
|
||||
aiter_comm = None
|
||||
|
||||
sgl_results: List[Tuple[int, Optional[float]]] = []
|
||||
aiter_results: List[Tuple[int, Optional[float]]] = []
|
||||
|
||||
if sgl_comm is not None:
|
||||
sgl_results = bench_impl(
|
||||
name="SGLang",
|
||||
comm=sgl_comm,
|
||||
sizes=sizes,
|
||||
device=device,
|
||||
warmup=args.warmup,
|
||||
iters_small=args.iters_small,
|
||||
iters_large=args.iters_large,
|
||||
verbose=args.verbose,
|
||||
pg=pg,
|
||||
)
|
||||
|
||||
if aiter_comm is not None:
|
||||
aiter_results = bench_impl(
|
||||
name="Aiter",
|
||||
comm=aiter_comm,
|
||||
sizes=sizes,
|
||||
device=device,
|
||||
warmup=args.warmup,
|
||||
iters_small=args.iters_small,
|
||||
iters_large=args.iters_large,
|
||||
verbose=args.verbose,
|
||||
pg=pg,
|
||||
)
|
||||
|
||||
for comm in (sgl_comm, aiter_comm):
|
||||
if comm is not None and hasattr(comm, "close"):
|
||||
try:
|
||||
comm.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print("\nResults (avg ms across ranks; None = disabled/unavailable):")
|
||||
header = f"{'Size':>8} {'SGLang(ms)':>12} {'Aiter(ms)':>11}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
sgl_map = {s: v for s, v in sgl_results if v is not None}
|
||||
aiter_map = {s: v for s, v in aiter_results if v is not None}
|
||||
|
||||
for s in sizes:
|
||||
sgl_ms = sgl_map.get(s, None)
|
||||
aiter_ms = aiter_map.get(s, None)
|
||||
print(
|
||||
f"{human_size(s):>8} {('%.3f' % sgl_ms) if sgl_ms is not None else 'None':>12} "
|
||||
f"{('%.3f' % aiter_ms) if aiter_ms is not None else 'None':>11}"
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
351
third_party/sglang/benchmark/kernels/all_reduce/benchmark_all_reduce.py
vendored
Normal file
351
third_party/sglang/benchmark/kernels/all_reduce/benchmark_all_reduce.py
vendored
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Benchmark SGLang custom all-reduce vs Torch symm-mem all-reduce across message sizes.
|
||||
Usage:
|
||||
torchrun --nproc_per_node=2 benchmark_all_reduce.py
|
||||
torchrun --nproc_per_node=4 benchmark_all_reduce.py
|
||||
torchrun --nproc_per_node=8 benchmark_all_reduce.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark SGLang custom all-reduce vs Torch symm-mem all-reduce across message sizes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="gloo",
|
||||
help="Process group backend for the custom-AR control path (must NOT be nccl).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Warmup iterations per size per implementation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iters-small",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Benchmark iterations for sizes <= 1MB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iters-large",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Benchmark iterations for sizes > 1MB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Print per-iteration timings on rank 0 for debugging.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_env_rank_world() -> Tuple[int, int, int]:
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", str(rank)))
|
||||
return rank, world_size, local_rank
|
||||
|
||||
|
||||
def init_dist(backend: str):
|
||||
rank, world_size, _ = get_env_rank_world()
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method="env://",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://localhost:23456"
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=rank,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def get_device(local_rank: int) -> torch.device:
|
||||
torch.cuda.set_device(local_rank)
|
||||
return torch.device(f"cuda:{local_rank}")
|
||||
|
||||
|
||||
def human_size(num_bytes: int) -> str:
|
||||
units = [("B", 1), ("K", 1024), ("M", 1024 * 1024), ("G", 1024 * 1024 * 1024)]
|
||||
for suf, base in reversed(units):
|
||||
if num_bytes % base == 0 and num_bytes >= base:
|
||||
val = num_bytes // base
|
||||
return f"{val}{suf}"
|
||||
return f"{num_bytes}B"
|
||||
|
||||
|
||||
def get_message_sizes() -> List[int]:
|
||||
return [
|
||||
32 * 1024,
|
||||
64 * 1024,
|
||||
128 * 1024,
|
||||
256 * 1024,
|
||||
512 * 1024,
|
||||
1 * 1024 * 1024,
|
||||
2 * 1024 * 1024,
|
||||
4 * 1024 * 1024,
|
||||
8 * 1024 * 1024,
|
||||
16 * 1024 * 1024,
|
||||
32 * 1024 * 1024,
|
||||
64 * 1024 * 1024,
|
||||
]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_once(comm, inp: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
if hasattr(comm, "custom_all_reduce"):
|
||||
return comm.custom_all_reduce(inp)
|
||||
if hasattr(comm, "all_reduce"):
|
||||
return comm.all_reduce(inp)
|
||||
raise RuntimeError("No known all-reduce method found on the communicator.")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def bench_impl(
|
||||
name: str,
|
||||
comm,
|
||||
sizes: List[int],
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters_small: int,
|
||||
iters_large: int,
|
||||
verbose: bool,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
) -> List[Tuple[int, Optional[float]]]:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
results: List[Tuple[int, Optional[float]]] = []
|
||||
|
||||
for size_bytes in sizes:
|
||||
elems = size_bytes // 2 # float16: 2 bytes per element
|
||||
inp = torch.empty(elems, dtype=torch.float16, device=device)
|
||||
inp.uniform_(0, 1)
|
||||
|
||||
disabled = False
|
||||
dist.barrier(group=pg)
|
||||
for _ in range(warmup):
|
||||
torch.cuda.synchronize()
|
||||
out = run_once(comm, inp)
|
||||
torch.cuda.synchronize()
|
||||
if out is None:
|
||||
disabled = True
|
||||
break
|
||||
dist.barrier(group=pg)
|
||||
|
||||
if disabled:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[{name}] {human_size(size_bytes)}: custom AR disabled (skipped)"
|
||||
)
|
||||
results.append((size_bytes, None))
|
||||
continue
|
||||
|
||||
num_iters = iters_small if size_bytes <= (1 * 1024 * 1024) else iters_large
|
||||
|
||||
times_ms: List[float] = []
|
||||
for it in range(num_iters):
|
||||
dist.barrier(group=pg)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = run_once(comm, inp)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
dist.barrier(group=pg)
|
||||
|
||||
if out is None:
|
||||
disabled = True
|
||||
break
|
||||
|
||||
dt_ms = (t1 - t0) * 1000.0
|
||||
times_ms.append(dt_ms)
|
||||
|
||||
if verbose and rank == 0:
|
||||
print(
|
||||
f"[{name}] size={human_size(size_bytes)} iter={it} time={dt_ms:.3f} ms"
|
||||
)
|
||||
|
||||
if disabled or not times_ms:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"[{name}] {human_size(size_bytes)}: custom AR disabled (no timings)"
|
||||
)
|
||||
results.append((size_bytes, None))
|
||||
continue
|
||||
|
||||
avg_ms_local = sum(times_ms) / len(times_ms)
|
||||
avg_tensor = torch.tensor([avg_ms_local], dtype=torch.float64, device=device)
|
||||
gather_list = [torch.zeros_like(avg_tensor) for _ in range(world_size)]
|
||||
dist.all_gather(gather_list, avg_tensor, group=pg)
|
||||
if rank == 0:
|
||||
avg_ms = float(torch.stack(gather_list).mean().item())
|
||||
print(
|
||||
f"[{name}] {human_size(size_bytes)}: {avg_ms:.3f} ms (avg across ranks)"
|
||||
)
|
||||
results.append((size_bytes, avg_ms))
|
||||
else:
|
||||
results.append((size_bytes, None))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
rank, world_size, local_rank = get_env_rank_world()
|
||||
|
||||
if world_size not in (2, 4, 6, 8):
|
||||
print(
|
||||
f"[rank {rank}] WARNING: world_size={world_size} not in supported set (2,4,6,8). "
|
||||
"Custom AR may disable itself.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
group = init_dist(args.backend)
|
||||
device = get_device(local_rank)
|
||||
|
||||
# Import after dist init; some libs query torch dist state on import
|
||||
torch_symm_mem_comm = None
|
||||
HAVE_SGLANG_CUSTOM = False
|
||||
HAVE_TORCH_SYMM_MEM = False
|
||||
|
||||
try:
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce as SGLCustomAllreduce,
|
||||
)
|
||||
|
||||
HAVE_SGLANG_CUSTOM = True
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f"SGLang CustomAllreduce import failed: {e}", file=sys.stderr)
|
||||
|
||||
try:
|
||||
from sglang.srt.distributed.device_communicators.torch_symm_mem import (
|
||||
TorchSymmMemCommunicator as TorchSymmMemAllreduce,
|
||||
)
|
||||
|
||||
HAVE_TORCH_SYMM_MEM = True
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f"TorchSymmMemAllreduce import failed: {e}", file=sys.stderr)
|
||||
|
||||
if rank == 0:
|
||||
print(f"Initialized PG backend={args.backend} world_size={world_size}")
|
||||
print(f"Device: {device.type}:{device.index}")
|
||||
print(
|
||||
f"SGLang Custom available: {HAVE_SGLANG_CUSTOM}, Torch Symm-Mem available: {HAVE_TORCH_SYMM_MEM}"
|
||||
)
|
||||
|
||||
sizes = get_message_sizes()
|
||||
max_size = max(sizes) if sizes else (128 * 1024 * 1024)
|
||||
|
||||
if HAVE_SGLANG_CUSTOM:
|
||||
try:
|
||||
sgl_custom_comm = SGLCustomAllreduce(
|
||||
group=group, device=device, max_size=max_size
|
||||
)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Failed to construct SGLangCustomAllreduce: {e}", file=sys.stderr
|
||||
)
|
||||
sgl_custom_comm = None
|
||||
|
||||
if HAVE_TORCH_SYMM_MEM:
|
||||
try:
|
||||
torch_symm_mem_comm = TorchSymmMemAllreduce(group=group, device=device)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Failed to construct TorchSymmMemAllreduce: {e}", file=sys.stderr
|
||||
)
|
||||
torch_symm_mem_comm = None
|
||||
|
||||
sgl_custom_results: List[Tuple[int, Optional[float]]] = []
|
||||
symm_mem_results: List[Tuple[int, Optional[float]]] = []
|
||||
|
||||
if sgl_custom_comm is not None:
|
||||
sgl_custom_results = bench_impl(
|
||||
name="SGLangCustom",
|
||||
comm=sgl_custom_comm,
|
||||
sizes=sizes,
|
||||
device=device,
|
||||
warmup=args.warmup,
|
||||
iters_small=args.iters_small,
|
||||
iters_large=args.iters_large,
|
||||
verbose=args.verbose,
|
||||
pg=group,
|
||||
)
|
||||
|
||||
if torch_symm_mem_comm is not None:
|
||||
symm_mem_results = bench_impl(
|
||||
name="TorchSymmMem",
|
||||
comm=torch_symm_mem_comm,
|
||||
sizes=sizes,
|
||||
device=device,
|
||||
warmup=args.warmup,
|
||||
iters_small=args.iters_small,
|
||||
iters_large=args.iters_large,
|
||||
verbose=args.verbose,
|
||||
pg=group,
|
||||
)
|
||||
|
||||
for comm in (sgl_custom_comm, torch_symm_mem_comm):
|
||||
if comm is not None and hasattr(comm, "close"):
|
||||
try:
|
||||
comm.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"\nResults (avg ms across {world_size} ranks; None = disabled/unavailable):"
|
||||
)
|
||||
header = f"{'Size':>8} {'CustomAR(ms)':>12} {'TorchSymmMem(ms)':>11}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
sgl_custom_map = {s: v for s, v in sgl_custom_results if v is not None}
|
||||
symm_mem_map = {s: v for s, v in symm_mem_results if v is not None}
|
||||
|
||||
for s in sizes:
|
||||
sgl_ms = sgl_custom_map.get(s, None)
|
||||
symm_mem_ms = symm_mem_map.get(s, None)
|
||||
print(
|
||||
f"{human_size(s):>8} {('%.3f' % sgl_ms) if sgl_ms is not None else 'None':>12} "
|
||||
f"{('%.3f' % symm_mem_ms) if symm_mem_ms is not None else 'None':>11}"
|
||||
)
|
||||
torch.distributed.barrier(group=group)
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
536
third_party/sglang/benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py
vendored
Normal file
536
third_party/sglang/benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
"""
|
||||
Benchmark fused allreduce+rmsnorm on AMD with correctness checks.
|
||||
|
||||
This script targets the same fused op used by SGLang:
|
||||
`tensor_model_parallel_fused_allreduce_rmsnorm`.
|
||||
|
||||
It reports:
|
||||
- eager mode latency (prefill-like)
|
||||
- graph mode latency (decode-like)
|
||||
- fused availability (whether fused path returns non-None)
|
||||
- correctness (fused output matches split allreduce + rmsnorm reference)
|
||||
|
||||
Usage example:
|
||||
torchrun --nproc_per_node=8 \
|
||||
benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py \
|
||||
--dtype bfloat16 \
|
||||
--prefill-shapes 2048x8192,8192x8192 \
|
||||
--decode-shapes 1x8192,4x8192,16x8192 \
|
||||
--warmup 10 --iters 30 --repeats 5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import statistics
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.distributed.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_fused_allreduce_rmsnorm,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
graph_capture,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
|
||||
Shape = Tuple[int, int]
|
||||
|
||||
|
||||
def parse_shapes(raw: str) -> List[Shape]:
|
||||
shapes: List[Shape] = []
|
||||
for item in [x.strip() for x in raw.split(",") if x.strip()]:
|
||||
if "x" not in item:
|
||||
raise ValueError(f"Invalid shape '{item}', expected MxN format.")
|
||||
m_str, n_str = item.split("x", 1)
|
||||
m = int(m_str)
|
||||
n = int(n_str)
|
||||
if m <= 0 or n <= 0:
|
||||
raise ValueError(f"Invalid shape '{item}', both dims must be positive.")
|
||||
shapes.append((m, n))
|
||||
if not shapes:
|
||||
raise ValueError("Empty shape list is not allowed.")
|
||||
return shapes
|
||||
|
||||
|
||||
def dtype_from_name(name: str) -> torch.dtype:
|
||||
mapping = {
|
||||
"float16": torch.float16,
|
||||
"fp16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
if name not in mapping:
|
||||
raise ValueError(f"Unsupported dtype: {name}")
|
||||
return mapping[name]
|
||||
|
||||
|
||||
def check_close(
|
||||
a: torch.Tensor, b: torch.Tensor, dtype: torch.dtype
|
||||
) -> Tuple[bool, str]:
|
||||
if dtype == torch.bfloat16:
|
||||
rtol, atol = 2e-2, 1.25e-1
|
||||
else:
|
||||
rtol, atol = 1e-2, 2e-2
|
||||
try:
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
return True, "PASS"
|
||||
except AssertionError:
|
||||
max_diff = torch.max(torch.abs(a - b)).item()
|
||||
mean_diff = torch.mean(torch.abs(a - b)).item()
|
||||
return False, f"FAIL(max={max_diff:.6f},mean={mean_diff:.6f})"
|
||||
|
||||
|
||||
def _measure_us(
|
||||
fn,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
repeats: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[float, Dict[str, float]]:
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
samples_us: List[float] = []
|
||||
|
||||
for _ in range(max(1, repeats)):
|
||||
_barrier(device)
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(iters):
|
||||
fn()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
samples_us.append(start_event.elapsed_time(end_event) * 1000.0 / iters)
|
||||
|
||||
sorted_samples = sorted(samples_us)
|
||||
p50 = float(statistics.median(sorted_samples))
|
||||
p95 = float(sorted_samples[int((len(sorted_samples) - 1) * 0.95)])
|
||||
return p50, {
|
||||
"p50_us": p50,
|
||||
"p95_us": p95,
|
||||
"min_us": float(sorted_samples[0]),
|
||||
"max_us": float(sorted_samples[-1]),
|
||||
}
|
||||
|
||||
|
||||
def _barrier(device: torch.device):
|
||||
try:
|
||||
dist.barrier(device_ids=[device.index])
|
||||
except TypeError:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def _mean_across_ranks(value: float, device: torch.device) -> float:
|
||||
t = torch.tensor([value], dtype=torch.float64, device=device)
|
||||
dist.all_reduce(t, op=dist.ReduceOp.SUM)
|
||||
t /= dist.get_world_size()
|
||||
return float(t.item())
|
||||
|
||||
|
||||
def _all_true_across_ranks(value: bool, device: torch.device) -> bool:
|
||||
t = torch.tensor([1 if value else 0], dtype=torch.int32, device=device)
|
||||
dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||
return bool(int(t.item()))
|
||||
|
||||
|
||||
def _make_inputs(
|
||||
shape: Shape,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
residual_mode: str,
|
||||
rank: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
m, n = shape
|
||||
torch.manual_seed(seed + rank * 17)
|
||||
x = torch.randn((m, n), dtype=torch.float32, device=device).to(dtype)
|
||||
if residual_mode == "self":
|
||||
residual = x.clone()
|
||||
elif residual_mode == "random":
|
||||
residual = torch.randn((m, n), dtype=torch.float32, device=device).to(dtype)
|
||||
elif residual_mode == "zero":
|
||||
residual = torch.zeros((m, n), dtype=dtype, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unknown residual_mode: {residual_mode}")
|
||||
weight = torch.randn((n,), dtype=torch.float32, device=device).to(dtype)
|
||||
return x, residual, weight
|
||||
|
||||
|
||||
def _split_reference(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
ar_out = tensor_model_parallel_all_reduce(x.clone())
|
||||
residual_out = ar_out + residual
|
||||
out = F.rms_norm(
|
||||
input=residual_out,
|
||||
normalized_shape=(residual_out.shape[-1],),
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
)
|
||||
return out, residual_out
|
||||
|
||||
|
||||
def bench_eager(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
repeats: int,
|
||||
) -> Dict[str, object]:
|
||||
split_fn = lambda: _split_reference(x, residual, weight, eps)
|
||||
split_us, split_stats = _measure_us(split_fn, warmup, iters, repeats, x.device)
|
||||
|
||||
fused_probe = tensor_model_parallel_fused_allreduce_rmsnorm(
|
||||
x.clone(), residual.clone(), weight, eps
|
||||
)
|
||||
fused_available = fused_probe is not None
|
||||
|
||||
fused_us: Optional[float] = None
|
||||
fused_stats: Optional[Dict[str, float]] = None
|
||||
if fused_available:
|
||||
fused_fn = lambda: tensor_model_parallel_fused_allreduce_rmsnorm(
|
||||
x, residual, weight, eps
|
||||
)
|
||||
fused_us, fused_stats = _measure_us(fused_fn, warmup, iters, repeats, x.device)
|
||||
|
||||
ref_out, ref_residual = _split_reference(x, residual, weight, eps)
|
||||
if fused_available:
|
||||
fused_out, fused_residual = tensor_model_parallel_fused_allreduce_rmsnorm(
|
||||
x.clone(), residual.clone(), weight, eps
|
||||
)
|
||||
out_ok, out_detail = check_close(fused_out, ref_out, x.dtype)
|
||||
res_ok, res_detail = check_close(fused_residual, ref_residual, x.dtype)
|
||||
correctness_ok = out_ok and res_ok
|
||||
correctness_detail = f"out={out_detail}, residual={res_detail}"
|
||||
else:
|
||||
correctness_ok = True
|
||||
correctness_detail = "SKIP(fused_unavailable)"
|
||||
|
||||
return {
|
||||
"split_us": split_us,
|
||||
"split_stats": split_stats,
|
||||
"fused_available": fused_available,
|
||||
"fused_us": fused_us,
|
||||
"fused_stats": fused_stats,
|
||||
"correctness_ok": correctness_ok,
|
||||
"correctness_detail": correctness_detail,
|
||||
}
|
||||
|
||||
|
||||
def bench_graph(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
repeats: int,
|
||||
) -> Dict[str, object]:
|
||||
split_x = x.clone()
|
||||
split_res = residual.clone()
|
||||
split_graph_out: Optional[torch.Tensor] = None
|
||||
|
||||
with graph_capture() as gc:
|
||||
split_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(split_graph, stream=gc.stream):
|
||||
split_graph_out, _ = _split_reference(split_x, split_res, weight, eps)
|
||||
|
||||
def split_replay():
|
||||
split_graph.replay()
|
||||
|
||||
split_us, split_stats = _measure_us(split_replay, warmup, iters, repeats, x.device)
|
||||
|
||||
fused_probe = tensor_model_parallel_fused_allreduce_rmsnorm(
|
||||
x.clone(), residual.clone(), weight, eps
|
||||
)
|
||||
fused_available = fused_probe is not None
|
||||
|
||||
fused_us: Optional[float] = None
|
||||
fused_stats: Optional[Dict[str, float]] = None
|
||||
fused_graph_out: Optional[torch.Tensor] = None
|
||||
fused_graph_residual: Optional[torch.Tensor] = None
|
||||
|
||||
if fused_available:
|
||||
fused_x = x.clone()
|
||||
fused_res = residual.clone()
|
||||
with graph_capture() as gc:
|
||||
fused_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(fused_graph, stream=gc.stream):
|
||||
fused_graph_out, fused_graph_residual = (
|
||||
tensor_model_parallel_fused_allreduce_rmsnorm(
|
||||
fused_x, fused_res, weight, eps
|
||||
)
|
||||
)
|
||||
|
||||
def fused_replay():
|
||||
fused_graph.replay()
|
||||
|
||||
fused_us, fused_stats = _measure_us(
|
||||
fused_replay, warmup, iters, repeats, x.device
|
||||
)
|
||||
|
||||
ref_out, ref_residual = _split_reference(x, residual, weight, eps)
|
||||
if (
|
||||
fused_available
|
||||
and fused_graph_out is not None
|
||||
and fused_graph_residual is not None
|
||||
):
|
||||
fused_graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
out_ok, out_detail = check_close(fused_graph_out, ref_out, x.dtype)
|
||||
res_ok, res_detail = check_close(fused_graph_residual, ref_residual, x.dtype)
|
||||
correctness_ok = out_ok and res_ok
|
||||
correctness_detail = f"out={out_detail}, residual={res_detail}"
|
||||
else:
|
||||
correctness_ok = True
|
||||
correctness_detail = "SKIP(fused_unavailable)"
|
||||
|
||||
return {
|
||||
"split_us": split_us,
|
||||
"split_stats": split_stats,
|
||||
"fused_available": fused_available,
|
||||
"fused_us": fused_us,
|
||||
"fused_stats": fused_stats,
|
||||
"correctness_ok": correctness_ok,
|
||||
"correctness_detail": correctness_detail,
|
||||
}
|
||||
|
||||
|
||||
def _shape_bytes(shape: Shape, dtype: torch.dtype) -> int:
|
||||
m, n = shape
|
||||
return m * n * torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark fused allreduce+rmsnorm (prefill eager + decode graph)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp16", "bf16", "float16", "bfloat16"],
|
||||
)
|
||||
parser.add_argument("--eps", type=float, default=1e-6)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument(
|
||||
"--residual-mode",
|
||||
type=str,
|
||||
default="self",
|
||||
choices=["self", "random", "zero"],
|
||||
help="Use residual=x (self) to match aiter test behavior by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-shapes",
|
||||
type=str,
|
||||
default="2048x8192,8192x8192,16384x8192",
|
||||
help="Comma-separated MxN shapes for eager mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-shapes",
|
||||
type=str,
|
||||
default="1x8192,2x8192,4x8192,8x8192,16x8192",
|
||||
help="Comma-separated MxN shapes for graph mode.",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=10)
|
||||
parser.add_argument("--iters", type=int, default=30)
|
||||
parser.add_argument("--repeats", type=int, default=5)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="both",
|
||||
choices=["eager", "graph", "both"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-out",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional output CSV path (written on rank 0 only).",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
dtype = dtype_from_name(args.dtype)
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", str(rank)))
|
||||
torch.cuda.set_device(local_rank % torch.cuda.device_count())
|
||||
device = torch.device(f"cuda:{local_rank % torch.cuda.device_count()}")
|
||||
|
||||
set_custom_all_reduce(True)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method="env://",
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
prefill_shapes = parse_shapes(args.prefill_shapes)
|
||||
decode_shapes = parse_shapes(args.decode_shapes)
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
"Config: "
|
||||
f"world_size={world_size}, dtype={dtype}, residual_mode={args.residual_mode}, "
|
||||
f"warmup={args.warmup}, iters={args.iters}, repeats={args.repeats}"
|
||||
)
|
||||
|
||||
run_modes: Sequence[str]
|
||||
if args.mode == "both":
|
||||
run_modes = ("eager", "graph")
|
||||
else:
|
||||
run_modes = (args.mode,)
|
||||
csv_rows: List[Dict[str, object]] = []
|
||||
|
||||
for mode in run_modes:
|
||||
shapes = prefill_shapes if mode == "eager" else decode_shapes
|
||||
if rank == 0:
|
||||
phase_name = "prefill(eager)" if mode == "eager" else "decode(graph)"
|
||||
print("\n" + "=" * 120)
|
||||
print(f"Mode: {phase_name}")
|
||||
print(
|
||||
"| Shape | Input bytes/rank | Split p50 (us) | Fused p50 (us) | Speedup | Fused available | Correctness |"
|
||||
)
|
||||
print(
|
||||
"|:------|-----------------:|---------------:|---------------:|--------:|:----------------|:------------|"
|
||||
)
|
||||
|
||||
for shape in shapes:
|
||||
x, residual, weight = _make_inputs(
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
seed=args.seed,
|
||||
residual_mode=args.residual_mode,
|
||||
rank=rank,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if mode == "eager":
|
||||
metrics = bench_eager(
|
||||
x=x,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
eps=args.eps,
|
||||
warmup=args.warmup,
|
||||
iters=args.iters,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
else:
|
||||
metrics = bench_graph(
|
||||
x=x,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
eps=args.eps,
|
||||
warmup=args.warmup,
|
||||
iters=args.iters,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
|
||||
split_us = _mean_across_ranks(float(metrics["split_us"]), device)
|
||||
fused_available = _all_true_across_ranks(
|
||||
bool(metrics["fused_available"]), device
|
||||
)
|
||||
correctness_ok = _all_true_across_ranks(
|
||||
bool(metrics["correctness_ok"]), device
|
||||
)
|
||||
|
||||
fused_us: Optional[float] = None
|
||||
if fused_available and metrics["fused_us"] is not None:
|
||||
fused_us = _mean_across_ranks(float(metrics["fused_us"]), device)
|
||||
|
||||
if rank == 0:
|
||||
m, n = shape
|
||||
shape_str = f"{m}x{n}"
|
||||
bytes_per_rank = _shape_bytes(shape, dtype)
|
||||
if fused_us is not None and fused_us > 0:
|
||||
speedup = split_us / fused_us
|
||||
speedup_str = f"{speedup:.3f}x"
|
||||
fused_str = f"{fused_us:.1f}"
|
||||
else:
|
||||
speedup_str = "N/A"
|
||||
fused_str = "N/A"
|
||||
correctness_text = (
|
||||
"PASS" if correctness_ok else str(metrics["correctness_detail"])
|
||||
)
|
||||
print(
|
||||
f"| {shape_str} | {bytes_per_rank} | {split_us:.1f} | {fused_str} | "
|
||||
f"{speedup_str} | {str(fused_available)} | {correctness_text} |"
|
||||
)
|
||||
csv_rows.append(
|
||||
{
|
||||
"mode": mode,
|
||||
"shape": shape_str,
|
||||
"m": m,
|
||||
"n": n,
|
||||
"bytes_per_rank": bytes_per_rank,
|
||||
"split_p50_us": split_us,
|
||||
"fused_p50_us": fused_us if fused_us is not None else "",
|
||||
"speedup_split_over_fused": (
|
||||
split_us / fused_us
|
||||
if fused_us is not None and fused_us > 0
|
||||
else ""
|
||||
),
|
||||
"fused_available": fused_available,
|
||||
"correctness_ok": correctness_ok,
|
||||
"correctness_detail": correctness_text,
|
||||
"dtype": str(dtype),
|
||||
"world_size": world_size,
|
||||
"residual_mode": args.residual_mode,
|
||||
"warmup": args.warmup,
|
||||
"iters": args.iters,
|
||||
"repeats": args.repeats,
|
||||
}
|
||||
)
|
||||
|
||||
if rank == 0 and args.csv_out:
|
||||
os.makedirs(os.path.dirname(args.csv_out) or ".", exist_ok=True)
|
||||
fieldnames = [
|
||||
"mode",
|
||||
"shape",
|
||||
"m",
|
||||
"n",
|
||||
"bytes_per_rank",
|
||||
"split_p50_us",
|
||||
"fused_p50_us",
|
||||
"speedup_split_over_fused",
|
||||
"fused_available",
|
||||
"correctness_ok",
|
||||
"correctness_detail",
|
||||
"dtype",
|
||||
"world_size",
|
||||
"residual_mode",
|
||||
"warmup",
|
||||
"iters",
|
||||
"repeats",
|
||||
]
|
||||
with open(args.csv_out, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(csv_rows)
|
||||
print(f"\nSaved CSV to: {args.csv_out}")
|
||||
|
||||
_barrier(device)
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
224
third_party/sglang/benchmark/kernels/all_reduce/benchmark_mscclpp.py
vendored
Normal file
224
third_party/sglang/benchmark/kernels/all_reduce/benchmark_mscclpp.py
vendored
Normal file
@@ -0,0 +1,224 @@
|
||||
"""For Now, MSCCL is only supported on TP16 and TP8 case
|
||||
|
||||
export WORLD_SIZE=1
|
||||
export RANK=0
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export MASTER_PORT=12345
|
||||
|
||||
torchrun --nproc_per_node gpu \
|
||||
--nnodes $WORLD_SIZE \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed import init_distributed_environment
|
||||
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
|
||||
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
graph_capture,
|
||||
initialize_model_parallel,
|
||||
set_mscclpp_all_reduce,
|
||||
)
|
||||
|
||||
|
||||
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
|
||||
dist.all_reduce(torch_input, group=group)
|
||||
return torch_input
|
||||
|
||||
|
||||
def msccl_allreduce(
|
||||
msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator
|
||||
) -> torch.Tensor:
|
||||
return msccl_comm.all_reduce(msccl_input)
|
||||
|
||||
|
||||
def pynccl_allreduce(
|
||||
msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
|
||||
) -> torch.Tensor:
|
||||
pynccl_comm.all_reduce(msccl_input)
|
||||
return msccl_input
|
||||
|
||||
|
||||
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
|
||||
graph_input = inp_randn.clone()
|
||||
with graph_capture() as graph_capture_context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for _ in range(graph_loop):
|
||||
graph_out = func(graph_input)
|
||||
|
||||
graph.replay()
|
||||
func_output = graph_out.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for _ in range(test_loop):
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
|
||||
graph.reset()
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
|
||||
eager_input = inp_randn.clone()
|
||||
eager_output = func(eager_input)
|
||||
func_output = eager_output.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
func(eager_input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(test_loop):
|
||||
func(eager_input)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
|
||||
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def get_torch_prof_ctx(do_prof: bool):
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
)
|
||||
if do_prof
|
||||
else nullcontext()
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=1):
|
||||
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
||||
if size < 1024.0 or unit == "PiB":
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f} {unit}"
|
||||
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("tabulate not installed, skipping table printing")
|
||||
tabulate = None
|
||||
|
||||
|
||||
def print_markdown_table(data):
|
||||
if tabulate is not None:
|
||||
print(tabulate(data, headers="keys", tablefmt="github"))
|
||||
return
|
||||
headers = data[0].keys()
|
||||
header_row = "| " + " | ".join(headers) + " |"
|
||||
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||||
rows = []
|
||||
for item in data:
|
||||
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
|
||||
rows.append(row)
|
||||
markdown_table = "\n".join([header_row, separator] + rows)
|
||||
print(markdown_table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
world, world_size = dist.group.WORLD, dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank % 8)
|
||||
device = torch.cuda.current_device()
|
||||
set_mscclpp_all_reduce(True)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
local_rank=rank % 8,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
cpu_group = get_tensor_model_parallel_group().cpu_group
|
||||
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
|
||||
pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm
|
||||
dist.barrier()
|
||||
profile = False
|
||||
dtype = torch.bfloat16
|
||||
ctx = get_torch_prof_ctx(profile)
|
||||
result = []
|
||||
|
||||
with ctx:
|
||||
for i in range(10, 20):
|
||||
sz = 2**i
|
||||
if sz * dtype.itemsize > 2**20:
|
||||
break
|
||||
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
|
||||
memory = torch.empty_like(inp_randn)
|
||||
memory_out = torch.empty_like(memory)
|
||||
torch_eager_output, torch_eager_time = _bench_eager_time(
|
||||
lambda inp: torch_allreduce(inp, group), inp_randn
|
||||
)
|
||||
msccl_eager_output, msccl_eager_time = _bench_eager_time(
|
||||
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
|
||||
)
|
||||
msccl_graph_output, msccl_graph_time = _bench_graph_time(
|
||||
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
|
||||
)
|
||||
# since pynccl is inplace op, this return result is not correct if graph loop > 1
|
||||
_, pynccl_graph_time = _bench_graph_time(
|
||||
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
|
||||
)
|
||||
torch.testing.assert_close(torch_eager_output, msccl_graph_output)
|
||||
torch.testing.assert_close(torch_eager_output, msccl_eager_output)
|
||||
result.append(
|
||||
{
|
||||
"msg_size": human_readable_size(inp_randn.nbytes),
|
||||
"torch eager time": torch_eager_time,
|
||||
"msccl eager time": msccl_eager_time,
|
||||
"msccl graph time": msccl_graph_time,
|
||||
"pynccl graph time": pynccl_graph_time,
|
||||
}
|
||||
)
|
||||
if rank == 0:
|
||||
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
|
||||
if rank == 0:
|
||||
print_markdown_table(result)
|
||||
if profile:
|
||||
prof_dir = f"prof/msccl"
|
||||
os.makedirs(prof_dir, exist_ok=True)
|
||||
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
|
||||
248
third_party/sglang/benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py
vendored
Normal file
248
third_party/sglang/benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py
vendored
Normal file
@@ -0,0 +1,248 @@
|
||||
"""For Now, TORCH_SYMM_MEM is only supported on following limited tp case
|
||||
|
||||
SM90: {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 64 * MiB, # 64 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
},
|
||||
SM100: {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 64 * MiB, # 64 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
}
|
||||
|
||||
export WORLD_SIZE=8
|
||||
export RANK=0
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export MASTER_PORT=12345
|
||||
|
||||
torchrun --nproc_per_node gpu \
|
||||
--nnodes $WORLD_SIZE \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_torch_symm_mem.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.distributed import init_distributed_environment
|
||||
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from sglang.srt.distributed.device_communicators.torch_symm_mem import (
|
||||
TorchSymmMemCommunicator,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
graph_capture,
|
||||
initialize_model_parallel,
|
||||
set_torch_symm_mem_all_reduce,
|
||||
)
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
IS_CI = is_in_ci()
|
||||
|
||||
|
||||
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
|
||||
dist.all_reduce(torch_input, group=group)
|
||||
return torch_input
|
||||
|
||||
|
||||
def torch_symm_mem_allreduce(
|
||||
torch_symm_mem_input: torch.Tensor, torch_symm_mem_comm: TorchSymmMemCommunicator
|
||||
) -> torch.Tensor:
|
||||
return torch_symm_mem_comm.all_reduce(torch_symm_mem_input)
|
||||
|
||||
|
||||
def pynccl_allreduce(
|
||||
pynccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
|
||||
) -> torch.Tensor:
|
||||
pynccl_comm.all_reduce(pynccl_input)
|
||||
return pynccl_input
|
||||
|
||||
|
||||
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
|
||||
graph_input = inp_randn.clone()
|
||||
with graph_capture() as graph_capture_context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
||||
for _ in range(graph_loop):
|
||||
graph_out = func(graph_input)
|
||||
|
||||
graph.replay()
|
||||
func_output = graph_out.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for _ in range(test_loop):
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
|
||||
graph.reset()
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
|
||||
eager_input = inp_randn.clone()
|
||||
eager_output = func(eager_input)
|
||||
func_output = eager_output.clone()
|
||||
|
||||
for _ in range(warmup_loop):
|
||||
func(eager_input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(test_loop):
|
||||
func(eager_input)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
|
||||
|
||||
return func_output, func_cost_us
|
||||
|
||||
|
||||
def get_torch_prof_ctx(do_prof: bool):
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
)
|
||||
if do_prof
|
||||
else nullcontext()
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=1):
|
||||
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
|
||||
if size < 1024.0 or unit == "PiB":
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f} {unit}"
|
||||
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("tabulate not installed, skipping table printing")
|
||||
tabulate = None
|
||||
|
||||
|
||||
def print_markdown_table(data):
|
||||
if tabulate is not None:
|
||||
print(tabulate(data, headers="keys", tablefmt="github"))
|
||||
return
|
||||
headers = data[0].keys()
|
||||
header_row = "| " + " | ".join(headers) + " |"
|
||||
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
|
||||
rows = []
|
||||
for item in data:
|
||||
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
|
||||
rows.append(row)
|
||||
markdown_table = "\n".join([header_row, separator] + rows)
|
||||
print(markdown_table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
world, world_size = dist.group.WORLD, dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank % 8)
|
||||
device = torch.cuda.current_device()
|
||||
set_torch_symm_mem_all_reduce(True)
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
local_rank=rank % 8,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
cpu_group = get_tensor_model_parallel_group().cpu_group
|
||||
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
|
||||
torch_symm_mem_comm = get_tensor_model_parallel_group().torch_symm_mem_comm
|
||||
dist.barrier()
|
||||
profile = False
|
||||
dtype = torch.bfloat16
|
||||
ctx = get_torch_prof_ctx(profile)
|
||||
result = []
|
||||
|
||||
with ctx:
|
||||
if IS_CI:
|
||||
i_range = range(10, 11)
|
||||
else:
|
||||
i_range = range(10, 20)
|
||||
for i in i_range:
|
||||
sz = 2**i
|
||||
if sz * dtype.itemsize > 2**24:
|
||||
break
|
||||
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
|
||||
memory = torch.empty_like(inp_randn)
|
||||
memory_out = torch.empty_like(memory)
|
||||
torch_eager_output, torch_eager_time = _bench_eager_time(
|
||||
lambda inp: torch_allreduce(inp, group), inp_randn
|
||||
)
|
||||
symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time(
|
||||
lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),
|
||||
inp_randn,
|
||||
)
|
||||
symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time(
|
||||
lambda inp: torch_symm_mem_allreduce(inp, torch_symm_mem_comm),
|
||||
inp_randn,
|
||||
)
|
||||
# since pynccl is inplace op, this return result is not correct if graph loop > 1
|
||||
_, pynccl_graph_time = _bench_graph_time(
|
||||
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
|
||||
)
|
||||
torch.testing.assert_close(torch_eager_output, symm_mem_graph_output)
|
||||
torch.testing.assert_close(torch_eager_output, symm_mem_eager_output)
|
||||
result.append(
|
||||
{
|
||||
"msg_size": human_readable_size(inp_randn.nbytes),
|
||||
"torch eager time": torch_eager_time,
|
||||
"symm mem eager time": symm_mem_eager_time,
|
||||
"symm mem graph time": symm_mem_graph_time,
|
||||
"pynccl graph time": pynccl_graph_time,
|
||||
}
|
||||
)
|
||||
if rank == 0:
|
||||
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
|
||||
if rank == 0:
|
||||
print_markdown_table(result)
|
||||
if profile:
|
||||
prof_dir = f"prof/torch_symm_mem"
|
||||
os.makedirs(prof_dir, exist_ok=True)
|
||||
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
|
||||
403
third_party/sglang/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py
vendored
Normal file
403
third_party/sglang/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py
vendored
Normal file
@@ -0,0 +1,403 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import cudnn
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
|
||||
from sglang.srt.layers.attention.flashinfer_backend import should_use_tensor_core
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
||||
|
||||
|
||||
def benchmark_forward(
|
||||
fn,
|
||||
*inputs,
|
||||
repeats=10,
|
||||
amp=False,
|
||||
amp_dtype=torch.float16,
|
||||
**kwinputs,
|
||||
):
|
||||
def amp_wrapper(*inputs, **kwinputs):
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
fn(*inputs, **kwinputs)
|
||||
|
||||
t = benchmark.Timer(
|
||||
stmt="fn_amp(*inputs, **kwinputs)",
|
||||
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
return t, m
|
||||
|
||||
|
||||
def time_fwd(func, *args, **kwargs):
|
||||
time_f = benchmark_forward(func, *args, **kwargs)
|
||||
return time_f[1].mean * 1e6
|
||||
|
||||
|
||||
def decode_attention_sglang(
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
num_kv_splits,
|
||||
warmup=10,
|
||||
):
|
||||
|
||||
k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
|
||||
v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
|
||||
o = torch.empty_like(q)
|
||||
total_tokens = batch_size * kv_len
|
||||
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
|
||||
max_len_in_batch = kv_len
|
||||
sm_scale = 1.0 / (head_dim**0.5)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(batch_size, head_num_q, num_kv_splits, head_dim + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
for _ in range(warmup):
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
f = time_fwd(
|
||||
decode_attention_fwd,
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
return f, o
|
||||
|
||||
|
||||
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=dtype,
|
||||
num_attention_heads=head_num_q,
|
||||
num_kv_heads=head_num_kv,
|
||||
)
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
|
||||
class FlashinferAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
dtype,
|
||||
warmup=10,
|
||||
):
|
||||
total_tokens = batch_size * kv_len
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full(
|
||||
(batch_size,), 1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=dtype,
|
||||
)
|
||||
|
||||
for _ in range(warmup):
|
||||
o = flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, head_num_q, head_dim), kv_data
|
||||
)
|
||||
|
||||
f = time_fwd(
|
||||
flashinfer_decode_wrapper.forward,
|
||||
q.contiguous().view(-1, head_num_q, head_dim),
|
||||
kv_data,
|
||||
)
|
||||
|
||||
return f, o
|
||||
|
||||
return FlashinferAttention
|
||||
|
||||
|
||||
def convert_to_cudnn_type(torch_type):
|
||||
if torch_type == torch.float16:
|
||||
return cudnn.data_type.HALF
|
||||
elif torch_type == torch.bfloat16:
|
||||
return cudnn.data_type.BFLOAT16
|
||||
elif torch_type == torch.float32:
|
||||
return cudnn.data_type.FLOAT
|
||||
elif torch_type == torch.int32:
|
||||
return cudnn.data_type.INT32
|
||||
elif torch_type == torch.int64:
|
||||
return cudnn.data_type.INT64
|
||||
else:
|
||||
raise ValueError("Unsupported tensor data type.")
|
||||
|
||||
|
||||
def decode_attention_cudnn(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
|
||||
):
|
||||
# Prepare data: continuous q,k,v
|
||||
dims_q = (batch_size, head_num_q, 1, head_dim)
|
||||
strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
|
||||
q_gpu = q.as_strided(dims_q, strides_q)
|
||||
o_gpu = (
|
||||
torch.empty(batch_size * head_num_q * head_dim)
|
||||
.half()
|
||||
.cuda()
|
||||
.as_strided(dims_q, strides_q)
|
||||
)
|
||||
|
||||
dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
|
||||
strides_kv = (
|
||||
kv_len * head_num_kv * head_dim,
|
||||
head_dim,
|
||||
head_num_kv * head_dim,
|
||||
1,
|
||||
)
|
||||
k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
|
||||
v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
|
||||
|
||||
seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
|
||||
seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
|
||||
attn_scale = 1.0 / (head_dim**0.5)
|
||||
|
||||
# Prepare data: paged k,v
|
||||
block_size = 1
|
||||
blocks_per_batch = math.ceil(kv_len / block_size)
|
||||
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
|
||||
container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
||||
container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
||||
page_table_k_gpu = (
|
||||
torch.linspace(
|
||||
0,
|
||||
batch_size * blocks_per_batch - 1,
|
||||
batch_size * blocks_per_batch,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
.reshape(blocks_per_batch, 1, batch_size, 1)
|
||||
.transpose(0, 2)
|
||||
)
|
||||
page_table_v_gpu = page_table_k_gpu.clone()
|
||||
|
||||
graph = cudnn.pygraph(
|
||||
io_data_type=convert_to_cudnn_type(dtype),
|
||||
intermediate_data_type=cudnn.data_type.FLOAT,
|
||||
compute_data_type=cudnn.data_type.FLOAT,
|
||||
)
|
||||
|
||||
q = graph.tensor_like(q_gpu)
|
||||
container_k = graph.tensor_like(container_k_gpu)
|
||||
container_v = graph.tensor_like(container_v_gpu)
|
||||
page_table_k = graph.tensor_like(page_table_k_gpu)
|
||||
page_table_v = graph.tensor_like(page_table_v_gpu)
|
||||
|
||||
seq_len_q = graph.tensor_like(seq_len_q_gpu)
|
||||
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
|
||||
|
||||
o, _ = graph.sdpa(
|
||||
name="sdpa",
|
||||
q=q,
|
||||
k=container_k, # Container K: non contiguous container with K blocks
|
||||
v=container_v, # Container V: non contiguous container with V blocks
|
||||
is_inference=True,
|
||||
attn_scale=attn_scale,
|
||||
use_causal_mask=False,
|
||||
use_padding_mask=True,
|
||||
seq_len_q=seq_len_q,
|
||||
seq_len_kv=seq_len_kv,
|
||||
paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
|
||||
paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
|
||||
paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
|
||||
)
|
||||
|
||||
o.set_output(True).set_dim(dims_q).set_stride(strides_q)
|
||||
|
||||
graph.validate()
|
||||
graph.build_operation_graph()
|
||||
graph.create_execution_plans([cudnn.heur_mode.A])
|
||||
graph.check_support()
|
||||
graph.build_plans()
|
||||
|
||||
workspace = torch.empty(
|
||||
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
|
||||
)
|
||||
|
||||
variant_pack = {
|
||||
q: q_gpu,
|
||||
container_k: container_k_gpu,
|
||||
container_v: container_v_gpu,
|
||||
page_table_k: page_table_k_gpu,
|
||||
page_table_v: page_table_v_gpu,
|
||||
seq_len_q: seq_len_q_gpu,
|
||||
seq_len_kv: seq_len_kv_gpu,
|
||||
o: o_gpu,
|
||||
}
|
||||
|
||||
for _ in range(warmup):
|
||||
graph.execute(variant_pack, workspace)
|
||||
|
||||
f = time_fwd(
|
||||
graph.execute,
|
||||
variant_pack,
|
||||
workspace,
|
||||
)
|
||||
|
||||
return f, o_gpu.squeeze(dim=2)
|
||||
|
||||
|
||||
def calculate_diff():
|
||||
|
||||
dtype = torch.float16
|
||||
batch_size = 64
|
||||
kv_len = 4096
|
||||
head_num_q = 64
|
||||
head_num_kv = 8
|
||||
head_dim = 128
|
||||
|
||||
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
|
||||
kv_data = (
|
||||
torch.randn(
|
||||
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
||||
),
|
||||
torch.randn(
|
||||
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
||||
),
|
||||
)
|
||||
|
||||
_, output_sglang = decode_attention_sglang(
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
num_kv_splits=8,
|
||||
)
|
||||
|
||||
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
|
||||
_, output_flashinfer = attn_flashinfer(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
|
||||
_, output_cudnn = decode_attention_cudnn(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
|
||||
print(f"SGLang output={output_sglang}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"cuDNN output={output_cudnn}")
|
||||
if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
|
||||
print("✅ SGLang[Triton] and FlashInfer match")
|
||||
else:
|
||||
print("❌ SGLang[Triton] and FlashInfer differ")
|
||||
|
||||
if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
|
||||
print("✅ SGLang[Triton] and cuDNN match")
|
||||
else:
|
||||
print("❌ SGLang[Triton] and cuDNN differ")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff()
|
||||
|
||||
head_dim = 128
|
||||
dtype = torch.float16
|
||||
batch_size_range = [2**i for i in range(0, 8, 2)]
|
||||
kv_len_range = [2**i for i in range(6, 13, 1)]
|
||||
configs = list(itertools.product(batch_size_range, kv_len_range))
|
||||
|
||||
for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
|
||||
attn_flashinfer = decode_attention_flashinfer(
|
||||
dtype, head_num_q, head_num_kv
|
||||
).apply
|
||||
for batch_size, kv_len in configs:
|
||||
q = torch.randn(
|
||||
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
|
||||
)
|
||||
kv_data = (
|
||||
torch.randn(
|
||||
batch_size * kv_len,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
),
|
||||
torch.randn(
|
||||
batch_size * kv_len,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
),
|
||||
)
|
||||
us_cudnn, output_cudnn = decode_attention_cudnn(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
us_sglang, output_sglang = decode_attention_sglang(
|
||||
q,
|
||||
kv_data,
|
||||
batch_size,
|
||||
kv_len,
|
||||
head_num_q,
|
||||
head_num_kv,
|
||||
head_dim,
|
||||
num_kv_splits=8,
|
||||
)
|
||||
us_flashinfer, _ = attn_flashinfer(
|
||||
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
||||
)
|
||||
print(
|
||||
head_num_q,
|
||||
" ",
|
||||
head_num_kv,
|
||||
" ",
|
||||
batch_size,
|
||||
" ",
|
||||
kv_len,
|
||||
" ",
|
||||
us_cudnn,
|
||||
" ",
|
||||
us_sglang,
|
||||
" ",
|
||||
us_flashinfer,
|
||||
)
|
||||
218
third_party/sglang/benchmark/kernels/deepep/deepep_utils.py
vendored
Normal file
218
third_party/sglang/benchmark/kernels/deepep/deepep_utils.py
vendored
Normal file
@@ -0,0 +1,218 @@
|
||||
# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_dist(local_rank: int, num_local_ranks: int, args):
|
||||
ip = args.master_addr
|
||||
port = args.master_port
|
||||
num_nodes = args.nnodes
|
||||
node_rank = args.node_rank
|
||||
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=f"tcp://{ip}:{port}",
|
||||
world_size=num_nodes * num_local_ranks,
|
||||
rank=node_rank * num_local_ranks + local_rank,
|
||||
)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
return (
|
||||
dist.get_rank(),
|
||||
dist.get_world_size(),
|
||||
dist.new_group(list(range(num_local_ranks * num_nodes))),
|
||||
)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double() + 1, y.double() + 1
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return (1 - sim).item()
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor):
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n
|
||||
), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
|
||||
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
|
||||
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
|
||||
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
|
||||
|
||||
|
||||
def inplace_unique(x: torch.Tensor, num_slots: int):
|
||||
assert x.dim() == 2
|
||||
mask = x < 0
|
||||
x_padded = x.masked_fill(mask, num_slots)
|
||||
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
|
||||
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
|
||||
bin_count = bin_count[:, :num_slots]
|
||||
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
|
||||
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
|
||||
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
|
||||
x[:, :].fill_(-1)
|
||||
valid_len = min(num_slots, x.size(1))
|
||||
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
|
||||
|
||||
|
||||
def create_grouped_scores(
|
||||
scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
|
||||
):
|
||||
num_tokens, num_experts = scores.shape
|
||||
scores = scores.view(num_tokens, num_groups, -1)
|
||||
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
|
||||
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
|
||||
return (scores * mask).view(num_tokens, num_experts)
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Flush L2
|
||||
cache.zero_()
|
||||
|
||||
# Testing
|
||||
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
for i in range(num_tests):
|
||||
# Record
|
||||
start_events[i].record()
|
||||
fn()
|
||||
end_events[i].record()
|
||||
if post_fn is not None:
|
||||
post_fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = np.array(
|
||||
[s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
|
||||
)[1:]
|
||||
return np.average(times), np.min(times), np.max(times)
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(
|
||||
fn,
|
||||
kernel_names,
|
||||
num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: Optional[str] = None,
|
||||
barrier_comm_profiling: bool = False,
|
||||
):
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
||||
) as prof:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
|
||||
for _ in range(num_tests):
|
||||
fn()
|
||||
prof.step()
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
prof_lines = (
|
||||
prof.key_averages()
|
||||
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
||||
.split("\n")
|
||||
)
|
||||
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table"
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
prof.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {"ms": 1e3, "us": 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, "")) / scale)
|
||||
break
|
||||
break
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def hash_tensor(t: torch.Tensor):
|
||||
return t.view(torch.int64).sum().item()
|
||||
480
third_party/sglang/benchmark/kernels/deepep/tuning_deepep.py
vendored
Normal file
480
third_party/sglang/benchmark/kernels/deepep/tuning_deepep.py
vendored
Normal file
@@ -0,0 +1,480 @@
|
||||
# MODIFIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py
|
||||
|
||||
"""
|
||||
Example usage:
|
||||
python tuning_deepep.py --nnodes 4 --node-rank $MY_NODE_RANK --master-addr 1.2.3.4
|
||||
Then check `deepep_tuned.json`
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from deepep_utils import (
|
||||
bench,
|
||||
calc_diff,
|
||||
create_grouped_scores,
|
||||
init_dist,
|
||||
inplace_unique,
|
||||
per_token_cast_back,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_sms: int,
|
||||
local_rank: int,
|
||||
num_local_ranks: int,
|
||||
num_ranks: int,
|
||||
num_nodes: int,
|
||||
rank: int,
|
||||
buffer: deep_ep.Buffer,
|
||||
group: dist.ProcessGroup,
|
||||
args,
|
||||
):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
|
||||
args.num_tokens,
|
||||
args.hidden,
|
||||
min(num_nodes, 4),
|
||||
args.num_topk,
|
||||
(args.num_experts // num_ranks) * num_ranks,
|
||||
)
|
||||
assert num_experts % num_ranks == 0 and num_local_ranks == 8
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=num_topk_groups, dim=-1, sorted=False
|
||||
).indices
|
||||
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
|
||||
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
|
||||
1
|
||||
]
|
||||
topk_weights = (
|
||||
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
|
||||
)
|
||||
topk_weights_pure_rand = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
rdma_rank_idx = rank_idx // num_local_ranks
|
||||
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
|
||||
inplace_unique(rdma_rank_idx, num_nodes)
|
||||
|
||||
# RDMA dispatch counts
|
||||
rdma_idx = topk_idx // (num_experts // num_nodes)
|
||||
rdma_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rdma_idx, num_nodes)
|
||||
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank = torch.full(
|
||||
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(
|
||||
count, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_nodes):
|
||||
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
(
|
||||
ref_num_tokens_per_rank,
|
||||
ref_num_tokens_per_rdma_rank,
|
||||
ref_num_tokens_per_expert,
|
||||
ref_is_token_in_rank,
|
||||
_,
|
||||
) = buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
|
||||
print("", flush=True)
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, recv_gbl_rank_prefix_sum):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = recv_gbl_rank_prefix_sum[i].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
|
||||
flush=True,
|
||||
end="",
|
||||
)
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
dispatch_args.update(
|
||||
{
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
),
|
||||
}
|
||||
)
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
|
||||
# Checks
|
||||
recv_gbl_rank_prefix_sum = handle[-4]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
|
||||
0
|
||||
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
|
||||
assert (
|
||||
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
|
||||
== recv_num_tokens_per_expert_list
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (
|
||||
recv_topk_idx.eq(-1)
|
||||
| (
|
||||
(recv_topk_idx >= 0)
|
||||
& (recv_topk_idx < (num_experts // num_ranks))
|
||||
)
|
||||
).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = (
|
||||
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
|
||||
recv_topk_weights
|
||||
)[recv_topk_idx.eq(-1)]
|
||||
)
|
||||
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test combine
|
||||
combine_args = {
|
||||
"x": recv_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
combine_args.update({"topk_weights": recv_topk_weights})
|
||||
if previous_mode:
|
||||
combine_args.update({"previous_event": buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(
|
||||
**combine_args
|
||||
)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(
|
||||
dim=1
|
||||
).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = (
|
||||
combined_topk_weights
|
||||
if (current_x is x_pure_rand)
|
||||
else (
|
||||
combined_topk_weights
|
||||
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
)
|
||||
)
|
||||
ref_topk_weights = (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
)
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(" passed", flush=True)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
output_data = {}
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
rdma_send_bytes = (
|
||||
(dispatch_bf16_rdma_send_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_rdma_send_bytes
|
||||
)
|
||||
nvl_recv_bytes = (
|
||||
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_nvl_recv_bytes
|
||||
)
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
for rdma_chunk_size in range(4, 33, 4):
|
||||
config_kwargs = {
|
||||
"num_sms": num_sms,
|
||||
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
|
||||
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
|
||||
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
|
||||
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
|
||||
}
|
||||
config = deep_ep.Config(**config_kwargs)
|
||||
tune_args = {"x": current_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
config_kwargs,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
is_fp8 = isinstance(current_x, tuple)
|
||||
if is_fp8:
|
||||
output_data["normal_dispatch"] = deepcopy(best_results[3])
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor(
|
||||
[best_results[0], best_results[1], best_results[2]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
all_best_fp8_results_list = [
|
||||
torch.zeros_like(best_dispatch_results)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
dist.all_gather(
|
||||
all_best_fp8_results_list, best_dispatch_results, group=group
|
||||
)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(
|
||||
best_dispatch_results[0],
|
||||
best_dispatch_results[1],
|
||||
nvl_buffer_size,
|
||||
best_dispatch_results[2],
|
||||
rdma_buffer_size,
|
||||
)
|
||||
|
||||
dispatch_args = {
|
||||
"x": x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": dispatch_config if dispatch_config is not None else config,
|
||||
}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 8, 1):
|
||||
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
|
||||
config_kwargs = {
|
||||
"num_sms": num_sms,
|
||||
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
|
||||
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
|
||||
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
|
||||
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
|
||||
}
|
||||
config = deep_ep.Config(**config_kwargs)
|
||||
tune_args = {"x": recv_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
config_kwargs,
|
||||
)
|
||||
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
output_data["normal_combine"] = deepcopy(best_results[3])
|
||||
|
||||
if rank == 0 and local_rank == 0:
|
||||
_write_output(args, output_data)
|
||||
|
||||
|
||||
def _write_output(args, output_data):
|
||||
text = json.dumps(output_data, indent=4)
|
||||
output_path = args.output_path
|
||||
print(f"Write to {output_path} with {text}")
|
||||
Path(output_path).write_text(text)
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int, args):
|
||||
num_nodes = args.nnodes
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks, args)
|
||||
|
||||
num_sms = args.num_sms
|
||||
num_qps_per_rank = num_sms // 2
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
int(1e9),
|
||||
int(1e9),
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
assert num_local_ranks == 8 and num_ranks > 8
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (num_sms,):
|
||||
test_main(
|
||||
i,
|
||||
local_rank,
|
||||
num_local_ranks,
|
||||
num_ranks,
|
||||
num_nodes,
|
||||
rank,
|
||||
buffer,
|
||||
group,
|
||||
args,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-sms", type=int, default=24)
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden", type=int, default=7168)
|
||||
parser.add_argument("--num-topk", type=int, default=8)
|
||||
parser.add_argument("--num-experts", type=int, default=256)
|
||||
parser.add_argument("--output-path", type=str, default="deepep_tuned.json")
|
||||
parser.add_argument("--nnodes", type=int, default=1)
|
||||
parser.add_argument("--node-rank", type=int, default=0)
|
||||
parser.add_argument("--master-addr", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--master-port", type=int, default=8361)
|
||||
args = parser.parse_args()
|
||||
print(f"Start system with {args=}")
|
||||
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(
|
||||
test_loop, args=(num_processes, args), nprocs=num_processes
|
||||
)
|
||||
19
third_party/sglang/benchmark/kernels/deepseek/README.md
vendored
Normal file
19
third_party/sglang/benchmark/kernels/deepseek/README.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
## DeepSeek kernels benchmark
|
||||
|
||||
|
||||
### Prerequisites
|
||||
- You should install [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) from source before run `benchmark_deepgemm_fp8_gemm.py` and `benchmark_deepgemm_fp8_group_gemm.py`.
|
||||
|
||||
### Benchmark
|
||||
- `benchmark_deepgemm_fp8_gemm.py`
|
||||
```bash
|
||||
python benchmark_deepgemm_fp8_gemm.py --run_correctness --tp_size 1
|
||||
```
|
||||
|
||||
- `benchmark_deepgemm_fp8_group_gemm.py`
|
||||
```bash
|
||||
python benchmark_deepgemm_fp8_group_gemm.py --run_correctness --tp_size 1
|
||||
```
|
||||
|
||||
- You can use the `--run_correctness` parameter to verify all kernels results's correctness.
|
||||
- You can use the `--tp_size` parameter to benchmark all FP8 w8a8 block-wise matrix multiplications involved in DeepSeek V3/R1 under the current tensor parallelism (TP) setting. This benchmark compares DeepSeek's open-source [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) implementation with SGLang's and VLLM Triton implementation.
|
||||
250
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_dsv3_router_gemm_blackwell.py
vendored
Normal file
250
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_dsv3_router_gemm_blackwell.py
vendored
Normal file
@@ -0,0 +1,250 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer.gemm import mm_M1_16_K7168_N256
|
||||
from sgl_kernel import dsv3_router_gemm
|
||||
|
||||
N = 256
|
||||
K = 7168
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_sizes: List[int]):
|
||||
configs = []
|
||||
for tp_size in tp_sizes:
|
||||
for m in range(1, 17):
|
||||
configs.append((m, N, K, tp_size))
|
||||
return configs
|
||||
|
||||
|
||||
def dsv3_router_gemm_flashinfer(
|
||||
hidden_states: torch.Tensor,
|
||||
router_weights: torch.Tensor,
|
||||
):
|
||||
"""Flashinfer implementation of dsv3 router gemm"""
|
||||
output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
router_weights.shape[0],
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
mm_M1_16_K7168_N256(
|
||||
hidden_states, router_weights.t(), output, launch_with_pdl=args.use_pdl
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def dsv3_router_gemm_sgl(
|
||||
hidden_states: torch.Tensor,
|
||||
router_weights: torch.Tensor,
|
||||
):
|
||||
"""SGLang implementation of dsv3 router gemm"""
|
||||
output = dsv3_router_gemm(
|
||||
hidden_states,
|
||||
router_weights,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def check_accuracy(a, b, atol, rtol, percent):
|
||||
"""Unified accuracy checking function with detailed error reporting."""
|
||||
if not torch.isfinite(a).all():
|
||||
print("Non-finite values in reference output")
|
||||
return False
|
||||
if not torch.isfinite(b).all():
|
||||
print("Non-finite values in actual output")
|
||||
return False
|
||||
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
||||
|
||||
close = torch.isclose(a, b, atol=atol, rtol=rtol)
|
||||
match_ratio = close.float().mean()
|
||||
if match_ratio >= percent:
|
||||
return True
|
||||
|
||||
mismatch_percent = 1.0 - match_ratio.item()
|
||||
if mismatch_percent > 1 - percent:
|
||||
print(
|
||||
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
||||
f"(threshold: {1 - percent:.4f})"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int):
|
||||
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
router_weights = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out_flashinfer = dsv3_router_gemm_flashinfer(
|
||||
hidden_states.clone(memory_format=torch.contiguous_format),
|
||||
router_weights.clone(memory_format=torch.contiguous_format),
|
||||
)
|
||||
|
||||
out_sgl = dsv3_router_gemm_sgl(
|
||||
hidden_states.clone(memory_format=torch.contiguous_format),
|
||||
router_weights.clone(memory_format=torch.contiguous_format),
|
||||
)
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"Using PDL={args.use_pdl}")
|
||||
print(f"Flashinfer output: {out_flashinfer[0, 0:5]}")
|
||||
print(f"SGLang output: {out_sgl[0, 0:5]}")
|
||||
|
||||
flashinfer_sgl_match = check_accuracy(out_flashinfer, out_sgl, 0.1, 0.6, 0.95)
|
||||
print("Correctness check:")
|
||||
print(f" - Flashinfer vs SGLang: {'✅' if flashinfer_sgl_match else '❌'}")
|
||||
|
||||
|
||||
def _benchmark(m, n, k, tp_size, provider):
|
||||
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
|
||||
hidden_states = torch.randn(
|
||||
(m, k), device="cuda", dtype=torch.bfloat16
|
||||
).contiguous()
|
||||
router_weights = torch.randn(
|
||||
(n, k), device="cuda", dtype=torch.bfloat16
|
||||
).contiguous()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "sglang":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: dsv3_router_gemm_sgl(
|
||||
hidden_states.clone(memory_format=torch.contiguous_format),
|
||||
router_weights.clone(memory_format=torch.contiguous_format),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: dsv3_router_gemm_flashinfer(
|
||||
hidden_states.clone(memory_format=torch.contiguous_format),
|
||||
router_weights.clone(memory_format=torch.contiguous_format),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
# Print shape-specific results with TFLOPS
|
||||
print(f"Time: {ms*1000:.2f} us, TFLOPS: {tflops:.2f}")
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
def get_benchmark_plot_friendly(tp_sizes):
|
||||
all_configs = create_benchmark_configs(tp_sizes)
|
||||
x_vals = list(range(len(all_configs)))
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["cfg_id"],
|
||||
x_vals=x_vals,
|
||||
line_arg="provider",
|
||||
line_vals=["sglang", "flashinfer"],
|
||||
line_names=["SGLang", "Flashinfer"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp-{"-".join(str(tp) for tp in tp_sizes)}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(cfg_id, provider):
|
||||
m, n, k, tp_size = all_configs[cfg_id]
|
||||
ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def get_benchmark(tp_sizes):
|
||||
all_configs = create_benchmark_configs(tp_sizes)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=[
|
||||
"m",
|
||||
"n",
|
||||
"k",
|
||||
"tp_size",
|
||||
],
|
||||
x_vals=[list(config) for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["sglang", "flashinfer"],
|
||||
line_names=["SGLang", "Flashinfer"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp-{"-".join(str(tp) for tp in tp_sizes)}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, tp_size, provider):
|
||||
ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 10:
|
||||
print("Skipping benchmark because the device is not supported")
|
||||
exit(0)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/dsv3_router_gemm/",
|
||||
help="Path to save dsv3 router gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-correctness",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1],
|
||||
help="List of tensor parallelism sizes to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-friendly",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Plot x axis as the config index instead of the m",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-pdl",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use PDL if true.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
if args.use_pdl:
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
for m, n, k, _ in create_benchmark_configs(args.tp_sizes):
|
||||
calculate_diff(m, n, k)
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = (
|
||||
get_benchmark_plot_friendly(args.tp_sizes)
|
||||
if args.plot_friendly
|
||||
else get_benchmark(args.tp_sizes)
|
||||
)
|
||||
|
||||
print(f"Running performance benchmark for TP sizes = {args.tp_sizes}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
402
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
vendored
Normal file
402
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
vendored
Normal file
@@ -0,0 +1,402 @@
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import ceil_div
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
|
||||
def tl_gemm(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
in_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
):
|
||||
assert in_dtype in [
|
||||
"e4m3_float8",
|
||||
], "Currently only e4m3_float8 is supported"
|
||||
assert out_dtype in [
|
||||
"bfloat16",
|
||||
"float16",
|
||||
], "Currently only bfloat16 and float16 are supported"
|
||||
|
||||
TILE_SIZE = (128, 128, 128)
|
||||
block_M = TILE_SIZE[0]
|
||||
block_N = TILE_SIZE[1]
|
||||
block_K = TILE_SIZE[2]
|
||||
|
||||
A_shape = (M, K)
|
||||
Scales_A_shape = (M, T.ceildiv(K, block_K))
|
||||
B_shape = (N, K)
|
||||
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
|
||||
A_shared_shape = (block_M, block_K)
|
||||
B_shared_shape = (block_N, block_K)
|
||||
C_shared_shape = (block_M, block_N)
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
A: T.Buffer(A_shape, in_dtype),
|
||||
scales_a: T.Buffer(Scales_A_shape, "float32"),
|
||||
B: T.Buffer(B_shape, in_dtype),
|
||||
scales_b: T.Buffer(Scales_B_shape, "float32"),
|
||||
C: T.Buffer((M, N), out_dtype),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
||||
bx,
|
||||
by,
|
||||
):
|
||||
|
||||
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
|
||||
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
|
||||
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
|
||||
Scale_C_shared = T.alloc_shared((block_M), "float32")
|
||||
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
|
||||
|
||||
# Improve L2 Cache
|
||||
T.use_swizzle(panel_size=10)
|
||||
|
||||
T.clear(C_local)
|
||||
T.clear(C_local_accum)
|
||||
K_iters = T.ceildiv(K, block_K)
|
||||
for k in T.Pipelined(K_iters, num_stages=4):
|
||||
# Load A into shared memory
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
# Load B into shared memory
|
||||
T.copy(B[bx * block_N, k * block_K], B_shared)
|
||||
# Load scale into shared memory
|
||||
Scale_B = scales_b[bx, k]
|
||||
for i in T.Parallel(block_M):
|
||||
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
|
||||
|
||||
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
||||
# Promote to enable 2xAcc
|
||||
for i, j in T.Parallel(block_M, block_N):
|
||||
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
||||
T.clear(C_local)
|
||||
# TMA store
|
||||
T.copy(C_local_accum, C_shared)
|
||||
T.copy(C_shared, C[by * block_M, bx * block_N])
|
||||
|
||||
return main
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n
|
||||
), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def fp8_gemm_deepgemm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""DeepGEMM implementation of FP8 GEMM"""
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
def fp8_gemm_sglang(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""SGLang implementation of FP8 GEMM"""
|
||||
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||
|
||||
# Run SGLang kernel
|
||||
out = w8a8_block_fp8_matmul(
|
||||
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def fp8_gemm_vllm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""vLLM implementation of FP8 GEMM"""
|
||||
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
|
||||
|
||||
# Run vLLM kernel
|
||||
out = vllm_w8a8_block_fp8_matmul(
|
||||
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int):
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
|
||||
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
out_deepgemm = fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
x_scale_col_major.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
out_sglang = fp8_gemm_sglang(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
|
||||
)
|
||||
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
out_tilelang = tilelang_kernel(
|
||||
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
|
||||
)
|
||||
|
||||
diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
|
||||
diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
|
||||
diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
print(f"SGLang output: {out_sglang[0, 0:5]}")
|
||||
print(f"TileLang output: {out_tilelang[0, 0:5]}")
|
||||
print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
|
||||
print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
|
||||
print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
|
||||
|
||||
sglang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
tilelang_deepgemm_match = torch.allclose(
|
||||
out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
tilelang_sglang_match = torch.allclose(
|
||||
out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
|
||||
print("✅ All implementations match\n")
|
||||
else:
|
||||
print("❌ Some implementations differ:")
|
||||
print(f" - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
|
||||
print(f" - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
|
||||
print(f" - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_size):
|
||||
configs = []
|
||||
weight_shapes = get_weight_shapes(tp_size)
|
||||
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]
|
||||
|
||||
for n, k in weight_shapes:
|
||||
for m in batch_sizes:
|
||||
configs.append((m, n, k, tp_size))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_benchmark(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["m", "n", "k", "tp_size"],
|
||||
x_vals=[list(config) for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "sglang", "tilelang"],
|
||||
line_names=["DeepGEMM", "SGLang", "TileLang"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, tp_size, provider):
|
||||
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Preprocess data before benchmarking
|
||||
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
|
||||
if provider == "deepgemm":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: fp8_gemm_deepgemm(
|
||||
x_fp8.clone(),
|
||||
x_scale_col_major.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "sglang":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: fp8_gemm_sglang(
|
||||
x_fp8.clone(),
|
||||
x_scale.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else: # tilelang
|
||||
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
|
||||
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: tilelang_kernel(
|
||||
x_fp8.clone(),
|
||||
x_scale.clone(),
|
||||
y_fp8.clone(),
|
||||
y_scale.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
# Print shape-specific results with TFLOPS
|
||||
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fp8_gemm/",
|
||||
help="Path to save fp8 gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_correctness",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallelism size to benchmark (default: 1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
calculate_diff(64, 512, 7168) # Small test
|
||||
calculate_diff(64, 7168, 16384) # Medium test
|
||||
calculate_diff(64, 18432, 7168) # Large test
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = get_benchmark(args.tp_size)
|
||||
|
||||
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
330
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm_blackwell.py
vendored
Normal file
330
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm_blackwell.py
vendored
Normal file
@@ -0,0 +1,330 @@
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import ceil_div
|
||||
from flashinfer.gemm import gemm_fp8_nt_groupwise
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul_deepgemm,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import requant_weight_ue8m0
|
||||
|
||||
BLOCK_SIZE = 128
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
assert BLOCK_SIZE == 128
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_size):
|
||||
configs = []
|
||||
weight_shapes = get_weight_shapes(tp_size)
|
||||
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]
|
||||
|
||||
for n, k in weight_shapes:
|
||||
for m in batch_sizes:
|
||||
configs.append((m, n, k, tp_size))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def fp8_gemm_flashinfer(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
):
|
||||
"""Flashinfer implementation of FP8 GEMM"""
|
||||
output = gemm_fp8_nt_groupwise(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
x_scale,
|
||||
y_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
backend="trtllm",
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fp8_gemm_deepgemm_blackwell(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
):
|
||||
"""DeepGEMM implementation of FP8 GEMM"""
|
||||
block_size = [BLOCK_SIZE, BLOCK_SIZE]
|
||||
output = w8a8_block_fp8_matmul_deepgemm(
|
||||
x_fp8, y_fp8, x_scale, y_scale, block_size, output_dtype=torch.bfloat16
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def check_accuracy(a, b, atol, rtol, percent):
|
||||
"""Unified accuracy checking function with detailed error reporting."""
|
||||
if not torch.isfinite(a).all():
|
||||
print("Non-finite values in reference output")
|
||||
return False
|
||||
if not torch.isfinite(b).all():
|
||||
print("Non-finite values in actual output")
|
||||
return False
|
||||
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
||||
|
||||
close = torch.isclose(a, b, atol=atol, rtol=rtol)
|
||||
match_ratio = close.float().mean()
|
||||
if match_ratio >= percent:
|
||||
return True
|
||||
|
||||
mismatch_percent = 1.0 - match_ratio.item()
|
||||
if mismatch_percent > 1 - percent:
|
||||
print(
|
||||
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
||||
f"(threshold: {1 - percent:.4f})"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int):
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||
x_fp8, x_scale = sglang_per_token_group_quant_fp8(
|
||||
x, BLOCK_SIZE, column_major_scales=True
|
||||
)
|
||||
out_flashinfer = fp8_gemm_flashinfer(
|
||||
x_fp8,
|
||||
x_scale,
|
||||
y_fp8,
|
||||
y_scale,
|
||||
)
|
||||
|
||||
dg_x_fp8, dg_x_scale = sglang_per_token_group_quant_fp8(
|
||||
x,
|
||||
BLOCK_SIZE,
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
)
|
||||
# We can directly quantize y here, but to mimic the behavior of the actual
|
||||
# implementations, we requant it here.
|
||||
dg_y_fp8, dg_y_scale = requant_weight_ue8m0(
|
||||
y_fp8, y_scale, [BLOCK_SIZE, BLOCK_SIZE]
|
||||
)
|
||||
out_deepgemm = fp8_gemm_deepgemm_blackwell(
|
||||
dg_x_fp8, dg_x_scale, dg_y_fp8, dg_y_scale
|
||||
)
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"Flashinfer output: {out_flashinfer[0, 0:5]}")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
|
||||
flashinfer_deepgemm_match = check_accuracy(
|
||||
out_flashinfer, out_deepgemm, 0.1, 0.6, 0.95
|
||||
)
|
||||
print("Correctness check:")
|
||||
print(f" - Flashinfer vs DeepGEMM: {'✅' if flashinfer_deepgemm_match else '❌'}")
|
||||
|
||||
|
||||
def _benchmark(m, n, k, tp_size, provider):
|
||||
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Preprocess data before benchmarking
|
||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||
x_fp8, x_scale = sglang_per_token_group_quant_fp8(
|
||||
x, BLOCK_SIZE, column_major_scales=True
|
||||
)
|
||||
dg_x_fp8, dg_x_scale = sglang_per_token_group_quant_fp8(
|
||||
x,
|
||||
BLOCK_SIZE,
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
)
|
||||
dg_y_fp8, dg_y_scale = requant_weight_ue8m0(
|
||||
y_fp8, y_scale, [BLOCK_SIZE, BLOCK_SIZE]
|
||||
)
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
|
||||
if provider == "deepgemm":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: fp8_gemm_deepgemm_blackwell(
|
||||
dg_x_fp8,
|
||||
dg_x_scale,
|
||||
dg_y_fp8,
|
||||
dg_y_scale,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: fp8_gemm_flashinfer(
|
||||
x_fp8,
|
||||
x_scale,
|
||||
y_fp8,
|
||||
y_scale,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
# Print shape-specific results with TFLOPS
|
||||
print(f"Time: {ms*1000:.2f} us, TFLOPS: {tflops:.2f}")
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
def get_benchmark_plot_friendly(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
x_vals = list(range(len(all_configs)))
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["cfg_id"],
|
||||
x_vals=x_vals,
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "flashinfer"],
|
||||
line_names=["DeepGEMM", "Flashinfer"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(cfg_id, provider):
|
||||
m, n, k, tp_size = all_configs[cfg_id]
|
||||
ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def get_benchmark(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["m", "n", "k", "tp_size"],
|
||||
x_vals=[list(config) for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "flashinfer"],
|
||||
line_names=["DeepGEMM", "Flashinfer"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, tp_size, provider):
|
||||
ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 10:
|
||||
print("Skipping benchmark because the device is not supported")
|
||||
exit(0)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fp8_gemm/",
|
||||
help="Path to save fp8 gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-correctness",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallelism size to benchmark (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-friendly",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Plot x axis as the config index instead of the m",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
calculate_diff(64, 512, 7168) # Small test
|
||||
calculate_diff(64, 7168, 16384) # Medium test
|
||||
calculate_diff(64, 18432, 7168) # Large test
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = (
|
||||
get_benchmark_plot_friendly(args.tp_size)
|
||||
if args.plot_friendly
|
||||
else get_benchmark(args.tp_size)
|
||||
)
|
||||
|
||||
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
488
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py
vendored
Normal file
488
third_party/sglang/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_group_gemm.py
vendored
Normal file
@@ -0,0 +1,488 @@
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from deep_gemm import calc_diff
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
|
||||
# Import shared functionality from the regular GEMM benchmark
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
||||
per_block_cast_to_fp8,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def construct_grouped_and_flat_fp8(
|
||||
x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
|
||||
) -> Tuple[
|
||||
Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # flat x_fp8
|
||||
Tuple[torch.Tensor, torch.Tensor], # flat y_fp8
|
||||
torch.Tensor, # output
|
||||
torch.Tensor, # reference output
|
||||
]:
|
||||
# Verify input shapes
|
||||
m, k = x.shape
|
||||
n, k_y = y.shape
|
||||
assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
|
||||
assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
|
||||
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||
|
||||
# Reshape inputs for grouped processing
|
||||
m_per_group = m // num_groups
|
||||
x_grouped = x.view(num_groups, m_per_group, k)
|
||||
y_grouped = y.unsqueeze(0).expand(num_groups, n, k)
|
||||
|
||||
# Initialize output tensors
|
||||
out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
|
||||
ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)
|
||||
|
||||
# Quantize grouped tensors
|
||||
x_fp8_grouped = (
|
||||
torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
y_fp8_grouped = (
|
||||
torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
|
||||
y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])
|
||||
|
||||
# Quantize flat tensors
|
||||
x_fp8_flat = per_token_cast_to_fp8(x)
|
||||
y_fp8_flat = per_block_cast_to_fp8(y)
|
||||
|
||||
# For non-masked input, merge the group and M dims in output
|
||||
if not is_masked:
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0].view(-1, k),
|
||||
per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
|
||||
)
|
||||
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
||||
|
||||
# Transpose earlier for testing
|
||||
x_fp8_grouped = (
|
||||
x_fp8_grouped[0],
|
||||
get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
||||
)
|
||||
x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
|
||||
|
||||
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
||||
|
||||
|
||||
# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
|
||||
# custom kernel based on the Triton tutorial.
|
||||
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||
@triton.jit
|
||||
def fp8_gemm_group_triton_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
# Pointers to scaling factors
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension.
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
# Strides for scaling factors
|
||||
stride_a_scale_m,
|
||||
stride_a_scale_k,
|
||||
stride_b_scale_n,
|
||||
stride_b_scale_k,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
|
||||
Note: Block sizes must be multiples of 32 for optimal TMA performance.
|
||||
"""
|
||||
# Map program ids to the block of C it should compute
|
||||
pid_group = tl.program_id(axis=0) # Group ID
|
||||
pid_n = tl.program_id(axis=1) # N dimension ID
|
||||
|
||||
# Compute the M block ID within this group
|
||||
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
|
||||
pid_m_within_group = tl.program_id(axis=2) % group_size_m
|
||||
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
|
||||
|
||||
# Create pointers for the first blocks of A and B
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# Initialize accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Main loop
|
||||
for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
k_offset = k_block * BLOCK_SIZE_K
|
||||
|
||||
# Load the next block of A and B, with masks
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)
|
||||
|
||||
# Calculate indices for scaling factors for this K block
|
||||
a_scale_ptrs = a_scale_ptr + (
|
||||
offs_am * stride_a_scale_m + k_block * stride_a_scale_k
|
||||
)
|
||||
b_scale_ptrs = b_scale_ptr + (
|
||||
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
|
||||
)
|
||||
|
||||
# Perform matrix multiplication in FP8
|
||||
res = tl.dot(a, b)
|
||||
|
||||
# Load scaling factors for the current block
|
||||
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
# Apply scaling factors to the accumulated result
|
||||
accumulator += res * a_scale * b_scale
|
||||
|
||||
# Advance pointers
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Convert to bfloat16 for output
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
|
||||
# Write back the result
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
||||
"""
|
||||
Perform matrix multiplication with FP8 inputs and proper scaling.
|
||||
|
||||
Args:
|
||||
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
|
||||
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
|
||||
c: Output tensor in BF16 format
|
||||
num_groups: Number of groups for grouped GEMM
|
||||
|
||||
Returns:
|
||||
Result tensor in BF16 format
|
||||
"""
|
||||
# Unpack the tuples
|
||||
a, a_scale = a_tuple
|
||||
b, b_scale = b_tuple
|
||||
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
|
||||
# Configure block sizes - must be multiples of 32 for TMA alignment
|
||||
BLOCK_SIZE_M = 128
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_K = 128
|
||||
|
||||
# Calculate grid dimensions
|
||||
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
|
||||
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
|
||||
|
||||
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
|
||||
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
|
||||
|
||||
fp8_gemm_group_triton_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a_scale,
|
||||
b_scale,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
a_scale.stride(0),
|
||||
1, # Stride in the K dimension may be 1
|
||||
b_scale.stride(0),
|
||||
1 if b_scale.dim() > 1 else 0,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
GROUP_SIZE_M=num_groups,
|
||||
)
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
m_indices,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def calculate_diff(m: int, n: int, k: int, num_groups: int):
|
||||
print(f"Shape (m={m}, n={n}, k={k}")
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||
)
|
||||
m_per_group = m // num_groups
|
||||
out_deepgemm = out.clone()
|
||||
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||
m_indices = (
|
||||
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
|
||||
)
|
||||
|
||||
fp8_gemm_group_deepgemm(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out_deepgemm,
|
||||
m_indices,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Prepare inputs for Triton
|
||||
a, a_scale = x_fp8_flat
|
||||
b, b_scale = y_fp8_flat
|
||||
b = b.T.contiguous()
|
||||
# Ensure scales are in the right format and contiguous
|
||||
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||
M, _ = a.shape
|
||||
_, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
|
||||
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
|
||||
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
|
||||
|
||||
print(f"Shape m={m}, n={n}, k={k}:")
|
||||
print(f"Torch output: {out_torch[0, 0:5]}")
|
||||
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
||||
print(f"Triton output: {out_triton[0, 0:5]}")
|
||||
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
|
||||
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
|
||||
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
|
||||
|
||||
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
|
||||
triton_torch_diff = calc_diff(out_triton, out_torch)
|
||||
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
|
||||
|
||||
DIFF_THRESHOLD = 0.001
|
||||
all_match = (
|
||||
deepgemm_torch_diff < DIFF_THRESHOLD
|
||||
and triton_torch_diff < DIFF_THRESHOLD
|
||||
and deepgemm_triton_diff < DIFF_THRESHOLD
|
||||
)
|
||||
if all_match:
|
||||
print("✅ All implementations match\n")
|
||||
else:
|
||||
print("❌ Some implementations differ:")
|
||||
print(
|
||||
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
|
||||
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def create_benchmark_configs(tp_size):
|
||||
configs = []
|
||||
weight_shapes = get_weight_shapes(tp_size)
|
||||
batch_sizes = [2048, 4096]
|
||||
group_sizes = [4, 8]
|
||||
for n, k in weight_shapes:
|
||||
for m in batch_sizes:
|
||||
for num_groups in group_sizes:
|
||||
configs.append((m, n, k, num_groups, tp_size))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_benchmark(tp_size):
|
||||
all_configs = create_benchmark_configs(tp_size)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["m", "n", "k", "num_groups", "tp_size"],
|
||||
x_vals=[config for config in all_configs],
|
||||
line_arg="provider",
|
||||
line_vals=["deepgemm", "triton"],
|
||||
line_names=["DeepGEMM", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(m, n, k, num_groups, tp_size, provider):
|
||||
print(
|
||||
f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
|
||||
)
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
||||
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
||||
)
|
||||
m_per_group = m // num_groups
|
||||
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
||||
m_indices = (
|
||||
m_indices.unsqueeze(-1)
|
||||
.expand(num_groups, m_per_group)
|
||||
.contiguous()
|
||||
.view(-1)
|
||||
)
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
|
||||
if provider == "deepgemm":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: fp8_gemm_group_deepgemm(
|
||||
x_fp8_grouped,
|
||||
y_fp8_grouped,
|
||||
out,
|
||||
m_indices,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
# Prepare inputs for Triton
|
||||
# We did it outside of the lambda function to make it fair comparison like deepgemm
|
||||
a, a_scale = x_fp8_flat
|
||||
b, b_scale = y_fp8_flat
|
||||
b = b.T.contiguous()
|
||||
# Ensure scales are in the right format and contiguous
|
||||
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
||||
M, _ = a.shape
|
||||
_, N = b.shape
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: fp8_gemm_group_triton(
|
||||
(a, a_scale),
|
||||
(b, b_scale),
|
||||
c,
|
||||
num_groups,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * m * n * k # multiply-adds
|
||||
tflops = flops / (ms * 1e-3) / 1e12
|
||||
|
||||
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fp8_group_gemm/",
|
||||
help="Path to save deepgemm fp8 group gemm benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_correctness",
|
||||
action="store_true",
|
||||
help="Whether to run correctness test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallelism size to benchmark (default: 1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Run correctness tests on a few examples
|
||||
if args.run_correctness:
|
||||
print("Running correctness tests...")
|
||||
calculate_diff(8192, 7168, 4096, 4)
|
||||
calculate_diff(8192, 2048, 7168, 4)
|
||||
calculate_diff(4096, 7168, 4096, 8)
|
||||
calculate_diff(4096, 2048, 7168, 8)
|
||||
calculate_diff(4096, 576, 7168, 8)
|
||||
|
||||
# Get the benchmark function with the specified tp_size
|
||||
benchmark = get_benchmark(args.tp_size)
|
||||
|
||||
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
198
third_party/sglang/benchmark/kernels/elementwise/benchmark_concat_mla.py
vendored
Normal file
198
third_party/sglang/benchmark/kernels/elementwise/benchmark_concat_mla.py
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import concat_mla_k as concat_mla_k_cuda
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
|
||||
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
||||
|
||||
num_local_heads = 128
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
|
||||
|
||||
def create_data(num_tokens):
|
||||
k_nope_container = torch.randn(
|
||||
(num_tokens, num_local_heads, qk_nope_head_dim + 128),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
k_nope = k_nope_container[:, :, :qk_nope_head_dim]
|
||||
|
||||
k_rope_container = torch.randn(
|
||||
(num_tokens, 1, 128 + qk_rope_head_dim), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
k_rope = k_rope_container[:, :, -qk_rope_head_dim:]
|
||||
|
||||
k = torch.empty(
|
||||
(num_tokens, num_local_heads, qk_nope_head_dim + qk_rope_head_dim),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
return dict(k=k, k_nope=k_nope, k_rope=k_rope)
|
||||
|
||||
|
||||
def fn_torch(k, k_nope, k_rope):
|
||||
k[..., :qk_nope_head_dim] = k_nope
|
||||
k[..., qk_nope_head_dim:] = k_rope
|
||||
|
||||
|
||||
def fn_hack_non_strided(k, k_nope, k_rope):
|
||||
k_flatten_view = k.flatten()
|
||||
k_flatten_view[: k_nope.numel()] = k_nope.flatten()
|
||||
|
||||
k2 = k_flatten_view[k_nope.numel() :].view(k_rope.numel(), -1)
|
||||
k2 = k_rope.flatten()[:, None]
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def fn_torch_compiled(k, k_nope, k_rope):
|
||||
return fn_torch(k, k_nope, k_rope)
|
||||
|
||||
|
||||
def fn_cuda(k, k_nope, k_rope):
|
||||
concat_mla_k_cuda(k, k_nope, k_rope)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fn_triton_kernel(
|
||||
k_ptr,
|
||||
k_nope_ptr,
|
||||
k_rope_ptr,
|
||||
num_tokens,
|
||||
QK_NOPE_HEAD_DIM: tl.constexpr,
|
||||
QK_ROPE_HEAD_DIM: tl.constexpr,
|
||||
NUM_LOCAL_HEADS: tl.constexpr,
|
||||
K_NOPE_STRIDE_0: tl.constexpr,
|
||||
K_NOPE_STRIDE_1: tl.constexpr,
|
||||
K_STRIDE_0: tl.constexpr,
|
||||
K_STRIDE_1: tl.constexpr,
|
||||
K_ROPE_STRIDE_0: tl.constexpr,
|
||||
BLOCK_ROWS: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
token_id = pid * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS)
|
||||
token_mask = token_id < num_tokens
|
||||
|
||||
head_id = tl.arange(0, NUM_LOCAL_HEADS)
|
||||
|
||||
# nope
|
||||
nope_sub_id = tl.arange(0, QK_NOPE_HEAD_DIM)
|
||||
offs_nope = (
|
||||
token_id[:, None, None] * K_NOPE_STRIDE_0
|
||||
+ head_id[None, :, None] * K_NOPE_STRIDE_1
|
||||
+ nope_sub_id[None, None, :]
|
||||
)
|
||||
offs_k = (
|
||||
token_id[:, None, None] * K_STRIDE_0
|
||||
+ head_id[None, :, None] * K_STRIDE_1
|
||||
+ nope_sub_id[None, None, :]
|
||||
)
|
||||
vals_nope = tl.load(k_nope_ptr + offs_nope, mask=token_mask[:, None, None])
|
||||
tl.store(k_ptr + offs_k, vals_nope, mask=token_mask[:, None, None])
|
||||
|
||||
# rope
|
||||
rope_sub_id = tl.arange(0, QK_ROPE_HEAD_DIM)
|
||||
offs_rope = token_id[:, None, None] * K_ROPE_STRIDE_0 + rope_sub_id[None, None, :]
|
||||
offs_k = (
|
||||
token_id[:, None, None] * K_STRIDE_0
|
||||
+ head_id[None, :, None] * K_STRIDE_1
|
||||
+ rope_sub_id[None, None, :]
|
||||
+ QK_NOPE_HEAD_DIM
|
||||
)
|
||||
vals_rope = tl.load(k_rope_ptr + offs_rope, mask=token_mask[:, None, None])
|
||||
tl.store(k_ptr + offs_k, vals_rope, mask=token_mask[:, None, None])
|
||||
|
||||
|
||||
def fn_triton(k, k_nope, k_rope):
|
||||
assert k.device == DEVICE and k_nope.device == DEVICE and k_rope.device == DEVICE
|
||||
num_tokens, _, _ = k.shape
|
||||
grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_ROWS"]),)
|
||||
fn_triton_kernel[grid](
|
||||
k,
|
||||
k_nope,
|
||||
k_rope,
|
||||
num_tokens,
|
||||
QK_NOPE_HEAD_DIM=qk_nope_head_dim,
|
||||
QK_ROPE_HEAD_DIM=qk_rope_head_dim,
|
||||
NUM_LOCAL_HEADS=num_local_heads,
|
||||
K_NOPE_STRIDE_0=k_nope.stride(0),
|
||||
K_NOPE_STRIDE_1=k_nope.stride(1),
|
||||
K_STRIDE_0=k.stride(0),
|
||||
K_STRIDE_1=k.stride(1),
|
||||
K_ROPE_STRIDE_0=k_rope.stride(0),
|
||||
BLOCK_ROWS=16,
|
||||
)
|
||||
|
||||
|
||||
def execute_and_get_output(f, data):
|
||||
data["k"].zero_()
|
||||
f(**data)
|
||||
assert data["k"].sum().item() != 0
|
||||
return data["k"].clone()
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
data = create_data(num_tokens=32768)
|
||||
output_ref = execute_and_get_output(fn_torch, data)
|
||||
output_exp = execute_and_get_output(fn_cuda, data)
|
||||
# print(output_ref)
|
||||
# print(output_exp)
|
||||
if not torch.all(output_ref == output_exp):
|
||||
abs_delta = torch.abs(output_ref - output_exp)
|
||||
raise AssertionError(
|
||||
f"{output_ref=} {output_exp=} "
|
||||
f"{abs_delta=} "
|
||||
f"{torch.argwhere(abs_delta != 0.0)=} "
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"], # Argument names to use as an x-axis for the plot.
|
||||
x_vals=[
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
32768,
|
||||
], # Different possible values for `x_name`.
|
||||
x_log=False, # x axis is logarithmic.
|
||||
line_arg="provider", # Argument name whose value corresponds to a different line in the plot.
|
||||
line_vals=[
|
||||
"torch",
|
||||
"torch_compiled",
|
||||
"triton",
|
||||
"hack_non_strided",
|
||||
"cuda",
|
||||
], # Possible values for `line_arg`.
|
||||
line_names=[
|
||||
"torch",
|
||||
"torch_compiled",
|
||||
"triton",
|
||||
"hack_non_strided",
|
||||
"cuda",
|
||||
], # Label name for the lines.
|
||||
plot_name="vector-add-performance", # Name for the plot. Used also as a file name for saving the plot.
|
||||
args={}, # Values for function arguments not in `x_names` and `y_name`.
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, provider):
|
||||
data = create_data(num_tokens=num_tokens)
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
fn = {
|
||||
"torch": fn_torch,
|
||||
"torch_compiled": fn_torch_compiled,
|
||||
"triton": fn_triton,
|
||||
"hack_non_strided": fn_hack_non_strided,
|
||||
"cuda": fn_cuda,
|
||||
}[provider]
|
||||
ms, min_ms, max_ms = run_bench(lambda: fn(**data), quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
102
third_party/sglang/benchmark/kernels/flashinfer_allreduce_fusion/README.md
vendored
Normal file
102
third_party/sglang/benchmark/kernels/flashinfer_allreduce_fusion/README.md
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# FlashInfer Fused AllReduce + RMSNorm Benchmark
|
||||
|
||||
This benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations.
|
||||
|
||||
This benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators.
|
||||
|
||||
Script path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py`
|
||||
|
||||
## Feature Overview
|
||||
|
||||
- Compare average execution time (ms) and calculate speedup ratios for the following paths:
|
||||
- standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm)
|
||||
- flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes
|
||||
- Optionally compare FP8/FP4 quantized fused paths with standard paths
|
||||
- Use CUDA Graph capture and batch replay to reduce measurement noise
|
||||
- Automatically select the faster "standard baseline" (native/compiled version) as the denominator for speedup calculation
|
||||
- Optionally export results in Markdown format
|
||||
|
||||
## Runtime Environment and Prerequisites
|
||||
|
||||
- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend)
|
||||
- Properly install/compile sglang along with sgl-kernel and custom operators
|
||||
|
||||
## Quick Start (Command Examples)
|
||||
|
||||
The following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine:
|
||||
|
||||
- Regular paths only (no quantization):
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
- FP8 quantization paths only:
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
- FP4 quantization paths only:
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
- Larger hidden dimensions:
|
||||
```
|
||||
torchrun --nproc_per_node=2 \
|
||||
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
|
||||
--no-quant --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100
|
||||
```
|
||||
|
||||
## Parameter Description
|
||||
- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048)
|
||||
- `--hidden-dim`: Hidden dimension (default: 8192)
|
||||
- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16)
|
||||
- `--no-residual`: Only test "no residual" scenarios (default tests both "with/without residual")
|
||||
- Mutually exclusive quantization options:
|
||||
- `--no-quant`: No quantization testing
|
||||
- `--quant-fp8`: Only FP8 quantization testing
|
||||
- `--quant-fp4`: Only FP4 quantization testing
|
||||
- `--quant-all`: Test all (default)
|
||||
- FlashInfer related:
|
||||
- `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously)
|
||||
- Runtime configuration:
|
||||
- `--warmup`: Warmup count before graph capture and before graph replay (default 5)
|
||||
- `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times)
|
||||
- `--output-file`: Save results as Markdown file (only rank0 takes effect)
|
||||
|
||||
## Output Example
|
||||
|
||||
Each configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example:
|
||||
```
|
||||
================================================================================
|
||||
Results: seq_len=1024, hidden_dim=1024
|
||||
dtype=torch.bfloat16, residual=yes, quant_mode=none
|
||||
================================================================================
|
||||
Operation Time (ms) Speedup
|
||||
--------------------------------------------------------------------------------
|
||||
standard_allreduce_rmsnorm 0.024 0.98x
|
||||
standard_allreduce_rmsnorm_native_compiled 0.023 baseline
|
||||
flashinfer_fused_allreduce_rmsnorm_oneshot 0.011 2.19x
|
||||
flashinfer_fused_allreduce_rmsnorm_twoshot 0.041 0.57x
|
||||
```
|
||||
|
||||
If `--output-file` is specified, all configurations will be summarized in Markdown tables in that file.
|
||||
|
||||
## Important Notes and Recommendations
|
||||
|
||||
- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device.
|
||||
- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt.
|
||||
- FlashInfer:
|
||||
- If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs.
|
||||
- The fused operator internally uses "oneshot"/"twoshot" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously.
|
||||
- FP8/FP4:
|
||||
- FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc.
|
||||
- FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support.
|
||||
- CUDA Graph:
|
||||
- Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter.
|
||||
1305
third_party/sglang/benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py
vendored
Normal file
1305
third_party/sglang/benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
210
third_party/sglang/benchmark/kernels/fused_moe_triton/README.md
vendored
Normal file
210
third_party/sglang/benchmark/kernels/fused_moe_triton/README.md
vendored
Normal file
@@ -0,0 +1,210 @@
|
||||
## Tuning Triton MoE Kernels
|
||||
|
||||
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
|
||||
|
||||
### Overview
|
||||
|
||||
The tuning tools support both **Tensor Parallelism (TP)** and **Expert Parallelism (EP)** modes:
|
||||
|
||||
- **TP Mode**: Traditional tensor parallelism where intermediate layers are sharded across GPUs
|
||||
- **EP Mode**: Expert parallelism where experts are distributed across GPUs. Can be combined with TP mode (e.g., `--tp-size 8 --ep-size 2`)
|
||||
- **MLLM Support**: Multi-modal Large Language Models with text encoders (e.g., Llama4, Qwen3VL)
|
||||
|
||||
### Tuning Tools
|
||||
|
||||
#### 1. `tuning_fused_moe_triton.py`
|
||||
A unified tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with support for EP mode and various model architectures.
|
||||
|
||||
#### 2. `tuning_fused_moe_triton_sep.py`
|
||||
A specialized tool for separate kernel tuning, optimizing the first and second MoE kernels independently with TMA (Tensor Memory Accelerator) support.
|
||||
|
||||
### Usage Examples
|
||||
|
||||
#### Basic TP Mode Tuning
|
||||
```bash
|
||||
# Tune Mixtral-8x7B with default TP settings
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||
--tune
|
||||
|
||||
# Tune Qwen2-57B with FP8 and TP=4
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--tp-size 4 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
# Tune DeepSeek-V3 with FP8 and TP=8
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
```
|
||||
|
||||
#### EP Mode Tuning (Expert Parallelism)
|
||||
**Note**: EP mode can be used alone or combined with TP mode. When using both, ensure `tp_size` is divisible by `ep_size`.
|
||||
|
||||
```bash
|
||||
# Tune Mixtral-8x7B with EP=2 only
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||
--tp-size 2 \
|
||||
--ep-size 2 \
|
||||
--tune
|
||||
|
||||
# Tune Qwen2-57B with TP=8 and EP=4 (combined mode)
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--tp-size 8 \
|
||||
--ep-size 4 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
```
|
||||
|
||||
#### MLLM Model Tuning (Multi-modal)
|
||||
```bash
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--tp-size 2 \
|
||||
--tune
|
||||
```
|
||||
|
||||
#### Separate Kernel Tuning with `tuning_fused_moe_triton_sep.py`
|
||||
|
||||
This tool requires pre-generated topk_ids files and supports both TP and EP modes:
|
||||
|
||||
Edit the code file (such as srt/models/deepseek_v2.py) in the Python site package and add the logic for saving topk_ids:
|
||||
|
||||
```python
|
||||
# import get_tensor_model_parallel_rank
|
||||
# DeepseekV2MoE::forward_normal
|
||||
if hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:
|
||||
topk_ids_dir = xxxx
|
||||
if not hasattr(self, "save_idx"):
|
||||
self.save_idx = 0
|
||||
if self.save_idx <= 1:
|
||||
torch.save(topk_output.topk_ids, f"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt")
|
||||
self.save_idx += 1
|
||||
```
|
||||
|
||||
Launch sglang server and send request using `benchmark/kernels/fused_moe_triton/tuning_client.py`
|
||||
```bash
|
||||
python benchmark/kernels/fused_moe_triton/tuning_client.py --port 8000
|
||||
```
|
||||
|
||||
```bash
|
||||
# TP Mode: Tune separate kernels with TP=4
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--tp-size 4 \
|
||||
--topk-ids-dir /path/to/topk_ids \
|
||||
--tune
|
||||
|
||||
# EP Mode: Tune separate kernels with TP=4 and EP=2
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
|
||||
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||
--tp-size 4 \
|
||||
--ep-size 2 \
|
||||
--topk-ids-dir /path/to/topk_ids \
|
||||
--tune
|
||||
|
||||
# MLLM: Tune DeepSeek-V3 with separate kernels, TP=8 and EP=4
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8 \
|
||||
--ep-size 4 \
|
||||
--dtype fp8_w8a8 \
|
||||
--topk-ids-dir /path/to/topk_ids \
|
||||
--tune
|
||||
|
||||
# Benchmark specific config without tuning
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 4 \
|
||||
--batch-size 1024 \
|
||||
--dtype fp8_w8a8 \
|
||||
--configs 128 256 128 16 8 4 \
|
||||
--topk-ids-dir /path/to/topk_ids
|
||||
```
|
||||
|
||||
#### Advanced Options
|
||||
```bash
|
||||
# Channel-wise quantization
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model meituan/DeepSeek-R1-Channel-INT8 \
|
||||
--tp-size 16 \
|
||||
--dtype int8_w8a8 \
|
||||
--per-channel-quant \
|
||||
--tune
|
||||
|
||||
# Specific batch size tuning
|
||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
||||
--batch-size 2048 \
|
||||
--tune
|
||||
```
|
||||
|
||||
### Configuration Files
|
||||
|
||||
After tuning, configuration files will be generated:
|
||||
- **Standard tuning**: `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`
|
||||
- **Separate kernel tuning**: Two files for up/down kernels with TMA optimization flags
|
||||
|
||||
Move these files to `sglang/srt/layers/moe/fused_moe_triton/configs/triton_version/` directory to use them in SGLang.
|
||||
|
||||
### Supported Models
|
||||
|
||||
- **Mixtral**: mistralai/Mixtral-8x7B-Instruct-v0.1, mixtral-8x22b
|
||||
- **Qwen**: Qwen2-57B, Qwen3-235B, Qwen3VL (MLLM)
|
||||
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3, DeepSeek-R1
|
||||
- **Llama**: Llama4-Vision (MLLM)
|
||||
- **DBRX**: databricks/dbrx-instruct
|
||||
- **Jamba**: ai21labs/AI21-Jamba
|
||||
- **Grok**: xai-org/grok-1
|
||||
- **GLM**: THUDM/glm-4-9b-chat
|
||||
- **Bailing**: Custom MoE models
|
||||
|
||||
### Parameters Reference
|
||||
|
||||
- `--model`: HuggingFace model name or local path
|
||||
- `--tp-size`: Tensor parallelism size (default: 2)
|
||||
- `--ep-size`: Expert parallelism size (default: 1, can be combined with TP mode, ensure tp_size is divisible by ep_size)
|
||||
- `--dtype`: Data type (`auto`, `fp8_w8a8`, `int8_w8a16`, `int8_w8a8`)
|
||||
- `--batch-size`: Specific batch size for tuning (optional)
|
||||
- `--tune`: Enable tuning mode
|
||||
- `--per-channel-quant`: Enable per-channel quantization
|
||||
- `--disable-shared-experts-fusion`: Disable shared expert fusion for some models
|
||||
- `--topk-ids-dir`: Directory containing pre-generated topk_ids (for sep tool only)
|
||||
- `--configs`: Manual config specification [BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, warps, stages]
|
||||
|
||||
### Performance Comparison Tool
|
||||
|
||||
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
# Compare with default settings (Mixtral model)
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
|
||||
|
||||
# Compare with FP8 mode for Qwen2-57B
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model Qwen/Qwen2-57B-A14B-Instruct \
|
||||
--use-fp8-w8a8
|
||||
|
||||
# Compare with custom TP size
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8
|
||||
|
||||
# Compare with custom TP size
|
||||
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tp-size 8
|
||||
```
|
||||
|
||||
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
||||
|
||||
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
|
||||
|
||||
Usage is similar to `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel. Both tools now support EP mode with `--ep-size` parameter.
|
||||
250
third_party/sglang/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
vendored
Normal file
250
third_party/sglang/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
vendored
Normal file
@@ -0,0 +1,250 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from common_utils import get_model_config
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_sglang,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
TopK,
|
||||
TopKConfig,
|
||||
TopKOutputFormat,
|
||||
select_experts,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
|
||||
|
||||
|
||||
def fused_moe_triton_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
):
|
||||
topk_op = TopK(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
output_format=TopKOutputFormat.TRITON_KERNEL,
|
||||
)
|
||||
triton_topk_output = topk_op.forward_cuda(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
)
|
||||
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=False,
|
||||
)
|
||||
return triton_kernel_moe_forward(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
triton_topk_output,
|
||||
moe_runner_config,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
topk_output = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=input_gating,
|
||||
topk_config=TopKConfig(top_k=topk, renormalize=False),
|
||||
)
|
||||
return fused_moe_sglang(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"sglang_fused_moe_triton_v340",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
line_names=[
|
||||
"sglang_fused_moe_triton_v340",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(
|
||||
batch_size,
|
||||
provider,
|
||||
model_config,
|
||||
use_fp8_w8a8=False,
|
||||
use_cuda_graph: bool = False,
|
||||
):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
w1_tri = w1_tri.transpose(-2, -1).contiguous()
|
||||
w2_tri = w2_tri.transpose(-2, -1).contiguous()
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
if provider == "sglang_fused_moe_triton_v340":
|
||||
api_func = fused_moe_triton_api
|
||||
api_kwargs = {
|
||||
"x": x,
|
||||
"w1": w1_tri,
|
||||
"w2": w2_tri,
|
||||
"input_gating": input_gating,
|
||||
"topk": topk,
|
||||
}
|
||||
else:
|
||||
api_func = fused_moe_sglang_api
|
||||
api_kwargs = {
|
||||
"x": x,
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"input_gating": input_gating,
|
||||
"topk": topk,
|
||||
"use_fp8_w8a8": use_fp8_w8a8,
|
||||
"block_shape": block_shape,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = api_func(**api_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if use_cuda_graph:
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
api_func(**api_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
bench_lambda = lambda: graph.replay()
|
||||
else:
|
||||
bench_lambda = lambda: api_func(**api_kwargs)
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
ms, min_ms, max_ms = run_bench(bench_lambda, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||
parser.add_argument("--ep-size", "--ep", type=int, default=1)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/sglang_fused_moe/",
|
||||
)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize global server args (required by SGLang MoE kernels)
|
||||
server_args = ServerArgs(model_path=args.model)
|
||||
set_global_server_args_for_scheduler(server_args)
|
||||
|
||||
try:
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="tcp://127.0.0.1:23456",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
expert_model_parallel_size=1,
|
||||
)
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size, args.ep_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
)
|
||||
finally:
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
306
third_party/sglang/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
vendored
Normal file
306
third_party/sglang/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
vendored
Normal file
@@ -0,0 +1,306 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from torch.nn import functional as F
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_triton,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
"""Get model configuration parameters"""
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
shape_configs = {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": config.hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
}
|
||||
print(f"{shape_configs=}")
|
||||
return shape_configs
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
M, _ = hidden_states.shape
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@torch.compile(dynamic=False)
|
||||
def fused_moe_torch(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
) -> torch.Tensor:
|
||||
assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
|
||||
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=x,
|
||||
gating_output=input_gating,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
w13_weights = w1[topk_ids]
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = w2[topk_ids]
|
||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||
x1 = F.silu(x1)
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
|
||||
|
||||
def fused_moe_torch_compile(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
):
|
||||
return fused_moe_torch(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
):
|
||||
return fused_moe_triton(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list(range(1, 5)),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"fused_moe_triton",
|
||||
"fused_moe_torch_compile",
|
||||
],
|
||||
line_names=[
|
||||
"fused_moe_triton",
|
||||
"fused_moe_torch_compile",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
set_torch_compile_config()
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
init_dtype = dtype
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
# Warmup
|
||||
api_func = (
|
||||
fused_moe_torch_compile
|
||||
if provider == "fused_moe_torch_compile"
|
||||
else fused_moe_sglang_api
|
||||
)
|
||||
for _ in range(10):
|
||||
y = api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)[0],
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", type=int, default=2)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/fused_moe_torch_compile/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_config = get_model_config(args.model, args.tp_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=model_config,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
265
third_party/sglang/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
vendored
Normal file
265
third_party/sglang/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
vendored
Normal file
@@ -0,0 +1,265 @@
|
||||
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
destroy_distributed_environment,
|
||||
destroy_model_parallel,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_sglang,
|
||||
)
|
||||
|
||||
from .common_utils import get_model_config
|
||||
|
||||
|
||||
def fused_moe_vllm_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
if block_shape is not None:
|
||||
return fused_moe_vllm(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
return fused_moe_vllm(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_sglang_api(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale=None,
|
||||
w2_scale=None,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
):
|
||||
return fused_moe_sglang(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=list(range(1, 513)),
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm_fused_moe_triton",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_fused_moe_triton",
|
||||
"sglang_fused_moe_triton",
|
||||
],
|
||||
styles=[
|
||||
("blue", "-"),
|
||||
("green", "-"),
|
||||
],
|
||||
ylabel="Time (ms)",
|
||||
plot_name="fused-moe-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
topk = model_config["topk"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
w1_scale = w2_scale = a1_scale = a2_scale = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
init_dtype = dtype
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
|
||||
if block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
|
||||
)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
# Warmup
|
||||
api_func = (
|
||||
fused_moe_vllm_api
|
||||
if provider == "vllm_fused_moe_triton"
|
||||
else fused_moe_sglang_api
|
||||
)
|
||||
for _ in range(10):
|
||||
y = api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: api_func(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)[0],
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||
parser.add_argument("--ep-size", "--ep", type=int, default=1)
|
||||
parser.add_argument("--use-fp8-w8a8", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
init_method="tcp://127.0.0.1:23456",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:23456",
|
||||
local_rank=0,
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=1,
|
||||
pipeline_model_parallel_size=1,
|
||||
)
|
||||
|
||||
shape_configs = get_model_config(args.model, args.tp_size, args.ep_size)
|
||||
benchmark.run(
|
||||
show_plots=True,
|
||||
print_data=True,
|
||||
save_path=args.save_path,
|
||||
model_config=shape_configs,
|
||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
||||
)
|
||||
finally:
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
288
third_party/sglang/benchmark/kernels/fused_moe_triton/common_utils.py
vendored
Normal file
288
third_party/sglang/benchmark/kernels/fused_moe_triton/common_utils.py
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
import json
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import get_config_dtype_str
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
|
||||
get_config_file_name,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils.hf_transformers_utils import get_config
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def calculate_shard_intermediate_size(
|
||||
intermediate_size: int, tp_size: int, ep_size: int = 1
|
||||
) -> int:
|
||||
assert tp_size % ep_size == 0
|
||||
moe_tp_size = tp_size // ep_size
|
||||
assert intermediate_size % moe_tp_size == 0
|
||||
return 2 * intermediate_size // moe_tp_size
|
||||
|
||||
|
||||
def get_model_config(
|
||||
model_name: str,
|
||||
tp_size: int,
|
||||
ep_size: int = 1,
|
||||
disable_shared_experts_fusion: bool = False,
|
||||
topk_ids_dir: str = None,
|
||||
) -> Dict:
|
||||
config = get_config(model_name, trust_remote_code=True)
|
||||
architecture = config.architectures[0]
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "config_groups" in config.quantization_config
|
||||
):
|
||||
config_groups = config.quantization_config["config_groups"]
|
||||
# Get group_size from the first group's weights config
|
||||
first_group = next(iter(config_groups.values()), {})
|
||||
weights_config = first_group.get("weights", {})
|
||||
group_size = weights_config.get("group_size")
|
||||
block_shape = [0, group_size]
|
||||
assert len(block_shape) == 2
|
||||
# Replace config with text_config for encoder-decoder models after getting block_shape and architecture
|
||||
if hasattr(config, "text_config"):
|
||||
config = config.get_text_config()
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
if architecture == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts // ep_size
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
elif architecture == "JambaForCausalLM":
|
||||
E = config.num_experts // ep_size
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
elif architecture in [
|
||||
"Qwen2MoeForCausalLM",
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Qwen3NextForCausalLM",
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"Qwen3_5MoeForConditionalGeneration",
|
||||
]:
|
||||
E = config.num_experts // ep_size
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif architecture in [
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
"GlmMoeDsaForCausalLM",
|
||||
"MistralLarge3ForCausalLM",
|
||||
]:
|
||||
E = (config.n_routed_experts // ep_size) + (
|
||||
0
|
||||
if disable_shared_experts_fusion
|
||||
or architecture
|
||||
not in [
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
"GlmMoeDsaForCausalLM",
|
||||
"MistralLarge3ForCausalLM",
|
||||
]
|
||||
else 1
|
||||
)
|
||||
topk = config.num_experts_per_tok + (
|
||||
0 if disable_shared_experts_fusion or topk_ids_dir is None else 1
|
||||
)
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif architecture == "Llama4ForConditionalGeneration":
|
||||
E = config.num_local_experts // ep_size + (
|
||||
0 if disable_shared_experts_fusion else 1
|
||||
)
|
||||
topk = config.num_experts_per_tok + (
|
||||
0 if disable_shared_experts_fusion or topk_ids_dir is None else 1
|
||||
)
|
||||
intermediate_size = config.intermediate_size
|
||||
elif architecture in [
|
||||
"Grok1ForCausalLM",
|
||||
"Grok1ImgGen",
|
||||
"Grok1AForCausalLM",
|
||||
]:
|
||||
E = config.num_local_experts // ep_size
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif architecture in [
|
||||
"BailingMoEForCausalLM",
|
||||
"BailingMoeForCausalLM",
|
||||
"BailingMoeV2ForCausalLM",
|
||||
]:
|
||||
E = config.num_experts // ep_size
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif architecture == "NemotronHForCausalLM":
|
||||
E = config.n_routed_experts // ep_size
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
hidden_size = getattr(config, "moe_latent_size", None) or hidden_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
E = config.num_local_experts // ep_size
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
shard_intermediate_size = calculate_shard_intermediate_size(
|
||||
intermediate_size, tp_size, ep_size
|
||||
)
|
||||
|
||||
return {
|
||||
"num_experts": E,
|
||||
"topk": topk,
|
||||
"hidden_size": hidden_size,
|
||||
"shard_intermediate_size": shard_intermediate_size,
|
||||
"dtype": config.torch_dtype,
|
||||
"block_shape": block_shape,
|
||||
"architecture": architecture,
|
||||
}
|
||||
|
||||
|
||||
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
waves_per_eu_range = 0
|
||||
for num_stages in [2]:
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for block_n in [16, 32, 64, 128, 256]:
|
||||
for num_warps in [1, 2, 4, 8]:
|
||||
for group_size in [1, 4, 8, 16, 32]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
if is_hip():
|
||||
configs = get_rocm_configs_compute_bound()
|
||||
else:
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
**(
|
||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||
),
|
||||
**({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: Dict[int, BenchmarkConfig],
|
||||
filename: str,
|
||||
) -> None:
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def get_config_filename(
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int],
|
||||
) -> str:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
N = shard_intermediate_size // 2
|
||||
if use_int4_w4a16:
|
||||
N = N // 2
|
||||
|
||||
filename = get_config_file_name(
|
||||
num_experts,
|
||||
N,
|
||||
dtype_str,
|
||||
block_shape,
|
||||
per_channel_quant,
|
||||
)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def get_default_batch_sizes() -> List[int]:
|
||||
return [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
71
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_client.py
vendored
Normal file
71
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_client.py
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
|
||||
"""
|
||||
# Edit the code file srt/models/deepseek_v2.py in the Python site package and add the logic for saving topk_ids:
|
||||
# import get_tensor_model_parallel_rank
|
||||
# DeepseekV2MoE::forward_normal
|
||||
if hidden_states.shape[0] >= 4096 and get_tensor_model_parallel_rank() == 0:
|
||||
topk_ids_dir = xxxx
|
||||
if not hasattr(self, "save_idx"):
|
||||
self.save_idx = 0
|
||||
if self.save_idx <= 1:
|
||||
torch.save(topk_output.topk_ids, f"{topk_ids_dir}/topk_ids_layer{self.layer_id}_idx{self.save_idx}.pt")
|
||||
self.save_idx += 1
|
||||
"""
|
||||
|
||||
|
||||
def read_long_prompt():
|
||||
import json
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(f"{current_dir}/tuning_text.json", "r") as fp:
|
||||
text = fp.read()
|
||||
rst = json.loads(text)
|
||||
return rst["prompt"]
|
||||
|
||||
|
||||
def openai_stream_test(model, ip, port):
|
||||
client = openai.Client(base_url=f"http://{ip}:{port}/v1", api_key="None")
|
||||
qst = read_long_prompt()
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": qst},
|
||||
]
|
||||
msg2 = dict(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.6,
|
||||
top_p=0.75,
|
||||
max_tokens=100,
|
||||
)
|
||||
response = client.chat.completions.create(**msg2, stream=True)
|
||||
time_start = time.time()
|
||||
time_cost = []
|
||||
for chunk in response:
|
||||
time_end = time.time()
|
||||
# if chunk.choices[0].delta.content:
|
||||
# print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
time_cost.append(time_end - time_start)
|
||||
time_start = time.time()
|
||||
|
||||
ttft = time_cost[0] + time_cost[1]
|
||||
tpot = sum(time_cost[2:]) / len(time_cost[2:])
|
||||
print(f"\nTTFT {ttft}, TPOT {tpot}")
|
||||
return ttft, tpot
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="auto")
|
||||
parser.add_argument(
|
||||
"--ip",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8188)
|
||||
args = parser.parse_args()
|
||||
openai_stream_test(args.model, args.ip, args.port)
|
||||
520
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
vendored
Normal file
520
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
vendored
Normal file
@@ -0,0 +1,520 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import triton
|
||||
from common_utils import (
|
||||
BenchmarkConfig,
|
||||
get_config_filename,
|
||||
get_configs_compute_bound,
|
||||
get_default_batch_sizes,
|
||||
get_model_config,
|
||||
save_configs,
|
||||
sort_config,
|
||||
)
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
|
||||
get_config_dtype_str,
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.server_args import (
|
||||
ServerArgs,
|
||||
set_global_server_args_for_scheduler,
|
||||
)
|
||||
from sglang.srt.utils import get_device, is_hip, is_xpu
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_xpu = is_xpu()
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int] = None,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16 or use_int8_w8a8:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
elif use_int4_w4a16:
|
||||
w1 = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size // 2,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 4,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_int4_w4a16:
|
||||
block_n = 1 if (block_shape[0] == 0) else block_shape[0]
|
||||
block_k = block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.bfloat16
|
||||
)
|
||||
w2_scale = torch.randn(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.bfloat16
|
||||
)
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
if use_int8_w8a8 and block_shape is None:
|
||||
w1_scale = torch.randn(
|
||||
num_experts, shard_intermediate_size, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
|
||||
elif block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
topk_output = select_experts(x, input_gating, topk_config)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating = gating_output[i]
|
||||
new_topk_output = select_experts(x, input_gating, topk_config)
|
||||
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
|
||||
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
|
||||
topk_output.router_logits.copy_(new_topk_output.router_logits)
|
||||
|
||||
def run():
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Flush L2 cache with 256 MB data
|
||||
cache_flush = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
|
||||
cache_flush.zero_()
|
||||
|
||||
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
|
||||
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
|
||||
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
start_events[i].record()
|
||||
graph.replay()
|
||||
end_events[i].record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
latencies: List[float] = []
|
||||
for i in range(num_iters):
|
||||
latencies.append(start_events[i].elapsed_time(end_events[i]))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int, server_args: ServerArgs) -> None:
|
||||
torch.set_default_device(get_device())
|
||||
torch.get_device_module().manual_seed_all(0)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
set_global_server_args_for_scheduler(server_args)
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int],
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(0)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
N = shard_intermediate_size // 2
|
||||
if use_int4_w4a16:
|
||||
N = N // 2
|
||||
op_config = get_moe_configs(
|
||||
num_experts,
|
||||
N,
|
||||
dtype_str,
|
||||
block_n,
|
||||
block_k,
|
||||
per_channel_quant,
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
False,
|
||||
block_shape,
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int],
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
with (
|
||||
torch.get_device_module().device(self.device_id)
|
||||
if _is_xpu or _is_hip
|
||||
else nullcontext()
|
||||
):
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
num_iters=10,
|
||||
)
|
||||
except (triton.runtime.autotuner.OutOfResources, RuntimeError):
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
server_args = ServerArgs(
|
||||
model_path=args.model, tp_size=args.tp_size, ep_size=args.ep_size
|
||||
)
|
||||
|
||||
model_config = get_model_config(
|
||||
args.model, args.tp_size, args.ep_size, args.disable_shared_experts_fusion
|
||||
)
|
||||
|
||||
E = model_config["num_experts"]
|
||||
topk = model_config["topk"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
use_int4_w4a16 = args.dtype == "int4_w4a16"
|
||||
per_channel_quant = args.per_channel_quant
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = get_default_batch_sizes()
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed, server_args) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
search_space = [
|
||||
config
|
||||
for config in search_space
|
||||
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
|
||||
filename = get_config_filename(
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
)
|
||||
print(
|
||||
f"Start tuning over {len(search_space)} configurations to create {filename}..."
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
search_space,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
best_configs = {
|
||||
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(
|
||||
best_configs,
|
||||
filename,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
print(f"Kernel time: {kernel_time:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||
parser.add_argument("--ep-size", "--ep", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8", "int4_w4a16"],
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-channel-quant",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
893
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py
vendored
Normal file
893
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py
vendored
Normal file
@@ -0,0 +1,893 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from common_utils import (
|
||||
BenchmarkConfig,
|
||||
get_config_filename,
|
||||
get_configs_compute_bound,
|
||||
get_default_batch_sizes,
|
||||
get_model_config,
|
||||
sort_config,
|
||||
)
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
get_config_dtype_str,
|
||||
invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
|
||||
get_config_file_name,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.srt.server_args import (
|
||||
ServerArgs,
|
||||
set_global_server_args_for_scheduler,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MoeInputs:
|
||||
topk_ids: torch.Tensor
|
||||
sorted_token_ids: torch.Tensor
|
||||
expert_ids: torch.Tensor
|
||||
num_tokens_post_padded: torch.Tensor
|
||||
|
||||
|
||||
class KernelWrapper:
|
||||
def __init__(self, moe_inputs, use_cuda_graph=True, inner_iter=10, **kwargs):
|
||||
self.func = invoke_fused_moe_kernel
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.moe_inputs = moe_inputs
|
||||
self.inner_iter = inner_iter
|
||||
self.kwargs = kwargs
|
||||
if use_cuda_graph:
|
||||
self.graph = self.cuda_graph_wrapper()
|
||||
else:
|
||||
self.graph = None
|
||||
|
||||
def cuda_graph_wrapper(self):
|
||||
moe_input = self.moe_inputs[0]
|
||||
self.func(
|
||||
**self.kwargs,
|
||||
topk_ids=moe_input.topk_ids,
|
||||
sorted_token_ids=moe_input.sorted_token_ids,
|
||||
expert_ids=moe_input.expert_ids,
|
||||
num_tokens_post_padded=moe_input.num_tokens_post_padded,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for k in range(self.inner_iter):
|
||||
moe_input = self.moe_inputs[k]
|
||||
self.func(
|
||||
**self.kwargs,
|
||||
topk_ids=moe_input.topk_ids,
|
||||
sorted_token_ids=moe_input.sorted_token_ids,
|
||||
expert_ids=moe_input.expert_ids,
|
||||
num_tokens_post_padded=moe_input.num_tokens_post_padded,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
return graph
|
||||
|
||||
def forward_cost(self, try_cnt=2):
|
||||
time_cost = float("inf")
|
||||
for _ in range(try_cnt):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
if self.use_cuda_graph:
|
||||
self.graph.replay()
|
||||
else:
|
||||
for k in range(self.inner_iter):
|
||||
moe_input = self.moe_inputs[k]
|
||||
self.func(
|
||||
**self.kwargs,
|
||||
topk_ids=moe_input.topk_ids,
|
||||
sorted_token_ids=moe_input.sorted_token_ids,
|
||||
expert_ids=moe_input.expert_ids,
|
||||
num_tokens_post_padded=moe_input.num_tokens_post_padded,
|
||||
)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
time_cost = min(time_cost, start_event.elapsed_time(end_event))
|
||||
return time_cost
|
||||
|
||||
|
||||
def load_topk_ids(topk_ids_dir, i: int):
|
||||
num_layers = 61
|
||||
dense_layers = 3
|
||||
moe_layers = num_layers - dense_layers
|
||||
return torch.load(
|
||||
f"{topk_ids_dir}/topk_ids_layer{i % moe_layers + dense_layers}_idx{i // moe_layers}.pt"
|
||||
)
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
topk_ids_list,
|
||||
block_shape: List[int] = None,
|
||||
ep_size: int = 1,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
ncu_enable = os.getenv("NCU_ENABLE", "0") == "1"
|
||||
if ncu_enable:
|
||||
num_iters = 1
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16 or use_int8_w8a8:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
elif use_int4_w4a16:
|
||||
w1 = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size // 2,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
0,
|
||||
255,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 4,
|
||||
),
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_int4_w4a16:
|
||||
block_n = 1 if (block_shape[0] == 0) else block_shape[0]
|
||||
block_k = block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.bfloat16
|
||||
)
|
||||
w2_scale = torch.randn(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.bfloat16
|
||||
)
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
if use_int8_w8a8 and block_shape is None:
|
||||
w1_scale = torch.randn(
|
||||
num_experts, shard_intermediate_size, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
|
||||
elif block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=True,
|
||||
)
|
||||
topk_output_ = select_experts(hidden_states, input_gating, topk_config)
|
||||
sorted_token_ids_, expert_ids_, num_tokens_post_padded_ = moe_align_block_size(
|
||||
topk_output_.topk_ids, config["BLOCK_SIZE_M"], num_experts
|
||||
)
|
||||
inner_iter = 10 if not ncu_enable else 1
|
||||
moe_inputs = [
|
||||
MoeInputs(
|
||||
topk_output_.topk_ids.clone(),
|
||||
sorted_token_ids_.clone(),
|
||||
expert_ids_.clone(),
|
||||
num_tokens_post_padded_.clone(),
|
||||
)
|
||||
for _ in range(inner_iter)
|
||||
]
|
||||
M = hidden_states.shape[0]
|
||||
E, N, _ = w1.shape
|
||||
|
||||
padded_tokens = min(M * topk, E + 1) * (
|
||||
config["BLOCK_SIZE_M"] - 1
|
||||
) # if moe_use_tma else 0
|
||||
total_tokens = M * topk + padded_tokens
|
||||
cache = torch.empty(
|
||||
total_tokens * max(N, w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = cache[: total_tokens * N].view(
|
||||
(total_tokens, N),
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(total_tokens, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = cache[: M * topk * w2.shape[1]].view(
|
||||
(M, topk, w2.shape[1]),
|
||||
)
|
||||
|
||||
def prepare(i: int, inner_iter): # update inputs according to topk_ids
|
||||
for k in range(inner_iter):
|
||||
topk_ids = topk_ids_list[i * inner_iter + k]
|
||||
# With EP, saved topk_ids are global expert indices; remap to local.
|
||||
if ep_size > 1:
|
||||
topk_ids = (topk_ids // ep_size).to(
|
||||
device=moe_inputs[k].topk_ids.device,
|
||||
dtype=moe_inputs[k].topk_ids.dtype,
|
||||
)
|
||||
tokens, _topk = moe_inputs[k].topk_ids.shape
|
||||
moe_inputs[k].topk_ids.copy_(topk_ids[:tokens, :_topk])
|
||||
sorted_token_ids_, expert_ids_, num_tokens_post_padded_ = (
|
||||
moe_align_block_size(
|
||||
moe_inputs[k].topk_ids, config["BLOCK_SIZE_M"], num_experts
|
||||
)
|
||||
)
|
||||
moe_inputs[k].sorted_token_ids.copy_(sorted_token_ids_)
|
||||
moe_inputs[k].expert_ids.copy_(expert_ids_)
|
||||
moe_inputs[k].num_tokens_post_padded.copy_(num_tokens_post_padded_)
|
||||
|
||||
def get_kernel_wrapper(moe_use_tma, inner_iter, use_cuda_graph):
|
||||
compute_type = (
|
||||
tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
)
|
||||
moe_runner_config = MoeRunnerConfig(
|
||||
inplace=True,
|
||||
)
|
||||
apply_router_weight_on_input = moe_runner_config.apply_router_weight_on_input
|
||||
kernel0 = KernelWrapper(
|
||||
A=hidden_states,
|
||||
B=w1,
|
||||
bias=None,
|
||||
C=intermediate_cache1,
|
||||
A_scale=a1_scale,
|
||||
B_scale=w1_scale,
|
||||
B_zp=None,
|
||||
topk_weights=topk_output_.topk_weights,
|
||||
moe_inputs=moe_inputs,
|
||||
mul_routed_weight=apply_router_weight_on_input,
|
||||
top_k=topk,
|
||||
config=config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=False,
|
||||
block_shape=block_shape,
|
||||
b_use_tma=moe_use_tma,
|
||||
c_sorted=moe_use_tma,
|
||||
filter_expert=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
inner_iter=inner_iter,
|
||||
)
|
||||
kernel1 = KernelWrapper(
|
||||
A=intermediate_cache2,
|
||||
B=w2,
|
||||
bias=None,
|
||||
C=intermediate_cache3,
|
||||
A_scale=a2_scale,
|
||||
B_scale=w2_scale,
|
||||
B_zp=None,
|
||||
topk_weights=topk_output_.topk_weights,
|
||||
moe_inputs=moe_inputs,
|
||||
mul_routed_weight=not apply_router_weight_on_input,
|
||||
top_k=1,
|
||||
config=config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=False,
|
||||
block_shape=block_shape,
|
||||
a_use_tma=moe_use_tma,
|
||||
b_use_tma=moe_use_tma,
|
||||
filter_expert=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
inner_iter=inner_iter,
|
||||
)
|
||||
return kernel0, kernel1
|
||||
|
||||
use_cuda_graph = True if not ncu_enable else False
|
||||
|
||||
kernel0, kernel1 = get_kernel_wrapper(False, inner_iter, use_cuda_graph)
|
||||
kernel_tma0, kernel_tma1 = get_kernel_wrapper(True, inner_iter, use_cuda_graph)
|
||||
|
||||
# JIT compilation & warmup
|
||||
if not ncu_enable:
|
||||
kernel0.forward_cost()
|
||||
kernel1.forward_cost()
|
||||
kernel_tma0.forward_cost()
|
||||
kernel_tma1.forward_cost()
|
||||
|
||||
ts0 = []
|
||||
ts1 = []
|
||||
ts_tma0 = []
|
||||
ts_tma1 = []
|
||||
|
||||
for i in range(num_iters // inner_iter):
|
||||
prepare(i, inner_iter)
|
||||
ts0.append(kernel0.forward_cost())
|
||||
ts1.append(kernel1.forward_cost())
|
||||
ts_tma0.append(kernel_tma0.forward_cost())
|
||||
ts_tma1.append(kernel_tma1.forward_cost())
|
||||
torch.cuda.synchronize()
|
||||
|
||||
avg = sum(ts0) / (num_iters) * 1000 # us
|
||||
avg1 = sum(ts1) / (num_iters) * 1000 # us
|
||||
avg_tma = sum(ts_tma0) / (num_iters) * 1000 # us
|
||||
avg1_tma = sum(ts_tma1) / (num_iters) * 1000 # us
|
||||
|
||||
return avg, avg_tma, avg1, avg1_tma
|
||||
|
||||
|
||||
class BestConfigTrace:
|
||||
def __init__(self, name, down_moe=False):
|
||||
self.name = name
|
||||
self.down_moe = down_moe
|
||||
self.best_costs_m = {} # block_m: best_cost
|
||||
|
||||
def update(self, config, time_cost_all):
|
||||
block_m = config["BLOCK_SIZE_M"]
|
||||
if not self.down_moe:
|
||||
time_cost = time_cost_all[0]
|
||||
else:
|
||||
time_cost = min(time_cost_all[2], time_cost_all[3])
|
||||
if (
|
||||
block_m not in self.best_costs_m
|
||||
or time_cost < self.best_costs_m[block_m][1]
|
||||
):
|
||||
self.best_costs_m[block_m] = config, time_cost, time_cost_all
|
||||
|
||||
def time_cost(self, block_m):
|
||||
if block_m not in self.best_costs_m:
|
||||
return float("inf")
|
||||
time_cost = self.best_costs_m[block_m][1]
|
||||
return time_cost
|
||||
|
||||
def config_dict(self, block_m):
|
||||
if block_m not in self.best_costs_m:
|
||||
return {}
|
||||
config, _, time_cost_all = self.best_costs_m[block_m]
|
||||
if not self.down_moe:
|
||||
return config
|
||||
else:
|
||||
return {
|
||||
**config,
|
||||
"USE_TMA": time_cost_all[2] > time_cost_all[3],
|
||||
}
|
||||
|
||||
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int, server_args: ServerArgs) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU.
|
||||
self.device_id = 0 # int(ray.get_gpu_ids()[0])
|
||||
set_global_server_args_for_scheduler(server_args)
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: List[int],
|
||||
cfg: Dict[str, int],
|
||||
topk_ids_dir: str,
|
||||
ep_size: int = 1,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(0)
|
||||
topk_ids_list = [load_topk_ids(topk_ids_dir, i) for i in range(100)]
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
kernel_time = benchmark_config(
|
||||
cfg,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
topk_ids_list,
|
||||
block_shape,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
return cfg, kernel_time
|
||||
|
||||
def tune(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: List[int],
|
||||
search_space: List[Dict[str, int]],
|
||||
topk_ids_dir: str,
|
||||
ep_size: int = 1,
|
||||
) -> Dict[str, int]:
|
||||
trace0 = BestConfigTrace("kernel0", down_moe=False)
|
||||
trace1 = BestConfigTrace("kernel1", down_moe=True)
|
||||
topk_ids_list = [load_topk_ids(topk_ids_dir, i) for i in range(100)]
|
||||
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
topk_ids_list,
|
||||
block_shape,
|
||||
ep_size=ep_size,
|
||||
num_iters=100,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
trace0.update(
|
||||
config,
|
||||
(kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma),
|
||||
)
|
||||
trace1.update(
|
||||
config,
|
||||
(kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma),
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
best_block_m = 16
|
||||
for block_m in (32, 64, 128, 256):
|
||||
if trace0.time_cost(block_m) + trace1.time_cost(block_m) < trace0.time_cost(
|
||||
best_block_m
|
||||
) + trace1.time_cost(best_block_m):
|
||||
best_block_m = block_m
|
||||
|
||||
return (
|
||||
trace0.config_dict(best_block_m),
|
||||
trace1.config_dict(best_block_m),
|
||||
trace0.time_cost(best_block_m),
|
||||
trace1.time_cost(best_block_m),
|
||||
)
|
||||
|
||||
def cmp_configs(
|
||||
self,
|
||||
num_tokens: List[int],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: List[int],
|
||||
cmp_config_files: List[str],
|
||||
topk_ids_dir: str,
|
||||
ep_size: int = 1,
|
||||
):
|
||||
# compare performance of different configs
|
||||
cmp_configs = []
|
||||
for file in cmp_config_files:
|
||||
with open(file) as f:
|
||||
cmp_configs.append({int(key): val for key, val in json.load(f).items()})
|
||||
for i, file in enumerate(cmp_config_files):
|
||||
print(f"config {i}: {file}")
|
||||
|
||||
topk_ids_list = [load_topk_ids(topk_ids_dir, i) for i in range(100)]
|
||||
torch.cuda.manual_seed_all(0)
|
||||
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
|
||||
for bs in num_tokens:
|
||||
kernel_times = []
|
||||
cfgs = []
|
||||
for configs in cmp_configs:
|
||||
cfg_org = configs[min(configs.keys(), key=lambda x: abs(x - bs))]
|
||||
cfgs.append(cfg_org)
|
||||
cfg = cfg_org.copy()
|
||||
cfg.pop("USE_TMA", None)
|
||||
kernel_time = benchmark_config(
|
||||
cfg,
|
||||
bs,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
topk_ids_list,
|
||||
block_shape,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
kernel_times.append(kernel_time)
|
||||
print(f"batch_size={bs=}:")
|
||||
for i, cfg in enumerate(cfgs):
|
||||
print(f" config {i} {cfg}: {kernel_times[i]}")
|
||||
|
||||
|
||||
def save_configs_sep(
|
||||
configs: Dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: List[int],
|
||||
down_moe: bool = False,
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(
|
||||
num_experts,
|
||||
shard_intermediate_size // 2,
|
||||
dtype_str,
|
||||
block_shape,
|
||||
down_moe=down_moe,
|
||||
)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
server_args = ServerArgs(
|
||||
model_path=args.model, tp_size=args.tp_size, ep_size=args.ep_size
|
||||
)
|
||||
|
||||
model_config = get_model_config(
|
||||
args.model,
|
||||
args.tp_size,
|
||||
args.ep_size,
|
||||
args.disable_shared_experts_fusion,
|
||||
args.topk_ids_dir,
|
||||
)
|
||||
|
||||
E = model_config["num_experts"]
|
||||
topk = model_config["topk"]
|
||||
hidden_size = model_config["hidden_size"]
|
||||
shard_intermediate_size = model_config["shard_intermediate_size"]
|
||||
dtype = model_config["dtype"]
|
||||
block_shape = model_config["block_shape"]
|
||||
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
use_int4_w4a16 = args.dtype == "int4_w4a16"
|
||||
|
||||
topk_ids_dir = args.topk_ids_dir
|
||||
if args.batch_size is None:
|
||||
batch_sizes = get_default_batch_sizes()
|
||||
batch_sizes.reverse()
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
if args.cmp_configs is not None:
|
||||
worker = BenchmarkWorker(args.seed, server_args)
|
||||
worker.cmp_configs(
|
||||
batch_sizes,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_shape,
|
||||
args.cmp_configs,
|
||||
topk_ids_dir,
|
||||
args.ep_size,
|
||||
)
|
||||
return
|
||||
|
||||
if len(batch_sizes) == 1:
|
||||
worker = BenchmarkWorker(args.seed, server_args)
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
worker.tune(
|
||||
batch_sizes[0],
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_shape,
|
||||
search_space,
|
||||
topk_ids_dir,
|
||||
args.ep_size,
|
||||
)
|
||||
else:
|
||||
cfg = {
|
||||
"BLOCK_SIZE_M": args.configs[0],
|
||||
"BLOCK_SIZE_N": args.configs[1],
|
||||
"BLOCK_SIZE_K": args.configs[2],
|
||||
"GROUP_SIZE_M": args.configs[3],
|
||||
"num_warps": args.configs[4],
|
||||
"num_stages": args.configs[5],
|
||||
}
|
||||
|
||||
_, (t0, t0_tma, t1, t1_tma) = worker.benchmark(
|
||||
args.batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_shape,
|
||||
cfg,
|
||||
topk_ids_dir,
|
||||
args.ep_size,
|
||||
)
|
||||
print(f"{t0=}, {t0_tma=}, {t1=}, {t1_tma=}")
|
||||
return
|
||||
|
||||
assert args.tune
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [
|
||||
ray.remote(num_gpus=1)(BenchmarkWorker).remote(args.seed, server_args)
|
||||
for _ in range(num_gpus)
|
||||
]
|
||||
|
||||
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
search_space = get_configs_compute_bound()
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
search_space = [
|
||||
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
filename = get_config_filename(
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
False,
|
||||
block_shape,
|
||||
)
|
||||
print(
|
||||
f"Start tuning over {len(search_space)} configurations to create {filename}..."
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_shape,
|
||||
search_space,
|
||||
topk_ids_dir,
|
||||
args.ep_size,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
print(f"{configs=}", flush=True)
|
||||
cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
with open(f"tuning_result_{cur_time}.txt", "w") as f:
|
||||
print(configs, file=f)
|
||||
batch_sizes.reverse()
|
||||
configs0 = [config[0] for config in configs]
|
||||
configs1 = [config[1] for config in configs]
|
||||
configs0.reverse()
|
||||
configs1.reverse()
|
||||
best_configs0 = {M: sort_config(config) for M, config in zip(batch_sizes, configs0)}
|
||||
save_configs_sep(
|
||||
best_configs0,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)}
|
||||
save_configs_sep(
|
||||
best_configs1,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_shape,
|
||||
down_moe=True,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||
parser.add_argument("--ep-size", "--ep", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8", "int8_w4a16"],
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
|
||||
parser.add_argument("--configs", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--topk-ids-dir", type=str, required=True)
|
||||
parser.add_argument("--cmp-configs", type=str, nargs="+", required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
1
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_text.json
vendored
Normal file
1
third_party/sglang/benchmark/kernels/fused_moe_triton/tuning_text.json
vendored
Normal file
File diff suppressed because one or more lines are too long
92
third_party/sglang/benchmark/kernels/quantization/README.md
vendored
Normal file
92
third_party/sglang/benchmark/kernels/quantization/README.md
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
# W8A8 Block-wise Quantization Kernel Tuning
|
||||
|
||||
Auto-tune Triton FP8/INT8 block-wise quantization kernels for optimal performance.
|
||||
|
||||
## When to Use Triton FP8 Block-wise Quantization Kernel vs DeepGEMM
|
||||
|
||||
**Use Triton FP8 Block-wise Quantization Kernel when:**
|
||||
- Output dtype is NOT `bfloat16` (e.g., `float16`, `float32`)
|
||||
- DeepGEMM is disabled (environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`)
|
||||
- Running on GPUs with compute capability < SM90 (DeepGEMM requires SM90+)
|
||||
- You need cross-platform compatibility (Triton works on both NVIDIA and AMD GPUs)
|
||||
|
||||
**Use DeepGEMM when:**
|
||||
- Output dtype is `bfloat16` AND DeepGEMM is enabled
|
||||
- Running on NVIDIA GPUs with compute capability >= SM90 (e.g., H100, H200)
|
||||
- Need maximum performance for production workloads (DeepGEMM is highly optimized for Hopper architecture)
|
||||
|
||||
**Note:** DeepGEMM requires CUDA compute capability >= 9.0 (SM90+). It is specifically optimized for NVIDIA Hopper GPUs (H100/H200).
|
||||
|
||||
The kernel selection logic in SGLang automatically chooses DeepGEMM when conditions are met (see `w8a8_block_fp8_matmul` function in `fp8_kernel.py`), otherwise falls back to Triton implementation.
|
||||
|
||||
## Quick Start
|
||||
|
||||
**Default (DeepSeek-V3):**
|
||||
```bash
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --tp-size 8
|
||||
```
|
||||
|
||||
**Custom Model (specify N and K):**
|
||||
```bash
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- `--N`, `--K`: Weight matrix dimensions (N=output_dim, K=input_dim). If not specified, uses `--tp-size` for DeepSeek-V3
|
||||
- `--tp-size`: Tensor parallelism size for DeepSeek-V3 (default: 8)
|
||||
- `--input-type`: `fp8` or `int8` (default: fp8)
|
||||
- `--block-n`, `--block-k`: Block quantization granularity (default: 128)
|
||||
- `--batch-size`: Test single batch size (optional)
|
||||
|
||||
## How to Calculate N and K
|
||||
|
||||
For a linear layer `y = xW^T` where `x` is (M, K) and `W` is (N, K):
|
||||
- **N**: Output features (weight matrix output dimension)
|
||||
- **K**: Input features (weight matrix input dimension)
|
||||
|
||||
**Example: Qwen3-VL-32B** (hidden_size=5120, intermediate_size=25600, num_heads=64, num_kv_heads=8, head_dim=128) and TP=1
|
||||
```bash
|
||||
# QKV projection: Q(8192) + K(1024) + V(1024) = 10240
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 10240 --K 5120
|
||||
|
||||
# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 51200 --K 5120
|
||||
|
||||
# MLP down projection
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600
|
||||
|
||||
# O projection (if separate from QKV)
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 8192
|
||||
```
|
||||
|
||||
If TP=8:
|
||||
|
||||
```bash
|
||||
# QKV projection: Q(8192) + K(1024) + V(1024) = 10240 / TP=8
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 1280 --K 5120
|
||||
|
||||
# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200 / TP=8
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 6400 --K 5120
|
||||
|
||||
# MLP down projection
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 3200
|
||||
|
||||
# O projection (if separate from QKV)
|
||||
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 1024
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
Generates JSON config files saved to `python/sglang/srt/layers/quantization/configs/`:
|
||||
```
|
||||
N={N},K={K},device_name={DEVICE},dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
```
|
||||
|
||||
Config maps batch size to optimal kernel parameters:
|
||||
```json
|
||||
{
|
||||
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, ...},
|
||||
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, ...}
|
||||
}
|
||||
```
|
||||
137
third_party/sglang/benchmark/kernels/quantization/bench_fp4_quant.py
vendored
Normal file
137
third_party/sglang/benchmark/kernels/quantization/bench_fp4_quant.py
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer import (
|
||||
scaled_fp4_grouped_quantize,
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize,
|
||||
)
|
||||
from sgl_kernel.elementwise import silu_and_mul
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
||||
|
||||
|
||||
def _test_accuracy_once(E, M, K, input_dtype, device):
|
||||
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
|
||||
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
|
||||
masks = torch.full((E,), M, dtype=torch.int32, device=device)
|
||||
out, blk_scales = silu_and_mul_scaled_nvfp4_experts_quantize(x, masks, glb_scales)
|
||||
out1, blk_scales1 = scaled_fp4_grouped_quantize(
|
||||
silu_and_mul(x),
|
||||
masks,
|
||||
glb_scales,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out1)
|
||||
torch.testing.assert_close(blk_scales, blk_scales1)
|
||||
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")
|
||||
|
||||
|
||||
NUM_RANKS = 48
|
||||
M_PER_RANKs = [128, 256, 512, 1024]
|
||||
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
|
||||
Ks = [2048, 4096, 7168]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "K"],
|
||||
x_vals=list(itertools.product(Ms, Ks)),
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
|
||||
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="fp4 quant",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, K, provider):
|
||||
E = 6
|
||||
device = "cuda"
|
||||
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
|
||||
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
|
||||
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
|
||||
fp8_out = torch.empty(
|
||||
(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
x.shape[2] // 2,
|
||||
),
|
||||
device=x.device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
scale_block_size = 128
|
||||
fp8_scales = torch.empty(
|
||||
(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
x.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
if provider == "triton_fp8":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: silu_and_mul_masked_post_quant_fwd(
|
||||
x,
|
||||
fp8_out,
|
||||
fp8_scales,
|
||||
scale_block_size,
|
||||
masks,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cuda_unfused_fp4":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: scaled_fp4_grouped_quantize(
|
||||
silu_and_mul(x),
|
||||
masks,
|
||||
glb_scales,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cuda_fused_fp4":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: silu_and_mul_scaled_nvfp4_experts_quantize(
|
||||
x,
|
||||
masks,
|
||||
glb_scales,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
E = 6
|
||||
N_RANKS = 48
|
||||
Ms = [128, 256, 512, 1024]
|
||||
Ks = [2048, 4096, 7168]
|
||||
input_dtype = torch.bfloat16
|
||||
for M in Ms:
|
||||
for K in Ks:
|
||||
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./bench_fp4_quant_res",
|
||||
help="Path to save fp4 quant benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_accuracy()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
|
||||
95
third_party/sglang/benchmark/kernels/quantization/bench_int8_quant.py
vendored
Normal file
95
third_party/sglang/benchmark/kernels/quantization/bench_int8_quant.py
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
|
||||
|
||||
@torch.compile(backend="inductor")
|
||||
def torch_int8_quant(x):
|
||||
int8_max = torch.iinfo(torch.int8).max
|
||||
|
||||
abs_max = x.abs().max(dim=-1, keepdim=True).values
|
||||
scales = abs_max.to(torch.float32) / float(int8_max)
|
||||
|
||||
q_x = (x / scales).round().to(torch.int8)
|
||||
|
||||
return q_x, scales
|
||||
|
||||
|
||||
def _test_accuracy_once(M, K, input_dtype, device):
|
||||
x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000
|
||||
out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)
|
||||
out1, scales1 = per_token_quant_int8(x)
|
||||
out2, scales2 = torch_int8_quant(x)
|
||||
torch.testing.assert_close(out, out2, atol=1, rtol=0)
|
||||
torch.testing.assert_close(out, out1, atol=1, rtol=0)
|
||||
torch.testing.assert_close(scales, scales2)
|
||||
torch.testing.assert_close(scales1, scales2)
|
||||
print(f"M: {M}, K: {K}, type: {input_dtype} OK")
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
Ms = [1, 13, 128, 1024, 2048, 4096]
|
||||
Ks = [512, 1024, 2048, 8192]
|
||||
input_dtypes = [torch.float16, torch.bfloat16]
|
||||
for M in Ms:
|
||||
for K in Ks:
|
||||
for input_dtype in input_dtypes:
|
||||
_test_accuracy_once(M, K, input_dtype, "cuda")
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm op", "triton", "torch.compile"],
|
||||
line_names=["vllm op", "triton", "torch.compile"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="int8 per token quant",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
M, K = batch_size, 16384
|
||||
x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000
|
||||
|
||||
quantiles = (0.5, 0.2, 0.8)
|
||||
if provider == "vllm op":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: vllm_scaled_int8_quant(x, symmetric=True),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: per_token_quant_int8(x),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "torch.compile":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: torch_int8_quant(x),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./bench_int8_quant_res",
|
||||
help="Path to save int8 quant benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_accuracy()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
|
||||
527
third_party/sglang/benchmark/kernels/quantization/tuning_block_wise_kernel.py
vendored
Normal file
527
third_party/sglang/benchmark/kernels/quantization/tuning_block_wise_kernel.py
vendored
Normal file
@@ -0,0 +1,527 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
_w8a8_block_fp8_matmul_unrolledx4,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
|
||||
from sglang.srt.utils import (
|
||||
get_device,
|
||||
get_device_core_count,
|
||||
get_device_count,
|
||||
get_device_name,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def w8a8_block_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
config: Dict[str, Any],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise quantization.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
needs_masking = bool(K % config["BLOCK_SIZE_K"] != 0)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
||||
# Empirical testing shows the sweet spot lies when it's less than the # of
|
||||
# compute units available on the device.
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
|
||||
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
else:
|
||||
kernel = _w8a8_block_int8_matmul
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
needs_masking=needs_masking,
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def get_rocm_configs_compute_bound():
|
||||
configs = []
|
||||
waves_per_eu_range = 0
|
||||
for num_stages in [2]:
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for block_n in [16, 32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 4, 8, 16, 32]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
if _is_hip:
|
||||
configs = get_rocm_configs_compute_bound()
|
||||
else:
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||
):
|
||||
def run():
|
||||
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||
|
||||
torch.get_device_module().synchronize()
|
||||
# JIT complication & warmup
|
||||
for _ in range(5):
|
||||
run()
|
||||
torch.get_device_module().synchronize()
|
||||
|
||||
start_event = torch.get_device_module().Event(enable_timing=True)
|
||||
end_event = torch.get_device_module().Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
for i in range(num_iters):
|
||||
torch.get_device_module().synchronize()
|
||||
start_event.record()
|
||||
run()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
return avg
|
||||
|
||||
|
||||
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||
factor_for_scale = 1e-2
|
||||
device = get_device()
|
||||
|
||||
if input_type == "fp8":
|
||||
fp8_info = torch.finfo(
|
||||
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
)
|
||||
else:
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * int8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * int8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale
|
||||
Bs = (
|
||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device)
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
block_size,
|
||||
config,
|
||||
out_dtype,
|
||||
num_iters=10,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def save_configs(
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
configs,
|
||||
save_path,
|
||||
input_type="fp8",
|
||||
lock=None,
|
||||
) -> None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
device_name = get_device_name().replace(" ", "_")
|
||||
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
|
||||
|
||||
config_file_path = os.path.join(save_path, json_file_name)
|
||||
print(f"Writing best config to {config_file_path}...")
|
||||
|
||||
if lock is not None:
|
||||
lock.acquire()
|
||||
try:
|
||||
existing_configs = {}
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path, "r") as f:
|
||||
existing_configs = json.load(f)
|
||||
existing_configs = {int(k): v for k, v in existing_configs.items()}
|
||||
|
||||
existing_configs.update(configs)
|
||||
|
||||
with open(config_file_path, "w") as f:
|
||||
json.dump(existing_configs, f, indent=4)
|
||||
f.write("\n")
|
||||
finally:
|
||||
if lock is not None:
|
||||
lock.release()
|
||||
|
||||
|
||||
def tune_on_gpu(args_dict):
|
||||
"""Run tuning on a specific GPU."""
|
||||
gpu_id = args_dict["gpu_id"]
|
||||
batch_sizes = args_dict["batch_sizes"]
|
||||
weight_shapes = args_dict["weight_shapes"]
|
||||
args = args_dict["args"]
|
||||
lock = args_dict["lock"]
|
||||
|
||||
torch.get_device_module().set_device(gpu_id)
|
||||
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
|
||||
|
||||
block_n = args.block_n
|
||||
block_k = args.block_k
|
||||
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||
save_path = args.save_path
|
||||
input_type = args.input_type
|
||||
|
||||
search_space = get_configs_compute_bound()
|
||||
search_space = [
|
||||
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
results = {}
|
||||
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
|
||||
N, K = shape[0], shape[1]
|
||||
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||
benchmark_results = [
|
||||
tune(
|
||||
batch_size,
|
||||
N,
|
||||
K,
|
||||
[block_n, block_k],
|
||||
out_dtype,
|
||||
search_space,
|
||||
input_type,
|
||||
)
|
||||
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
||||
]
|
||||
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type, lock)
|
||||
|
||||
end = time.perf_counter()
|
||||
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||
|
||||
|
||||
def distribute_batch_sizes(batch_sizes, num_gpus):
|
||||
"""Distribute batch sizes across available GPUs."""
|
||||
batches_per_gpu = []
|
||||
for i in range(num_gpus):
|
||||
start_idx = i * len(batch_sizes) // num_gpus
|
||||
end_idx = (i + 1) * len(batch_sizes) // num_gpus
|
||||
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
|
||||
return batches_per_gpu
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
|
||||
num_gpus = get_device_count()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("No GPU available for tuning")
|
||||
print(f"Found {num_gpus} GPUs for parallel tuning")
|
||||
|
||||
torch.get_device_module().init()
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
num_gpus = 1 # If only one batch size, use only one GPU
|
||||
|
||||
# Support manual N and K specification
|
||||
if args.N is not None and args.K is not None:
|
||||
weight_shapes = [(args.N, args.K)]
|
||||
print(f"Using manually specified weight shape: N={args.N}, K={args.K}")
|
||||
else:
|
||||
weight_shapes = get_weight_shapes(args.tp_size)
|
||||
print(f"Using predefined weight shapes for TP size {args.tp_size}")
|
||||
|
||||
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
manager = ctx.Manager()
|
||||
lock = manager.Lock()
|
||||
|
||||
process_args = []
|
||||
for gpu_id in range(num_gpus):
|
||||
process_args.append(
|
||||
{
|
||||
"gpu_id": gpu_id,
|
||||
"batch_sizes": batches_per_gpu[gpu_id],
|
||||
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
|
||||
"args": args,
|
||||
"lock": lock,
|
||||
}
|
||||
)
|
||||
|
||||
with ctx.Pool(num_gpus) as pool:
|
||||
pool.map(tune_on_gpu, process_args)
|
||||
|
||||
print("Multi-GPU tuning completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
"-tp",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Tensor parallelism size (ignored if --N and --K are specified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--N",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output dimension of weight matrix (number of columns)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--K",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input dimension of weight matrix (number of rows)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dtype",
|
||||
type=str,
|
||||
choices=["float32", "float16", "bfloat16", "half"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument("--block-n", type=int, default=128)
|
||||
parser.add_argument("--block-k", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument(
|
||||
"--save-path", type=str, default="python/sglang/srt/layers/quantization/configs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if (args.N is None) != (args.K is None):
|
||||
parser.error("--N and --K must be specified together or not at all")
|
||||
|
||||
main(args)
|
||||
171
third_party/sglang/benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py
vendored
Normal file
171
third_party/sglang/benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_last_loc_torch(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.where(
|
||||
prefix_lens_tensor > 0,
|
||||
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
||||
torch.full_like(prefix_lens_tensor, -1),
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_last_loc_kernel(
|
||||
req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
result,
|
||||
num_tokens,
|
||||
req_to_token_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
||||
mask = offset < num_tokens
|
||||
|
||||
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
|
||||
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
|
||||
|
||||
token_mask = prefix_lens > 0
|
||||
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
|
||||
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
|
||||
|
||||
tl.store(result + offset, tokens, mask=mask)
|
||||
|
||||
|
||||
def get_last_loc_triton(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
BLOCK_SIZE = 256
|
||||
num_tokens = prefix_lens_tensor.shape[0]
|
||||
result = torch.empty_like(prefix_lens_tensor)
|
||||
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
|
||||
|
||||
get_last_loc_kernel[grid](
|
||||
req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
result,
|
||||
num_tokens,
|
||||
req_to_token.stride(0),
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def test_get_last_loc():
|
||||
max_batch = 4097
|
||||
max_context_len = 6148
|
||||
batch_size = 20
|
||||
|
||||
# Initialize input tensors
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
|
||||
pre_lens = torch.randint(
|
||||
-max_context_len // 2,
|
||||
max_context_len,
|
||||
(batch_size,),
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens)
|
||||
last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(last_loc_res, last_loc_ref)
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=batch_sizes,
|
||||
line_arg="provider",
|
||||
line_vals=["reference", "triton"],
|
||||
line_names=["PyTorch", "Triton"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="get-last-loc-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
max_batch = 2048
|
||||
max_context_len = 16384
|
||||
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
|
||||
pre_lens = torch.randint(
|
||||
-max_context_len // 2,
|
||||
max_context_len,
|
||||
(batch_size,),
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "reference":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens),
|
||||
quantiles=tuple(quantiles),
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens),
|
||||
quantiles=tuple(quantiles),
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"):
|
||||
"""Run benchmark and save results"""
|
||||
|
||||
# Ensure save path exists
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Run correctness test
|
||||
test_get_last_loc()
|
||||
print("Correctness test passed!")
|
||||
|
||||
# Run performance test
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(print_data=True, save_path=save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/get_last_loc/",
|
||||
help="Path to save benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_benchmark(args.save_path)
|
||||
342
third_party/sglang/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py
vendored
Normal file
342
third_party/sglang/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py
vendored
Normal file
@@ -0,0 +1,342 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices + pid)
|
||||
pre_len = tl.load(pre_lens + pid)
|
||||
seq_len = tl.load(seq_lens + pid)
|
||||
|
||||
# TODO: optimize this?
|
||||
cumsum_start = 0
|
||||
for i in range(pid):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < (seq_len - pre_len)
|
||||
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
||||
tl.store(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ offset
|
||||
+ pre_len,
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton_optimize(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_batch = tl.program_id(0)
|
||||
pid_token = tl.program_id(1)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices + pid_batch)
|
||||
pre_len = tl.load(pre_lens + pid_batch)
|
||||
seq_len = tl.load(seq_lens + pid_batch)
|
||||
extend_len = seq_len - pre_len
|
||||
|
||||
cumsum_start = 0
|
||||
for i in range(pid_batch):
|
||||
cumsum_start += tl.load(extend_lens + i)
|
||||
|
||||
token_start = pid_token * BLOCK_SIZE
|
||||
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
actual_offset = token_start + offset
|
||||
mask = actual_offset < extend_len
|
||||
|
||||
src_ptr = out_cache_loc + cumsum_start + actual_offset
|
||||
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||
value = tl.load(src_ptr, mask=mask)
|
||||
dst_ptr = (
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ actual_offset
|
||||
+ pre_len
|
||||
)
|
||||
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||
|
||||
tl.store(dst_ptr, value, mask=mask)
|
||||
|
||||
|
||||
def write_req_to_token_pool_reference(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
pre_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
extend_lens: torch.Tensor,
|
||||
out_cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
"""Reference implementation using PyTorch"""
|
||||
for i in range(len(req_pool_indices)):
|
||||
req_pool_idx = req_pool_indices[i].item()
|
||||
pre_len = pre_lens[i].item()
|
||||
seq_len = seq_lens[i].item()
|
||||
extend_len = extend_lens[i].item()
|
||||
|
||||
cumsum_start = sum(extend_lens[:i].tolist())
|
||||
|
||||
# Copy values from out_cache_loc to req_to_token
|
||||
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
|
||||
cumsum_start : cumsum_start + extend_len
|
||||
]
|
||||
|
||||
|
||||
def test_write_req_to_token_pool():
|
||||
max_batch = 4097
|
||||
max_context_len = 6148
|
||||
batch_size = 1
|
||||
extend_len = 14
|
||||
|
||||
# Initialize input tensors
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
|
||||
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
|
||||
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
|
||||
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
|
||||
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
# Create copies for reference implementation
|
||||
req_to_token_ref = req_to_token.clone()
|
||||
req_to_token_opt = req_to_token.clone()
|
||||
|
||||
# Run original triton kernel
|
||||
write_req_to_token_pool_triton[(batch_size,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
)
|
||||
|
||||
# Run optimized triton kernel
|
||||
def grid(batch_size, extend_len):
|
||||
num_token_blocks = triton.cdiv(extend_len, 512)
|
||||
return (batch_size, num_token_blocks)
|
||||
|
||||
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
|
||||
req_to_token_opt,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
write_req_to_token_pool_reference(
|
||||
req_to_token_ref,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||
|
||||
# Test case 2: batch size > 1
|
||||
batch_size = 3
|
||||
extend_lens_list = [14, 20, 30]
|
||||
total_extend_len = sum(extend_lens_list)
|
||||
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
|
||||
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
|
||||
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
|
||||
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
req_to_token_ref = req_to_token.clone()
|
||||
req_to_token_opt = req_to_token.clone()
|
||||
|
||||
# Run original triton kernel
|
||||
write_req_to_token_pool_triton[(batch_size,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
)
|
||||
|
||||
# Run optimized triton kernel
|
||||
max_extend_len = max(extend_lens_list)
|
||||
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
|
||||
req_to_token_opt,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
write_req_to_token_pool_reference(
|
||||
req_to_token_ref,
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
|
||||
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
configs = list(itertools.product(batch_sizes, extend_lens))
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "extend_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["reference", "triton", "triton_optimize"],
|
||||
line_names=["PyTorch", "Triton", "Triton Optimized"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="write-req-to-token-pool-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, extend_len, provider):
|
||||
max_batch = 256
|
||||
max_context_len = 16384
|
||||
|
||||
extend_lens_list = [extend_len] * batch_size
|
||||
total_extend_len = sum(extend_lens_list)
|
||||
|
||||
req_to_token = torch.zeros(
|
||||
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
|
||||
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
|
||||
seq_lens = pre_lens + extend_len
|
||||
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "reference":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: write_req_to_token_pool_reference(
|
||||
req_to_token.clone(),
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
),
|
||||
quantiles=tuple(quantiles),
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = run_bench(
|
||||
lambda: write_req_to_token_pool_triton[(batch_size,)](
|
||||
req_to_token.clone(),
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
),
|
||||
quantiles=tuple(quantiles),
|
||||
)
|
||||
else:
|
||||
|
||||
def run_optimized():
|
||||
block_size = 128 if extend_len <= 1024 else 512
|
||||
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
|
||||
write_req_to_token_pool_triton_optimize[grid_config](
|
||||
req_to_token.clone(),
|
||||
req_pool_indices,
|
||||
pre_lens,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
out_cache_loc,
|
||||
max_context_len,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = run_bench(run_optimized, quantiles=tuple(quantiles))
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
|
||||
"""Run benchmark and save results"""
|
||||
|
||||
# Ensure save path exists
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Run correctness test
|
||||
test_write_req_to_token_pool()
|
||||
print("Correctness test passed!")
|
||||
|
||||
# Run performance test
|
||||
benchmark = get_benchmark()
|
||||
benchmark.run(print_data=True, save_path=save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/write_req_to_token_pool/",
|
||||
help="Path to save benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_benchmark(args.save_path)
|
||||
294
third_party/sglang/benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py
vendored
Normal file
294
third_party/sglang/benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py
vendored
Normal file
@@ -0,0 +1,294 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton.testing as tt
|
||||
|
||||
from sglang.benchmark.bench_utils import run_bench
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
||||
|
||||
|
||||
def extend_attention_fwd_torch(
|
||||
q: torch.Tensor, # [extend_tokens, H_Q, D]
|
||||
k: torch.Tensor, # [extend_tokens, H_KV, D]
|
||||
v: torch.Tensor, # [extend_tokens, H_KV, D]
|
||||
o: torch.Tensor, # [extend_tokens, H_Q, D]
|
||||
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
|
||||
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
|
||||
qo_indptr: torch.Tensor, # [B+1]
|
||||
kv_indptr: torch.Tensor, # [B+1]
|
||||
kv_indices: torch.Tensor, # [prefix_tokens]
|
||||
sliding_window_size: int,
|
||||
):
|
||||
B = qo_indptr.size(0) - 1
|
||||
_, H_Q, D = q.shape
|
||||
_, H_KV, _ = k.shape
|
||||
|
||||
group_size = H_Q // H_KV
|
||||
scale = 1.0 / D**0.5
|
||||
|
||||
for i in range(B):
|
||||
q_start = int(qo_indptr[i].item())
|
||||
q_end = int(qo_indptr[i + 1].item())
|
||||
kv_start = int(kv_indptr[i].item())
|
||||
kv_end = int(kv_indptr[i + 1].item())
|
||||
|
||||
prefix_indices = kv_indices[kv_start:kv_end]
|
||||
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
|
||||
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
|
||||
|
||||
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
|
||||
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
|
||||
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
|
||||
|
||||
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
|
||||
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
|
||||
|
||||
if group_size != 1:
|
||||
k_full_hq = k_full.repeat_interleave(
|
||||
group_size, dim=1
|
||||
) # [total_len, H_Q, D]
|
||||
v_full_hq = v_full.repeat_interleave(
|
||||
group_size, dim=1
|
||||
) # [total_len, H_Q, D]
|
||||
else:
|
||||
k_full_hq = k_full
|
||||
v_full_hq = v_full
|
||||
|
||||
prefix_len = k_prefix.size(0)
|
||||
extend_len = k_extend.size(0)
|
||||
total_len = prefix_len + extend_len
|
||||
|
||||
# causal
|
||||
pos_keys = torch.arange(total_len, device=q.device)
|
||||
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
|
||||
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
|
||||
|
||||
# sliding window
|
||||
if sliding_window_size is not None and sliding_window_size > 0:
|
||||
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
|
||||
else:
|
||||
start = torch.zeros_like(t)
|
||||
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
|
||||
|
||||
final_mask = causal_mask & window_mask
|
||||
|
||||
attn_scores = (
|
||||
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
|
||||
) # [extend_len, H_Q, total_len]
|
||||
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
|
||||
|
||||
|
||||
def _build_batch(
|
||||
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device="cuda"
|
||||
):
|
||||
b_seq_len_prefix = torch.randint(
|
||||
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
|
||||
)
|
||||
b_seq_len_extend = torch.randint(
|
||||
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
|
||||
|
||||
kv_indices = torch.zeros(
|
||||
(int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device
|
||||
)
|
||||
for i in range(B):
|
||||
s = kv_indptr[i].item()
|
||||
e = kv_indptr[i + 1].item()
|
||||
kv_indices[s:e] = torch.arange(
|
||||
b_start_loc[i],
|
||||
b_start_loc[i] + b_seq_len_prefix[i],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
total_token_num = int(torch.sum(b_seq_len).item())
|
||||
extend_token_num = int(torch.sum(b_seq_len_extend).item())
|
||||
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device=device
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device=device
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
|
||||
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
o_extend_triton = torch.empty(
|
||||
(extend_token_num, H_Q, D), dtype=dtype, device=device
|
||||
)
|
||||
o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())
|
||||
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
|
||||
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||
|
||||
inputs = dict(
|
||||
q_extend=q_extend,
|
||||
k_extend=k_extend,
|
||||
v_extend=v_extend,
|
||||
k_buffer=k_buffer,
|
||||
v_buffer=v_buffer,
|
||||
o_extend_triton=o_extend_triton,
|
||||
o_extend_torch=o_extend_torch,
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=kv_indptr,
|
||||
kv_indices=kv_indices,
|
||||
max_len_extend=max_len_extend,
|
||||
WINDOW_SIZE=WINDOW_SIZE,
|
||||
)
|
||||
meta = dict(
|
||||
B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num
|
||||
)
|
||||
return inputs, meta
|
||||
|
||||
|
||||
def _run_triton(inputs):
|
||||
extend_attention_fwd(
|
||||
inputs["q_extend"],
|
||||
inputs["k_extend"],
|
||||
inputs["v_extend"],
|
||||
inputs["o_extend_triton"],
|
||||
inputs["k_buffer"],
|
||||
inputs["v_buffer"],
|
||||
inputs["qo_indptr"],
|
||||
inputs["kv_indptr"],
|
||||
inputs["kv_indices"],
|
||||
custom_mask=None,
|
||||
is_causal=True,
|
||||
mask_indptr=None,
|
||||
max_len_extend=inputs["max_len_extend"],
|
||||
sliding_window_size=inputs["WINDOW_SIZE"],
|
||||
)
|
||||
|
||||
|
||||
def _run_torch_ref(inputs):
|
||||
extend_attention_fwd_torch(
|
||||
inputs["q_extend"],
|
||||
inputs["k_extend"],
|
||||
inputs["v_extend"],
|
||||
inputs["o_extend_torch"],
|
||||
inputs["k_buffer"],
|
||||
inputs["v_buffer"],
|
||||
inputs["qo_indptr"],
|
||||
inputs["kv_indptr"],
|
||||
inputs["kv_indices"],
|
||||
inputs["WINDOW_SIZE"],
|
||||
)
|
||||
|
||||
|
||||
N_CTXS = [1024, 2048, 4096, 8192]
|
||||
WINDOW_SIZES = [-1, 127, 256, 512]
|
||||
|
||||
CONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))
|
||||
|
||||
PROVIDERS = ["torch", "triton"]
|
||||
|
||||
|
||||
@tt.perf_report(
|
||||
tt.Benchmark(
|
||||
x_names=["N_CTX", "WINDOW_SIZE"],
|
||||
x_vals=CONFIGS,
|
||||
line_arg="provider",
|
||||
line_vals=PROVIDERS,
|
||||
line_names=PROVIDERS,
|
||||
ylabel="Runtime (ms)",
|
||||
plot_name="extend_attention_triton_vs_torch",
|
||||
args={
|
||||
"B": 32,
|
||||
"H_Q": 64,
|
||||
"H_KV": 8,
|
||||
"D": 128,
|
||||
"dtype": "bf16",
|
||||
"device": "cuda",
|
||||
"check_correctness": False,
|
||||
"warmup": 25,
|
||||
"rep": 100,
|
||||
},
|
||||
)
|
||||
)
|
||||
def bench(
|
||||
N_CTX,
|
||||
provider,
|
||||
B,
|
||||
H_Q,
|
||||
H_KV,
|
||||
D,
|
||||
dtype,
|
||||
device,
|
||||
WINDOW_SIZE,
|
||||
check_correctness,
|
||||
warmup,
|
||||
rep,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
||||
dt = dtype_map[dtype]
|
||||
|
||||
inputs, _ = _build_batch(
|
||||
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device
|
||||
)
|
||||
|
||||
if check_correctness and provider == "triton":
|
||||
_run_triton(inputs)
|
||||
_run_torch_ref(inputs)
|
||||
torch.cuda.synchronize()
|
||||
if not torch.allclose(
|
||||
inputs["o_extend_triton"], inputs["o_extend_torch"], rtol=1e-3, atol=1e-3
|
||||
):
|
||||
raise AssertionError("Mismatch between triton and torch reference.")
|
||||
|
||||
if provider == "triton":
|
||||
ms = run_bench(
|
||||
lambda: _run_triton(inputs),
|
||||
quantiles=None,
|
||||
warmup_ms=warmup,
|
||||
rep_ms=rep,
|
||||
)[0]
|
||||
elif provider == "torch":
|
||||
ms = run_bench(
|
||||
lambda: _run_torch_ref(inputs),
|
||||
quantiles=None,
|
||||
warmup_ms=warmup,
|
||||
rep_ms=rep,
|
||||
)[0]
|
||||
else:
|
||||
raise ValueError(provider)
|
||||
|
||||
return ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench.run(print_data=True, show_plots=False)
|
||||
37
third_party/sglang/benchmark/line_retrieval/README.md
vendored
Normal file
37
third_party/sglang/benchmark/line_retrieval/README.md
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
## Download data
|
||||
|
||||
```
|
||||
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
|
||||
python3 gen_data.py --number 1000
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
###
|
||||
|
||||
```
|
||||
# original
|
||||
Accuracy: 0.940, latency: 332.83 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 1000)
|
||||
Accuracy: 0.760, latency: 238.46 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 3000)
|
||||
Accuracy: 0.760, latency: 238.46 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 0)
|
||||
Accuracy: 0.520, latency: 238.46 s
|
||||
|
||||
# parallel encoding (adjust_cache)
|
||||
Accuracy: 0.460, latency: 257.66 s
|
||||
```
|
||||
149
third_party/sglang/benchmark/line_retrieval/bench_sglang.py
vendored
Normal file
149
third_party/sglang/benchmark/line_retrieval/bench_sglang.py
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
|
||||
@sgl.function
|
||||
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
|
||||
s += prefix + "\n"
|
||||
|
||||
contexts = [body_0, body_1, body_2, body_3]
|
||||
position_ids_offset = [i * 1000 for i in range(len(contexts))]
|
||||
forks = s.fork(len(contexts), position_ids_offset)
|
||||
forks += lambda i: contexts[i] + "\n"
|
||||
forks.join(mode="concate_and_append")
|
||||
|
||||
s += "\n" + suffix
|
||||
s += sgl.gen("answer", max_tokens=16)
|
||||
|
||||
|
||||
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||
arguments = []
|
||||
labels = []
|
||||
sum_src_indices = []
|
||||
sum_dst_indices = []
|
||||
|
||||
for i in range(len(src_indices)):
|
||||
for j in range(len(dst_percents)):
|
||||
src_index = src_indices[i]
|
||||
dst_percent = dst_percents[j]
|
||||
|
||||
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
||||
query_indices = [
|
||||
q
|
||||
for q in query_indices
|
||||
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
|
||||
]
|
||||
dst_index = query_indices[
|
||||
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
|
||||
]
|
||||
label = line_obj["values"][dst_index]
|
||||
|
||||
body = line_obj["lines"][: src_index + 1]
|
||||
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
||||
body_part_len = len(body) // 4
|
||||
|
||||
arguments.append(
|
||||
{
|
||||
"prefix": line_obj["prefix"],
|
||||
"body_0": "\n".join(body[:body_part_len]),
|
||||
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
|
||||
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
|
||||
"body_3": "\n".join(body[3 * body_part_len :]),
|
||||
"suffix": suffix,
|
||||
}
|
||||
)
|
||||
labels.append(label)
|
||||
sum_src_indices.append(src_index)
|
||||
sum_dst_indices.append(dst_index)
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
tic = time.perf_counter()
|
||||
states = line_retrieval.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
backend=backend,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
corrects = []
|
||||
for i in range(len(arguments)):
|
||||
output = states[i]["answer"]
|
||||
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
|
||||
label = labels[i]
|
||||
|
||||
# Try all numbers
|
||||
findall = re.findall("\d+", output)
|
||||
if not findall:
|
||||
response_number = output
|
||||
else:
|
||||
for response_number in findall:
|
||||
if response_number == label:
|
||||
break
|
||||
|
||||
correct = response_number == label
|
||||
corrects.append(correct)
|
||||
|
||||
# Log results
|
||||
summary = (
|
||||
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
|
||||
f"Prompt len: {prompt_len}, "
|
||||
f"Correct: {correct}, "
|
||||
f"Label: {label}, Predicted: {response_number}, "
|
||||
)
|
||||
print(summary)
|
||||
|
||||
accuracy = np.mean(corrects)
|
||||
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "line_retrieval",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": len(arguments),
|
||||
"other": {
|
||||
"num_questions": len(arguments),
|
||||
"parallel": args.parallel,
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
def main(args):
|
||||
line_obj = json.load(open(args.data_path, "r"))
|
||||
|
||||
num_hoops = args.num_hoops
|
||||
for src_index in args.src_index:
|
||||
src_indices = [src_index]
|
||||
num_queries = args.num_queries_per_src
|
||||
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
|
||||
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
|
||||
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
|
||||
parser.add_argument("--num-queries-per-src", type=int, default=10)
|
||||
parser.add_argument("--num-hoops", type=int, default=1)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
139
third_party/sglang/benchmark/line_retrieval/gen_data.py
vendored
Normal file
139
third_party/sglang/benchmark/line_retrieval/gen_data.py
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Generate line data for line retrieval task.
|
||||
|
||||
Usage:
|
||||
python3 gen_data.py --number 1000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def generate_lines(random_words, num_lines, redirect_ratio):
|
||||
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
|
||||
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resolving the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
|
||||
|
||||
# Raw lines
|
||||
visited_indices = set([None])
|
||||
visited_values = set([None])
|
||||
|
||||
lines = []
|
||||
redirects = []
|
||||
indices = []
|
||||
values = []
|
||||
for i in tqdm(range(num_lines)):
|
||||
line_index = None
|
||||
while line_index in visited_indices:
|
||||
line_index = "-".join(np.random.choice(random_words, size=(2,)))
|
||||
visited_indices.add(line_index)
|
||||
|
||||
line_value = np.random.randint(low=0, high=999999)
|
||||
line_value = f"{line_value:06}"
|
||||
|
||||
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
|
||||
lines.append(line)
|
||||
redirects.append(None)
|
||||
indices.append(line_index)
|
||||
values.append(line_value)
|
||||
|
||||
# Add redirect
|
||||
if redirect_ratio > 0:
|
||||
num_redirect_lines = int(len(lines) * redirect_ratio)
|
||||
redirect_indices = np.random.choice(
|
||||
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
|
||||
)
|
||||
for i in redirect_indices:
|
||||
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
||||
lines[i] = (
|
||||
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||
)
|
||||
redirects[i] = target_idx
|
||||
|
||||
# Build links and find sources
|
||||
links = [[] for _ in range(num_lines)]
|
||||
contains_ring = set()
|
||||
for i in range(num_lines):
|
||||
if redirects[i] is None:
|
||||
continue
|
||||
|
||||
tmp_link = []
|
||||
cur = i
|
||||
visited = set()
|
||||
while redirects[cur] is not None:
|
||||
visited.add(cur)
|
||||
tmp_link.append(redirects[cur])
|
||||
cur = redirects[cur]
|
||||
|
||||
if cur in visited:
|
||||
contains_ring.add(i)
|
||||
tmp_link = None
|
||||
break
|
||||
values[i] = values[cur]
|
||||
links[i] = tmp_link
|
||||
|
||||
# Group by num_links
|
||||
group_by_num_hoops = defaultdict(list)
|
||||
for i in range(num_lines):
|
||||
if i in contains_ring:
|
||||
continue
|
||||
group_by_num_hoops[len(links[i]) + 1].append(i)
|
||||
|
||||
keys = sorted(list(group_by_num_hoops.keys()))
|
||||
for num_links in keys:
|
||||
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
|
||||
|
||||
# Append few-shot examples
|
||||
hoop1_candidates = list(group_by_num_hoops[1])
|
||||
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
|
||||
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
|
||||
hoop2_candidates = list(group_by_num_hoops[2])
|
||||
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
|
||||
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
|
||||
|
||||
i = hoop1_candidates[5]
|
||||
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
|
||||
if len(hoop2_candidates):
|
||||
i = hoop2_candidates[0]
|
||||
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||
i = hoop2_candidates[1]
|
||||
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||
else:
|
||||
i = hoop1_candidates[1]
|
||||
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||
i = hoop1_candidates[10]
|
||||
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||
|
||||
obj = {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"lines": lines,
|
||||
"indices": indices,
|
||||
"values": values,
|
||||
"links": links,
|
||||
"group_by_num_hoops": group_by_num_hoops,
|
||||
"contains_ring": sorted(list(contains_ring)),
|
||||
}
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--number", type=int)
|
||||
parser.add_argument("--redirect-ratio", type=float, default=0.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
num_lines = args.number
|
||||
|
||||
random_words_filename = "random_words.json"
|
||||
random_words = json.load(open(random_words_filename, "r"))
|
||||
|
||||
np.random.seed(42)
|
||||
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
|
||||
|
||||
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
|
||||
with open(fout, "w") as fout:
|
||||
json.dump(obj, fout, indent=2)
|
||||
61
third_party/sglang/benchmark/llava_bench/README.md
vendored
Normal file
61
third_party/sglang/benchmark/llava_bench/README.md
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
## Download benchmark images
|
||||
|
||||
```
|
||||
python3 download_images.py
|
||||
```
|
||||
|
||||
image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
|
||||
|
||||
### Other Dependency
|
||||
```
|
||||
pip3 install "sglang[all]"
|
||||
pip3 install "torch>=2.1.2" "transformers>=4.36" pillow
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
Launch a server
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
```
|
||||
|
||||
Run benchmark
|
||||
```
|
||||
# Run with local models
|
||||
python3 bench_sglang.py --num-questions 60
|
||||
|
||||
# Run with OpenAI models
|
||||
python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview
|
||||
```
|
||||
|
||||
### Bench LLaVA original code
|
||||
```
|
||||
git clone git@github.com:haotian-liu/LLaVA.git
|
||||
cd LLaVA
|
||||
git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96
|
||||
pip3 install -e .
|
||||
|
||||
cd ~/sglang/benchmark/llava_bench
|
||||
CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh
|
||||
```
|
||||
|
||||
|
||||
### Benchmark llama.cpp
|
||||
|
||||
```
|
||||
# Install
|
||||
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
|
||||
pip install sse_starlette starlette_context pydantic_settings
|
||||
|
||||
# Download weights
|
||||
mkdir -p ~/model_weights/llava-v1.5-7b/
|
||||
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf
|
||||
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf
|
||||
```
|
||||
|
||||
```
|
||||
python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000
|
||||
|
||||
OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1
|
||||
```
|
||||
9
third_party/sglang/benchmark/llava_bench/bench_hf_llava_bench.sh
vendored
Executable file
9
third_party/sglang/benchmark/llava_bench/bench_hf_llava_bench.sh
vendored
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
python -m llava.eval.model_vqa \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--question-file ./questions.jsonl \
|
||||
--image-folder ./images \
|
||||
--answers-file ./answers_hf.jsonl \
|
||||
--temperature 0 \
|
||||
--conv-mode vicuna_v1
|
||||
9
third_party/sglang/benchmark/llava_bench/bench_hf_mme.sh
vendored
Executable file
9
third_party/sglang/benchmark/llava_bench/bench_hf_mme.sh
vendored
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
python -m llava.eval.model_vqa_loader \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--question-file ./mme_pack/llava_mme_bench_replace.jsonl \
|
||||
--image-folder ./mme_pack/MME_Benchmark_release_version \
|
||||
--answers-file ./answers_hf_mme.jsonl \
|
||||
--temperature 0 \
|
||||
--conv-mode vicuna_v1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user