Add vLLM v0.18.1 source tree with KV transfer abort fix

third_party/vllm/ now tracked in git for direct patch management.
Based on vLLM v0.18.1 release with one patch applied:

  vllm/v1/core/sched/scheduler.py:
    Replace fatal assert with graceful skip when KV transfer callback
    arrives for an already-aborted request during PD disaggregated serving.

Future vLLM modifications should be made directly in third_party/vllm/
and committed normally. The patches/ directory is kept as documentation
of what changed from upstream.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 00:30:38 +08:00
parent b6591950bc
commit 445e491123
4285 changed files with 1111303 additions and 1 deletions

View File

@@ -0,0 +1,73 @@
# Offline Inference
The `LLM` class provides the primary Python interface for doing offline inference, which is interacting with a model without using a separate model inference server.
## Usage
The first script in this example shows the most basic usage of vLLM. If you are new to Python and vLLM, you should start here.
```bash
python examples/basic/offline_inference/basic.py
```
The rest of the scripts include an [argument parser](https://docs.python.org/3/library/argparse.html), which you can use to pass any arguments that are compatible with [`LLM`](https://docs.vllm.ai/en/latest/api/offline_inference/llm.html). Try running the script with `--help` for a list of all available arguments.
```bash
python examples/basic/offline_inference/classify.py
```
```bash
python examples/basic/offline_inference/embed.py
```
```bash
python examples/basic/offline_inference/score.py
```
The chat and generate scripts also accept the [sampling parameters](https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-parameters): `max_tokens`, `temperature`, `top_p` and `top_k`.
```bash
python examples/basic/offline_inference/chat.py
```
```bash
python examples/basic/offline_inference/generate.py
```
## Features
In the scripts that support passing arguments, you can experiment with the following features.
### Default generation config
The `--generation-config` argument specifies where the generation config will be loaded from when calling `LLM.get_default_sampling_params()`. If set to auto, the generation config will be loaded from model path. If set to a folder path, the generation config will be loaded from the specified folder path. If it is not provided, vLLM defaults will be used.
> If max_new_tokens is specified in generation config, then it sets a server-wide limit on the number of output tokens for all requests.
Try it yourself with the following argument:
```bash
--generation-config auto
```
### Quantization
#### GGUF
vLLM supports models that are quantized using GGUF.
Try one yourself using the `repo_id:quant_type` format to load directly from HuggingFace:
```bash
--model unsloth/Qwen3-0.6B-GGUF:Q4_K_M --tokenizer Qwen/Qwen3-0.6B
```
### CPU offload
The `--cpu-offload-gb` argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and set this to 10, virtually you can think of it as a 34 GB GPU. Then you can load a 13B model with BF16 weight, which requires at least 26GB GPU memory. Note that this requires fast CPU-GPU interconnect, as part of the model is loaded from CPU memory to GPU memory on the fly in each model forward pass.
Try it yourself with the following arguments:
```bash
--model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10
```

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, EngineArgs
from vllm.outputs import RequestOutput
from vllm.utils.argparse_utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
# Add example params
parser.add_argument("--chat-template-path", type=str)
return parser
def main(args: dict):
# Pop arguments not used by LLM
max_tokens = args.pop("max_tokens")
temperature = args.pop("temperature")
top_p = args.pop("top_p")
top_k = args.pop("top_k")
chat_template_path = args.pop("chat_template_path")
# Create an LLM
llm = LLM(**args)
# Create sampling params object
sampling_params = llm.get_default_sampling_params()
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if temperature is not None:
sampling_params.temperature = temperature
if top_p is not None:
sampling_params.top_p = top_p
if top_k is not None:
sampling_params.top_k = top_k
def print_outputs(outputs: list[RequestOutput], prompts: list):
assert len(outputs) == len(prompts)
print("\nGenerated Outputs:\n" + "-" * 80)
for i, output in enumerate(outputs):
generated_text = output.outputs[0].text
print(f"Prompt: {prompts[i]!r}\n")
print(f"Generated text: {generated_text!r}")
print("-" * 80)
print("=" * 80)
# In this script, we demonstrate how to pass input to the chat method:
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
print_outputs(
outputs,
[
conversation,
],
)
# You can run batch inference with llm.chat API
conversations = [conversation for _ in range(10)]
# We turn on tqdm progress bar to verify it's indeed running batch inference
outputs = llm.chat(conversations, sampling_params, use_tqdm=True)
print_outputs(outputs, conversations)
# A chat template can be optionally supplied.
# If not, the model will use its default chat template.
if chat_template_path is not None:
with open(chat_template_path) as f:
chat_template = f.read()
outputs = llm.chat(
conversations,
sampling_params,
use_tqdm=False,
chat_template=chat_template,
)
print_outputs(outputs, conversations)
if __name__ == "__main__":
parser = create_parser()
args: dict = vars(parser.parse_args())
main(args)

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="jason9693/Qwen2.5-1.5B-apeach",
runner="pooling",
enforce_eager=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass runner="pooling" for classification models
llm = LLM(**vars(args))
# Generate logits. The output is a list of ClassificationRequestOutputs.
outputs = llm.classify(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
probs = output.outputs.probs
probs_trimmed = (str(probs[:16])[:-1] + ", ...]") if len(probs) > 16 else probs
print(
f"Prompt: {prompt!r} \n"
f"Class Probabilities: {probs_trimmed} (size={len(probs)})"
)
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.print_utils import print_embeddings
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="intfloat/e5-small",
runner="pooling",
enforce_eager=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.embed(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
print(f"Prompt: {prompt!r}")
print_embeddings(embeds)
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
return parser
def main(args: dict):
# Pop arguments not used by LLM
max_tokens = args.pop("max_tokens")
temperature = args.pop("temperature")
top_p = args.pop("top_p")
top_k = args.pop("top_k")
# Create an LLM
llm = LLM(**args)
# Create a sampling params object
sampling_params = llm.get_default_sampling_params()
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if temperature is not None:
sampling_params.temperature = temperature
if top_p is not None:
sampling_params.top_p = top_p
if top_k is not None:
sampling_params.top_k = top_k
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
parser = create_parser()
args: dict = vars(parser.parse_args())
main(args)

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.print_utils import print_embeddings
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="internlm/internlm2-1_8b-reward",
runner="pooling",
enforce_eager=True,
max_model_len=1024,
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass runner="pooling" for reward models
llm = LLM(**vars(args))
# Generate rewards. The output is a list of PoolingRequestOutput.
outputs = llm.reward(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
rewards = output.outputs.data
print(f"Prompt: {prompt!r}")
print_embeddings(rewards, prefix="Reward")
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="BAAI/bge-reranker-v2-m3",
runner="pooling",
enforce_eager=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
# Create an LLM.
# You should pass runner="pooling" for cross-encoder models
llm = LLM(**vars(args))
# Generate scores. The output is a list of ScoringRequestOutputs.
outputs = llm.score(query, documents)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for document, output in zip(documents, outputs):
score = output.outputs.score
print(f"Pair: {[query, document]!r} \nScore: {score}")
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for OpenAI Chat Completion using vLLM API server
NOTE: start a supported chat completion model server with `vllm serve`, e.g.
vllm serve meta-llama/Llama-2-7b-chat-hf
"""
import argparse
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{
"role": "assistant",
"content": "The Los Angeles Dodgers won the World Series in 2020.",
},
{"role": "user", "content": "Where was it played?"},
]
def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument(
"--stream", action="store_true", help="Enable streaming response"
)
return parser.parse_args()
def main(args):
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Chat Completion API
chat_completion = client.chat.completions.create(
messages=messages,
model=model,
stream=args.stream,
)
print("-" * 50)
print("Chat completion results:")
if args.stream:
for c in chat_completion:
print(c)
else:
print(chat_completion)
print("-" * 50)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument(
"--stream", action="store_true", help="Enable streaming response"
)
return parser.parse_args()
def main(args):
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Completion API
completion = client.completions.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=args.stream,
logprobs=3,
)
print("-" * 50)
print("Completion results:")
if args.stream:
for c in completion:
print(c)
else:
print(completion)
print("-" * 50)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple example demonstrating streaming offline inference with AsyncLLM (V1 engine).
This script shows the core functionality of vLLM's AsyncLLM engine for streaming
token-by-token output in offline inference scenarios. It demonstrates DELTA mode
streaming where you receive new tokens as they are generated.
Usage:
python examples/offline_inference/async_llm_streaming.py
"""
import asyncio
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
async def stream_response(engine: AsyncLLM, prompt: str, request_id: str) -> None:
"""
Stream response from AsyncLLM and display tokens as they arrive.
This function demonstrates the core streaming pattern:
1. Create SamplingParams with DELTA output kind
2. Call engine.generate() and iterate over the async generator
3. Print new tokens as they arrive
4. Handle the finished flag to know when generation is complete
"""
print(f"\n🚀 Prompt: {prompt!r}")
print("💬 Response: ", end="", flush=True)
# Configure sampling parameters for streaming
sampling_params = SamplingParams(
max_tokens=100,
temperature=0.8,
top_p=0.95,
seed=42, # For reproducible results
output_kind=RequestOutputKind.DELTA, # Get only new tokens each iteration
)
try:
# Stream tokens from AsyncLLM
async for output in engine.generate(
request_id=request_id, prompt=prompt, sampling_params=sampling_params
):
# Process each completion in the output
for completion in output.outputs:
# In DELTA mode, we get only new tokens generated since last iteration
new_text = completion.text
if new_text:
print(new_text, end="", flush=True)
# Check if generation is finished
if output.finished:
print("\n✅ Generation complete!")
break
except Exception as e:
print(f"\n❌ Error during streaming: {e}")
raise
async def main():
print("🔧 Initializing AsyncLLM...")
# Create AsyncLLM engine with simple configuration
engine_args = AsyncEngineArgs(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True, # Faster startup for examples
)
engine = AsyncLLM.from_engine_args(engine_args)
try:
# Example prompts to demonstrate streaming
prompts = [
"The future of artificial intelligence is",
"In a galaxy far, far away",
"The key to happiness is",
]
print(f"🎯 Running {len(prompts)} streaming examples...")
# Process each prompt
for i, prompt in enumerate(prompts, 1):
print(f"\n{'=' * 60}")
print(f"Example {i}/{len(prompts)}")
print(f"{'=' * 60}")
request_id = f"stream-example-{i}"
await stream_response(engine, prompt, request_id)
# Brief pause between examples
if i < len(prompts):
await asyncio.sleep(0.5)
print("\n🎉 All streaming examples completed!")
finally:
# Always clean up the engine
print("🔧 Shutting down engine...")
engine.shutdown()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n🛑 Interrupted by user")

View File

@@ -0,0 +1,665 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
from dataclasses import asdict
from typing import Any, NamedTuple
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest
from vllm.utils.argparse_utils import FlexibleArgumentParser
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
question_per_audio_count = {
0: "What is 1+1?",
1: "What is recited in the audio?",
2: "What sport and what nursery rhyme are referenced?",
}
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str | None = None
prompt_token_ids: dict[str, list[int]] | None = None
multi_modal_data: dict[str, Any] | None = None
stop_token_ids: list[int] | None = None
lora_requests: list[LoRARequest] | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# AudioFlamingo3
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
model_name = "nvidia/audio-flamingo-3-hf"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
# AudioFlamingo3 uses <sound> token for audio
audio_placeholder = "<sound>" * audio_count
prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_placeholder}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
model_name = "nvidia/music-flamingo-2601-hf"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
# MusicFlamingo uses <sound> token for audio
audio_placeholder = "<sound>" * audio_count
prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_placeholder}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
model_name = "google/gemma-3n-E2B-it"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_batched_tokens=2048,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
"<end_of_turn>\n<start_of_turn>model\n"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
model_name = "zai-org/GLM-ASR-Nano-2512"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# GLM-ASR uses <|pad|> token for audio
audio_placeholder = "<|pad|>" * audio_count
messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# FunAudioChat
def run_funaudiochat(question: str, audio_count: int) -> ModelRequestData:
# NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of
# writing. Pass a local model path via `--model`.
model_name = "funaudiochat"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
audio_in_prompt = "".join(
["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for _ in range(audio_count)]
)
prompt = f"{audio_in_prompt}{question}"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somewhat different from what is
# optimal for granite speech, and it is generally recommended to use beam
# search. Check the model README for suggested settings.
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
model_name = "ibm-granite/granite-speech-3.3-8b"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=2048,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=64,
limit_mm_per_prompt={"audio": audio_count},
)
# The model has an audio-specific lora directly in its model dir;
# it should be enabled whenever you pass audio inputs to the model.
speech_lora_path = model_name
audio_placeholder = "<|audio|>" * audio_count
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# Kimi-Audio-7B-Instruct
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
"""Kimi-Audio-7B-Instruct for audio transcription and understanding."""
model_name = "moonshotai/Kimi-Audio-7B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
# Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
# Default prompt for transcription
if not question:
question = "Please transcribe the audio"
prompt = f"{audio_placeholder}{question}"
# Stop at EOS token (151644) to prevent repetition
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=[151644],
)
# MiDashengLM
def run_midashenglm(question: str, audio_count: int):
model_name = "mispeech/midashenglm-7b"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join(
["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)]
)
default_system = "You are a helpful language and speech assistant."
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
audio_placeholder = "(<audio>./</audio>)" * audio_count
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_template=audio_chat_template,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
)
# Phi-4-multimodal-instruct
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
"""
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=12800,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# Qwen2-Audio
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join(
[
f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
for idx in range(audio_count)
]
)
prompt = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
model_name = "Qwen/Qwen2.5-Omni-7B"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join(
["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
)
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
def run_qwen3_asr(question: str, audio_count: int) -> ModelRequestData:
model_name = "Qwen/Qwen3-Asr-1.7B"
audio_in_prompt = "<|audio_start|><|audio_pad|><|audio_end|>\n" * audio_count
prompt = f"<|im_start|>user\n{audio_in_prompt}<|im_end|>\n<|im_start|>assistant\n"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
RawAudio,
TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
model_name = "mistralai/Voxtral-Mini-3B-2507"
tokenizer = MistralTokenizer.from_hf_hub(model_name)
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
enable_chunked_prefill=False,
)
text_chunk = TextChunk(text=question)
audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(audio_count)
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
req = ChatCompletionRequest(messages=messages, model=model_name)
tokens = tokenizer.encode_chat_completion(req)
prompt_ids, audios = tokens.tokens, tokens.audios
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
multi_modal_data = {"audio": audios_and_sr}
return ModelRequestData(
engine_args=engine_args,
prompt_token_ids=prompt_ids,
multi_modal_data=multi_modal_data,
)
# Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "Whisper only support single audio input per prompt"
model_name = "openai/whisper-large-v3-turbo"
prompt = "<|startoftranscript|>"
engine_args = EngineArgs(
model=model_name,
max_model_len=448,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
model_example_map = {
"audioflamingo3": run_audioflamingo3,
"musicflamingo": run_musicflamingo,
"gemma3n": run_gemma3n,
"glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech,
"kimi_audio": run_kimi_audio,
"midashenglm": run_midashenglm,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
"qwen3_asr": run_qwen3_asr,
"ultravox": run_ultravox,
"voxtral": run_voxtral,
"whisper": run_whisper,
}
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"audio language models"
)
parser.add_argument(
"--model-type",
"-m",
type=str,
default="ultravox",
choices=model_example_map.keys(),
help='Huggingface "model_type".',
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Model ID or local path override. Required for funaudiochat.",
)
parser.add_argument(
"--num-prompts", type=int, default=1, help="Number of prompts to run."
)
parser.add_argument(
"--num-audios",
type=int,
default=1,
choices=[0, 1, 2],
help="Number of audio items per prompt.",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument(
"--tensor-parallel-size",
"-tp",
type=int,
default=None,
help="Tensor parallel size to override the model's default setting. ",
)
return parser.parse_args()
def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
if model == "funaudiochat" and not args.model:
raise ValueError("--model is required when --model-type=funaudiochat")
if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
raise ValueError(
f"tensor_parallel_size must be a positive integer, "
f"got {args.tensor_parallel_size}"
)
audio_count = args.num_audios
req_data = model_example_map[model](
question_per_audio_count[audio_count], audio_count
)
if model == "funaudiochat":
req_data.engine_args.model = args.model
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
if args.tensor_parallel_size is not None:
engine_args["tensor_parallel_size"] = args.tensor_parallel_size
llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
)
def get_input(start, end):
mm_data = req_data.multi_modal_data
if not mm_data:
mm_data = {}
if end - start > 0:
mm_data = {
"audio": [
asset.audio_and_sample_rate for asset in audio_assets[start:end]
]
}
inputs = {"multi_modal_data": mm_data}
if req_data.prompt:
inputs["prompt"] = req_data.prompt
else:
inputs["prompt_token_ids"] = req_data.prompt_token_ids
return inputs
# Batch inference
assert args.num_prompts > 0
if audio_count != 1:
inputs = get_input(0, audio_count)
inputs = [inputs] * args.num_prompts
else:
# For single audio input, we need to vary the audio input
# to avoid deduplication in vLLM engine.
inputs = []
for i in range(args.num_prompts):
start = i % len(audio_assets)
inp = get_input(start, start + 1)
inputs.append(inp)
# Add LoRA request if applicable
lora_request = (
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
)
outputs = llm.generate(
inputs,
sampling_params=sampling_params,
lora_request=lora_request,
)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstration script for Automatic Prefix Caching (APC) in vLLM.
Automatic Prefix Caching (APC) allows the vLLM engine to reuse cached
KV (key-value) pairs from previous prompts if a new query shares the same
prefix. This reduces redundant computation and improves inference speed.
To enable APC, set `enable_prefix_caching=True` when initializing the
vLLM engine.
This script uses a long Markdown table as the shared prompt prefix and
compares the generation time for two queries that share the same prefix
but ask different questions.
Run:
python examples/offline_inference/automatic_prefix_caching.py
"""
import time
from vllm import LLM, SamplingParams
# ruff: noqa: E501
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
LONG_PROMPT = (
"You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n"
"""
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON |
| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK |
| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW |
| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ |
| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE |
| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY |
| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC |
| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK |
| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC|
| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ |
| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE |
| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA |
| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB |
| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK |
| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD |
| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ |
| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE |
| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA |
| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON |
| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK |
| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA |
| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ|
| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE |
| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO |
| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC |
| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK |
| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA |
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
"""
)
def get_generation_time(llm, sampling_params, prompts):
# time the generation
start_time = time.time()
output = llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
# print the output and generation time
print("-" * 30)
print(f"Output: {output[0].outputs[0].text}")
print(f"Generation time: {end_time - start_time} seconds.")
print("-" * 30)
def main():
# set enable_prefix_caching=True to enable APC
llm = LLM(model="lmsys/longchat-13b-16k", enable_prefix_caching=True)
sampling_params = SamplingParams(temperature=0, max_tokens=100)
# Querying the age of John Doe
get_generation_time(
llm,
sampling_params,
LONG_PROMPT
+ "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
)
# Querying the age of Zack Blue
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
get_generation_time(
llm,
sampling_params,
LONG_PROMPT
+ "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use Ray Data for data parallel batch inference.
Ray Data is a data processing framework that can process very large datasets
with first-class support for vLLM.
Ray Data provides functionality for:
* Reading and writing to most popular file formats and cloud object storage.
* Streaming execution, so you can run inference on datasets that far exceed
the aggregate RAM of the cluster.
* Scale up the workload without code changes.
* Automatic sharding, load-balancing, and autoscaling across a Ray cluster,
with built-in fault-tolerance and retry semantics.
* Continuous batching that keeps vLLM replicas saturated and maximizes GPU
utilization.
* Compatible with tensor/pipeline parallel inference.
Learn more about Ray Data's LLM integration:
https://docs.ray.io/en/latest/data/working-with-llms.html
"""
import ray
from packaging.version import Version
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
assert Version(ray.__version__) >= Version("2.44.1"), (
"Ray version must be at least 2.44.1"
)
# Uncomment to reduce clutter in stdout
# ray.init(log_to_driver=False)
# ray.data.DataContext.get_current().enable_progress_bars = False
# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
print(ds.schema())
size = ds.count()
print(f"Size of dataset: {size} prompts")
# Configure vLLM engine.
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
},
concurrency=1, # set the number of parallel vLLM replicas
batch_size=64,
)
# Create a Processor object, which will be used to
# do batch inference on the dataset
vllm_processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a bot that responds with haikus."},
{"role": "user", "content": row["text"]},
],
sampling_params=dict(
temperature=0.3,
max_tokens=250,
),
),
postprocess=lambda row: dict(
answer=row["generated_text"],
**row, # This will return all the original columns in the dataset.
),
)
ds = vllm_processor(ds)
# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs = ds.take(limit=10)
for output in outputs:
prompt = output["prompt"]
generated_text = output["generated_text"]
print(f"Prompt: {prompt!r}")
print(f"Generated text: {generated_text!r}")
# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
# ds.write_parquet("s3://<your-output-bucket>")

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import json
import random
import string
from vllm import LLM
from vllm.sampling_params import SamplingParams
# This script is an offline demo for function calling
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Mistral-7B-Instruct-v0.3"
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
# or "mistralai/Mistral-Large-Instruct-2407"
# or any other mistral model with function calling ability
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
llm = LLM(
model=model_name,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral",
)
def generate_random_id(length=9):
characters = string.ascii_letters + string.digits
random_id = "".join(random.choice(characters) for _ in range(length))
return random_id
# simulate an API that can be called
def get_current_weather(city: str, state: str, unit: "str"):
return (
f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's."
)
tool_functions = {"get_current_weather": get_current_weather}
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
]
messages = [
{
"role": "user",
"content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
}
]
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()
# append the assistant message
messages.append(
{
"role": "assistant",
"content": output,
}
)
# let's now actually parse and execute the model's output simulating an API call by using the
# above defined function
tool_calls = json.loads(output)
tool_answers = [
tool_functions[call["name"]](**call["arguments"]) for call in tool_calls
]
# append the answer as a tool message and let the LLM give you an answer
messages.append(
{
"role": "tool",
"content": "\n\n".join(tool_answers),
"tool_call_id": generate_random_id(),
}
)
outputs = llm.chat(messages, sampling_params, tools=tools)
print(outputs[0].outputs[0].text.strip())
# yields
# 'The weather in Dallas, TX is 85 degrees Fahrenheit. '
# 'It is partly cloudly, with highs in the 90's.'

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script demonstrates how to extend the context length
of a Qwen model using the YARN method (rope_parameters)
and run a simple chat example.
Usage:
python examples/offline_inference/context_extension.py
"""
from vllm import LLM, RequestOutput, SamplingParams
def create_llm():
rope_theta = 1000000
original_max_position_embeddings = 32768
factor = 4.0
# Use yarn to extend context
hf_overrides = {
"rope_parameters": {
"rope_theta": rope_theta,
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
},
"max_model_len": int(original_max_position_embeddings * factor),
}
llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
return llm
def run_llm_chat(llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=128,
)
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
return outputs, [
conversation,
]
def print_outputs(outputs: list[RequestOutput], conversations: list):
print("\nGenerated Outputs:\n" + "-" * 80)
for i, output in enumerate(outputs):
prompt = conversations[i]
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n")
print(f"Generated text: {generated_text!r}")
print("-" * 80)
def main():
llm = create_llm()
outputs, conversations = run_llm_chat(llm)
print_outputs(outputs, conversations)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,214 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Usage:
Single node:
python examples/offline_inference/data_parallel.py \
--model="ibm-research/PowerMoE-3b" \
-dp=2 \
-tp=2
Multi-node:
Node 0 (assume the node has ip of 10.99.48.128):
python examples/offline_inference/data_parallel.py \
--model="ibm-research/PowerMoE-3b" \
-dp=2 \
-tp=2 \
--dp-num-nodes=2 \
--dp-node-rank=0 \
--dp-master-addr=10.99.48.128 \
--dp-master-port=13345
Node 1:
python examples/offline_inference/data_parallel.py \
--model="ibm-research/PowerMoE-3b" \
-dp=2 \
-tp=2 \
--dp-num-nodes=2 \
--dp-node-rank=1 \
--dp-master-addr=10.99.48.128 \
--dp-master-port=13345
"""
import os
from time import sleep
from vllm import LLM, EngineArgs, SamplingParams
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import get_open_port
def create_parser():
parser = FlexibleArgumentParser(description="Data Parallel Inference")
# Add all engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(
model="ibm-research/PowerMoE-3b",
enable_expert_parallel=True,
)
# Add DP-specific args (separate from engine args to avoid conflicts)
parser.add_argument(
"--dp-num-nodes",
type=int,
default=1,
help="Total number of nodes for data parallel.",
)
parser.add_argument(
"--dp-node-rank",
type=int,
default=0,
help="Rank of the current node for data parallel.",
)
parser.add_argument(
"--dp-master-addr",
type=str,
default="",
help="Master node IP address for DP coordination.",
)
parser.add_argument(
"--dp-master-port",
type=int,
default=0,
help="Master node port for DP coordination.",
)
parser.add_argument(
"--timeout",
type=int,
default=300,
help="Number of seconds before unresponsive process is killed.",
)
return parser
def main(
dp_size,
local_dp_rank,
global_dp_rank,
dp_master_ip,
dp_master_port,
engine_args,
):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
# engine processes.
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 100
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
floor = len(prompts) // dp_size
remainder = len(prompts) % dp_size
# Distribute prompts into even groups.
def start(rank):
return rank * floor + min(rank, remainder)
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
prompts = ["Placeholder"]
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
# Create a sampling params object.
# since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
sampling_params = SamplingParams(
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
)
# Create an LLM.
llm = LLM(**engine_args)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
if i >= 5:
# print only 5 outputs
break
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}"
)
# Give engines time to pause their processing loops before exiting.
sleep(1)
if __name__ == "__main__":
parser = create_parser()
args = vars(parser.parse_args())
# Extract DP-specific args (pop to remove from engine_args)
dp_size = args.pop("data_parallel_size")
dp_num_nodes = args.pop("dp_num_nodes")
dp_node_rank = args.pop("dp_node_rank")
dp_master_addr = args.pop("dp_master_addr")
dp_master_port = args.pop("dp_master_port")
timeout = args.pop("timeout")
# Remaining args are engine args
engine_args = args
if dp_num_nodes == 1:
dp_master_ip = "127.0.0.1"
dp_master_port_val = get_open_port()
else:
dp_master_ip = dp_master_addr
dp_master_port_val = dp_master_port
assert dp_size % dp_num_nodes == 0, "dp_size should be divisible by dp_num_nodes"
dp_per_node = dp_size // dp_num_nodes
from multiprocessing import Process
if current_platform.is_rocm():
from multiprocessing import set_start_method
set_start_method("spawn", force=True)
procs = []
for local_dp_rank, global_dp_rank in enumerate(
range(dp_node_rank * dp_per_node, (dp_node_rank + 1) * dp_per_node)
):
proc = Process(
target=main,
args=(
dp_size,
local_dp_rank,
global_dp_rank,
dp_master_ip,
dp_master_port_val,
engine_args,
),
)
proc.start()
procs.append(proc)
exit_code = 0
for proc in procs:
proc.join(timeout=timeout)
if proc.exitcode is None:
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:
exit_code = proc.exitcode
exit(exit_code)

View File

@@ -0,0 +1,10 @@
# Disaggregated Prefill V1
This example contains scripts that demonstrate disaggregated prefill in the offline setting of vLLM.
## Files
- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
"""Read prompts from output.txt"""
prompts = []
try:
with open("output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from output.txt")
return prompts
except FileNotFoundError:
print("Error: output.txt file not found")
exit(-1)
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to output.txt
with open("output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,11 @@
rm -rf local_storage/
if [ -f "output.txt" ]; then
rm output.txt
fi
# The directory of current script
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 "$SCRIPT_DIR/prefill_example.py"
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 "$SCRIPT_DIR/decode_example.py"

View File

@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates the example usage of disaggregated prefilling
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and then transfer the KV cache between them.
"""
import os
import time
from multiprocessing import Event, Process
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def run_prefill(prefill_done):
# We use GPU 0 for prefill node.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# The prefill node receives two requests, while the decode node receives
# three requests. So the decode node will only receive the KV Cache for
# requests 1 and 3. The decode node will use the KV Cache of requests 1
# and 3 and do prefilling on request 2.
prompts = [
"Hello, my name is",
"Hi, your name is",
# The decode node will actually "prefill" this request.
"Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
# Using P2pNcclConnector to transmit KV caches between vLLM instances.
# This instance is the prefill node (kv_producer, rank 0).
# The number of parallel instances for KV cache transfer is set to 2,
# as required for P2pNcclConnector.
ktc = KVTransferConfig(
kv_connector="P2pNcclConnector",
kv_role="kv_producer",
kv_rank=0,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.
llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
)
llm.generate(prompts, sampling_params)
print("Prefill node is finished.")
prefill_done.set()
# To keep the prefill node running in case the decode node is not done;
# otherwise, the script might exit prematurely, causing incomplete decoding.
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Script stopped by user.")
def run_decode(prefill_done):
# We use GPU 1 for decode node.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
prompts = [
"Hello, my name is",
"Hi, your name is",
"Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95)
# Using P2pNcclConnector to transmit KV caches between vLLM instances.
# This instance is the decode node (kv_consumer, rank 1).
# The number of parallel instances for KV cache transfer is set to 2,
# as required for P2pNcclConnector.
ktc = KVTransferConfig(
kv_connector="P2pNcclConnector",
kv_role="kv_consumer",
kv_rank=1,
kv_parallel_size=2,
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.
llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
)
# Wait for the producer to start the pipe
print("Waiting for prefill node to finish...")
prefill_done.wait()
# At this point when the prefill_done is set, the kv-cache should have been
# transferred to this decode node, so we can start decoding.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def main():
prefill_done = Event()
prefill_process = Process(target=run_prefill, args=(prefill_done,))
decode_process = Process(target=run_decode, args=(prefill_done,))
# Start prefill node
prefill_process.start()
# Start decode node
decode_process.start()
# Terminate the prefill node when decode is finished
decode_process.join()
prefill_process.terminate()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
import os
import time
from collections.abc import Sequence
from dataclasses import asdict
from typing import NamedTuple
from vllm import LLM, EngineArgs, PromptType, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.utils.argparse_utils import FlexibleArgumentParser
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompts: Sequence[PromptType]
def run_whisper():
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
engine_args = EngineArgs(
model="openai/whisper-large-v3-turbo",
max_model_len=448,
max_num_seqs=16,
limit_mm_per_prompt={"audio": 1},
dtype="half",
)
prompts = [
{ # Test implicit prompt
"prompt": "<|startoftranscript|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt": "<|startoftranscript|>",
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
model_example_map = {
"whisper": run_whisper,
}
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"vision language models for text generation"
)
parser.add_argument(
"--model-type",
"-m",
type=str,
default="whisper",
choices=model_example_map.keys(),
help='Huggingface "model_type".',
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()
def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
req_data = model_example_map[model]()
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
prompts = req_data.prompts
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=64,
skip_special_tokens=False,
)
start = time.time()
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
duration = time.time() - start
print("Duration:", duration)
print("RPS:", len(prompts) / duration)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from safetensors import safe_open
from vllm import LLM, SamplingParams
# Example: Using the custom "extract_hidden_states" speculator method and
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
with tempfile.TemporaryDirectory() as tmpdirname:
llm = LLM(
model="Qwen/Qwen3-8B", # Your target model
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {
"eagle_aux_hidden_state_layer_ids": [ # Target model layer indices
1,
2,
3,
4,
],
}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"shared_storage_path": tmpdirname,
},
},
)
prompts = ["Generate a sentence with hidden states", "Write a python function"]
sampling_params = SamplingParams(max_tokens=1)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print("\nPrompt:", output.prompt)
print("Prompt token ids:", output.prompt_token_ids)
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
print("Prompt hidden states path:", hidden_states_path)
with safe_open(hidden_states_path, "pt") as f:
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")
print("Extracted token ids:", token_ids) # Matches prompt token ids
print(
"Extracted hidden states shape:", hidden_states.shape
) # [num_hidden_layers, prompt len, hidden size]
print("Extracted hidden states:", hidden_states)

View File

@@ -0,0 +1,31 @@
# KV Load Failure Recovery Test
This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`.
It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output.
## Files
- `prefill_example.py` performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`).
- `decode_example.py` performs the decode stage. Accepts:
- `--simulate-failure`: simulates KV load failure using a custom connector.
- `--async-load`: enables asynchronous KV loading mode.
- `load_recovery_example_connector.py` defines `LoadRecoveryExampleConnector`, a subclass of `ExampleConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request.
- `run.sh` orchestrates the test: runs the prefill stage, then three decode stages:
1. Normal decode (baseline).
2. Decode with simulated sync KV load failure.
3. Decode with simulated async KV load failure.
Finally, it compares the output of the baseline with the recovered outputs to verify correctness.
## How It Works
- The test dynamically loads `LoadRecoveryExampleConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector.
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode.
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error.
## Usage
```bash
./run.sh
```

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
"""Read prompts from prefill_output.txt"""
prompts = []
try:
with open("prefill_output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
return prompts
except FileNotFoundError:
print("Error: prefill_output.txt file not found")
exit(-1)
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
parser = argparse.ArgumentParser()
parser.add_argument(
"--simulate-failure", action="store_true", help="Simulate KV load failure."
)
parser.add_argument(
"--async-load", action="store_true", help="Simulate async KV load"
)
args = parser.parse_args()
if args.simulate_failure:
ktc = KVTransferConfig(
kv_connector="LoadRecoveryExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
"async_load": args.async_load,
},
kv_connector_module_path="load_recovery_example_connector",
kv_load_failure_policy="recompute",
)
out_file = (
"async_decode_recovered_output.txt"
if args.async_load
else "sync_decode_recovered_output.txt"
)
else:
ktc = KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
},
)
out_file = "decode_output.txt"
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=ktc,
)
outputs = llm.generate(prompts, sampling_params)
sep_str = "-" * 30
with open(out_file, "w", encoding="utf-8") as f:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
print(out_str)
print(sep_str)
f.write(out_str)
f.write(sep_str)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
ExampleConnector,
ExampleConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
@dataclass
class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata):
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
@classmethod
def from_base(cls, base: ExampleConnectorMetadata):
return cls(requests=base.requests)
class LoadRecoveryExampleConnector(ExampleConnector):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
"async_load", False
)
self._invalid_block_ids: set = None
self._seen_requests: set = set()
self._req_to_block_ids: dict[str, list[int]] = dict()
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, LoadRecoveryExampleConnectorMetadata)
index, failed_request = next(
(
(i, x)
for i, x in enumerate(connector_metadata.requests)
if not x.is_store
),
(None, None),
)
if index is not None:
del connector_metadata.requests[index]
self._invalid_block_ids = set(
(
failed_request.slot_mapping[:: self._block_size] // self._block_size
).tolist()
)
logger.info(
"Simulating failure to load all KV blocks for the "
"first load request. Total blocks: %d",
len(self._invalid_block_ids),
)
super().bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
self._invalid_block_ids = None
super().clear_connector_metadata()
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
if self._async_load and forward_context.attn_metadata is None:
# Bypass sanity check in super().start_load_kv
forward_context.attn_metadata = "None"
super().start_load_kv(forward_context, **kwargs)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
if self._async_load:
meta = self._get_connector_metadata()
assert isinstance(meta, LoadRecoveryExampleConnectorMetadata)
if meta.req_to_block_ids:
return None, set(meta.req_to_block_ids)
return None, None
def get_block_ids_with_load_errors(self) -> set[int]:
return self._invalid_block_ids
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int, bool]:
if request.request_id in self._seen_requests:
return 0, False
self._seen_requests.add(request.request_id)
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
return num_tokens, self._async_load and num_tokens > 0
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
super().update_state_after_alloc(request, blocks, num_external_tokens)
if num_external_tokens > 0:
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]
def build_connector_meta(
self,
scheduler_output: "SchedulerOutput",
) -> KVConnectorMetadata:
if not self._async_load:
base = super().build_connector_meta(scheduler_output)
meta = LoadRecoveryExampleConnectorMetadata.from_base(base)
else:
meta = LoadRecoveryExampleConnectorMetadata()
if self._requests_need_load:
for req_id, request in self._requests_need_load.items():
meta.add_request(
token_ids=request.prompt_token_ids,
block_ids=self._req_to_block_ids[req_id],
block_size=self._block_size,
is_store=False,
mm_hashes=[],
)
# Clear state
self._requests_need_load.clear()
meta.req_to_block_ids = self._req_to_block_ids
self._req_to_block_ids = dict()
return meta

View File

@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to prefill_output.txt
with open("prefill_output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,33 @@
#!/bin/bash
# Constants
SHARED_STORAGE_DIR="local_storage"
PREFILL_OUTPUT="prefill_output.txt"
DECODE_OUTPUT="decode_output.txt"
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt"
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"
# Cleanup
rm -rf "$SHARED_STORAGE_DIR"
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
# Run inference examples
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load
# Compare outputs
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: sync recovery failed."
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi
if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: async recovery failed."
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi
echo "✅ Outputs match: recovery successful."

View File

@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates using the `LLMEngine`
for processing prompts with various sampling parameters.
"""
import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return [
(
"A robot may not injure a human being",
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
),
(
"To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
),
(
"What is the meaning of life?",
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
),
]
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
print("-" * 50)
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print(request_output)
print("-" * 50)
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
engine_args = EngineArgs.from_cli_args(args)
return LLMEngine.from_engine_args(engine_args)
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using the LLMEngine class directly"
)
parser = EngineArgs.add_cli_args(parser)
return parser.parse_args()
def main(args: argparse.Namespace):
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine(args)
test_prompts = create_test_prompts()
process_requests(engine, test_prompts)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates preempt requests when using the `LLMEngine`
for processing prompts with various sampling parameters.
"""
import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return [
(
"A robot may not injure a human being " * 50,
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
),
),
(
"A robot may not injure a human being " * 50,
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
),
),
(
"To be or not to be,",
SamplingParams(
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
),
),
(
"What is the meaning of life?",
SamplingParams(
n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128
),
),
]
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
print("-" * 50)
step_id = 0
while test_prompts or engine.has_unfinished_requests():
print("-" * 50)
import os
print(f"Step {step_id} (pid={os.getpid()})")
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1
if step_id == 10:
print(f"Resetting prefix cache at {step_id}")
engine.reset_prefix_cache(reset_running_requests=True)
request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print("-" * 50)
print(request_output)
print("-" * 50)
step_id += 1
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
engine_args = EngineArgs.from_cli_args(args)
return LLMEngine.from_engine_args(engine_args)
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using the LLMEngine class directly"
)
parser = EngineArgs.add_cli_args(parser)
return parser.parse_args()
def main(args: argparse.Namespace):
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine(args)
test_prompts = create_test_prompts()
process_requests(engine, test_prompts)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Validates the loading of a model saved with the sharded_state format.
This script demonstrates how to load a model that was previously saved
using save_sharded_state.py and validates it by running inference.
Example usage:
(First need to save a sharded_state mode)
python save_sharded_state.py \
--model /path/to/load \
--tensor-parallel-size 8 \
--output /path/to/save/sharded/model
python load_sharded_state.py \
--model /path/to/saved/sharded/model \
--load-format sharded_state \
--tensor-parallel-size 8 \
--prompt "Hello, my name is" \
--max-tokens 50
"""
import dataclasses
from vllm import LLM, EngineArgs, SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
# Add engine arguments
EngineArgs.add_cli_args(parser)
# Override default load_format for clarity
parser.set_defaults(load_format="sharded_state")
# Add validation arguments
parser.add_argument(
"--prompt", type=str, default="Hello, world!", help="Prompt for validation"
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=1.0, help="Top-p sampling parameter"
)
return parser.parse_args()
def main():
args = parse_args()
engine_args = EngineArgs.from_cli_args(args)
print(
f"Loading model from {engine_args.model} using format {engine_args.load_format}"
)
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
# Load the model using engine args
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare sampling parameters
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
)
print("\nRunning inference:")
print(f"Prompt: {args.prompt}")
# Generate completion
outputs = llm.generate(args.prompt, sampling_params)
# Display generated text
print("\nGenerated outputs:")
for output in outputs:
generated_text = output.outputs[0].text
print("-" * 50)
print(f"Full output: {args.prompt}{generated_text}")
print("-" * 50)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,40 @@
# Custom Logits Processors
This directory contains examples demonstrating how to use custom logits processors with vLLM's offline inference API. Logits processors allow you to modify the model's output distribution before sampling, enabling controlled generation behaviors like token masking, constrained decoding, and custom sampling strategies.
## Scripts
### `custom.py` — Engine-level logits processor
Demonstrates how to instantiate vLLM with a custom logits processor class that operates at the batch level. The example uses a `DummyLogitsProcessor` that masks out all tokens except a specified `target_token` when passed via `SamplingParams.extra_args`.
```bash
python examples/offline_inference/logits_processor/custom.py
```
### `custom_req.py` — Request-level logits processor wrapper
Shows how to wrap a request-level logits processor (which operates on individual requests) to be compatible with vLLM's batch-level logits processing interface.
```bash
python examples/offline_inference/logits_processor/custom_req.py
```
### `custom_req_init.py` — Request-level processor with engine config
A special case of wrapping a request-level logits processor where the processor needs access to engine configuration or model metadata during initialization (e.g., vocabulary size, tokenizer info).
```bash
python examples/offline_inference/logits_processor/custom_req_init.py
```
## Key Concepts
- **Batch-level vs. request-level**: vLLM processes logits at the batch level for efficiency. If you have a per-request processor, you need to wrap it using the patterns shown in `custom_req.py` and `custom_req_init.py`.
- **`SamplingParams.extra_args`**: Use this to pass custom keyword arguments to your logits processor on a per-request basis (e.g., `target_token`).
- **`DummyLogitsProcessor`**: A reference implementation available in `vllm/test_utils.py` that can be used as a starting point for custom processors.
## Further Reading
- [vLLM Sampling Parameters](https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-parameters)
- [vLLM LLM API](https://docs.vllm.ai/en/latest/api/offline_inference/llm.html)

View File

@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This example demonstrates instantiating vLLM with a custom logits processor
class object.
For a basic example of implementing a custom logits processor, see
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.
For testing purposes, a dummy logits processor is employed which, if
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
will mask out all tokens except `target_token`.
A batch is constructed with `temperature=0.0` and 50% of requests specifying
`target_token`, and for these requests - and *only* these requests - we
expect the `target_token` to be decoded in each step, yielding an output
similar to that shown below:
Generated Outputs:
------------------------------------------------------------
Prompt: 'Hello, my name is'
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
------------------------------------------------------------
Prompt: 'The president of the United States is'
Output: " not a racist. He is a racist.\nHe's a racist because he"
------------------------------------------------------------
Prompt: 'The capital of France is'
Output: ' also also also also also also also also also also also also also
also also also'
------------------------------------------------------------
Prompt: 'The future of AI is'
Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
"""
from typing import Any
import torch
from vllm import LLM, SamplingParams
from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import (
BatchUpdate,
LogitsProcessor,
)
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
# Hypothetical custom logits processor
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
self.req_info: dict[int, int] = {}
def is_argmax_invariant(self) -> bool:
return False
def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")
process_dict_updates(
self.req_info,
batch_update,
# This function returns the LP's per-request state based on the
# request details, or None if this LP does not apply to the
# request.
lambda params, _, __: extract_extra_arg(params),
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info:
return logits
# Save target values before modification
cols = torch.tensor(
list(self.req_info.values()), dtype=torch.long, device=logits.device
)
rows = torch.tensor(
list(self.req_info.keys()), dtype=torch.long, device=logits.device
)
values_to_keep = logits[rows, cols].clone()
# Mask all but target tokens
logits[rows] = float("-inf")
logits[rows, cols] = values_to_keep
return logits
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
SamplingParams(temperature=0.0),
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
SamplingParams(temperature=0.0),
]
def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
logits_processors=[DummyLogitsProcessor],
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params_list)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This example demonstrates wrapping a request-level logits processor to be
compatible with vLLM's batch-level logits processing
For demo purposes, a dummy logits processor is employed which, if
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
will mask out all tokens except `target_token`. This logits processor can be
applied to a vector of logits associated with a single decode step for a single
request. The logits processor cannot be applied to a request which does not
pass in a `target_token` custom argument.
The request-level dummy logits processor is wrapped to create a batch-level
logits processor, which can apply the logits processor to output logits from
all requests in the persistent batch in a given decode step. For requests which
do not provide a `target_token` argument, the corresponding row of `logits`
will not be modified.
A batch is constructed with `temperature=0.0` and 50% of requests specifying
`target_token`, and for these requests - and *only* these requests - we
expect the `target_token` to be decoded in each step, yielding an output
similar to that shown below:
Generated Outputs:
------------------------------------------------------------
Prompt: 'Hello, my name is'
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
------------------------------------------------------------
Prompt: 'The president of the United States is'
Output: " not a racist. He is a racist.\nHe's a racist because he"
------------------------------------------------------------
Prompt: 'The capital of France is'
Output: ' also also also also also also also also also also also also also
also also also'
------------------------------------------------------------
Prompt: 'The future of AI is'
Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
"""
from typing import Any
import torch
from vllm import LLM, SamplingParams
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
)
logger = init_logger(__name__)
class DummyPerReqLogitsProcessor:
"""The request-level logits processor masks out all logits except the
token id identified by `target_token`"""
def __init__(self, target_token: int) -> None:
"""Specify `target_token`"""
self.target_token = target_token
def __call__(
self,
output_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
val_to_keep = logits[self.target_token].item()
logits[:] = float("-inf")
logits[self.target_token] = val_to_keep
return logits
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")
def is_argmax_invariant(self) -> bool:
return False
def new_req_logits_processor(
self,
params: SamplingParams,
) -> RequestLogitsProcessor | None:
"""This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request.
Returns None if the logits processor should not be applied to the
particular request. To use the logits processor the request must have
a "target_token" custom argument with an integer value.
Args:
params: per-request sampling params
Returns:
`Callable` request logits processor, or None
"""
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is None:
return None
return DummyPerReqLogitsProcessor(target_token)
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
SamplingParams(temperature=0.0),
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
SamplingParams(temperature=0.0),
]
def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
logits_processors=[WrappedPerReqLogitsProcessor],
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params_list)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This example demonstrates a special case of wrapping a request-level logits
processor, namely the case where it is necessary to utilize engine config or
environment info passed to the constructor. The subclass must override the
wrapper base class `__init__()` method to access the engine config, the device
identifier, or the flag which indicates whether pinned memory is available.
For demo purposes, a request-level dummy logits processor is employed which
causes the same token (`target_token`) to be decoded in each step. The
request-level dummy logits processor is wrapped to create a batch-level logits
processor, which can apply the logits processor to output logits from all
requests in the persistent batch in a given decode step.
The wrapped dummy logits processor below models a scenario where we must
disable the logits processor on non-"cuda" platforms. The wrapper base class
`__init__()` is overridden in order to check this condition and set a flag.
A batch is constructed with `temperature=0.0` and 50% of requests specifying
`target_token`, and for these requests - and *only* these requests - we
expect that on a "cuda" device the output will look something like:
Generated Outputs:
------------------------------------------------------------
Prompt: 'Hello, my name is'
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
------------------------------------------------------------
Prompt: 'The president of the United States is'
Output: " not a racist. He is a racist.\nHe's a racist because he"
------------------------------------------------------------
Prompt: 'The capital of France is'
Output: ' also also also also also also also also also also also also also
also also also'
------------------------------------------------------------
Prompt: 'The future of AI is'
Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
which indicates that the logits processor is running. However, on a non-"cuda"
device, the first and third requests would not repeat the same token.
"""
import torch
from vllm import LLM, SamplingParams
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
)
logger = init_logger(__name__)
class DummyPerReqLogitsProcessor:
"""The request-level logits processor masks out all logits except the
token id identified by `target_token`"""
def __init__(self, target_token: int) -> None:
"""Specify `target_token`"""
self.target_token = target_token
def __call__(
self,
output_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
val_to_keep = logits[self.target_token].item()
logits[:] = float("-inf")
logits[self.target_token] = val_to_keep
return logits
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
super().__init__(vllm_config, device, is_pin_memory)
self.is_cuda = device.type == "cuda"
def is_argmax_invariant(self) -> bool:
return False
def new_req_logits_processor(
self,
params: SamplingParams,
) -> RequestLogitsProcessor | None:
"""This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request.
Returns None if the logits processor should not be applied to the
particular request. To use the logits processor the request must have
a "target_token" custom argument with an integer value, and the device
must be "cuda"-type
Args:
params: per-request sampling params
Returns:
`Callable` request logits processor, or None
"""
if (
not self.is_cuda
or (
target_token := params.extra_args
and params.extra_args.get("target_token")
)
is None
):
return None
return DummyPerReqLogitsProcessor(target_token)
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
SamplingParams(temperature=0.0),
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
SamplingParams(temperature=0.0),
]
def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
logits_processors=[WrappedPerReqLogitsProcessor],
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params_list)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use LoRA with different quantization techniques
for offline inference.
Requires HuggingFace credentials for access.
"""
import gc
import torch
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
def create_test_prompts(
lora_path: str,
) -> list[tuple[str, SamplingParams, LoRARequest | None]]:
return [
# this is an example of using quantization without LoRA
(
"My name is",
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
None,
),
# the next three examples use quantization with LoRA
(
"my name is",
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("lora-test-1", 1, lora_path),
),
(
"The capital of USA is",
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("lora-test-2", 1, lora_path),
),
(
"The capital of France is",
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("lora-test-3", 1, lora_path),
),
]
def process_requests(
engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]],
):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(
str(request_id), prompt, sampling_params, lora_request=lora_request
)
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print("----------------------------------------------------")
print(f"Prompt: {request_output.prompt}")
print(f"Output: {request_output.outputs[0].text}")
def initialize_engine(
model: str, quantization: str, lora_repo: str | None
) -> LLMEngine:
"""Initialize the LLMEngine."""
engine_args = EngineArgs(
model=model,
quantization=quantization,
enable_lora=True,
max_lora_rank=64,
max_loras=4,
)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
test_configs = [
# QLoRA (https://arxiv.org/abs/2305.14314)
{
"name": "qlora_inference_example",
"model": "huggyllama/llama-7b",
"quantization": "bitsandbytes",
"lora_repo": "timdettmers/qlora-flan-7b",
},
{
"name": "AWQ_inference_with_lora_example",
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
"quantization": "awq",
"lora_repo": "jashing/tinyllama-colorist-lora",
},
{
"name": "GPTQ_inference_with_lora_example",
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"quantization": "gptq",
"lora_repo": "jashing/tinyllama-colorist-lora",
},
]
for test_config in test_configs:
print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
engine = initialize_engine(
test_config["model"], test_config["quantization"], test_config["lora_repo"]
)
lora_path = snapshot_download(repo_id=test_config["lora_repo"])
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.accelerator.empty_cache()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m", disable_log_stats=False)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Dump all metrics
for metric in llm.get_metrics():
if isinstance(metric, Gauge):
print(f"{metric.name} (gauge) = {metric.value}")
elif isinstance(metric, Counter):
print(f"{metric.name} (counter) = {metric.value}")
elif isinstance(metric, Vector):
print(f"{metric.name} (vector) = {metric.values}")
elif isinstance(metric, Histogram):
print(f"{metric.name} (histogram)")
print(f" sum = {metric.sum}")
print(f" count = {metric.count}")
for bucket_le, value in metric.buckets.items():
print(f" {bucket_le} = {value}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import argparse
from vllm import LLM
from vllm.sampling_params import SamplingParams
from vllm.assets.image import ImageAsset
from vllm.multimodal.utils import encode_image_url
# This script is an offline demo for running Mistral-Small-3.1
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# # Mistral format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
# --limit-mm-per-prompt.image 4 --max-model-len 16384
#
# # HF format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --limit-mm-per-prompt.image 4 --max-model-len 16384
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
# These scripts have been tested on 2x L40 GPUs
def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
sampling_params = SamplingParams(max_tokens=8192)
llm = LLM(
model=model_name,
tokenizer_mode="mistral" if args.format == "mistral" else "hf",
config_format="mistral" if args.format == "mistral" else "hf",
load_format="mistral" if args.format == "mistral" else "hf",
limit_mm_per_prompt={"image": 1},
max_model_len=4096,
max_num_seqs=2,
tensor_parallel_size=2,
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
)
prompt = "Describe this image in one sentence."
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": encode_image_url(ImageAsset("cherry_blossom").pil_image)
},
},
],
},
]
outputs = llm.chat(messages, sampling_params=sampling_params)
print("-" * 50)
print(outputs[0].outputs[0].text)
print("-" * 50)
def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
max_img_per_msg = 3
max_tokens_per_img = 4096
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
llm = LLM(
model=model_name,
tokenizer_mode="mistral" if args.format == "mistral" else "hf",
config_format="mistral" if args.format == "mistral" else "hf",
load_format="mistral" if args.format == "mistral" else "hf",
limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img,
tensor_parallel_size=2,
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
)
prompt = "Describe the following image."
url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
url_2 = "https://picsum.photos/seed/picsum/200/300"
url_3 = "https://picsum.photos/id/32/512/512"
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": url_1}},
{"type": "image_url", "image_url": {"url": url_2}},
],
},
{
"role": "assistant",
"content": "The images show nature.",
},
{
"role": "user",
"content": "More details please and answer only in French!.",
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": url_3}},
],
},
]
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
print("-" * 50)
print(outputs[0].outputs[0].text)
print("-" * 50)
def parse_args():
parser = argparse.ArgumentParser(
description="Run a demo in simple or advanced mode."
)
parser.add_argument(
"mode",
choices=["simple", "advanced"],
help="Specify the demo mode: 'simple' or 'advanced'",
)
parser.add_argument(
"--format",
choices=["mistral", "hf"],
default="mistral",
help="Specify the format of the model to load.",
)
parser.add_argument(
"--disable-mm-processor-cache",
action="store_true",
help="If True, disables caching of multi-modal processor.",
)
return parser.parse_args()
def main():
args = parse_args()
if args.mode == "simple":
print("Running simple demo...")
run_simple_demo(args)
elif args.mode == "advanced":
print("Running advanced demo...")
run_advanced_demo(args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates the usage of text generation with an LLM model,
comparing the performance with and without speculative decoding.
Note that this example is out of date and not supported in vLLM v1.
"""
import gc
import time
from vllm import LLM, SamplingParams
def time_generation(
llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
print("-" * 50)
print(title)
print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
# Print the outputs.
for output in outputs:
generated_text = output.outputs[0].text
print(f"text: {generated_text!r}")
print("-" * 50)
def main():
template = (
"Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}"
"\n\n### Response:\n"
)
# Sample prompts.
prompts = [
"Write about the president of the United States.",
]
prompts = [template.format(prompt) for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
# Create an LLM without spec decoding
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
time_generation(llm, prompts, sampling_params, "Without speculation")
del llm
gc.collect()
# Create an LLM with spec decoding
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_config={
"model": "ibm-ai-platform/llama-13b-accelerator",
},
)
time_generation(llm, prompts, sampling_params, "With speculation")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use the multi-LoRA functionality
for offline inference.
Requires HuggingFace credentials for access to Llama2.
"""
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
def create_test_prompts(
lora_path: str,
) -> list[tuple[str, SamplingParams, LoRARequest | None]]:
"""Create a list of test prompts with their sampling parameters.
2 requests for base model, 4 requests for the LoRA. We define 2
different LoRA adapters (using the same model for demo purposes).
Since we also set `max_loras=1`, the expectation is that the requests
with the second LoRA adapter will be run after all requests with the
first adapter have finished.
"""
return [
(
"A robot may not injure a human being",
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
None,
),
(
"To be or not to be,",
SamplingParams(
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
),
None,
),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("sql-lora", 1, lora_path),
),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("sql-lora2", 2, lora_path),
),
]
def process_requests(
engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]],
):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
print("-" * 50)
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(
str(request_id), prompt, sampling_params, lora_request=lora_request
)
request_id += 1
request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print(request_output)
print("-" * 50)
def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine."""
# max_loras: controls the number of LoRAs that can be used in the same
# batch. Larger numbers will cause higher memory usage, as each LoRA
# slot requires its own preallocated tensor.
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
# numbers will cause higher memory usage. If you know that all LoRAs will
# use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache.
engine_args = EngineArgs(
model="meta-llama/Llama-3.2-3B-Instruct",
enable_lora=True,
max_loras=1,
max_lora_rank=8,
max_cpu_loras=2,
max_num_seqs=256,
)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine()
lora_path = snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,415 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates async reinforcement learning using vLLM and Ray,
with native weight syncing APIs at engine instance.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies one GPU for training, whereas a
2x tensor-parallel vLLM inference engine occupies two GPUs.
The example performs the following steps:
* Load the training model on one gpu (scheduled via ray)
* Initialize the inference model with dummy weights across
two gpus using vLLM's tensor parallelism and Ray placement groups.
* Generate gibberish from a list of prompts using the randomly initialized
inference engine.
* Pause generation once generation completes for one sequence
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group.
* Resume generation and print out the results
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import asyncio
import uuid
from dataclasses import asdict
import ray
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import vllm
from vllm import SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
class MyLLM(vllm.AsyncLLMEngine):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, **kwargs):
engine_args = vllm.AsyncEngineArgs(**kwargs)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
super().__init__(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats,
)
self._generation_paused = False
self._request_pause_flag = False
async def do_generate(
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
) -> tuple[vllm.RequestOutput, int]:
"""Generate a single request, setting the request pause flag once the
token count reaches the threshold.
Returns (output, pause_token_index). pause_token_index is the number
of tokens generated before the weight change, or -1 if no pause.
"""
pause_token_index = -1
prev_token_count = 0
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
):
output = request_output
cur_token_count = len(output.outputs[0].token_ids)
if (
cur_token_count >= PAUSE_TOKEN_THRESHOLD
and not self._request_pause_flag
):
self._request_pause_flag = True
if self._generation_paused and pause_token_index == -1:
pause_token_index = prev_token_count
prev_token_count = cur_token_count
return output, pause_token_index
async def pause_after_n_tokens(self):
"""Wait for any request to set the pause flag, then pause."""
while not self._request_pause_flag:
await asyncio.sleep(0)
await super().pause_generation(mode="keep")
await asyncio.sleep(5)
self._generation_paused = True
@ray.remote(num_gpus=1)
class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops
attn_backend = (
AttentionBackendEnum.TRITON_ATTN
if current_platform.is_rocm()
else AttentionBackendEnum.FLASH_ATTN
)
init_batch_invariance(attn_backend)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16
).to("cuda:0")
self.port = get_open_port()
self.master_address = get_ip()
def get_master_address_and_port(self):
return self.master_address, self.port
def get_weight_metadata(self):
"""Return weight names, dtypes, and shapes for weight transfer."""
names = []
dtype_names = []
shapes = []
for name, p in self.model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
return names, dtype_names, shapes
def init_weight_transfer_group(self, world_size):
"""Initialize the NCCL process group for weight transfer."""
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=self.master_address,
master_port=self.port,
world_size=world_size,
),
)
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
@torch.inference_mode()
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
"""Greedy-decode max_new_tokens from the given context."""
input_ids = torch.tensor([token_ids], device="cuda:0")
output = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
)
new_token_ids = output[0, len(token_ids) :].tolist()
return new_token_ids
# Build platform-specific env vars for Ray
ray_env_vars = {
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}
if current_platform.is_rocm():
# For ROCm, BATCH_INVARIANT vllm is not supported
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
else:
# Enable batch invariance for deterministic outputs on NVIDIA
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
ray.init(runtime_env={"env_vars": ray_env_vars})
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME_V2)
rocm_determinism_kwargs = {}
if current_platform.is_rocm():
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
# sequential request processing (max_num_seqs=1).
rocm_determinism_kwargs = {
"seed": 0,
"enable_prefix_caching": False,
"max_num_seqs": 1,
}
# Build platform-specific LLM kwargs
llm_kwargs = dict(
model=MODEL_NAME_V1,
enforce_eager=True,
max_model_len=8192,
distributed_executor_backend="ray",
attention_backend=ATTN_BACKEND,
gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
llm_kwargs.update(rocm_determinism_kwargs)
# Launch the vLLM inference engine.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_kwargs)
PROMPTS = [
"The president of the United States is",
"The capital of France is",
"The largest ocean on Earth is",
"The speed of light in a vacuum is",
"The chemical formula for water is",
"The tallest mountain in the world is",
"The first person to walk on the moon was",
"The Great Wall of China was built to",
"Photosynthesis is the process by which",
"The theory of general relativity was proposed by",
"The boiling point of water at sea level is",
"The largest planet in our solar system is",
"DNA stands for deoxyribonucleic acid and it",
]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
batch_prompt_token_ids = [
tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
]
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = 2 # 1 trainer + 1 inference worker
inference_handle = llm.init_weight_transfer_engine.remote(
WeightTransferInitRequest(
init_info=asdict(
NCCLWeightTransferInitInfo(
master_address=master_address,
master_port=master_port,
rank_offset=1,
world_size=world_size,
)
)
)
)
# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
N_NEW_TOKENS = 100
# Collect weight metadata once
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# ── Phase 1: concurrent requests with weight sync ───────────────────
print(f"\n{'=' * 50}")
print(f"Prompts ({len(PROMPTS)}):")
for p in PROMPTS:
print(f" - {p!r}")
print(f"{'=' * 50}")
sampling_params = SamplingParams(
temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
)
gen_futures = [
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
]
ray.get(llm.pause_after_n_tokens.remote())
inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest(
update_info=asdict(
NCCLWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
)
)
)
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.resume_generation.remote())
results = ray.get(gen_futures)
for i, (output, pause_idx) in enumerate(results):
all_token_ids = list(output.outputs[0].token_ids)
before_text = tokenizer.decode(all_token_ids[:pause_idx])
after_text = tokenizer.decode(all_token_ids[pause_idx:])
print(f"\n Request {i} ({PROMPTS[i]!r}):")
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
n_after = len(all_token_ids) - pause_idx
print(f" New weights ({n_after} tokens): {after_text!r}")
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
# This validation relies on batch-invariant (deterministic) generation to
# compare outputs from the weight-synced engine against a fresh V2 instance.
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
# token match. On ROCm, batch invariance is not yet fully implemented
# (see https://github.com/vllm-project/vllm/issues/27433 and
# https://github.com/vllm-project/vllm/issues/33123), so residual
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
# can cause single-token divergences that don't indicate a weight-sync
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
if current_platform.is_rocm():
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
print(f"{'=' * 50}")
ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)
llm_v2_kwargs = dict(
model=MODEL_NAME_V2,
enforce_eager=True,
max_model_len=8192,
gpu_memory_utilization=0.75,
distributed_executor_backend="ray",
attention_backend=ATTN_BACKEND,
)
llm_v2_kwargs.update(rocm_determinism_kwargs)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_v2_kwargs)
val_futures = [
llm_v2.do_generate.remote(
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
SamplingParams(
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
),
)
for output, pause_idx in results
]
val_results = ray.get(val_futures)
num_pass = 0
num_total = len(results)
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
expected = list(output.outputs[0].token_ids)[pause_idx:]
actual = list(val_output.outputs[0].token_ids)
match = actual == expected
if match:
num_pass += 1
print(f" [PASS] {PROMPTS[i]!r}")
else:
print(f" [FAIL] {PROMPTS[i]!r}")
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
for j, (e, a) in enumerate(zip(expected, actual)):
if e != a:
print(
f" first divergence at output token {j}: "
f"expected {e} ({tokenizer.decode([e])!r}) vs "
f"actual {a} ({tokenizer.decode([a])!r})"
)
break
ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2)
pass_rate = num_pass / num_total
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
print(f" Required: >= {MIN_PASS_RATE:.0%}")
assert pass_rate >= MIN_PASS_RATE, (
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
f"See failures above for details."
)
print("=" * 50)

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,
with IPC-based weight syncing APIs
The script colocates the training and inference workloads onto the same GPU using Ray.
The example performs the following steps:
* Request a placement group of 1 GPU.
* Place the inference model on the above GPU using the placement group.
* Place and load the training model on the same GPU using the placement group.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using CUDA IPC handles. Note that
for demonstration purposes we simply zero out the weights.
This example assumes a single-node cluster with a single GPU,
but can be extended to multiple GPUs.
"""
import os
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
)
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# Each worker uses 0.4 GPU so that two instances fit on the same GPU.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
# needed for ipc handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
super().__init__(*args, **kwargs)
# Load the OPT-125M model onto GPU 0 for the training workload.
MODEL_NAME = "facebook/opt-125m"
@ray.remote
class TrainModel:
def __init__(self, llm_handle: ray.actor.ActorHandle):
self.train_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
)
self.train_model.to("cuda:0")
self.llm_handle = llm_handle
def init_weight_transfer(self):
# IPC backend doesn't need initialization info
ray.get(
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
)
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
"""Broadcast weights to the inference engine using IPC."""
self.llm_handle = llm_handle
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
IPCWeightTransferEngine.trainer_send_weights(
iterator=self.train_model.named_parameters(),
trainer_args=trainer_args,
)
ray.init()
pg_colocate = placement_group([{"GPU": 1, "CPU": 0}])
ray.get(pg_colocate.ready())
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=1,
distributed_executor_backend="ray",
gpu_memory_utilization=0.7,
weight_transfer_config=WeightTransferConfig(backend="ipc"),
load_format="dummy",
)
train_model = TrainModel.options(
num_gpus=0.1,
num_cpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate, placement_group_capture_child_tasks=True
),
).remote(llm)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
ray.get(train_model.init_weight_transfer.remote())
# Synchronize the updated weights to the inference engine using batched API.
ray.get(train_model.broadcast_weights.remote(llm))
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning using vLLM and Ray,
with native weight syncing APIs at engine instance.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies one GPU for training, whereas a
2x tensor-parallel vLLM inference engine occupies two GPUs.
The example performs the following steps:
* Load the training model on one gpu (scheduled via ray)
* Initialize the inference model with dummy weights across
two gpus using vLLM's tensor parallelism and Ray placement groups.
* Generate gibberish from a list of prompts using the randomly initialized
inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group.
* Generating from the list of prompts after weight sync should result
in sensible outputs.
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
MODEL_NAME = "facebook/opt-125m"
# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128"
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
super().__init__(*args, **kwargs)
@ray.remote(num_gpus=1)
class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
).to("cuda:0")
self.port = get_open_port()
self.master_address = get_ip()
def get_master_address_and_port(self):
return self.master_address, self.port
def get_weight_metadata(self):
"""Return weight names, dtypes, and shapes for weight transfer."""
names = []
dtype_names = []
shapes = []
for name, p in self.model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
return names, dtype_names, shapes
def init_weight_transfer_group(self, world_size):
"""Initialize the NCCL process group for weight transfer."""
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=self.master_address,
master_port=self.port,
world_size=world_size,
),
)
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME)
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
# are now native to vLLM workers.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=2,
data_parallel_size=1,
distributed_executor_backend="ray",
weight_transfer_config=WeightTransferConfig(backend="nccl"),
load_format="dummy",
quantization="fp8",
)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
# Generate text with the initial model. The output is expected to be nonsense
# because the weights are randomly initialized.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer
inference_handle = llm.init_weight_transfer_engine.remote(
dict(
init_info=dict(
master_address=master_address,
master_port=master_port,
rank_offset=1,
world_size=world_size,
)
)
)
# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
# Synchronize the updated weights to the inference engine using batched API.
# Collect all weight metadata from the training actor
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote(
dict(
update_info=dict(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
)
)
# Broadcast all weights from trainer using the weight transfer API
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model. The output is expected to be normal
# because the weights are updated.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

View File

@@ -0,0 +1,276 @@
# Offline Inference with the OpenAI Batch file format
```{important}
This is a guide to performing batch inference using the OpenAI batch file format, **not** the complete Batch (REST) API.
```
## File Format
The OpenAI batch file format consists of a series of json objects on new lines.
[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl)
Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.
```{note}
We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` endpoints (completions coming soon).
```
## Pre-requisites
* The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`.
* Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens)
* Install the token on your machine (Run `hf auth login`).
* Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions.
## Example 1: Running with a local file
### Step 1: Create your batch file
To follow along with this example, you can download the example batch, or create your own batch file in your working directory.
```bash
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl
```
Once you've created your batch file it should look like this
```bash
cat offline_inference/openai_batch/openai_example_batch.jsonl
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
```
### Step 2: Run the batch
The batch running tool is designed to be used from the command line.
You can run the batch with the following command, which will write its results to a file called `results.jsonl`
```bash
python -m vllm.entrypoints.openai.run_batch \
-i offline_inference/openai_batch/openai_example_batch.jsonl \
-o results.jsonl \
--model meta-llama/Meta-Llama-3-8B-Instruct
```
or use command-line:
```bash
vllm run-batch \
-i offline_inference/openai_batch/openai_example_batch.jsonl \
-o results.jsonl \
--model meta-llama/Meta-Llama-3-8B-Instruct
```
### Step 3: Check your results
You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl`
```bash
cat results.jsonl
{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null}
{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null}
```
## Example 2: Using remote files
The batch runner supports remote input and output urls that are accessible via http/https.
For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl`, you can run
```bash
python -m vllm.entrypoints.openai.run_batch \
-i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \
-o results.jsonl \
--model meta-llama/Meta-Llama-3-8B-Instruct
```
or use command-line:
```bash
vllm run-batch \
-i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \
-o results.jsonl \
--model meta-llama/Meta-Llama-3-8B-Instruct
```
## Example 3: Integrating with AWS S3
To integrate with cloud blob storage, we recommend using presigned urls.
[Learn more about S3 presigned urls here]
### Additional prerequisites
* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html).
* The `awscli` package (Run `pip install awscli`) to configure your credentials and interactively use s3.
* [Configure your credentials](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html).
* The `boto3` python package (Run `pip install boto3`) to generate presigned urls.
### Step 1: Upload your input script
To follow along with this example, you can download the example batch, or create your own batch file in your working directory.
```bash
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl
```
Once you've created your batch file it should look like this
```bash
cat offline_inference/openai_batch/openai_example_batch.jsonl
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
```
Now upload your batch file to your S3 bucket.
```bash
aws s3 cp offline_inference/openai_batch/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl
```
### Step 2: Generate your presigned urls
Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
(The script is adapted from <https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py>)
```python
import boto3
from botocore.exceptions import ClientError
def generate_presigned_url(s3_client, client_method, method_parameters, expires_in):
"""
Generate a presigned Amazon S3 URL that can be used to perform an action.
:param s3_client: A Boto3 Amazon S3 client.
:param client_method: The name of the client method that the URL performs.
:param method_parameters: The parameters of the specified client method.
:param expires_in: The number of seconds the presigned URL is valid for.
:return: The presigned URL.
"""
try:
url = s3_client.generate_presigned_url(
ClientMethod=client_method,
Params=method_parameters,
ExpiresIn=expires_in,
)
except ClientError:
raise
return url
s3_client = boto3.client("s3")
input_url = generate_presigned_url(
s3_client,
"get_object",
{"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"},
expires_in=3600,
)
output_url = generate_presigned_url(
s3_client,
"put_object",
{"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"},
expires_in=3600,
)
print(f"{input_url=}")
print(f"{output_url=}")
```
This script should output
```text
input_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091'
output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091'
```
### Step 3: Run the batch runner using your presigned urls
You can now run the batch runner, using the urls generated in the previous section.
```bash
python -m vllm.entrypoints.openai.run_batch \
-i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
-o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
--model --model meta-llama/Meta-Llama-3-8B-Instruct
```
or use command-line:
```bash
vllm run-batch \
-i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
-o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
--model --model meta-llama/Meta-Llama-3-8B-Instruct
```
### Step 4: View your results
Your results are now on S3. You can view them in your terminal by running
```bash
aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl -
```
## Example 4: Using embeddings endpoint
### Additional prerequisites
* Ensure you are using `vllm >= 0.5.5`.
### Step 1: Create your batch file
Add embedding requests to your batch file. The following is an example:
```text
{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
```
You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model).
### Step 2: Run the batch
You can run the batch using the same command as in earlier examples.
### Step 3: Check your results
You can check your results by running `cat results.jsonl`
```bash
cat results.jsonl
{"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null}
...
```
## Example 5: Using score endpoint
### Additional prerequisites
* Ensure you are using `vllm >= 0.7.0`.
### Step 1: Create your batch file
Add score requests to your batch file. The following is an example:
```text
{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "queries": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "queries": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
```
You can mix chat completion, embedding, and score requests in the batch file, as long as the model you are using supports them all (note that all requests must use the same model).
### Step 2: Run the batch
You can run the batch using the same command as in earlier examples.
### Step 3: Check your results
You can check your results by running `cat results.jsonl`
```bash
cat results.jsonl
{"id":"vllm-f87c5c4539184f618e555744a2965987","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-806ab64512e44071b37d3f7ccd291413","body":{"id":"score-4ee45236897b4d29907d49b01298cdb1","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.0010900497436523438},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null}
{"id":"vllm-41990c51a26d4fac8419077f12871099","custom_id":"request-2","response":{"status_code":200,"request_id":"vllm-batch-73ce66379026482699f81974e14e1e99","body":{"id":"score-13f2ffe6ba40460fbf9f7f00ad667d75","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.001094818115234375},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null}
```

View File

@@ -0,0 +1,2 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}

View File

@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test for pause/resume with keep mode.
This test uses concurrent tasks to verify the engine truly stops generating
during pause:
1. Generator task: continuously generates and logs time between tokens
2. Controller task: sends pause/resume commands
If the engine properly pauses, we should see a gap in token timestamps
matching the pause duration.
"""
import asyncio
import time
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
PAUSE_DURATION = 3.0 # seconds
async def main():
# Create engine with a small model
engine_args = AsyncEngineArgs(
model="facebook/opt-125m",
enforce_eager=True,
)
engine = AsyncLLM.from_engine_args(engine_args)
prompt = "Write a story about a dragon. Once upon a time"
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
# Track token arrival times
token_times: list[tuple[int, float]] = [] # (token_count, timestamp)
pause_time: float = 0
resume_time: float = 0
pause_token_idx: int = 0 # Index in token_times when pause occurred
async def generator_task():
"""Generate tokens and record timestamps."""
async for output in engine.generate(
request_id="test-req",
prompt=prompt,
sampling_params=sampling_params,
):
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
print(
f"Token {token_count} arrived:"
f"T={token_times[-1][1] - token_times[0][1]:.3f}s"
)
return output
async def controller_task():
"""Pause and resume the engine after some tokens generated."""
nonlocal pause_time, resume_time, pause_token_idx
# Wait for some tokens to be generated
while len(token_times) < 5:
await asyncio.sleep(0.01)
print(f"\nPausing engine (keep mode) at token {len(token_times)}")
pause_time = time.monotonic()
await engine.pause_generation(mode="keep")
pause_token_idx = len(token_times)
print(f"Paused! Sleeping for {PAUSE_DURATION}s...")
# Sleep while paused - no tokens should be generated during this time
await asyncio.sleep(PAUSE_DURATION)
print("Resuming engine...")
await engine.resume_generation()
resume_time = time.monotonic()
print("Resumed!\n")
# Run both tasks concurrently
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.gather(gen_task, ctrl_task)
# Verify the pause actually stopped generation.
# The gap after the pause token should be approximately the sleep duration.
pause_gap = token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
print(
f"\nGap after pause (token {pause_token_idx - 1} -> {pause_token_idx}): "
f"{pause_gap:.3f}s"
)
if pause_gap >= PAUSE_DURATION * 0.9:
print(f"✓ Test passed! Engine paused for ~{pause_gap:.1f}s")
else:
print(
f"✗ Test failed! Expected ~{PAUSE_DURATION}s gap after pause, "
f"got {pause_gap:.3f}s"
)
raise AssertionError("Engine did not properly pause")
# Verify request completed
assert final_output.finished, "Request should have finished"
assert len(final_output.outputs[0].token_ids) == 30, "Should have all tokens"
engine.shutdown()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
# NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py
# Common prefix.
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: "
)
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
generating_prompts = [prefix + prompt for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)
def main():
# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
print("Results without `enable_prefix_caching`")
# ruff: noqa: E501
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = regular_llm.generate(generating_prompts, sampling_params)
regular_generated_texts = []
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Destroy the LLM object and free up the GPU memory.
del regular_llm
cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(
model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.4,
)
# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
# Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
print("Results with `enable_prefix_caching`")
cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Compare the results and display the speedup
generated_same = all(
[
regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts))
]
)
print(f"Generated answers are the same: {generated_same}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use FlexKV with vLLM for prefix caching.
FlexKV is a distributed KV Store and multi-level cache management system for
ultra-large-scale LLM inference.
Requirements:
- Install FlexKV (https://github.com/taco-project/FlexKV):
1. git clone git@github.com:taco-project/FlexKV.git
2. cd FlexKV && bash build.sh
- Ensure FlexKV is compatible with your vLLM version.
Usage:
1. Run this script:
python examples/offline_inference/prefix_caching_flexkv.py \
--model /path/to/your/model
2. Arguments:
--model Path or name of the model (required)
--tp-size Tensor parallel size (default: 1)
--gpu-memory-util GPU memory utilization (default: 0.4)
3. The script will:
- Create a FlexKV configuration file.
- Set the FLEXKV_CONFIG_PATH environment variable.
- Run vLLM with FlexKVConnectorV1 enabled.
- Compare results between regular execution, vLLM's default prefix
caching, and FlexKV.
"""
import argparse
import json
import os
import time
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
# NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py
def parse_args():
parser = argparse.ArgumentParser(
description="Example of using FlexKV with vLLM for prefix caching."
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path or name of the model to use.",
)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size (default: 1).",
)
parser.add_argument(
"--gpu-memory-util",
type=float,
default=0.4,
help="GPU memory utilization fraction (default: 0.4).",
)
return parser.parse_args()
def main():
args = parse_args()
flexkv_config = {
"server_recv_port": f"ipc:///tmp/flexkv_test_{os.getpid()}",
"cache_config": {
"enable_cpu": True,
"num_cpu_blocks": 10240,
},
"num_log_interval_requests": 200,
}
flexkv_config_path = f"./flexkv_config_{os.getpid()}.json"
with open(flexkv_config_path, "w") as f:
json.dump(flexkv_config, f)
os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path
try:
_run(args)
finally:
if os.path.exists(flexkv_config_path):
os.remove(flexkv_config_path)
def _run(args):
# Common prefix.
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: "
)
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
generating_prompts = [prefix + prompt for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)
kv_transfer_config = {
"kv_connector": "FlexKVConnectorV1",
"kv_role": "kv_both",
}
# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(
model=args.model,
enable_prefix_caching=False,
gpu_memory_utilization=args.gpu_memory_util,
tensor_parallel_size=args.tp_size,
)
print("Results without `enable_prefix_caching`")
# ruff: noqa: E501
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = regular_llm.generate(generating_prompts, sampling_params)
regular_generated_texts = []
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Destroy the LLM object and free up the GPU memory.
del regular_llm
cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(
model=args.model,
enable_prefix_caching=True,
gpu_memory_utilization=args.gpu_memory_util,
tensor_parallel_size=args.tp_size,
kv_transfer_config=kv_transfer_config,
)
# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
# wait for offload kv task finished.
time.sleep(2)
# Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
print("Results with `enable_prefix_caching`")
cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Compare the results and display the speedup
generated_same = all(
regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts))
)
print(f"Generated answers are the same: {generated_same}")
# wait for offload kv task finished.
time.sleep(2)
# reset prefix cache to use flexkv
prefix_cached_llm.reset_prefix_cache()
# Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
print("Results with `flexkv`")
flexkv_generated_texts = []
# Print the outputs. You should see the same outputs as before.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
flexkv_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Compare the results and display the speedup
generated_same = all(
regular_generated_texts[i] == flexkv_generated_texts[i]
for i in range(len(prompts))
)
print(f"Generated answers are the same: {generated_same}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates how to generate prompt embeddings using
Hugging Face Transformers and use them as input to vLLM
for both single and batch inference.
Model: meta-llama/Llama-3.2-1B-Instruct
Note: This model is gated on Hugging Face Hub.
You must request access to use it:
https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
Requirements:
- vLLM
- transformers
Run:
python examples/offline_inference/prompt_embed_inference.py
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from vllm import LLM
def init_tokenizer_and_llm(model_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
embedding_layer = transformers_model.get_input_embeddings()
llm = LLM(model=model_name, enable_prompt_embeds=True)
return tokenizer, embedding_layer, llm
def get_prompt_embeds(
chat: list[dict[str, str]],
tokenizer: PreTrainedTokenizer,
embedding_layer: torch.nn.Module,
):
token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt", return_dict=True
).input_ids
prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds
def single_prompt_inference(
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
):
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
outputs = llm.generate(
{
"prompt_embeds": prompt_embeds,
}
)
print("\n[Single Inference Output]")
print("-" * 30)
for o in outputs:
print(o.outputs[0].text)
print("-" * 30)
def batch_prompt_inference(
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
):
chats = [
[{"role": "user", "content": "Please tell me about the capital of France."}],
[{"role": "user", "content": "When is the day longest during the year?"}],
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
]
prompt_embeds_list = [
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
]
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
print("\n[Batch Inference Outputs]")
print("-" * 30)
for i, o in enumerate(outputs):
print(f"Q{i + 1}: {chats[i][0]['content']}")
print(f"A{i + 1}: {o.outputs[0].text}\n")
print("-" * 30)
def main():
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
single_prompt_inference(llm, tokenizer, embedding_layer)
batch_prompt_inference(llm, tokenizer, embedding_layer)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,39 @@
# Qwen2.5-Omni Offline Inference Examples
This folder provides several example scripts on how to inference Qwen2.5-Omni offline.
## Thinker Only
```bash
# Audio + image + video
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
-q mixed_modalities
# Read vision and audio inputs from a single video file
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
-q use_audio_in_video
# Multiple audios
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
-q multi_audios
```
This script will run the thinker part of Qwen2.5-Omni, and generate text response.
You can also test Qwen2.5-Omni on a single modality:
```bash
# Process audio inputs
python examples/offline_inference/audio_language.py \
--model-type qwen2_5_omni
# Process image inputs
python examples/offline_inference/vision_language.py \
--modality image \
--model-type qwen2_5_omni
# Process video inputs
python examples/offline_inference/vision_language.py \
--modality video \
--model-type qwen2_5_omni
```

View File

@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only).
"""
from typing import NamedTuple
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.utils.argparse_utils import FlexibleArgumentParser
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
def get_mixed_modalities_query() -> QueryResult:
question = (
"What is recited in the audio? "
"What is the content of this image? Why is this video funny?"
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image": convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB"
),
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
def get_use_audio_in_video_query() -> QueryResult:
question = (
"Describe the content of the video, then convert what the baby say into text."
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={"audio": 1, "video": 1},
)
def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
def get_multi_images_query() -> QueryResult:
question = "What are the differences between these two images?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": [
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
],
},
},
limit_mm_per_prompt={
"image": 2,
},
)
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
"multi_images": get_multi_images_query,
}
def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B"
query_result = query_map[args.query_type]()
llm = LLM(
model=model_name,
max_model_len=5632,
max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed,
)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"audio language models"
)
parser.add_argument(
"--query-type",
"-q",
type=str,
default="mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,223 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen3-Omni (thinker only).
"""
from typing import NamedTuple
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.utils.argparse_utils import FlexibleArgumentParser
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
def get_mixed_modalities_query() -> QueryResult:
question = (
"What is recited in the audio? "
"What is the content of this image? Why is this video funny?"
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image": convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB"
),
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
def get_use_audio_in_video_query() -> QueryResult:
question = (
"Describe the content of the video in details, then convert what the "
"baby say into text."
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={"audio": 1, "video": 1},
)
def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|audio_start|><|audio_pad|><|audio_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
def get_multi_images_query() -> QueryResult:
question = "What are the differences between these two images?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": [
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
],
},
},
limit_mm_per_prompt={
"image": 2,
},
)
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
"multi_images": get_multi_images_query,
}
def main(args):
model_name = args.model
query_result = query_map[args.query_type]()
llm = LLM(
model=model_name,
max_model_len=args.max_model_len,
max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed,
tensor_parallel_size=args.tensor_parallel_size,
gpu_memory_utilization=args.gpu_memory_utilization,
)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"audio language models"
)
parser.add_argument(
"--query-type",
"-q",
type=str,
default="mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
help="Model name or path.",
)
parser.add_argument(
"--tensor-parallel-size",
"-tp",
type=int,
default=1,
help="Tensor parallel size for distributed inference.",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.9,
help="GPU memory utilization (0.0 to 1.0).",
)
parser.add_argument(
"--max-model-len",
type=int,
default=12800,
help="Maximum model context length.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from urllib.request import urlopen
from vllm import LLM, SamplingParams
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
def load_prompt() -> str:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
timeout=5,
) as response:
prompt = response.read().decode("utf-8")
return prompt
# Processing the prompt.
def process_requests(llm: LLM, prompts: list[str]) -> None:
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
top_k=20,
repetition_penalty=1.05,
detokenize=True,
max_tokens=256,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text
print(
f"Prompt length: {len(prompt_token_ids)}, "
f"Generated text: {generated_text!r}"
)
# Create an LLM.
def initialize_engine() -> LLM:
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct-1M",
max_model_len=1048576,
tensor_parallel_size=4,
enforce_eager=True,
enable_chunked_prefill=True,
max_num_batched_tokens=131072,
)
return llm
def main():
llm = initialize_engine()
prompt = load_prompt()
process_requests(llm, [prompt])
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates how to achieve reproducibility in vLLM.
Main article: https://docs.vllm.ai/en/latest/usage/reproducibility.html
"""
import os
import random
from vllm import LLM, SamplingParams
# Either:
## Turn off multiprocessing to make the scheduling deterministic, or
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
## Enable batch invariance to get consistent results regardless of scheduling.
os.environ["VLLM_BATCH_INVARIANT"] = "1"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
llm = LLM(model="facebook/opt-125m")
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Try generating random numbers outside vLLM
# The same number is output across runs, meaning that the random state
# in the user code has been updated by vLLM
print(random.randint(0, 100))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
tensor-parallel vLLM inference engine occupies GPU 12.
The example performs the following steps:
* Load the training model on GPU 0.
* Split the inference model across GPUs 12 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.
For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.utils.network_utils import get_ip, get_open_port
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
super().__init__(*args, **kwargs)
# Load the OPT-125M model onto GPU 0 for the training workload.
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Set up the communication channel between the training process and the
# inference engine.
master_address = get_ip()
master_port = get_open_port()
handle = llm.collective_rpc.remote(
"init_weight_update_group", args=(master_address, master_port, 1, 3)
)
model_update_group = stateless_init_process_group(
master_address, master_port, 0, 3, torch.device("cuda:0")
)
ray.get(handle)
# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for name, p in train_model.named_parameters():
p.data.zero_()
# Synchronize the updated weights to the inference engine.
for name, p in train_model.named_parameters():
dtype_name = str(p.dtype).split(".")[-1]
handle = llm.collective_rpc.remote(
"update_weight", args=(name, dtype_name, p.shape)
)
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)
# Verify that the inference weights have been updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
# Generate text with the updated model. The output is expected to be nonsense
# because the weights are zero.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

View File

@@ -0,0 +1,256 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates how to co-locate a vLLM inference worker and training
actors on the same set of GPUs for reinforcement learning from human feedback
(RLHF) workloads.
Ray serves as the distributed execution framework in this example. Ray
placement groups allocate both training actors and vLLM workers to the
same GPU bundles, enabling fast, in-GPU communication between the two
components.
The script shows how to do the following:
* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and
`VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired
devices.
* Exchange tensors between processes by means of CUDA inter-process
communication (IPC). CUDA IPC sidesteps NCCL limitations that occur
when multiple processes share a single GPU.
Note that this example assumes a single-node cluster with four GPUs, but Ray
supports multi-node clusters. vLLM expects exclusive use of the GPUs during
its initialization for memory profiling. Residual GPU activity interferes
with vLLM memory profiling and causes unexpected behavior.
Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
"""
import gc
import os
import sys
import ray
import torch
import zmq
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.multiprocessing.reductions import reduce_tensor
from vllm import LLM
if torch.version.hip is not None:
print("Skipping test for ROCm. Ray is unsupported on vLLM ROCm.")
sys.exit(0)
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution.
The constructor sets environment variables that allow multiple vLLM
workers to share a single physical GPU and that encode the bundle
indices assigned by the placement group.
Args:
*args: Positional arguments forwarded to `vllm.LLM`.
bundle_indices (list[int]): Placement-group bundle indices
assigned to this worker.
**kwargs: Keyword arguments forwarded to `vllm.LLM`.
"""
def __init__(self, *args, bundle_indices: list[int], **kwargs):
# Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable
# so that vLLM can its own device placement inside the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
print(f"creating LLM with bundle_indices={bundle_indices}")
super().__init__(*args, **kwargs)
class RayTrainingActor:
"""Training actor that hosts a Facebook OPT-125M model from Hugging Face.
The model is loaded onto the first GPU assigned to this actor, and expose
the CUDA IPC handles so that colocated vLLM workers can map tensors
directly.
"""
def __init__(self):
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0")
# Zero out all the parameters.
for name, p in self.model.named_parameters():
p.data.zero_()
torch.accelerator.synchronize()
# The argument for `get_device_uuid` is the index of the GPU in the
# list of visible devices.
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(0)
self.zmq_context = zmq.Context()
self.zmq_address_counter = 0
self.zmq_handle = None
def report_device_id(self) -> str:
return self.device_uuid
def get_zmq_handles(self) -> dict[str, str]:
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
self.zmq_address_counter += 1
return {self.device_uuid: self.zmq_handle}
def update_weights(self):
# align size to avoid misaligned address
align_size = 256
def get_size(p: torch.Tensor) -> int:
return (p.nbytes + align_size - 1) // align_size * align_size
named_parameters: dict[str, torch.nn.Parameter] = dict(
self.model.named_parameters()
)
max_tensor_size = max(get_size(p) for p in named_parameters.values())
# use max_tensor_size * 2 as buffer size
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
s = self.zmq_context.socket(zmq.REQ)
s.bind(self.zmq_handle)
handle = reduce_tensor(buffer)
offset = 0
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
named_tensors: list[dict] = []
real_tensors: list[torch.Tensor] = []
for name, p in named_parameters.items():
size = get_size(p)
if offset + size > buffer.numel():
buckets.append((named_tensors, real_tensors))
named_tensors, real_tensors = [], []
offset = 0
# assume tensors are contiguous
named_tensors.append(
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
)
real_tensors.append(p)
offset += size
if named_tensors:
buckets.append((named_tensors, real_tensors))
s.send_pyobj(handle)
s.recv()
for named_tensors, real_tensors in buckets:
offset = 0
for p in real_tensors:
buffer[offset : offset + p.nbytes].data.copy_(
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
)
offset += get_size(p)
torch.accelerator.synchronize()
s.send_pyobj(named_tensors)
s.recv()
s.send_pyobj(None)
s.recv()
s.close()
del buffer
gc.collect()
torch.accelerator.empty_cache()
# Ray manages four GPUs.
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()
# Co-locate vLLM instances and training actors on the same set of GPUs:
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
# (tensor parallelism = 2).
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
# (tensor parallelism = 2).
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")
training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []
for bundle_index in [0, 1, 2, 3]:
training_actor = ray.remote(
num_cpus=0,
num_gpus=0.4,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)
for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.4,
bundle_indices=bundle_indices,
)
inference_engines.append(llm)
# Do not call any method on the inference engine at this point; the call
# blocks until the vLLM instance finishes initialization.
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print(f"ZMQ handles: {zmq_handles}")
print("Update the weights of the inference engines.")
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
for llm in inference_engines
]
)
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
tensor-parallel vLLM inference engine occupies GPU 12.
The example performs the following steps:
* Load the training model on GPU 0.
* Split the inference model across GPUs 12 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.
For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import json
import os
import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from torchao.core.config import config_to_dict
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
PerRow,
)
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.utils.network_utils import get_ip, get_open_port
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
super().__init__(*args, **kwargs)
# Load the OPT-125M model onto GPU 0 for the training workload.
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# generate torchao quantization config for RL rollout
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
# use serialized config files instead of passing around json string
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
json_str = json.dumps(config_to_dict(config))
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model="facebook/opt-125m",
hf_overrides={"quantization_config_dict_json": json_str},
enforce_eager=True,
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Set up the communication channel between the training process and the
# inference engine.
master_address = get_ip()
master_port = get_open_port()
handle = llm.collective_rpc.remote(
"init_weight_update_group", args=(master_address, master_port, 1, 3)
)
model_update_group = stateless_init_process_group(
master_address, master_port, 0, 3, torch.device("cuda:0")
)
ray.get(handle)
# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for name, p in train_model.named_parameters():
p.data.zero_()
# Synchronize the updated weights to the inference engine.
for name, p in train_model.named_parameters():
dtype_name = str(p.dtype).split(".")[-1]
handle = llm.collective_rpc.remote(
"update_weight", args=(name, dtype_name, p.shape)
)
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)
# Verify that the inference weights have been updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
# Generate text with the updated model. The output is expected to be nonsense
# because the weights are zero.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

View File

@@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from collections.abc import Callable
from typing import TypedDict
import torch
import zmq
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(
host=master_address, port=master_port, rank=rank, world_size=world_size
)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
class WorkerExtension:
"""
The class for vLLM's worker to inherit from.
By defining an extension class, the code can work no matter what is
the underlying worker class.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def init_weight_update_group(
self, master_address, master_port, rank_offset, world_size
):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)
def update_weight(self, name, dtype_name, shape):
dtype = getattr(torch, dtype_name)
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()
)
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
return weights_updated
def rebuild_ipc(
handle: tuple[Callable, tuple], device_id: int | None = None
) -> torch.Tensor:
func, args = handle
list_args = list(args)
if device_id is not None:
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
buffer = func(*list_args)
return buffer
class FlattenedTensorMetadata(TypedDict):
name: str
shape: torch.Size
dtype: torch.dtype
# specify the start offset of this tensor in shared ipc_buffer tensor
offset: int
class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an extension class, the code can work no matter what is
the underlying worker class.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
from vllm.model_executor.model_loader.utils import process_weights_after_loading
assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
socket = self._zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handles[self.report_device_id()])
buffer: torch.Tensor | None = None
while True:
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
socket.recv_pyobj()
)
if payload is None:
# means the update is done
process_weights_after_loading(
self.model_runner.model, self.model_config, self.device
)
torch.accelerator.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
# an ipc handle that vLLM can use `func, args = handle`
# and `func(*args)` to rebuild GPU tensor.
buffer = rebuild_ipc(payload, self.device.index)
assert buffer.dtype == torch.uint8
socket.send(b"")
continue
assert isinstance(payload, list)
assert buffer is not None
weights = []
for item in payload:
shape = item["shape"]
if isinstance(shape, (list, tuple)):
shape = torch.Size(shape)
assert isinstance(shape, torch.Size)
dtype, offset = item["dtype"], item["offset"]
size = dtype.itemsize * shape.numel()
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights)
del weights
torch.accelerator.synchronize()
socket.send(b"")
socket.close()
del buffer
gc.collect()
torch.accelerator.empty_cache()
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
return weights_updated

View File

@@ -0,0 +1,384 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
End-to-end example for routed experts capture with hybrid models.
Validates that:
1. routed_experts is returned in CompletionOutput for MoE models.
2. Expert IDs are within valid range.
3. Results are deterministic across runs (baseline vs reference).
Usage:
python examples/offline_inference/routed_experts_e2e.py \
--model Qwen/Qwen3-30B-A3B \
--tp 4 \
--max-model-len 4096 \
--num-prompts 20 \
--max-new-tokens 50
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import os
import uuid
from dataclasses import dataclass, field
import numpy as np
from vllm.engine.arg_utils import AsyncEngineArgs
logger = logging.getLogger(__name__)
DEFAULT_MODEL = "Qwen/Qwen3-30B-A3B"
TEST_PROMPTS = [
"Hello, my name is",
"The capital of France is",
"Explain quantum computing in simple terms:",
"Write a Python function that sorts a list:",
"The meaning of life is",
"In a distant galaxy, there was a",
"The best way to learn programming is",
"Once upon a time in a land far away,",
"The theory of relativity states that",
"How does photosynthesis work?",
"Describe the process of machine learning:",
"What are the benefits of exercise?",
"The history of artificial intelligence began",
"Translate the following to French: Hello world",
"Summarize the plot of Romeo and Juliet:",
"What is the difference between TCP and UDP?",
"The water cycle consists of",
"Explain how a neural network learns:",
"The periodic table organizes elements by",
"Write a haiku about the ocean:",
]
@dataclass
class InferenceResult:
"""Result from a single inference run."""
experts_list: list[np.ndarray] = field(default_factory=list)
token_ids_list: list[list[int]] = field(default_factory=list)
num_experts: int = 0
# ---------------------------------------------------------------------------
# Inference helpers
# ---------------------------------------------------------------------------
async def _run_async_inference(
engine_args: AsyncEngineArgs,
prompts: list[str],
max_new_tokens: int,
) -> InferenceResult:
"""Run inference using AsyncLLM."""
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
engine = AsyncLLM.from_engine_args(engine_args)
hf_config = engine.model_config.hf_text_config
num_experts: int = getattr(hf_config, "num_experts", 0) or getattr(
hf_config, "num_local_experts", 0
)
assert num_experts > 0, "Could not determine num_experts from model config"
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_new_tokens,
)
async def _generate_one(prompt: str, idx: int):
request_id = str(uuid.uuid4())
final_output = None
async for output in engine.generate(prompt, sampling_params, request_id):
final_output = output
assert final_output is not None
completion = final_output.outputs[0]
routed = completion.routed_experts
num_prompt_tokens = len(final_output.prompt_token_ids)
num_generated_tokens = len(completion.token_ids)
expected_len = num_prompt_tokens + num_generated_tokens - 1
assert routed is not None, f"Prompt {idx}: routed_experts is None"
assert routed.shape[0] == expected_len, (
f"Prompt {idx}: routed_experts length {routed.shape[0]} != "
f"prompt ({num_prompt_tokens}) + generated ({num_generated_tokens})"
f" - 1 = {expected_len}"
)
return idx, routed, list(completion.token_ids)
tasks = [_generate_one(p, i) for i, p in enumerate(prompts)]
outputs = await asyncio.gather(*tasks)
# Sort by original index to maintain prompt order
outputs.sort(key=lambda x: x[0])
result = InferenceResult(num_experts=num_experts)
for _, routed, token_ids in outputs:
result.experts_list.append(routed)
result.token_ids_list.append(token_ids)
engine.shutdown()
return result
def run_inference(
model: str,
prompts: list[str],
max_new_tokens: int = 50,
tp: int = 1,
max_model_len: int = 4096,
) -> InferenceResult:
"""Run inference with routed experts capture enabled via AsyncLLM."""
engine_args = AsyncEngineArgs(
model=model,
enable_return_routed_experts=True,
tensor_parallel_size=tp,
max_model_len=max_model_len,
disable_log_stats=True,
attention_backend="FLASH_ATTN",
)
result = asyncio.run(_run_async_inference(engine_args, prompts, max_new_tokens))
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
current_platform.empty_cache()
return result
# ---------------------------------------------------------------------------
# Validation helpers
# ---------------------------------------------------------------------------
def validate_expert_ids(
experts_list: list[np.ndarray],
num_experts: int,
) -> None:
"""Check that all expert IDs are within valid range [0, num_experts)."""
for i, experts in enumerate(experts_list):
assert np.all(experts >= 0), (
f"Prompt {i}: negative expert IDs found, min={experts.min()}"
)
assert np.all(experts < num_experts), (
f"Prompt {i}: expert ID out of range [0, {num_experts}), "
f"max={experts.max()}"
)
def validate_shapes(experts_list: list[np.ndarray]) -> None:
"""Check that all routed_experts arrays have at least 2 dimensions."""
for i, experts in enumerate(experts_list):
assert experts.ndim >= 2, (
f"Prompt {i}: expected at least 2D array, got shape {experts.shape}"
)
logger.info("Prompt %d: routed_experts shape = %s", i, experts.shape)
# ---------------------------------------------------------------------------
# Comparison helpers
# ---------------------------------------------------------------------------
def compare_token_ids(
baseline: list[list[int]],
reference: list[list[int]],
) -> float:
"""Compare token IDs from two runs. Returns mismatch ratio."""
assert len(baseline) == len(reference), (
f"Length mismatch: {len(baseline)} vs {len(reference)}"
)
total_tokens = 0
total_mismatches = 0
for i, (base, ref) in enumerate(zip(baseline, reference)):
min_len = min(len(base), len(ref))
max_len = max(len(base), len(ref))
matches = 0
for a, b in zip(base[:min_len], ref[:min_len]):
if a != b:
break
matches += 1
total_mismatches += max_len - matches
total_tokens += max_len
if matches < min_len or len(base) != len(ref):
print(
f" Prompt {i}: token_ids len={len(base)} vs {len(ref)}, "
f"mismatches={max_len - matches}/{max_len}"
)
if total_tokens == 0:
raise ValueError("No tokens to compare")
mismatch_ratio = total_mismatches / total_tokens
print(
f"Token ID mismatches: {total_mismatches}/{total_tokens} ({mismatch_ratio:.4%})"
)
return mismatch_ratio
def compare_routed_experts(
baseline: list[np.ndarray],
reference: list[np.ndarray],
threshold: float = 0.05,
) -> float:
"""Compare two runs of routed experts. Returns mismatch ratio.
Raises AssertionError if ratio exceeds threshold.
"""
assert len(baseline) == len(reference), (
f"Length mismatch: {len(baseline)} vs {len(reference)}"
)
total_elements = 0
total_mismatches = 0
for i, (base, ref) in enumerate(zip(baseline, reference)):
min_len = min(len(base), len(ref))
max_len = max(len(base), len(ref))
if min_len == 0:
continue
base_trimmed = base[:min_len]
ref_trimmed = ref[:min_len]
matches = 0
for a, b in zip(base_trimmed, ref_trimmed):
if a.sum() != b.sum():
break
matches += 1
total_mismatches += max_len - matches
total_elements += max_len
if matches < min_len or len(base) != len(ref):
print(
f" Prompt {i}: routed_experts len={len(base)} vs {len(ref)}, "
f"mismatches={max_len - matches}/{max_len}"
)
if total_elements == 0:
raise ValueError("No elements to compare")
mismatch_ratio = total_mismatches / total_elements
print(
f"Routed experts mismatches: {total_mismatches}/{total_elements} "
f"({mismatch_ratio:.4%})"
)
assert mismatch_ratio < threshold, (
f"Too many mismatches: {total_mismatches}/{total_elements} "
f"({mismatch_ratio:.4%}) exceeds threshold {threshold:.4%}"
)
return mismatch_ratio
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
os.environ.setdefault("VLLM_BATCH_INVARIANT", "1")
parser = argparse.ArgumentParser(
description="Test routed experts capture for MoE models"
)
parser.add_argument("--model", type=str, default=DEFAULT_MODEL)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--max-model-len", type=int, default=4096)
parser.add_argument("--num-prompts", type=int, default=20)
parser.add_argument("--max-new-tokens", type=int, default=50)
parser.add_argument(
"--deterministic",
action="store_true",
help="Run twice and compare results for determinism check",
)
parser.add_argument(
"--threshold",
type=float,
default=0.05,
help="Maximum allowed mismatch ratio for determinism check",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
prompts = TEST_PROMPTS[: args.num_prompts]
print(f"Model: {args.model}")
print(f"TP: {args.tp}")
print(f"Prompts: {len(prompts)}")
print(f"Max new tokens: {args.max_new_tokens}")
print()
print("=== Run 1 (baseline) ===")
baseline = run_inference(
model=args.model,
prompts=prompts,
max_new_tokens=args.max_new_tokens,
tp=args.tp,
max_model_len=args.max_model_len,
)
print(f"num_experts (from model config): {baseline.num_experts}")
print("\n=== Validation ===")
validate_shapes(baseline.experts_list)
validate_expert_ids(baseline.experts_list, num_experts=baseline.num_experts)
print(f"All {len(baseline.experts_list)} results passed validation.")
for i, experts in enumerate(baseline.experts_list):
print(
f" Prompt {i}: shape={experts.shape}, "
f"min={experts.min()}, max={experts.max()}"
)
if args.deterministic:
print("\n=== Run 2 (reference) ===")
reference = run_inference(
model=args.model,
prompts=prompts,
max_new_tokens=args.max_new_tokens,
tp=args.tp,
max_model_len=args.max_model_len,
)
print("\n=== Determinism Check ===")
validate_expert_ids(reference.experts_list, num_experts=baseline.num_experts)
print("\n--- Token IDs ---")
token_mismatch = compare_token_ids(
baseline.token_ids_list, reference.token_ids_list
)
print("\n--- Routed Experts ---")
expert_mismatch = compare_routed_experts(
baseline.experts_list,
reference.experts_list,
threshold=args.threshold,
)
print(
f"\nDeterminism check passed. "
f"Token mismatch: {token_mismatch:.4%}, "
f"Expert mismatch: {expert_mismatch:.4%}"
)
print("\nAll tests passed!")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from vllm import LLM, EngineArgs
from vllm.config import ProfilerConfig
from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MAX_TOKENS = 16
def create_parser() -> FlexibleArgumentParser:
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
batch_group = parser.add_argument_group("Batch parameters")
batch_group.add_argument("--batch-size", type=int, default=1)
batch_group.add_argument("--prompt-size", type=int, default=128)
batch_group.add_argument("--prompt-prefix", type=str, default="Hello, my name is")
profile_group = parser.add_argument_group("Profiling parameters")
profile_group.add_argument(
"--profile",
choices=["none", "prefill", "decode", "both"],
default="none",
)
profile_group.add_argument(
"--profile-dir",
type=str,
default="",
help="Required when --profile is not 'none'.",
)
return parser
def _build_prompt(prefix: str, prompt_size: int) -> str:
if prompt_size <= 0:
return ""
if not prefix:
prefix = " "
if len(prefix) >= prompt_size:
return prefix[:prompt_size]
repeat_count = (prompt_size + len(prefix) - 1) // len(prefix)
return (prefix * repeat_count)[:prompt_size]
def _build_profiler_config(
profile: str, profile_dir: str, max_tokens: int
) -> ProfilerConfig | None:
if profile == "none":
return None
if not profile_dir:
raise ValueError("--profile-dir must be set when profiling is enabled.")
if profile == "prefill":
delay_iterations = 0
max_iterations = 1
elif profile == "decode":
delay_iterations = 1
max_iterations = max(1, max_tokens)
else:
delay_iterations = 0
max_iterations = 0
return ProfilerConfig(
profiler="torch",
torch_profiler_dir=profile_dir,
delay_iterations=delay_iterations,
max_iterations=max_iterations,
)
def main(args: dict) -> None:
max_tokens = DEFAULT_MAX_TOKENS
batch_size = args.pop("batch_size")
prompt_size = args.pop("prompt_size")
prompt_prefix = args.pop("prompt_prefix")
profile = args.pop("profile")
profile_dir = args.pop("profile_dir")
profiler_config = _build_profiler_config(profile, profile_dir, max_tokens)
if profiler_config is not None:
args["profiler_config"] = profiler_config
llm = LLM(**args)
sampling_params = llm.get_default_sampling_params()
sampling_params.max_tokens = max_tokens
sampling_params.min_tokens = max_tokens
sampling_params.ignore_eos = True
prompt = _build_prompt(prompt_prefix, prompt_size)
prompts = [prompt] * batch_size
if profile != "none":
llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
if profile != "none":
llm.stop_profile()
print("-" * 50)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Prompt: {output.prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
parser = create_parser()
main(vars(parser.parse_args()))

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py \
--model /path/to/load \
--tensor-parallel-size 8 \
--output /path/to/save
Then, the model can be loaded with
llm = LLM(
model="/path/to/save",
load_format="sharded_state",
tensor_parallel_size=8,
)
"""
import dataclasses
import os
import shutil
from pathlib import Path
from vllm import LLM, EngineArgs
from vllm.model_executor.model_loader import ShardedStateLoader
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument(
"--output", "-o", required=True, type=str, help="path to output checkpoint"
)
parser.add_argument(
"--file-pattern",
type=str,
default=ShardedStateLoader.DEFAULT_PATTERN,
help="string pattern of saved filenames",
)
parser.add_argument(
"--max-file-size",
type=int,
default=5 * 1024**3,
help="max size (in bytes) of each safetensors file",
)
return parser.parse_args()
def main(args):
engine_args = EngineArgs.from_cli_args(args)
if engine_args.enable_lora:
raise ValueError("Saving with enable_lora=True is not supported!")
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare output directory
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
llm.llm_engine.engine_core.save_sharded_state(
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(
os.path.join(model_path, file), os.path.join(args.output, file)
)
else:
shutil.copy(os.path.join(model_path, file), args.output)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=1,
profiler_config={
"profiler": "torch",
"torch_profiler_dir": "./vllm_profile",
},
)
llm.start_profile()
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
llm.stop_profile()
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Add a buffer to wait for profiler in the background process
# (in case MP is on) to finish writing profiling output.
time.sleep(10)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, RequestOutput, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def print_prompts_and_outputs(outputs: list[RequestOutput]) -> None:
print("-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
def main():
# Create an LLM without loading real weights
llm = LLM(
model="Qwen/Qwen3-0.6B",
load_format="dummy",
enforce_eager=True,
tensor_parallel_size=4,
)
outputs = llm.generate(prompts, sampling_params)
print("\nOutputs do not make sense:")
print_prompts_and_outputs(outputs)
# Update load format from `dummy` to `auto`
llm.collective_rpc(
"update_config", args=({"load_config": {"load_format": "auto"}},)
)
# Now reload real weights inplace
llm.collective_rpc("reload_weights")
# Check outputs make sense
outputs = llm.generate(prompts, sampling_params)
print("\nOutputs make sense after loading real weights:")
print_prompts_and_outputs(outputs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,258 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.metrics.reader import Counter, Vector
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
]
def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
parser.add_argument("--test", action="store_true")
parser.add_argument(
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
)
parser.add_argument("--backend", type=str, default="openai")
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
parser.add_argument("--prompt-lookup-min", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true")
parser.add_argument("--enable-chunked-prefill", action="store_true")
parser.add_argument("--max-model-len", type=int, default=16384)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=-1)
parser.add_argument("--print-output", action="store_true")
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--parallel-drafting", action="store_true")
parser.add_argument("--allowed-local-media-path", type=str, default="")
return parser.parse_args()
def main(args):
model_dir = args.model_dir
if args.model_dir is None:
if args.custom_mm_prompts:
raise ValueError(
"custom_mm_prompts requires mm based models"
"default llama3.1-8b-instruct is not mm based"
"please specify model_dir to give a mm based model"
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if args.custom_mm_prompts:
prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
else:
prompts = get_samples(args, tokenizer)
if args.enable_multimodal_chat:
llm_prompts = [p.prompt for p in prompts]
else:
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
llm_prompts = [
{
"prompt_token_ids": tokenizer.encode(
prompt.prompt, add_special_tokens=False
),
"multi_modal_data": prompt.multi_modal_data,
}
for prompt in prompts
]
if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
if args.method == "eagle" and eagle_dir is None:
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
elif args.method == "eagle3" and eagle_dir is None:
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
speculative_config = {
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
"parallel_drafting": args.parallel_drafting,
}
elif args.method == "ngram":
speculative_config = {
"method": "ngram",
"num_speculative_tokens": args.num_spec_tokens,
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "draft_model":
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
"model": args.draft_model,
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
"parallel_drafting": args.parallel_drafting,
}
elif args.method == "mtp":
speculative_config = {
"method": "mtp",
"num_speculative_tokens": args.num_spec_tokens,
}
else:
raise ValueError(f"unknown method: {args.method}")
llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=args.gpu_memory_utilization,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
max_num_seqs=args.max_num_seqs,
allowed_local_media_path=args.allowed_local_media_path,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
if args.backend == "openai-chat":
outputs = llm.chat(llm_prompts, sampling_params=sampling_params)
else:
outputs = llm.generate(
llm_prompts,
sampling_params=sampling_params,
)
# print the generated text
if args.print_output:
for i, output in enumerate(outputs):
print("-" * 50)
if not args.custom_mm_prompts:
print(f"prompt: {prompts[i].prompt}")
else:
print(f"prompt: {prompts[i]}")
print(f"generated text: {output.outputs[0].text}")
print("-" * 50)
metrics = llm.get_metrics()
total_num_output_tokens = sum(
len(output.outputs[0].token_ids) for output in outputs
)
num_drafts = 0
num_draft_tokens = 0
num_accepted_tokens = 0
acceptance_counts = [0] * args.num_spec_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_draft_tokens":
assert isinstance(metric, Counter)
num_draft_tokens += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
assert isinstance(metric, Counter)
num_accepted_tokens += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
acceptance_counts[pos] += metric.values[pos]
print("-" * 50)
print(f"total_num_output_tokens: {total_num_output_tokens}")
print(f"num_drafts: {num_drafts}")
print(f"num_draft_tokens: {num_draft_tokens}")
print(f"num_accepted_tokens: {num_accepted_tokens}")
acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
print(f"mean acceptance length: {acceptance_length:.2f}")
print("-" * 50)
# print acceptance at each token position
for i in range(len(acceptance_counts)):
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
print(f"acceptance at token {i}: {acceptance_rate:.2f}")
return acceptance_length
if __name__ == "__main__":
args = parse_args()
args.enable_multimodal_chat = args.backend == "openai-chat"
acceptance_length = main(args)
if args.test:
# takes ~30s to run on 1xH100
assert args.method in ["eagle", "eagle3"]
assert args.tp == 1
assert args.num_spec_tokens == 3
assert args.dataset_name == "hf"
assert args.dataset_path == "philschmid/mt-bench"
assert args.num_prompts == 80
assert args.temp == 0
assert args.top_p == 1.0
assert args.top_k == -1
assert args.enable_chunked_prefill
# check acceptance length is within 2% of expected value
rtol = 0.02
expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811
assert (
acceptance_length <= (1 + rtol) * expected_acceptance_length
and acceptance_length >= (1 - rtol) * expected_acceptance_length
), (
f"acceptance_length {acceptance_length} is not "
f"within {rtol * 100}% of {expected_acceptance_length}"
)
print(
f"Test passed! Expected AL: "
f"{expected_acceptance_length}, got {acceptance_length}"
)

View File

@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates the example usage of structured outputs
in vLLM. It shows how to apply different constraints such as choice,
regex, json schema, and grammar to produce structured and formatted
results based on specific prompts.
"""
from enum import Enum
from pydantic import BaseModel
from vllm import LLM, SamplingParams
from vllm.sampling_params import StructuredOutputsParams
MAX_TOKENS = 50
# Structured outputs by Choice (list of possible options)
structured_outputs_params_choice = StructuredOutputsParams(
choice=["Positive", "Negative"]
)
sampling_params_choice = SamplingParams(
structured_outputs=structured_outputs_params_choice
)
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
# Structured outputs by Regex
structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams(
structured_outputs=structured_outputs_params_regex,
stop=["\n"],
max_tokens=MAX_TOKENS,
)
prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n"
)
# Structured outputs by JSON using Pydantic schema
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
json_schema = CarDescription.model_json_schema()
structured_outputs_params_json = StructuredOutputsParams(json=json_schema)
sampling_params_json = SamplingParams(
structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS
)
prompt_json = (
"Generate a JSON with the brand, model and car_type of "
"the most iconic car from the 90's"
)
# Structured outputs by Grammar
simplified_sql_grammar = """
root ::= select_statement
select_statement ::= "SELECT " column " from " table " where " condition
column ::= "col_1 " | "col_2 "
table ::= "table_1 " | "table_2 "
condition ::= column "= " number
number ::= "1 " | "2 "
"""
structured_outputs_params_grammar = StructuredOutputsParams(
grammar=simplified_sql_grammar
)
sampling_params_grammar = SamplingParams(
structured_outputs=structured_outputs_params_grammar,
max_tokens=MAX_TOKENS,
)
prompt_grammar = (
"Generate an SQL query to show the 'username' and 'email' from the 'users' table."
)
def format_output(title: str, output: str):
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
outputs = llm.generate(prompt, sampling_params=sampling_params)
return outputs[0].outputs[0].text
def main():
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
format_output("Structured outputs by Choice", choice_output)
regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
format_output("Structured outputs by Regex", regex_output)
json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Structured outputs by JSON", json_output)
grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
format_output("Structured outputs by Grammar", grammar_output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
experimental support for data-parallel inference with torchrun
Note the data load balancing and distribution is done out of the vllm engine,
no internal lb supported in external_launcher mode.
To run this example:
```bash
$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py
```
With custom parallelism settings:
```bash
$ torchrun --nproc-per-node=8 examples/offline_inference/torchrun_dp_example.py \
--tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
```
"""
import argparse
from vllm import LLM, SamplingParams
def parse_args():
parser = argparse.ArgumentParser(
description="Data-parallel inference with torchrun"
)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size (default: 1)",
)
parser.add_argument(
"--pp-size",
type=int,
default=1,
help="Pipeline parallel size (default: 1)",
)
parser.add_argument(
"--dp-size",
type=int,
default=2,
help="Data parallel size (default: 2)",
)
parser.add_argument(
"--enable-ep",
action="store_true",
help="Enable expert parallel (default: False)",
)
parser.add_argument(
"--model",
type=str,
default="microsoft/Phi-mini-MoE-instruct",
help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
)
parser.add_argument(
"--max-model-len",
type=int,
default=4096,
help="Maximum model length (default: 4096)",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.6,
help="GPU memory utilization (default: 0.6)",
)
parser.add_argument(
"--seed",
type=int,
default=1,
help="Random seed (default: 1)",
)
return parser.parse_args()
args = parse_args()
# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
# it is important to set an explicit seed to make sure that
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
model=args.model,
tensor_parallel_size=args.tp_size,
data_parallel_size=args.dp_size,
pipeline_parallel_size=args.pp_size,
enable_expert_parallel=args.enable_ep,
distributed_executor_backend="external_launcher",
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
seed=args.seed,
)
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
prompts = [
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
)
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,
a PyTorch ProcessGroup with GLOO backend.
```python
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
torch_rank = dist.get_rank(group=cpu_group)
if torch_rank == 0:
# do something for rank 0, e.g. saving the results to disk.
```
2. to communicate data across all ranks, use the model's device group,
a PyTorch ProcessGroup with NCCL backend.
```python
from vllm.distributed.parallel_state import get_world_group
device_group = get_world_group().device_group
```
3. to access the model directly in every rank, use the following code:
```python
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
```
"""

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
experimental support for tensor-parallel inference with torchrun,
see https://github.com/vllm-project/vllm/issues/11400 for
the motivation and use case for this example.
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
the argument 2 should match the `tensor_parallel_size` below.
see `tests/distributed/test_torchrun_example.py` for the unit test.
"""
import torch.distributed as dist
from vllm import LLM, SamplingParams
# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
# it is important to set an explicit seed to make sure that
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
model="meta-llama/Llama-3.1-8B",
tensor_parallel_size=2,
pipeline_parallel_size=2,
distributed_executor_backend="external_launcher",
max_model_len=32768,
seed=1,
)
outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs
if dist.get_rank() == 0:
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
print("-" * 50)
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,
a PyTorch ProcessGroup with GLOO backend.
```python
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
torch_rank = dist.get_rank(group=cpu_group)
if torch_rank == 0:
# do something for rank 0, e.g. saving the results to disk.
```
2. to communicate data across all ranks, use the model's device group,
a PyTorch ProcessGroup with NCCL backend.
```python
from vllm.distributed.parallel_state import get_world_group
device_group = get_world_group().device_group
```
3. to access the model directly in every rank, use the following code:
```python
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
```
"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for `vllm.entrypoints.api_server`
Start the demo server:
python -m vllm.entrypoints.api_server --model <model_name>
NOTE: The API server is used only for demonstration and simple performance
benchmarks. It is not intended for production use.
For production use, we recommend `vllm serve` and the OpenAI client API.
"""
import argparse
import json
from argparse import Namespace
from collections.abc import Iterable
import requests
def clear_line(n: int = 1) -> None:
LINE_UP = "\033[1A"
LINE_CLEAR = "\x1b[2K"
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(
prompt: str, api_url: str, n: int = 1, stream: bool = False
) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": n,
"temperature": 0.0,
"max_tokens": 16,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
return response
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
def get_response(response: requests.Response) -> list[str]:
data = json.loads(response.content)
output = data["text"]
return output
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=1)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
return parser.parse_args()
def main(args: Namespace):
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
n = args.n
stream = args.stream
print(f"Prompt: {prompt!r}\n", flush=True)
response = post_http_request(prompt, api_url, n, stream)
if stream:
num_printed_lines = 0
for h in get_streaming_response(response):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,6 @@
*.png
.git/
ct.yaml
lintconf.yaml
values.schema.json
/workflows

View File

@@ -0,0 +1,21 @@
apiVersion: v2
name: chart-vllm
description: Chart vllm
# A chart can be either an 'application' or a 'library' chart.
#
# Application charts are a collection of templates that can be packaged into versioned archives
# to be deployed.
#
# Library charts provide useful utilities or functions for the chart developer. They're included as
# a dependency of application charts to inject those utilities and functions into the rendering
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
version: 0.0.1
maintainers:
- name: mfournioux

View File

@@ -0,0 +1,33 @@
# Helm Charts
This directory contains a Helm chart for deploying the vllm application. The chart includes configurations for deployment, autoscaling, resource management, and more.
## Files
- Chart.yaml: Defines the chart metadata including name, version, and maintainers.
- ct.yaml: Configuration for chart testing.
- lintconf.yaml: Linting rules for YAML files.
- values.schema.json: JSON schema for validating values.yaml.
- values.yaml: Default values for the Helm chart.
- templates/_helpers.tpl: Helper templates for defining common configurations.
- templates/configmap.yaml: Template for creating ConfigMaps.
- templates/custom-objects.yaml: Template for custom Kubernetes objects.
- templates/deployment.yaml: Template for creating Deployments.
- templates/hpa.yaml: Template for Horizontal Pod Autoscaler.
- templates/job.yaml: Template for Kubernetes Jobs.
- templates/poddisruptionbudget.yaml: Template for Pod Disruption Budget.
- templates/pvc.yaml: Template for Persistent Volume Claims.
- templates/secrets.yaml: Template for Kubernetes Secrets.
- templates/service.yaml: Template for creating Services.
## Running Tests
This chart includes unit tests using [helm-unittest](https://github.com/helm-unittest/helm-unittest). Install the plugin and run tests:
```bash
# Install plugin
helm plugin install https://github.com/helm-unittest/helm-unittest
# Run tests
helm unittest .
```

View File

@@ -0,0 +1,3 @@
chart-dirs:
- charts
validate-maintainers: false

View File

@@ -0,0 +1,42 @@
---
rules:
braces:
min-spaces-inside: 0
max-spaces-inside: 0
min-spaces-inside-empty: -1
max-spaces-inside-empty: -1
brackets:
min-spaces-inside: 0
max-spaces-inside: 0
min-spaces-inside-empty: -1
max-spaces-inside-empty: -1
colons:
max-spaces-before: 0
max-spaces-after: 1
commas:
max-spaces-before: 0
min-spaces-after: 1
max-spaces-after: 1
comments:
require-starting-space: true
min-spaces-from-content: 2
document-end: disable
document-start: disable # No --- to start a file
empty-lines:
max: 2
max-start: 0
max-end: 0
hyphens:
max-spaces-after: 1
indentation:
spaces: consistent
indent-sequences: whatever # - list indentation will handle both indentation and without
check-multi-line-strings: false
key-duplicates: enable
line-length: disable # Lines can be any length
new-line-at-end-of-file: disable
new-lines:
type: unix
trailing-spaces: enable
truthy:
level: warning

View File

@@ -0,0 +1,165 @@
{{/*
Define ports for the pods
*/}}
{{- define "chart.container-port" -}}
{{- default "8000" .Values.containerPort }}
{{- end }}
{{/*
Define service name
*/}}
{{- define "chart.service-name" -}}
{{- if .Values.serviceName }}
{{- .Values.serviceName | lower | trim }}
{{- else }}
"{{ .Release.Name }}-service"
{{- end }}
{{- end }}
{{/*
Define service port
*/}}
{{- define "chart.service-port" -}}
{{- if .Values.servicePort }}
{{- .Values.servicePort }}
{{- else }}
{{- include "chart.container-port" . }}
{{- end }}
{{- end }}
{{/*
Define service port name
*/}}
{{- define "chart.service-port-name" -}}
"service-port"
{{- end }}
{{/*
Define container port name
*/}}
{{- define "chart.container-port-name" -}}
"container-port"
{{- end }}
{{/*
Define deployment strategy
*/}}
{{- define "chart.strategy" -}}
strategy:
{{- if not .Values.deploymentStrategy }}
rollingUpdate:
maxSurge: 100%
maxUnavailable: 0
{{- else }}
{{ toYaml .Values.deploymentStrategy | indent 2 }}
{{- end }}
{{- end }}
{{/*
Define additional ports
*/}}
{{- define "chart.extraPorts" }}
{{- with .Values.extraPorts }}
{{ toYaml . }}
{{- end }}
{{- end }}
{{/*
Define chart external ConfigMaps and Secrets
*/}}
{{- define "chart.externalConfigs" -}}
{{- with .Values.externalConfigs -}}
{{ toYaml . }}
{{- end }}
{{- end }}
{{/*
Define liveness et readiness probes
*/}}
{{- define "chart.probes" -}}
{{- if .Values.readinessProbe }}
readinessProbe:
{{- with .Values.readinessProbe }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end }}
{{- if .Values.livenessProbe }}
livenessProbe:
{{- with .Values.livenessProbe }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Define resources
*/}}
{{- define "chart.resources" -}}
requests:
memory: {{ required "Value 'resources.requests.memory' must be defined !" .Values.resources.requests.memory | quote }}
cpu: {{ required "Value 'resources.requests.cpu' must be defined !" .Values.resources.requests.cpu | quote }}
{{- if and (gt (int (index .Values.resources.requests "nvidia.com/gpu")) 0) (gt (int (index .Values.resources.limits "nvidia.com/gpu")) 0) }}
nvidia.com/gpu: {{ required "Value 'resources.requests.nvidia.com/gpu' must be defined !" (index .Values.resources.requests "nvidia.com/gpu") | quote }}
{{- end }}
limits:
memory: {{ required "Value 'resources.limits.memory' must be defined !" .Values.resources.limits.memory | quote }}
cpu: {{ required "Value 'resources.limits.cpu' must be defined !" .Values.resources.limits.cpu | quote }}
{{- if and (gt (int (index .Values.resources.requests "nvidia.com/gpu")) 0) (gt (int (index .Values.resources.limits "nvidia.com/gpu")) 0) }}
nvidia.com/gpu: {{ required "Value 'resources.limits.nvidia.com/gpu' must be defined !" (index .Values.resources.limits "nvidia.com/gpu") | quote }}
{{- end }}
{{- end }}
{{/*
Define User used for the main container
*/}}
{{- define "chart.user" }}
{{- if .Values.image.runAsUser }}
runAsUser:
{{- with .Values.runAsUser }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end }}
{{- end }}
{{- define "chart.extraInitEnv" -}}
- name: S3_ENDPOINT_URL
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3endpoint
- name: S3_BUCKET_NAME
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3bucketname
- name: AWS_ACCESS_KEY_ID
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3accesskeyid
- name: AWS_SECRET_ACCESS_KEY
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3accesskey
{{- if .Values.extraInit.s3modelpath }}
- name: S3_PATH
value: "{{ .Values.extraInit.s3modelpath }}"
{{- end }}
{{- if hasKey .Values.extraInit "awsEc2MetadataDisabled" }}
- name: AWS_EC2_METADATA_DISABLED
value: "{{ .Values.extraInit.awsEc2MetadataDisabled }}"
{{- end }}
{{- end }}
{{/*
Define chart labels
*/}}
{{- define "chart.labels" -}}
{{- with .Values.labels -}}
{{ toYaml . }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,11 @@
{{- if .Values.configs -}}
apiVersion: v1
kind: ConfigMap
metadata:
name: "{{ .Release.Name }}-configs"
namespace: {{ .Release.Namespace }}
data:
{{- with .Values.configs }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end -}}

View File

@@ -0,0 +1,6 @@
{{- if .Values.customObjects }}
{{- range .Values.customObjects }}
{{- tpl (. | toYaml) $ }}
---
{{- end }}
{{- end }}

View File

@@ -0,0 +1,131 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: "{{ .Release.Name }}-deployment-vllm"
namespace: {{ .Release.Namespace }}
labels:
{{- include "chart.labels" . | nindent 4 }}
spec:
replicas: {{ .Values.replicaCount }}
{{- include "chart.strategy" . | nindent 2 }}
selector:
matchLabels:
environment: "test"
release: "test"
progressDeadlineSeconds: 1200
template:
metadata:
labels:
environment: "test"
release: "test"
spec:
containers:
- name: "vllm"
image: "{{ required "Required value 'image.repository' must be defined !" .Values.image.repository }}:{{ required "Required value 'image.tag' must be defined !" .Values.image.tag }}"
{{- if .Values.image.command }}
command :
{{- with .Values.image.command }}
{{- toYaml . | nindent 10 }}
{{- end }}
{{- end }}
securityContext:
{{- if .Values.image.securityContext }}
{{- with .Values.image.securityContext }}
{{- toYaml . | nindent 12 }}
{{- end }}
{{- else }}
runAsNonRoot: false
{{- include "chart.user" . | indent 12 }}
{{- end }}
imagePullPolicy: IfNotPresent
{{- if .Values.image.env }}
env :
{{- with .Values.image.env }}
{{- toYaml . | nindent 10 }}
{{- end }}
{{- else }}
env: []
{{- end }}
{{- if or .Values.externalConfigs .Values.configs .Values.secrets }}
envFrom:
{{- if .Values.configs }}
- configMapRef:
name: "{{ .Release.Name }}-configs"
{{- end }}
{{- if .Values.secrets}}
- secretRef:
name: "{{ .Release.Name }}-secrets"
{{- end }}
{{- include "chart.externalConfigs" . | nindent 12 }}
{{- end }}
ports:
- name: {{ include "chart.container-port-name" . }}
containerPort: {{ include "chart.container-port" . }}
{{- include "chart.extraPorts" . | nindent 12 }}
{{- include "chart.probes" . | indent 10 }}
resources: {{- include "chart.resources" . | nindent 12 }}
volumeMounts:
- name: {{ .Release.Name }}-storage
mountPath: /data
{{- with .Values.extraContainers }}
{{ toYaml . | nindent 8 }}
{{- end }}
{{- if and .Values.extraInit (or .Values.extraInit.modelDownload.enabled .Values.extraInit.initContainers) }}
initContainers:
{{- if .Values.extraInit.modelDownload.enabled }}
- name: wait-download-model
image: {{ .Values.extraInit.modelDownload.image.repository }}:{{ .Values.extraInit.modelDownload.image.tag }}
imagePullPolicy: {{ .Values.extraInit.modelDownload.image.pullPolicy }}
command: {{ .Values.extraInit.modelDownload.waitContainer.command | toJson }}
args:
{{- toYaml .Values.extraInit.modelDownload.waitContainer.args | nindent 10 }}
env:
{{- if .Values.extraInit.modelDownload.waitContainer.env }}
{{- toYaml .Values.extraInit.modelDownload.waitContainer.env | nindent 10 }}
{{- else }}
{{- include "chart.extraInitEnv" . | nindent 10 }}
{{- end }}
resources:
requests:
cpu: 200m
memory: 1Gi
limits:
cpu: 500m
memory: 2Gi
volumeMounts:
- name: {{ .Release.Name }}-storage
mountPath: /data
{{- end }}
{{- with .Values.extraInit.initContainers }}
{{- toYaml . | nindent 6 }}
{{- end }}
{{- end }}
volumes:
- name: {{ .Release.Name }}-storage
persistentVolumeClaim:
claimName: {{ .Release.Name }}-storage-claim
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- if and (gt (int (index .Values.resources.requests "nvidia.com/gpu")) 0) (gt (int (index .Values.resources.limits "nvidia.com/gpu")) 0) }}
runtimeClassName: nvidia
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: nvidia.com/gpu.product
operator: In
{{- with .Values.gpuModels }}
values:
{{- toYaml . | nindent 20 }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,31 @@
{{- if .Values.autoscaling.enabled }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: "{{ .Release.Name }}-hpa"
namespace: {{ .Release.Namespace }}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: vllm
minReplicas: {{ .Values.autoscaling.minReplicas }}
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
metrics:
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
{{- end }}
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,41 @@
{{- if and .Values.extraInit .Values.extraInit.modelDownload.enabled }}
apiVersion: batch/v1
kind: Job
metadata:
name: "{{ .Release.Name }}-init-vllm"
namespace: {{ .Release.Namespace }}
spec:
ttlSecondsAfterFinished: 100
template:
metadata:
name: init-vllm
spec:
containers:
- name: job-download-model
image: {{ .Values.extraInit.modelDownload.image.repository }}:{{ .Values.extraInit.modelDownload.image.tag }}
imagePullPolicy: {{ .Values.extraInit.modelDownload.image.pullPolicy }}
command: {{ .Values.extraInit.modelDownload.downloadJob.command | toJson }}
args:
{{- toYaml .Values.extraInit.modelDownload.downloadJob.args | nindent 8 }}
env:
{{- if .Values.extraInit.modelDownload.downloadJob.env }}
{{- toYaml .Values.extraInit.modelDownload.downloadJob.env | nindent 8 }}
{{- else }}
{{- include "chart.extraInitEnv" . | nindent 8 }}
{{- end }}
volumeMounts:
- name: {{ .Release.Name }}-storage
mountPath: /data
resources:
requests:
cpu: 200m
memory: 1Gi
limits:
cpu: 500m
memory: 2Gi
restartPolicy: OnFailure
volumes:
- name: {{ .Release.Name }}-storage
persistentVolumeClaim:
claimName: "{{ .Release.Name }}-storage-claim"
{{- end }}

View File

@@ -0,0 +1,7 @@
apiVersion: policy/v1
kind: PodDisruptionBudget
metadata:
name: "{{ .Release.Name }}-pdb"
namespace: {{ .Release.Namespace }}
spec:
maxUnavailable: {{ default 1 .Values.maxUnavailablePodDisruptionBudget }}

View File

@@ -0,0 +1,13 @@
{{- if .Values.extraInit }}
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: "{{ .Release.Name }}-storage-claim"
namespace: {{ .Release.Namespace }}
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: {{ .Values.extraInit.pvcStorage }}
{{- end }}

View File

@@ -0,0 +1,10 @@
apiVersion: v1
kind: Secret
metadata:
name: "{{ .Release.Name }}-secrets"
namespace: {{ .Release.Namespace }}
type: Opaque
data:
{{- range $key, $val := .Values.secrets }}
{{ $key }}: {{ $val | b64enc | quote }}
{{- end }}

View File

@@ -0,0 +1,14 @@
apiVersion: v1
kind: Service
metadata:
name: "{{ .Release.Name }}-service"
namespace: {{ .Release.Namespace }}
spec:
type: ClusterIP
ports:
- name: {{ include "chart.service-port-name" . }}
port: {{ include "chart.service-port" . }}
targetPort: {{ include "chart.container-port-name" . }}
protocol: TCP
selector:
{{- include "chart.labels" . | nindent 4 }}

View File

@@ -0,0 +1,135 @@
suite: test deployment
templates:
- deployment.yaml
tests:
- it: should create wait-download-model init container when modelDownload is enabled
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done"
downloadJob:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
initContainers: [ ]
pvcStorage: "1Gi"
s3modelpath: "relative_s3_model_path/opt-125m"
awsEc2MetadataDisabled: true
asserts:
- hasDocuments:
count: 1
- isKind:
of: Deployment
- isNotEmpty:
path: spec.template.spec.initContainers
- equal:
path: spec.template.spec.initContainers[0].name
value: wait-download-model
- equal:
path: spec.template.spec.initContainers[0].image
value: amazon/aws-cli:2.6.4
- equal:
path: spec.template.spec.initContainers[0].imagePullPolicy
value: IfNotPresent
- it: should only create custom init containers when modelDownload is disabled
set:
extraInit:
modelDownload:
enabled: false
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args: [ "-c", "echo test" ]
downloadJob:
command: [ "/bin/bash" ]
args: [ "-c", "echo test" ]
initContainers:
- name: llm-d-routing-proxy
image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
imagePullPolicy: IfNotPresent
ports:
- containerPort: 8080
name: proxy
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 1
- isKind:
of: Deployment
- lengthEqual:
path: spec.template.spec.initContainers
count: 1
- equal:
path: spec.template.spec.initContainers[0].name
value: llm-d-routing-proxy
- equal:
path: spec.template.spec.initContainers[0].image
value: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
- equal:
path: spec.template.spec.initContainers[0].ports[0].containerPort
value: 8080
- it: should create both wait-download-model and custom init containers when both are enabled
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done"
downloadJob:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
initContainers:
- name: llm-d-routing-proxy
image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
imagePullPolicy: IfNotPresent
ports:
- containerPort: 8080
name: proxy
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 1
- isKind:
of: Deployment
- lengthEqual:
path: spec.template.spec.initContainers
count: 2
- equal:
path: spec.template.spec.initContainers[0].name
value: wait-download-model
- equal:
path: spec.template.spec.initContainers[0].image
value: amazon/aws-cli:2.6.4
- equal:
path: spec.template.spec.initContainers[1].name
value: llm-d-routing-proxy
- equal:
path: spec.template.spec.initContainers[1].image
value: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
- equal:
path: spec.template.spec.initContainers[1].ports[0].containerPort
value: 8080

View File

@@ -0,0 +1,61 @@
suite: test job
templates:
- job.yaml
tests:
- it: should create job when modelDownload is enabled
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args: [ "-c", "wait" ]
downloadJob:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
pvcStorage: "1Gi"
s3modelpath: "relative_s3_model_path/opt-125m"
awsEc2MetadataDisabled: true
asserts:
- hasDocuments:
count: 1
- isKind:
of: Job
- equal:
path: spec.template.spec.containers[0].name
value: job-download-model
- equal:
path: spec.template.spec.containers[0].image
value: amazon/aws-cli:2.6.4
- equal:
path: spec.template.spec.restartPolicy
value: OnFailure
- it: should not create job when modelDownload is disabled
set:
extraInit:
modelDownload:
enabled: false
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args: [ "-c", "wait" ]
downloadJob:
command: [ "/bin/bash" ]
args: [ "-c", "download" ]
initContainers:
- name: llm-d-routing-proxy
image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 0

View File

@@ -0,0 +1,32 @@
suite: test pvc
templates:
- pvc.yaml
tests:
# Test Case: PVC Created When extraInit Defined
- it: should create pvc when extraInit is defined
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: ["/bin/bash"]
args: ["-c", "wait"]
downloadJob:
command: ["/bin/bash"]
args: ["-c", "download"]
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 1
- isKind:
of: PersistentVolumeClaim
- equal:
path: spec.accessModes[0]
value: ReadWriteOnce
- equal:
path: spec.resources.requests.storage
value: 10Gi

View File

@@ -0,0 +1,329 @@
{
"$schema": "http://json-schema.org/schema#",
"type": "object",
"properties": {
"image": {
"type": "object",
"properties": {
"repository": {
"type": "string"
},
"tag": {
"type": "string"
},
"command": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": [
"command",
"repository",
"tag"
]
},
"containerPort": {
"type": "integer"
},
"serviceName": {
"type": "null"
},
"servicePort": {
"type": "integer"
},
"extraPorts": {
"type": "array"
},
"replicaCount": {
"type": "integer"
},
"deploymentStrategy": {
"type": "object"
},
"resources": {
"type": "object",
"properties": {
"requests": {
"type": "object",
"properties": {
"cpu": {
"type": "integer"
},
"memory": {
"type": "string"
},
"nvidia.com/gpu": {
"type": "integer"
}
},
"required": [
"cpu",
"memory",
"nvidia.com/gpu"
]
},
"limits": {
"type": "object",
"properties": {
"cpu": {
"type": "integer"
},
"memory": {
"type": "string"
},
"nvidia.com/gpu": {
"type": "integer"
}
},
"required": [
"cpu",
"memory",
"nvidia.com/gpu"
]
}
},
"required": [
"limits",
"requests"
]
},
"gpuModels": {
"type": "array",
"items": {
"type": "string"
}
},
"autoscaling": {
"type": "object",
"properties": {
"enabled": {
"type": "boolean"
},
"minReplicas": {
"type": "integer"
},
"maxReplicas": {
"type": "integer"
},
"targetCPUUtilizationPercentage": {
"type": "integer"
}
},
"required": [
"enabled",
"maxReplicas",
"minReplicas",
"targetCPUUtilizationPercentage"
]
},
"configs": {
"type": "object"
},
"secrets": {
"type": "object"
},
"externalConfigs": {
"type": "array"
},
"customObjects": {
"type": "array"
},
"maxUnavailablePodDisruptionBudget": {
"type": "string"
},
"extraInit": {
"type": "object",
"properties": {
"modelDownload": {
"type": "object",
"properties": {
"enabled": {
"type": "boolean"
},
"image": {
"type": "object",
"properties": {
"repository": {
"type": "string"
},
"tag": {
"type": "string"
},
"pullPolicy": {
"type": "string"
}
},
"required": ["repository", "tag", "pullPolicy"]
},
"waitContainer": {
"type": "object",
"properties": {
"command": {
"type": "array",
"items": {"type": "string"}
},
"args": {
"type": "array",
"items": {"type": "string"}
},
"env": {
"type": "array",
"items": {"type": "object"}
}
},
"required": ["command", "args"]
},
"downloadJob": {
"type": "object",
"properties": {
"command": {
"type": "array",
"items": {"type": "string"}
},
"args": {
"type": "array",
"items": {"type": "string"}
},
"env": {
"type": "array",
"items": {"type": "object"}
}
},
"required": ["command", "args"]
}
},
"required": ["enabled", "image", "waitContainer", "downloadJob"]
},
"initContainers": {
"type": "array",
"items": {"type": "object"}
},
"s3modelpath": {
"type": "string"
},
"pvcStorage": {
"type": "string"
},
"awsEc2MetadataDisabled": {
"type": "boolean"
}
},
"required": [
"modelDownload",
"initContainers",
"pvcStorage"
]
},
"extraContainers": {
"type": "array"
},
"readinessProbe": {
"type": "object",
"properties": {
"initialDelaySeconds": {
"type": "integer"
},
"periodSeconds": {
"type": "integer"
},
"failureThreshold": {
"type": "integer"
},
"httpGet": {
"type": "object",
"properties": {
"path": {
"type": "string"
},
"port": {
"type": "integer"
}
},
"required": [
"path",
"port"
]
}
},
"required": [
"failureThreshold",
"httpGet",
"initialDelaySeconds",
"periodSeconds"
]
},
"livenessProbe": {
"type": "object",
"properties": {
"initialDelaySeconds": {
"type": "integer"
},
"failureThreshold": {
"type": "integer"
},
"periodSeconds": {
"type": "integer"
},
"httpGet": {
"type": "object",
"properties": {
"path": {
"type": "string"
},
"port": {
"type": "integer"
}
},
"required": [
"path",
"port"
]
}
},
"required": [
"failureThreshold",
"httpGet",
"initialDelaySeconds",
"periodSeconds"
]
},
"labels": {
"type": "object",
"properties": {
"environment": {
"type": "string"
},
"release": {
"type": "string"
}
},
"required": [
"environment",
"release"
]
}
},
"required": [
"autoscaling",
"configs",
"containerPort",
"customObjects",
"deploymentStrategy",
"externalConfigs",
"extraContainers",
"extraInit",
"extraPorts",
"gpuModels",
"image",
"labels",
"livenessProbe",
"maxUnavailablePodDisruptionBudget",
"readinessProbe",
"replicaCount",
"resources",
"secrets",
"servicePort"
]
}

View File

@@ -0,0 +1,174 @@
# -- Default values for chart vllm
# -- Declare variables to be passed into your templates.
# -- Image configuration
image:
# -- Image repository
repository: "vllm/vllm-openai"
# -- Image tag
tag: "latest"
# -- Container launch command
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
# -- Container port
containerPort: 8000
# -- Service name
serviceName:
# -- Service port
servicePort: 80
# -- Additional ports configuration
extraPorts: []
# -- Number of replicas
replicaCount: 1
# -- Deployment strategy configuration
deploymentStrategy: {}
# -- Resource configuration
resources:
requests:
# -- Number of CPUs
cpu: 4
# -- CPU memory configuration
memory: 16Gi
# -- Number of gpus used
nvidia.com/gpu: 1
limits:
# -- Number of CPUs
cpu: 4
# -- CPU memory configuration
memory: 16Gi
# -- Number of gpus used
nvidia.com/gpu: 1
# -- Type of gpu used
gpuModels:
- "TYPE_GPU_USED"
# -- Autoscaling configuration
autoscaling:
# -- Enable autoscaling
enabled: false
# -- Minimum replicas
minReplicas: 1
# -- Maximum replicas
maxReplicas: 100
# -- Target CPU utilization for autoscaling
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
# -- Configmap
configs: {}
# -- Secrets configuration
secrets: {}
# -- External configuration
externalConfigs: []
# -- Custom Objects configuration
customObjects: []
# -- Disruption Budget Configuration
maxUnavailablePodDisruptionBudget: ""
# -- Additional configuration for the init container
extraInit:
# -- Model download functionality (optional)
modelDownload:
# -- Enable model download job and wait container
enabled: true
# -- Image configuration for model download operations
image:
# -- Image repository
repository: "amazon/aws-cli"
# -- Image tag
tag: "2.6.4"
# -- Image pull policy
pullPolicy: "IfNotPresent"
# -- Wait container configuration (init container that waits for model to be ready)
waitContainer:
# -- Command to execute
command: ["/bin/bash"]
# -- Arguments for the wait container
args:
- "-eucx"
- "while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done"
# -- Environment variables (optional, overrides S3 defaults entirely if specified)
# env:
# - name: HUGGING_FACE_HUB_TOKEN
# value: "your-token"
# - name: MODEL_ID
# value: "meta-llama/Llama-2-7b"
# -- Download job configuration (job that actually downloads the model)
downloadJob:
# -- Command to execute
command: ["/bin/bash"]
# -- Arguments for the download job
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
# -- Environment variables (optional, overrides S3 defaults entirely if specified)
# env:
# - name: HUGGING_FACE_HUB_TOKEN
# value: "your-token"
# - name: MODEL_ID
# value: "meta-llama/Llama-2-7b"
# -- Custom init containers (appended after wait-download-model if modelDownload is enabled)
initContainers: []
# Example for llm-d sidecar:
# initContainers:
# - name: llm-d-routing-proxy
# image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
# imagePullPolicy: IfNotPresent
# ports:
# - containerPort: 8080
# name: proxy
# securityContext:
# runAsUser: 1000
# -- Path of the model on the s3 which hosts model weights and config files
s3modelpath: "relative_s3_model_path/opt-125m"
# -- Storage size for the PVC
pvcStorage: "1Gi"
# -- Disable AWS EC2 metadata service
awsEc2MetadataDisabled: true
# -- Additional containers configuration
extraContainers: []
# -- Readiness probe configuration
readinessProbe:
# -- Number of seconds after the container has started before readiness probe is initiated
initialDelaySeconds: 5
# -- How often (in seconds) to perform the readiness probe
periodSeconds: 5
# -- Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready
failureThreshold: 3
# -- Configuration of the Kubelet http request on the server
httpGet:
# -- Path to access on the HTTP server
path: /health
# -- Name or number of the port to access on the container, on which the server is listening
port: 8000
# -- Liveness probe configuration
livenessProbe:
# -- Number of seconds after the container has started before liveness probe is initiated
initialDelaySeconds: 15
# -- Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive
failureThreshold: 3
# -- How often (in seconds) to perform the liveness probe
periodSeconds: 10
# -- Configuration of the Kubelet http request on the server
httpGet:
# -- Path to access on the HTTP server
path: /health
# -- Name or number of the port to access on the container, on which the server is listening
port: 8000
labels:
environment: "test"
release: "test"

View File

@@ -0,0 +1,87 @@
# Monitoring Dashboards
This directory contains monitoring dashboard configurations for vLLM, providing
comprehensive observability for your vLLM deployments.
## Dashboard Platforms
We provide dashboards for two popular observability platforms:
- **[Grafana](https://grafana.com)**
- **[Perses](https://perses.dev)**
## Dashboard Format Approach
All dashboards are provided in **native formats** that work across different
deployment methods:
### Grafana (JSON)
- ✅ Works with any Grafana instance (cloud, self-hosted, Docker)
- ✅ Direct import via Grafana UI or API
- ✅ Can be wrapped in Kubernetes operators when needed
- ✅ No vendor lock-in or deployment dependencies
### Perses (YAML)
- ✅ Works with standalone Perses instances
- ✅ Compatible with Perses API and CLI
- ✅ Supports Dashboard-as-Code workflows
- ✅ Can be wrapped in Kubernetes operators when needed
## Dashboard Contents
Both platforms provide equivalent monitoring capabilities:
| Dashboard | Description |
| --------- | ----------- |
| **Performance Statistics** | Tracks latency, throughput, and performance metrics |
| **Query Statistics** | Monitors request volume, query performance, and KPIs |
## Quick Start
First, navigate to this example's directory:
```bash
cd examples/online_serving/dashboards
```
### Grafana
Import the JSON directly into the Grafana UI, or use the API:
```bash
curl -X POST http://grafana/api/dashboards/db \
-H "Content-Type: application/json" \
-d @grafana/performance_statistics.json
```
### Perses
Import via the Perses CLI:
```bash
percli apply -f perses/performance_statistics.yaml
```
## Requirements
- **Prometheus** metrics from your vLLM deployment
- **Data source** configured in your monitoring platform
- **vLLM metrics** enabled and accessible
## Platform-Specific Documentation
For detailed deployment instructions and platform-specific options, see:
- **[Grafana Documentation](./grafana)** - JSON dashboards, operator usage, manual import
- **[Perses Documentation](./perses)** - YAML specs, CLI usage, operator wrapping
## Contributing
When adding new dashboards, please:
1. Provide native formats (JSON for Grafana, YAML specs for Perses)
2. Update platform-specific README files
3. Ensure dashboards work across deployment methods
4. Test with the latest platform versions

View File

@@ -0,0 +1,59 @@
# Grafana Dashboards for vLLM Monitoring
This directory contains Grafana dashboard configurations (as JSON) designed to monitor
vLLM performance and metrics.
## Requirements
- Grafana 8.0+
- Prometheus data source configured in Grafana
- vLLM deployment with Prometheus metrics enabled
## Dashboard Descriptions
- **performance_statistics.json**: Tracks performance metrics including latency and
throughput for your vLLM service.
- **query_statistics.json**: Tracks query performance, request volume, and key
performance indicators for your vLLM service.
## Deployment Options
### Manual Import (Recommended)
The easiest way to use these dashboards is to manually import the JSON configurations
directly into your Grafana instance:
1. Navigate to your Grafana instance
2. Click the '+' icon in the sidebar
3. Select 'Import'
4. Copy and paste the JSON content from the dashboard files, or upload the JSON files
directly
### Grafana Operator
If you're using the [Grafana Operator](https://github.com/grafana-operator/grafana-operator)
in Kubernetes, you can wrap these JSON configurations in a `GrafanaDashboard` custom
resource:
```yaml
# Note: Adjust the instanceSelector to match your Grafana instance's labels
# You can check with: kubectl get grafana -o yaml
apiVersion: grafana.integreatly.org/v1beta1
kind: GrafanaDashboard
metadata:
name: vllm-performance-dashboard
spec:
instanceSelector:
matchLabels:
dashboards: grafana # Adjust to match your Grafana instance labels
folder: "vLLM Monitoring"
json: |
# Replace this comment with the complete JSON content from
# performance_statistics.json - The JSON should start with { and end with }
```
Then apply to your cluster:
```bash
kubectl apply -f your-dashboard.yaml -n <namespace>
```

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,760 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"description": "High-level overview of VLLM model deployment behavior and key performance indicators. Designed for Data Scientists and Product Managers to monitor request volume, token throughput, and latency",
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"id": 47,
"links": [],
"panels": [
{
"collapsed": true,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 0 },
"id": 20,
"panels": [],
"title": "Request Over Time",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": { "legend": false, "tooltip": false, "viz": false },
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "auto",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "off" }
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "req/s"
},
"overrides": []
},
"gridPos": { "h": 6, "w": 10, "x": 0, "y": 1 },
"id": 1,
"options": {
"legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "single", "sort": "none" }
},
"pluginVersion": "11.3.0",
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"editorMode": "code",
"expr": "sum by (model_name) (\n rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval])\n)",
"interval": "1",
"legendFormat": "{{model_name}}",
"range": true,
"refId": "A"
}
],
"title": "Successful Requests Over Time",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "req/s"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 10, "y": 1 },
"id": 2,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["mean"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "sum(rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Requests Avg Rate",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calcultaions": { "index": 0, "text": "Last (not null)" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "ms"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 17, "y": 1 },
"id": 3,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.50, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "p50 Latency",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "ms"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 10, "y": 4 },
"id": 4,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.90, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "p90 Latency",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "ms"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 17, "y": 4 },
"id": 5,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "p99 Latency",
"type": "stat"
},
{
"collapsed": false,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 7 },
"id": 19,
"panels": [],
"title": "Size Distribution",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"fillOpacity": 80,
"gradientMode": "none",
"hideFrom": { "legend": false, "tooltip": false, "viz": false },
"lineWidth": 1,
"stacking": { "group": "A", "mode": "none" }
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 6, "w": 10, "x": 0, "y": 8 },
"id": 6,
"options": {
"legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "single", "sort": "none" }
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "sum by (le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval]))",
"legendFormat": "{{model_name}} le={{le}}",
"range": true,
"refId": "A"
}
],
"title": "Input Token Size Distribution",
"type": "histogram"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "calculation ": { "index": 0, "text": "Last (not null)" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 10, "y": 8 },
"id": 9,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.90, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Input Token Size p90",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 17, "y": 8 },
"id": 8,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.50, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Input Token Size p50",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calcultaion": { "index": 0, "text": "mean" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 10, "y": 11 },
"id": 7,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "sum(rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))\n/\nsum(rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Input Token Size Avg",
"type": "stat"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 17, "y": 11 },
"id": 10,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Input Token Size p99",
"type": "stat"
},
{
"collapsed": true,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 14 },
"id": 18,
"panels": [],
"title": "Input Token Over Time",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": { "legend": false, "tooltip": false, "viz": false },
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "auto",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "off" }
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 6, "w": 10, "x": 0, "y": 15 },
"id": 11,
"options": {
"legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "single", "sort": "none" }
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "sum by (model_name) (rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))",
"legendFormat": "{{model_name}}",
"range": true,
"refId": "A"
}
],
"title": "Input Tokens Over Time",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calculation": { "index": 0, "text": "mean" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 10, "y": 15 },
"id": 12,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "sum(rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Input Tokens/Sec Avg",
"type": "stat"
},
{
"collapsed": false,
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 21 },
"id": 17,
"panels": [],
"title": "Output Token Over Time",
"type": "row"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": { "legend": false, "tooltip": false, "viz": false },
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "auto",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "off" }
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 6, "w": 10, "x": 0, "y": 22 },
"id": 13,
"options": {
"legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true },
"tooltip": { "mode": "single", "sort": "none" }
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "sum by (model_name) (rate(vllm:generation_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))",
"legendFormat": "{{model_name}}",
"range": true,
"refId": "A"
}
],
"title": "Output Tokens Over Time",
"type": "timeseries"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "thresholds" },
"mappings": [
{ "options": { "Calculation": { "index": 0, "text": "mean" } }, "type": "value" }
],
"thresholds": {
"mode": "absolute",
"steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }]
},
"unit": "cps"
},
"overrides": []
},
"gridPos": { "h": 3, "w": 7, "x": 10, "y": 22 },
"id": 14,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false },
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "11.3.0",
"targets": [
{
"editorMode": "code",
"expr": "sum(rate(vllm:generation_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Output Tokens/Sec Avg",
"type": "stat"
}
],
"preload": false,
"schemaVersion": 40,
"tags": [],
"templating": {
"list": [
{
"current": { "text": "Prometheus", "value": "4184fc20-68a7-483a-8d9b-7caa59c680dd" },
"label": "datasource",
"name": "DS_PROMETHEUS",
"options": [],
"query": "prometheus",
"refresh": 1,
"type": "datasource"
},
{
"current": { "text": ["All"], "value": ["$__all"] },
"definition": "label_values(vllm:request_success_total,model_name)",
"includeAll": true,
"label": "Deployment_ID",
"multi": true,
"name": "Deployment_id",
"options": [],
"query": {
"qryType": 1,
"query": "label_values(vllm:request_success_total,model_name)",
"refId": "PrometheusVariableQueryEditor-VariableQuery"
},
"refresh": 1,
"regex": "",
"sort": 1,
"type": "query"
},
{
"current": { "text": "All hours", "value": "All hours" },
"hide": 2,
"label": "Rush Hours Only",
"name": "rush_hours",
"options": [
{ "selected": true, "text": "false", "value": "All hours" },
{ "selected": false, "text": "true", "value": "Rush hours" }
],
"query": "false : All hours, true : Rush hours",
"type": "custom"
},
{
"current": { "text": "All", "value": "All" },
"hide": 2,
"label": "Rush Hours Type",
"name": "rush_hours_type",
"options": [
{ "selected": true, "text": "^All__.*$", "value": "All" },
{ "selected": false, "text": "^Static__.*$", "value": "Static" },
{ "selected": false, "text": "^Dynamic__.*$", "value": "Dynamic" }
],
"query": "^All__.*$ : All, ^Static__.*$ : Static, ^Dynamic__.*$ : Dynamic",
"type": "custom"
},
{
"current": { "text": "", "value": "" },
"hide": 2,
"name": "query0",
"options": [],
"query": "",
"refresh": 1,
"regex": "",
"type": "query"
}
]
},
"time": { "from": "now-12h", "to": "now" },
"timepicker": {},
"timezone": "browser",
"title": "Query Statistics_New4",
"uid": "query-statistics4",
"version": 2,
"weekStart": ""
}

View File

@@ -0,0 +1,48 @@
# Perses Dashboards for vLLM Monitoring
This directory contains Perses dashboard configurations designed to monitor vLLM
performance and metrics.
## Requirements
- Perses instance (standalone or via operator)
- Prometheus data source configured in Perses
- vLLM deployment with Prometheus metrics enabled
## Dashboard Format
We provide dashboards in the **native Perses YAML format** that works across all
deployment methods:
- **Files**: `*.yaml` (native Perses dashboard specifications)
- **Format**: Pure dashboard specifications that work everywhere
- **Usage**: Works with standalone Perses, API imports, CLI, and file provisioning
- **Kubernetes**: Directly compatible with Perses Operator
## Dashboard Descriptions
- **performance_statistics.yaml**: Performance metrics with aggregated latency
statistics
- **query_statistics.yaml**: Query performance and deployment metrics
## Deployment Options
### Direct Import to Perses
Import the dashboard specifications via Perses API or CLI:
```bash
percli apply -f performance_statistics.yaml
```
### Perses Operator (Kubernetes)
The native YAML format works directly with the Perses Operator:
```bash
kubectl apply -f performance_statistics.yaml -n <namespace>
```
### File Provisioning
Place the YAML files in a Perses provisioning folder for automatic loading.

View File

@@ -0,0 +1,764 @@
kind: PersesDashboard
metadata:
name: performance-statistics
createdAt: 0001-01-01T00:00:00Z
updatedAt: 0001-01-01T00:00:00Z
version: 0
project: ""
spec:
display:
name: Performance Statistics
variables:
- kind: ListVariable
spec:
display:
name: Deployment_ID
hidden: false
name: Deployment_id
allowAllValue: true
allowMultiple: true
defaultValue:
- $__all
sort: alphabetical-asc
plugin:
kind: PrometheusLabelValuesVariable
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
labelName: model_name
matchers:
# Any one vllm metric that always carries model_name
- vllm:generation_tokens_total{}
panels:
"1":
kind: Panel
spec:
display:
name: E2E Latency over Time
plugin:
kind: TimeSeriesChart
spec:
legend:
mode: table
position: bottom
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
# avg latency by model = sum(rate(sum)) / sum(rate(count))
query: >
sum by (model_name) (rate(vllm:e2e_request_latency_seconds_sum{model_name=~"$Deployment_id"}[$__interval]))
/
sum by (model_name) (rate(vllm:e2e_request_latency_seconds_count{model_name=~"$Deployment_id"}[$__interval]))
seriesNameFormat: '{{model_name}}'
"2":
kind: Panel
spec:
display:
name: E2E Latency (Avg)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
(sum by (model_name) (increase(vllm:e2e_request_latency_seconds_sum{model_name=~"$Deployment_id"}[$__range])))
/
(sum by (model_name) (increase(vllm:e2e_request_latency_seconds_count{model_name=~"$Deployment_id"}[$__range])))
"3":
kind: Panel
spec:
display:
name: E2E Latency (P50)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.50,
sum by (le, model_name) (
rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"4":
kind: Panel
spec:
display:
name: E2E Latency (P90)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.90,
sum by (le, model_name) (
rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"5":
kind: Panel
spec:
display:
name: E2E Latency (P99)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.99,
sum by (le, model_name) (
rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"6":
kind: Panel
spec:
display:
name: TTFT over Time
plugin:
kind: TimeSeriesChart
spec:
legend:
mode: table
position: bottom
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
sum by (model_name) (rate(vllm:time_to_first_token_seconds_sum{model_name=~"$Deployment_id"}[$__interval]))
/
sum by (model_name) (rate(vllm:time_to_first_token_seconds_count{model_name=~"$Deployment_id"}[$__interval]))
seriesNameFormat: '{{model_name}}'
"7":
kind: Panel
spec:
display:
name: TTFT (Avg)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
(sum by (model_name) (increase(vllm:time_to_first_token_seconds_sum{model_name=~"$Deployment_id"}[$__range])))
/
(sum by (model_name) (increase(vllm:time_to_first_token_seconds_count{model_name=~"$Deployment_id"}[$__range])))
"8":
kind: Panel
spec:
display:
name: TTFT (P50)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.50,
sum by (le, model_name) (
rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"9":
kind: Panel
spec:
display:
name: TTFT (P90)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.90,
sum by (le, model_name) (
rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"10":
kind: Panel
spec:
display:
name: TTFT (P99)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.99,
sum by (le, model_name) (
rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"11":
kind: Panel
spec:
display:
name: ITL (Time per Output Token) over Time
plugin:
kind: TimeSeriesChart
spec:
legend:
mode: table
position: bottom
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
sum by (model_name) (rate(vllm:inter_token_latency_seconds_sum{model_name=~"$Deployment_id"}[$__interval]))
/
sum by (model_name) (rate(vllm:inter_token_latency_seconds_count{model_name=~"$Deployment_id"}[$__interval]))
seriesNameFormat: '{{model_name}}'
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.50,
sum by (le, model_name) (
rate(vllm:inter_token_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
seriesNameFormat: '{{model_name}} p50'
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.90,
sum by (le, model_name) (
rate(vllm:inter_token_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
seriesNameFormat: '{{model_name}} p90'
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.99,
sum by (le, model_name) (
rate(vllm:inter_token_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
seriesNameFormat: '{{model_name}} p99'
"12":
kind: Panel
spec:
display:
name: ITL (Avg)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
(sum by (model_name) (increase(vllm:inter_token_latency_seconds_sum{model_name=~"$Deployment_id"}[$__range])))
/
(sum by (model_name) (increase(vllm:inter_token_latency_seconds_count{model_name=~"$Deployment_id"}[$__range])))
"13":
kind: Panel
spec:
display:
name: ITL (P50)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.50,
sum by (le, model_name) (
rate(vllm:inter_token_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"14":
kind: Panel
spec:
display:
name: ITL (P90)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.90,
sum by (le, model_name) (
rate(vllm:inter_token_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"15":
kind: Panel
spec:
display:
name: ITL (P99)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
histogram_quantile(
0.99,
sum by (le, model_name) (
rate(vllm:inter_token_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval])
)
)
"16":
kind: Panel
spec:
display:
name: TPS (Tokens/sec) over Time
plugin:
kind: TimeSeriesChart
spec:
legend:
mode: table
position: bottom
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
sum by (model_name) (rate(vllm:generation_tokens_total{model_name=~"$Deployment_id"}[$__interval]))
seriesNameFormat: '{{model_name}} generation'
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
sum by (model_name) (rate(vllm:prompt_tokens_total{model_name=~"$Deployment_id"}[$__interval]))
seriesNameFormat: '{{model_name}} prompt'
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
# overall iteration tokens/sec if exposed
query: >
rate(vllm:iteration_tokens_total_count[$__interval])
seriesNameFormat: 'iteration overall'
"17":
kind: Panel
spec:
display:
name: KV Cache Usage (avg %)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
# Multiply by 100 so we can read it as a percentage without setting a unit (avoids CUE unit conflicts)
query: >
100 * avg(vllm:kv_cache_usage_perc)
"18":
kind: Panel
spec:
display:
name: Running Requests by Pod
plugin:
kind: TimeSeriesChart
spec:
legend:
mode: table
position: bottom
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
sum by (pod) (vllm:num_requests_running)
seriesNameFormat: '{{pod}}'
"19":
kind: Panel
spec:
display:
name: Waiting Requests by Pod
plugin:
kind: TimeSeriesChart
spec:
legend:
mode: table
position: bottom
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: >
sum by (pod) (vllm:num_requests_waiting)
seriesNameFormat: '{{pod}}'
"20":
kind: Panel
spec:
display:
name: Running Requests (sum)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: sum(vllm:num_requests_running)
"21":
kind: Panel
spec:
display:
name: Waiting Requests (sum)
plugin:
kind: StatChart
spec:
calculation: last-number
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource:
kind: PrometheusDatasource
name: accelerators-thanos-querier-datasource
query: sum(vllm:num_requests_waiting)
layouts:
- kind: Grid
spec:
display:
title: Overview
items:
- x: 0
y: 0
width: 6
height: 3
content: { $ref: '#/spec/panels/17' } # KV cache %
- x: 6
y: 0
width: 6
height: 3
content: { $ref: '#/spec/panels/20' } # running sum
- x: 12
y: 0
width: 6
height: 3
content: { $ref: '#/spec/panels/21' } # waiting sum
- kind: Grid
spec:
display:
title: E2E Latency
items:
- x: 0
y: 1
width: 10
height: 6
content: { $ref: '#/spec/panels/1' }
- x: 10
y: 1
width: 7
height: 3
content: { $ref: '#/spec/panels/2' }
- x: 17
y: 1
width: 7
height: 3
content: { $ref: '#/spec/panels/3' }
- x: 10
y: 4
width: 7
height: 3
content: { $ref: '#/spec/panels/4' }
- x: 17
y: 4
width: 7
height: 3
content: { $ref: '#/spec/panels/5' }
- kind: Grid
spec:
display:
title: TTFT
items:
- x: 0
y: 8
width: 10
height: 6
content: { $ref: '#/spec/panels/6' }
- x: 10
y: 8
width: 7
height: 3
content: { $ref: '#/spec/panels/7' }
- x: 17
y: 8
width: 7
height: 3
content: { $ref: '#/spec/panels/8' }
- x: 10
y: 11
width: 7
height: 3
content: { $ref: '#/spec/panels/9' }
- x: 17
y: 11
width: 7
height: 3
content: { $ref: '#/spec/panels/10' }
- kind: Grid
spec:
display:
title: ITL (Time per Output Token)
items:
- x: 0
y: 15
width: 10
height: 6
content: { $ref: '#/spec/panels/11' }
- x: 10
y: 15
width: 7
height: 3
content: { $ref: '#/spec/panels/12' }
- x: 17
y: 15
width: 7
height: 3
content: { $ref: '#/spec/panels/13' }
- x: 10
y: 18
width: 7
height: 3
content: { $ref: '#/spec/panels/14' }
- x: 17
y: 18
width: 7
height: 3
content: { $ref: '#/spec/panels/15' }
- kind: Grid
spec:
display:
title: TPS (Prompt / Generation / Iteration)
items:
- x: 0
y: 22
width: 14
height: 6
content: { $ref: '#/spec/panels/16' }
- kind: Grid
spec:
display:
title: Per-Pod Request State
items:
- x: 0
y: 28
width: 12
height: 6
content: { $ref: '#/spec/panels/18' }
- x: 12
y: 28
width: 12
height: 6
content: { $ref: '#/spec/panels/19' }

View File

@@ -0,0 +1,392 @@
kind: PersesDashboard
metadata:
name: query-statistics
createdAt: 0001-01-01T00:00:00Z
updatedAt: 0001-01-01T00:00:00Z
version: 0
project: ""
spec:
display:
name: Query Statistics_New
variables:
- kind: ListVariable
spec:
name: NS
display: { name: Namespace }
allowMultiple: false
defaultValue: llm-d
plugin:
kind: PrometheusLabelValuesVariable
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
labelName: namespace
matchers:
- up{service=~".*vllm.*"}
- kind: ListVariable
spec:
name: SVC
display: { name: Service }
allowMultiple: false
defaultValue: vllm-qwen2-0-5b-sim
plugin:
kind: PrometheusLabelValuesVariable
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
labelName: service
matchers:
- up{namespace="$NS",service=~".*vllm.*"}
- kind: ListVariable
spec:
name: MODEL
display: { name: Model (real vLLM) }
allowAllValue: true
allowMultiple: true
defaultValue: ["$__all"]
plugin:
kind: PrometheusLabelValuesVariable
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
labelName: model_name
matchers:
- vllm:request_success_total{namespace="$NS",service="$SVC"}
panels:
# --- Core (works on Simulator & Real) ---
core_running_now:
kind: Panel
spec:
display: { name: Running Requests (now) }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum(vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
core_waiting_now:
kind: Panel
spec:
display: { name: Waiting Requests (now) }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum(vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
core_kv_usage_now:
kind: Panel
spec:
display: { name: KV Cache Usage (01) }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: avg(vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
core_running_ts:
kind: Panel
spec:
display: { name: Running Over Time }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum by (service) (vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
core_waiting_ts:
kind: Panel
spec:
display: { name: Waiting Over Time }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum by (service) (vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
core_targets_up:
kind: Panel
spec:
display: { name: Scrape Targets Up }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: count(up{namespace="$NS",service="$SVC"} == 1) or vector(0)
minStep: "15s"
# --- KV Cache as Percent (works on Simulator & Real) ---
core_kv_usage_pct_now:
kind: Panel
spec:
display: { name: KV Cache Usage (%) now }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
# multiply by 100 to present percentage; omit format.unit to avoid schema conflicts
query: (avg(vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
minStep: "15s"
core_kv_usage_pct_ts:
kind: Panel
spec:
display: { name: KV Cache Usage (%) over time }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: (avg by (service) (vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
minStep: "15s"
# --- Per-Pod breakdowns (works on Simulator & Real) ---
per_pod_running_ts:
kind: Panel
spec:
display: { name: Running by Pod }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum by (pod) (vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
per_pod_waiting_ts:
kind: Panel
spec:
display: { name: Waiting by Pod }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum by (pod) (vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0)
minStep: "15s"
per_pod_kv_pct_ts:
kind: Panel
spec:
display: { name: KV Cache (%) by Pod }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
# if your exporter labels kv metric with pod (the sim does), this works; otherwise it will just return empty
query: (avg by (pod) (vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0)
minStep: "15s"
# --- Real vLLM only (zeros on simulator) ---
real_req_rate_ts:
kind: Panel
spec:
display: { name: Request Rate (real vLLM) }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum by (model_name) (rate(vllm:request_success_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0)
minStep: "15s"
real_p50:
kind: Panel
spec:
display: { name: p50 Latency (real vLLM) }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: histogram_quantile(0.50, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0)
minStep: "15s"
real_p90:
kind: Panel
spec:
display: { name: p90 Latency (real vLLM) }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: histogram_quantile(0.90, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0)
minStep: "15s"
real_p99:
kind: Panel
spec:
display: { name: p99 Latency (real vLLM) }
plugin: { kind: StatChart, spec: { calculation: last-number } }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: histogram_quantile(0.99, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0)
minStep: "15s"
real_input_tokens_ts:
kind: Panel
spec:
display: { name: Input Tokens / sec (real vLLM) }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum by (model_name) (rate(vllm:prompt_tokens_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0)
minStep: "15s"
real_output_tokens_ts:
kind: Panel
spec:
display: { name: Output Tokens / sec (real vLLM) }
plugin:
kind: TimeSeriesChart
spec:
legend: { mode: table, position: bottom }
visual: { display: line, lineWidth: 1, areaOpacity: 0.3 }
queries:
- kind: TimeSeriesQuery
spec:
plugin:
kind: PrometheusTimeSeriesQuery
spec:
datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource }
query: sum by (model_name) (rate(vllm:generation_tokens_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0)
minStep: "15s"
layouts:
- kind: Grid
spec:
display: { title: Core (Sim & Real) }
items:
- { x: 0, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_running_now' } }
- { x: 6, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_waiting_now' } }
- { x: 12, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_kv_usage_now' } }
- { x: 18, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_targets_up' } }
- { x: 0, y: 3, width: 12, height: 6, content: { $ref: '#/spec/panels/core_running_ts' } }
- { x: 12, y: 3, width: 12, height: 6, content: { $ref: '#/spec/panels/core_waiting_ts' } }
- kind: Grid
spec:
display: { title: KV Cache (%) }
items:
- { x: 0, y: 9, width: 6, height: 3, content: { $ref: '#/spec/panels/core_kv_usage_pct_now' } }
- { x: 6, y: 9, width: 18, height: 6, content: { $ref: '#/spec/panels/core_kv_usage_pct_ts' } }
- kind: Grid
spec:
display: { title: Per-Pod breakdowns }
items:
- { x: 0, y: 15, width: 12, height: 6, content: { $ref: '#/spec/panels/per_pod_running_ts' } }
- { x: 12, y: 15, width: 12, height: 6, content: { $ref: '#/spec/panels/per_pod_waiting_ts' } }
- { x: 0, y: 21, width: 24, height: 6, content: { $ref: '#/spec/panels/per_pod_kv_pct_ts' } }
- kind: Grid
spec:
display: { title: Real vLLM only (shows 0 on simulator) }
items:
- { x: 0, y: 27, width: 12, height: 6, content: { $ref: '#/spec/panels/real_req_rate_ts' } }
- { x: 12, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p50' } }
- { x: 16, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p90' } }
- { x: 20, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p99' } }
- { x: 0, y: 33, width: 12, height: 6, content: { $ref: '#/spec/panels/real_input_tokens_ts' } }
- { x: 12, y: 33, width: 12, height: 6, content: { $ref: '#/spec/panels/real_output_tokens_ts' } }

View File

@@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test pause/resume with Data Parallel (DP) via HTTP API.
This example demonstrates coordinated pause/resume across multiple DP ranks.
The pause synchronizes across all DP engines via all-reduce.
Prerequisites:
Start a vLLM server with data parallelism:
$ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
--enforce-eager \
--data-parallel-size 4 \
--tensor-parallel-size 1
Then run this script:
$ python data_parallel_pause_resume.py
The test verifies pause works by:
1. Starting a streaming generation request
2. Pausing the server mid-generation
3. Sleeping for PAUSE_DURATION seconds
4. Resuming the server
5. Verifying there was a gap in token generation matching the pause duration
"""
import argparse
import threading
import time
import requests
from openai import OpenAI
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
PAUSE_DURATION = 3.0
def pause_generation(base_url: str, mode: str = "keep") -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, params={"mode": mode}, timeout=60)
response.raise_for_status()
print("Server paused")
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
print("Server resumed")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", default=BASE_URL)
parser.add_argument("--model", default=MODEL_NAME)
args = parser.parse_args()
client = OpenAI(
base_url=f"{args.base_url}/v1",
api_key="EMPTY",
)
prompt = "Write a long story about a dragon. Once upon a time"
token_times: list[float] = []
pause_token_idx = 0
pause_triggered = threading.Event()
def generator_thread():
"""Stream tokens and record timestamps."""
stream = client.completions.create(
model=args.model,
prompt=prompt,
max_tokens=50,
stream=True,
)
for chunk in stream:
if chunk.choices[0].text:
token_times.append(time.monotonic())
token_count = len(token_times)
print(f"Token {token_count}: {chunk.choices[0].text!r}")
# Signal controller after some tokens
if token_count >= 5 and not pause_triggered.is_set():
pause_triggered.set()
def controller_thread():
"""Pause and resume the server."""
nonlocal pause_token_idx
# Wait for some tokens
pause_triggered.wait()
print(f"\nPausing server (keep mode) at token {len(token_times)}...")
pause_generation(args.base_url, mode="keep")
pause_token_idx = len(token_times)
print(f"Sleeping for {PAUSE_DURATION}s...")
time.sleep(PAUSE_DURATION)
print("Resuming server...")
resume_generation(args.base_url)
print("Resumed!\n")
# Run both threads
gen_thread = threading.Thread(target=generator_thread)
ctrl_thread = threading.Thread(target=controller_thread)
gen_thread.start()
ctrl_thread.start()
gen_thread.join()
ctrl_thread.join()
# Check gap at the pause point
if pause_token_idx < len(token_times):
pause_gap = token_times[pause_token_idx] - token_times[pause_token_idx - 1]
print(
f"\nGap after pause (token {pause_token_idx} -> "
f"{pause_token_idx + 1}): {pause_gap:.3f}s"
)
if pause_gap >= PAUSE_DURATION * 0.9:
print("Test passed! Pause synchronized across DP ranks.")
else:
print(f"Test failed! Expected ~{PAUSE_DURATION}s gap, got {pause_gap:.3f}s")
else:
print("Test failed! No tokens were generated after resuming.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,121 @@
# Disaggregated Encoder
These example scripts that demonstrate the disaggregated encoder (EPD) features of vLLM.
For a detailed explanation of the EPD features, please refer to the [Disaggregated Encoder Feature Documentation](../../../docs/features/disagg_encoder.md).
## Files
- `disagg_epd_proxy.py` - Proxy script that demonstrates the XeYpZd setup (X encode instances, Y prefill instances, Z decode instances). Currently stable for the 1e1p1d configuration.
- `disagg_1e1p1d_example.sh` - Sets up the 1e1p1d configuration, runs the VisionArena benchmark, and processes a single request with a local image.
- `disagg_1e1pd_example.sh` - Sets up the 1e1pd configuration, runs the VisionArena benchmark, and processes a single request with a local image.
### Custom Configuration
```bash
# Use specific GPUs
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash disagg_1e1p1d_example.sh
# Use specific ports
ENDPOINT_PORT=10001 bash disagg_1e1p1d_example.sh
# Use specific model
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh
# Use specific storage path
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
```
## Encoder Instances
Encoder engines should be launched with the following flags:
- `--enforce-eager` **(required)** The current EPD implementation is only compatible with encoder instances running in this mode.
- `--no-enable-prefix-caching` **(required)** Encoder instances do not consume KV cache; prefix caching is disabled to avoid conflicts with other features.
- `--max-num-batched-tokens=<large value>` **(default: 2048)** This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
- `--mm-encoder-only` **(Optional)** - If possible, skips the language model during initialization to reduce device memory usage.
## Local media inputs
To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance:
```bash
--allowed-local-media-path $MEDIA_PATH
```
The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url": "file://'"$MEDIA_PATH_FILENAME"'}``` as multimodal inputs. Each URI is passed unchanged from the `disagg_encoder_proxy` to the encoder instance so that the encoder can load the media locally.
## EC connector and KV transfer
The `ECExampleonnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration:
```bash
# Add to encoder instance:
--ec-transfer-config '{
"ec_connector": "ECExampleConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}'
# Add to prefill/prefill+decode instance:
--ec-transfer-config '{
"ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
}
}'
```
`$EC_SHARED_STORAGE_PATH` is the path where the EC connector temporarily stores the cache.
If you enable prefill instance (`--prefill-servers-urls` not disabled), you will need --kv-transfer-config to facilitate the PD disaggregation. Currently, we use the `NixlConnector` for this purpose. Refer to `tests/v1/kv_connector/nixl_integration` for more example codes on PD disaggregation with Nixl.
```bash
# Add to prefill instance:
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_producer"
}'
# Add to decode instance:
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer"
}'
```
## Proxy Instance Flags (`disagg_epd_proxy.py`)
| Flag | Description |
| ---- | ----------- |
| `--encode-servers-urls` | Comma-separated list of encoder endpoints. Every multimodal item extracted from the request is fanned out to one of these URLs in a round-robin fashion. |
| `--prefill-servers-urls` | Comma-separated list of prefill endpoints. Set to `disable`, `none`, or `""` to skip the dedicated prefill phase and run E+PD (encoder + combined prefill/decode). |
| `--decode-servers-urls` | Comma-separated list of decode endpoints. Non-stream and stream paths both round-robin over this list. |
| `--host`, `--port` | Bind address for the proxy itself (defaults: `0.0.0.0:8000`). |
Example usage:
For E + PD setup:
```bash
$ python disagg_encoder_proxy.py \
--encode-servers-urls "http://e1:8001,http://e2:8002" \
--prefill-servers-urls "disable" \
--decode-servers-urls "http://pd1:8003,http://pd2:8004"
```
For E + P + D setup:
```bash
$ python disagg_encoder_proxy.py \
--encode-servers-urls "http://e1:8001,http://e2:8001" \
--prefill-servers-urls "http://p1:8003,http://p2:8004" \
--decode-servers-urls "http://d1:8005,http://d2:8006"
```

Some files were not shown because too many files have changed in this diff Show More