chore: vendor sglang v0.5.10 snapshot

This commit is contained in:
2026-04-24 12:29:36 +00:00
parent 78f0d15221
commit bded08301f
4308 changed files with 1200894 additions and 2 deletions

View File

@@ -0,0 +1,89 @@
#!/bin/bash
set -euo pipefail
# Detect GPU family from hostname (e.g., linux-mi35x-gpu-1-xxxxx-runner-zzzzz)
HOSTNAME_VALUE=$(hostname)
GPU_FAMILY=""
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
GPU_FAMILY="${BASH_REMATCH[1]}"
echo "Detected GPU family from hostname: ${GPU_FAMILY}"
else
echo "Warning: could not parse GPU family from '${HOSTNAME_VALUE}'"
fi
WORKDIR="/sglang-checkout/test/srt"
declare -A ENV_MAP=(
[SGLANG_IS_IN_CI_AMD]=1
[SGLANG_IS_IN_CI]=1
[SGLANG_USE_AITER]=1
)
# Conditionally add GPU_ARCHS only for mi35x
if [[ "${GPU_FAMILY}" == "mi35x" ]]; then
ENV_MAP[GPU_ARCHS]="gfx950"
fi
# Parse -w/--workdir and -e ENV=VAL
while [[ $# -gt 0 ]]; do
case "$1" in
-w|--workdir)
WORKDIR="$2"
shift 2
;;
-e)
IFS="=" read -r key val <<< "$2"
ENV_MAP["$key"]="$val"
shift 2
;;
--)
shift
break
;;
*)
break
;;
esac
done
# Build final ENV_ARGS
ENV_ARGS=()
for key in "${!ENV_MAP[@]}"; do
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
done
# Run docker exec with retry logic for HuggingFace network/download issues
# When HF model downloads fail due to network timeouts or rate limits,
# retrying with HF_HUB_OFFLINE=1 uses cached models from previous downloads.
#
# First attempt: normal mode (allows HF downloads)
if docker exec \
-w "$WORKDIR" \
"${ENV_ARGS[@]}" \
ci_sglang "$@"; then
exit 0
else
FIRST_EXIT_CODE=$?
fi
echo "First attempt failed with exit code $FIRST_EXIT_CODE"
# Skip retry for test failures that won't be fixed by offline mode:
# - Exit 1: Test assertion failures (accuracy below threshold)
# - Exit 137 (128+9): Process killed by OOM
# - Exit 255: Test suite completed with test errors
# Only retry for other exit codes (e.g., network timeouts, HF download failures)
if [[ "$FIRST_EXIT_CODE" -eq 1 || "$FIRST_EXIT_CODE" -eq 137 || "$FIRST_EXIT_CODE" -eq 255 ]]; then
echo "Exit code $FIRST_EXIT_CODE indicates test failure (not network issue), not retrying"
exit $FIRST_EXIT_CODE
fi
echo "Retrying with HF_HUB_OFFLINE=1 (offline mode to use cached models)..."
# Second attempt: force HF offline mode to avoid network timeouts
docker exec \
-w "$WORKDIR" \
"${ENV_ARGS[@]}" \
-e HF_HUB_OFFLINE=1 \
ci_sglang "$@"

View File

@@ -0,0 +1,320 @@
#!/bin/bash
set -euo pipefail
HOSTNAME_VALUE=$(hostname)
GPU_ARCH="mi30x" # default
SKIP_TT_DEPS=""
SKIP_SGLANG_BUILD=""
SKIP_AITER_BUILD=""
while [[ $# -gt 0 ]]; do
case $1 in
--skip-aiter-build) SKIP_AITER_BUILD="1"; shift;;
--skip-sglang-build) SKIP_SGLANG_BUILD="1"; shift;;
--skip-test-time-deps) SKIP_TT_DEPS="1"; shift;;
-h|--help)
echo "Usage: $0 [OPTIONS] [OPTIONAL_DEPS]"
echo "Options:"
echo " --skip-sglang-build Don't build checkout sglang, use what was shipped with the image"
echo " --skip-aiter-build Don't build aiter, use what was shipped with the image"
echo " --skip-test-time-deps Don't build miscellaneous dependencies"
exit 0
;;
*) break ;;
esac
done
OPTIONAL_DEPS="${1:-}"
# Build python extras
EXTRAS="dev_hip"
if [ -n "$OPTIONAL_DEPS" ]; then
EXTRAS="dev_hip,${OPTIONAL_DEPS}"
fi
echo "Installing python extras: [${EXTRAS}]"
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
GPU_ARCH="${BASH_REMATCH[1]}"
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
else
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
fi
# Install the required dependencies in CI.
# Fix permissions on pip cache, ignore errors from concurrent access or missing temp files
docker exec ci_sglang chown -R root:root /sgl-data/pip-cache 2>/dev/null || true
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache --upgrade pip
# Helper function to install with retries and fallback PyPI mirror
install_with_retry() {
local max_attempts=3
local cmd="$@"
for attempt in $(seq 1 $max_attempts); do
echo "Attempt $attempt/$max_attempts: $cmd"
if eval "$cmd"; then
echo "Success!"
return 0
fi
if [ $attempt -lt $max_attempts ]; then
echo "Failed, retrying in 5 seconds..."
sleep 5
# Try with alternative PyPI index on retry
if [[ "$cmd" =~ "pip install" ]] && [ $attempt -eq 2 ]; then
cmd="$cmd --index-url https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com"
echo "Using fallback PyPI mirror: $cmd"
fi
fi
done
echo "Failed after $max_attempts attempts"
return 1
}
# Helper function to git clone with retries
git_clone_with_retry() {
local repo_url="$1"
local dest_dir="${2:-}"
local branch_args="${3:-}"
local max_attempts=3
for attempt in $(seq 1 $max_attempts); do
echo "Git clone attempt $attempt/$max_attempts: $repo_url"
# prevent from partial clone
if [ -n "$dest_dir" ] && [ -d "$dest_dir" ]; then
rm -rf "$dest_dir"
fi
if git \
-c http.lowSpeedLimit=1000 \
-c http.lowSpeedTime=30 \
clone --depth 1 ${branch_args:+$branch_args} "$repo_url" "$dest_dir"; then
echo "Git clone succeeded."
return 0
fi
if [ $attempt -lt $max_attempts ]; then
echo "Git clone failed, retrying in 5 seconds..."
sleep 5
fi
done
echo "Git clone failed after $max_attempts attempts: $repo_url"
return 1
}
# Install checkout sglang
if [ -n "$SKIP_SGLANG_BUILD" ]; then
echo "Didn't build checkout SGLang"
else
docker exec ci_sglang pip uninstall sgl-kernel -y || true
docker exec ci_sglang pip uninstall sglang-kernel -y || true
docker exec ci_sglang pip uninstall sglang -y || true
# Clear Python cache to ensure latest code is used
docker exec ci_sglang find /opt/venv -name "*.pyc" -delete || true
docker exec ci_sglang find /opt/venv -name "__pycache__" -type d -exec rm -rf {} + || true
# Also clear cache in sglang-checkout
docker exec ci_sglang find /sglang-checkout -name "*.pyc" -delete || true
docker exec ci_sglang find /sglang-checkout -name "__pycache__" -type d -exec rm -rf {} + || true
docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
docker exec ci_sglang bash -c 'rm -rf python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml'
install_with_retry docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache -e "python[${EXTRAS}]"
fi
if [[ -n "${SKIP_TT_DEPS}" ]]; then
echo "Didn't build lmms_eval, human-eval, and others"
else
# For lmms_evals evaluating MMMU
docker exec -w / ci_sglang git clone --branch v0.4.1 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
install_with_retry docker exec -w /lmms-eval ci_sglang pip install --cache-dir=/sgl-data/pip-cache -e .
git_clone_with_retry https://github.com/akao-amd/human-eval.git human-eval
docker cp human-eval ci_sglang:/
install_with_retry docker exec -w /human-eval ci_sglang pip install --cache-dir=/sgl-data/pip-cache -e .
docker exec -w / ci_sglang mkdir -p /dummy-grok
# Create dummy grok config inline (bypasses Azure blob storage which may have auth issues)
mkdir -p dummy-grok
cat > dummy-grok/config.json << 'EOF'
{
"architectures": [
"Grok1ModelForCausalLM"
],
"embedding_multiplier_scale": 78.38367176906169,
"output_multiplier_scale": 0.5773502691896257,
"vocab_size": 131072,
"hidden_size": 6144,
"intermediate_size": 32768,
"max_position_embeddings": 8192,
"num_experts_per_tok": 2,
"num_local_experts": 8,
"num_attention_heads": 48,
"num_hidden_layers": 64,
"num_key_value_heads": 8,
"head_dim": 128,
"rms_norm_eps": 1e-05,
"rope_theta": 10000.0,
"model_type": "mixtral",
"torch_dtype": "bfloat16"
}
EOF
# docker exec -w / ci_sglang mkdir -p /dummy-grok
# mkdir -p dummy-grok && wget https://sharkpublic.blob.core.windows.net/sharkpublic/sglang/dummy_grok.json -O dummy-grok/config.json
# docker cp ./dummy-grok ci_sglang:/
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache huggingface_hub[hf_xet]
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache pytest
# Install cache-dit for qwen_image_t2i_cache_dit_enabled test (added in PR 16204)
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache cache-dit || echo "cache-dit installation failed"
# Install accelerate for distributed training and inference support
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache accelerate || echo "accelerate installation failed"
fi
if [[ -n "${SKIP_AITER_BUILD}" ]]; then
exit 0
fi
# Detect AITER version
#############################################
# Detect correct AITER_COMMIT for this runner
# + Check mismatch
# + Rebuild AITER if needed
#############################################
echo "[CI-AITER-CHECK] === AITER VERSION CHECK START ==="
DOCKERFILE="docker/rocm.Dockerfile"
# GPU_ARCH
GPU_ARCH="${GPU_ARCH:-mi30x}"
echo "[CI-AITER-CHECK] Runner GPU_ARCH=${GPU_ARCH}"
#############################################
# 1. Extract AITER_COMMIT from correct Dockerfile block
#############################################
if [[ "${GPU_ARCH}" == "mi35x" ]]; then
echo "[CI-AITER-CHECK] Using gfx950 block from Dockerfile..."
REPO_AITER_COMMIT=$(grep -F -A20 'FROM $BASE_IMAGE_950 AS gfx950' docker/rocm.Dockerfile \
| grep 'AITER_COMMIT_DEFAULT=' \
| head -n1 \
| sed 's/.*AITER_COMMIT_DEFAULT="\([^"]*\)".*/\1/')
else
echo "[CI-AITER-CHECK] Using gfx942 block from Dockerfile..."
REPO_AITER_COMMIT=$(grep -F -A20 'FROM $BASE_IMAGE_942 AS gfx942' docker/rocm.Dockerfile \
| grep 'AITER_COMMIT_DEFAULT=' \
| head -n1 \
| sed 's/.*AITER_COMMIT_DEFAULT="\([^"]*\)".*/\1/')
fi
if [[ -z "${REPO_AITER_COMMIT}" ]]; then
echo "[CI-AITER-CHECK] ERROR: Failed to extract AITER_COMMIT from Dockerfile."
exit 1
fi
echo "[CI-AITER-CHECK] Dockerfile expects AITER_COMMIT=${REPO_AITER_COMMIT}"
#############################################
# 2. Check container pre-installed AITER version
#############################################
IMAGE_AITER_VERSION=$(docker exec ci_sglang bash -c "pip show amd-aiter 2>/dev/null | grep '^Version:' | awk '{print \$2}'" || echo "none")
IMAGE_AITER_VERSION="v${IMAGE_AITER_VERSION}"
echo "[CI-AITER-CHECK] AITER version inside CI image: ${IMAGE_AITER_VERSION}"
#############################################
# 3. Decide rebuild
#############################################
NEED_REBUILD="false"
if [[ -n "${AITER_COMMIT_OVERRIDE:-}" ]]; then
echo "[CI-AITER-CHECK] AITER_COMMIT_OVERRIDE=${AITER_COMMIT_OVERRIDE} → forcing rebuild"
REPO_AITER_COMMIT="${AITER_COMMIT_OVERRIDE}"
NEED_REBUILD="true"
elif [[ "${IMAGE_AITER_VERSION}" == "vnone" || "${IMAGE_AITER_VERSION}" == "v" ]]; then
echo "[CI-AITER-CHECK] No AITER found in image → rebuild needed"
NEED_REBUILD="true"
elif [[ "${IMAGE_AITER_VERSION}" == "${REPO_AITER_COMMIT}" ]]; then
echo "[CI-AITER-CHECK] AITER version matches"
elif [[ "${IMAGE_AITER_VERSION}" =~ (dev|\+g[0-9a-f]+) ]]; then
# Dev/patched version (contains 'dev' or git hash) → preserve it
echo "[CI-AITER-CHECK] Dev/patched version detected: ${IMAGE_AITER_VERSION} → skipping rebuild"
else
echo "[CI-AITER-CHECK] Version mismatch: image=${IMAGE_AITER_VERSION}, repo=${REPO_AITER_COMMIT}"
NEED_REBUILD="true"
fi
#############################################
# 4. Rebuild AITER if needed
#############################################
if [[ "${NEED_REBUILD}" == "true" ]]; then
echo "[CI-AITER-CHECK] === AITER REBUILD START ==="
# uninstall existing aiter
docker exec ci_sglang pip uninstall -y amd-aiter || true
# delete old aiter directory
docker exec ci_sglang rm -rf /sgl-workspace/aiter
# clone a fresh copy to /sgl-workspace/aiter
docker exec ci_sglang git clone https://github.com/ROCm/aiter.git /sgl-workspace/aiter
# checkout correct version
docker exec ci_sglang bash -c "
cd /sgl-workspace/aiter && \
git fetch --all && \
git checkout ${REPO_AITER_COMMIT} && \
git submodule update --init --recursive
"
if [[ "${GPU_ARCH}" == "mi35x" ]]; then
GPU_ARCH_LIST="gfx950"
else
GPU_ARCH_LIST="gfx942"
fi
echo "[CI-AITER-CHECK] GPU_ARCH_LIST=${GPU_ARCH_LIST}"
# Re-apply Dockerfile hotpatches for ROCm 7.2 (the fresh clone lost them, can be removed after triton fixed this problem)
ROCM_VERSION=$(docker exec ci_sglang bash -c "cat /opt/rocm/.info/version 2>/dev/null || echo unknown")
if [[ "${ROCM_VERSION}" == 7.2* ]]; then
echo "[CI-AITER-CHECK] ROCm 7.2 detected (${ROCM_VERSION}), applying AITER hotpatches..."
docker exec ci_sglang bash -c "
cd /sgl-workspace/aiter && \
TARGET_FILE='aiter/ops/triton/attention/pa_mqa_logits.py' && \
if [ -f \"\${TARGET_FILE}\" ]; then \
sed -i '459 s/if.*:/if False:/' \"\${TARGET_FILE}\" && \
echo '[CI-AITER-CHECK] Hotpatch applied to pa_mqa_logits.py'; \
else \
echo '[CI-AITER-CHECK] pa_mqa_logits.py not found, skipping hotpatch'; \
fi
"
else
echo "[CI-AITER-CHECK] ROCm version=${ROCM_VERSION}, no hotpatch needed"
fi
# build AITER
docker exec ci_sglang bash -c "
cd /sgl-workspace/aiter && \
GPU_ARCHS=${GPU_ARCH_LIST} python3 setup.py develop
"
echo "[CI-AITER-CHECK] === AITER REBUILD COMPLETE ==="
fi
echo "[CI-AITER-CHECK] === AITER VERSION CHECK END ==="
# # Clear pre-built AITER kernels from Docker image to avoid segfaults
# # The Docker image may contain pre-compiled kernels incompatible with the current environment
# echo "Clearing pre-built AITER kernels from Docker image..."
# docker exec ci_sglang find /sgl-workspace/aiter/aiter/jit -name "*.so" -delete 2>/dev/null || true
# docker exec ci_sglang ls -la /sgl-workspace/aiter/aiter/jit/ 2>/dev/null || echo "jit dir empty or not found"
# # Pre-build AITER kernels to avoid timeout during tests
# echo "Warming up AITER JIT kernels..."
# docker exec -e SGLANG_USE_AITER=1 ci_sglang python3 /sglang-checkout/scripts/ci/amd/amd_ci_warmup_aiter.py || echo "AITER warmup completed (some kernels may not be available)"

View File

@@ -0,0 +1,248 @@
#!/bin/bash
set -euo pipefail
# Get version from git tags
SGLANG_VERSION="v0.5.5" # Default version, will be overridden if git tags are found
# Fetch tags from origin to ensure we have the latest
if git fetch --tags origin; then
# Get the latest version tag sorted by version number (e.g., v0.5.7)
VERSION_FROM_TAG=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1)
if [ -n "$VERSION_FROM_TAG" ]; then
SGLANG_VERSION="$VERSION_FROM_TAG"
echo "Using SGLang version from git tags: $SGLANG_VERSION"
else
echo "Warning: No version tags found; using default $SGLANG_VERSION" >&2
fi
else
echo "Warning: Failed to fetch tags from origin; using default $SGLANG_VERSION" >&2
fi
# Default base tags (can be overridden by command line arguments)
ROCM_VERSION="rocm700"
DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
# Parse command line arguments
MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
CUSTOM_IMAGE=""
BUILD_FROM_DOCKERFILE=""
GPU_ARCH_BUILD=""
while [[ $# -gt 0 ]]; do
case $1 in
--mi30x-base-tag) MI30X_BASE_TAG="$2"; shift 2;;
--mi35x-base-tag) MI35X_BASE_TAG="$2"; shift 2;;
--custom-image) CUSTOM_IMAGE="$2"; shift 2;;
--build-from-dockerfile) BUILD_FROM_DOCKERFILE="1"; shift;;
--gpu-arch) GPU_ARCH_BUILD="$2"; shift 2;;
--rocm-version)
ROCM_VERSION="$2"
MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
echo "Using ROCm version override: ${ROCM_VERSION}"
shift 2;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --mi30x-base-tag TAG Override MI30x base image tag"
echo " --mi35x-base-tag TAG Override MI35x base image tag"
echo " --custom-image IMAGE Use a specific Docker image directly"
echo " --build-from-dockerfile Build image from docker/rocm.Dockerfile"
echo " --gpu-arch ARCH GPU architecture for Dockerfile build (e.g., gfx950-rocm720)"
echo " --rocm-version VERSION Override ROCm version for image lookup (e.g., rocm720)"
exit 0
;;
*) echo "Unknown option $1"; exit 1;;
esac
done
# Detect GPU architecture from the Kubernetes runner hostname
HOSTNAME_VALUE=$(hostname)
GPU_ARCH="mi30x" # default
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
GPU_ARCH="${BASH_REMATCH[1]}"
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
else
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
fi
# Normalise / collapse architectures we don't yet build specifically for
case "${GPU_ARCH}" in
mi35x)
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
;;
mi30x|mi300|mi325)
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
GPU_ARCH="mi30x"
;;
*)
echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
GPU_ARCH="mi30x"
;;
esac
# Set up DEVICE_FLAG based on Kubernetes pod info
if [[ -f /etc/podinfo/gha-render-devices ]]; then
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
else
DEVICE_FLAG="--device /dev/dri"
fi
# Find the latest image
find_latest_image() {
local gpu_arch=$1
local base_tag days_back image_tag
case "${gpu_arch}" in
mi30x) base_tag="${MI30X_BASE_TAG}" ;;
mi35x) base_tag="${MI35X_BASE_TAG}" ;;
*) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
esac
# First, check local cache
for days_back in {0..6}; do
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
local local_image="rocm/sgl-dev:${image_tag}"
image_id=$(docker images -q "${local_image}")
if [[ -n "$image_id" ]]; then
echo "Found cached image locally: ${local_image}" >&2
echo "${local_image}"
return 0
fi
done
# If not found locally, fall back to pulling from public registry
for days_back in {0..6}; do
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
echo "rocm/sgl-dev:${image_tag}"
return 0
fi
done
# If still not found, try finding any image matching ROCm+arch from remote registry
echo "Exact version not found. Searching remote registry for any ${ROCM_VERSION}-${gpu_arch} image…" >&2
for days_back in {0..6}; do
local target_date=$(date -d "${days_back} days ago" +%Y%m%d)
local remote_tags=$(curl -s "https://registry.hub.docker.com/v2/repositories/rocm/sgl-dev/tags?page_size=100&name=${ROCM_VERSION}-${gpu_arch}-${target_date}" 2>/dev/null | grep -o '"name":"[^"]*"' | cut -d'"' -f4 | head -n 1)
if [[ -n "$remote_tags" ]]; then
echo "Found available image: rocm/sgl-dev:${remote_tags}" >&2
echo "rocm/sgl-dev:${remote_tags}"
return 0
fi
done
echo "No recent images found. Searching any cached local images matching ROCm+arch…" >&2
local any_local
any_local=$(docker images --format '{{.Repository}}:{{.Tag}}' --filter "reference=rocm/sgl-dev:*${ROCM_VERSION}*${gpu_arch}*" | sort -r | head -n 1)
if [[ -n "$any_local" ]]; then
echo "Using cached fallback image: ${any_local}" >&2
echo "${any_local}"
return 0
fi
echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
echo "Using hard-coded fallback for ${ROCM_VERSION}" >&2
case "${ROCM_VERSION}" in
rocm720)
if [[ "${gpu_arch}" == "mi35x" ]]; then
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi35x-20260211-preview"
else
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi30x-20260211-preview"
fi
;;
rocm700)
if [[ "${gpu_arch}" == "mi35x" ]]; then
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi35x-20260211"
else
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi30x-20260211"
fi
;;
*)
echo "Error: no hard-coded fallback available for ${ROCM_VERSION}" >&2
return 1
;;
esac
}
# Determine which image to use
if [[ -n "${CUSTOM_IMAGE}" ]]; then
# Use explicitly provided custom image
IMAGE="${CUSTOM_IMAGE}"
echo "Using custom image: ${IMAGE}"
docker pull "${IMAGE}"
elif [[ -n "${BUILD_FROM_DOCKERFILE}" ]]; then
# Build image from Dockerfile
if [[ -z "${GPU_ARCH_BUILD}" ]]; then
echo "Error: --gpu-arch is required when using --build-from-dockerfile" >&2
exit 1
fi
DOCKERFILE_DIR="${GITHUB_WORKSPACE:-$PWD}/docker"
DOCKERFILE="${DOCKERFILE_DIR}/rocm.Dockerfile"
if [[ ! -f "${DOCKERFILE}" ]]; then
echo "Error: Dockerfile not found at ${DOCKERFILE}" >&2
exit 1
fi
IMAGE="sglang-ci:${GPU_ARCH_BUILD}-$(date +%Y%m%d)"
echo "Building Docker image from ${DOCKERFILE} with GPU_ARCH=${GPU_ARCH_BUILD}..."
# Pass full GPU_ARCH (e.g., gfx950-rocm720) - Dockerfile handles stripping suffix
docker build \
--build-arg GPU_ARCH="${GPU_ARCH_BUILD}" \
--build-arg SGL_BRANCH="main" \
-t "${IMAGE}" \
-f "${DOCKERFILE}" \
"${DOCKERFILE_DIR}"
echo "Successfully built image: ${IMAGE}"
else
# Find the latest pre-built image
IMAGE=$(find_latest_image "${GPU_ARCH}")
echo "Pulling Docker image: ${IMAGE}"
docker pull "${IMAGE}"
fi
CACHE_HOST=/home/runner/sgl-data
if [[ -d "$CACHE_HOST" ]]; then
CACHE_VOLUME="-v $CACHE_HOST:/sgl-data"
else
CACHE_VOLUME=""
fi
echo "Launching container: ci_sglang"
docker run -dt --user root --device=/dev/kfd ${DEVICE_FLAG} \
--ulimit nofile=65536:65536 \
-v "${GITHUB_WORKSPACE:-$PWD}:/sglang-checkout" \
$CACHE_VOLUME \
--group-add video \
--shm-size 32g \
--cap-add=SYS_PTRACE \
-e HF_TOKEN="${HF_TOKEN:-}" \
-e HF_HOME=/sgl-data/hf-cache \
-e HF_HUB_ETAG_TIMEOUT=300 \
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
-e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \
-e MIOPEN_CUSTOM_CACHE_DIR=/sgl-data/miopen-cache \
-e PYTHONPATH="/opt/tilelang:${PYTHONPATH:-}" \
--security-opt seccomp=unconfined \
-w /sglang-checkout \
--name ci_sglang \
"${IMAGE}"
# The checkout is owned by the runner (non-root) but the container runs as
# root. Git >= 2.35.2 rejects cross-user repos; mark the mount as safe so
# setuptools-scm / vcs_versioning can resolve the package version.
docker exec ci_sglang git config --global --add safe.directory /sglang-checkout

View File

@@ -0,0 +1,270 @@
#!/bin/bash
set -euo pipefail
# Get version from git tags
SGLANG_VERSION="v0.5.5" # Default version, will be overridden if git tags are found
# Fetch tags from origin to ensure we have the latest
if git fetch --tags origin; then
# Get the latest version tag sorted by version number (e.g., v0.5.7)
VERSION_FROM_TAG=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1)
if [ -n "$VERSION_FROM_TAG" ]; then
SGLANG_VERSION="$VERSION_FROM_TAG"
echo "Using SGLang version from git tags: $SGLANG_VERSION"
else
echo "Warning: No version tags found; using default $SGLANG_VERSION" >&2
fi
else
echo "Warning: Failed to fetch tags from origin; using default $SGLANG_VERSION" >&2
fi
# Default base tags (can be overridden by command line arguments)
ROCM_VERSION="rocm700"
DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
# Parse command line arguments
MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
while [[ $# -gt 0 ]]; do
case $1 in
--mi30x-base-tag) MI30X_BASE_TAG="$2"; shift 2;;
--mi35x-base-tag) MI35X_BASE_TAG="$2"; shift 2;;
--rocm-version)
ROCM_VERSION="$2"
MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
echo "Using ROCm version override: ${ROCM_VERSION}"
shift 2;;
-h|--help)
echo "Usage: $0 [--mi30x-base-tag TAG] [--mi35x-base-tag TAG] [--rocm-version VERSION]"
exit 0
;;
*) echo "Unknown option $1"; exit 1;;
esac
done
# Detect GPU architecture from the Kubernetes runner hostname
HOSTNAME_VALUE=$(hostname)
GPU_ARCH="mi30x" # default
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
GPU_ARCH="${BASH_REMATCH[1]}"
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
else
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
fi
# Normalise / collapse architectures we dont yet build specifically for
case "${GPU_ARCH}" in
mi35x)
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
;;
mi30x|mi300|mi325)
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
GPU_ARCH="mi30x"
;;
*)
echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
GPU_ARCH="mi30x"
;;
esac
# Set up DEVICE_FLAG based on Kubernetes pod info
if [[ -f /etc/podinfo/gha-render-devices ]]; then
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
else
DEVICE_FLAG="--device /dev/dri"
fi
# Find the latest image
find_latest_image() {
local gpu_arch=$1
local base_tag days_back image_tag
case "${gpu_arch}" in
mi30x) base_tag="${MI30X_BASE_TAG}" ;;
mi35x) base_tag="${MI35X_BASE_TAG}" ;;
*) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
esac
# First, check local cache
for days_back in {0..6}; do
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
local local_image="rocm/sgl-dev:${image_tag}"
image_id=$(docker images -q "${local_image}")
if [[ -n "$image_id" ]]; then
echo "Found cached image locally: ${local_image}" >&2
echo "${local_image}"
return 0
fi
done
# If not found locally, fall back to pulling from public registry
for days_back in {0..6}; do
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
echo "rocm/sgl-dev:${image_tag}"
return 0
fi
done
# If still not found, try finding any image matching ROCm+arch from remote registry
echo "Exact version not found. Searching remote registry for any ${ROCM_VERSION}-${gpu_arch} image…" >&2
for days_back in {0..6}; do
local target_date=$(date -d "${days_back} days ago" +%Y%m%d)
local remote_tags=$(curl -s "https://registry.hub.docker.com/v2/repositories/rocm/sgl-dev/tags?page_size=100&name=${ROCM_VERSION}-${gpu_arch}-${target_date}" 2>/dev/null | grep -o '"name":"[^"]*"' | cut -d'"' -f4 | head -n 1)
if [[ -n "$remote_tags" ]]; then
echo "Found available image: rocm/sgl-dev:${remote_tags}" >&2
echo "rocm/sgl-dev:${remote_tags}"
return 0
fi
done
echo "No recent images found. Searching any cached local images matching ROCm+arch…" >&2
local any_local
any_local=$(docker images --format '{{.Repository}}:{{.Tag}}' --filter "reference=rocm/sgl-dev:*${ROCM_VERSION}*${gpu_arch}*" | sort -r | head -n 1)
if [[ -n "$any_local" ]]; then
echo "Using cached fallback image: ${any_local}" >&2
echo "${any_local}"
return 0
fi
echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
echo "Using hard-coded fallback for ${ROCM_VERSION}" >&2
case "${ROCM_VERSION}" in
rocm720)
if [[ "${gpu_arch}" == "mi35x" ]]; then
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi35x-20260211-preview"
else
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi30x-20260211-preview"
fi
;;
rocm700)
if [[ "${gpu_arch}" == "mi35x" ]]; then
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi35x-20260211"
else
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi30x-20260211"
fi
;;
*)
echo "Error: no hard-coded fallback available for ${ROCM_VERSION}" >&2
return 1
;;
esac
}
# Pull and run the latest image
IMAGE=$(find_latest_image "${GPU_ARCH}")
echo "Pulling Docker image: ${IMAGE}"
docker pull "${IMAGE}"
CACHE_HOST=/home/runner/sgl-data
if [[ -d "$CACHE_HOST" ]]; then
CACHE_VOLUME="-v $CACHE_HOST:/sgl-data"
else
CACHE_VOLUME=""
fi
# Detect libionic library for RDMA support
LIBIONIC_MOUNT=""
IONIC_SYMLINK="/usr/lib/x86_64-linux-gnu/libibverbs/libionic-rdmav34.so"
if [[ -L "$IONIC_SYMLINK" ]]; then
LIBIONIC_LIB=$(readlink -f "$IONIC_SYMLINK" 2>/dev/null)
if [[ -f "$LIBIONIC_LIB" ]]; then
echo "Found libionic library: $LIBIONIC_LIB (resolved from symlink)"
LIBIONIC_MOUNT="-v ${LIBIONIC_LIB}:${LIBIONIC_LIB}:ro"
else
echo "Warning: libionic symlink exists but target does not: $LIBIONIC_LIB"
fi
else
# Fallback: try to find directly
LIBIONIC_FOUND=$(find /usr/lib/x86_64-linux-gnu -maxdepth 1 -name "libionic.so.*" 2>/dev/null | head -1)
if [[ -n "$LIBIONIC_FOUND" ]]; then
LIBIONIC_LIB=$(readlink -f "$LIBIONIC_FOUND" 2>/dev/null)
if [[ -f "$LIBIONIC_LIB" ]]; then
echo "Found libionic library: $LIBIONIC_LIB"
LIBIONIC_MOUNT="-v ${LIBIONIC_LIB}:${LIBIONIC_LIB}:ro"
else
echo "Warning: libionic found but cannot resolve real path: $LIBIONIC_FOUND"
fi
else
echo "Warning: libionic library not found on host, RDMA may not work"
fi
fi
MOUNT_ARGS=""
add_mount_if_exists() {
local name=$1
local search_pattern=$2
local path=$(find /lib/x86_64-linux-gnu /usr/lib/x86_64-linux-gnu /lib64 /usr/lib64 -name "$search_pattern" -print -quit 2>/dev/null)
if [ -n "$path" ]; then
echo "Found $name at: $path"
MOUNT_ARGS="$MOUNT_ARGS -v $path:$path:ro"
else
echo "WARNING: Could not find $name on host! (Pattern: $search_pattern)"
fi
}
IONIC_LINK="/usr/lib/x86_64-linux-gnu/libibverbs/libionic-rdmav34.so"
if [ -L "$IONIC_LINK" ]; then
IONIC_REAL=$(readlink -f "$IONIC_LINK")
if [ -f "$IONIC_REAL" ]; then
echo "Ionic Driver: $IONIC_REAL"
MOUNT_ARGS="$MOUNT_ARGS -v $IONIC_REAL:$IONIC_REAL:ro"
fi
fi
add_mount_if_exists "libnl-3" "libnl-3.so*"
add_mount_if_exists "libmnl" "libmnl.so*"
echo "Mount args: $MOUNT_ARGS"
echo "Launching container: ci_sglang"
docker run -dt --user root \
--device=/dev/kfd \
--device=/dev/dri \
${DEVICE_FLAG} \
-v "${GITHUB_WORKSPACE:-$PWD}:/sglang-checkout" \
-v /sys/class/infiniband:/sys/class/infiniband:ro \
-v /sys/class/infiniband_verbs:/sys/class/infiniband_verbs:ro \
-v /sys/class/net:/sys/class/net:ro \
-v /etc/libibverbs.d:/etc/libibverbs.d:ro \
-v /usr/lib/x86_64-linux-gnu/libibverbs:/usr/lib/x86_64-linux-gnu/libibverbs:ro \
$MOUNT_ARGS \
$CACHE_VOLUME \
--privileged \
--network=host \
--ipc=host \
--ulimit memlock=-1 \
--cap-add=IPC_LOCK \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--group-add video \
--group-add rdma \
--shm-size 32g \
-e HF_TOKEN="${HF_TOKEN:-}" \
-e HF_HOME=/sgl-data/hf-cache \
-e HF_HUB_ETAG_TIMEOUT=300 \
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
-e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \
-e MIOPEN_CUSTOM_CACHE_DIR=/sgl-data/miopen-cache \
-w /sglang-checkout \
--name ci_sglang \
"${IMAGE}"
# The checkout is owned by the runner (non-root) but the container runs as
# root. Git >= 2.35.2 rejects cross-user repos; mark the mount as safe so
# setuptools-scm / vcs_versioning can resolve the package version.
docker exec ci_sglang git config --global --add safe.directory /sglang-checkout

View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""
Warmup script to pre-build AITER JIT kernels.
This script triggers compilation of commonly used AITER kernels by importing
the relevant modules and calling functions with sample data. This avoids
timeouts during actual tests when kernels need to be compiled on first use.
Run this after clearing pre-built AITER kernels from the Docker image.
"""
import os
import sys
import time
# Ensure AITER is enabled
os.environ["SGLANG_USE_AITER"] = "1"
def warmup_aiter_kernels():
"""Trigger AITER JIT kernel compilation."""
import torch
if not torch.cuda.is_available():
print("CUDA/ROCm not available, skipping AITER warmup")
return
print("=" * 60)
print("AITER JIT Kernel Warmup")
print("=" * 60)
device = torch.device("cuda:0")
start_time = time.time()
# Warmup module_rmsnorm_quant (small module, ~2MB)
# Triggered by rmsnorm2d_fwd when hidden_size <= 8192
try:
print(
"\n[1/5] Warming up module_rmsnorm_quant (rmsnorm2d_fwd, hidden<=8192)..."
)
from aiter import rmsnorm2d_fwd
hidden_size = 4096
batch_size = 512 # Use larger batch to match CUDA graph capture
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
eps = 1e-6
# hidden_size=4096 <= 8192 -> takes rmsnorm() path -> compiles module_rmsnorm_quant
_ = rmsnorm2d_fwd(x, weight, eps)
torch.cuda.synchronize()
print(" module_rmsnorm_quant compiled successfully")
except Exception as e:
print(f" module_rmsnorm_quant warmup failed: {e}")
# Warmup module_rmsnorm (large CK module, ~159MB)
# Triggered by rmsnorm2d_fwd_with_add (always uses CK path)
# NOTE: rmsnorm2d_fwd_with_add signature is:
# rmsnorm2d_fwd_with_add(out, input, residual_in, residual_out, weight, epsilon)
try:
print("\n[2/5] Warming up module_rmsnorm (rmsnorm2d_fwd_with_add, CK path)...")
from aiter import rmsnorm2d_fwd_with_add
hidden_size = 4096
batch_size = 512
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
residual_in = torch.randn(
batch_size, hidden_size, dtype=torch.bfloat16, device=device
)
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
eps = 1e-6
# This triggers JIT compilation of module_rmsnorm (CK kernels)
rmsnorm2d_fwd_with_add(output, x, residual_in, residual_out, weight, eps)
torch.cuda.synchronize()
print(" module_rmsnorm compiled successfully")
except Exception as e:
print(f" module_rmsnorm warmup failed: {e}")
# Warmup module_rmsnorm via rmsnorm2d_fwd with large hidden_size (CK path)
# When hidden_size > 8192, rmsnorm2d_fwd takes the rmsnorm2d_fwd_ck path
# which also uses module_rmsnorm (already compiled in step 2, but this
# ensures the CK rmsnorm2d_fwd path is exercised as well)
try:
print("\n[3/5] Warming up rmsnorm2d_fwd CK path (hidden>8192)...")
from aiter import rmsnorm2d_fwd
hidden_size = 16384 # > 8192 to trigger rmsnorm2d_fwd_ck (module_rmsnorm)
batch_size = 32
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
eps = 1e-6
_ = rmsnorm2d_fwd(x, weight, eps)
torch.cuda.synchronize()
print(" rmsnorm2d_fwd CK path compiled successfully")
except Exception as e:
print(f" rmsnorm2d_fwd CK path warmup skipped: {e}")
# Warmup rotary embedding kernel if available
try:
print("\n[4/5] Warming up rotary embedding kernel...")
from aiter import rotary_embedding
head_size = 128
seq_len = 32
num_heads = 32
positions = torch.arange(seq_len, device=device)
query = torch.randn(
seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
)
key = torch.randn(
seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
)
cos = torch.ones(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)
sin = torch.zeros(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)
_ = rotary_embedding(positions, query, key, head_size, cos, sin, True)
torch.cuda.synchronize()
print(" Rotary embedding kernel compiled successfully")
except Exception as e:
print(f" Rotary embedding warmup skipped (may not be available): {e}")
# Warmup activation kernels if available
try:
print("\n[5/5] Warming up activation kernels...")
from aiter import silu_and_mul
hidden_size = 4096
batch_size = 512
x = torch.randn(
batch_size, hidden_size * 2, dtype=torch.bfloat16, device=device
)
out = torch.empty(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
silu_and_mul(out, x)
torch.cuda.synchronize()
print(" Activation kernel compiled successfully")
except Exception as e:
print(f" Activation warmup skipped (may not be available): {e}")
elapsed = time.time() - start_time
print("\n" + "=" * 60)
print(f"AITER warmup completed in {elapsed:.1f}s")
print("=" * 60 + "\n")
if __name__ == "__main__":
warmup_aiter_kernels()

View File

@@ -0,0 +1,27 @@
#!/bin/bash
check_vram_clear() {
local vram_threshold_percent=5 # Allow up to 5% VRAM usage
local memory_threshold_mb=500 # Allow up to 500MB memory usage
if command -v rocm-smi >/dev/null 2>&1; then
echo "Checking ROCm GPU VRAM usage..."
# Check if any GPU has more than threshold VRAM allocated
local high_usage=$(rocm-smi --showmemuse | grep -E "GPU Memory Allocated \(VRAM%\): ([6-9]|[1-9][0-9]|100)")
if [ -n "$high_usage" ]; then
echo "ERROR: VRAM usage exceeds threshold (${vram_threshold_percent}%) on some GPUs:"
echo "$high_usage"
rocm-smi --showmemuse
return 1
else
echo "✓ VRAM usage is within acceptable limits on all GPUs"
return 0
fi
fi
}
# If this script is run directly (not sourced), run the check
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
set -e
check_vram_clear
fi

View File

@@ -0,0 +1,103 @@
#!/bin/bash
# Source the VRAM checking function
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "$SCRIPT_DIR/check_vram_clear.sh"
ensure_vram_clear() {
local max_retries=3
local retry_count=0
# Stop and remove any existing ci_sglang container
echo "Stopping any existing ci_sglang container..."
docker stop ci_sglang || true
docker rm ci_sglang || true
# Log host information for debugging
echo "=== Host Information ==="
echo "Hostname: $(hostname)"
echo "Host IP: $(hostname -I 2>/dev/null || echo 'N/A')"
echo "Date: $(date)"
echo "Mode: rocm"
echo "========================"
echo "Running in ROCm mode"
# Show initial GPU status
echo "=== Initial GPU Memory Status ==="
rocm-smi --showmemuse
echo "=================================="
while [ $retry_count -lt $max_retries ]; do
echo "=== Cleanup Attempt $((retry_count + 1))/$max_retries ==="
# Clean SGLang processes
echo "Killing SGLang processes..."
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9 || true
if [ $retry_count -gt 0 ]; then
echo "Performing aggressive cleanup..."
# Kill all processes using KFD
rocm-smi --showpids 2>/dev/null | grep 'PID:' | awk '{print $2}' | xargs -r kill -9 2>/dev/null || true
# Wait a bit for cleanup to take effect
echo "Waiting 30 seconds for VRAM to clear..."
sleep 30
fi
# Check VRAM
echo "Checking VRAM status..."
if check_vram_clear; then
echo "✓ VRAM cleanup successful after $((retry_count + 1)) attempts"
return 0
else
echo "✗ VRAM still not clear after attempt $((retry_count + 1))"
retry_count=$((retry_count + 1))
fi
done
# Failed after all retries
echo "=== FAILED: VRAM cleanup unsuccessful after $max_retries attempts ==="
echo "Final GPU status:"
timeout 30 rocm-smi --showmemuse || echo "rocm-smi timed out"
echo "Processes using GPU:"
rocm-smi --showpids 2>/dev/null | grep -q 'PID:' || echo "No processes found using /dev/kfd"
# Print detailed information about suspicious processes
echo "=== Detailed Process Information ==="
if command -v rocm-smi >/dev/null 2>&1; then
# For AMD GPUs, get processes from rocm-smi --showpids
kfd_pids=$(rocm-smi --showpids 2>/dev/null | grep 'PID:' | awk '{print $2}' | sort -u)
if [ -n "$kfd_pids" ]; then
echo "Processes accessing /dev/kfd (AMD GPU device):"
for pid in $kfd_pids; do
if ps -p $pid -o pid,ppid,cmd --no-headers 2>/dev/null; then
echo " └─ Command line: $(ps -p $pid -o cmd --no-headers 2>/dev/null | head -1)"
else
echo " └─ PID $pid: Process not found or already terminated"
fi
done
else
echo "No processes found accessing /dev/kfd"
fi
fi
# Check for any remaining sglang-related processes
echo "Checking for any remaining sglang-related processes:"
sglang_procs=$(pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' 2>/dev/null)
if [ -n "$sglang_procs" ]; then
echo "Found sglang processes still running:"
for pid in $sglang_procs; do
ps -p $pid -o pid,ppid,cmd --no-headers 2>/dev/null || echo "PID $pid not found"
done
else
echo "No sglang-related processes found."
fi
echo "=================================================================="
return 1
}
# If this script is run directly (not sourced), run the ensure function
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
set -e
ensure_vram_clear "$@"
fi

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python3
"""
Simple RCCL test for multi-GPU communication.
This test verifies that RCCL can initialize and communicate across multiple GPUs.
"""
import os
import sys
import torch
import torch.distributed as dist
def test_rccl_allreduce():
"""Test basic RCCL allreduce operation across all GPUs."""
if not torch.cuda.is_available():
print("CUDA not available, skipping test")
sys.exit(1)
# Initialize process group with NCCL (RCCL on AMD)
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"[Rank {rank}/{world_size}] Initialized successfully")
# Set device
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
print(f"[Rank {rank}] Device: {torch.cuda.get_device_name(device)}")
print(
f"[Rank {rank}] Device memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB"
)
# Create a tensor and perform allreduce
tensor = torch.ones(1000, device=device) * rank
print(f"[Rank {rank}] Before allreduce: tensor sum = {tensor.sum().item()}")
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
expected_sum = sum(range(world_size)) * 1000
actual_sum = tensor.sum().item()
print(
f"[Rank {rank}] After allreduce: tensor sum = {actual_sum}, expected = {expected_sum}"
)
if abs(actual_sum - expected_sum) < 0.1:
print(f"[Rank {rank}] ✓ RCCL allreduce test PASSED")
dist.destroy_process_group()
sys.exit(0)
else:
print(f"[Rank {rank}] ✗ RCCL allreduce test FAILED")
dist.destroy_process_group()
sys.exit(1)
if __name__ == "__main__":
test_rccl_allreduce()

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env python3
"""Check that required status check job names are unique across workflows.
Duplicate job names on the same commit allow a passing job in one workflow
to satisfy a required status check meant for a different workflow, bypassing
branch protection.
See: https://github.com/sgl-project/sglang/pull/20208 for an example where
pr-test-npu.yml's "pr-test-finish" job (which passed) caused GitHub to treat
the required "pr-test-finish" check (from pr-test.yml, which failed) as met.
"""
import glob
import sys
from collections import defaultdict
import yaml
# Job names used as required status checks in branch protection.
# These MUST be unique across all workflow files.
PROTECTED_JOB_NAMES = {
"pr-test-finish",
"lint",
}
def main() -> int:
workflows = sorted(glob.glob(".github/workflows/*.yml"))
job_to_files: dict[str, list[str]] = defaultdict(list)
for wf in workflows:
with open(wf) as f:
data = yaml.safe_load(f)
if not data or "jobs" not in data:
continue
for job in data["jobs"]:
if job in PROTECTED_JOB_NAMES:
job_to_files[job].append(wf)
duplicates = {job: files for job, files in job_to_files.items() if len(files) > 1}
if not duplicates:
return 0
print("ERROR: Required status check job names must be unique across workflows.")
print("Duplicates allow branch protection bypass via auto-merge.\n")
for job, files in sorted(duplicates.items()):
print(f" Job '{job}' appears in:")
for f in files:
print(f" - {f}")
print()
print("Fix: rename the job in non-primary workflows to avoid collision.")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,24 @@
#!/bin/bash
# Cache and pre-install nvidia wheels that torch pins.
#
# pypi.nvidia.com returns Cache-Control: no-store, so pip re-downloads
# cudnn (~707 MB) and nvshmem (~125 MB) on every CI run. This script
# caches the wheels locally and installs them so that the subsequent
# `pip install -e "python[dev]"` sees "Requirement already satisfied".
#
# Integrity: uses `unzip -t` to detect partial/corrupt downloads.
#
# Usage: source scripts/ci/cuda/cache_nvidia_wheels.sh
NVIDIA_WHEEL_CACHE="/root/.cache/nvidia-wheels"
mkdir -p "$NVIDIA_WHEEL_CACHE"
for url in \
"https://pypi.nvidia.com/nvidia-cudnn-cu12/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl" \
"https://pypi.nvidia.com/nvidia-nvshmem-cu12/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"; do
whl="$NVIDIA_WHEEL_CACHE/$(basename "$url")"
[ -f "$whl" ] && unzip -tq "$whl" &>/dev/null || curl -fL -o "$whl" "$url"
done
pip install --no-deps "$NVIDIA_WHEEL_CACHE"/nvidia_cudnn_cu12-*.whl \
"$NVIDIA_WHEEL_CACHE"/nvidia_nvshmem_cu12-*.whl 2>/dev/null || true

View File

@@ -0,0 +1,62 @@
#!/bin/bash
# Download flashinfer cubins if the local set is incomplete.
#
# The flashinfer-cubin pip package may not include cubins for newer architectures
# (e.g. sm_100, sm_120) due to PyPI size limits. This script checks the local
# cubin status against the flashinfer artifact repository and downloads any
# missing files.
#
# This script is best-effort: if the status check or download times out (e.g.
# due to a GPU in error state blocking CUDA init), we warn and continue.
# The pip package already includes cubins for common architectures (sm_80, sm_90).
set -uxo pipefail
# Early exit: the pip package already includes cubins for sm_80 and sm_90.
# Only sm_100+ (Blackwell) needs extra cubins downloaded. Skip the expensive
# Python status check entirely if no such GPU is present.
if COMPUTE_CAPS=$(timeout 10 nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null); then
NEEDS_EXTRA_CUBINS=false
while IFS= read -r cap; do
major="${cap%%.*}"
if [ "$major" -ge 10 ] 2>/dev/null; then
NEEDS_EXTRA_CUBINS=true
break
fi
done <<< "$COMPUTE_CAPS"
if [ "$NEEDS_EXTRA_CUBINS" = false ]; then
echo "All GPUs are sm_9x or older (compute caps: $(echo $COMPUTE_CAPS | tr '\n' ' ')), pip cubins sufficient — skipping download"
exit 0
fi
fi
# Use timeout to prevent hangs when GPUs are in error state (the flashinfer
# import can trigger CUDA init which blocks on bad GPUs).
CUBIN_STATUS=$(timeout 60 python3 -c "
import os
os.environ.setdefault('CUDA_VISIBLE_DEVICES', '')
from flashinfer.artifacts import get_artifacts_status
status = get_artifacts_status()
total = len(status)
downloaded = sum(1 for _, exists in status if exists)
print(f'{downloaded}/{total}')
" 2>/dev/null) || CUBIN_STATUS="unknown"
echo "Flashinfer cubin status: ${CUBIN_STATUS}"
if echo "$CUBIN_STATUS" | grep -qE '^[0-9]+/[0-9]+$'; then
CUBIN_DOWNLOADED="${CUBIN_STATUS%/*}"
CUBIN_TOTAL="${CUBIN_STATUS#*/}"
if [ "$CUBIN_DOWNLOADED" = "$CUBIN_TOTAL" ] && [ "$CUBIN_TOTAL" != "0" ]; then
echo "All flashinfer cubins already present (${CUBIN_STATUS}), skipping download"
else
echo "Cubins incomplete (${CUBIN_STATUS}), downloading..."
if ! timeout 300 env FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin; then
echo "WARNING: flashinfer cubin download failed or timed out, continuing with existing cubins"
fi
fi
else
echo "Could not determine cubin status (status check timed out or failed), attempting download..."
if ! timeout 300 env FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin; then
echo "WARNING: flashinfer cubin download failed or timed out, continuing with existing cubins"
fi
fi

View File

@@ -0,0 +1,69 @@
#!/bin/bash
# Install flashinfer-jit-cache with caching and retry logic (flashinfer.ai can have transient DNS issues).
# The jit-cache wheel is 1.2+ GB, so we skip the download entirely if already installed.
#
# Required environment (caller must export or set):
# UNINSTALL_JIT_CACHE — literal true/false (skip download when false)
# FLASHINFER_PYTHON_REQUIRED — e.g. from python/pyproject.toml (flashinfer_python)
# CU_VERSION — e.g. cu129
# PIP_CMD — e.g. "pip" or "uv pip"
# PIP_INSTALL_SUFFIX — extra pip args for this runner
set -euxo pipefail
: "${UNINSTALL_JIT_CACHE:?must be set}"
: "${FLASHINFER_PYTHON_REQUIRED:?must be set}"
: "${CU_VERSION:?must be set}"
: "${PIP_CMD:?must be set}"
FLASHINFER_JIT_CACHE_INSTALLED=false
if [ "$UNINSTALL_JIT_CACHE" = false ]; then
FLASHINFER_JIT_CACHE_INSTALLED=true
echo "flashinfer-jit-cache already at correct version, skipping download"
fi
if [ "$FLASHINFER_JIT_CACHE_INSTALLED" = false ]; then
FLASHINFER_CACHE_DIR="${HOME}/.cache/flashinfer-wheels"
mkdir -p "${FLASHINFER_CACHE_DIR}"
FLASHINFER_WHEEL_PATTERN="flashinfer_jit_cache-${FLASHINFER_PYTHON_REQUIRED}*.whl"
CACHED_WHEEL=$(find "${FLASHINFER_CACHE_DIR}" -name "${FLASHINFER_WHEEL_PATTERN}" -type f 2>/dev/null | head -n 1)
if [ -n "$CACHED_WHEEL" ] && [ -f "$CACHED_WHEEL" ]; then
echo "Found cached flashinfer wheel: $CACHED_WHEEL"
if $PIP_CMD install "$CACHED_WHEEL" $PIP_INSTALL_SUFFIX; then
FLASHINFER_JIT_CACHE_INSTALLED=true
echo "Successfully installed flashinfer-jit-cache from cache"
else
echo "Failed to install from cache, will try downloading..."
rm -f "$CACHED_WHEEL"
fi
fi
if [ "$FLASHINFER_JIT_CACHE_INSTALLED" = false ]; then
for i in {1..5}; do
# Download wheel to cache directory (use pip directly as uv pip doesn't support download)
if timeout 600 pip download "flashinfer-jit-cache==${FLASHINFER_PYTHON_REQUIRED}" \
--index-url "https://flashinfer.ai/whl/${CU_VERSION}" \
-d "${FLASHINFER_CACHE_DIR}"; then
CACHED_WHEEL=$(find "${FLASHINFER_CACHE_DIR}" -name "${FLASHINFER_WHEEL_PATTERN}" -type f 2>/dev/null | head -n 1)
if [ -n "$CACHED_WHEEL" ] && [ -f "$CACHED_WHEEL" ]; then
if $PIP_CMD install "$CACHED_WHEEL" $PIP_INSTALL_SUFFIX; then
FLASHINFER_JIT_CACHE_INSTALLED=true
echo "Successfully downloaded and installed flashinfer-jit-cache"
break
fi
else
echo "Warning: Download succeeded but wheel file not found"
fi
fi
echo "Attempt $i to download flashinfer-jit-cache failed, retrying in 10 seconds..."
sleep 10
done
fi
fi
if [ "$FLASHINFER_JIT_CACHE_INSTALLED" = false ]; then
echo "ERROR: Failed to install flashinfer-jit-cache after 5 attempts"
exit 1
fi

View File

@@ -0,0 +1,119 @@
#!/bin/bash
# Install the dependency in CI.
set -euxo pipefail
bash scripts/ci/cuda/ci_install_dependency.sh
export GDRCOPY_HOME=/usr/src/gdrdrv-2.5.1/
export CUDA_HOME=/usr/local/cuda
GRACE_BLACKWELL=${GRACE_BLACKWELL:-0}
# Detect architecture
ARCH=$(uname -m)
if [ "$ARCH" != "x86_64" ] && [ "$ARCH" != "aarch64" ]; then
echo "Unsupported architecture: $ARCH"
exit 1
fi
if python3 -c "import deep_ep" >/dev/null 2>&1; then
echo "deep_ep is already installed or importable. Skipping installation."
exit 0
fi
# Install system dependencies
# Use fallback logic in case apt fails due to unrelated broken packages on the runner
DEEPEP_SYSTEM_DEPS="curl wget git sudo rdma-core infiniband-diags openssh-server perftest libibumad3 libibverbs-dev libibverbs1 ibverbs-providers ibverbs-utils libnl-3-200 libnl-route-3-200 librdmacm1 build-essential cmake"
apt-get install -y --no-install-recommends $DEEPEP_SYSTEM_DEPS || {
echo "Warning: apt-get install failed, checking if required packages are available..."
for pkg in $DEEPEP_SYSTEM_DEPS; do
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
echo "ERROR: Required package $pkg is not installed and apt-get failed"
exit 1
fi
done
echo "All required packages are already installed, continuing..."
}
# Install GDRCopy
rm -rf /opt/gdrcopy && mkdir -p /opt/gdrcopy
cd /opt/gdrcopy
git clone https://github.com/NVIDIA/gdrcopy.git .
git checkout v2.5.1
apt-get update || true # May fail due to unrelated broken packages
GDRCOPY_DEPS_1="nvidia-dkms-580"
GDRCOPY_DEPS_2="build-essential devscripts debhelper fakeroot pkg-config dkms"
GDRCOPY_DEPS_3="check libsubunit0 libsubunit-dev python3-venv"
for deps_group in "$GDRCOPY_DEPS_1" "$GDRCOPY_DEPS_2" "$GDRCOPY_DEPS_3"; do
apt-get install -y --no-install-recommends $deps_group || {
echo "Warning: apt-get install failed for '$deps_group', checking if packages are available..."
for pkg in $deps_group; do
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
echo "ERROR: Required package $pkg is not installed and apt-get failed"
exit 1
fi
done
echo "All required packages from '$deps_group' are already installed, continuing..."
}
done
cd packages
CUDA=/usr/local/cuda ./build-deb-packages.sh
dpkg -i gdrdrv-dkms_*.deb
dpkg -i libgdrapi_*.deb
dpkg -i gdrcopy-tests_*.deb
dpkg -i gdrcopy_*.deb
# Set up library paths based on architecture
LIB_PATH="/usr/lib/$ARCH-linux-gnu"
if [ ! -e "$LIB_PATH/libmlx5.so" ]; then
ln -s $LIB_PATH/libmlx5.so.1 $LIB_PATH/libmlx5.so
fi
apt-get update || true
apt-get install -y --no-install-recommends libfabric-dev || {
if ! dpkg -l libfabric-dev 2>/dev/null | grep -q "^ii"; then
echo "ERROR: Required package libfabric-dev is not installed and apt-get failed"
exit 1
fi
echo "libfabric-dev is already installed, continuing..."
}
# Install DeepEP
DEEPEP_DIR=/root/.cache/deepep
rm -rf ${DEEPEP_DIR}
if [ "$GRACE_BLACKWELL" = "1" ]; then
# We use Tom's DeepEP fork for GB200 for now, which supports fp4 dispatch.
GRACE_BLACKWELL_DEEPEP_BRANCH=gb200_blog_part_2
git clone https://github.com/fzyzcjy/DeepEP.git ${DEEPEP_DIR} && \
pushd ${DEEPEP_DIR} && \
git checkout ${GRACE_BLACKWELL_DEEPEP_BRANCH} && \
sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \
popd
else
git clone https://github.com/deepseek-ai/DeepEP.git ${DEEPEP_DIR} && \
pushd ${DEEPEP_DIR} && \
git checkout 9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee && \
popd
fi
cd ${DEEPEP_DIR}
if [ "$GRACE_BLACKWELL" = "1" ]; then
CUDA_VERSION=$(nvidia-smi | grep "CUDA Version" | head -n1 | awk '{print $9}')
if [ "$CUDA_VERSION" = "12.8" ]; then
CHOSEN_TORCH_CUDA_ARCH_LIST='10.0'
elif awk -v ver="$CUDA_VERSION" 'BEGIN {exit !(ver > 12.8)}'; then
# With cuda > 12.8, the compiler supports 10.3, so we should use
# CHOSEN_TORCH_CUDA_ARCH_LIST='10.0;10.3'
#
# However, our CI machine has a weird setup and nvidia-smi reports wrong CUDA version in the container.
# The container is actually cuda 12.8, but nvidia-smi reports 13.0, leading to compilation errors. so we
# drop 10.3.
CHOSEN_TORCH_CUDA_ARCH_LIST='10.0'
else
echo "Unsupported CUDA version for Grace Blackwell: $CUDA_VERSION" && exit 1
fi && \
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
sed -i "/^ include_dirs = \['csrc\/'\]/a\ include_dirs.append('${CUDA_HOME}/include/cccl')" setup.py; \
fi
TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install --no-build-isolation .
else
python3 setup.py install
fi

View File

@@ -0,0 +1,384 @@
#!/bin/bash
# Install the dependency in CI.
#
# Structure (see section banners below):
# - Configuration & timing
# - Host / runner detection (arch, Blackwell, pip vs uv)
# - Kill existing processes
# - Install apt packages
# - Python package site hygiene & install protoc
# - Pip / uv toolchain & stale package cleanup
# - Uninstall Flashinfer
# - Install main package
# - Install sglang-kernel
# - Install sglang-router
# - Download flashinfer artifacts
# - Install extra dependency
# - Fix other dependencies
# - Prepare runner
# - Verify imports
set -euxo pipefail
# ------------------------------------------------------------------------------
# Configuration & timing
# ------------------------------------------------------------------------------
# Set up environment variables
CU_VERSION="cu129"
# Nvidia package versions we override (torch pins older versions).
# Used both as pip constraints during install and for post-install verification.
NVIDIA_CUDNN_VERSION="9.16.0.29"
NVIDIA_NVSHMEM_VERSION="3.4.5"
OPTIONAL_DEPS="${1:-}"
SECONDS=0
_CI_MARK_PREV=${SECONDS}
mark_step_done() {
local label=$1
local now=${SECONDS}
local step=$((now - _CI_MARK_PREV))
printf '\n[STEP DONE] %s, step: %ss, total: %ss, date: %s\n' \
"${label}" "${step}" "${now}" "$(date -u '+%Y-%m-%dT%H:%M:%SZ')"
_CI_MARK_PREV=${now}
}
mark_step_done "Configuration"
# ------------------------------------------------------------------------------
# Host / runner detection (CPU arch, Blackwell, USE_UV)
# ------------------------------------------------------------------------------
# Detect CPU architecture (x86_64 or aarch64)
ARCH=$(uname -m)
echo "Detected architecture: ${ARCH}"
# Detect GPU architecture (blackwell or not)
if [ "${IS_BLACKWELL+set}" = set ]; then
case "$IS_BLACKWELL" in 1 | true | yes) IS_BLACKWELL=1 ;; *) IS_BLACKWELL=0 ;; esac
echo "IS_BLACKWELL=${IS_BLACKWELL} (manually set via environment)"
else
IS_BLACKWELL=0
if command -v nvidia-smi >/dev/null 2>&1; then
while IFS= read -r cap; do
major="${cap%%.*}"
if [ "${major:-0}" -ge 10 ] 2>/dev/null; then
IS_BLACKWELL=1
break
fi
done <<< "$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null || true)"
fi
echo "IS_BLACKWELL=${IS_BLACKWELL} (auto-detected via nvidia-smi)"
fi
# Whether to use pip or uv to install dependencies
if [ "${USE_UV+set}" != set ]; then
if [ "$IS_BLACKWELL" = "1" ]; then
# Our current b200 runners have some issues with uv, so we default to pip
# It is a runner specific issue, not a GPU architecture issue.
USE_UV=false
else
USE_UV=true
fi
fi
case "$(printf '%s' "$USE_UV" | tr '[:upper:]' '[:lower:]')" in 1 | true | yes) USE_UV=1 ;; *) USE_UV=0 ;; esac
echo "USE_UV=${USE_UV}"
mark_step_done "Host / runner detection"
# ------------------------------------------------------------------------------
# Kill existing processes
# ------------------------------------------------------------------------------
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
python3 "${REPO_ROOT}/python/sglang/cli/killall.py"
KILLALL_EXIT=$?
echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-}"
if [ $KILLALL_EXIT -ne 0 ]; then
echo "ERROR: killall.py detected uncleanable GPU memory. Aborting CI."
exit 1
fi
mark_step_done "Kill existing processes"
# ------------------------------------------------------------------------------
# Install apt packages
# ------------------------------------------------------------------------------
# Install apt packages (including python3/pip which may be missing on some runners)
# Use --no-install-recommends and ignore errors from unrelated broken packages on the runner
# The NVIDIA driver packages may have broken dependencies that are unrelated to these packages
# Run apt-get update first to refresh package index (stale index causes 404 on security.ubuntu.com)
apt-get update || true
CI_APT_PACKAGES=(
python3 python3-pip python3-venv python3-dev git libnuma-dev libssl-dev pkg-config
libibverbs-dev libibverbs1 ibverbs-providers ibverbs-utils
ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev
)
apt-get install -y --no-install-recommends "${CI_APT_PACKAGES[@]}" || {
echo "Warning: apt-get install failed, checking if required packages are available..."
for pkg in "${CI_APT_PACKAGES[@]}"; do
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
echo "ERROR: Required package $pkg is not installed and apt-get failed"
exit 1
fi
done
echo "All required packages are already installed, continuing..."
}
mark_step_done "Install apt packages"
# ------------------------------------------------------------------------------
# Python package site hygiene & install protoc
# ------------------------------------------------------------------------------
# Clear torch compilation cache
python3 -c 'import os, shutil, tempfile, getpass; cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") or os.path.join(tempfile.gettempdir(), "torchinductor_" + getpass.getuser()); shutil.rmtree(cache_dir, ignore_errors=True)'
# Remove broken dist-info directories (missing METADATA per PEP 376)
SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])")
if [ -d "$SITE_PACKAGES" ]; then
{ set +x; } 2>/dev/null
find "$SITE_PACKAGES" -maxdepth 1 -name "*.dist-info" -type d | while read -r d; do
if [ ! -f "$d/METADATA" ]; then
echo "Removing broken dist-info: $d"
rm -rf "$d"
fi
done
set -x
fi
# Install protoc
bash "${SCRIPT_DIR}/../utils/install_protoc.sh"
mark_step_done "Python package site hygiene & install protoc"
# ------------------------------------------------------------------------------
# Pip / uv toolchain & stale package cleanup
# ------------------------------------------------------------------------------
# Install pip and uv (use python3 -m pip for robustness since some runners only have pip3)
python3 -m pip install --upgrade pip
if [ "$USE_UV" = "0" ]; then
PIP_CMD="pip"
PIP_INSTALL_SUFFIX="--break-system-packages"
PIP_UNINSTALL_CMD="pip uninstall -y"
PIP_UNINSTALL_SUFFIX="--break-system-packages"
else
pip install uv
export UV_SYSTEM_PYTHON=true
PIP_CMD="uv pip"
PIP_INSTALL_SUFFIX="--index-strategy unsafe-best-match --prerelease allow"
PIP_UNINSTALL_CMD="uv pip uninstall"
PIP_UNINSTALL_SUFFIX=""
fi
# Clean up existing installations
$PIP_UNINSTALL_CMD sgl-kernel sglang-kernel sglang sgl-fa4 flash-attn-4 $PIP_UNINSTALL_SUFFIX || true
mark_step_done "Pip / uv toolchain & stale package cleanup"
# ------------------------------------------------------------------------------
# Uninstall Flashinfer
# ------------------------------------------------------------------------------
# Keep flashinfer packages installed if version matches to avoid re-downloading:
# - flashinfer-cubin: 150+ MB, plus extra cubins from ci_download_flashinfer_cubin.sh
# - flashinfer-jit-cache: 1.2+ GB, by far the largest download in CI
FLASHINFER_PYTHON_REQUIRED=$(grep -Po -m1 '(?<=flashinfer_python==)[0-9A-Za-z\.\-]+' python/pyproject.toml || echo "")
FLASHINFER_CUBIN_REQUIRED=$(grep -Po -m1 '(?<=flashinfer_cubin==)[0-9A-Za-z\.\-]+' python/pyproject.toml || echo "")
FLASHINFER_CUBIN_INSTALLED=$(pip show flashinfer-cubin 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
FLASHINFER_JIT_INSTALLED=$(pip show flashinfer-jit-cache 2>/dev/null | grep "^Version:" | awk '{print $2}' | sed 's/+.*//' || echo "")
UNINSTALL_CUBIN=true
UNINSTALL_JIT_CACHE=true
if [ "$FLASHINFER_CUBIN_INSTALLED" = "$FLASHINFER_CUBIN_REQUIRED" ] && [ -n "$FLASHINFER_CUBIN_REQUIRED" ]; then
echo "flashinfer-cubin==${FLASHINFER_CUBIN_REQUIRED} already installed, keeping it"
UNINSTALL_CUBIN=false
else
echo "flashinfer-cubin version mismatch (installed: ${FLASHINFER_CUBIN_INSTALLED:-none}, required: ${FLASHINFER_CUBIN_REQUIRED}), reinstalling"
fi
if [ "$FLASHINFER_JIT_INSTALLED" = "$FLASHINFER_PYTHON_REQUIRED" ] && [ -n "$FLASHINFER_PYTHON_REQUIRED" ]; then
echo "flashinfer-jit-cache==${FLASHINFER_PYTHON_REQUIRED} already installed, keeping it"
UNINSTALL_JIT_CACHE=false
else
echo "flashinfer-jit-cache version mismatch (installed: ${FLASHINFER_JIT_INSTALLED:-none}, required: ${FLASHINFER_PYTHON_REQUIRED}), will reinstall"
fi
# Build uninstall list based on what needs updating
FLASHINFER_UNINSTALL="flashinfer-python"
[ "$UNINSTALL_CUBIN" = true ] && FLASHINFER_UNINSTALL="$FLASHINFER_UNINSTALL flashinfer-cubin"
[ "$UNINSTALL_JIT_CACHE" = true ] && FLASHINFER_UNINSTALL="$FLASHINFER_UNINSTALL flashinfer-jit-cache"
$PIP_UNINSTALL_CMD $FLASHINFER_UNINSTALL $PIP_UNINSTALL_SUFFIX || true
$PIP_UNINSTALL_CMD opencv-python opencv-python-headless $PIP_UNINSTALL_SUFFIX || true
mark_step_done "Uninstall Flashinfer"
# ------------------------------------------------------------------------------
# Install main package
# ------------------------------------------------------------------------------
# Install the main package
EXTRAS="dev,runai,tracing"
if [ -n "$OPTIONAL_DEPS" ]; then
EXTRAS="dev,runai,tracing,${OPTIONAL_DEPS}"
fi
echo "Installing python extras: [${EXTRAS}]"
source "$(dirname "$0")/cache_nvidia_wheels.sh"
$PIP_CMD install -e "python[${EXTRAS}]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
mark_step_done "Install main package"
# ------------------------------------------------------------------------------
# Install sglang-kernel
# ------------------------------------------------------------------------------
# Install sgl-kernel
SGL_KERNEL_VERSION_FROM_KERNEL=$(grep -Po '(?<=^version = ")[^"]*' sgl-kernel/pyproject.toml)
SGL_KERNEL_VERSION_FROM_SRT=$(grep -Po -m1 '(?<=sglang-kernel==)[0-9A-Za-z\.\-]+' python/pyproject.toml)
echo "SGL_KERNEL_VERSION_FROM_KERNEL=${SGL_KERNEL_VERSION_FROM_KERNEL} SGL_KERNEL_VERSION_FROM_SRT=${SGL_KERNEL_VERSION_FROM_SRT}"
if [ "${CUSTOM_BUILD_SGL_KERNEL:-}" = "true" ] && [ -d "sgl-kernel/dist" ]; then
ls -alh sgl-kernel/dist
# Determine wheel architecture
if [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then
WHEEL_ARCH="aarch64"
else
WHEEL_ARCH="x86_64"
fi
$PIP_CMD install sgl-kernel/dist/sglang_kernel-${SGL_KERNEL_VERSION_FROM_KERNEL}-cp310-abi3-manylinux2014_${WHEEL_ARCH}.whl --force-reinstall $PIP_INSTALL_SUFFIX
elif [ "${CUSTOM_BUILD_SGL_KERNEL:-}" = "true" ] && [ ! -d "sgl-kernel/dist" ]; then
# CUSTOM_BUILD_SGL_KERNEL was set but artifacts not available (e.g., stage rerun without wheel build)
# Fail instead of falling back to PyPI - we need to test the built kernel, not PyPI version
echo "ERROR: CUSTOM_BUILD_SGL_KERNEL=true but sgl-kernel/dist not found."
echo "This usually happens when rerunning a stage without the sgl-kernel-build-wheels job."
echo "Please re-run the full workflow using /tag-and-rerun-ci to rebuild the kernel."
exit 1
else
# On Blackwell machines, skip reinstall if correct version already installed to avoid race conditions
if [ "$IS_BLACKWELL" = "1" ]; then
INSTALLED_SGL_KERNEL=$(pip show sglang-kernel 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
if [ "$INSTALLED_SGL_KERNEL" = "$SGL_KERNEL_VERSION_FROM_SRT" ]; then
echo "sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} already installed, skipping reinstall"
else
echo "Installing sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} (current: ${INSTALLED_SGL_KERNEL:-none})"
$PIP_CMD install sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} $PIP_INSTALL_SUFFIX
fi
else
$PIP_CMD install sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} --force-reinstall $PIP_INSTALL_SUFFIX
fi
fi
mark_step_done "Install sglang-kernel"
# ------------------------------------------------------------------------------
# Install sglang-router
# ------------------------------------------------------------------------------
# Install router for pd-disagg test
$PIP_CMD install sglang-router $PIP_INSTALL_SUFFIX
# Show current packages
$PIP_CMD list
mark_step_done "Install sglang-router"
# ------------------------------------------------------------------------------
# Download flashinfer artifacts
# ------------------------------------------------------------------------------
# Download flashinfer jit cache
UNINSTALL_JIT_CACHE="$UNINSTALL_JIT_CACHE" \
FLASHINFER_PYTHON_REQUIRED="$FLASHINFER_PYTHON_REQUIRED" \
CU_VERSION="$CU_VERSION" \
PIP_CMD="$PIP_CMD" \
PIP_INSTALL_SUFFIX="$PIP_INSTALL_SUFFIX" \
bash "${SCRIPT_DIR}/ci_download_flashinfer_jit_cache.sh"
# Download flashinfer cubins
bash "${SCRIPT_DIR}/ci_download_flashinfer_cubin.sh"
mark_step_done "Download flashinfer artifacts"
# ------------------------------------------------------------------------------
# Install extra dependency
# ------------------------------------------------------------------------------
# Install other python dependencies
if [ "$CU_VERSION" = "cu130" ]; then
NVRTC_SPEC="nvidia-cuda-nvrtc"
else
NVRTC_SPEC="nvidia-cuda-nvrtc-cu12"
fi
$PIP_CMD install mooncake-transfer-engine==0.3.10.post1 "${NVRTC_SPEC}" py-spy scipy huggingface_hub[hf_xet] pytest $PIP_INSTALL_SUFFIX
# Install other test dependencies
if [ "$IS_BLACKWELL" != "1" ]; then
# For lmms_evals evaluating MMMU
git clone --branch v0.5 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
$PIP_CMD install -e lmms-eval/ $PIP_INSTALL_SUFFIX
fi
$PIP_CMD uninstall xformers || true
mark_step_done "Install extra dependency"
# ------------------------------------------------------------------------------
# Fix other dependencies
# ------------------------------------------------------------------------------
# Fix CUDA version mismatch between torch and torchaudio.
# PyPI's torch 2.9.1 bundles cu128 but torchaudio from pytorch.org/cu129 uses cu129.
# This mismatch causes torchaudio's C extension to fail loading, producing:
# "partially initialized module 'torchaudio' has no attribute 'lib'"
# We cannot replace torch with cu129 (breaks sgl_kernel ABI), so instead we reinstall
# torchaudio/torchvision from an index matching torch's CUDA version.
TORCH_CUDA_VER=$(python3 -c "import torch; v=torch.version.cuda; parts=v.split('.'); print(f'cu{parts[0]}{parts[1]}')")
echo "Detected torch CUDA version: ${TORCH_CUDA_VER}"
if [ "${TORCH_CUDA_VER}" != "${CU_VERSION}" ]; then
# Pin versions to match what was installed by pyproject.toml (strip +cuXYZ suffix)
TORCHAUDIO_VER=$(pip show torchaudio 2>/dev/null | grep "^Version:" | awk '{print $2}' | sed 's/+.*//')
TORCHVISION_VER=$(pip show torchvision 2>/dev/null | grep "^Version:" | awk '{print $2}' | sed 's/+.*//')
echo "Reinstalling torchaudio==${TORCHAUDIO_VER} torchvision==${TORCHVISION_VER} from ${TORCH_CUDA_VER} index to match torch..."
$PIP_CMD install "torchaudio==${TORCHAUDIO_VER}" "torchvision==${TORCHVISION_VER}" --index-url "https://download.pytorch.org/whl/${TORCH_CUDA_VER}" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX
fi
# Fix dependencies: DeepEP depends on nvshmem 3.4.5 — skip reinstall when already correct (avoids pip races / wasted work)
INSTALLED_NVSHMEM=$(pip show nvidia-nvshmem-cu12 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
if [ "$INSTALLED_NVSHMEM" = "$NVIDIA_NVSHMEM_VERSION" ]; then
echo "nvidia-nvshmem-cu12==${NVIDIA_NVSHMEM_VERSION} already installed, skipping reinstall"
else
$PIP_CMD install nvidia-nvshmem-cu12==${NVIDIA_NVSHMEM_VERSION} $PIP_INSTALL_SUFFIX
fi
# Fix dependencies: Cudnn with version less than 9.16.0.29 will cause performance regression on Conv3D kernel
INSTALLED_CUDNN=$(pip show nvidia-cudnn-cu12 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
if [ "$INSTALLED_CUDNN" = "$NVIDIA_CUDNN_VERSION" ]; then
echo "nvidia-cudnn-cu12==${NVIDIA_CUDNN_VERSION} already installed, skipping reinstall"
else
$PIP_CMD install nvidia-cudnn-cu12==${NVIDIA_CUDNN_VERSION} $PIP_INSTALL_SUFFIX
fi
mark_step_done "Fix other dependencies"
# Force reinstall nvidia-cutlass-dsl to ensure the .pth file exists.
# The Docker image ships nvidia-cutlass-dsl-libs-base 4.3.5; upgrading to 4.4.2
# can delete the .pth file without reliably recreating it (pip race condition).
$PIP_CMD install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --no-deps --force-reinstall $PIP_INSTALL_SUFFIX || true
# Install human-eval
pip install "setuptools==70.0.0"
git clone https://github.com/merrymercy/human-eval.git
cd human-eval
pip install -e . --no-build-isolation
# ------------------------------------------------------------------------------
# Prepare runner
# ------------------------------------------------------------------------------
# Prepare the CI runner (cleanup HuggingFace cache, etc.)
bash "${SCRIPT_DIR}/prepare_runner.sh"
mark_step_done "Prepare runner"
# ------------------------------------------------------------------------------
# Verify imports
# ------------------------------------------------------------------------------
# Show current packages
$PIP_CMD list
python3 -c "import torch; print(torch.version.cuda)"
python3 -c "import cutlass; import cutlass.cute;"
mark_step_done "Verify imports"

View File

@@ -0,0 +1,24 @@
#!/bin/bash
set -euxo pipefail
# Check if sudo is available
if command -v sudo >/dev/null 2>&1; then
sudo apt-get update
sudo apt-get install -y libssl-dev pkg-config protobuf-compiler redis-server
else
apt-get update
apt-get install -y libssl-dev pkg-config protobuf-compiler redis-server
fi
# Install rustup (Rust installer and version manager)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.90
# Follow the installation prompts, then reload your shell
. "$HOME/.cargo/env"
source $HOME/.cargo/env
# Verify installation
rustc --version
cargo --version
protoc --version

View File

@@ -0,0 +1,106 @@
#!/bin/bash
set -euo pipefail
# Optional: set DISAGG_READY_FILE to a filepath; when all servers are healthy, the script will
# create this file as a readiness signal (useful for CI to proceed to next steps).
DISAGG_READY_FILE="${DISAGG_READY_FILE:-}"
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
# Function to find the first available active IB device
find_active_ib_device() {
for device in mlx5_{0..11}; do
if ibv_devinfo $device >/dev/null 2>&1; then
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
if [[ "$state" == "PORT_ACTIVE" ]]; then
echo "$device"
return 0
fi
fi
done
echo "No active IB device found" >&2
return 1
}
# Get the first available active IB device
DEVICE=$(find_active_ib_device)
echo "Using IB device: $DEVICE"
# Launch prefill servers on GPU 03
for i in {0..3}; do
PORT=$((30001 + i))
BOOTSTRAP_PORT=$((9001 + i))
HOST="127.0.0.$((i + 1))"
echo "Launching PREFILL server on GPU $i at $HOST:$PORT (bootstrap: $BOOTSTRAP_PORT)"
CUDA_VISIBLE_DEVICES=$i \
python3 -m sglang.launch_server \
--model-path "$MODEL_PATH" \
--disaggregation-mode prefill \
--host "$HOST" \
--port "$PORT" \
--disaggregation-ib-device "$DEVICE" \
--disaggregation-bootstrap-port "$BOOTSTRAP_PORT" &
done
# Launch decode servers on GPU 47
for i in {4..7}; do
PORT=$((30001 + i))
HOST="127.0.0.$((i + 1))"
echo "Launching DECODE server on GPU $i at $HOST:$PORT"
CUDA_VISIBLE_DEVICES=$i \
python3 -m sglang.launch_server \
--model-path "$MODEL_PATH" \
--disaggregation-mode decode \
--host "$HOST" \
--port "$PORT" \
--disaggregation-ib-device "$DEVICE" \
--base-gpu-id 0 &
done
# Wait for disaggregation servers to initialize
echo "Waiting for disaggregation servers to initialize..."
# Health check with 5-minute timeout
TIMEOUT=300
START_TIME=$(date +%s)
echo "Checking health of all 8 servers..."
while true; do
CURRENT_TIME=$(date +%s)
ELAPSED=$((CURRENT_TIME - START_TIME))
if [ $ELAPSED -ge $TIMEOUT ]; then
echo "❌ Timeout: Servers did not become healthy within 5 minutes"
exit 1
fi
HEALTHY_COUNT=0
# Check all 8 servers (127.0.0.1-8:30001-30008)
for i in {1..8}; do
if curl -s -f "http://127.0.0.$i:$((30000 + i))/health" >/dev/null 2>&1; then
HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
fi
done
echo "Healthy servers: $HEALTHY_COUNT/8 (elapsed: ${ELAPSED}s)"
if [ $HEALTHY_COUNT -eq 8 ]; then
echo "✅ All 8 servers are healthy!"
# Emit readiness signal file if requested
if [ -n "$DISAGG_READY_FILE" ]; then
echo "Creating readiness flag: $DISAGG_READY_FILE"
# Ensure parent dir exists; ignore errors
mkdir -p "$(dirname "$DISAGG_READY_FILE")" 2>/dev/null || true
touch "$DISAGG_READY_FILE"
fi
break
else
sleep 10 # Wait 10 seconds before next check
fi
done
# Don't launch router here - just keep servers running
echo "✅ All disaggregation servers are ready and waiting for router connections"
# Keep the script running
wait

View File

@@ -0,0 +1,19 @@
#!/bin/bash
# Prepare the CI runner by cleaning up stale HuggingFace cache artifacts and validating models
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
echo "Preparing CI runner..."
echo ""
# Clean up stale HuggingFace cache artifacts from previous failed downloads
python3 "${SCRIPT_DIR}/../utils/cleanup_hf_cache.py"
echo ""
# Pre-validate cached models and write markers for offline mode
# This allows tests to run with HF_HUB_OFFLINE=1 for models that are fully cached
python3 "${SCRIPT_DIR}/../utils/prevalidate_cached_models.py"
echo ""
echo "CI runner preparation complete!"

View File

@@ -0,0 +1,399 @@
"""
Lightweight DeepGEMM JIT compilation warmup without loading model weights.
Reads model config.json from HF cache to derive kernel shapes, then compiles
DeepGEMM kernels directly. This avoids the expensive model weight loading step
that the full `sglang.compile_deep_gemm` requires.
Supports DeepSeek V2/V3 family models. Falls back to `sglang.compile_deep_gemm`
for unsupported architectures.
Usage:
python3 scripts/ci/cuda/warmup_deep_gemm.py \
deepseek-ai/DeepSeek-V3-0324:8 \
deepseek-ai/DeepSeek-V3.2-Exp:8
"""
import json
import os
import subprocess
import sys
import time
from math import ceil
from pathlib import Path
# Configure DeepGEMM cache before importing deep_gemm
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
"SGLANG_DG_CACHE_DIR",
os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm"),
)
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
BLOCK_SIZE = 128
def get_config_json(model_name):
"""Load config.json for a cached model from HF cache."""
cache_dir = os.environ.get(
"HF_HOME", os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
)
hub_dir = os.path.join(cache_dir, "hub")
safe_name = "models--" + model_name.replace("/", "--")
snapshots_dir = os.path.join(hub_dir, safe_name, "snapshots")
if not os.path.isdir(snapshots_dir):
return None
snapshots = sorted(
Path(snapshots_dir).iterdir(), key=lambda p: p.stat().st_mtime, reverse=True
)
for snapshot in snapshots:
config_path = snapshot / "config.json"
if config_path.exists():
with open(config_path) as f:
return json.load(f)
return None
def is_deepseek_v2v3(config):
"""Check if a model is from the DeepSeek V2/V3 family."""
architectures = config.get("architectures", [])
model_type = config.get("model_type", "")
return any(
"DeepseekV2" in a or "DeepseekV3" in a for a in architectures
) or model_type in ("deepseek_v2", "deepseek_v3")
def compute_deepseek_v2v3_shapes(config, tp):
"""Compute all DeepGEMM (kernel_type, N, K, num_groups) for DeepSeek V2/V3.
Shape derivation based on:
- MoE: python/sglang/srt/layers/moe/fused_moe_triton/layer.py
- MLA: python/sglang/srt/models/deepseek_v2.py
- FP8: python/sglang/srt/layers/quantization/fp8_kernel.py
"""
shapes = []
hidden_size = config["hidden_size"]
num_attention_heads = config.get("num_attention_heads", 128)
kv_lora_rank = config.get("kv_lora_rank", 512)
qk_nope_head_dim = config.get("qk_nope_head_dim", 128)
v_head_dim = config.get("v_head_dim", 128)
n_routed_experts = config.get("n_routed_experts", 0)
n_shared_experts = config.get("n_shared_experts", 0)
moe_intermediate_size = config.get("moe_intermediate_size", 0)
num_local_heads = num_attention_heads // tp
# Shared expert fusion is enabled by default (disable_shared_experts_fusion=False)
# so the FusedMoE weight tensor includes shared experts
num_local_experts = n_routed_experts + n_shared_experts
# --- MoE expert GEMM shapes ---
# FusedMoE shards intermediate_size across TP ranks (column parallel for gate/up,
# row parallel for down). All experts are replicated on each TP rank.
if n_routed_experts > 0 and moe_intermediate_size > 0:
moe_inter_per_tp = moe_intermediate_size // tp
# Gate-Up projection: (tokens, hidden_size) @ (experts, 2*inter_per_tp, hidden_size)^T
# Both masked and contiguous paths are used at runtime
shapes.append(("MASKED", moe_inter_per_tp * 2, hidden_size, num_local_experts))
shapes.append(("CONTIG", moe_inter_per_tp * 2, hidden_size, num_local_experts))
# Down projection: (tokens, inter_per_tp) @ (experts, hidden_size, inter_per_tp)^T
shapes.append(("MASKED", hidden_size, moe_inter_per_tp, num_local_experts))
shapes.append(("CONTIG", hidden_size, moe_inter_per_tp, num_local_experts))
# --- MLA attention GEMM shapes (masked grouped GEMM) ---
if kv_lora_rank > 0 and num_local_heads > 0:
# Q_nope -> compressed K: (heads, m, qk_nope_head_dim) @ (heads, kv_lora_rank, qk_nope_head_dim)^T
shapes.append(("MASKED", kv_lora_rank, qk_nope_head_dim, num_local_heads))
# Attention output -> V: (heads, m, kv_lora_rank) @ (heads, v_head_dim, kv_lora_rank)^T
shapes.append(("MASKED", v_head_dim, kv_lora_rank, num_local_heads))
# --- kv_b_proj (non-grouped GEMM via FP8 kernel) ---
# ColumnParallelLinear(kv_lora_rank, num_heads * (qk_nope + v_head_dim))
# Per TP rank: N = num_local_heads * (qk_nope_head_dim + v_head_dim)
if kv_lora_rank > 0 and num_local_heads > 0:
kv_b_proj_n = num_local_heads * (qk_nope_head_dim + v_head_dim)
shapes.append(("NORMAL", kv_b_proj_n, kv_lora_rank, 1))
return shapes
def get_architecture_key(config, tp):
"""Key for dedup: models with same key share DeepGEMM kernels."""
if config is None:
return None
fields = [
config.get("hidden_size", 0),
config.get("moe_intermediate_size", 0),
config.get("n_routed_experts", 0),
config.get("n_shared_experts", 0),
config.get("num_attention_heads", 0),
config.get("kv_lora_rank", 0),
config.get("qk_nope_head_dim", 0),
config.get("v_head_dim", 0),
tp,
]
return tuple(fields)
def compute_m_list(fast_warmup=False, chunked_prefill_size=8192):
"""Compute the list of M values to compile (matches compile_utils.py logic)."""
m_list = []
if fast_warmup:
m_list += list(range(1, 1025))
next_m, sample_step = 1024, 2
max_prefill_bs = min(chunked_prefill_size, 32 * 1024)
while next_m < max_prefill_bs:
m_list += list(range(next_m, 2 * next_m, sample_step))
next_m *= 2
sample_step *= 2
m_list.append(max_prefill_bs)
m_list = sorted(set(m_list))
else:
m_max = 16 * 1024
if chunked_prefill_size > 8192:
m_max = chunked_prefill_size * 2
m_max = min(128 * 1024, m_max)
m_list = list(range(1, m_max + 1))
return m_list
def _empty_token_fp8(size):
"""Create FP8 token tensor + per-block scale tensor."""
import torch
*dims, k = size
return (
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
torch.empty((*dims, ceil(k / BLOCK_SIZE)), device="cuda", dtype=torch.float32),
)
def _empty_block_fp8(size):
"""Create FP8 block tensor + per-block scale tensor."""
import torch
*dims, n, k = size
return (
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
torch.empty(
(*dims, ceil(n / BLOCK_SIZE), ceil(k / BLOCK_SIZE)),
device="cuda",
dtype=torch.float32,
),
)
def get_memory_requirement(kernel_type, max_m, n, k, num_groups):
"""Estimate GPU memory needed in GB for compilation buffers."""
_GB = 1 << 30
if kernel_type == "NORMAL":
return (max_m * k + n * k + max_m * n * 2) / _GB
elif kernel_type == "CONTIG":
return (max_m * k + num_groups * n * k + max_m * 4 + max_m * n * 2) / _GB
elif kernel_type == "MASKED":
return (
num_groups * max_m * k
+ num_groups * n * k
+ num_groups * 4
+ num_groups * max_m * n * 2
) / _GB
return 0
def compile_one_shape(kernel_type, n, k, num_groups, m_list):
"""Compile DeepGEMM kernels for one (kernel_type, N, K, num_groups) shape."""
import deep_gemm
import torch
from tqdm import tqdm
# Filter M list for contiguous layout alignment
if kernel_type == "CONTIG":
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
m_list = sorted(set(m for m in m_list if m % m_alignment == 0))
if not m_list:
return
max_m = max(m_list)
# Reduce max_m if not enough GPU memory
mem_free = torch.cuda.mem_get_info()[0] / (1 << 30)
mem_required = get_memory_requirement(kernel_type, max_m, n, k, num_groups)
if mem_required > mem_free:
while (
get_memory_requirement(kernel_type, max_m, n, k, num_groups) > mem_free
and max_m > 4096
):
max_m //= 2
print(
f" Memory {mem_free:.1f}GB < required {mem_required:.1f}GB, "
f"reducing max_m to {max_m}"
)
m_list = [m for m in m_list if m <= max_m]
old_mode = deep_gemm.get_compile_mode()
deep_gemm.set_compile_mode(1)
try:
if kernel_type == "NORMAL":
lhs_q, lhs_s = _empty_token_fp8((max_m, k))
rhs_q, rhs_s = _empty_block_fp8((n, k))
out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
for m in tqdm(m_list, desc=f" NORMAL N={n} K={k}"):
deep_gemm.fp8_gemm_nt((lhs_q[:m], lhs_s[:m]), (rhs_q, rhs_s), out[:m])
elif kernel_type == "CONTIG":
lhs_q, lhs_s = _empty_token_fp8((max_m, k))
rhs_q, rhs_s = _empty_block_fp8((num_groups, n, k))
m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
for m in tqdm(m_list, desc=f" CONTIG N={n} K={k} G={num_groups}"):
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
(lhs_q[:m], lhs_s[:m]),
(rhs_q, rhs_s),
out[:m],
m_indices=m_indices[:m],
)
elif kernel_type == "MASKED":
lhs_q, lhs_s = _empty_token_fp8((num_groups, max_m, k))
rhs_q, rhs_s = _empty_block_fp8((num_groups, n, k))
masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
out = torch.empty(
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
)
for m in tqdm(m_list, desc=f" MASKED N={n} K={k} G={num_groups}"):
deep_gemm.fp8_m_grouped_gemm_nt_masked(
(lhs_q, lhs_s),
(rhs_q, rhs_s),
out,
masked_m=masked_m,
expected_m=m,
)
finally:
deep_gemm.set_compile_mode(old_mode)
torch.cuda.current_stream().synchronize()
torch.cuda.empty_cache()
def compile_shapes_lightweight(shapes, m_list):
"""Compile all DeepGEMM shapes directly (no model loading)."""
for i, (kernel_type, n, k, num_groups) in enumerate(shapes, 1):
print(f"\n[{i}/{len(shapes)}] {kernel_type} N={n} K={k} G={num_groups}")
t0 = time.time()
compile_one_shape(kernel_type, n, k, num_groups, m_list)
elapsed = time.time() - t0
print(f" Done in {elapsed:.1f}s")
def fallback_compile_deep_gemm(model, tp):
"""Fall back to full sglang.compile_deep_gemm (loads model weights)."""
print(f"Falling back to full compile_deep_gemm for {model} (tp={tp})...")
cmd = [
sys.executable,
"-m",
"sglang.compile_deep_gemm",
"--model",
model,
"--tp",
str(tp),
"--trust-remote-code",
"--model-loader-extra-config",
'{"enable_multithread_load": true, "num_threads": 64}',
]
result = subprocess.run(cmd)
if result.returncode != 0:
print(f"Warning: fallback failed for {model} (exit code {result.returncode})")
return result.returncode == 0
def main():
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
print("Usage: warmup_deep_gemm.py model1:tp1 [model2:tp2 ...]")
print("\nDerives DeepGEMM kernel shapes from config.json without loading model")
print(
"weights. Falls back to full compile_deep_gemm for unknown architectures."
)
sys.exit(0)
# Parse model:tp pairs
model_tp_pairs = []
for arg in sys.argv[1:]:
if ":" not in arg:
print(f"Error: expected model:tp format, got '{arg}'")
sys.exit(1)
model, tp_str = arg.rsplit(":", 1)
model_tp_pairs.append((model, int(tp_str)))
fast_warmup = os.environ.get("SGLANG_JIT_DEEPGEMM_FAST_WARMUP", "0").lower() in (
"1",
"true",
)
print(f"=== DeepGEMM Lightweight Warmup ({len(model_tp_pairs)} model(s)) ===")
print(f" Fast warmup: {fast_warmup}")
print(
f" Cache dir: {os.environ.get('DG_JIT_CACHE_DIR', '~/.cache/deep_gemm')}\n"
)
# Load configs and deduplicate by architecture
seen_keys = {}
to_process = [] # (model, tp, config_or_None, shapes_or_None)
for model, tp in model_tp_pairs:
config = get_config_json(model)
if config is None:
print(f" SKIP {model} (tp={tp}): config.json not in HF cache")
continue
key = get_architecture_key(config, tp)
if key in seen_keys:
print(f" DEDUP {model} (tp={tp}): same shapes as {seen_keys[key]}")
continue
if is_deepseek_v2v3(config):
shapes = compute_deepseek_v2v3_shapes(config, tp)
seen_keys[key] = model
to_process.append((model, tp, config, shapes))
print(f" FOUND {model} (tp={tp}): {len(shapes)} DeepGEMM shape(s)")
else:
# Unknown architecture: will use fallback
seen_keys[key] = model
to_process.append((model, tp, config, None))
arch = config.get("architectures", ["unknown"])
print(f" FOUND {model} (tp={tp}): unknown arch {arch}, will use fallback")
if not to_process:
print("\nNo models to process. Done.")
return
m_list = compute_m_list(fast_warmup=fast_warmup)
print(f"\nM list: {len(m_list)} values (range {min(m_list)}-{max(m_list)})")
for model, tp, config, shapes in to_process:
print(f"\n{'=' * 60}")
print(f"Model: {model} (tp={tp})")
print(f"{'=' * 60}")
if shapes is None:
# Unknown architecture: fall back to full compile_deep_gemm
fallback_compile_deep_gemm(model, tp)
continue
# Print shape summary
for kernel_type, n, k, num_groups in shapes:
print(f" {kernel_type:8s} N={n:<6d} K={k:<6d} G={num_groups}")
t0 = time.time()
compile_shapes_lightweight(shapes, m_list)
elapsed = time.time() - t0
print(f"\nCompleted {model} in {elapsed:.1f}s")
print("\nDeepGEMM lightweight warmup complete.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,313 @@
"""
Full server warmup to pre-warm Triton autotuning and CUDA graph capture.
On cold H200 nodes (new nodes or after container recreation), CUDA graph capture
triggers Triton autotuning which takes ~330s per server launch. This script
launches actual servers with CUDA graphs enabled to cache the autotuned kernels,
so subsequent test launches are fast (~30-60s).
Uses marker files to skip warmup on already-warm nodes. Marker files are
invalidated when Python, Triton, or PyTorch versions change.
Usage:
python3 scripts/ci/cuda/warmup_server.py \
deepseek-ai/DeepSeek-V3-0324:8 \
inclusionAI/Ring-2.5-1T:8
"""
import hashlib
import json
import os
import signal
import subprocess
import sys
import tempfile
import time
from pathlib import Path
# Reuse helpers from warmup_deep_gemm (same directory)
sys.path.insert(0, os.path.dirname(__file__))
from warmup_deep_gemm import get_architecture_key, get_config_json
MARKER_DIR = os.path.join(os.path.expanduser("~"), ".cache", "sglang", "warmup_markers")
HEALTH_POLL_INTERVAL = 10 # seconds between health checks
SERVER_STARTUP_TIMEOUT = 900 # 15 min max to wait for server ready
DEFAULT_PORT = 39876
def get_version_key():
"""Hash of Python + Triton + PyTorch versions to invalidate markers on upgrades."""
parts = [sys.version]
try:
import triton
parts.append(f"triton={triton.__version__}")
except ImportError:
parts.append("triton=none")
try:
import torch
parts.append(f"torch={torch.__version__}")
except ImportError:
parts.append("torch=none")
return hashlib.sha256("|".join(parts).encode()).hexdigest()[:12]
def get_marker_path(model, tp):
"""Get the marker file path for a model:tp pair."""
version_key = get_version_key()
safe_model = model.replace("/", "--")
return os.path.join(
MARKER_DIR, f"server_warmup_{safe_model}_tp{tp}_{version_key}.done"
)
def check_marker(model, tp):
"""Check if warmup marker exists (node already warm)."""
marker = get_marker_path(model, tp)
return os.path.exists(marker)
def write_marker(model, tp):
"""Write warmup marker after successful warmup."""
marker = get_marker_path(model, tp)
os.makedirs(os.path.dirname(marker), exist_ok=True)
Path(marker).write_text(
json.dumps(
{
"model": model,
"tp": tp,
"version_key": get_version_key(),
"timestamp": time.time(),
}
)
)
print(f" Wrote marker: {marker}")
def kill_server(proc):
"""Kill server process tree."""
if proc.poll() is not None:
return
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, OSError):
pass
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
try:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except (ProcessLookupError, OSError):
pass
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
def wait_for_server(base_url, proc, timeout):
"""Poll /health_generate until server is ready or timeout."""
import requests
start = time.time()
while time.time() - start < timeout:
ret = proc.poll()
if ret is not None:
return False, f"Server exited with code {ret}"
try:
resp = requests.get(f"{base_url}/health_generate", timeout=5)
if resp.status_code == 200:
return True, None
except requests.RequestException:
pass
time.sleep(HEALTH_POLL_INTERVAL)
return False, "Timed out waiting for server"
def send_generate_request(base_url):
"""Send one /generate request to exercise the full inference path."""
import requests
payload = {
"input_ids": [0, 1, 2, 3],
"sampling_params": {
"max_new_tokens": 8,
"temperature": 0,
},
}
try:
resp = requests.post(f"{base_url}/generate", json=payload, timeout=120)
if resp.status_code == 200:
print(" Generate request succeeded")
else:
print(f" Warning: generate request returned {resp.status_code}")
except requests.RequestException as e:
print(f" Warning: generate request failed: {e}")
def warmup_one_model(model, tp, port):
"""Launch server, wait for ready, send one request, then kill."""
base_url = f"http://127.0.0.1:{port}"
cmd = [
sys.executable,
"-m",
"sglang.launch_server",
"--model-path",
model,
"--tp",
str(tp),
"--host",
"127.0.0.1",
"--port",
str(port),
"--trust-remote-code",
"--model-loader-extra-config",
'{"enable_multithread_load": true, "num_threads": 64}',
]
# Use a temp file for server output to avoid pipe buffer deadlock
# (server logs can exceed the 64KB pipe buffer during CUDA graph capture)
log_file = tempfile.NamedTemporaryFile(
mode="w", prefix="warmup_server_", suffix=".log", delete=False
)
log_path = log_file.name
print(f" Launching server: {' '.join(cmd)}")
print(f" Server log: {log_path}")
proc = subprocess.Popen(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
preexec_fn=os.setsid,
)
try:
# Wait for server to be ready (includes CUDA graph capture)
print(
f" Waiting for server (timeout={SERVER_STARTUP_TIMEOUT}s, "
f"polling every {HEALTH_POLL_INTERVAL}s)..."
)
ok, err = wait_for_server(base_url, proc, SERVER_STARTUP_TIMEOUT)
if not ok:
print(f" Warning: server not ready: {err}")
# Dump last lines of server log for debugging
try:
log_file.flush()
with open(log_path) as f:
lines = f.readlines()
for line in lines[-20:]:
print(f" | {line.rstrip()}")
except Exception:
pass
return False
print(" Server ready, sending generate request...")
send_generate_request(base_url)
return True
finally:
print(" Killing server...")
kill_server(proc)
log_file.close()
try:
os.unlink(log_path)
except OSError:
pass
def main():
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
print("Usage: warmup_server.py model1:tp1 [model2:tp2 ...]")
print(
"\nLaunches full servers with CUDA graphs enabled to pre-warm"
" Triton autotuning."
)
print("Skips instantly on warm nodes (marker file exists).")
sys.exit(0)
# Parse model:tp pairs
model_tp_pairs = []
for arg in sys.argv[1:]:
if ":" not in arg:
print(f"Error: expected model:tp format, got '{arg}'")
sys.exit(1)
model, tp_str = arg.rsplit(":", 1)
model_tp_pairs.append((model, int(tp_str)))
print(f"=== Server CUDA Graph Warmup ({len(model_tp_pairs)} model(s)) ===")
print(f" Marker dir: {MARKER_DIR}")
print(f" Version key: {get_version_key()}\n")
# Deduplicate by architecture and check markers
seen_keys = {}
to_warmup = []
for model, tp in model_tp_pairs:
# Check marker first (fast path)
if check_marker(model, tp):
print(f" SKIP {model} (tp={tp}): already warm (marker exists)")
continue
# Architecture dedup
config = get_config_json(model)
if config is not None:
key = get_architecture_key(config, tp)
if key in seen_keys:
print(
f" DEDUP {model} (tp={tp}): same architecture as {seen_keys[key]}"
)
continue
seen_keys[key] = model
to_warmup.append((model, tp))
print(f" QUEUE {model} (tp={tp}): needs warmup")
if not to_warmup:
print("\nAll models already warm. Done.")
return
print(f"\n{len(to_warmup)} model(s) to warm up.\n")
port = DEFAULT_PORT
for i, (model, tp) in enumerate(to_warmup, 1):
print(f"\n{'=' * 60}")
print(f"[{i}/{len(to_warmup)}] {model} (tp={tp})")
print(f"{'=' * 60}")
t0 = time.time()
success = warmup_one_model(model, tp, port)
elapsed = time.time() - t0
if success:
print(f" Completed in {elapsed:.0f}s")
write_marker(model, tp)
# Also write markers for dedup'd models that share this architecture
config = get_config_json(model)
if config is not None:
key = get_architecture_key(config, tp)
for other_model, other_tp in model_tp_pairs:
if (other_model, other_tp) == (model, tp):
continue
other_config = get_config_json(other_model)
if other_config is not None:
other_key = get_architecture_key(other_config, other_tp)
if other_key == key and not check_marker(other_model, other_tp):
write_marker(other_model, other_tp)
print(
f" Also marked {other_model} (tp={other_tp}) as warm (same arch)"
)
else:
print(
f" Warning: warmup failed after {elapsed:.0f}s (non-fatal, tests will still work)"
)
# Use a different port for the next model to avoid bind conflicts
port += 100
print("\nServer CUDA graph warmup complete.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
#!/bin/bash
set -euo pipefail
PIP_INSTALL="python3 -m pip install --no-cache-dir"
${PIP_INSTALL} --upgrade pip setuptools torchada

View File

@@ -0,0 +1,46 @@
#!/usr/bin/env bash
set -euo pipefail
# Rename MUSA wheels to include a +musa<suffix> build tag.
# Usage:
# rename_wheels_musa.sh <musa_suffix> [wheel_dir]
# Example:
# rename_wheels_musa.sh 43 sgl-kernel/dist
if [[ $# -lt 1 || $# -gt 2 ]]; then
echo "Usage: $0 <musa_suffix> [wheel_dir]" >&2
exit 1
fi
MUSA_SUFFIX="$1"
WHEEL_DIR="${2:-dist}"
wheel_files=("$WHEEL_DIR"/*.whl)
if [[ ! -e "${wheel_files[0]}" ]]; then
echo "No wheel files found in ${WHEEL_DIR}/, nothing to rename."
exit 0
fi
for wheel in "${wheel_files[@]}"; do
# Normalize platform tag to manylinux2014
intermediate_wheel="${wheel/linux/manylinux2014}"
# Extract Python ABI version (e.g. cp310)
if [[ $intermediate_wheel =~ -cp([0-9]+)- ]]; then
cp_version="${BASH_REMATCH[1]}"
else
echo "Could not extract Python version from wheel name: $intermediate_wheel" >&2
continue
fi
# Insert +musa<suffix> before the Python ABI tag
new_wheel="${intermediate_wheel/-cp${cp_version}/+musa${MUSA_SUFFIX}-cp${cp_version}}"
if [[ "$wheel" != "$new_wheel" ]]; then
echo "Renaming $wheel -> $new_wheel"
mv -- "$wheel" "$new_wheel"
fi
done
echo "MUSA wheel renaming completed."

View File

@@ -0,0 +1,68 @@
#!/bin/bash
set -euo pipefail
PIP_INSTALL="python3 -m pip install --no-cache-dir"
DEVICE_TYPE=$1
OPTIONAL_DEPS="${2:-}"
# Install the required dependencies in CI.
apt update -y && apt install -y \
unzip \
build-essential \
cmake \
wget \
curl \
net-tools \
zlib1g-dev \
lld \
clang \
locales \
ccache \
libgl1-mesa-glx \
libgl1-mesa-dri \
ca-certificates \
libgl1 \
libglib2.0-0
update-ca-certificates
${PIP_INSTALL} --upgrade pip
# Pin wheel to 0.45.1, REF: https://github.com/pypa/wheel/issues/662
${PIP_INSTALL} wheel==0.45.1 pybind11 pyyaml decorator scipy attrs psutil
### Install MemFabric
${PIP_INSTALL} memfabric-hybrid==1.0.5
### Install PyTorch and PTA
if [ -n "$OPTIONAL_DEPS" ]; then
PYTORCH_VERSION="2.10.0"
TORCHVISION_VERSION="0.25.0"
${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url ${TORCH_CACHE_URL:="https://download.pytorch.org/whl/cpu"} --extra-index-url ${PYPI_CACHE_URL:="https://pypi.org/simple/"}
PTA_URL="https://gitcode.com/Ascend/pytorch/releases/download/7.3.0.alpha002/torch_npu-2.10.0rc2-cp311-cp311-manylinux_2_28_aarch64.whl"
${PIP_INSTALL} ${PTA_URL}
else
PYTORCH_VERSION="2.8.0"
TORCHVISION_VERSION="0.23.0"
${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url ${TORCH_CACHE_URL:="https://download.pytorch.org/whl/cpu"} --extra-index-url ${PYPI_CACHE_URL:="https://pypi.org/simple/"}
PTA_URL="https://gitcode.com/Ascend/pytorch/releases/download/v7.3.0-pytorch2.8.0/torch_npu-2.8.0.post2-cp311-cp311-manylinux_2_28_aarch64.whl"
${PIP_INSTALL} ${PTA_URL}
fi
### Install Triton-Ascend
${PIP_INSTALL} triton-ascend
### Install sgl-kernel-npu
SGLANG_KERNEL_NPU_TAG="2026.03.10.rc1"
mkdir sgl-kernel-npu
(cd sgl-kernel-npu && wget "${GITHUB_PROXY_URL:=""}https://github.com/sgl-project/sgl-kernel-npu/releases/download/${SGLANG_KERNEL_NPU_TAG}/sgl-kernel-npu-${SGLANG_KERNEL_NPU_TAG}-torch2.8.0-py311-cann8.5.0-${DEVICE_TYPE}-$(arch).zip" \
&& unzip ./sgl-kernel-npu-${SGLANG_KERNEL_NPU_TAG}-torch2.8.0-py311-cann8.5.0-${DEVICE_TYPE}-$(arch).zip \
&& ${PIP_INSTALL} ./deep_ep*.whl ./sgl_kernel_npu*.whl \
&& (cd "$(python3 -m pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so))
### Install SGLang
rm -rf python/pyproject.toml && mv python/pyproject_npu.toml python/pyproject.toml
${PIP_INSTALL} -v -e "python[dev_npu]"

View File

@@ -0,0 +1,26 @@
#!/bin/bash
set -euo pipefail
# Print log information(sglang version, commit sha, sgl-kernel-npu version, sgl-kernel-npu commit sha, npu-smi info and pip list.
npu-smi info
pip list
get_version() {
[ -f "$1" ] && python3 -c 'import re, sys; print(sys.argv[2] + " version: v" + re.search(r"__version__\s*=\s*[\"'"'"'](.*?)[\"'"'"']", open(sys.argv[1]).read()).group(1))' "$1" "$2" 2>/dev/null || echo "$2 version: unknown"
}
get_version "./python/sglang/version.py" "sglang"
get_version "./sgl-kernel/python/sgl_kernel/version.py" "sgl_kernel"
SGLANG_URL="https://github.com/sgl-project/sglang.git"
SGL_KERNEL_URL="https://github.com/sgl-project/sgl-kernel-npu.git"
SGLANG_BRANCH="main"
SGL_KERNEL_BRANCH="main"
get_sha() {
local name="$1"
local url="$2"
local branch="$3"
local sha
sha=$(git ls-remote "$url" "refs/heads/$branch" | cut -f1)
echo "$name SHA for branch $branch: ${sha:-"Not Found"}"
}
get_sha "sglang" "$SGLANG_URL" "$SGLANG_BRANCH"
get_sha "sgl-kernel" "$SGL_KERNEL_URL" "$SGL_KERNEL_BRANCH"
chmod +x scripts/ci/npu/npu_log_print.sh

View File

@@ -0,0 +1,477 @@
#!/usr/bin/env python3
"""
CI Coverage Report Generator
Collects all CI test registrations from test/registered/ and generates
a coverage report organized by folder, backend, and suite.
Usage:
python scripts/ci/utils/ci_coverage_report.py [--output-format markdown|json]
"""
import argparse
import glob
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
# Add the ci_register module path directly to avoid heavy sglang imports
sys.path.insert(
0,
str(
Path(__file__).parent.parent.parent.parent / "python" / "sglang" / "test" / "ci"
),
)
from ci_register import CIRegistry, HWBackend, ut_parse_one_file
def collect_all_tests(registered_dir: str) -> list[CIRegistry]:
"""Collect all CI registrations from registered directory."""
files = glob.glob(f"{registered_dir}/**/*.py", recursive=True)
all_tests = []
for file in sorted(files):
try:
registries = ut_parse_one_file(file)
all_tests.extend(registries)
except Exception as e:
print(f"Warning: Failed to parse {file}: {e}", file=sys.stderr)
return all_tests
def get_folder_name(filename: str) -> str:
"""Extract folder name from test filename."""
# e.g., "registered/models/test_foo.py" -> "models"
parts = Path(filename).parts
if "registered" in parts:
idx = parts.index("registered")
if idx + 1 < len(parts) - 1: # Has subfolder
return parts[idx + 1]
return "root"
def get_test_basename(filename: str) -> str:
"""Extract just the test file name from the path."""
return Path(filename).name
def organize_test_data(tests: list[CIRegistry]) -> dict:
"""Organize tests into various groupings."""
by_backend = defaultdict(list)
by_folder = defaultdict(list)
disabled_tests = []
for t in tests:
by_backend[t.backend.name].append(t)
by_folder[get_folder_name(t.filename)].append(t)
if t.disabled:
disabled_tests.append(t)
# Count unique test files (a file may be registered for multiple backends)
unique_files = set(t.filename for t in tests)
unique_enabled_files = set(t.filename for t in tests if not t.disabled)
unique_disabled_files = set(t.filename for t in tests if t.disabled)
return {
"total": len(tests),
"total_unique_files": len(unique_files),
"enabled": len(tests) - len(disabled_tests),
"enabled_unique_files": len(unique_enabled_files),
"disabled_count": len(disabled_tests),
"disabled_unique_files": len(unique_disabled_files),
"by_backend": by_backend,
"by_folder": by_folder,
"disabled_tests": disabled_tests,
}
def generate_summary_section(data: dict) -> str:
"""Generate the summary/overview section."""
lines = []
lines.append("# CI Coverage Overview\n")
lines.append(
f"**Unique Test Files:** {data['total_unique_files']} ({data['enabled_unique_files']} enabled, {data['disabled_unique_files']} disabled)\n"
)
lines.append(
f"**Total Registrations:** {data['total']} ({data['enabled']} enabled, {data['disabled_count']} disabled)\n"
)
lines.append(
"*Note: A test file may be registered for multiple backends (e.g., CUDA + AMD), so total registrations > unique files.*\n"
)
by_backend = data["by_backend"]
by_folder = data["by_folder"]
disabled_tests = data["disabled_tests"]
# Backend summary (collapsible)
lines.append("<details>")
lines.append("<summary><h2>Backend Summary</h2></summary>\n")
lines.append("| Backend | Total | Enabled | Disabled | Per-Commit | Nightly |")
lines.append("|---------|-------|---------|----------|------------|---------|")
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
backend_tests = by_backend.get(backend, [])
if not backend_tests:
continue
b_total = len(backend_tests)
b_disabled = sum(1 for t in backend_tests if t.disabled)
b_enabled = b_total - b_disabled
b_per_commit = sum(1 for t in backend_tests if not t.nightly and not t.disabled)
b_nightly = sum(1 for t in backend_tests if t.nightly and not t.disabled)
lines.append(
f"| {backend} | {b_total} | {b_enabled} | {b_disabled} | {b_per_commit} | {b_nightly} |"
)
lines.append("\n</details>\n")
# Folder summary (collapsible)
lines.append("<details>")
lines.append("<summary><h2>Folder Summary</h2></summary>\n")
lines.append("| Folder | CUDA | AMD | NPU | CPU | Total |")
lines.append("|--------|------|-----|-----|-----|-------|")
for folder in sorted(by_folder.keys()):
folder_tests = by_folder[folder]
cuda = sum(1 for t in folder_tests if t.backend == HWBackend.CUDA)
amd = sum(1 for t in folder_tests if t.backend == HWBackend.AMD)
npu = sum(1 for t in folder_tests if t.backend == HWBackend.NPU)
cpu = sum(1 for t in folder_tests if t.backend == HWBackend.CPU)
lines.append(
f"| {folder} | {cuda} | {amd} | {npu} | {cpu} | {len(folder_tests)} |"
)
lines.append("\n</details>\n")
# Disabled tests section (collapsible)
if disabled_tests:
lines.append("<details>")
lines.append("<summary><h2>Disabled Tests</h2></summary>\n")
lines.append("| File | Backend | Suite | Reason |")
lines.append("|------|---------|-------|--------|")
for t in sorted(disabled_tests, key=lambda x: (x.backend.name, x.filename)):
test_name = get_test_basename(t.filename)
reason = t.disabled[:50] + "..." if len(t.disabled) > 50 else t.disabled
lines.append(f"| `{test_name}` | {t.backend.name} | {t.suite} | {reason} |")
lines.append("\n</details>\n")
return "\n".join(lines)
def generate_by_folder_section(data: dict) -> str:
"""Generate the 'All Tests by Folder' section."""
lines = []
by_folder = data["by_folder"]
lines.append("# All Tests by Folder\n")
for folder in sorted(by_folder.keys()):
folder_tests = by_folder[folder]
lines.append("<details>")
lines.append(
f"<summary><h2>{folder}/ ({len(folder_tests)} tests)</h2></summary>\n"
)
# Group by backend within folder
folder_by_backend = defaultdict(list)
for t in folder_tests:
folder_by_backend[t.backend.name].append(t)
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
backend_tests = folder_by_backend.get(backend, [])
if not backend_tests:
continue
lines.append(f"### {backend} ({len(backend_tests)} tests)\n")
lines.append("| Test File | Suite | Est. Time | Status |")
lines.append("|-----------|-------|-----------|--------|")
for t in sorted(backend_tests, key=lambda x: x.filename):
test_name = get_test_basename(t.filename)
status = (
"Disabled"
if t.disabled
else ("Nightly" if t.nightly else "Per-Commit")
)
lines.append(
f"| `{test_name}` | {t.suite} | {t.est_time:.0f}s | {status} |"
)
lines.append("")
lines.append("</details>\n")
return "\n".join(lines)
def generate_by_suite_section(data: dict) -> str:
"""Generate the 'All Tests by Test Suite' section."""
lines = []
by_backend = data["by_backend"]
lines.append("# All Tests by Test Suite\n")
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
backend_tests = by_backend.get(backend, [])
if not backend_tests:
continue
b_total = len(backend_tests)
b_disabled = sum(1 for t in backend_tests if t.disabled)
b_enabled = b_total - b_disabled
lines.append("<details>")
lines.append(
f"<summary><h2>{backend} Backend ({b_enabled} enabled, {b_disabled} disabled)</h2></summary>\n"
)
# Group by suite within backend
backend_suites = defaultdict(list)
for t in backend_tests:
backend_suites[t.suite].append(t)
for suite in sorted(backend_suites.keys()):
suite_tests = backend_suites[suite]
s_enabled = sum(1 for t in suite_tests if not t.disabled)
s_disabled = sum(1 for t in suite_tests if t.disabled)
s_est_time = sum(t.est_time for t in suite_tests if not t.disabled)
is_nightly = any(t.nightly for t in suite_tests if not t.disabled)
suite_type = "Nightly" if is_nightly else "Per-Commit"
lines.append("<details>")
lines.append(
f"<summary><h3>{suite} ({s_enabled} enabled, {s_disabled} disabled) - {suite_type}</h3></summary>\n"
)
lines.append(f"*Estimated total time: {s_est_time:.0f}s*\n")
lines.append("| Test File | Folder | Est. Time | Status |")
lines.append("|-----------|--------|-----------|--------|")
for t in sorted(suite_tests, key=lambda x: x.filename):
test_name = get_test_basename(t.filename)
folder = get_folder_name(t.filename)
if t.disabled:
status = (
f"Disabled: {t.disabled[:30]}..."
if len(t.disabled) > 30
else f"Disabled: {t.disabled}"
)
else:
status = "Nightly" if t.nightly else "Per-Commit"
lines.append(
f"| `{test_name}` | {folder} | {t.est_time:.0f}s | {status} |"
)
lines.append("\n</details>\n")
lines.append("</details>\n")
return "\n".join(lines)
def generate_markdown_report(tests: list[CIRegistry], section: str = "all") -> str:
"""Generate markdown report for GitHub step summary."""
data = organize_test_data(tests)
if section == "summary":
return generate_summary_section(data)
elif section == "by-folder":
return generate_by_folder_section(data)
elif section == "by-suite":
return generate_by_suite_section(data)
else: # "all"
parts = [
generate_summary_section(data),
"---",
generate_by_folder_section(data),
"---",
generate_by_suite_section(data),
]
return "\n".join(parts)
def generate_json_report(tests: list[CIRegistry]) -> str:
"""Generate JSON report with detailed test listings."""
by_backend = defaultdict(list)
by_folder = defaultdict(list)
for t in tests:
by_backend[t.backend.name].append(t)
by_folder[get_folder_name(t.filename)].append(t)
disabled_tests = [t for t in tests if t.disabled]
# Build structured data
data = {
"summary": {
"total": len(tests),
"enabled": len(tests) - len(disabled_tests),
"disabled": len(disabled_tests),
},
"tests_by_folder": {},
"tests_by_suite": {},
"backend_summary": {},
"folder_summary": {},
"disabled_tests": [],
}
# Section 1: Tests by Folder
for folder in sorted(by_folder.keys()):
folder_tests = by_folder[folder]
folder_by_backend = defaultdict(list)
for t in folder_tests:
folder_by_backend[t.backend.name].append(t)
data["tests_by_folder"][folder] = {
"total": len(folder_tests),
"backends": {},
}
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
backend_tests = folder_by_backend.get(backend, [])
if backend_tests:
data["tests_by_folder"][folder]["backends"][backend] = [
{
"filename": get_test_basename(t.filename),
"suite": t.suite,
"est_time": t.est_time,
"status": (
"disabled"
if t.disabled
else ("nightly" if t.nightly else "per-commit")
),
}
for t in sorted(backend_tests, key=lambda x: x.filename)
]
# Section 2: Tests by Suite (Backend -> Suite)
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
backend_tests = by_backend.get(backend, [])
if not backend_tests:
continue
backend_suites = defaultdict(list)
for t in backend_tests:
backend_suites[t.suite].append(t)
data["tests_by_suite"][backend] = {
"total": len(backend_tests),
"enabled": sum(1 for t in backend_tests if not t.disabled),
"disabled": sum(1 for t in backend_tests if t.disabled),
"suites": {},
}
for suite in sorted(backend_suites.keys()):
suite_tests = backend_suites[suite]
is_nightly = any(t.nightly for t in suite_tests if not t.disabled)
data["tests_by_suite"][backend]["suites"][suite] = {
"total": len(suite_tests),
"enabled": sum(1 for t in suite_tests if not t.disabled),
"disabled": sum(1 for t in suite_tests if t.disabled),
"est_time": sum(t.est_time for t in suite_tests if not t.disabled),
"type": "nightly" if is_nightly else "per-commit",
"tests": [
{
"filename": get_test_basename(t.filename),
"folder": get_folder_name(t.filename),
"est_time": t.est_time,
"status": (
"disabled"
if t.disabled
else ("nightly" if t.nightly else "per-commit")
),
"disabled_reason": t.disabled if t.disabled else None,
}
for t in sorted(suite_tests, key=lambda x: x.filename)
],
}
# Backend summary
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
backend_tests = by_backend.get(backend, [])
if backend_tests:
data["backend_summary"][backend] = {
"total": len(backend_tests),
"enabled": sum(1 for t in backend_tests if not t.disabled),
"disabled": sum(1 for t in backend_tests if t.disabled),
"per_commit": sum(
1 for t in backend_tests if not t.nightly and not t.disabled
),
"nightly": sum(
1 for t in backend_tests if t.nightly and not t.disabled
),
}
# Folder summary
for folder in sorted(by_folder.keys()):
folder_tests = by_folder[folder]
data["folder_summary"][folder] = {
"CUDA": sum(1 for t in folder_tests if t.backend == HWBackend.CUDA),
"AMD": sum(1 for t in folder_tests if t.backend == HWBackend.AMD),
"NPU": sum(1 for t in folder_tests if t.backend == HWBackend.NPU),
"CPU": sum(1 for t in folder_tests if t.backend == HWBackend.CPU),
"total": len(folder_tests),
}
# Disabled tests
for t in sorted(disabled_tests, key=lambda x: (x.backend.name, x.filename)):
data["disabled_tests"].append(
{
"filename": get_test_basename(t.filename),
"backend": t.backend.name,
"suite": t.suite,
"reason": t.disabled,
}
)
return json.dumps(data, indent=2)
def main():
parser = argparse.ArgumentParser(description="Generate CI coverage report")
parser.add_argument(
"--output-format",
choices=["markdown", "json"],
default="markdown",
help="Output format (default: markdown)",
)
parser.add_argument(
"--section",
choices=["all", "summary", "by-folder", "by-suite"],
default="all",
help="Which section to output (default: all). Only applies to markdown format.",
)
parser.add_argument(
"--registered-dir",
default="test/registered",
help="Path to registered test directory",
)
args = parser.parse_args()
# Change to repo root if needed
script_dir = Path(__file__).parent.parent
repo_root = script_dir.parent.parent
os.chdir(repo_root)
tests = collect_all_tests(args.registered_dir)
if args.output_format == "markdown":
report = generate_markdown_report(tests, section=args.section)
else:
report = generate_json_report(tests)
print(report)
# Write to GITHUB_STEP_SUMMARY if available
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
if summary_file and args.output_format == "markdown":
with open(summary_file, "a") as f:
f.write(report)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
"""
Clean up stale HuggingFace cache artifacts from previous failed downloads.
This script removes incomplete marker files, temporary files, and lock files
from the HuggingFace cache directory. These artifacts can accumulate from
interrupted or failed downloads and may interfere with future downloads.
"""
import os
import sys
from pathlib import Path
from typing import List
try:
from huggingface_hub import constants
HF_HUB_AVAILABLE = True
except ImportError:
print("Warning: huggingface_hub not available")
HF_HUB_AVAILABLE = False
def get_hf_cache_dir() -> str:
"""Get the HuggingFace cache directory."""
if HF_HUB_AVAILABLE:
return constants.HF_HUB_CACHE
# Fallback to environment variable or default
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
return os.path.join(hf_home, "hub")
def find_stale_artifacts(cache_dir: str) -> List[Path]:
"""
Find stale artifact files in the HuggingFace cache.
Args:
cache_dir: HuggingFace cache directory
Returns:
List of paths to stale artifact files
"""
cache_path = Path(cache_dir)
if not cache_path.exists():
return []
# Patterns for stale files to clean up
patterns = [
"**/*.incomplete", # Incomplete download markers
"**/*.tmp", # Temporary files
"**/*.lock", # Lock files from interrupted downloads
]
stale_files = []
for pattern in patterns:
stale_files.extend(cache_path.glob(pattern))
return stale_files
def cleanup_artifacts(artifacts: List[Path]) -> tuple[int, int]:
"""
Remove stale artifact files.
Args:
artifacts: List of file paths to remove
Returns:
Tuple of (successful_removals, failed_removals)
"""
successful = 0
failed = 0
for file_path in artifacts:
try:
file_path.unlink()
print(f" Removed: {file_path}")
successful += 1
except Exception as e:
print(f" Warning: Could not remove {file_path}: {e}")
failed += 1
return successful, failed
def main() -> int:
"""
Main cleanup logic.
Returns:
Always returns 0 (cleanup is best-effort and should not fail CI)
"""
print("=" * 70)
print("HuggingFace Cache Cleanup")
print("=" * 70)
# Get cache directory
cache_dir = get_hf_cache_dir()
print(f"Cache directory: {cache_dir}")
if not os.path.exists(cache_dir):
print("Cache directory does not exist - nothing to clean")
return 0
print("-" * 70)
# Find stale artifacts
print("Scanning for stale artifacts...")
stale_artifacts = find_stale_artifacts(cache_dir)
if not stale_artifacts:
print("✓ No stale cache artifacts found")
return 0
# Clean up artifacts
print(f"Found {len(stale_artifacts)} stale artifact(s) to remove:")
successful, failed = cleanup_artifacts(stale_artifacts)
print("-" * 70)
# Summary
if failed > 0:
print(f"⚠ Cleaned up {successful} file(s), {failed} removal(s) failed")
else:
print(f"✓ Successfully cleaned up {successful} stale file(s)")
# Always return 0 - cleanup failures should not fail CI
return 0
if __name__ == "__main__":
try:
exit_code = main()
sys.exit(exit_code)
except KeyboardInterrupt:
print("\nInterrupted by user")
sys.exit(0)
except Exception as e:
print(f"ERROR: Unexpected error during cleanup: {e}")
import traceback
traceback.print_exc()
# Still return 0 - cleanup failures should not fail CI
sys.exit(0)

View File

@@ -0,0 +1,157 @@
{
"_comment": "Per-model comparison config. Sampling params omitted where model defaults are correct — only override resolution, seed, and params that differ from defaults.",
"test_image_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png",
"cases": [
{
"id": "flux1_dev_t2i_1024",
"model": "black-forest-labs/FLUX.1-dev",
"task": "text-to-image",
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
"width": 1024,
"height": 1024,
"seed": 42,
"num_gpus": 1,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup --dit-layerwise-offload false",
"extra_env": {}
}
}
},
{
"id": "flux2_dev_t2i_1024",
"model": "black-forest-labs/FLUX.2-dev",
"task": "text-to-image",
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
"width": 1024,
"height": 1024,
"seed": 42,
"num_gpus": 1,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup --dit-layerwise-offload false",
"extra_env": {}
}
}
},
{
"id": "qwen_image_2512_t2i_1024",
"model": "Qwen/Qwen-Image-2512",
"task": "text-to-image",
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
"width": 1024,
"height": 1024,
"seed": 42,
"num_gpus": 1,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup",
"extra_env": {}
}
}
},
{
"id": "qwen_image_edit_2511",
"model": "Qwen/Qwen-Image-Edit-2511",
"task": "image-edit",
"prompt": "Make the cat wear a red hat",
"reference_image": true,
"width": 1024,
"height": 1024,
"seed": 42,
"num_gpus": 1,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup",
"extra_env": {}
}
}
},
{
"id": "zimage_turbo_t2i_1024",
"model": "Tongyi-MAI/Z-Image-Turbo",
"task": "text-to-image",
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
"width": 1024,
"height": 1024,
"seed": 42,
"num_gpus": 1,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup",
"extra_env": {}
}
}
},
{
"id": "wan22_t2v_a14b_720p",
"model": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
"task": "text-to-video",
"prompt": "A cat and a dog baking a cake together in a kitchen.",
"width": 1280,
"height": 720,
"num_frames": 81,
"seed": 42,
"num_gpus": 4,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup --enable-cfg-parallel --ulysses-degree 2 --text-encoder-cpu-offload --pin-cpu-memory",
"extra_env": {}
}
}
},
{
"id": "wan22_ti2v_5b_720p",
"model": "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
"task": "text-image-to-video",
"prompt": "The cat starts walking slowly towards the camera.",
"reference_image": true,
"width": 1280,
"height": 720,
"num_frames": 81,
"seed": 42,
"num_gpus": 1,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup",
"extra_env": {}
}
}
},
{
"id": "ltx2_twostage_t2v",
"model": "Lightricks/LTX-2",
"task": "text-to-video",
"prompt": "A cat and a dog baking a cake together in a kitchen.",
"width": 768,
"height": 512,
"num_frames": 121,
"seed": 42,
"num_gpus": 2,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup --enable-cfg-parallel --pipeline-class-name LTX2TwoStagePipeline",
"extra_env": {}
}
}
},
{
"id": "wan22_i2v_a14b_720p",
"model": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
"task": "image-to-video",
"prompt": "The cat starts walking slowly towards the camera.",
"reference_image": true,
"width": 1280,
"height": 720,
"num_frames": 81,
"seed": 42,
"num_gpus": 4,
"frameworks": {
"sglang": {
"serve_args": "--enable-torch-compile --warmup --enable-cfg-parallel --ulysses-degree 2 --text-encoder-cpu-offload --pin-cpu-memory",
"extra_env": {}
}
}
}
]
}

View File

@@ -0,0 +1,836 @@
"""Generate a Markdown dashboard for diffusion cross-framework comparisons.
Reads current comparison results + historical data from sglang-ci-data repo
and produces a Markdown report with tables and trend charts saved as PNG files.
Usage:
python3 scripts/ci/utils/diffusion/generate_diffusion_dashboard.py \
--results comparison-results.json \
--output dashboard.md \
--charts-dir comparison-charts/ \
--history-dir history/ # optional, local history JSONs
--fetch-history # fetch from GitHub API instead
"""
import argparse
import json
import os
import sys
from datetime import datetime, timezone
# ---------------------------------------------------------------------------
# History fetching (from sglang-ci-data repo via GitHub API)
# ---------------------------------------------------------------------------
CI_DATA_REPO_OWNER = "sglang-bot"
CI_DATA_REPO_NAME = "sglang-ci-data"
CI_DATA_BRANCH = "main"
HISTORY_PREFIX = "diffusion-comparisons"
MAX_HISTORY_RUNS = 14
# Base URL for chart images pushed to sglang-ci-data
CHARTS_RAW_BASE_URL = (
f"https://raw.githubusercontent.com/{CI_DATA_REPO_OWNER}/{CI_DATA_REPO_NAME}"
f"/{CI_DATA_BRANCH}/{HISTORY_PREFIX}/charts"
)
def _github_get(url: str, token: str) -> dict | list | None:
"""Simple GET to GitHub API."""
from urllib.error import HTTPError
from urllib.request import Request, urlopen
headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {token}",
"X-GitHub-Api-Version": "2022-11-28",
}
req = Request(url, headers=headers)
try:
with urlopen(req) as resp:
return json.loads(resp.read().decode("utf-8"))
except HTTPError as e:
print(f" Warning: GitHub API request failed ({e.code}): {url}")
return None
except Exception as e:
print(f" Warning: GitHub API request error: {e}")
return None
def fetch_history_from_github(token: str) -> list[dict]:
"""Fetch recent comparison result JSONs from sglang-ci-data repo."""
print("Fetching historical comparison data from GitHub...")
url = (
f"https://api.github.com/repos/{CI_DATA_REPO_OWNER}/{CI_DATA_REPO_NAME}"
f"/contents/{HISTORY_PREFIX}?ref={CI_DATA_BRANCH}"
)
listing = _github_get(url, token)
if not listing or not isinstance(listing, list):
print(" No historical data found.")
return []
# Filter JSON files and sort by name (date prefix) descending
json_files = sorted(
[f for f in listing if f["name"].endswith(".json")],
key=lambda f: f["name"],
reverse=True,
)[:MAX_HISTORY_RUNS]
history = []
for entry in json_files:
raw_url = entry.get("download_url")
if not raw_url:
continue
data = _github_get(raw_url, token)
if data and isinstance(data, dict):
history.append(data)
print(f" Loaded {len(history)} historical run(s).")
return history
def load_history_from_dir(history_dir: str) -> list[dict]:
"""Load historical JSONs from a local directory."""
if not os.path.isdir(history_dir):
return []
files = sorted(
[f for f in os.listdir(history_dir) if f.endswith(".json")],
reverse=True,
)[:MAX_HISTORY_RUNS]
history = []
for fname in files:
try:
with open(os.path.join(history_dir, fname)) as f:
history.append(json.load(f))
except Exception:
pass
return history
# ---------------------------------------------------------------------------
# Dashboard generation
# ---------------------------------------------------------------------------
def _fmt_latency(val: float | None) -> str:
if val is None:
return "N/A"
return f"{val:.2f}"
def _fmt_speedup(sglang_lat: float | None, other_lat: float | None) -> str:
if sglang_lat is None or other_lat is None or sglang_lat <= 0:
return "N/A"
ratio = other_lat / sglang_lat
return f"{ratio:.2f}x"
def _short_date(ts: str) -> str:
"""Extract short date from ISO timestamp."""
try:
dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
return dt.strftime("%b %d")
except Exception:
return ts[:10]
def _short_sha(sha: str) -> str:
return sha[:7] if sha and sha != "unknown" else "?"
def _assess_risk(
cid: str,
current_cases: dict[str, dict[str, float | None]],
history: list[dict],
other_frameworks: list[str],
) -> tuple[str, str]:
"""Assess risk for a given case, returning (emoji, reason).
Rules (checked in order):
- N/A latency → ❌ broken
- History exists: SGLang latency >5% vs avg of last 3 runs → ⚠️ regression
- Competitor exists & SGLang slower → 🔴 competitive risk
- SGLang faster than all competitors by >20% → 🟢 strong advantage
- SGLang faster than all competitors by ≤20% → 🟡 moderate advantage
- Default → ✅ stable
"""
sg_lat = current_cases.get(cid, {}).get("sglang")
# Broken: sglang latency is N/A
if sg_lat is None:
return "", f"{cid}: SGLang latency is N/A (broken)"
# Check regression against 3-run historical average
if history:
hist_lats: list[float] = []
for run in history[:3]:
run_cases = _extract_case_results(run)
h_lat = run_cases.get(cid, {}).get("sglang")
if h_lat is not None:
hist_lats.append(h_lat)
if hist_lats:
avg_3 = sum(hist_lats) / len(hist_lats)
if avg_3 > 0 and (sg_lat - avg_3) / avg_3 > 0.05:
pct = (sg_lat - avg_3) / avg_3 * 100
return (
"⚠️",
f"{cid}: SGLang regression +{pct:.1f}% vs 3-run avg "
f"({sg_lat:.2f}s vs {avg_3:.2f}s)",
)
# Check competitive risk
if other_frameworks:
competitor_lats: dict[str, float] = {}
for ofw in other_frameworks:
olat = current_cases.get(cid, {}).get(ofw)
if olat is not None:
competitor_lats[ofw] = olat
if competitor_lats:
# SGLang slower than any competitor?
for ofw, olat in competitor_lats.items():
if sg_lat > olat:
return (
"🔴",
f"{cid}: SGLang slower than {ofw} "
f"({sg_lat:.2f}s vs {olat:.2f}s)",
)
# SGLang faster — check margin
min_competitor = min(competitor_lats.values())
advantage = (min_competitor - sg_lat) / min_competitor
if advantage > 0.20:
return "🟢", ""
else:
return "🟡", ""
# Default: stable
return "", ""
def _trend_emoji(current: float | None, previous: float | None) -> str:
if current is None or previous is None:
return ""
diff_pct = (current - previous) / previous * 100
if diff_pct < -2:
return " :arrow_down:" # faster (good)
elif diff_pct > 2:
return " :arrow_up:" # slower (bad)
return " :left_right_arrow:"
def _extract_case_results(run_data: dict) -> dict[str, dict[str, float | None]]:
"""Extract {case_id: {framework: latency}} from a run."""
mapping: dict[str, dict[str, float | None]] = {}
for r in run_data.get("results", []):
cid = r["case_id"]
fw = r["framework"]
if cid not in mapping:
mapping[cid] = {}
mapping[cid][fw] = r.get("latency_s")
return mapping
def _sanitize_filename(name: str) -> str:
"""Sanitize a case ID to be a safe filename."""
return name.replace("/", "_").replace(" ", "_").replace(":", "_")
def generate_dashboard(
current: dict,
history: list[dict],
charts_dir: str | None = None,
) -> tuple[str, list[str]]:
"""Generate full markdown dashboard.
Returns (markdown_string, alert_reasons) where alert_reasons is a list of
human-readable strings for cases that need attention (empty if all is well).
If charts_dir is provided, saves chart PNGs as files to that directory
and references them via raw.githubusercontent URLs. Otherwise, charts
are omitted.
Returns the markdown string.
"""
lines: list[str] = []
lines.append("# Diffusion Cross-Framework Performance Dashboard\n")
ts = current.get("timestamp", datetime.now(timezone.utc).isoformat())
sha = current.get("commit_sha", "unknown")
lines.append(f"*Generated: {_short_date(ts)} | Commit: `{_short_sha(sha)}`*\n")
current_cases = _extract_case_results(current)
case_ids = list(current_cases.keys())
# ---- Regression detection ----
REGRESSION_THRESHOLD = 0.05 # 5%
regressions: list[str] = []
if history:
prev_cases = _extract_case_results(history[0])
for cid in case_ids:
for fw in ("sglang", "vllm-omni"):
cur = current_cases.get(cid, {}).get(fw)
prev = prev_cases.get(cid, {}).get(fw)
if cur and prev and prev > 0:
pct = (cur - prev) / prev
if pct > REGRESSION_THRESHOLD:
regressions.append(
f"**{cid}** ({fw}): {prev:.2f}s -> {cur:.2f}s "
f"(+{pct*100:.1f}%)"
)
if regressions:
lines.append("> [!WARNING]\n> **Performance Regression Detected**\n>")
for reg in regressions:
lines.append(f"> - {reg}")
lines.append("\n")
# Discover all frameworks present in results
all_frameworks = []
seen_fw = set()
for r in current.get("results", []):
fw = r["framework"]
if fw not in seen_fw:
all_frameworks.append(fw)
seen_fw.add(fw)
# Ensure sglang is first
if "sglang" in all_frameworks:
all_frameworks.remove("sglang")
all_frameworks.insert(0, "sglang")
other_frameworks = [fw for fw in all_frameworks if fw != "sglang"]
# ---- Section 1: Cross-Framework Comparison (current run) ----
lines.append("## Cross-Framework Performance Comparison\n")
# Compute risk assessments for all cases
risk_map: dict[str, tuple[str, str]] = {}
for cid in case_ids:
risk_map[cid] = _assess_risk(cid, current_cases, history, other_frameworks)
# Dynamic header
header = "| Model | Risk |"
sep = "|-------|------|"
for fw in all_frameworks:
header += f" {fw} (s) |"
sep += "---------|"
for ofw in other_frameworks:
header += f" vs {ofw} |"
sep += "---------|"
lines.append(header)
lines.append(sep)
# One row per case (deduplicated by case_id)
seen_cases = set()
for r in current.get("results", []):
cid = r["case_id"]
if cid in seen_cases:
continue
seen_cases.add(cid)
case_fws = current_cases.get(cid, {})
sg_lat = case_fws.get("sglang")
risk_emoji, _ = risk_map.get(cid, ("", ""))
row = f"| {r['model'].split('/')[-1]} | {risk_emoji} |"
# Latency columns -- bold the fastest
lats = {fw: case_fws.get(fw) for fw in all_frameworks}
valid_lats = [v for v in lats.values() if v is not None]
min_lat = min(valid_lats) if valid_lats else None
for fw in all_frameworks:
lat = lats[fw]
if lat is not None and min_lat is not None and lat == min_lat:
row += f" **{_fmt_latency(lat)}** |"
else:
row += f" {_fmt_latency(lat)} |"
# Speedup columns
for ofw in other_frameworks:
row += f" {_fmt_speedup(sg_lat, case_fws.get(ofw))} |"
lines.append(row)
# ---- Section 2: Cross-Framework Speedup Trend (only if multiple frameworks) ----
if history and other_frameworks:
lines.append("\n## SGLang vs vLLM-Omni Speedup Over Time\n")
header = "| Date |"
sep = "|------|"
for cid in case_ids:
header += f" {cid} |"
sep += "---------|"
lines.append(header)
lines.append(sep)
all_runs = [current] + history
for run in all_runs:
run_cases = _extract_case_results(run)
date = _short_date(run.get("timestamp", ""))
row = f"| {date} |"
for cid in case_ids:
sg = run_cases.get(cid, {}).get("sglang")
vl = run_cases.get(cid, {}).get("vllm-omni")
row += f" {_fmt_speedup(sg, vl)} |"
lines.append(row)
# ---- Section 4: Matplotlib Trend Charts (saved as PNG files) ----
if history and charts_dir:
all_runs = list(reversed([current] + history)) # chronological order
def _chart_label(run: dict) -> str:
d = _short_date(run.get("timestamp", ""))
s = _short_sha(run.get("commit_sha", ""))
return f"{d}\n({s})"
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
os.makedirs(charts_dir, exist_ok=True)
# Per-case latency trend charts
for cid in case_ids:
labels = []
sg_vals = []
vl_vals = []
for run in all_runs:
run_cases = _extract_case_results(run)
sg = run_cases.get(cid, {}).get("sglang")
vl = run_cases.get(cid, {}).get("vllm-omni")
if sg is None:
continue
labels.append(_chart_label(run))
sg_vals.append(sg)
vl_vals.append(vl)
if not sg_vals:
continue
has_vl = any(v is not None for v in vl_vals)
fig, ax = plt.subplots(figsize=(max(6, len(labels) * 1.2), 4))
# SGLang line
ax.plot(
range(len(sg_vals)),
sg_vals,
"o-",
color="#2563eb",
linewidth=2,
markersize=6,
label="SGLang",
)
for i, v in enumerate(sg_vals):
ax.annotate(
f"{v:.2f}s",
(i, v),
textcoords="offset points",
xytext=(0, 10),
ha="center",
fontsize=8,
fontweight="bold",
color="#2563eb",
)
# vLLM-Omni line (if data exists)
if has_vl:
vl_clean = [v if v is not None else float("nan") for v in vl_vals]
ax.plot(
range(len(vl_clean)),
vl_clean,
"s--",
color="#dc2626",
linewidth=2,
markersize=5,
label="vLLM-Omni",
)
for i, v in enumerate(vl_vals):
if v is not None:
ax.annotate(
f"{v:.2f}s",
(i, v),
textcoords="offset points",
xytext=(0, -14),
ha="center",
fontsize=8,
color="#dc2626",
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=7)
ax.set_ylabel("Latency (s)")
ax.set_title(f"Latency Trend -- {cid}", fontsize=11, fontweight="bold")
ax.legend(loc="lower right", fontsize=8, framealpha=0.8)
ax.grid(True, alpha=0.3)
all_vals = sg_vals + [v for v in vl_vals if v is not None]
y_min = min(all_vals)
y_max = max(all_vals)
y_range = y_max - y_min if y_max > y_min else max(y_max * 0.1, 0.1)
ax.set_ylim(
bottom=max(0, y_min - y_range * 0.3),
top=y_max + y_range * 0.3,
)
filename = f"latency_{_sanitize_filename(cid)}.png"
chart_path = os.path.join(charts_dir, filename)
fig.savefig(chart_path, format="png", dpi=120, bbox_inches="tight")
plt.close(fig)
print(f" Saved chart: {chart_path}")
chart_url = f"{CHARTS_RAW_BASE_URL}/{filename}"
lines.append(f"\n### Latency Trend: {cid}\n")
lines.append(f"![Latency Trend {cid}]({chart_url})\n")
# Speedup trend chart (only if multiple frameworks)
if other_frameworks:
fig, ax = plt.subplots(figsize=(max(6, len(all_runs) * 1.2), 4))
colors = ["#2563eb", "#dc2626", "#16a34a", "#ea580c"]
for ci_idx, cid in enumerate(case_ids):
speedups = []
run_labels = []
for run in all_runs:
run_cases = _extract_case_results(run)
sg = run_cases.get(cid, {}).get("sglang")
vl = run_cases.get(cid, {}).get("vllm-omni")
if sg and vl and sg > 0:
speedups.append(vl / sg)
else:
speedups.append(None)
run_labels.append(_chart_label(run))
clean = [v if v is not None else float("nan") for v in speedups]
ax.plot(
range(len(clean)),
clean,
"o-",
color=colors[ci_idx % len(colors)],
linewidth=2,
markersize=5,
label=cid,
)
ax.set_xticks(range(len(run_labels)))
ax.set_xticklabels(run_labels, fontsize=7)
ax.set_ylabel("Speedup (x)")
ax.set_title(
"SGLang Speedup Over vLLM-Omni", fontsize=11, fontweight="bold"
)
ax.axhline(y=1.0, color="gray", linestyle=":", alpha=0.5)
ax.legend(loc="upper left", fontsize=7)
ax.grid(True, alpha=0.3)
filename = "speedup_trend.png"
chart_path = os.path.join(charts_dir, filename)
fig.savefig(chart_path, format="png", dpi=120, bbox_inches="tight")
plt.close(fig)
print(f" Saved chart: {chart_path}")
chart_url = f"{CHARTS_RAW_BASE_URL}/{filename}"
lines.append("\n### Speedup Trend (SGLang vs vLLM-Omni)\n")
lines.append(f"![Speedup Trend]({chart_url})\n")
except ImportError:
lines.append("\n*Charts unavailable (matplotlib not installed)*\n")
# ---- SGLang Performance Trend (raw data table, at the end) ----
if history:
lines.append(f"\n## SGLang Performance Trend (Last {len(history) + 1} Runs)\n")
header = "| Date | Commit |"
sep = "|------|--------|"
for cid in case_ids:
header += f" {cid} (s) |"
sep += "---------|"
header += " Trend |"
sep += "-------|"
lines.append(header)
lines.append(sep)
all_runs = [current] + history
for i, run in enumerate(all_runs):
run_cases = _extract_case_results(run)
date = _short_date(run.get("timestamp", ""))
sha_s = _short_sha(run.get("commit_sha", ""))
row = f"| {date} | `{sha_s}` |"
for cid in case_ids:
lat = run_cases.get(cid, {}).get("sglang")
row += f" {_fmt_latency(lat)} |"
if i + 1 < len(all_runs):
prev_cases = _extract_case_results(all_runs[i + 1])
emojis = []
for cid in case_ids:
cur = run_cases.get(cid, {}).get("sglang")
prev = prev_cases.get(cid, {}).get("sglang")
emojis.append(_trend_emoji(cur, prev))
row += " ".join(emojis) + " |"
else:
row += " -- |"
lines.append(row)
# ---- Risk Notification ----
alert_cases = [
(cid, emoji, reason)
for cid, (emoji, reason) in risk_map.items()
if emoji in ("⚠️", "🔴", "")
]
if alert_cases:
lines.append("\n> [!CAUTION]")
lines.append("> **Action Required — Performance Alert**")
lines.append(">")
lines.append("> The following cases need attention:")
for _cid, _emoji, reason in alert_cases:
lines.append(f"> - {reason}")
lines.append("")
# Footer
lines.append("\n---")
lines.append(
"*Generated by `generate_diffusion_dashboard.py` in SGLang nightly CI.*"
)
alert_reasons = [reason for _, _, reason in alert_cases]
return "\n".join(lines) + "\n", alert_reasons
ALERT_ASSIGNEES = ["mickqian", "bbuf", "yhyang201"]
ALERT_LABEL = "perf-regression"
ALERT_ISSUE_TITLE = "[Diffusion CI] Performance regression tracker"
def _find_alert_issue(repo: str) -> tuple[str | None, bool]:
"""Find the perf-regression tracker issue (open OR closed).
Returns (issue_number, is_open). Prefers an open issue; if none,
returns the most recent closed one so it can be reopened.
"""
import subprocess
for state in ("open", "closed"):
result = subprocess.run(
[
"gh",
"issue",
"list",
"--repo",
repo,
"--label",
ALERT_LABEL,
"--state",
state,
"--json",
"number",
"--limit",
"1",
],
capture_output=True,
text=True,
timeout=30,
)
if result.returncode != 0 or not result.stdout.strip():
continue
issues = json.loads(result.stdout)
if issues:
return str(issues[0]["number"]), state == "open"
return None, False
def _create_alert_issue(alert_reasons: list[str]) -> None:
"""Create or update the single perf-regression tracker issue.
Logic:
- If an open issue exists → add a comment with the new alert.
- If a closed issue exists → reopen it, then add a comment.
- If no issue exists → create one.
This guarantees at most one tracker issue ever exists.
Uses `gh` (GitHub CLI) which is available in all GitHub Actions runners.
Falls back silently outside CI.
"""
import subprocess
run_url = ""
run_id = os.environ.get("GITHUB_RUN_ID", "")
repo = os.environ.get("GITHUB_REPOSITORY", "sgl-project/sglang")
server_url = os.environ.get("GITHUB_SERVER_URL", "https://github.com")
if run_id:
run_url = f"{server_url}/{repo}/actions/runs/{run_id}"
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
body_lines = [
f"## Performance Alert — {date}",
"",
"The nightly diffusion benchmark detected the following issue(s):",
"",
]
for reason in alert_reasons:
body_lines.append(f"- {reason}")
if run_url:
body_lines += ["", f"**CI Run:** {run_url}"]
body = "\n".join(body_lines)
try:
existing, is_open = _find_alert_issue(repo)
if existing:
# Reopen if closed
if not is_open:
subprocess.run(
[
"gh",
"issue",
"reopen",
existing,
"--repo",
repo,
],
capture_output=True,
text=True,
timeout=30,
)
print(f"Reopened alert issue #{existing}")
# Add comment
result = subprocess.run(
[
"gh",
"issue",
"comment",
existing,
"--repo",
repo,
"--body",
body,
],
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
print(f"Commented on alert issue #{existing}")
else:
print(
f"Warning: failed to comment on issue #{existing} "
f"(rc={result.returncode}): {result.stderr.strip()}"
)
else:
# Create a new issue
cmd = [
"gh",
"issue",
"create",
"--repo",
repo,
"--title",
ALERT_ISSUE_TITLE,
"--body",
body,
"--label",
ALERT_LABEL,
]
for user in ALERT_ASSIGNEES:
cmd += ["--assignee", user]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode == 0:
print(f"Created alert issue: {result.stdout.strip()}")
else:
print(
f"Warning: failed to create alert issue "
f"(rc={result.returncode}): {result.stderr.strip()}"
)
except FileNotFoundError:
print("Warning: `gh` CLI not found — skipping alert issue creation")
except Exception as e:
print(f"Warning: failed to create/update alert issue: {e}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Generate diffusion cross-framework comparison dashboard"
)
parser.add_argument(
"--results",
required=True,
help="Path to comparison-results.json from current run",
)
parser.add_argument(
"--output",
default="dashboard.md",
help="Output markdown file path",
)
parser.add_argument(
"--charts-dir",
default="comparison-charts",
help="Directory to save chart PNG files (default: comparison-charts/)",
)
parser.add_argument(
"--history-dir",
default=None,
help="Local directory containing historical comparison JSONs",
)
parser.add_argument(
"--fetch-history",
action="store_true",
help="Fetch history from sglang-ci-data GitHub repo",
)
parser.add_argument(
"--step-summary",
action="store_true",
help="Also write to $GITHUB_STEP_SUMMARY",
)
args = parser.parse_args()
# Load current results
with open(args.results) as f:
current = json.load(f)
print(f"Loaded current results: {len(current.get('results', []))} entries")
# Load history
history: list[dict] = []
if args.fetch_history:
token = os.environ.get("GH_PAT_FOR_NIGHTLY_CI_DATA") or os.environ.get(
"GITHUB_TOKEN"
)
if token:
history = fetch_history_from_github(token)
else:
print("Warning: No GitHub token available, skipping history fetch")
elif args.history_dir:
history = load_history_from_dir(args.history_dir)
print(f"Loaded {len(history)} historical run(s) from {args.history_dir}")
# Generate dashboard
markdown, alert_reasons = generate_dashboard(
current, history, charts_dir=args.charts_dir
)
# Write output
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
f.write(markdown)
print(f"Dashboard written to {args.output}")
# Write to GitHub Step Summary
if args.step_summary:
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
if summary_file:
with open(summary_file, "a") as f:
f.write(markdown)
print("Dashboard appended to $GITHUB_STEP_SUMMARY")
else:
print("Warning: $GITHUB_STEP_SUMMARY not set, skipping")
# Create GitHub Issue for performance alerts (so assignees get notified)
if alert_reasons:
_create_alert_issue(alert_reasons)
else:
print("No performance alerts — skipping issue creation.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,231 @@
"""Publish diffusion comparison results to sglang-bot/sglang-ci-data repo.
Pushes comparison-results.json, dashboard.md, and chart PNG files to the
ci-data repository for historical tracking. Chart PNGs are stored under
diffusion-comparisons/charts/ so they can be referenced via
raw.githubusercontent URLs in the dashboard markdown (GitHub Step Summary
blocks data: URIs).
Usage:
python3 scripts/ci/utils/diffusion/publish_comparison_results.py \
--results comparison-results.json \
--dashboard dashboard.md \
--charts-dir comparison-charts/
"""
import argparse
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
# Reuse GitHub API helpers from publish_traces.
# Support both direct script execution and package-style imports.
if __package__:
from ..publish_traces import (
create_blobs,
create_commit,
create_tree,
get_branch_sha,
get_tree_sha,
is_permission_error,
is_rate_limit_error,
update_branch_ref,
verify_token_permissions,
)
else:
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from publish_traces import (
create_blobs,
create_commit,
create_tree,
get_branch_sha,
get_tree_sha,
is_permission_error,
is_rate_limit_error,
update_branch_ref,
verify_token_permissions,
)
# Repository configuration
REPO_OWNER = "sglang-bot"
REPO_NAME = "sglang-ci-data"
BRANCH = "main"
STORAGE_PREFIX = "diffusion-comparisons"
def _collect_chart_files(charts_dir: str) -> list[tuple[str, bytes]]:
"""Collect PNG chart files from directory for upload."""
files: list[tuple[str, bytes]] = []
if not charts_dir or not os.path.isdir(charts_dir):
return files
for entry in sorted(os.listdir(charts_dir)):
if not entry.lower().endswith(".png"):
continue
full_path = os.path.join(charts_dir, entry)
if not os.path.isfile(full_path):
continue
with open(full_path, "rb") as f:
content = f.read()
# Store charts under diffusion-comparisons/charts/
repo_path = f"{STORAGE_PREFIX}/charts/{entry}"
files.append((repo_path, content))
return files
def publish_comparison(
results_path: str,
dashboard_path: str | None = None,
charts_dir: str | None = None,
) -> None:
"""Publish comparison results, dashboard, and charts to ci-data repo."""
token = os.environ.get("GH_PAT_FOR_NIGHTLY_CI_DATA") or os.environ.get(
"GITHUB_TOKEN"
)
if not token:
print("Error: GH_PAT_FOR_NIGHTLY_CI_DATA or GITHUB_TOKEN not set")
sys.exit(1)
run_id = os.environ.get("GITHUB_RUN_ID", "local")
run_number = os.environ.get("GITHUB_RUN_NUMBER", "0")
# Verify permissions
perm = verify_token_permissions(REPO_OWNER, REPO_NAME, token)
if perm == "rate_limited":
print("Warning: Rate limited, skipping publish")
return
elif not perm:
print("Error: Token permission verification failed")
sys.exit(1)
# Prepare files to upload
files_to_upload: list[tuple[str, bytes]] = []
# Results JSON: stored with date prefix for chronological ordering
date_prefix = datetime.now(timezone.utc).strftime("%Y-%m-%d")
results_target = f"{STORAGE_PREFIX}/{date_prefix}_{run_id}.json"
with open(results_path, "rb") as f:
files_to_upload.append((results_target, f.read()))
# Dashboard markdown: always overwrite latest
if dashboard_path and os.path.exists(dashboard_path):
dashboard_target = f"{STORAGE_PREFIX}/dashboard.md"
with open(dashboard_path, "rb") as f:
files_to_upload.append((dashboard_target, f.read()))
# Chart PNG files
chart_files = _collect_chart_files(charts_dir)
if chart_files:
print(f"Found {len(chart_files)} chart PNG(s) to upload")
files_to_upload.extend(chart_files)
print(f"Publishing {len(files_to_upload)} file(s) to {REPO_OWNER}/{REPO_NAME}")
# Create blobs
try:
tree_items = create_blobs(REPO_OWNER, REPO_NAME, files_to_upload, token)
except Exception as e:
if is_rate_limit_error(e):
print("Warning: Rate limited during blob creation, skipping")
return
if is_permission_error(e):
print(f"Error: No write permission to {REPO_OWNER}/{REPO_NAME}")
sys.exit(1)
raise
# Commit with retry (handle concurrent writes)
max_retries = 5
retry_delay = 5
for attempt in range(max_retries):
try:
branch_sha = get_branch_sha(REPO_OWNER, REPO_NAME, BRANCH, token)
tree_sha = get_tree_sha(REPO_OWNER, REPO_NAME, branch_sha, token)
new_tree_sha = create_tree(
REPO_OWNER, REPO_NAME, tree_sha, tree_items, token
)
commit_msg = (
f"Diffusion comparison results for run {run_id} (#{run_number})"
)
commit_sha = create_commit(
REPO_OWNER, REPO_NAME, new_tree_sha, branch_sha, commit_msg, token
)
update_branch_ref(REPO_OWNER, REPO_NAME, BRANCH, commit_sha, token)
print(
f"Successfully published comparison results (commit {commit_sha[:7]})"
)
return
except Exception as e:
is_retryable = False
if hasattr(e, "error_body"):
body = getattr(e, "error_body", "")
if "Update is not a fast forward" in body:
is_retryable = True
elif "Object does not exist" in body:
is_retryable = True
from urllib.error import HTTPError
if isinstance(e, HTTPError) and e.code in [422, 500, 502, 503, 504]:
is_retryable = True
if is_rate_limit_error(e):
print("Warning: Rate limited, skipping publish")
return
if is_permission_error(e):
print(f"Error: No write permission to {REPO_OWNER}/{REPO_NAME}")
sys.exit(1)
if is_retryable and attempt < max_retries - 1:
print(
f"Attempt {attempt + 1}/{max_retries} failed, retrying in {retry_delay}s..."
)
time.sleep(retry_delay)
else:
print(f"Failed to publish after {attempt + 1} attempts: {e}")
raise
def main():
parser = argparse.ArgumentParser(
description="Publish diffusion comparison results to sglang-ci-data"
)
parser.add_argument(
"--results",
required=True,
help="Path to comparison-results.json",
)
parser.add_argument(
"--dashboard",
default=None,
help="Path to dashboard.md (optional)",
)
parser.add_argument(
"--charts-dir",
default=None,
help="Directory containing chart PNG files to upload (optional)",
)
args = parser.parse_args()
if not os.path.exists(args.results):
print(f"Error: Results file not found: {args.results}")
sys.exit(1)
publish_comparison(
results_path=args.results,
dashboard_path=args.dashboard,
charts_dir=args.charts_dir,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,166 @@
"""
Publish diffusion CI ground-truth images to sglang-bot/sglang-ci-data
via the GitHub API (same pattern as publish_traces.py).
"""
import argparse
import os
import sys
from pathlib import Path
# Reuse GitHub API helpers from publish_traces.
# Support both direct script execution and package-style imports.
if __package__:
from ..publish_traces import (
create_blobs,
create_commit,
create_tree,
get_branch_sha,
get_tree_sha,
is_permission_error,
is_rate_limit_error,
update_branch_ref,
verify_token_permissions,
)
else:
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from publish_traces import (
create_blobs,
create_commit,
create_tree,
get_branch_sha,
get_tree_sha,
is_permission_error,
is_rate_limit_error,
update_branch_ref,
verify_token_permissions,
)
REPO_OWNER = "sglang-bot"
REPO_NAME = "sglang-ci-data"
BRANCH = "main"
TARGET_DIR = "diffusion-ci/consistency_gt"
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp"}
def collect_images(source_dir):
"""Collect image files from source_dir and return list of (repo_path, content) tuples."""
files = []
for entry in sorted(os.listdir(source_dir)):
ext = os.path.splitext(entry)[1].lower()
if ext not in IMAGE_EXTENSIONS:
continue
full_path = os.path.join(source_dir, entry)
if not os.path.isfile(full_path):
continue
with open(full_path, "rb") as f:
content = f.read()
repo_path = f"{TARGET_DIR}/{entry}"
files.append((repo_path, content))
return files
def publish(source_dir):
token = os.getenv("GITHUB_TOKEN")
if not token:
print("Error: GITHUB_TOKEN environment variable not set")
sys.exit(1)
files_to_upload = collect_images(source_dir)
if not files_to_upload:
print(f"No image files found in {source_dir}")
return
print(
f"Found {len(files_to_upload)} image(s) to upload to {REPO_OWNER}/{REPO_NAME}/{TARGET_DIR}"
)
# Verify token
perm = verify_token_permissions(REPO_OWNER, REPO_NAME, token)
if perm == "rate_limited":
print("GitHub API rate-limited, skipping upload.")
return
if not perm:
print("Token permission verification failed.")
sys.exit(1)
# Create blobs
try:
tree_items = create_blobs(REPO_OWNER, REPO_NAME, files_to_upload, token)
except Exception as e:
if is_rate_limit_error(e):
print("Rate-limited during blob creation, skipping.")
return
if is_permission_error(e):
print(
f"ERROR: Token lacks write permission to {REPO_OWNER}/{REPO_NAME}. "
"Update GH_PAT_FOR_NIGHTLY_CI_DATA with a token that has contents:write."
)
sys.exit(1)
raise
# Commit with retry (handle concurrent pushes)
max_retries = 5
for attempt in range(max_retries):
try:
branch_sha = get_branch_sha(REPO_OWNER, REPO_NAME, BRANCH, token)
tree_sha = get_tree_sha(REPO_OWNER, REPO_NAME, branch_sha, token)
new_tree_sha = create_tree(
REPO_OWNER, REPO_NAME, tree_sha, tree_items, token
)
commit_msg = f"diffusion-ci: update consistency_gt images ({len(files_to_upload)} files) [automated]"
commit_sha = create_commit(
REPO_OWNER, REPO_NAME, new_tree_sha, branch_sha, commit_msg, token
)
update_branch_ref(REPO_OWNER, REPO_NAME, BRANCH, commit_sha, token)
print(
f"Successfully pushed {len(files_to_upload)} images (commit {commit_sha[:10]})"
)
return
except Exception as e:
if is_rate_limit_error(e):
print("Rate-limited, skipping.")
return
if is_permission_error(e):
print(f"ERROR: permission denied to {REPO_OWNER}/{REPO_NAME}")
sys.exit(1)
retryable = False
if hasattr(e, "error_body"):
if "Update is not a fast forward" in e.error_body:
retryable = True
elif "Object does not exist" in e.error_body:
retryable = True
from urllib.error import HTTPError
if isinstance(e, HTTPError) and e.code in [422, 500, 502, 503, 504]:
retryable = True
if retryable and attempt < max_retries - 1:
import time
wait = 2**attempt
print(
f"Attempt {attempt + 1}/{max_retries} failed, retrying in {wait}s..."
)
time.sleep(wait)
else:
print(f"Failed after {attempt + 1} attempts: {e}")
raise
def main():
parser = argparse.ArgumentParser(
description="Publish diffusion GT images to GitHub"
)
parser.add_argument(
"--source-dir", required=True, help="Directory containing GT images"
)
args = parser.parse_args()
publish(args.source_dir)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,951 @@
"""Cross-framework comparison benchmark for diffusion serving.
Launches servers (SGLang, vLLM-Omni, LightX2V) for each test case, sends a
single request, measures end-to-end latency, and writes comparison-results.json.
Usage:
# Full run (requires GPU)
python3 scripts/ci/utils/diffusion/run_comparison.py
# Dry-run (config parsing + command preview only)
python3 scripts/ci/utils/diffusion/run_comparison.py --dry-run
# Run only specific case(s)
python3 scripts/ci/utils/diffusion/run_comparison.py --case-ids flux1_dev_t2i_1024
# Run only specific framework(s)
python3 scripts/ci/utils/diffusion/run_comparison.py --frameworks sglang
"""
import argparse
import base64
import io
import json
import os
import signal
import subprocess
import sys
import tempfile
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
import requests
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CONFIGS_PATH = Path(__file__).parent / "comparison_configs.json"
INSTALL_SCRIPT = Path(__file__).parents[1] / "install_comparison_frameworks.sh"
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 30000
HEALTH_TIMEOUT = (
2400 # seconds (40 min — FLUX.2-dev needs ~10 min download + torch.compile)
)
REQUEST_TIMEOUT = 1200 # seconds
GPU_CLEAR_WAIT = 15 # seconds between framework runs
# Frameworks that need separate installation (conflict with sglang's deps)
INSTALLABLE_FRAMEWORKS = {"vllm-omni", "lightx2v"}
# Cached reference image (downloaded once)
_cached_ref_image: bytes | None = None
_cached_ref_image_path: str | None = None
# ---------------------------------------------------------------------------
# Server lifecycle — command builders
# ---------------------------------------------------------------------------
def _build_sglang_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]:
cmd = [
"sglang",
"serve",
"--model-path",
case["model"],
"--port",
str(port),
"--host",
DEFAULT_HOST,
]
if case["num_gpus"] > 1:
cmd += ["--num-gpus", str(case["num_gpus"])]
if fw_cfg.get("serve_args", "").strip():
cmd += fw_cfg["serve_args"].strip().split()
return cmd
def _build_vllm_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]:
cmd = [
"vllm",
"serve",
case["model"],
"--omni",
"--port",
str(port),
"--host",
DEFAULT_HOST,
]
if fw_cfg.get("serve_args", "").strip():
cmd += fw_cfg["serve_args"].strip().split()
return cmd
def _resolve_hf_model_path(model_id: str) -> str:
"""Resolve a HuggingFace model ID to a local cache path, or return as-is."""
if os.path.isdir(model_id):
return model_id
try:
from huggingface_hub import snapshot_download
path = snapshot_download(model_id)
print(f" Resolved {model_id} -> {path}")
return path
except Exception:
return model_id
def _write_lightx2v_config(case: dict) -> str:
"""Write a minimal LightX2V config JSON and return its path."""
cfg = {
"infer_steps": case.get("num_inference_steps", 50),
"guidance_scale": case.get("guidance_scale", 4.0),
"seed": case.get("seed", 42),
}
if "num_frames" in case:
cfg["target_video_length"] = case["num_frames"]
if "height" in case:
cfg["height"] = case["height"]
if "width" in case:
cfg["width"] = case["width"]
config_path = os.path.join(
tempfile.gettempdir(), f"lightx2v_config_{case['id']}.json"
)
with open(config_path, "w") as f:
json.dump(cfg, f)
return config_path
def _build_lightx2v_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]:
"""Build LightX2V server launch command.
Single GPU: python -m lightx2v.server --model_path ... --model_cls ... --task ... --port ...
Multi GPU: torchrun --nproc_per_node=N -m lightx2v.server ...
LightX2V requires a local model path and a config JSON with infer params.
"""
model_cls = fw_cfg["model_cls"]
task = fw_cfg["lightx2v_task"]
num_gpus = case["num_gpus"]
model_path = _resolve_hf_model_path(case["model"])
config_path = _write_lightx2v_config(case)
server_args = [
"--model_path",
model_path,
"--model_cls",
model_cls,
"--task",
task,
"--config_json",
config_path,
"--host",
DEFAULT_HOST,
"--port",
str(port),
]
if fw_cfg.get("serve_args", "").strip():
server_args += fw_cfg["serve_args"].strip().split()
if num_gpus > 1:
cmd = [
"torchrun",
f"--nproc_per_node={num_gpus}",
"-m",
"lightx2v.server",
] + server_args
else:
cmd = ["python3", "-m", "lightx2v.server"] + server_args
return cmd
def build_server_cmd(framework: str, case: dict, fw_cfg: dict, port: int) -> list[str]:
builders = {
"sglang": _build_sglang_cmd,
"vllm-omni": _build_vllm_cmd,
"lightx2v": _build_lightx2v_cmd,
}
builder = builders.get(framework)
if builder is None:
raise ValueError(f"Unknown framework: {framework}")
return builder(case, fw_cfg, port)
# ---------------------------------------------------------------------------
# Server lifecycle — health check & cleanup
# ---------------------------------------------------------------------------
# Health check endpoints per framework
HEALTH_ENDPOINTS = {
"sglang": "/health",
"vllm-omni": "/health",
"lightx2v": "/v1/service/status",
}
def wait_for_health(
base_url: str, framework: str = "sglang", timeout: int = HEALTH_TIMEOUT
) -> None:
"""Poll health endpoint until 200, then verify model is loaded."""
endpoint = HEALTH_ENDPOINTS.get(framework, "/health")
health_url = f"{base_url}{endpoint}"
print(f" Waiting for server at {health_url} ...")
start = time.time()
while True:
try:
resp = requests.get(health_url, timeout=2)
if resp.status_code == 200:
break
except requests.exceptions.RequestException:
pass
if time.time() - start > timeout:
raise TimeoutError(
f"Server at {health_url} did not start within {timeout}s"
)
time.sleep(2)
# For SGLang, /health can return 200 before model routes are registered.
# Poll /v1/models to confirm the model is fully loaded.
if framework == "sglang":
models_url = f"{base_url}/v1/models"
while True:
try:
resp = requests.get(models_url, timeout=5)
if resp.status_code == 200:
break
except requests.exceptions.RequestException:
pass
if time.time() - start > timeout:
raise TimeoutError(f"Model at {models_url} not ready within {timeout}s")
time.sleep(2)
elapsed = time.time() - start
print(f" Server ready in {elapsed:.1f}s")
KILLALL_SCRIPT = Path(__file__).parents[3] / "killall_sglang.sh"
def kill_server(proc: subprocess.Popen) -> None:
"""Kill server process tree and clean up GPU processes."""
if proc.poll() is not None:
return
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError):
pass
try:
proc.wait(timeout=30)
except subprocess.TimeoutExpired:
try:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError):
pass
proc.wait(timeout=10)
# Use killall_sglang.sh for thorough cleanup (esp. multi-GPU workers)
if KILLALL_SCRIPT.exists():
subprocess.run(
["bash", str(KILLALL_SCRIPT)],
timeout=30,
capture_output=True,
)
# ---------------------------------------------------------------------------
# Reference image helpers
# ---------------------------------------------------------------------------
def _get_ref_image_bytes(config: dict) -> bytes:
"""Download and cache the shared test reference image."""
global _cached_ref_image
if _cached_ref_image is not None:
return _cached_ref_image
url = config.get("test_image_url", "")
if not url:
raise RuntimeError("No test_image_url in config for image-conditioned case")
print(f" Downloading reference image from {url} ...")
resp = requests.get(url, timeout=60)
resp.raise_for_status()
_cached_ref_image = resp.content
return _cached_ref_image
def _get_ref_image_b64(config: dict) -> str:
"""Get reference image as base64 string."""
return base64.b64encode(_get_ref_image_bytes(config)).decode("utf-8")
def _get_ref_image_path(config: dict) -> str:
"""Save reference image to a temp file and return path."""
global _cached_ref_image_path
if _cached_ref_image_path and os.path.exists(_cached_ref_image_path):
return _cached_ref_image_path
data = _get_ref_image_bytes(config)
fd, path = tempfile.mkstemp(suffix=".png")
with os.fdopen(fd, "wb") as f:
f.write(data)
_cached_ref_image_path = path
return path
# ---------------------------------------------------------------------------
# Request helpers — SGLang (OpenAI-compatible)
# ---------------------------------------------------------------------------
def _build_sglang_payload(case: dict) -> dict:
"""Build common SGLang request payload."""
payload = {
"model": case["model"],
"prompt": case["prompt"],
"size": f"{case['width']}x{case['height']}",
"n": 1,
"response_format": "b64_json",
}
for key in ("num_inference_steps", "guidance_scale", "seed", "num_frames"):
if key in case:
payload[key] = case[key]
return payload
def _read_perf_dump(perf_dump_path: str, timeout: float = 10.0) -> float | None:
"""Read total_duration_ms from a perf dump JSON written by the server.
The server writes the file asynchronously after the HTTP response,
so we poll briefly.
"""
deadline = time.time() + timeout
while time.time() < deadline:
try:
with open(perf_dump_path) as f:
data = json.load(f)
total_ms = data.get("total_duration_ms")
if total_ms is not None:
return total_ms / 1000.0
except (FileNotFoundError, json.JSONDecodeError):
pass
time.sleep(0.5)
return None
def send_image_request_sglang(
base_url: str, case: dict, perf_dump_path: str | None = None
) -> float:
"""Send a single T2I request via SGLang's /v1/images/generations."""
payload = _build_sglang_payload(case)
if perf_dump_path:
payload["perf_dump_path"] = perf_dump_path
start = time.time()
resp = requests.post(
f"{base_url}/v1/images/generations",
json=payload,
timeout=REQUEST_TIMEOUT,
)
client_latency = time.time() - start
resp.raise_for_status()
data = resp.json()
if "data" not in data or len(data["data"]) == 0:
raise RuntimeError(f"Image request returned no data: {data}")
if perf_dump_path:
server_latency = _read_perf_dump(perf_dump_path)
if server_latency is not None:
print(
f" Image generated in {server_latency:.2f}s (server-side), "
f"client={client_latency:.2f}s"
)
return server_latency
print(f" Image generated in {client_latency:.2f}s")
return client_latency
def send_video_request_sglang(
base_url: str, case: dict, perf_dump_path: str | None = None
) -> float:
"""Send a single T2V request via SGLang's /v1/videos (async)."""
payload = _build_sglang_payload(case)
if perf_dump_path:
payload["perf_dump_path"] = perf_dump_path
start = time.time()
# Submit job
resp = requests.post(
f"{base_url}/v1/videos",
json=payload,
timeout=REQUEST_TIMEOUT,
)
resp.raise_for_status()
job = resp.json()
job_id = job.get("id")
if not job_id:
raise RuntimeError(f"Video submit returned no job id: {job}")
# Poll for completion
poll_url = f"{base_url}/v1/videos/{job_id}"
while True:
time.sleep(1)
poll_resp = requests.get(poll_url, timeout=30)
poll_resp.raise_for_status()
poll_data = poll_resp.json()
status = poll_data.get("status")
if status == "completed":
break
elif status == "failed":
raise RuntimeError(f"Video generation failed: {poll_data}")
if time.time() - start > REQUEST_TIMEOUT:
raise TimeoutError(f"Video generation timed out after {REQUEST_TIMEOUT}s")
client_latency = time.time() - start
if perf_dump_path:
server_latency = _read_perf_dump(perf_dump_path)
if server_latency is not None:
print(
f" Video generated in {server_latency:.2f}s (server-side), "
f"client={client_latency:.2f}s"
)
return server_latency
print(f" Video generated in {client_latency:.2f}s")
return client_latency
def send_image_conditioned_request_sglang(
base_url: str, case: dict, config: dict, perf_dump_path: str | None = None
) -> float:
"""Send an image-conditioned request (edit/I2V/TI2V) via SGLang multipart API."""
task = case["task"]
ref_bytes = _get_ref_image_bytes(config)
# Build multipart form — field name depends on endpoint:
# image edits use "image", video (I2V/TI2V) uses "input_reference"
if task in ("image-to-video", "text-image-to-video"):
file_field = "input_reference"
else:
file_field = "image"
files = {file_field: ("ref.png", io.BytesIO(ref_bytes), "image/png")}
data = {
"model": case["model"],
"prompt": case["prompt"],
"size": f"{case['width']}x{case['height']}",
"n": "1",
"response_format": "b64_json",
}
for key in ("num_inference_steps", "guidance_scale", "seed", "num_frames"):
if key in case:
data[key] = str(case[key])
if perf_dump_path:
data["perf_dump_path"] = perf_dump_path
# Choose endpoint based on task
if task in ("image-edit", "image-to-image"):
endpoint = "/v1/images/edits"
elif task in ("image-to-video", "text-image-to-video"):
endpoint = "/v1/videos"
else:
endpoint = "/v1/images/generations"
start = time.time()
resp = requests.post(
f"{base_url}{endpoint}",
files=files,
data=data,
timeout=REQUEST_TIMEOUT,
)
# For video endpoints, need to poll
if task in ("image-to-video", "text-image-to-video"):
resp.raise_for_status()
job = resp.json()
job_id = job.get("id")
if not job_id:
raise RuntimeError(f"Video submit returned no job id: {job}")
poll_url = f"{base_url}/v1/videos/{job_id}"
while True:
time.sleep(1)
poll_resp = requests.get(poll_url, timeout=30)
poll_resp.raise_for_status()
poll_data = poll_resp.json()
status = poll_data.get("status")
if status == "completed":
break
elif status == "failed":
raise RuntimeError(f"Video generation failed: {poll_data}")
if time.time() - start > REQUEST_TIMEOUT:
raise TimeoutError(f"Timed out after {REQUEST_TIMEOUT}s")
else:
resp.raise_for_status()
client_latency = time.time() - start
if perf_dump_path:
server_latency = _read_perf_dump(perf_dump_path)
if server_latency is not None:
print(
f" Generated in {server_latency:.2f}s (server-side), "
f"client={client_latency:.2f}s"
)
return server_latency
print(f" Generated in {client_latency:.2f}s (sglang, image-conditioned)")
return client_latency
# ---------------------------------------------------------------------------
# Request helpers — vLLM-Omni
# ---------------------------------------------------------------------------
def send_request_vllm_omni(base_url: str, case: dict, config: dict) -> float:
"""Send request via vLLM-Omni's /v1/chat/completions endpoint."""
extra_body = {
"height": case["height"],
"width": case["width"],
"num_inference_steps": case.get("num_inference_steps", 50),
"guidance_scale": case.get("guidance_scale", 4.0),
"seed": case.get("seed", 42),
}
if "num_frames" in case:
extra_body["num_frames"] = case["num_frames"]
# Build message content (text or text+image)
content: list[dict] | str = case["prompt"]
if case.get("reference_image"):
ref_b64 = _get_ref_image_b64(config)
content = [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{ref_b64}"},
},
{"type": "text", "text": case["prompt"]},
]
payload = {
"model": case["model"],
"messages": [{"role": "user", "content": content}],
"extra_body": extra_body,
}
start = time.time()
resp = requests.post(
f"{base_url}/v1/chat/completions",
json=payload,
timeout=REQUEST_TIMEOUT,
)
latency = time.time() - start
resp.raise_for_status()
data = resp.json()
choices = data.get("choices", [])
if not choices:
raise RuntimeError(f"vLLM-Omni request returned no choices: {data}")
print(f" Generated in {latency:.2f}s (vllm-omni)")
return latency
# ---------------------------------------------------------------------------
# Request helpers — LightX2V
# ---------------------------------------------------------------------------
def send_request_lightx2v(base_url: str, case: dict, config: dict) -> float:
"""Send request via LightX2V's async task API."""
task = case["task"]
if task in ("text-to-image", "image-edit"):
endpoint = "/v1/tasks/image"
else:
endpoint = "/v1/tasks/video"
payload = {
"prompt": case["prompt"],
"seed": case.get("seed", 42),
"infer_steps": case.get("num_inference_steps", 50),
}
# LightX2V uses target_video_length for frames, height/width directly
if "num_frames" in case:
payload["target_video_length"] = case["num_frames"]
if "height" in case:
payload["height"] = case["height"]
if "width" in case:
payload["width"] = case["width"]
if "guidance_scale" in case:
payload["guidance_scale"] = case["guidance_scale"]
# Image-conditioned: LightX2V accepts image_path (URL or local path)
if case.get("reference_image"):
payload["image_path"] = config.get("test_image_url", "")
start = time.time()
# Submit task
resp = requests.post(
f"{base_url}{endpoint}",
json=payload,
timeout=REQUEST_TIMEOUT,
)
resp.raise_for_status()
task_data = resp.json()
task_id = task_data.get("task_id")
if not task_id:
raise RuntimeError(f"LightX2V submit returned no task_id: {task_data}")
# Poll for completion
poll_url = f"{base_url}/v1/tasks/{task_id}/status"
while True:
time.sleep(1)
poll_resp = requests.get(poll_url, timeout=30)
poll_resp.raise_for_status()
poll_data = poll_resp.json()
status = poll_data.get("task_status", "").upper()
if status == "COMPLETED":
break
elif status in ("FAILED", "CANCELLED"):
raise RuntimeError(f"LightX2V task {status}: {poll_data}")
if time.time() - start > REQUEST_TIMEOUT:
raise TimeoutError(f"LightX2V task timed out after {REQUEST_TIMEOUT}s")
latency = time.time() - start
print(f" Generated in {latency:.2f}s (lightx2v)")
return latency
# ---------------------------------------------------------------------------
# Unified request dispatcher
# ---------------------------------------------------------------------------
def send_request(
base_url: str,
case: dict,
framework: str = "sglang",
config: dict | None = None,
perf_dump_path: str | None = None,
) -> float:
config = config or {}
if framework == "vllm-omni":
return send_request_vllm_omni(base_url, case, config)
elif framework == "lightx2v":
return send_request_lightx2v(base_url, case, config)
# SGLang — use OpenAI-compatible endpoints with optional perf log
task = case["task"]
if case.get("reference_image"):
return send_image_conditioned_request_sglang(
base_url, case, config, perf_dump_path
)
elif task == "text-to-image":
return send_image_request_sglang(base_url, case, perf_dump_path)
elif task == "text-to-video":
return send_video_request_sglang(base_url, case, perf_dump_path)
else:
raise ValueError(f"Unknown task type: {task}")
# ---------------------------------------------------------------------------
# Main orchestrator
# ---------------------------------------------------------------------------
def run_single(
case: dict,
framework: str,
fw_cfg: dict,
port: int,
log_dir: Path,
config: dict | None = None,
) -> dict:
"""Run a single (case, framework) combination. Returns result dict."""
result = {
"case_id": case["id"],
"framework": framework,
"model": case["model"],
"task": case["task"],
"latency_s": None,
"error": None,
}
cmd = build_server_cmd(framework, case, fw_cfg, port)
print(f"\n Command: {' '.join(cmd)}")
env = os.environ.copy()
env.update(fw_cfg.get("extra_env", {}))
# perf_dump_path for SGLang server-side timing (passed in request, zero overhead when None)
perf_dump_path = None
if framework == "sglang":
perf_dump_path = os.path.join(str(log_dir), f"perf_{case['id']}_measured.json")
log_file = log_dir / f"{case['id']}_{framework}.log"
log_fh = open(log_file, "w", encoding="utf-8", buffering=1)
log_thread = None
proc = None
try:
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=env,
preexec_fn=os.setsid,
text=True,
bufsize=1,
)
# Tee server output to both log file and stdout (like test_server_utils)
def _log_pipe(pipe, fh):
try:
for line in iter(pipe.readline, ""):
sys.stdout.write(f" [server] {line}")
sys.stdout.flush()
fh.write(line)
except ValueError:
pass # pipe closed
log_thread = threading.Thread(target=_log_pipe, args=(proc.stdout, log_fh))
log_thread.daemon = True
log_thread.start()
base_url = f"http://{DEFAULT_HOST}:{port}"
wait_for_health(base_url, framework)
# Warmup requests (not measured, no perf dump)
# Use few steps to be fast — server's own warmup (warmup_steps=3) handles
# torch.compile compilation; these external warmups just stabilize triton
# kernel specializations across requests.
WARMUP_STEPS = 3
warmup_case = {**case, "num_inference_steps": WARMUP_STEPS}
for wi in range(1, 3):
print(f" Sending warmup request ({wi}/2, {WARMUP_STEPS} steps)...")
try:
send_request(base_url, warmup_case, framework, config)
except Exception as e:
print(f" Warmup request {wi} failed (non-fatal): {e}")
# Measured request — pass perf_dump_path for SGLang server-side timing
if perf_dump_path and os.path.exists(perf_dump_path):
os.remove(perf_dump_path)
print(" Sending measured request...")
latency = send_request(
base_url, case, framework, config, perf_dump_path=perf_dump_path
)
result["latency_s"] = round(latency, 3)
except Exception as e:
result["error"] = str(e)
print(f" ERROR: {e}")
finally:
if proc:
kill_server(proc)
if log_thread:
log_thread.join(timeout=5)
log_fh.close()
return result
def _install_framework(fw_name: str, dry_run: bool = False) -> bool:
"""Install a comparison framework via the install script. Returns True on success."""
if fw_name not in INSTALLABLE_FRAMEWORKS:
return True
if not INSTALL_SCRIPT.exists():
print(f" WARNING: Install script not found at {INSTALL_SCRIPT}")
return False
if dry_run:
print(f" [DRY-RUN] Would install: bash {INSTALL_SCRIPT} {fw_name}")
return True
print(f"\n{'='*60}")
print(f"Installing framework: {fw_name}")
print(f"{'='*60}")
ret = subprocess.run(
["bash", str(INSTALL_SCRIPT), fw_name],
timeout=600,
)
if ret.returncode != 0:
print(f" WARNING: {fw_name} installation failed (exit {ret.returncode})")
return False
return True
def run_comparison(
config: dict,
case_ids: list[str] | None = None,
frameworks: list[str] | None = None,
port: int = DEFAULT_PORT,
output: str = "comparison-results.json",
dry_run: bool = False,
) -> dict:
"""Run all comparison cases, grouped by framework to minimize installs.
Order: sglang first (already installed), then vllm-omni, then lightx2v.
Each non-sglang framework is installed right before its cases run.
"""
timestamp = datetime.now(timezone.utc).isoformat()
commit_sha = os.environ.get("GITHUB_SHA", "unknown")
run_id = os.environ.get("GITHUB_RUN_ID", "local")
log_dir = Path("comparison-logs")
log_dir.mkdir(exist_ok=True)
# Collect all (case, framework) pairs, grouped by framework
fw_order = ["sglang", "vllm-omni", "lightx2v"]
fw_cases: dict[str, list[tuple[dict, dict]]] = {fw: [] for fw in fw_order}
for case in config["cases"]:
if case_ids and case["id"] not in case_ids:
continue
for fw_name, fw_cfg in case["frameworks"].items():
if frameworks and fw_name not in frameworks:
continue
if fw_name not in fw_cases:
fw_cases[fw_name] = []
fw_cases[fw_name].append((case, fw_cfg))
results = []
installed_fws: set[str] = set()
for fw_name in fw_order:
pairs = fw_cases.get(fw_name, [])
if not pairs:
continue
# Install framework if needed (once per framework)
if fw_name not in installed_fws and fw_name in INSTALLABLE_FRAMEWORKS:
if not _install_framework(fw_name, dry_run):
# Skip all cases for this framework
for case, _ in pairs:
results.append(
{
"case_id": case["id"],
"framework": fw_name,
"model": case["model"],
"task": case["task"],
"latency_s": None,
"error": f"{fw_name} installation failed",
}
)
continue
installed_fws.add(fw_name)
for case, fw_cfg in pairs:
print(f"\n{'='*60}")
print(f"Case: {case['id']} | Model: {case['model']} | Framework: {fw_name}")
print(f"{'='*60}")
if dry_run:
cmd = build_server_cmd(fw_name, case, fw_cfg, port)
print(f" [DRY-RUN] Would run: {' '.join(cmd)}")
results.append(
{
"case_id": case["id"],
"framework": fw_name,
"model": case["model"],
"task": case["task"],
"latency_s": None,
"error": "dry-run",
}
)
continue
result = run_single(case, fw_name, fw_cfg, port, log_dir, config)
results.append(result)
# Wait for GPU memory to clear
print(f" Waiting {GPU_CLEAR_WAIT}s for GPU memory to clear...")
time.sleep(GPU_CLEAR_WAIT)
output_data = {
"timestamp": timestamp,
"commit_sha": commit_sha,
"run_id": run_id,
"results": results,
}
os.makedirs(os.path.dirname(output) or ".", exist_ok=True)
with open(output, "w") as f:
json.dump(output_data, f, indent=2)
print(f"\nResults written to {output}")
# Print summary table
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
for r in results:
lat = f"{r['latency_s']:.2f}s" if r["latency_s"] else r.get("error", "N/A")
print(f" {r['case_id']:30s} | {r['framework']:12s} | {lat}")
return output_data
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Cross-framework diffusion serving comparison benchmark"
)
parser.add_argument(
"--config",
default=str(CONFIGS_PATH),
help="Path to comparison_configs.json",
)
parser.add_argument(
"--case-ids",
nargs="+",
default=None,
help="Only run specific case IDs",
)
parser.add_argument(
"--frameworks",
nargs="+",
default=None,
help="Only run specific frameworks (sglang, vllm-omni, lightx2v)",
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help="Server port",
)
parser.add_argument(
"--output",
default="comparison-results.json",
help="Output JSON path",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Parse config and print commands without launching servers",
)
args = parser.parse_args()
with open(args.config) as f:
config = json.load(f)
print(f"Loaded {len(config['cases'])} comparison case(s) from {args.config}")
run_comparison(
config=config,
case_ids=args.case_ids,
frameworks=args.frameworks,
port=args.port,
output=args.output,
dry_run=args.dry_run,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,163 @@
#!/usr/bin/env python3
"""Collect and save diffusion performance metrics for artifact collection in CI.
This script reads diffusion test results from the pytest stash and saves them
with metadata for the performance dashboard.
Usage:
python3 scripts/ci/utils/diffusion/save_diffusion_metrics.py \
--gpu-config 1-gpu-h100 \
--run-id 12345678 \
--output test/diffusion-metrics-1gpu.json \
--results-json test/diffusion-results.json
"""
import argparse
import json
import os
import sys
from datetime import datetime, timezone
def load_diffusion_results(results_file: str) -> list[dict]:
"""Load diffusion performance results from JSON file."""
if not os.path.exists(results_file):
print(f"Warning: Results file not found: {results_file}")
return []
try:
with open(results_file, "r", encoding="utf-8") as f:
data = json.load(f)
return data if isinstance(data, list) else [data]
except (json.JSONDecodeError, OSError) as e:
print(f"Warning: Failed to parse {results_file}: {e}")
return []
def transform_diffusion_result(result: dict, gpu_config: str) -> dict:
"""Transform a diffusion result to match dashboard expectations.
Dashboard expects:
- Separate test_name, class_name
- Numeric metrics in consistent units
- Optional modality field
"""
return {
"test_name": result.get("test_name"),
"class_name": result.get("class_name"),
"modality": result.get("modality", "image"),
"e2e_ms": result.get("e2e_ms"),
"avg_denoise_ms": result.get("avg_denoise_ms"),
"median_denoise_ms": result.get("median_denoise_ms"),
"stage_metrics": result.get("stage_metrics", {}),
"sampled_steps": result.get("sampled_steps", {}),
# Video-specific metrics (if present)
"frames_per_second": result.get("frames_per_second"),
"total_frames": result.get("total_frames"),
"avg_frame_time_ms": result.get("avg_frame_time_ms"),
}
def group_results_by_class(results: list[dict], gpu_config: str) -> list[dict]:
"""Group diffusion results by test class (suite).
Returns list with one entry per test class, containing all tests in that class.
"""
groups = {}
for result in results:
class_name = result.get("class_name", "unknown")
if class_name not in groups:
groups[class_name] = {
"gpu_config": gpu_config,
"test_suite": class_name,
"tests": [],
}
transformed = transform_diffusion_result(result, gpu_config)
groups[class_name]["tests"].append(transformed)
return list(groups.values())
def save_metrics(
gpu_config: str,
run_id: str,
output_file: str,
results_file: str,
) -> bool:
"""Collect diffusion metrics and save to output file."""
timestamp = datetime.now(timezone.utc).isoformat()
# Load diffusion results
raw_results = load_diffusion_results(results_file)
print(f"Loaded {len(raw_results)} diffusion test result(s)")
# Group by test class
grouped = group_results_by_class(raw_results, gpu_config)
# Create metrics structure
metrics = {
"run_id": run_id,
"timestamp": timestamp,
"gpu_config": gpu_config,
"test_type": "diffusion",
"results": grouped,
}
# Ensure output directory exists and write output
try:
os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
if not raw_results:
print(f"Created empty metrics file: {output_file}")
else:
print(f"Saved diffusion metrics to: {output_file}")
return True
except OSError as e:
print(f"Error writing metrics file: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Collect diffusion performance metrics from test results"
)
parser.add_argument(
"--gpu-config",
required=True,
help="GPU configuration (e.g., 1-gpu-h100, 2-gpu-h100)",
)
parser.add_argument(
"--run-id",
required=True,
help="GitHub Actions run ID",
)
parser.add_argument(
"--output",
required=True,
help="Output file path for metrics JSON",
)
parser.add_argument(
"--results-json",
required=True,
help="Path to diffusion results JSON file",
)
args = parser.parse_args()
success = save_metrics(
gpu_config=args.gpu_config,
run_id=args.run_id,
output_file=args.output,
results_file=args.results_json,
)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
#!/bin/bash
# Ensure protoc is installed for router build (gRPC protobuf compilation).
set -euxo pipefail
if command -v protoc >/dev/null 2>&1 && protoc --version >/dev/null 2>&1; then
echo "protoc already installed: $(protoc --version)"
exit 0
fi
if command -v protoc >/dev/null 2>&1; then
echo "protoc found but not runnable, reinstalling..."
else
echo "protoc not found, installing..."
fi
ARCH=$(uname -m)
if command -v apt-get &> /dev/null; then
# Ubuntu/Debian
apt-get update || true # May fail due to unrelated broken packages
PROTOC_APT_PACKAGES=(wget unzip gcc g++ perl make)
apt-get install -y --no-install-recommends "${PROTOC_APT_PACKAGES[@]}" || {
echo "Warning: apt-get install failed, checking if required packages are available..."
for pkg in "${PROTOC_APT_PACKAGES[@]}"; do
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
echo "ERROR: Required package $pkg is not installed and apt-get failed"
exit 1
fi
done
echo "All required packages are already installed, continuing..."
}
elif command -v yum &> /dev/null; then
# RHEL/CentOS
yum update -y
yum install -y wget unzip gcc gcc-c++ perl-core make
else
echo "ERROR: Neither apt-get nor yum found; cannot install protoc build deps"
exit 1
fi
if [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then
PROTOC_ARCH="aarch_64"
else
PROTOC_ARCH="x86_64"
fi
PROTOC_ZIP="protoc-32.0-linux-${PROTOC_ARCH}.zip"
(
cd /tmp
wget "https://github.com/protocolbuffers/protobuf/releases/download/v32.0/${PROTOC_ZIP}"
unzip -o "${PROTOC_ZIP}" -d /usr/local
rm -f "${PROTOC_ZIP}"
)
protoc --version

View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""Merge per-partition metrics into a consolidated metrics file.
This script reads all per-partition metric JSON files and consolidates them
into a single JSON file with run-level metadata.
Usage:
python3 scripts/ci/utils/merge_metrics.py \
--input-dir metrics/ \
--output consolidated-metrics-12345678.json \
--run-id 12345678 \
--commit-sha abc123def456
"""
import argparse
import glob
import json
import os
import sys
from datetime import datetime, timezone
def find_partition_files(input_dir: str) -> list[str]:
"""Find all partition metric files in the input directory."""
patterns = [
os.path.join(input_dir, "**/metrics-*.json"),
os.path.join(input_dir, "**/diffusion-metrics-*.json"),
os.path.join(input_dir, "**/comparison-metrics-*.json"),
]
files = set()
for pattern in patterns:
files.update(glob.glob(pattern, recursive=True))
return list(files)
def load_partition_metrics(filepath: str) -> dict | None:
"""Load a partition metrics file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError) as e:
print(f"Warning: Failed to load {filepath}: {e}")
return None
def merge_metrics(
input_dir: str,
output_file: str,
run_id: str,
commit_sha: str,
branch: str | None = None,
) -> bool:
"""Merge all partition metrics into a consolidated file."""
run_date = datetime.now(timezone.utc).isoformat()
# Find all partition files
partition_files = find_partition_files(input_dir)
print(f"Found {len(partition_files)} partition file(s)")
all_results = []
if not partition_files:
print("No partition metrics files found")
else:
# Load all partition files
for filepath in sorted(partition_files):
print(f" Reading: {filepath}")
metrics = load_partition_metrics(filepath)
if metrics and "results" in metrics:
all_results.extend(metrics["results"])
print(f"Total results collected: {len(all_results)}")
# Create consolidated structure
consolidated = {
"run_id": run_id,
"run_date": run_date,
"commit_sha": commit_sha,
"branch": branch,
"results": all_results,
}
# Ensure output directory exists and write output
try:
os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(consolidated, f, indent=2)
if not partition_files:
print(f"Created empty consolidated file: {output_file}")
else:
print(f"Saved consolidated metrics to: {output_file}")
return True
except OSError as e:
print(f"Error writing consolidated file: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Merge per-partition metrics into consolidated file"
)
parser.add_argument(
"--input-dir",
required=True,
help="Directory containing partition metric files",
)
parser.add_argument(
"--output",
required=True,
help="Output file path for consolidated metrics JSON",
)
parser.add_argument(
"--run-id",
required=True,
help="GitHub Actions run ID",
)
parser.add_argument(
"--commit-sha",
required=True,
help="Git commit SHA",
)
parser.add_argument(
"--branch",
default=None,
help="Git branch name (optional)",
)
args = parser.parse_args()
success = merge_metrics(
input_dir=args.input_dir,
output_file=args.output,
run_id=args.run_id,
commit_sha=args.commit_sha,
branch=args.branch,
)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,407 @@
#!/usr/bin/env python3
"""
Pre-validate all cached HuggingFace models to provide detailed feedback.
This script runs once during CI initialization (in prepare_runner.sh) to:
1. Scan snapshots in ~/.cache/huggingface/hub/ (with time/quantity limits)
2. Validate completeness (config/tokenizer/weights)
3. Output detailed failure reasons for debugging
NOTE: This script no longer writes shared validation markers. Each test run
independently validates its cache using per-run markers to avoid cross-runner
cache state pollution.
"""
import glob
import json
import os
import sys
import time
from pathlib import Path
# Add python directory to path to import sglang modules
REPO_ROOT = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(REPO_ROOT / "python"))
from sglang.srt.model_loader.ci_weight_validation import ( # noqa: E402
_validate_diffusion_model,
validate_cache_with_detailed_reason,
)
# Limits to avoid spending too much time on validation
MAX_VALIDATION_TIME_SECONDS = 300 # Max 5 minutes total
def find_all_hf_snapshots():
"""
Find all HuggingFace snapshots in cache.
Returns:
List of (model_name, snapshot_dir) tuples, sorted by mtime (newest first)
"""
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
hub_dir = os.path.join(hf_home, "hub")
if not os.path.isdir(hub_dir):
print(f"HF hub directory not found: {hub_dir}")
return []
snapshots = []
# Pattern: models--org--model/snapshots/hash
for model_dir in glob.glob(os.path.join(hub_dir, "models--*")):
# Extract model name from directory (models--org--model -> org/model)
dir_name = os.path.basename(model_dir)
if not dir_name.startswith("models--"):
continue
# models--meta-llama--Llama-2-7b-hf -> meta-llama/Llama-2-7b-hf
# Handle multi-part names: models--a--b--c -> a/b-c (join parts 1+ with /)
parts = dir_name.split("--")
if len(parts) < 3 or parts[0] != "models":
# Invalid format, skip
continue
# Standard format: models--org--repo -> org/repo
# Extended format: models--org--repo--extra -> org/repo-extra (join with -)
model_name = parts[1] + "/" + "-".join(parts[2:])
snapshots_dir = os.path.join(model_dir, "snapshots")
if not os.path.isdir(snapshots_dir):
continue
# Find all snapshot hashes
for snapshot_hash_dir in os.listdir(snapshots_dir):
snapshot_path = os.path.join(snapshots_dir, snapshot_hash_dir)
if os.path.isdir(snapshot_path):
try:
mtime = os.path.getmtime(snapshot_path)
snapshots.append((model_name, snapshot_path, mtime))
except OSError:
continue
# Sort by mtime (newest first) - prioritize recently used models
snapshots.sort(key=lambda x: x[2], reverse=True)
# Return without mtime
return [(name, path) for name, path, _ in snapshots]
def is_transformers_text_model(snapshot_dir):
"""
Check if a snapshot is a transformers text model.
Only excludes (returns False) for models with STRONG evidence of being
diffusers/generation pipelines. Uses conservative heuristics to avoid
false negatives on multimodal LLMs with tokenizers.
Args:
snapshot_dir: Path to snapshot directory
Returns:
True if this looks like a transformers text model, False otherwise (N/A)
"""
# Check for diffusers pipeline markers (strong evidence)
diffusers_markers = [
"model_index.json", # Diffusers pipeline config
"scheduler", # Scheduler directory (diffusers)
]
if any(
os.path.exists(os.path.join(snapshot_dir, marker))
for marker in diffusers_markers
):
return False
config_path = os.path.join(snapshot_dir, "config.json")
if not os.path.exists(config_path):
# No config.json - likely not a transformers model
return False
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
# Check for explicit diffusers/generation model types (conservative keywords)
model_type = config.get("_class_name") or config.get("model_type")
if model_type:
model_type_lower = str(model_type).lower()
# Only exclude clear diffusion/generation models
if any(
keyword in model_type_lower
for keyword in [
"diffusion",
"unet",
"vae",
"controlnet",
"stable-diffusion",
"latent-diffusion",
]
):
return False
# Check architectures for explicit generation/diffusion classes
architectures = config.get("architectures", [])
if architectures:
arch_str = " ".join(architectures).lower()
# Conservative: only exclude obvious diffusion/generation architectures
# Use word boundaries to avoid false positives (e.g., "dit" in "conditional")
for keyword in [
"diffusion",
"unet2d",
"unet3d",
"vaedecoder", # More specific than "vae"
"vaeencoder",
"controlnet",
"autoencoder",
"ditmodel", # Diffusion Transformer - use more specific pattern
"pixart", # PixArt diffusion model
]:
if keyword in arch_str:
return False
# Check for standalone vision encoder/image processor (no text component)
# Only if model name explicitly indicates non-text usage
model_name = config.get("_name_or_path", "").lower()
if any(
keyword in model_name
for keyword in [
"image-edit-", # Pure image editing (e.g., Qwen-Image-Edit)
"-image-editing",
"dit-", # DiT generation models
"pixart-", # PixArt generation models
]
):
# Additional check: does it have tokenizer? If yes, might be multimodal LLM
has_tokenizer = any(
os.path.exists(os.path.join(snapshot_dir, fname))
for fname in ["tokenizer.json", "tokenizer.model", "tiktoken.model"]
)
if not has_tokenizer:
# Image-edit model without tokenizer -> likely pure vision pipeline
return False
# Default: assume it's a transformers text/multimodal model
# Even if it lacks tokenizer, let validation report the actual error
# (better false positive than false negative for text models)
return True
except (json.JSONDecodeError, OSError, KeyError):
# Can't parse config - assume it's transformers and let validation report failure
return True
def scan_weight_files(snapshot_dir):
"""
Scan for weight files in a snapshot.
Returns:
List of weight file paths, or empty list if scan fails
"""
weight_files = []
# First, look for index files
index_patterns = ["*.safetensors.index.json", "pytorch_model.bin.index.json"]
index_files = []
for pattern in index_patterns:
index_files.extend(glob.glob(os.path.join(snapshot_dir, pattern)))
# If we have safetensors index, collect shards from it
for index_file in index_files:
if index_file.endswith(".safetensors.index.json"):
try:
with open(index_file, "r", encoding="utf-8") as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
for weight_file in set(weight_map.values()):
weight_path = os.path.join(snapshot_dir, weight_file)
if os.path.exists(weight_path):
weight_files.append(weight_path)
except Exception as e:
print(
f" Warning: Failed to parse index {os.path.basename(index_file)}: {e}"
)
# If no index found or no shards from index, do recursive glob
if not weight_files:
matched = glob.glob(
os.path.join(snapshot_dir, "**/*.safetensors"), recursive=True
)
MAX_WEIGHT_FILES = 1000
if len(matched) > MAX_WEIGHT_FILES:
print(
f" Warning: Too many safetensors files ({len(matched)} > {MAX_WEIGHT_FILES})"
)
return []
for f in matched:
if os.path.exists(f): # Filter out broken symlinks
weight_files.append(f)
return weight_files
def validate_snapshot(model_name, snapshot_dir, weight_files, validated_cache):
"""
Validate a snapshot and return detailed status.
Uses in-process cache to avoid duplicate validation within the same run.
Args:
model_name: Model identifier
snapshot_dir: Path to snapshot directory
weight_files: List of weight files to validate
validated_cache: Dict to track already-validated snapshots in this run
Returns:
Tuple of (result, reason):
- (True, None) if validation passed
- (False, reason_str) if validation failed
- (None, None) if skipped (already validated in this run)
"""
# Fast path: check in-process cache first
if snapshot_dir in validated_cache:
return None, None # Already validated in this run, skip
try:
# Perform validation with detailed reason
is_complete, reason = validate_cache_with_detailed_reason(
snapshot_dir=snapshot_dir,
weight_files=weight_files,
model_name_or_path=model_name,
)
# Cache result to avoid re-validation in this run
validated_cache[snapshot_dir] = (is_complete, reason)
return is_complete, reason
except Exception as e:
error_msg = f"Validation raised exception: {e}"
return False, error_msg
def main():
start_time = time.time()
print("=" * 70)
print("CI_OFFLINE: Pre-validating cached HuggingFace models")
print("=" * 70)
print(f"Max time: {MAX_VALIDATION_TIME_SECONDS}s")
print()
print("Scanning HuggingFace cache for models...")
snapshots = find_all_hf_snapshots()
if not snapshots:
print("No cached models found, skipping validation")
print("=" * 70)
return
print(f"Found {len(snapshots)} snapshot(s) in cache")
print()
validated_count = 0
failed_count = 0
skipped_count = 0
processed_count = 0
# In-process cache to avoid re-validating same snapshot in this run
validated_cache = {}
for model_name, snapshot_dir in snapshots:
# Check time limit
elapsed = time.time() - start_time
if elapsed > MAX_VALIDATION_TIME_SECONDS:
print()
print(
f"Time limit reached ({elapsed:.1f}s > {MAX_VALIDATION_TIME_SECONDS}s)"
)
print(
f"Stopping validation, {len(snapshots) - processed_count} snapshots remaining"
)
break
snapshot_hash = os.path.basename(snapshot_dir)
print(
f"[{processed_count + 1}/{len(snapshots)}] {model_name} ({snapshot_hash[:8]}...)"
)
processed_count += 1
# Determine model type by checking for model_index.json (diffusers pipeline marker)
model_index_path = os.path.join(snapshot_dir, "model_index.json")
is_diffusion_model = os.path.exists(model_index_path)
if is_diffusion_model:
# This is a diffusers pipeline - use diffusion validation
try:
is_valid, reason = _validate_diffusion_model(snapshot_dir)
if is_valid:
print(" PASS (diffusion) - Cache complete & valid")
validated_count += 1
else:
print(f" FAIL (diffusion) - {reason}")
failed_count += 1
except Exception as e:
print(f" FAIL (diffusion) - Validation raised exception: {e}")
failed_count += 1
continue
# Transformers model - use standard validation
# First check if this looks like a transformers text model
if not is_transformers_text_model(snapshot_dir):
# Not a recognized model type, skip
print(
" SKIP (unknown type) - Not a diffusers pipeline or transformers model"
)
skipped_count += 1
continue
# Scan weight files
weight_files = scan_weight_files(snapshot_dir)
if not weight_files:
print(" SKIP (no weights) - empty or incomplete download")
skipped_count += 1
continue
# Validate
try:
result, reason = validate_snapshot(
model_name, snapshot_dir, weight_files, validated_cache
)
if result is True:
print(" PASS - Cache complete & valid")
validated_count += 1
elif result is False:
# Print detailed failure reason
if reason:
print(f" FAIL (incomplete) - {reason}")
else:
print(" FAIL (incomplete) - cache validation failed")
failed_count += 1
else: # None (skipped)
print(" SKIP (already validated in this run)")
skipped_count += 1
except Exception as e:
print(f" FAIL (error) - Validation raised exception: {e}")
failed_count += 1
elapsed_total = time.time() - start_time
print()
print("=" * 70)
print(f"Validation summary (completed in {elapsed_total:.1f}s):")
print(f" PASS (complete & valid): {validated_count}")
print(f" FAIL (incomplete/corrupted): {failed_count}")
print(f" SKIP (no weights/duplicate): {skipped_count}")
print(f" Total processed: {processed_count}/{len(snapshots)}")
print("=" * 70)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,517 @@
"""
Publish performance traces to GitHub repository
"""
import argparse
import base64
import json
import os
import sys
import time
import warnings
from urllib.error import HTTPError
from urllib.request import Request, urlopen
def is_rate_limit_error(e):
"""Check if an exception is a GitHub rate limit error (not permission error)"""
if not isinstance(e, HTTPError):
return False
if e.code == 429:
return True
if e.code == 403:
# 403 can be rate limit OR permission error - check the message
error_body = getattr(e, "error_body", "")
if isinstance(error_body, str):
# Rate limit errors contain specific phrases
rate_limit_phrases = [
"rate limit",
"abuse detection",
"secondary rate limit",
]
return any(phrase in error_body.lower() for phrase in rate_limit_phrases)
return False
def is_permission_error(e):
"""Check if an exception is a GitHub permission error"""
if not isinstance(e, HTTPError) or e.code != 403:
return False
error_body = getattr(e, "error_body", "")
if isinstance(error_body, str):
permission_phrases = [
"resource not accessible",
"must have push access",
"permission",
"denied",
]
return any(phrase in error_body.lower() for phrase in permission_phrases)
return False
def make_github_request(url, token, method="GET", data=None):
"""Make authenticated request to GitHub API"""
headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {token}",
# "User-Agent": "sglang-ci",
"X-GitHub-Api-Version": "2022-11-28",
}
if data:
headers["Content-Type"] = "application/json"
data = json.dumps(data).encode("utf-8")
req = Request(url, data=data, headers=headers, method=method)
try:
with urlopen(req) as response:
return response.read().decode("utf-8")
except HTTPError as e:
print(f"GitHub API request failed: {e}")
try:
error_body = e.read().decode("utf-8")
print(f"Error response body: {error_body}")
e.error_body = error_body # Attach for later inspection
except Exception:
e.error_body = ""
raise
except Exception as e:
print(f"GitHub API request failed with a non-HTTP error: {e}")
raise
def verify_token_permissions(repo_owner, repo_name, token):
"""Verify that the token has necessary permissions for the repository"""
print("Verifying token permissions...")
checks = [
(
f"https://api.github.com/repos/{repo_owner}/{repo_name}", # Check if we can access the repository
"Repository access verified",
),
(
f"https://api.github.com/repos/{repo_owner}/{repo_name}/contents", # Check if we can read the repository contents
"Repository contents access verified",
),
]
for url, success_message in checks:
try:
response = make_github_request(url, token)
if success_message == "Repository access verified":
repo_data = json.loads(response)
print(f"{success_message}: {repo_data['full_name']}")
else:
print(success_message)
except Exception as e:
if is_rate_limit_error(e):
warnings.warn(
"GitHub API rate limit exceeded during token verification."
)
return "rate_limited"
print(f"Failed to verify permissions for {url}: {e}")
return False
return True
def get_branch_sha(repo_owner, repo_name, branch, token):
"""Get SHA of the branch head"""
url = (
f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/refs/heads/{branch}"
)
response = make_github_request(url, token)
data = json.loads(response)
return data["object"]["sha"]
def get_tree_sha(repo_owner, repo_name, commit_sha, token):
"""Get tree SHA from commit"""
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/commits/{commit_sha}"
response = make_github_request(url, token)
data = json.loads(response)
return data["tree"]["sha"]
def create_blob(repo_owner, repo_name, content, token, max_retries=3):
"""Create a blob with file content"""
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/blobs"
# Encode content as base64 for GitHub API
content_b64 = base64.b64encode(content).decode("utf-8")
data = {"content": content_b64, "encoding": "base64"}
for attempt in range(max_retries):
try:
response = make_github_request(url, token, method="POST", data=data)
return json.loads(response)["sha"]
except Exception as e:
# Don't retry on rate limit errors - fail fast
if is_rate_limit_error(e):
raise
if attempt < max_retries - 1:
wait_time = 2**attempt # Exponential backoff: 1s, 2s, 4s
print(
f"Blob creation failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
)
time.sleep(wait_time)
else:
raise
def create_blobs(repo_owner, repo_name, files, token):
"""Create blobs for all files and return tree items with blob SHAs"""
tree_items = []
for i, (file_path, content) in enumerate(files):
# Create blob first to get SHA
blob_sha = create_blob(repo_owner, repo_name, content, token)
tree_items.append(
{
"path": file_path,
"mode": "100644",
"type": "blob",
"sha": blob_sha,
}
)
# Progress indicator for large uploads
if (i + 1) % 10 == 0 or (i + 1) == len(files):
print(f"Created {i + 1}/{len(files)} blobs...")
return tree_items
def create_tree(repo_owner, repo_name, base_tree_sha, tree_items, token, max_retries=3):
"""Create a new tree from pre-created blob SHAs"""
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/trees"
data = {"base_tree": base_tree_sha, "tree": tree_items}
for attempt in range(max_retries):
try:
response = make_github_request(url, token, method="POST", data=data)
return json.loads(response)["sha"]
except Exception as e:
# Don't retry on rate limit errors - fail fast
if is_rate_limit_error(e):
raise
if attempt < max_retries - 1:
wait_time = 2**attempt
print(
f"Tree creation failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
)
time.sleep(wait_time)
else:
raise
def create_commit(
repo_owner, repo_name, tree_sha, parent_sha, message, token, max_retries=3
):
"""Create a new commit"""
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/commits"
data = {"tree": tree_sha, "parents": [parent_sha], "message": message}
for attempt in range(max_retries):
try:
response = make_github_request(url, token, method="POST", data=data)
commit_sha = json.loads(response)["sha"]
# Verify the commit was actually created
verify_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/commits/{commit_sha}"
verify_response = make_github_request(verify_url, token)
verify_data = json.loads(verify_response)
if verify_data["sha"] != commit_sha:
raise Exception(
f"Commit verification failed: expected {commit_sha}, got {verify_data['sha']}"
)
return commit_sha
except Exception as e:
# Don't retry on rate limit errors - fail fast
if is_rate_limit_error(e):
raise
if attempt < max_retries - 1:
wait_time = 2**attempt
print(
f"Commit creation failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
)
time.sleep(wait_time)
else:
raise
def update_branch_ref(repo_owner, repo_name, branch, commit_sha, token, max_retries=3):
"""Update branch reference to point to new commit"""
url = (
f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/refs/heads/{branch}"
)
data = {"sha": commit_sha}
for attempt in range(max_retries):
try:
make_github_request(url, token, method="PATCH", data=data)
return
except HTTPError as e:
# Don't retry on rate limit errors - fail fast
if is_rate_limit_error(e):
raise
# Check if this is an "Object does not exist" error
is_object_not_exist = False
if hasattr(e, "error_body"):
try:
error_data = json.loads(e.error_body)
if "Object does not exist" in error_data.get("message", ""):
is_object_not_exist = True
except Exception:
pass
if is_object_not_exist and attempt < max_retries - 1:
# This might be a transient consistency issue - wait and retry
wait_time = 2**attempt
print(
f"Branch update failed with 'Object does not exist' (attempt {attempt + 1}/{max_retries}), waiting {wait_time}s for consistency..."
)
time.sleep(wait_time)
else:
raise
except Exception as e:
# Don't retry on rate limit errors - fail fast
if is_rate_limit_error(e):
raise
if attempt < max_retries - 1:
wait_time = 2**attempt
print(
f"Branch update failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
)
time.sleep(wait_time)
else:
raise
def copy_trace_files(source_dir, target_base_path):
"""Copy trace files and return list of files to upload.
Only uploads traces from TP rank 0 to avoid duplicated data across tensor parallel ranks.
"""
files_to_upload = []
if not os.path.exists(source_dir):
print(f"Warning: Traces directory {source_dir} does not exist")
return files_to_upload
# Walk through source directory and find .json.gz files
for root, dirs, files in os.walk(source_dir):
for file in files:
if file.endswith(".json.gz"):
# Only upload TP rank 0 traces to avoid duplicates across tensor parallel ranks
if "TP-" in file and "TP-0" not in file:
continue
source_file = os.path.join(root, file)
# Calculate relative path from source_dir
rel_path = os.path.relpath(source_file, source_dir)
target_path = f"{target_base_path}/{rel_path}"
# Read file content
with open(source_file, "rb") as f:
content = f.read()
files_to_upload.append((target_path, content))
return files_to_upload
def publish_traces(traces_dir, run_id, run_number):
"""Publish traces from a single directory to GitHub repository in a single commit"""
target_base_path = f"traces/{run_id}"
files_to_upload = copy_trace_files(traces_dir, target_base_path)
if not files_to_upload:
print("No trace files found to upload")
return
print(f"Found {len(files_to_upload)} files to upload")
publish_traces_from_files(files_to_upload, run_id, run_number)
def publish_traces_from_files(files_to_upload, run_id, run_number):
"""Publish pre-collected trace files to GitHub repository in a single commit"""
# Get environment variables
token = os.getenv("GITHUB_TOKEN")
if not token:
print("Error: GITHUB_TOKEN environment variable not set")
sys.exit(1)
# Repository configuration
repo_owner = "sglang-bot"
repo_name = "sglang-ci-data"
branch = "main"
# Verify token permissions before proceeding
permission_check = verify_token_permissions(repo_owner, repo_name, token)
if permission_check == "rate_limited":
warnings.warn(
"Skipping trace upload due to GitHub API rate limit. "
"This is expected during high CI activity and does not indicate a test failure."
)
return
elif not permission_check:
print(
"Token permission verification failed. Please check the token permissions."
)
sys.exit(1)
max_retries = 5
retry_delay = 5 # seconds
# Create blobs once before retry loop to avoid re-uploading on failures
try:
tree_items = create_blobs(repo_owner, repo_name, files_to_upload, token)
except Exception as e:
# Check for rate limit errors during blob creation
if is_rate_limit_error(e):
warnings.warn(
"GitHub API rate limit exceeded during blob creation. Skipping trace upload."
)
return
# Check for permission errors - these should fail loudly
if is_permission_error(e):
print(
f"ERROR: Token does not have write permission to {repo_owner}/{repo_name}. "
"Please update the GH_PAT_FOR_NIGHTLY_CI_DATA secret with a token that has "
"'contents: write' permission for the repository."
)
sys.exit(1)
print(f"Failed to create blobs: {e}")
raise
for attempt in range(max_retries):
try:
# Get current branch head
branch_sha = get_branch_sha(repo_owner, repo_name, branch, token)
print(f"Current branch head: {branch_sha}")
# Get current tree
tree_sha = get_tree_sha(repo_owner, repo_name, branch_sha, token)
print(f"Current tree SHA: {tree_sha}")
# Create new tree with pre-created blobs
new_tree_sha = create_tree(
repo_owner, repo_name, tree_sha, tree_items, token
)
print(f"Created new tree: {new_tree_sha}")
# Create commit
commit_message = f"Nightly traces for run {run_id} at {run_number} ({len(files_to_upload)} files)"
commit_sha = create_commit(
repo_owner,
repo_name,
new_tree_sha,
branch_sha,
commit_message,
token,
)
print(f"Created commit: {commit_sha}")
# Update branch reference
update_branch_ref(repo_owner, repo_name, branch, commit_sha, token)
print("Updated branch reference")
print("Successfully published all traces in a single commit")
return
except Exception as e:
# Check for retryable errors
is_retryable = False
error_type = "unknown"
if hasattr(e, "error_body"):
if "Update is not a fast forward" in e.error_body:
is_retryable = True
error_type = "fast-forward conflict"
elif "Object does not exist" in e.error_body:
is_retryable = True
error_type = "object consistency"
# Also retry on HTTP errors that might be transient
if isinstance(e, HTTPError) and e.code in [422, 500, 502, 503, 504]:
is_retryable = True
error_type = f"HTTP {e.code}"
# Check for rate limit errors (non-fatal - just warn and skip)
if is_rate_limit_error(e):
warnings.warn("GitHub API rate limit exceeded. Skipping trace upload.")
return
# Check for permission errors - these should fail loudly
if is_permission_error(e):
print(
f"ERROR: Token does not have write permission to {repo_owner}/{repo_name}. "
"Please update the GH_PAT_FOR_NIGHTLY_CI_DATA secret with a token that has "
"'contents: write' permission for the repository."
)
sys.exit(1)
if is_retryable and attempt < max_retries - 1:
print(
f"Attempt {attempt + 1}/{max_retries} failed ({error_type}). Retrying in {retry_delay} seconds..."
)
time.sleep(retry_delay)
else:
print(f"Failed to publish traces after {attempt + 1} attempts: {e}")
raise
def main():
parser = argparse.ArgumentParser(
description="Publish performance traces to GitHub repository"
)
parser.add_argument(
"--traces-dir",
type=str,
action="append",
dest="traces_dirs",
required=True,
help="Traces directory to publish (can be specified multiple times)",
)
args = parser.parse_args()
# Get environment variables
run_id = os.getenv("GITHUB_RUN_ID", "test")
run_number = os.getenv("GITHUB_RUN_NUMBER", "12345")
if not run_id or not run_number:
print(
"Error: GITHUB_RUN_ID and GITHUB_RUN_NUMBER environment variables must be set"
)
sys.exit(1)
# Collect trace files from all directories
target_base_path = f"traces/{run_id}"
all_files = []
for traces_dir in args.traces_dirs:
print(f"Processing traces from directory: {traces_dir}")
files = copy_trace_files(traces_dir, target_base_path)
all_files.extend(files)
if not all_files:
print("No trace files found to upload across all directories")
return
print(f"Found {len(all_files)} total files to upload")
# Publish all collected traces in a single commit
publish_traces_from_files(all_files, run_id, run_number)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,527 @@
#!/usr/bin/env python3
"""
Runner Utilization Report
Analyzes GitHub Actions job data to calculate runner utilization metrics.
Reports idle time, active time, and utilization percentage per runner label.
"""
import argparse
import json
import os
import subprocess
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
# Labels to skip when grouping runners (GitHub default labels)
DEFAULT_LABELS_TO_IGNORE = {"self-hosted", "Linux", "X64", "ARM64"}
GITHUB_HOSTED_LABELS = {"ubuntu-latest", "ubuntu-22.04", "ubuntu-24.04"}
def run_gh_command(args: list[str]) -> dict:
"""Run gh CLI command and return JSON result."""
result = subprocess.run(
["gh", "api"] + args,
capture_output=True,
text=True,
)
if result.returncode != 0:
raise Exception(f"gh api failed: {result.stderr}")
return json.loads(result.stdout)
def get_workflow_runs(repo: str, hours: int = 24) -> list[dict]:
"""Get workflow runs from the last N hours."""
since = datetime.now(timezone.utc) - timedelta(hours=hours)
runs = []
page = 1
while True:
data = run_gh_command(
[
f"repos/{repo}/actions/runs?per_page=100&page={page}",
]
)
page_runs = data.get("workflow_runs", [])
# Filter by time
for run in page_runs:
created_at = parse_time(run.get("created_at"))
if created_at and created_at >= since:
runs.append(run)
elif created_at and created_at < since:
# Runs are ordered by created_at desc, so we can stop
return runs
if len(page_runs) < 100:
break
page += 1
if page > 20: # Safety limit
break
return runs
def get_jobs_for_run(repo: str, run_id: int) -> list[dict]:
"""Get all jobs for a workflow run."""
jobs = []
page = 1
while True:
data = run_gh_command(
[
f"repos/{repo}/actions/runs/{run_id}/jobs?per_page=100&page={page}",
]
)
jobs.extend(data.get("jobs", []))
if len(data.get("jobs", [])) < 100:
break
page += 1
if page > 5: # Safety limit
break
return jobs
def get_runners(repo: str, online_only: bool = True) -> list[dict]:
"""Get all self-hosted runners with pagination. Returns empty if no permission."""
try:
all_runners = []
page = 1
while True:
data = run_gh_command(
[f"repos/{repo}/actions/runners?per_page=100&page={page}"]
)
runners = data.get("runners", [])
all_runners.extend(runners)
if len(runners) < 100:
break
page += 1
if page > 10: # Safety limit
break
if online_only:
all_runners = [r for r in all_runners if r.get("status") == "online"]
return all_runners
except Exception as e:
print(f"Warning: Cannot access runners API (need admin): {e}")
return []
def parse_time(time_str: str) -> datetime:
"""Parse ISO timestamp to datetime."""
if not time_str:
return None
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
# Known runner counts per label (fallback when API unavailable)
KNOWN_RUNNER_COUNTS = {
"1-gpu-5090": 16,
"h200": 8,
"h20": 4,
"b200": 4,
"amd": 8,
"github-hosted": 20, # GitHub hosted runners (variable)
"other": 10,
}
def calculate_concurrency_metrics(
jobs: list[dict],
window_start: datetime,
window_end: datetime,
num_runners: int,
) -> dict:
"""
Calculate concurrency metrics using a sweep line algorithm.
Tracks:
- Peak concurrent runners in use
- Average concurrent runners over time
- Time at saturation (all runners busy)
- Queue depth when runners are saturated
"""
if not jobs:
return {
"peak_concurrent": 0,
"avg_concurrent": 0.0,
"saturation_seconds": 0,
"saturation_pct": 0.0,
"peak_queue": 0,
}
window_seconds = (window_end - window_start).total_seconds()
if window_seconds <= 0:
return {
"peak_concurrent": 0,
"avg_concurrent": 0.0,
"saturation_seconds": 0,
"saturation_pct": 0.0,
"peak_queue": 0,
}
# Create events for running jobs: +1 at start, -1 at end
running_events = []
for job in jobs:
start = job["start"]
end = job["end"]
# Clamp to window
if end < window_start or start > window_end:
continue
clamped_start = max(start, window_start)
clamped_end = min(end, window_end)
running_events.append((clamped_start, 1, "start")) # +1 for start
running_events.append((clamped_end, -1, "end")) # -1 for end
# Create events for queue tracking (jobs created but not started)
queue_events = []
for job in jobs:
created_at = job.get("created_at")
started_at = job["start"]
if created_at and created_at < started_at:
# Clamp to window
if started_at < window_start or created_at > window_end:
continue
clamped_created = max(created_at, window_start)
clamped_started = min(started_at, window_end)
queue_events.append((clamped_created, 1, "queued"))
queue_events.append((clamped_started, -1, "dequeued"))
# Sort running events: by time, then ends before starts at same time
running_events.sort(key=lambda e: (e[0], e[1] == 1))
# Process running events to get concurrency metrics
current_running = 0
peak_running = 0
prev_time = window_start
total_running_seconds = 0.0
saturation_seconds = 0.0
for event_time, delta, _ in running_events:
# Accumulate time at previous concurrency level
time_delta = (event_time - prev_time).total_seconds()
if time_delta > 0:
total_running_seconds += current_running * time_delta
if current_running >= num_runners:
saturation_seconds += time_delta
# Update concurrency
current_running += delta
peak_running = max(peak_running, current_running)
prev_time = event_time
# Handle remaining time after last event
if prev_time < window_end:
time_delta = (window_end - prev_time).total_seconds()
total_running_seconds += current_running * time_delta
if current_running >= num_runners:
saturation_seconds += time_delta
# Sort queue events and calculate peak queue depth
queue_events.sort(key=lambda e: (e[0], e[1] == 1))
current_queued = 0
peak_queue = 0
for _, delta, _ in queue_events:
current_queued += delta
peak_queue = max(peak_queue, current_queued)
avg_concurrent = total_running_seconds / window_seconds if window_seconds > 0 else 0
return {
"peak_concurrent": peak_running,
"avg_concurrent": avg_concurrent,
"saturation_seconds": saturation_seconds,
"saturation_pct": (
(saturation_seconds / window_seconds * 100) if window_seconds > 0 else 0
),
"peak_queue": peak_queue,
}
def calculate_utilization(repo: str, hours: int = 24, runner_filter: str = None):
"""Calculate runner utilization metrics."""
print(f"Fetching workflow runs from last {hours} hours...")
runs = get_workflow_runs(repo, hours)
print(f"Found {len(runs)} workflow runs")
# Try to get online runners from API
print("Fetching online runners...")
runners = get_runners(repo, online_only=True)
# Build label -> set of online runner names from API
api_label_runners = defaultdict(set)
if runners:
for runner in runners:
for label in runner.get("labels", []):
label_name = label.get("name", "")
if label_name not in DEFAULT_LABELS_TO_IGNORE:
api_label_runners[label_name].add(runner["name"])
print(f"Got {len(runners)} online runners from API")
else:
print("No runner API access, will use observed runners from job data")
# Track runners seen in jobs (for labels not in API or when API unavailable)
job_label_runners = defaultdict(set)
label_jobs = defaultdict(list) # label -> list of job_info
# Fetch jobs for all runs in parallel
total_runs = len(runs)
print(f"Fetching jobs for {total_runs} runs in parallel...")
def fetch_jobs_for_run(run):
"""Fetch jobs for a single run, returning (run_id, jobs) or (run_id, None) on error."""
try:
return (run["id"], get_jobs_for_run(repo, run["id"]))
except Exception:
return (run["id"], None)
all_jobs = []
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(fetch_jobs_for_run, run) for run in runs]
completed = 0
for future in as_completed(futures):
completed += 1
if completed % 50 == 0:
print(f"Fetched jobs for {completed}/{total_runs} runs...")
run_id, jobs = future.result()
if jobs:
all_jobs.extend(jobs)
print(f"Processing {len(all_jobs)} jobs...")
for job in all_jobs:
runner_name = job.get("runner_name")
if not runner_name:
continue
created_at = parse_time(job.get("created_at"))
started_at = parse_time(job.get("started_at"))
completed_at = parse_time(job.get("completed_at"))
if not started_at or not completed_at:
continue
duration = (completed_at - started_at).total_seconds()
queue_time = (started_at - created_at).total_seconds() if created_at else 0
job_info = {
"start": started_at,
"end": completed_at,
"created_at": created_at,
"duration": duration,
"queue_time": queue_time,
"job_name": job["name"],
"runner_name": runner_name,
}
# Use job labels directly (available in job data)
job_labels = job.get("labels", [])
for label in job_labels:
# Skip generic labels
if label in DEFAULT_LABELS_TO_IGNORE | GITHUB_HOSTED_LABELS:
continue
job_label_runners[label].add(runner_name)
label_jobs[label].append(job_info)
# Merge API runners and job-observed runners
# Prefer API count (online runners) when available
all_labels = set(api_label_runners.keys()) | set(job_label_runners.keys())
# Filter labels if specified
if runner_filter:
all_labels = {lbl for lbl in all_labels if runner_filter in lbl}
print(f"Tracking {len(all_labels)} runner labels: {sorted(all_labels)}")
# Calculate metrics per label
window_seconds = hours * 3600
window_end = datetime.now(timezone.utc)
window_start = window_end - timedelta(hours=hours)
results = []
for label in sorted(all_labels):
# Use API runner count if available, otherwise use job-observed count
if label in api_label_runners and api_label_runners[label]:
num_runners = len(api_label_runners[label])
elif label in job_label_runners:
num_runners = len(job_label_runners[label])
else:
num_runners = KNOWN_RUNNER_COUNTS.get(label, 1)
total_capacity_seconds = window_seconds * num_runners
jobs = label_jobs.get(label, [])
total_active_seconds = sum(j["duration"] for j in jobs)
utilization = (
(total_active_seconds / total_capacity_seconds * 100)
if total_capacity_seconds > 0
else 0
)
idle_seconds = total_capacity_seconds - total_active_seconds
# Calculate queue time metrics
queue_times = [j["queue_time"] for j in jobs if j["queue_time"] > 0]
avg_queue_time = sum(queue_times) / len(queue_times) if queue_times else 0
max_queue_time = max(queue_times) if queue_times else 0
# Calculate concurrency metrics
# First pass: get peak concurrent to determine effective capacity
concurrency_initial = calculate_concurrency_metrics(
jobs, window_start, window_end, num_runners
)
# Use observed peak as effective capacity if lower than API count
# This handles cases where not all runners are active all the time
effective_runners = min(num_runners, concurrency_initial["peak_concurrent"])
if effective_runners < num_runners and effective_runners > 0:
# Recalculate with effective capacity for accurate saturation
concurrency = calculate_concurrency_metrics(
jobs, window_start, window_end, effective_runners
)
else:
concurrency = concurrency_initial
effective_runners = num_runners
results.append(
{
"label": label,
"num_runners": num_runners,
"effective_runners": effective_runners,
"num_jobs": len(jobs),
"total_active_hours": total_active_seconds / 3600,
"total_idle_hours": idle_seconds / 3600,
"total_capacity_hours": total_capacity_seconds / 3600,
"utilization_pct": utilization,
"avg_queue_min": avg_queue_time / 60,
"max_queue_min": max_queue_time / 60,
# Concurrency metrics
"peak_concurrent": concurrency_initial["peak_concurrent"],
"avg_concurrent": concurrency["avg_concurrent"],
"saturation_hours": concurrency["saturation_seconds"] / 3600,
"saturation_pct": concurrency["saturation_pct"],
"peak_queue": concurrency["peak_queue"],
}
)
return results
def format_report(results: list[dict], hours: int) -> str:
"""Format results as markdown report."""
lines = [
"# Runner Utilization Report",
"",
f"**Time window:** Last {hours} hours",
f"**Generated:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}",
"",
"## Concurrency Analysis",
"",
"| Label | Runners (API/Effective) | Peak Concurrent | Avg Concurrent | Saturation Time | Peak Queue |",
"|-------|-------------------------|-----------------|----------------|-----------------|------------|",
]
for r in results:
effective = r["effective_runners"]
avg_pct = (r["avg_concurrent"] / effective * 100) if effective > 0 else 0
runner_str = (
f"{r['num_runners']}/{effective}"
if effective != r["num_runners"]
else str(r["num_runners"])
)
lines.append(
f"| {r['label']} | {runner_str} | "
f"{r['peak_concurrent']} | "
f"{r['avg_concurrent']:.1f} ({avg_pct:.0f}%) | "
f"{r['saturation_hours']:.1f}h ({r['saturation_pct']:.0f}%) | "
f"{r['peak_queue']} jobs |"
)
# Add recommendations section
lines.extend(["", "## Recommendations", ""])
has_recommendations = False
for r in results:
label = r["label"]
saturation_pct = r["saturation_pct"]
peak_queue = r["peak_queue"]
effective = r["effective_runners"]
avg_pct = (r["avg_concurrent"] / effective * 100) if effective > 0 else 0
if saturation_pct > 50 or peak_queue > 5:
lines.append(
f"⚠️ **{label}**: High saturation ({saturation_pct:.0f}%) "
f"with queue buildup ({peak_queue} jobs). Consider adding runners."
)
has_recommendations = True
elif saturation_pct > 20 or peak_queue > 0:
lines.append(
f"📊 **{label}**: Moderate saturation ({saturation_pct:.0f}%), "
f"peak queue {peak_queue} jobs. Monitor for trends."
)
has_recommendations = True
elif avg_pct < 30 and r["num_jobs"] > 0:
lines.append(
f"💡 **{label}**: Low average utilization ({avg_pct:.0f}%). "
f"Runner pool may be oversized."
)
has_recommendations = True
else:
lines.append(f"✓ **{label}**: Healthy utilization with minimal queueing.")
if not has_recommendations and results:
lines.append("All runner pools have healthy utilization.")
# Add summary table
lines.extend(
[
"",
"## Summary by Runner Label",
"",
"| Label | Runners | Jobs | Active (hrs) | Utilization | Avg Queue | Max Queue |",
"|-------|---------|------|--------------|-------------|-----------|-----------|",
]
)
for r in results:
utilization_bar = "" * int(r["utilization_pct"] / 10) + "" * (
10 - int(r["utilization_pct"] / 10)
)
lines.append(
f"| {r['label']} | {r['num_runners']} | {r['num_jobs']} | "
f"{r['total_active_hours']:.1f} | "
f"{r['utilization_pct']:.1f}% {utilization_bar} | "
f"{r['avg_queue_min']:.1f}m | {r['max_queue_min']:.1f}m |"
)
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(description="Generate runner utilization report")
parser.add_argument("--repo", default="sgl-project/sglang", help="GitHub repo")
parser.add_argument("--hours", type=int, default=24, help="Time window in hours")
parser.add_argument(
"--filter", type=str, help="Filter runner labels (e.g., '5090', 'h200')"
)
parser.add_argument("--output", type=str, help="Output file (default: stdout)")
args = parser.parse_args()
results = calculate_utilization(args.repo, args.hours, args.filter)
report = format_report(results, args.hours)
if args.output:
with open(args.output, "w") as f:
f.write(report)
print(f"Report written to {args.output}")
else:
print(report)
# Also write to GITHUB_STEP_SUMMARY if available
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
if summary_file:
with open(summary_file, "a") as f:
f.write(report)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,245 @@
#!/usr/bin/env python3
"""Collect and save performance metrics from nightly benchmark results.
This script reads benchmark result JSON files from performance profile directories
and saves them with metadata for artifact collection in CI.
Usage:
python3 scripts/ci/utils/save_metrics.py \
--gpu-config 8-gpu-h200 \
--partition 0 \
--run-id 12345678 \
--output test/metrics-8gpu-h200-partition-0.json
"""
import argparse
import glob
import json
import os
import sys
from datetime import datetime, timezone
def find_result_files(search_dirs: list[str]) -> list[str]:
"""Find all results_*.json files in the given directories."""
result_files = set()
for search_dir in search_dirs:
if os.path.exists(search_dir):
pattern = os.path.join(search_dir, "**/results_*.json")
result_files.update(glob.glob(pattern, recursive=True))
return list(result_files)
def parse_result_file(filepath: str) -> list[dict]:
"""Parse a benchmark result JSON file."""
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return data
return [data]
except (json.JSONDecodeError, OSError) as e:
print(f"Warning: Failed to parse {filepath}: {e}")
return []
def transform_benchmark_result(result: dict, gpu_config: str, partition: int) -> dict:
"""Transform a benchmark result to the metrics schema.
Note: input_len and output_len are preserved here for the flat benchmarks list,
but are also used as grouping keys in benchmarks_by_io_len.
"""
# Handle None values safely for numeric conversions
latency = result.get("latency")
last_ttft = result.get("last_ttft")
return {
"batch_size": result.get("batch_size"),
"input_len": result.get("input_len"),
"output_len": result.get("output_len"),
"latency_ms": latency * 1000 if latency is not None else None,
"input_throughput": result.get("input_throughput"),
"output_throughput": result.get("output_throughput"),
"overall_throughput": result.get("overall_throughput"),
"ttft_ms": last_ttft * 1000 if last_ttft is not None else None,
"acc_length": result.get("acc_length"),
}
def get_io_len_key(input_len: int, output_len: int) -> str:
"""Generate a key for input/output length combination."""
return f"{input_len}_{output_len}"
def group_results_by_model(
results: list[dict], gpu_config: str, partition: int
) -> list[dict]:
"""Group benchmark results by model, variant, and server_args.
Results are organized with two benchmark structures:
- benchmarks: flat list of all benchmarks (for backward compatibility)
- benchmarks_by_io_len: nested structure grouped by input/output length combinations
"""
groups = {}
for result in results:
model_path = result.get("model_path", "unknown")
run_name = result.get("run_name", "default")
variant = run_name if run_name != "default" else None
server_args = result.get("server_args")
# Convert server_args list to tuple for use as dict key (lists are not hashable)
server_args_key = tuple(server_args) if server_args else None
key = (model_path, variant, server_args_key)
if key not in groups:
groups[key] = {
"gpu_config": gpu_config,
"partition": partition,
"model": model_path,
"variant": variant,
"server_args": server_args,
"benchmarks": [],
"benchmarks_by_io_len": {},
}
transformed = transform_benchmark_result(result, gpu_config, partition)
# Add to flat benchmarks list (backward compatibility)
groups[key]["benchmarks"].append(transformed)
# Add to nested benchmarks_by_io_len structure
input_len = result.get("input_len")
output_len = result.get("output_len")
if input_len is not None and output_len is not None:
io_key = get_io_len_key(input_len, output_len)
if io_key not in groups[key]["benchmarks_by_io_len"]:
groups[key]["benchmarks_by_io_len"][io_key] = {
"input_len": input_len,
"output_len": output_len,
"benchmarks": [],
}
# For the nested structure, exclude input_len and output_len from individual benchmarks
# since they're already in the parent
nested_benchmark = {
k: v
for k, v in transformed.items()
if k not in ("input_len", "output_len")
}
groups[key]["benchmarks_by_io_len"][io_key]["benchmarks"].append(
nested_benchmark
)
return list(groups.values())
def save_metrics(
gpu_config: str,
partition: int,
run_id: str,
output_file: str,
search_dirs: list[str],
) -> bool:
"""Collect metrics and save to output file."""
timestamp = datetime.now(timezone.utc).isoformat()
# Find all result files
result_files = find_result_files(search_dirs)
print(f"Found {len(result_files)} result file(s)")
grouped = []
if not result_files:
print("No benchmark result files found")
else:
# Parse all result files
all_results = []
for filepath in sorted(result_files):
print(f" Reading: {filepath}")
results = parse_result_file(filepath)
all_results.extend(results)
print(f"Total benchmark results: {len(all_results)}")
# Group by model/variant
grouped = group_results_by_model(all_results, gpu_config, partition)
# Create metrics structure
metrics = {
"run_id": run_id,
"timestamp": timestamp,
"gpu_config": gpu_config,
"partition": partition,
"results": grouped,
}
# Ensure output directory exists and write output
try:
os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
if not result_files:
print(f"Created empty metrics file: {output_file}")
else:
print(f"Saved metrics to: {output_file}")
return True
except OSError as e:
print(f"Error writing metrics file: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Collect performance metrics from benchmark results"
)
parser.add_argument(
"--gpu-config",
required=True,
help="GPU configuration (e.g., 8-gpu-h200, 8-gpu-b200)",
)
parser.add_argument(
"--partition",
type=int,
required=True,
help="Partition number (0, 1, 2, etc.)",
)
parser.add_argument(
"--run-id",
required=True,
help="GitHub Actions run ID",
)
parser.add_argument(
"--output",
required=True,
help="Output file path for metrics JSON",
)
parser.add_argument(
"--search-dir",
action="append",
default=[],
dest="search_dirs",
help="Directory to search for result files (can be specified multiple times)",
)
args = parser.parse_args()
# Default search directories if none specified
search_dirs = args.search_dirs or [
"test/performance_profiles_8_gpu",
"test/performance_profiles_text_models",
"test/performance_profiles_vlms",
"test",
".",
]
success = save_metrics(
gpu_config=args.gpu_config,
partition=args.partition,
run_id=args.run_id,
output_file=args.output,
search_dirs=search_dirs,
)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,870 @@
import glob
import json
import os
import re
import sys
import time
from datetime import datetime, timezone
import requests
from github import Auth, Github
# Configuration
PERMISSIONS_FILE_PATH = ".github/CI_PERMISSIONS.json"
def find_workflow_run_url(
gh_repo,
workflow_id,
ref,
target_stage,
token,
dispatch_time,
pr_head_sha=None,
max_wait=30,
test_command=None,
):
"""
Poll for the workflow run URL after dispatch.
Uses the dynamic run-name feature to identify runs:
- Fork PRs: display_title = "[stage-name] sha"
- Non-fork PRs: display_title = "[stage-name]"
Args:
gh_repo: PyGithub repository object
workflow_id: ID of the workflow that was dispatched
ref: Branch/ref the workflow was dispatched on
target_stage: The stage name we're looking for
token: GitHub API token
dispatch_time: Unix timestamp when dispatch was triggered
pr_head_sha: PR head SHA (for fork PRs, used to match display_title)
max_wait: Maximum seconds to wait for the run to appear
Returns:
The workflow run URL if found, None otherwise.
"""
# Build expected display_title based on workflow's run-name.
# rerun-test includes test_command: "[rerun-test] <test_command> [<sha>]"
# Other workflows: "[stage-name] [<sha>]"
suffix = f" {test_command}" if test_command else ""
if pr_head_sha:
expected_title = f"[{target_stage}]{suffix} {pr_head_sha}"
else:
expected_title = f"[{target_stage}]{suffix}"
print(f"Looking for workflow run with display_title: {expected_title}")
for attempt in range(max_wait // 5):
time.sleep(5)
# Get recent workflow_dispatch runs for this workflow
runs_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{workflow_id}/runs"
runs_resp = requests.get(
runs_url,
params={"event": "workflow_dispatch", "branch": ref, "per_page": 10},
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
},
)
if runs_resp.status_code != 200:
print(f"Failed to fetch workflow runs: {runs_resp.status_code}")
continue
for run in runs_resp.json().get("workflow_runs", []):
# Skip runs created before our dispatch (with 10s tolerance)
run_created = datetime.fromisoformat(
run["created_at"].replace("Z", "+00:00")
).timestamp()
if run_created < dispatch_time - 10:
continue
# Match by display_title (set by workflow's run-name directive)
# This is immediately available, unlike job names which require waiting
display_title = run.get("display_title", "")
if display_title == expected_title:
print(
f"Found matching workflow run: {run['id']} with title '{display_title}'"
)
return run["html_url"]
print(f"Could not find workflow run after {max_wait} seconds")
return None
def get_env_var(name):
val = os.getenv(name)
if not val:
print(f"Error: Environment variable {name} not set.")
sys.exit(1)
return val
def load_permissions(user_login):
"""
Reads the permissions JSON from the local file system and returns
the permissions dict for the specific user.
"""
try:
print(f"Loading permissions from {PERMISSIONS_FILE_PATH}...")
if not os.path.exists(PERMISSIONS_FILE_PATH):
print(f"Error: Permissions file not found at {PERMISSIONS_FILE_PATH}")
return None
with open(PERMISSIONS_FILE_PATH, "r") as f:
data = json.load(f)
user_perms = data.get(user_login)
if not user_perms:
print(f"User '{user_login}' not found in permissions file.")
return None
return user_perms
except Exception as e:
print(f"Failed to load or parse permissions file: {e}")
sys.exit(1)
def has_sgl_kernel_changes(pr):
"""
Check if the PR has changes to the sgl-kernel directory.
This is used to determine if we need a full workflow rerun
(to rebuild the kernel) vs just rerunning failed jobs.
"""
try:
files = pr.get_files()
for f in files:
if f.filename.startswith("sgl-kernel/"):
return True
return False
except Exception as e:
print(f"Warning: Could not check PR files for sgl-kernel changes: {e}")
# Default to False to avoid unnecessary full reruns
return False
def handle_tag_run_ci(gh_repo, pr, comment, user_perms, react_on_success=True):
"""
Handles the /tag-run-ci-label command.
Returns True if action was taken, False otherwise.
"""
if not user_perms.get("can_tag_run_ci_label", False):
print("Permission denied: can_tag_run_ci_label is false.")
return False
print("Permission granted. Adding 'run-ci' label.")
pr.add_to_labels("run-ci")
if react_on_success:
comment.create_reaction("+1")
print("Label added and comment reacted.")
else:
print("Label added (reaction suppressed).")
return True
def handle_rerun_failed_ci(gh_repo, pr, comment, user_perms, react_on_success=True):
"""
Handles the /rerun-failed-ci command.
Reruns workflows with 'failure' or 'skipped' conclusions.
Returns True if action was taken, False otherwise.
"""
if not user_perms.get("can_rerun_failed_ci", False):
print("Permission denied: can_rerun_failed_ci is false.")
return False
print("Permission granted. Triggering rerun of failed or skipped workflows.")
# Check if PR has sgl-kernel changes - if so, we need full reruns
# to ensure sgl-kernel-build-wheels runs and produces fresh artifacts
sgl_kernel_changes = has_sgl_kernel_changes(pr)
if sgl_kernel_changes:
print("PR has sgl-kernel changes - will use full rerun to rebuild kernel")
# Get the SHA of the latest commit in the PR
head_sha = pr.head.sha
print(f"Checking workflows for commit: {head_sha}")
# List all workflow runs for this commit
runs = gh_repo.get_workflow_runs(head_sha=head_sha)
rerun_count = 0
for run in runs:
if run.status != "completed":
continue
if run.conclusion == "failure":
print(f"Rerunning failed workflow: {run.name} (ID: {run.id})")
try:
if sgl_kernel_changes:
# Full rerun to ensure sgl-kernel-build-wheels runs
# and produces fresh artifacts for dependent jobs
run.rerun()
else:
# Use rerun_failed_jobs for efficiency on failures
run.rerun_failed_jobs()
rerun_count += 1
except Exception as e:
print(f"Failed to rerun workflow {run.id}: {e}")
elif run.conclusion == "skipped":
print(f"Rerunning skipped workflow: {run.name} (ID: {run.id})")
try:
# Skipped workflows don't have 'failed jobs', so we use full rerun()
run.rerun()
rerun_count += 1
except Exception as e:
print(f"Failed to rerun workflow {run.id}: {e}")
if rerun_count > 0:
print(f"Triggered rerun for {rerun_count} workflows.")
if react_on_success:
comment.create_reaction("+1")
return True
else:
print("No failed or skipped workflows found to rerun.")
return False
def handle_rerun_stage(
gh_repo, pr, comment, user_perms, stage_name, token, react_on_success=True
):
"""
Handles the /rerun-stage <stage-name> command.
Triggers a workflow_dispatch to run only the specified stage, skipping dependencies.
Returns True if action was taken, False otherwise.
"""
if not user_perms.get("can_rerun_stage", False):
print("Permission denied: can_rerun_stage is false.")
return False
if not stage_name:
print("Error: No stage name provided")
comment.create_reaction("confused")
pr.create_issue_comment(
f"❌ Please specify a stage name: `/rerun-stage <stage-name>`\n\n"
f"Examples: `/rerun-stage unit-test-backend-4-gpu`, `/rerun-stage accuracy-test-1-gpu`"
)
return False
print(f"Permission granted. Triggering workflow_dispatch for stage '{stage_name}'.")
# Valid NVIDIA stage names that support target_stage
nvidia_stages = [
"stage-a-test-1-gpu-small",
"stage-a-test-cpu",
"stage-b-test-1-gpu-small",
"stage-b-test-1-gpu-large",
"stage-b-test-2-gpu-large",
"stage-b-test-4-gpu-b200",
"stage-c-test-4-gpu-h100",
"stage-c-test-8-gpu-h200",
"stage-c-test-8-gpu-h20",
"stage-c-test-4-gpu-b200",
"stage-c-test-4-gpu-gb200",
"stage-c-test-deepep-4-gpu-h100",
"stage-c-test-deepep-8-gpu-h200",
"multimodal-gen-test-1-gpu",
"multimodal-gen-test-2-gpu",
"multimodal-gen-test-1-b200",
]
# Valid AMD stage names that support target_stage
amd_stages = [
"sgl-kernel-unit-test-amd",
"sgl-kernel-unit-test-2-gpu-amd",
"stage-a-test-1-gpu-small-amd",
"stage-b-test-1-gpu-small-amd",
"stage-b-test-1-gpu-small-amd-nondeterministic",
"stage-b-test-1-gpu-small-amd-mi35x",
"stage-b-test-1-gpu-large-amd",
"stage-b-test-2-gpu-large-amd",
"multimodal-gen-test-1-gpu-amd",
"multimodal-gen-test-2-gpu-amd",
"stage-c-test-large-8-gpu-amd",
"stage-c-test-large-8-gpu-amd-mi35x",
]
valid_stages = nvidia_stages + amd_stages
is_amd_stage = stage_name in amd_stages
if stage_name not in valid_stages:
comment.create_reaction("confused")
pr.create_issue_comment(
f"❌ Stage `{stage_name}` doesn't support isolated runs yet.\n\n"
f"**NVIDIA stages:**\n"
+ "\n".join(f"- `{s}`" for s in nvidia_stages)
+ "\n\n**AMD stages:**\n"
+ "\n".join(f"- `{s}`" for s in amd_stages)
+ "\n\nOther stages will be added soon. For now, use `/rerun-failed-ci` for those stages."
)
return False
try:
# Get the appropriate workflow based on stage type
workflow_name = "PR Test (AMD)" if is_amd_stage else "PR Test"
workflows = gh_repo.get_workflows()
target_workflow = None
for wf in workflows:
if wf.name == workflow_name:
target_workflow = wf
break
if not target_workflow:
print(f"Error: {workflow_name} workflow not found")
return False
# Check if PR is from a fork by comparing repo owners
# Handle case where fork repo may have been deleted (pr.head.repo is None)
is_fork = (
pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
)
print(f"PR is from fork: {is_fork}")
# pr_head_sha is used for fork PRs (passed to workflow and used for URL lookup)
pr_head_sha = None
if is_fork:
# For fork PRs: dispatch on main and pass SHA as input
# This is needed because fork branch names don't exist in the main repo
ref = "main"
pr_head_sha = pr.head.sha
print(
f"Triggering {workflow_name} workflow on ref: {ref}, PR head SHA: {pr_head_sha}"
)
if is_amd_stage:
inputs = {
"target_stage": stage_name,
"pr_head_sha": pr_head_sha,
}
else:
inputs = {
"target_stage": stage_name,
"pr_head_sha": pr_head_sha,
}
else:
# For non-fork PRs: dispatch on the PR branch directly
# This allows testing workflow changes before merge
ref = pr.head.ref
print(f"Triggering {workflow_name} workflow on branch: {ref}")
if is_amd_stage:
inputs = {"target_stage": stage_name}
else:
inputs = {"target_stage": stage_name}
# Record dispatch time before triggering
dispatch_time = time.time()
# Use requests directly as PyGithub's create_dispatch only accepts HTTP 204
dispatch_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{target_workflow.id}/dispatches"
dispatch_resp = requests.post(
dispatch_url,
json={"ref": ref, "inputs": inputs},
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
},
)
success = dispatch_resp.status_code in (200, 204)
if not success:
print(f"Dispatch failed: {dispatch_resp.status_code} {dispatch_resp.text}")
if success:
print(f"Successfully triggered workflow for stage '{stage_name}'")
if react_on_success:
comment.create_reaction("+1")
run_url = find_workflow_run_url(
gh_repo,
target_workflow.id,
ref,
stage_name,
token,
dispatch_time,
pr_head_sha=pr_head_sha,
max_wait=30,
)
if run_url:
pr.create_issue_comment(
f"✅ Triggered `{stage_name}` to run independently"
f" (skipping dependencies)."
f" [View workflow run]({run_url})"
)
else:
pr.create_issue_comment(
f"✅ Triggered `{stage_name}` to run independently"
f" (skipping dependencies).\n"
f"⚠️ Could not retrieve workflow run URL. "
f"Check the [Actions tab](https://github.com/{gh_repo.full_name}/actions) for progress."
)
return True
else:
print("Failed to trigger workflow_dispatch")
return False
except Exception as e:
print(f"Error triggering workflow_dispatch: {e}")
comment.create_reaction("confused")
pr.create_issue_comment(
f"❌ Failed to trigger workflow: {str(e)}\n\n"
f"Please check the logs or contact maintainers."
)
return False
CUDA_SUITE_TO_RUNNER = {
"stage-a-test-1-gpu-small": "1-gpu-5090",
"stage-a-test-cpu": "ubuntu-latest",
"stage-b-test-1-gpu-small": "1-gpu-5090",
"stage-b-test-1-gpu-large": "1-gpu-h100",
"stage-b-test-2-gpu-large": "2-gpu-h100",
"stage-b-test-4-gpu-b200": "4-gpu-b200",
"stage-c-test-4-gpu-h100": "4-gpu-h100",
"stage-c-test-8-gpu-h200": "8-gpu-h200",
"stage-c-test-8-gpu-h20": "8-gpu-h20",
"stage-c-test-4-gpu-b200": "4-gpu-b200",
"stage-c-test-deepep-4-gpu-h100": "4-gpu-h100",
"stage-c-test-deepep-8-gpu-h200": "8-gpu-h200",
}
DEEPEP_SUITES = {
"stage-c-test-8-gpu-h20",
"stage-c-test-deepep-4-gpu-h100",
"stage-c-test-deepep-8-gpu-h200",
}
def resolve_test_file(file_part):
"""
Resolve a user-provided file path to a path relative to test/.
Supports:
- Full path: test/registered/core/test_srt_endpoint.py
- Relative to test/: registered/core/test_srt_endpoint.py
- Bare filename: test_srt_endpoint.py (glob-matched, must be unique)
Returns (resolved_path, error_message). On success error_message is None.
"""
if file_part.startswith("test/"):
file_part = file_part[len("test/") :]
if "/" not in file_part:
matches = glob.glob(f"test/registered/**/{file_part}", recursive=True)
if len(matches) == 0:
return (
None,
f"No test file found matching `{file_part}` under `test/registered/`.",
)
if len(matches) > 1:
match_list = "\n".join(f"- `{m}`" for m in sorted(matches))
return None, (
f"Ambiguous filename `{file_part}` — matched {len(matches)} files:\n\n"
f"{match_list}\n\n"
f"Please provide the full path, e.g. `/rerun-test {matches[0]}`"
)
return matches[0][len("test/") :], None
full_path = f"test/{file_part}"
if not os.path.isfile(full_path):
return None, f"File not found: `{full_path}`"
return file_part, None
def detect_suite(file_path_from_test):
"""
Read a test file and extract the suite from register_cuda_ci or register_cpu_ci.
Returns (suite_name, runner_label, use_deepep, is_cpu, error_message).
"""
full_path = f"test/{file_path_from_test}"
with open(full_path, "r") as f:
content = f.read()
# Try CUDA first
match = re.search(
r'^[^#\n]*register_cuda_ci\([^)]*suite\s*=\s*["\']([^"\']+)["\']',
content,
re.MULTILINE,
)
if match:
suite = match.group(1)
runner = CUDA_SUITE_TO_RUNNER.get(suite)
if not runner:
known = ", ".join(f"`{s}`" for s in sorted(CUDA_SUITE_TO_RUNNER))
return (
suite,
None,
False,
False,
(
f"Unknown CUDA suite `{suite}` in `{full_path}`.\n\n"
f"Known suites: {known}"
),
)
use_deepep = suite in DEEPEP_SUITES
return suite, runner, use_deepep, False, None
# Try CPU
match = re.search(
r'^[^#\n]*register_cpu_ci\([^)]*suite\s*=\s*["\']([^"\']+)["\']',
content,
re.MULTILINE,
)
if match:
suite = match.group(1)
return suite, "ubuntu-latest", False, True, None
return (
None,
None,
False,
False,
(
f"No `register_cuda_ci()` or `register_cpu_ci()` found in `{full_path}`.\n\n"
f"This file may not be a registered CI test."
),
)
def _resolve_test_spec(test_spec):
"""
Resolve a single test spec into its components without dispatching.
Returns a dict with keys: spec, resolved_path, test_command, suite,
runner_label, use_deepep, is_cpu, error.
"""
if "::" in test_spec:
file_part, test_selector = test_spec.split("::", 1)
else:
file_part = test_spec
test_selector = None
file_part = file_part.strip()
if test_selector:
test_selector = test_selector.strip()
resolved_path, err = resolve_test_file(file_part)
if err:
return {"spec": test_spec, "error": err}
suite, runner_label, use_deepep, is_cpu, err = detect_suite(resolved_path)
if err:
return {"spec": test_spec, "error": err}
test_command = resolved_path
if test_selector:
test_command = f"{resolved_path} {test_selector}"
print(
f"Resolved: file={resolved_path}, selector={test_selector}, "
f"suite={suite}, runner={runner_label}, deepep={use_deepep}, "
f"cpu={is_cpu}, command='{test_command}'"
)
return {
"spec": test_spec,
"test_command": test_command,
"suite": suite,
"runner_label": runner_label,
"use_deepep": use_deepep,
"is_cpu": is_cpu,
"error": None,
}
def _dispatch_batch(gh_repo, pr, batch, token):
"""
Dispatch a single workflow run for a batch of resolved test specs
that share the same (runner_label, use_deepep, is_cpu).
Returns a dict with keys: specs, success, test_commands, runner_label, run_url, error.
"""
test_commands = [r["test_command"] for r in batch]
runner_label = batch[0]["runner_label"]
use_deepep = batch[0]["use_deepep"]
is_cpu = batch[0]["is_cpu"]
# Join multiple commands with newlines for the workflow to iterate over
combined_command = "\n".join(test_commands)
try:
workflow_name = "Rerun Test"
workflows = gh_repo.get_workflows()
target_workflow = None
for wf in workflows:
if wf.name == workflow_name:
target_workflow = wf
break
if not target_workflow:
return {
"specs": [r["spec"] for r in batch],
"success": False,
"error": f"{workflow_name} workflow not found",
}
is_fork = (
pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
)
pr_head_sha = None
inputs = {
"test_command": combined_command,
"runner_label": runner_label,
"use_deepep": str(use_deepep).lower(),
"is_cpu": str(is_cpu).lower(),
}
if is_fork:
ref = "main"
pr_head_sha = pr.head.sha
inputs["pr_head_sha"] = pr_head_sha
else:
ref = pr.head.ref
dispatch_time = time.time()
dispatch_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{target_workflow.id}/dispatches"
dispatch_resp = requests.post(
dispatch_url,
json={"ref": ref, "inputs": inputs},
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
},
)
success = dispatch_resp.status_code in (200, 204)
if not success:
print(f"Dispatch failed: {dispatch_resp.status_code} {dispatch_resp.text}")
return {
"specs": [r["spec"] for r in batch],
"success": False,
"error": f"Dispatch failed: {dispatch_resp.status_code}",
}
print(f"Successfully triggered rerun-test: {combined_command}")
run_url = find_workflow_run_url(
gh_repo,
target_workflow.id,
ref,
"rerun-test",
token,
dispatch_time,
pr_head_sha=pr_head_sha,
max_wait=30,
test_command=combined_command,
)
return {
"specs": [r["spec"] for r in batch],
"success": True,
"test_commands": test_commands,
"runner_label": runner_label,
"run_url": run_url,
}
except Exception as e:
print(f"Error triggering rerun-test for batch: {e}")
return {
"specs": [r["spec"] for r in batch],
"success": False,
"error": str(e),
}
def handle_rerun_test(gh_repo, pr, comment, user_perms, test_specs, token):
"""
Handles the /rerun-test command. Resolves all test specs, groups them by
(runner_label, use_deepep, is_cpu), and dispatches one workflow per group.
"""
# SECURITY: For fork PRs, only allow /rerun-test if the commenter has write+ permission.
# This command checks out and executes code from the PR branch on self-hosted GPU
# runners, so we must ensure the commenter is a trusted collaborator.
is_fork = pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
if is_fork:
commenter = comment.user.login
perm = gh_repo.get_collaborator_permission(commenter)
if perm not in ("admin", "write"):
print(f"Permission denied: /rerun-test on fork PR by {commenter}.")
comment.create_reaction("confused")
pr.create_issue_comment(
"❌ `/rerun-test` is not available for fork PRs unless the commenter "
"has write permission on the repo.\n\n"
"Please ask a maintainer to run this command, or use the normal CI flow."
)
return False
print(f"Fork PR, but commenter {commenter} has write+ permission. Proceeding.")
if not (
user_perms.get("can_rerun_test", False)
or user_perms.get("can_rerun_stage", False)
):
print("Permission denied: neither can_rerun_test nor can_rerun_stage is true.")
return False
if not test_specs:
comment.create_reaction("confused")
pr.create_issue_comment(
"❌ Please specify a test: `/rerun-test <file>::<TestClass.test_method>`\n\n"
"Examples:\n"
"- `/rerun-test test/registered/core/test_srt_endpoint.py::TestSRTEndpoint.test_simple_decode`\n"
"- `/rerun-test registered/core/test_srt_endpoint.py::TestSRTEndpoint`\n"
"- `/rerun-test test_srt_endpoint.py`\n"
"- `/rerun-test test_a.py test_b.py test_c.py` (multiple tests)"
)
return False
# Phase 1: Resolve all specs
resolved = []
resolve_failures = []
for spec in test_specs:
r = _resolve_test_spec(spec)
if r.get("error"):
resolve_failures.append(r)
else:
resolved.append(r)
# Phase 2: Group by (runner_label, use_deepep, is_cpu)
groups = {}
for r in resolved:
key = (r["runner_label"], r["use_deepep"], r["is_cpu"])
groups.setdefault(key, []).append(r)
# Phase 3: Dispatch one workflow per group
dispatch_results = []
for batch in groups.values():
dispatch_results.append(_dispatch_batch(gh_repo, pr, batch, token))
# Build consolidated comment
lines = []
for dr in dispatch_results:
if dr["success"]:
cmds = "\n".join(
f"cd test/ && python3 {cmd}" for cmd in dr["test_commands"]
)
if dr.get("run_url"):
lines.append(
f"✅ `{dr['runner_label']}` ({len(dr['test_commands'])} test{'s' if len(dr['test_commands']) > 1 else ''}): "
f"[View workflow run]({dr['run_url']})\n"
f"```\n{cmds}\n```"
)
else:
lines.append(
f"✅ `{dr['runner_label']}` ({len(dr['test_commands'])} test{'s' if len(dr['test_commands']) > 1 else ''}):\n"
f"```\n{cmds}\n```\n"
f"⚠️ Could not retrieve workflow run URL. "
f"Check the [Actions tab](https://github.com/{gh_repo.full_name}/actions) for progress."
)
else:
specs_str = ", ".join(f"`{s}`" for s in dr["specs"])
lines.append(f"{specs_str}: {dr['error']}")
for r in resolve_failures:
lines.append(f"❌ `{r['spec']}`: {r['error']}")
body = "\n\n".join(lines)
successes = [dr for dr in dispatch_results if dr["success"]]
if successes:
comment.create_reaction("+1")
if not successes and (resolve_failures or dispatch_results):
comment.create_reaction("confused")
pr.create_issue_comment(body)
return len(successes) > 0
def main():
# 1. Load Environment Variables
token = get_env_var("GITHUB_TOKEN")
repo_name = get_env_var("REPO_FULL_NAME")
pr_number = int(get_env_var("PR_NUMBER"))
comment_id = int(get_env_var("COMMENT_ID"))
comment_body = get_env_var("COMMENT_BODY").strip()
user_login = get_env_var("USER_LOGIN")
# 2. Load Permissions (local file check first to avoid unnecessary API calls)
user_perms = load_permissions(user_login)
# 3. Initialize GitHub API with Auth
auth = Auth.Token(token)
g = Github(auth=auth)
repo = g.get_repo(repo_name)
pr = repo.get_pull(pr_number)
comment = repo.get_issue(pr_number).get_comment(comment_id)
# PR authors can always rerun failed CI and rerun individual UTs on their own PRs,
# even if they are not listed in CI_PERMISSIONS.json.
# Note: /tag-run-ci-label and /rerun-stage still require CI_PERMISSIONS.json.
# Note: /rerun-test is blocked entirely for fork PRs in handle_rerun_test() itself.
if pr.user.login == user_login:
if user_perms is None:
print(
f"User {user_login} is the PR author (not in CI_PERMISSIONS.json). "
"Granting CI rerun permissions."
)
user_perms = {}
else:
print(
f"User {user_login} is the PR author and has existing CI permissions."
)
user_perms["can_rerun_failed_ci"] = True
user_perms["can_rerun_test"] = True
if not user_perms:
print(f"User {user_login} does not have any configured permissions. Exiting.")
return
# 4. Parse Command and Execute
first_line = comment_body.split("\n")[0].strip()
if first_line.startswith("/tag-run-ci-label"):
handle_tag_run_ci(repo, pr, comment, user_perms)
elif first_line.startswith("/rerun-failed-ci"):
handle_rerun_failed_ci(repo, pr, comment, user_perms)
elif first_line.startswith("/tag-and-rerun-ci"):
# Perform both actions, but suppress individual reactions
print("Processing combined command: /tag-and-rerun-ci")
tagged = handle_tag_run_ci(
repo, pr, comment, user_perms, react_on_success=False
)
# Wait for the label to propagate before triggering rerun
if tagged:
print("Waiting 5 seconds for label to propagate...")
time.sleep(5)
rerun = handle_rerun_failed_ci(
repo, pr, comment, user_perms, react_on_success=False
)
# If at least one action was successful, add the reaction here
if tagged or rerun:
comment.create_reaction("+1")
print("Combined command processed successfully; reaction added.")
else:
print("Combined command finished, but no actions were taken.")
elif first_line.startswith("/rerun-stage"):
# Extract stage name from command
parts = first_line.split(maxsplit=1)
stage_name = parts[1].strip() if len(parts) > 1 else None
handle_rerun_stage(repo, pr, comment, user_perms, stage_name, token)
elif first_line.startswith("/rerun-test"):
test_specs = first_line.split()[1:]
handle_rerun_test(repo, pr, comment, user_perms, test_specs or None, token)
else:
print(f"Unknown or ignored command: {first_line}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,55 @@
# SGLang CI failure monitoring
Scripts used by [.github/workflows/ci-failure-monitor.yml](../../.github/workflows/ci-failure-monitor.yml): scheduled failure analysis and optional Slack notifications.
## Tools
1. **Failures Analyzer** (`ci_failures_analysis.py`): Tracks consecutive failures, identifies flaky jobs, and monitors runner health across PR Test / Nightly workflows (Nvidia, AMD, Intel, XPU, NPU).
2. **Slack poster** (`post_ci_failures_to_slack.py`): Sends a condensed summary from a failure-analysis JSON to Slack (invoked by the workflow when `SGLANG_DIFFUSION_SLACK_TOKEN` is set).
## Installation
```bash
pip install requests slack_sdk
```
(`slack_sdk` is only required for `post_ci_failures_to_slack.py`.)
## Usage
### Failures Analyzer
```bash
export GITHUB_TOKEN="your_token_here"
python ci_failures_analysis.py --token $GITHUB_TOKEN --limit 50 --threshold 2
python ci_failures_analysis.py --token $GITHUB_TOKEN --limit 300 --threshold 2
python ci_failures_analysis.py --token $GITHUB_TOKEN --limit 500 --threshold 3
```
### Slack notifications
From the `scripts/ci_monitor` directory, after generating a report:
```bash
export SGLANG_DIFFUSION_SLACK_TOKEN="xoxb-..."
python post_ci_failures_to_slack.py --report-file ci_failure_analysis_YYYYMMDD_HHMMSS.json
```
## Token permissions
The GitHub token needs `repo` and `workflow` scopes to read CI run data; otherwise API calls may return 404.
### Failures Analyzer parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `--token` | Required | GitHub Personal Access Token |
| `--limit` | 500 | Number of workflow runs to analyze |
| `--threshold` | 3 | Alert threshold for consecutive failures |
| `--output` | None | Output JSON file (optional) |
## Historical note
The former **CI Monitor** workflow (`ci-monitor.yml`) and its analyzers (`ci_analyzer.py`, `ci_analyzer_perf.py`, `ci_analyzer_balance.py`) were removed as redundant; use this failure monitor workflow and scripts for ongoing CI health alerts.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,274 @@
#!/usr/bin/env python3
"""
Post CI failure analysis results to Slack.
This is a standalone script that doesn't depend on sglang package installation.
"""
import argparse
import json
import logging
import os
import sys
from datetime import datetime
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def post_ci_failures_to_slack(report_file: str) -> bool:
"""
Post CI failure report to Slack with threaded details.
Creates a parent message with summary (workflow: job1, job2, ...)
and a threaded reply with detailed failure information.
Args:
report_file: Path to JSON file containing failure analysis from ci_failures_analysis.py
Returns:
bool: True if successful, False otherwise
"""
try:
from slack_sdk import WebClient
token = os.environ.get("SGLANG_DIFFUSION_SLACK_TOKEN")
if not token:
logger.info("Slack post failed: no token")
return False
# CI failures channel
channel_id = "C0A2DG0R7CJ"
# Get GitHub run ID for linking to the workflow run
run_id = os.environ.get("GITHUB_RUN_ID", "")
# Load report data
with open(report_file, "r") as f:
report_data = json.load(f)
client = WebClient(token=token)
# Parse the real JSON structure
# The JSON has workflow sections like "pr_test_nvidia_scheduled_data", "nightly_scheduled_data"
# Each section contains jobs with their stats including "current_streak"
critical_failures = []
# Map workflow data keys to display names and hardware category
# Format: (display_name, hardware, test_type_order)
# test_type_order: 0 = PR Test, 1 = Nightly (so PR Test comes first)
workflow_info_map = {
# Nvidia
"pr_test_nvidia_scheduled_data": ("PR Test", "Nvidia", 0),
"nightly_nvidia_scheduled_data": ("Nightly", "Nvidia", 1),
# AMD
"pr_test_amd_scheduled_data": ("PR Test", "AMD", 0),
"nightly_amd_scheduled_data": ("Nightly", "AMD", 1),
# Intel/Xeon
"pr_test_xeon_scheduled_data": ("PR Test", "Intel", 0),
"nightly_intel_scheduled_data": ("Nightly", "Intel", 1),
# XPU
"pr_test_xpu_scheduled_data": ("PR Test", "XPU", 0),
# NPU
"pr_test_npu_scheduled_data": ("PR Test", "NPU", 0),
"nightly_npu_scheduled_data": ("Nightly", "NPU", 1),
}
# Hardware priority order (Nvidia first)
hardware_order = ["Nvidia", "AMD", "Intel", "XPU", "NPU"]
# Iterate through each workflow section
for workflow_key, workflow_data in report_data.items():
# Skip non-workflow keys (summary, limits, etc.)
if not isinstance(workflow_data, dict) or not any(
isinstance(v, dict) and "current_streak" in v
for v in workflow_data.values()
):
continue
# Only process scheduled workflows that are in our map
if workflow_key not in workflow_info_map:
continue
test_type, hardware, test_order = workflow_info_map[workflow_key]
# Check each job in this workflow
for job_name, job_data in workflow_data.items():
if not isinstance(job_data, dict):
continue
current_streak = job_data.get("current_streak", 0)
# Filter for jobs with streak >= 2
if current_streak >= 2:
first_failure = job_data.get("first_failure_in_streak", {})
last_failure = job_data.get("last_failure_in_streak", {})
critical_failures.append(
{
"hardware": hardware,
"test_type": test_type,
"test_order": test_order,
"job_name": job_name,
"consecutive_failures": current_streak,
"first_failed_at": (
first_failure.get("created_at", "unknown")
if first_failure
else "unknown"
),
"first_failed_url": (
first_failure.get("job_url", "")
if first_failure
else ""
),
"last_failed_at": (
last_failure.get("created_at", "unknown")
if last_failure
else "unknown"
),
"last_failed_url": (
last_failure.get("job_url", "") if last_failure else ""
),
}
)
# Group by hardware, then by test type
# Structure: {hardware: {test_type: [job_names]}}
hardware_jobs = {}
for job in critical_failures:
hardware = job.get("hardware", "Unknown")
test_type = job.get("test_type", "Unknown")
job_name = job.get("job_name", "unknown")
if hardware not in hardware_jobs:
hardware_jobs[hardware] = {}
if test_type not in hardware_jobs[hardware]:
hardware_jobs[hardware][test_type] = []
hardware_jobs[hardware][test_type].append(job_name)
# Create summary message
workflow_url = ""
if run_id:
workflow_url = (
f"https://github.com/sgl-project/sglang/actions/runs/{run_id}"
)
if not hardware_jobs:
summary = "✅ No critical failures detected in scheduled runs"
if workflow_url:
summary += f"\n<{workflow_url}|View CI Failure Monitor run>"
color = "good"
else:
# Ping relevant people when there are failures
mentions = "<@U09R55D8EAY> <@U09ABMCKQPM>"
summary_lines = [f"{mentions} 🚨 *CI Critical Failures (Scheduled Runs)*"]
# Iterate in hardware priority order, with PR Test before Nightly
test_type_order = ["PR Test", "Nightly"]
for hardware in hardware_order:
if hardware not in hardware_jobs:
continue
summary_lines.append(f"\n*{hardware}:*")
for test_type in test_type_order:
if test_type not in hardware_jobs[hardware]:
continue
jobs = hardware_jobs[hardware][test_type]
job_list = ", ".join(jobs)
summary_lines.append(f"{test_type}: {job_list}")
if workflow_url:
summary_lines.append(
f"\n<{workflow_url}|View full CI Failure Monitor report>"
)
summary = "\n".join(summary_lines)
color = "danger"
# Post parent message
response = client.chat_postMessage(
channel=channel_id,
text=summary,
attachments=[
{
"color": color,
"footer": "SGLang CI Failure Monitor",
"footer_icon": "https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png",
"ts": int(datetime.now().timestamp()),
}
],
)
thread_ts = response["ts"]
# If there are failures, post detailed breakdown in thread
if hardware_jobs:
details_lines = ["*Detailed Failure Breakdown*\n"]
# Sort critical_failures by hardware order, then test_order
hardware_order_map = {hw: i for i, hw in enumerate(hardware_order)}
sorted_failures = sorted(
critical_failures,
key=lambda x: (
hardware_order_map.get(x.get("hardware", ""), 99),
x.get("test_order", 99),
x.get("job_name", ""),
),
)
current_hardware = None
for job in sorted_failures:
hardware = job.get("hardware", "Unknown")
test_type = job.get("test_type", "Unknown")
job_name = job.get("job_name", "unknown")
consecutive = job.get("consecutive_failures", 0)
first_url = job.get("first_failed_url", "")
first_at = job.get("first_failed_at", "unknown")
last_url = job.get("last_failed_url", "")
last_at = job.get("last_failed_at", "unknown")
# Add hardware section header
if hardware != current_hardware:
details_lines.append(f"\n*━━━ {hardware} ━━━*")
current_hardware = hardware
details_lines.append(
f"• *{test_type}* → `{job_name}`\n"
f" Consecutive failures: {consecutive}\n"
f" First failed: <{first_url}|{first_at}>\n"
f" Last failed: <{last_url}|{last_at}>\n"
)
details_text = "\n".join(details_lines)
client.chat_postMessage(
channel=channel_id,
thread_ts=thread_ts,
text=details_text,
)
logger.info("CI failure report posted to Slack successfully")
return True
except Exception as e:
logger.error(f"Failed to post CI failures to Slack: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Post CI failure analysis results to Slack"
)
parser.add_argument(
"--report-file",
type=str,
required=True,
help="Path to CI failure analysis JSON report",
)
args = parser.parse_args()
success = post_ci_failures_to_slack(args.report_file)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,483 @@
"""
List commits in the private repo that need to be synced to the OSS repo.
NOTE:
1. This script resolves the git root automatically and can be run anywhere
inside the repo.
This script will:
1. Find the most recent sync commit (message starts with
"[Automated PR] Copy OSS code from commit").
2. Scan commits after that point and keep those that touch the configured paths.
3. Compare added diff lines in relevant files against OSS main.
4. Print a markdown summary with commit links and write it to GitHub Step Summary.
Usage:
python3 scripts/code_sync/check_commits.py
"""
import argparse
import os
import shutil
import subprocess
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
# Allow sibling imports regardless of the working directory.
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import ( # noqa: E402
FOLDER_NAMES,
get_last_sync_commit,
write_github_step_summary,
)
# --- Configuration Begin ---
private_repo = "your-org/sglang-private-repo"
oss_repo_url = "https://github.com/sgl-project/sglang.git"
oss_repo_branch = "main"
default_oss_repo_dir = ".oss_repo"
# --- Configuration End ---
@dataclass
class CommitInfo:
commit_hash: str
subject: str
commit_date: str
relevant_files: List[str]
synced_lines: int
total_added_lines: int
def check_dependencies() -> None:
"""Check for required command-line tools."""
if not shutil.which("git"):
raise EnvironmentError("git is not installed or not in PATH.")
def get_repo_root() -> str:
try:
output = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
capture_output=True,
text=True,
check=True,
).stdout.strip()
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Unable to determine git repo root: {e.stderr or e}") from e
if not output:
raise RuntimeError("Unable to determine git repo root.")
return os.path.abspath(output)
def get_repo_from_origin(repo_root: str) -> str:
"""Try to infer the repo slug (owner/name) from git remote.origin.url."""
try:
url = subprocess.run(
["git", "config", "--get", "remote.origin.url"],
capture_output=True,
text=True,
check=True,
cwd=repo_root,
).stdout.strip()
except subprocess.CalledProcessError:
return private_repo
if url.startswith("git@github.com:"):
repo = url.split("git@github.com:", 1)[1]
elif url.startswith("https://github.com/"):
repo = url.split("https://github.com/", 1)[1]
else:
return private_repo
if repo.endswith(".git"):
repo = repo[: -len(".git")]
return repo or private_repo
def get_default_oss_repo_path(repo_root: str) -> str:
env_path = os.environ.get("OSS_REPO_PATH")
if env_path:
return os.path.abspath(env_path)
return os.path.abspath(os.path.join(repo_root, default_oss_repo_dir))
def ensure_oss_repo(oss_repo_path: str, repo_url: str, branch: str) -> str:
oss_repo_path = os.path.abspath(oss_repo_path)
if os.path.exists(oss_repo_path) and not os.path.isdir(oss_repo_path):
raise RuntimeError(f"OSS repo path is not a directory: {oss_repo_path}")
if os.path.isdir(os.path.join(oss_repo_path, ".git")):
try:
subprocess.run(
["git", "-C", oss_repo_path, "rev-parse", "--is-inside-work-tree"],
capture_output=True,
text=True,
check=True,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"OSS repo path exists but is not a git repo: {oss_repo_path}"
) from e
subprocess.run(
["git", "-C", oss_repo_path, "fetch", "origin", branch, "--depth", "1"],
check=True,
)
return oss_repo_path
parent_dir = os.path.dirname(oss_repo_path)
if parent_dir and not os.path.isdir(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
subprocess.run(
["git", "clone", "--depth", "1", "--branch", branch, repo_url, oss_repo_path],
check=True,
)
return oss_repo_path
def get_commits_since(repo_root: str, last_sync_hash: Optional[str]) -> List[str]:
"""Get commit hashes from last sync commit (exclusive) to HEAD."""
try:
if last_sync_hash:
command = ["git", "rev-list", f"{last_sync_hash}..HEAD"]
else:
command = ["git", "rev-list", "HEAD"]
result = subprocess.run(
command, capture_output=True, text=True, check=True, cwd=repo_root
).stdout.strip()
return [line for line in result.split("\n") if line]
except subprocess.CalledProcessError as e:
print(f"Error getting commit list: {e.stderr}")
return []
def get_changed_files(repo_root: str, commit_hash: str) -> List[str]:
try:
output = subprocess.run(
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
capture_output=True,
text=True,
check=True,
cwd=repo_root,
).stdout.strip()
return [line for line in output.split("\n") if line]
except subprocess.CalledProcessError as e:
print(f"Error getting changed files for {commit_hash}: {e.stderr}")
return []
def is_relevant_path(changed_file: str, path_prefix: str) -> bool:
if changed_file == path_prefix:
return True
return changed_file.startswith(f"{path_prefix}/")
def get_relevant_files(changed_files: List[str]) -> List[str]:
return [
changed_file
for changed_file in changed_files
if any(is_relevant_path(changed_file, path) for path in FOLDER_NAMES)
]
def get_added_lines_by_file(
repo_root: str, commit_hash: str, relevant_files: List[str]
) -> Dict[str, List[str]]:
if not relevant_files:
return {}
command = [
"git",
"show",
"--no-color",
"--unified=0",
"--format=",
commit_hash,
"--",
] + relevant_files
try:
output = subprocess.run(
command, capture_output=True, text=True, check=True, cwd=repo_root
).stdout
except subprocess.CalledProcessError as e:
print(f"Error getting diff for {commit_hash}: {e.stderr}")
return {}
added_lines: Dict[str, List[str]] = {path: [] for path in relevant_files}
relevant_set = set(relevant_files)
current_file: Optional[str] = None
for line in output.splitlines():
if line.startswith("diff --git "):
current_file = None
continue
if line.startswith("+++ "):
file_path = None
if line.startswith("+++ b/"):
file_path = line[6:]
else:
candidate = line[4:]
if candidate == "/dev/null":
file_path = None
elif candidate.startswith("b/") or candidate.startswith("a/"):
file_path = candidate[2:]
else:
file_path = candidate
if file_path in relevant_set:
current_file = file_path
else:
current_file = None
continue
if current_file and line.startswith("+") and not line.startswith("+++ "):
added_lines[current_file].append(line[1:])
return added_lines
def get_oss_file_lines(
oss_repo_path: str,
oss_ref: str,
file_path: str,
cache: Dict[str, Optional[Set[str]]],
) -> Optional[Set[str]]:
if file_path in cache:
return cache[file_path]
try:
output = subprocess.run(
["git", "-C", oss_repo_path, "show", f"{oss_ref}:{file_path}"],
capture_output=True,
text=True,
errors="replace",
check=True,
).stdout
except subprocess.CalledProcessError:
cache[file_path] = None
return None
lines = output.splitlines()
line_set = set(lines)
cache[file_path] = line_set
return line_set
def count_synced_lines(
added_lines_by_file: Dict[str, List[str]],
oss_repo_path: str,
oss_ref: str,
oss_file_cache: Dict[str, Optional[Set[str]]],
) -> Tuple[int, int]:
total_added_lines = 0
synced_lines = 0
for file_path, lines in added_lines_by_file.items():
total_added_lines += len(lines)
if not lines:
continue
oss_lines = get_oss_file_lines(
oss_repo_path, oss_ref, file_path, oss_file_cache
)
if not oss_lines:
continue
for line in lines:
if line in oss_lines:
synced_lines += 1
return synced_lines, total_added_lines
def get_commit_summary(repo_root: str, commit_hash: str) -> Tuple[str, str]:
"""Return (subject, date) for a commit."""
try:
output = subprocess.run(
["git", "show", "-s", "--format=%s%x00%ad", "--date=short", commit_hash],
capture_output=True,
text=True,
check=True,
cwd=repo_root,
).stdout.strip()
subject, commit_date = output.split("\x00", 1)
except subprocess.CalledProcessError as e:
print(f"Error getting commit subject for {commit_hash}: {e.stderr}")
subject = "(unknown subject)"
commit_date = "(unknown date)"
return subject, commit_date
def format_files_list(relevant_files: List[str]) -> str:
return "\n".join([f"- {file_path}" for file_path in relevant_files])
def format_last_sync_block(
repo: str, subject: str, commit_hash: str, commit_date: str
) -> str:
short_hash = commit_hash[:9]
commit_url = f"https://github.com/{repo}/commit/{commit_hash}"
return "\n".join(
[
"## Last sync",
"",
f"#### {subject}",
f"date: {commit_date}",
f"commit: [{short_hash}]({commit_url})",
"",
]
)
def format_commit_block(
repo: str,
subject: str,
commit_hash: str,
commit_date: str,
relevant_files: List[str],
synced_lines: int,
total_added_lines: int,
) -> str:
short_hash = commit_hash[:9]
commit_url = f"https://github.com/{repo}/commit/{commit_hash}"
files_str = format_files_list(relevant_files) if relevant_files else "- None"
status_icon = "" if synced_lines == total_added_lines else ""
status_line = (
f"status: {status_icon} {synced_lines}/{total_added_lines} lines synced"
)
return "\n".join(
[
f"#### {subject}",
status_line,
f"date: {commit_date}",
"files to sync:",
files_str,
"",
f"commit: [{short_hash}]({commit_url})",
"",
]
)
def format_output(
repo: str,
last_sync: Optional[Tuple[str, str, str]],
commits: List[CommitInfo],
) -> str:
lines: List[str] = []
if last_sync:
subject, commit_hash, commit_date = last_sync
lines.append(format_last_sync_block(repo, subject, commit_hash, commit_date))
else:
lines.extend(["## Last sync", "", "No sync commit found.", ""])
lines.extend(["## Commits to sync", ""])
if not commits:
lines.append("No commits need to be synced.")
return "\n".join(lines) + "\n"
for commit in commits:
lines.append(
format_commit_block(
repo,
commit.subject,
commit.commit_hash,
commit.commit_date,
commit.relevant_files,
commit.synced_lines,
commit.total_added_lines,
)
)
return "\n".join(lines)
def main() -> None:
parser = argparse.ArgumentParser(
description="List commits in the private repo that need to be synced to OSS."
)
parser.add_argument(
"--limit",
type=int,
default=0,
help="Limit number of commits printed (0 means no limit).",
)
parser.add_argument(
"--oss-repo-path",
default=None,
help="Path to OSS repo clone (default: $OSS_REPO_PATH or .oss_repo).",
)
parser.add_argument(
"--oss-repo-url",
default=oss_repo_url,
help="OSS repo URL (default: https://github.com/sgl-project/sglang.git).",
)
parser.add_argument(
"--oss-branch",
default=oss_repo_branch,
help="OSS repo branch to check (default: main).",
)
args = parser.parse_args()
check_dependencies()
repo_root = get_repo_root()
oss_repo_path = (
os.path.abspath(args.oss_repo_path)
if args.oss_repo_path
else get_default_oss_repo_path(repo_root)
)
repo = get_repo_from_origin(repo_root)
last_sync_hash = get_last_sync_commit(repo_root)
last_sync_block = None
if last_sync_hash:
last_sync_subject, last_sync_date = get_commit_summary(
repo_root, last_sync_hash
)
last_sync_block = (last_sync_subject, last_sync_hash, last_sync_date)
commits = get_commits_since(repo_root, last_sync_hash)
if args.limit > 0:
commits = commits[: args.limit]
relevant_commit_inputs: List[Tuple[str, List[str]]] = []
for commit_hash in commits:
changed_files = get_changed_files(repo_root, commit_hash)
if not changed_files:
continue
relevant_files = get_relevant_files(changed_files)
if relevant_files:
relevant_commit_inputs.append((commit_hash, relevant_files))
relevant_commits: List[CommitInfo] = []
if relevant_commit_inputs:
oss_repo_path = ensure_oss_repo(
oss_repo_path, args.oss_repo_url, args.oss_branch
)
oss_ref = f"origin/{args.oss_branch}"
oss_file_cache: Dict[str, Optional[Set[str]]] = {}
for commit_hash, relevant_files in relevant_commit_inputs:
subject, commit_date = get_commit_summary(repo_root, commit_hash)
added_lines_by_file = get_added_lines_by_file(
repo_root, commit_hash, relevant_files
)
synced_lines, total_added_lines = count_synced_lines(
added_lines_by_file, oss_repo_path, oss_ref, oss_file_cache
)
relevant_commits.append(
CommitInfo(
commit_hash=commit_hash,
subject=subject,
commit_date=commit_date,
relevant_files=relevant_files,
synced_lines=synced_lines,
total_added_lines=total_added_lines,
)
)
output = format_output(repo, last_sync_block, relevant_commits)
print(output)
if os.environ.get("GITHUB_STEP_SUMMARY"):
write_github_step_summary(output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,273 @@
"""
Sync code from OSS repo to the local repo and open a PR if changes exist.
NOTE:
1. You need to execute this script in the git root folder.
2. A GH_TOKEN environment variable is required to create the pull request.
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
This script will:
1. Clone the sgl-project/sglang repository (or use a local copy).
2. Sync specified files and directories using rsync.
3. Check if the sync operation resulted in any changes.
4. If there are changes:
a. Create a new branch.
b. Commit and push the changes.
c. Open a pull request using the GitHub CLI (gh).
Usage:
# Run the full sync and PR creation process
python3 scripts/copy_from_oss.py
# Perform a dry run without making any actual changes
python3 scripts/copy_from_oss.py --dry-run
# Use a local directory as the source instead of cloning
python3 scripts/copy_from_oss.py --local-dir ~/projects/sglang
"""
import argparse
import datetime
import os
import shutil
import subprocess
import sys
import tempfile
# Allow sibling imports regardless of the working directory.
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import FOLDER_NAMES, write_github_step_summary # noqa: E402
# --- Configuration Begin ---
private_repo = "your-org/sglang-private-repo"
# --- Configuration End ---
def check_dependencies():
"""Check for required command-line tools."""
if not shutil.which("git"):
raise EnvironmentError("git is not installed or not in PATH.")
if not shutil.which("gh"):
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
print("✅ All dependencies (git, gh) are available.")
def checkout_main(dry_run):
"""Checkout to the main branch."""
commands = [
"git checkout main",
"git reset --hard",
]
for cmd in commands:
print(f"Run: {cmd}")
if not dry_run:
try:
subprocess.run(cmd, shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print(f"Git command failed: {e.stderr.decode()}")
raise
print("✅ Checkout the main branch.")
def get_source_folder(args):
"""
Prepare the source repository, either by cloning from GitHub or using a local directory.
Returns the path to the source repo root, a temporary directory path (if created),
and the short commit hash.
"""
temp_dir = None
if args.local_dir:
oss_root = os.path.expanduser(args.local_dir)
if not os.path.exists(oss_root):
raise FileNotFoundError(
f"Specified local directory {oss_root} does not exist."
)
print(f"Using local directory as the source: {oss_root}")
else:
temp_dir = tempfile.mkdtemp()
oss_root = temp_dir
print(f"Created temporary directory: {oss_root}")
repo_url = "https://github.com/sgl-project/sglang.git"
try:
subprocess.run(
[
"git",
"clone",
"--single-branch",
"--branch",
"main",
repo_url,
temp_dir,
],
check=True,
capture_output=True,
)
print(f"Successfully cloned repository to {temp_dir}")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr.decode()}")
raise
commit_hash = subprocess.run(
["git", "-C", oss_root, "rev-parse", "HEAD"],
capture_output=True,
text=True,
check=True,
).stdout.strip()[:8]
print(f"✅ Get source OSS code at commit: {commit_hash}")
return oss_root, temp_dir, commit_hash
def sync_directories(oss_root, sync_paths, dry_run):
"""Sync specified directories from oss_root to current working directory."""
rsync_commands = []
for folder_name in sync_paths:
target_name = f"{oss_root}/{folder_name}"
src_name = "./" + "/".join(folder_name.split("/")[:-1])
cmd = f"rsync -r --delete {target_name} {src_name}"
rsync_commands.append(cmd)
for cmd in rsync_commands:
try:
print(f"Run: {cmd}")
if not dry_run:
subprocess.run(cmd, shell=True, check=True)
except subprocess.CalledProcessError as e:
print(f"Error executing command '{cmd}': {e}")
raise
print(f"✅ Sync all folders.")
def check_for_changes():
"""Check if there are any uncommitted git changes."""
# This command exits with 1 if there are changes, 0 otherwise.
result = subprocess.run(["git", "diff", "--quiet"])
return result.returncode != 0
def create_and_push_branch(branch_name, commit_message, dry_run):
"""Create a new branch, commit all changes, and push to origin."""
commands = [
f"git checkout -b {branch_name}",
"git config user.name 'github-actions[bot]'",
"git config user.email 'github-actions[bot]@users.noreply.github.com'",
"git add .",
f"git commit -m '{commit_message}'",
f"git push origin {branch_name} --force",
]
print("\nCreating and pushing git branch...")
for cmd in commands:
print(f"Run: {cmd}")
if not dry_run:
try:
subprocess.run(cmd, shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print(f"Git command failed: {e.stderr.decode()}")
raise
def create_pull_request(branch_name, title, body, dry_run):
"""Create a pull request using the GitHub CLI."""
gh_token = os.getenv("GH_TOKEN")
if not gh_token:
print(
"\n⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
)
if not dry_run:
return
print("\nCreating pull request...")
command = [
"gh",
"pr",
"create",
"--base",
"main",
"--head",
branch_name,
"--repo",
private_repo,
"--title",
title,
"--body",
body,
]
print(f"Run: {' '.join(command)}")
if not dry_run:
env = os.environ.copy()
env["GH_TOKEN"] = gh_token
try:
result = subprocess.run(
command, check=True, capture_output=True, text=True, env=env
)
pr_url = result.stdout.strip()
msg = f"✅ Successfully created pull request: {pr_url}"
print(msg)
write_github_step_summary(msg)
except subprocess.CalledProcessError as e:
print(f"Error creating pull request: {e.stderr}")
raise
def main():
parser = argparse.ArgumentParser(
description="Copy code from OSS and open a PR if changes are detected."
)
parser.add_argument(
"--local-dir",
type=str,
help="Path to local SGLang directory to use instead of cloning from GitHub.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Dry run the script without executing git, rsync, or gh commands.",
)
args = parser.parse_args()
check_dependencies()
checkout_main(args.dry_run)
oss_root, temp_dir, oss_commit = get_source_folder(args)
try:
# Sync directories
sync_directories(oss_root, FOLDER_NAMES, args.dry_run)
# Check for changes and create PR if necessary
if not check_for_changes():
msg = "😴 No changes detected. The code is already in sync."
print(msg)
write_github_step_summary(msg)
return
print("✅ Changes detected. Proceeding to create a PR.")
current_date = datetime.datetime.now().strftime("%Y%m%d")
branch_name = f"copy-from-oss-{oss_commit}-{current_date}"
commit_message = f"Copy OSS code from {oss_commit} on {current_date}"
pr_title = (
f"[Automated PR] Copy OSS code from commit {oss_commit} on {current_date}"
)
pr_body = (
f"Copy OSS code from https://github.com/sgl-project/sglang/commit/{oss_commit} on {current_date}."
"\n\n---\n\n"
"*This is an automated PR created by scripts/copy_from_oss.py.*"
)
create_and_push_branch(branch_name, commit_message, args.dry_run)
create_pull_request(branch_name, pr_title, pr_body, args.dry_run)
finally:
# Remove temporary directory if it was created
if temp_dir:
try:
shutil.rmtree(temp_dir)
print(f"\nRemoved temporary directory: {temp_dir}")
except OSError as e:
print(f"Error removing temporary directory {temp_dir}: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,591 @@
"""
Sync a specific commit from the local private repo to the OSS upstream and open a PR.
NOTE:
1. You need to execute this script in the git root folder.
2. A GH_TOKEN environment variable is required to create the pull request.
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
This script will:
1. Take a commit hash as an argument (or use the latest commit by default).
2. Create a patch for that commit.
3. Filter the patch to only include changes in specified directories.
4. Clone the sgl-project/sglang repository.
5. Create a new branch in the OSS repo.
6. Apply the filtered patch, commit, and force push.
7. Open a pull request to the OSS repo using the GitHub CLI (gh).
Usage:
# Sync the latest commit from the current branch
python3 scripts/copy_to_oss.py
# Run the full sync and PR creation process for a given commit
python3 scripts/copy_to_oss.py --commit <commit_hash>
# Perform a dry run without making any actual changes
python3 scripts/copy_to_oss.py --commit <commit_hash> --dry-run
"""
import argparse
import datetime
import os
import re
import shutil
import subprocess
import sys
import tempfile
# Allow sibling imports regardless of the working directory.
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils import ( # noqa: E402
FOLDER_NAMES,
find_latest_oss_sync_commit,
write_github_step_summary,
)
def get_commit_info(commit_ref):
"""
Retrieves the hash and message of a specific commit.
Args:
commit_ref (str): The commit hash, tag, or branch to inspect (e.g., 'HEAD').
Returns:
A tuple containing the (commit_hash, commit_message),
or (None, None) if an error occurs.
"""
try:
# Use a custom format to get the hash (%H) and the full message (%B)
# separated by a null character for safe parsing.
command = ["git", "log", "-1", f"--pretty=%H%x00%B", commit_ref]
result = subprocess.run(
command, capture_output=True, text=True, check=True, encoding="utf-8"
)
# Split the output by the null character separator
commit_hash, commit_message = result.stdout.strip().split("\x00", 1)
return commit_hash, commit_message
except FileNotFoundError:
print("❌ Error: 'git' command not found. Is Git installed and in your PATH?")
except subprocess.CalledProcessError as e:
print(f"❌ Error getting commit info for '{commit_ref}': {e.stderr.strip()}")
print(
"Hint: Make sure you are running this from within a Git repository and the commit exists."
)
return None, None
def check_dependencies():
"""Check for required command-line tools."""
if not shutil.which("git"):
raise EnvironmentError("git is not installed or not in PATH.")
if not shutil.which("gh"):
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
print("✅ All dependencies (git, gh) are available.")
def create_filtered_patch(commit_hash, dry_run):
"""
Create a patch file for the given commit, containing only changes
to files and directories specified in `folder_names`.
"""
print(f"Creating a filtered patch for commit {commit_hash}")
try:
# Get the list of all files changed in the commit
changed_files_raw = subprocess.run(
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
capture_output=True,
text=True,
check=True,
).stdout
changed_files = changed_files_raw.strip().split("\n")
# Filter the list of files
relevant_files = [
f for f in changed_files if any(f.startswith(path) for path in FOLDER_NAMES)
]
if not relevant_files:
msg = "\n😴 No relevant file changes found in this commit. Exiting."
print(msg)
write_github_step_summary(msg)
return None, None
print("Found relevant changes in the following files:")
for f in relevant_files:
print(f" - {f}")
# Create a patch containing only the changes for the relevant files
patch_command = [
"git",
"format-patch",
"--stdout",
f"{commit_hash}^..{commit_hash}",
"--",
] + relevant_files
print(f"Run: {' '.join(patch_command)}")
patch_content = subprocess.run(
patch_command, capture_output=True, text=True, check=True
).stdout
# Save the patch to a temporary file
patch_file = tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".patch", encoding="utf-8"
)
patch_file.write(patch_content)
patch_file.close()
print(f"✅ Filtered patch created successfully at: {patch_file.name}")
return patch_file.name, relevant_files
except subprocess.CalledProcessError as e:
print(f"Error creating patch: {e.stderr}")
raise
def get_oss_repo(dry_run):
"""
Clones the OSS repository into a temporary directory.
Returns the path to the repo root and the temp directory itself.
"""
gh_token = os.getenv("GH_TOKEN")
if not gh_token:
print(
"⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
)
if not dry_run:
return
temp_dir = tempfile.mkdtemp()
oss_root = os.path.join(temp_dir, "sglang")
print(f"\nCreated temporary directory for OSS repo: {temp_dir}")
repo_url = f"https://{gh_token}@github.com/sgl-project/sglang.git"
command = ["git", "clone", repo_url, oss_root]
print(f"Run: {' '.join(command)}")
if not dry_run:
try:
subprocess.run(command, check=True, capture_output=True)
print(f"✅ Successfully cloned repository to {oss_root}")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr.decode()}")
shutil.rmtree(temp_dir)
raise
return oss_root, temp_dir
def _apply_patch(patch_file, dry_run):
"""
Try to apply a patch, falling back to --3way merge if a clean apply fails.
Returns True if the patch was applied cleanly.
Returns False if conflicts were encountered (changes are still staged
with conflict markers so a PR can be created for manual resolution).
"""
# --- Attempt 1: clean git apply ---
apply_cmd = ["git", "apply", patch_file]
print(f"Run: {' '.join(apply_cmd)}")
if dry_run:
return True
result = subprocess.run(apply_cmd, capture_output=True, text=True)
if result.returncode == 0:
print("✅ Patch applied cleanly.")
return True
print(f"⚠️ Clean apply failed:\n{result.stderr.strip()}")
print("Falling back to git apply --3way ...\n")
# --- Attempt 2: three-way merge ---
threeway_cmd = ["git", "apply", "--3way", patch_file]
print(f"Run: {' '.join(threeway_cmd)}")
result_3way = subprocess.run(threeway_cmd, capture_output=True, text=True)
if result_3way.returncode == 0:
print("✅ Patch applied via --3way merge (no conflicts).")
return True
# --- --3way left conflict markers in the working tree ---
print(f"⚠️ --3way merge had conflicts:\n{result_3way.stderr.strip()}\n")
# Show which hunks conflict
check_cmd = ["git", "apply", "--check", "--verbose", patch_file]
print(f"Run: {' '.join(check_cmd)}")
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
conflict_details = (check_result.stdout + check_result.stderr).strip()
print(
f"\n--- Conflict details ---\n{conflict_details}\n--- End conflict details ---\n"
)
# Show git diff if --3way left conflict markers
diff_result = subprocess.run(["git", "diff"], capture_output=True, text=True)
if diff_result.stdout.strip():
print(
f"\n--- git diff (conflict markers) ---\n"
f"{diff_result.stdout.strip()}\n"
f"--- End git diff ---\n"
)
# Read the patch content for the summary
with open(patch_file, "r", encoding="utf-8") as pf:
patch_content = pf.read()
# Print the patch to stdout so it's visible in the CI logs
separator = "=" * 72
print(
f"\n{separator}\n"
f"PATCH CONTENT (apply this manually):\n"
f"{separator}\n"
f"{patch_content}\n"
f"{separator}\n"
)
# Write a rich summary to the GitHub Actions step summary
summary_lines = [
"\n## ⚠️ Patch had conflicts — PR created for manual resolution\n",
"### Conflict details\n",
f"```\n{conflict_details}\n```\n",
]
if diff_result.stdout.strip():
summary_lines.append("### git diff (conflict markers)\n")
summary_lines.append(f"```diff\n{diff_result.stdout.strip()}\n```\n")
summary_lines.append("### Patch to apply manually\n")
summary_lines.append(
"<details><summary>Click to expand full patch</summary>\n\n"
f"```diff\n{patch_content}\n```\n"
"</details>\n"
)
write_github_step_summary("".join(summary_lines))
return False
def apply_patch_and_push(
oss_root, patch_file, branch_name, commit_message, base_oss_commit, dry_run
):
"""
In the OSS repo, create a branch from base_oss_commit, apply the patch,
commit, and push.
Args:
base_oss_commit: The OSS commit hash to branch from (the last sync
point). If None, the current HEAD (main) is used.
Returns True if the patch applied cleanly, False if there were conflicts
(the conflicted state is still committed and pushed so a PR can be opened).
"""
print("\nApplying patch and pushing to OSS repo...")
original_cwd = os.getcwd()
if not dry_run:
os.chdir(oss_root)
applied_cleanly = True
try:
# Check out a new branch from the base OSS commit
if base_oss_commit:
checkout_cmd = ["git", "checkout", "-b", branch_name, base_oss_commit]
else:
checkout_cmd = ["git", "checkout", "-b", branch_name]
print(f"Run: {' '.join(checkout_cmd)}")
if not dry_run:
subprocess.run(checkout_cmd, check=True, capture_output=True, text=True)
# Apply the patch (with --3way fallback and diagnostics)
applied_cleanly = _apply_patch(patch_file, dry_run)
# Configure git user and stage changes
post_apply_commands = [
["git", "config", "user.name", "github-actions[bot]"],
[
"git",
"config",
"user.email",
"github-actions[bot]@users.noreply.github.com",
],
["git", "add", "."],
]
for cmd_list in post_apply_commands:
print(f"Run: {' '.join(cmd_list)}")
if not dry_run:
subprocess.run(cmd_list, check=True, capture_output=True, text=True)
# Handle commit separately to pass multi-line message safely via stdin
commit_cmd = ["git", "commit", "-F", "-"]
print(f"Run: {' '.join(commit_cmd)}")
if not dry_run:
print(f"Commit Message:\n---\n{commit_message}\n---")
subprocess.run(
commit_cmd,
input=commit_message,
text=True,
check=True,
capture_output=True,
)
# Push the changes
push_cmd = ["git", "push", "origin", branch_name, "--force"]
print(f"Run: {' '.join(push_cmd)}")
if not dry_run:
subprocess.run(push_cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
print(f"Git command failed: {e.stderr}")
raise
finally:
if not dry_run:
os.chdir(original_cwd)
if applied_cleanly:
print("✅ Branch created, patch applied cleanly, and pushed successfully.")
else:
print(
"⚠️ Branch created and pushed with conflict markers. "
"A PR will be opened for manual resolution."
)
return applied_cleanly
def create_pull_request(oss_root, branch_name, title, body, dry_run):
"""Create a pull request in the OSS repo using the GitHub CLI."""
gh_token = os.getenv("GH_TOKEN")
if not gh_token:
print(
"⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
)
if not dry_run:
return
print("\nCreating pull request...")
command = [
"gh",
"pr",
"create",
"--base",
"main",
"--head",
branch_name,
"--repo",
"sgl-project/sglang",
"--title",
title,
"--body",
body,
]
print(f"Run: {' '.join(command)}")
if not dry_run:
env = os.environ.copy()
env["GH_TOKEN"] = gh_token
try:
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True,
env=env,
cwd=oss_root,
)
msg = f"✅ Successfully created pull request: {result.stdout.strip()}"
print(msg)
write_github_step_summary(msg)
except subprocess.CalledProcessError as e:
print(f"Error creating pull request: {e.stderr}")
# Check if a PR already exists
if "A pull request for" in e.stderr and "already exists" in e.stderr:
print(" A PR for this branch likely already exists.")
else:
raise
def get_commit_author(commit_hash):
"""Get the author name and email of a commit."""
try:
author_name = subprocess.run(
["git", "show", "-s", "--format=%an", commit_hash],
capture_output=True,
text=True,
check=True,
).stdout.strip()
author_email = subprocess.run(
["git", "show", "-s", "--format=%ae", commit_hash],
capture_output=True,
text=True,
check=True,
).stdout.strip()
return author_name, author_email
except subprocess.CalledProcessError as e:
print(f"Error getting commit author for {commit_hash}: {e.stderr}")
raise
def get_all_co_author_lines(commit_hash, commit_message):
"""
Build a deduplicated list of Co-authored-by lines that includes both
the primary commit author and any Co-authored-by trailers already
present in the commit message.
Returns a list of unique "Co-authored-by: Name <email>" strings.
"""
seen = set()
co_author_lines = []
def _add(name, email):
key = (name.strip(), email.strip().lower())
if key not in seen:
seen.add(key)
co_author_lines.append(f"Co-authored-by: {name.strip()} <{email.strip()}>")
# 1. Primary author of the commit
author_name, author_email = get_commit_author(commit_hash)
_add(author_name, author_email)
# 2. Existing Co-authored-by trailers in the commit message
for line in commit_message.splitlines():
m = re.match(r"^\s*Co-authored-by:\s*(.+?)\s*<([^>]+)>", line, re.IGNORECASE)
if m:
_add(m.group(1), m.group(2))
return co_author_lines
def main():
parser = argparse.ArgumentParser(
description="Copy a commit from the private repo to OSS and open a PR."
)
parser.add_argument(
"--commit",
type=str,
default="LAST",
help="The commit hash to sync. Defaults to 'LAST' to use the latest commit.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Dry run the script without executing git, rsync, or gh commands.",
)
args = parser.parse_args()
check_dependencies()
commit_ref = "HEAD" if args.commit == "LAST" else args.commit
commit_hash, original_commit_message = get_commit_info(commit_ref)
if not commit_hash:
return # Exit if we couldn't get commit info
# Display the details of the commit being processed
if args.commit == "LAST":
summary = (
f"\n No commit specified. Using the last commit:\n"
f" - **Hash:** `{commit_hash}`\n"
f" - **Message:** {original_commit_message}\n\n"
)
else:
summary = (
f"\n Using specified commit:\n"
f" - **Hash:** `{commit_hash}`\n"
f" - **Message:** {original_commit_message}\n\n"
)
print(summary)
write_github_step_summary(summary)
short_hash = commit_hash[:8]
patch_file = None
temp_dir = None
try:
# 1. Create a filtered patch from the local repo
patch_file, relevant_files = create_filtered_patch(commit_hash, args.dry_run)
if not patch_file:
return
# 2. Get the OSS repo
oss_root, temp_dir = get_oss_repo(args.dry_run)
# 3. Find the latest OSS commit that was synced into sglang-private.
# This is the correct base for our patch, since the private repo's
# code is based on this sync point.
base_oss_commit = find_latest_oss_sync_commit()
if base_oss_commit:
print(f" Will branch from OSS commit {base_oss_commit}")
else:
print(
"⚠️ Could not determine latest OSS sync commit. "
"Falling back to OSS main HEAD."
)
# 4. Get all co-author lines (primary author + trailers from commit message)
co_author_lines = get_all_co_author_lines(commit_hash, original_commit_message)
authors_display = "\n".join(co_author_lines)
# 5. Prepare content for the commit and PR based on changed files
file_list_str = "\n".join([f"- {f}" for f in relevant_files])
filename_list_str = ", ".join([f.split("/")[-1] for f in relevant_files])
if len(filename_list_str) > 40:
filename_list_str = filename_list_str[:40] + "..."
current_date = datetime.datetime.now().strftime("%Y%m%d")
pr_title = f"[Auto Sync] Update {filename_list_str} ({current_date})"
# 6. Create branch from the last synced OSS commit, apply patch, and push
branch_name = f"sync-{short_hash}-{current_date}"
co_authors_block = "\n".join(co_author_lines)
commit_message = f"{pr_title}\n\n{co_authors_block}"
applied_cleanly = apply_patch_and_push(
oss_root,
patch_file,
branch_name,
commit_message,
base_oss_commit,
args.dry_run,
)
# 7. Adjust PR title and body when there are conflicts
if not applied_cleanly:
pr_title = (
f"[Auto Sync][⚠️ Conflicts] Update {filename_list_str} ({current_date})"
)
pr_body_parts = [
f"Sync changes from commit `{short_hash}`.\n",
f"**Files Changed:**\n{file_list_str}\n",
f"**Authors:**\n{authors_display}",
]
if not applied_cleanly:
pr_body_parts.append(
"\n\n⚠️ **This patch had merge conflicts.** "
"The branch contains conflict markers that must be resolved manually. "
"Please check the CI logs for the full patch and conflict details."
)
pr_body_parts.append(
f"\n\n---\n\n"
f"*This is an automated PR created by scripts/copy_to_oss.py.*"
)
pr_body = "\n".join(pr_body_parts)
# 8. Create Pull Request
create_pull_request(oss_root, branch_name, pr_title, pr_body, args.dry_run)
finally:
# Cleanup temporary files
if patch_file and os.path.exists(patch_file):
os.remove(patch_file)
print(f"\nRemoved temporary patch file: {patch_file}")
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
print(f"Removed temporary directory: {temp_dir}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,51 @@
### Sync Code Between OSS and Private Fork
You can use the following principles and tools to sync the code between a private fork and the OSS repo [sgl-project/sglang](https://github.com/sgl-project/sglang/tree/main).
It learns from [Copybara](https://github.com/google/copybara), a tool used at Google for maintaining open-source code synchronization.
## Principals
- The core folders (e.g., `python/sglang/srt`) are 100% mirrored between the private fork and OSS repo.
- The OSS repo is the single source of truth. If one commit changes `python/sglang/srt` in the private repo, the change should be synced to the OSS repo as soon as possible with the action B below.
- The common code (e.g., base classes, well-known techniques in the industry without private secrets) goes to `python/sglang/srt`. The private-specific code (e.g., with private-specific features, confidential info) goes to `python/sglang/private` .
- Anytime you want to make private changes to a file or class under `python/sglang/srt`, duplicate the file and move it under `python/sglang/private`. You can achieve code reuse by importing and inheriting.
## How to sync the code bidirectionally
### Action A: Copy code from OSS to private
- We can run this action: [Open A PR to Copy Code From OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-from-oss.yml)
- It opens a PR to copy all files under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) from the OSS main branch to the private fork.
- Since the OSS repo is the single source of truth, this action copies files and overwrites any changes in the private fork. To prevent the private changes from being overwritten, you need to ensure all private changes are merged into the OSS repo before running this action.
- This action will be run automatically every day and can also be triggered manually.
### Action B: Copy diff from private to OSS
- We can run this action: [Open A PR to Copy Code To OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-to-oss.yml)
- It opens a PR to apply the diff of one specific commit of the private fork to the OSS main branch. It will only pick the changes under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) and ignore changes under private folders (e.g., `python/sglang/private` )
- For example, you can have a PR that changes both `python/sglang/srt` and `python/sglang/private/srt`. Once you merge the PR into the private repo, `python/sglang/srt` becomes desynced between the two repos. You need to run this action on your merge commit immediately to open a PR to send your diff to the OSS repo. Then, we need to merge the OSS PR as soon as possible. Once your OSS PR is merged, we can run action A again.
- Action A copies files directly, but Action B applies diff. This is because OSS is the source of truth; action A can just copy files. Action B cannot copy, so it uses diff instead.
- This action currently needs a manual trigger in order to prevent incidental code leaks. One can also consider making it automatic.
## Examples
- If you want to have some private server arguments, you can create a new file `python/sglang/private/server_args.py`. It defines a class that inherits the oss ServerArgs.
```python
from sglang.srt.server_args import ServerArgs as ServerArgsOSS
@dataclasses.dataclass
class ServerArgs(ServerArgsOSS):
private_flag: str = "foo"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Get all public args
ServerArgsOSS.add_cli_args(parser)
# Add your private flags
parser.add_argument(
"--private-flag",
type=str,
default=ServerArgs.private_flag,
)
```
- Similarly, you can inherit `Engine` and override its fields. You can override `server_args_class` to use your own ServerArgs,
override `init_tokenizer_manager_func` to use your own TokenizerManager, override `run_scheduler_process_func` to use your own scheduler.

View File

@@ -0,0 +1,18 @@
#!/bin/bash
# Check if gh is installed before attempting to install it
if ! command -v gh &> /dev/null
then
echo "GitHub CLI not found. Installing now..."
(type -p wget >/dev/null || ( apt update && apt install wget -y)) \
&& mkdir -p -m 755 /etc/apt/keyrings \
&& out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
&& cat $out | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& mkdir -p -m 755 /etc/apt/sources.list.d \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
&& apt update \
&& apt install gh -y
else
echo "GitHub CLI is already installed. Skipping installation."
fi

View File

@@ -0,0 +1,136 @@
"""
Shared constants and helpers for code-sync scripts.
"""
import os
import re
import subprocess
from typing import Optional
# --- Configuration Begin ---
# List of folders and files to copy to / from the OSS repo.
# Changes outside these paths will be ignored.
FOLDER_NAMES = [
"3rdparty",
"assets",
"benchmark",
"docker",
"docs",
"examples",
"python/sglang/lang",
"python/sglang/jit_kernel",
"python/sglang/srt",
"python/sglang/test",
"python/sglang/utils.py",
"python/sglang/README.md",
"sgl-kernel",
"test/manual",
"test/registered",
"test/srt",
"test/README.md",
"test/run_suite.py",
"README.md",
]
SYNC_COMMIT_PREFIX = r"\[Automated PR\] Copy OSS code from commit"
# --- Configuration End ---
def write_github_step_summary(content: str) -> None:
"""Append *content* to the GitHub Actions step summary (no-op outside CI)."""
summary_path = os.environ.get("GITHUB_STEP_SUMMARY")
if not summary_path:
return
with open(summary_path, "a") as f:
f.write(content)
def get_last_sync_commit(repo_root: Optional[str] = None) -> Optional[str]:
"""
Find the most recent sync commit that copied from OSS.
Returns the full private-repo commit hash, or None if not found.
The match is restricted to commits whose **subject** starts with the
sync prefix so that unrelated commits mentioning the phrase in their
body are ignored.
"""
subject_pattern = re.compile("^" + SYNC_COMMIT_PREFIX)
try:
cmd = [
"git",
"log",
"--all",
"--grep",
SYNC_COMMIT_PREFIX,
"--format=%H %s",
]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True,
cwd=repo_root,
).stdout.strip()
for line in result.splitlines():
# Format: "<full_hash> <subject>"
parts = line.split(" ", 1)
if len(parts) != 2:
continue
commit_hash, subject = parts
if subject_pattern.search(subject):
return commit_hash
return None
except subprocess.CalledProcessError as e:
print(f"Error finding last sync commit: {e.stderr}")
return None
def find_latest_oss_sync_commit(repo_root: Optional[str] = None) -> Optional[str]:
"""
Search the private repo history for the latest commit whose **subject**
matches "[Automated PR] Copy OSS code from commit {commit_id} on {date}"
and return the embedded **OSS** commit hash.
Returns the short OSS commit hash string, or None if not found.
"""
oss_hash_pattern = re.compile("^" + SYNC_COMMIT_PREFIX + r" ([0-9a-f]+)")
try:
# --grep filters on the full message body, so we request subject-only
# output and validate the pattern against the subject ourselves.
result = subprocess.run(
[
"git",
"log",
"--all",
"--grep",
SYNC_COMMIT_PREFIX,
"--pretty=%s",
],
capture_output=True,
text=True,
check=True,
cwd=repo_root,
)
for subject in result.stdout.strip().splitlines():
m = oss_hash_pattern.search(subject)
if m:
oss_commit = m.group(1)
print(
f"✅ Latest OSS sync commit found: {oss_commit} "
f"(from: {subject})"
)
return oss_commit
print(
"⚠️ No '[Automated PR] Copy OSS code from commit ...' " "found in history."
)
return None
except subprocess.CalledProcessError as e:
print(f"Error searching for OSS sync commits: {e.stderr.strip()}")
return None

View File

@@ -0,0 +1,463 @@
import argparse
import bisect
import json
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple
parser = argparse.ArgumentParser(
description="Convert SGLang OTEL trace files to Perfetto format.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-i",
"--input",
dest="input_file",
required=True,
type=str,
help="Path to the input OTEL trace file (JSON or JSONL format).",
)
parser.add_argument(
"-o",
"--output",
dest="output_file",
type=str,
default="sglang_trace_perfetto.json",
help="Path to the output Perfetto JSON file.",
)
parser.add_argument(
"-f", "--torch-file", dest="torch_file", help="specify torch profile file"
)
args = parser.parse_args()
perfetto_data = None
if args.torch_file:
with open(args.torch_file, "r", encoding="utf-8") as file:
perfetto_data = json.load(file)
baseline = perfetto_data["baseTimeNanoseconds"]
else:
baseline = 0
def id_generator():
i = 0
while True:
yield i
i += 1
relation_id_gen = id_generator()
class SpanLayoutContainer:
def __init__(self):
self.intervals = []
def check_overlap(self, start, end):
idx = bisect.bisect_left(self.intervals, (start, float("-inf")))
if idx > 0:
prev_start, prev_end = self.intervals[idx - 1]
if prev_end > start:
return True
if idx < len(self.intervals):
next_start, next_end = self.intervals[idx]
if next_start < end:
return True
return False
def insert_span(self, start, end):
bisect.insort_left(self.intervals, (start, end))
def new_metadata_level1(name: str, pid):
return {
"name": "process_name",
"ph": "M",
"pid": pid,
"args": {"name": name},
}
def new_metadata_level2(name: str, pid, slot_seq):
return {
"name": "thread_name",
"ph": "M",
"pid": pid,
"tid": slot_seq,
"args": {"name": name},
}
def __find_line(graph, trans_graph_status, slot_meta_data, pid, start, end):
if pid in trans_graph_status:
line = trans_graph_status[pid]
if start == end:
return line
# check conflict
if not graph[pid][line].check_overlap(start, end):
return line
if pid not in graph:
line = 1
graph[pid] = {line: SpanLayoutContainer()}
trans_graph_status[pid] = line
slot_meta_data.append(new_metadata_level2("slot", pid, line))
return line
for line in graph[pid]:
if not graph[pid][line].check_overlap(start, end):
trans_graph_status[pid] = line
return line
new_line = len(graph[pid]) + 1
graph[pid][new_line] = SpanLayoutContainer()
trans_graph_status[pid] = new_line
slot_meta_data.append(new_metadata_level2("slot", pid, new_line))
return new_line
OtelSpan = Dict[str, Any]
def load_otel_data(path: str | Path):
p = Path(path)
with p.open("rt", encoding="utf-8") as f:
first = f.read(1)
f.seek(0)
if first == "[":
data = json.load(f) # JSON array
else:
data = [json.loads(line) for line in f if line.strip()] # JSONL
return data
def extract_all_otel_spans(otel_data):
engine_otel_spans = []
smg_otel_spans = []
for line_data in otel_data:
for resource_spans in line_data["resourceSpans"]:
# filter: only keep spans which service.name is 'sglang' or 'smg'
service_name = ""
for attr in resource_spans["resource"]["attributes"]:
if attr["key"] == "service.name":
service_name = attr["value"]["stringValue"]
if service_name == "sglang":
spans_ref = engine_otel_spans
elif service_name == "smg":
spans_ref = smg_otel_spans
else:
continue
for scope_spans in resource_spans["scopeSpans"]:
for span in scope_spans["spans"]:
if "attributes" in span:
attributes_dict = {
attr.get("key"): next(
iter(attr.get("value", {}).values()), None
)
for attr in span["attributes"]
}
span["attributes"] = attributes_dict
else:
span["attributes"] = {}
spans_ref.append(span)
return engine_otel_spans, smg_otel_spans
def build_otel_span_tree(otel_spans):
span_id_map = {span["spanId"]: span for span in otel_spans}
for span in otel_spans:
span["child"] = []
root_spans = []
for span in otel_spans:
parent_span_id = span.get("parentSpanId", "")
if span.get("attributes", {}).get("module") == "sglang::request":
root_spans.append(span)
elif parent_span_id in span_id_map:
parent_span = span_id_map[parent_span_id]
parent_span["child"].append(span)
link_spans = []
if "links" in span:
for link in span["links"]:
link_span = span_id_map.get(link["spanId"])
if link_span:
link_spans.append(link_span)
span["links"] = link_spans
return root_spans
def __convert_to_perfetto_span(span, rid, bootstrap_room, pid, host_id):
if bootstrap_room:
span["attributes"]["bootstrap_room"] = bootstrap_room
if rid:
span["attributes"]["rid"] = rid
if host_id:
span["host_id"] = host_id
span["pid"] = pid
span["startTimeUnixNano"] = int(span["startTimeUnixNano"])
span["endTimeUnixNano"] = int(span["endTimeUnixNano"]) - 1000
ts = span["startTimeUnixNano"]
dur = span["endTimeUnixNano"] - ts
perfetto_span = {
"ph": "X",
"name": span.get("name", "unknown"),
"cat": "sglang",
"ts": (ts - baseline) / 1000.0,
"dur": dur / 1000.0,
"pid": pid,
"tid": 0,
"args": span["attributes"],
}
span["perfetto_span"] = perfetto_span
for child_span in span["child"]:
__convert_to_perfetto_span(child_span, rid, bootstrap_room, pid, host_id)
def generate_perfetto_span(engine_root_spans, smg_otel_spans, thread_meta_data):
for root_span in engine_root_spans:
root_span["spans"] = []
rid = root_span["attributes"]["rid"]
bootstrap_room = root_span["attributes"].get("bootstrap_room", "")
for thread_span in root_span["child"]:
pid = int(thread_span["attributes"]["pid"])
host_id = thread_span["attributes"]["host_id"]
thread_name = f'{thread_span["attributes"]["host_id"][:8]}:{thread_span["attributes"]["thread_label"]}'
if "tp_rank" in thread_span["attributes"]:
thread_name += f"-TP{thread_span['attributes']['tp_rank']}"
if pid not in thread_meta_data:
thread_meta_data[pid] = new_metadata_level1(thread_name, pid)
for span in thread_span["child"]:
__convert_to_perfetto_span(span, rid, bootstrap_room, pid, host_id)
root_span["spans"].append(span)
smg_pid = "smg"
thread_meta_data[smg_pid] = new_metadata_level1("smg", smg_pid)
for span in smg_otel_spans:
span["pid"] = smg_pid
__convert_to_perfetto_span(span, None, None, smg_pid, None)
def __set_span_tid(span, line):
span["perfetto_span"]["tid"] = line
for child_span in span["child"]:
__set_span_tid(child_span, line)
def generate_perfetto_span_layout(engine_root_spans, smg_otel_spans, slot_meta_data):
for root_span in engine_root_spans:
root_span["spans"] = sorted(
root_span["spans"], key=lambda x: int(x["startTimeUnixNano"])
)
engine_root_spans = sorted(
engine_root_spans, key=lambda x: int(x["spans"][0]["startTimeUnixNano"])
)
graph = {}
for root_span in engine_root_spans:
req_thread_status = {}
for span in root_span["spans"]:
line = __find_line(
graph,
req_thread_status,
slot_meta_data,
span["perfetto_span"]["pid"],
span["startTimeUnixNano"],
span["endTimeUnixNano"],
)
graph[span["perfetto_span"]["pid"]][line].insert_span(
span["startTimeUnixNano"], span["endTimeUnixNano"]
)
__set_span_tid(span, line)
smg_otel_spans = sorted(smg_otel_spans, key=lambda x: int(x["startTimeUnixNano"]))
req_thread_status = {}
for span in smg_otel_spans:
line = __find_line(
graph,
req_thread_status,
slot_meta_data,
span["perfetto_span"]["pid"],
span["startTimeUnixNano"],
span["endTimeUnixNano"],
)
graph[span["perfetto_span"]["pid"]][line].insert_span(
span["startTimeUnixNano"], span["endTimeUnixNano"]
)
span["perfetto_span"]["tid"] = line
def __convert_to_perfetto_events(span):
span["perfetto_events"] = []
if "events" in span:
for event in span["events"]:
attributes_dict = {
attr.get("key"): next(iter(attr.get("value", {}).values()), None)
for attr in event["attributes"]
}
perfetto_event = {
"ph": "i",
"cat": "sglang",
"ts": (int(event["timeUnixNano"]) - baseline) / 1000.0,
"pid": span["perfetto_span"]["pid"],
"tid": span["perfetto_span"]["tid"],
"name": event.get("name", "unknown"),
"args": attributes_dict,
}
span["perfetto_events"].append(perfetto_event)
for child_span in span["child"]:
__convert_to_perfetto_events(child_span)
def generate_perfetto_events(engine_root_spans, smg_otel_spans):
spans = [span for root_span in engine_root_spans for span in root_span["spans"]]
for span in spans:
__convert_to_perfetto_events(span)
for span in smg_otel_spans:
__convert_to_perfetto_events(span)
def generate_perfetto_links(engine_root_spans, smg_otel_spans):
# build link between engine span and smg span
span_id_map = {span["spanId"]: span for span in smg_otel_spans}
for root_span in engine_root_spans:
if "parentSpanId" in root_span and root_span["parentSpanId"] in span_id_map:
parent_span = span_id_map[root_span["parentSpanId"]]
root_span["spans"][0]["links"] = [parent_span]
for span in root_span["spans"]:
span["perfetto_links"] = []
if "links" in span:
for link_span in span["links"]:
try:
link_perfetto_span = link_span["perfetto_span"]
except (KeyError, AttributeError):
continue
if "correlation" in link_perfetto_span["args"]:
id = link_perfetto_span["args"]["correlation"]
else:
id = next(relation_id_gen)
link_perfetto_span["args"]["correlation"] = id
perfetto_start_node = {
"ph": "s",
"id": id,
"pid": link_perfetto_span["pid"],
"tid": link_perfetto_span["tid"],
"ts": link_perfetto_span["ts"],
"cat": "ac2g",
"name": "ac2g",
}
perfetto_end_node = {
"ph": "f",
"id": id,
"pid": span["perfetto_span"]["pid"],
"tid": span["perfetto_span"]["tid"],
"ts": span["perfetto_span"]["ts"],
"cat": "ac2g",
"name": "ac2g",
"bp": "e",
}
span["perfetto_links"].append(perfetto_start_node)
span["perfetto_links"].append(perfetto_end_node)
def __gather_one_span(span):
elems = []
elems.append(span["perfetto_span"])
if "perfetto_events" in span:
elems.extend(span["perfetto_events"])
if "perfetto_links" in span:
elems.extend(span["perfetto_links"])
for child_span in span["child"]:
elems.extend(__gather_one_span(child_span))
return elems
def gather_all_perfetto_elems(
engine_root_spans, smg_otel_spans, thread_meta_data, slot_meta_data
):
elems = []
elems.extend(thread_meta_data.values())
elems.extend(slot_meta_data)
for root_span in engine_root_spans:
for span in root_span["spans"]:
elems.extend(__gather_one_span(span))
for span in smg_otel_spans:
elems.append(span["perfetto_span"])
elems.extend(span["perfetto_events"])
return elems
def write_json(perfetto_elems):
global perfetto_data
if args.torch_file:
perfetto_data["traceEvents"].extend(perfetto_elems)
filered_data = [
item
for item in perfetto_data["traceEvents"]
if item.get("cat") != "gpu_user_annotation"
]
perfetto_data["traceEvents"] = filered_data
else:
perfetto_data = perfetto_elems
with open(args.output_file, "w", encoding="utf-8") as file:
json.dump(perfetto_data, file, ensure_ascii=False, indent=4)
def main():
start_time = time.time()
otel_data = load_otel_data(args.input_file)
engine_otel_spans, smg_otel_spans = extract_all_otel_spans(otel_data)
engine_root_spans = build_otel_span_tree(engine_otel_spans)
thread_meta_data = {}
generate_perfetto_span(engine_root_spans, smg_otel_spans, thread_meta_data)
slot_meta_data = []
generate_perfetto_span_layout(engine_root_spans, smg_otel_spans, slot_meta_data)
generate_perfetto_events(engine_root_spans, smg_otel_spans)
generate_perfetto_links(engine_root_spans, smg_otel_spans)
perfetto_elems = gather_all_perfetto_elems(
engine_root_spans, smg_otel_spans, thread_meta_data, slot_meta_data
)
write_json(perfetto_elems)
end_time = time.time()
execution_time = end_time - start_time
print(f"\nConversion finished successfully!")
print(f"Output written to: {args.output_file}")
print(f"Execution time: {execution_time * 1000:.4f} ms")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,115 @@
"""
Export NextN layer for DeepSeek-V3/R1 model. The exported model can be used for speculative decoding.
Usage:
python3 export_deepseek_nextn.py --input-dir /path/to/DeepSeek-V3 --output-dir /path/to/DeepSeek-V3-NextN
"""
import argparse
import json
import os
import shutil
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import AutoConfig
def get_nextn_layer_id(config):
if not hasattr(config, "num_hidden_layers"):
raise ValueError("'num_hidden_layers' not found in model config.")
return config.num_hidden_layers
def update_and_save_config(config, output_dir):
new_config = config.to_dict()
new_config.update(
{
"num_hidden_layers": 1,
"architectures": ["DeepseekV3ForCausalLMNextN"],
}
)
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(new_config, f, indent=2, ensure_ascii=False, sort_keys=True)
def copy_non_safetensors_files(input_dir, output_dir):
for filename in os.listdir(input_dir):
src_file_path = os.path.join(input_dir, filename)
if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"):
dst_file_path = os.path.join(output_dir, filename)
shutil.copy2(src_file_path, dst_file_path)
print(f"All non-safetensors files have been copied to {output_dir}")
def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
prefix = f"model.layers.{nextn_layer_id}"
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
params = {}
for filename in os.listdir(input_dir):
if not filename.endswith(".safetensors"):
continue
file_path = os.path.join(input_dir, filename)
print(f"Processing: {filename}")
try:
with safe_open(file_path, framework="pt") as f:
matching_keys = [k for k in f.keys() if k.startswith(prefix)]
if not matching_keys:
print(f" No parameters starting with '{prefix}' found")
continue
for key in matching_keys:
if "embed_tokens" in key or "shared_head.head" in key:
continue
new_key = key.replace(prefix, "model.layers.0")
params[new_key] = f.get_tensor(key)
except Exception as e:
print(f" Error processing {filename}: {str(e)}")
if params:
print(f"Saving {len(params)} parameters to {output_path}")
save_file(params, output_path)
else:
print("No matching parameters found.")
# Update safetensors index
index_path = os.path.join(output_dir, "model.safetensors.index.json")
print(f"Updating safetensors index to {index_path}")
index_data = {"weight_map": {}}
for key in params:
index_data["weight_map"][key] = "nextn_layer_parameters.safetensors"
with open(index_path, "w") as f:
json.dump(index_data, f, indent=4)
print("All done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export NextN layer parameters for DeepSeek-V3/R1"
)
parser.add_argument(
"--input-dir",
type=str,
required=True,
help="Input HF model directory.",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output nextn model directory.",
)
args = parser.parse_args()
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
nextn_layer_id = get_nextn_layer_id(config)
os.makedirs(args.output_dir, exist_ok=True)
copy_non_safetensors_files(args.input_dir, args.output_dir)
update_and_save_config(config, args.output_dir)
export_nextn_layer_parameters(args.input_dir, args.output_dir, nextn_layer_id)

60
third_party/sglang/scripts/killall_sglang.sh vendored Executable file
View File

@@ -0,0 +1,60 @@
#!/bin/bash
# DEPRECATED: This script will be migrated to python/sglang/cli/killall.py.
# CI mode is already handled there. This script remains for local/non-CI usage.
#
# TODO: Migrate remaining modes (rocm, all, gpus) to killall.py and remove this file.
#
# Usage:
# ./killall_sglang.sh - Kill SGLang processes only (NVIDIA mode)
# ./killall_sglang.sh rocm - Kill SGLang processes only (ROCm mode)
# ./killall_sglang.sh all - Kill all GPU processes (NVIDIA mode)
# ./killall_sglang.sh gpus 0,1,2,3 - Kill all processes on specific GPUs
if [ "$1" = "rocm" ]; then
echo "Running in ROCm mode"
# Clean SGLang processes
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_diffusion::' | xargs -r kill -9
elif [ "$1" = "gpus" ] && [ -n "$2" ]; then
# Kill all processes on specific GPUs only
echo "Killing all processes on GPUs: $2"
# Show current GPU status
nvidia-smi
# Build device file list from GPU IDs (e.g., "0,1,2,3" -> "/dev/nvidia0 /dev/nvidia1 ...")
devices=$(echo "$2" | tr ',' '\n' | sed 's/^[[:space:]]*//;s/[[:space:]]*$//' | sed 's|^|/dev/nvidia|' | tr '\n' ' ')
echo "Targeting devices: $devices"
# Kill all processes using specified GPU devices
[ -n "$devices" ] && lsof $devices 2>/dev/null | awk 'NR>1 {print $2}' | sort -u | xargs -r kill -9 2>/dev/null
# Show GPU status after clean up
nvidia-smi
else
# Show current GPU status
nvidia-smi
# Clean SGLang processes
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_diffusion::' | xargs -r kill -9
# Clean all GPU processes if "all" argument is provided
if [ "$1" = "all" ]; then
# Check if sudo is available
if command -v sudo >/dev/null 2>&1; then
sudo apt-get update
sudo apt-get install -y lsof
else
apt-get update
apt-get install -y lsof
fi
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
fi
# Show GPU status after clean up
nvidia-smi
fi

View File

@@ -0,0 +1,319 @@
"""
Usage:
# single GPU
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
# multiple GPU
python3 bench_speculative.py --model-path deepseek-ai/DeepSeek-V3 --speculative-draft-model-path lmsys/DeepSeek-V3-NextN --tp-size 8 --trust-remote-code --batch-size 1 4 8 16 32 --steps 0 1 2 --topk 0 1 2 4 --num_draft_tokens 0 2 4 8
"""
import argparse
import asyncio
import json
import os
import time
from types import SimpleNamespace
from typing import List
import numpy as np
import requests
from transformers import AutoTokenizer
from sglang.bench_serving import benchmark, set_global_args
from sglang.benchmark.datasets import DatasetRow
from sglang.benchmark.datasets.mmmu import sample_mmmu_requests
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
kill_process_tree,
popen_launch_server,
)
def node0_print(msg):
if server_args.node_rank == 0:
print(msg)
prompts = [
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if its children then you can talk about animals; If its adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
]
class FakeTokenizer:
def encode(self, text: str, add_special_tokens: bool = False):
return []
def send_one_batch(base_url, num_prompts, batch_size, processor, is_multimodal):
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
if is_multimodal:
backend = "sglang-oai-chat"
api_url = f"{base_url}/v1/chat/completions"
input_requests = sample_mmmu_requests(
num_prompts,
processor,
backend=backend,
fixed_output_len=512,
)
tokenizer = processor.tokenizer
else:
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
input_requests: List[DatasetRow] = [
DatasetRow(p, 0, 512) for p in padded_prompts
]
backend = "sglang"
api_url = f"{base_url}/generate"
tokenizer = processor
# We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace(
disable_ignore_eos=False,
disable_stream=False,
return_logprob=False,
return_routed_experts=False,
plot_throughput=False,
backend=backend,
dataset_name="custom",
num_prompts=None,
sharegpt_output_len=None,
random_input_len=None,
random_output_len=None,
random_range_ratio=None,
output_file=None,
warmup_requests=1,
output_details=False,
)
set_global_args(args)
# Run benchmark
results = asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id="default",
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=float("inf"),
max_concurrency=batch_size,
disable_tqdm=False,
lora_names=None,
lora_request_distribution=None,
lora_zipf_alpha=None,
extra_request_body={},
profile=None,
)
)
assert results["completed"] == len(input_requests)
acc_length = results["accept_length"] or 1.0
avg_output_token = results["total_output_tokens"] / results["completed"]
server_info = requests.get(base_url + "/server_info").json()
# We use 20% percentile instead of median on purpose
step_time = np.percentile(
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
)
speed = 1 / step_time * acc_length
return (
round(acc_length, 3),
round(step_time, 5),
round(speed, 3),
avg_output_token,
)
def main(args, server_args):
base_url = "http://127.0.0.1:20000"
configs = []
for batch_size in args.batch_size:
for steps in args.steps:
for topk in args.topk:
for num_draft_tokens in args.num_draft_tokens:
if steps * topk + 1 < num_draft_tokens:
continue
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
steps + topk + num_draft_tokens != 0
):
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
continue
configs.append((batch_size, steps, topk, num_draft_tokens))
for i in range(args.start, args.end or len(configs)):
batch_size, steps, topk, num_draft_tokens = configs[i]
node0_print(
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
)
# Create an LLM.
if steps == 0:
other_args = []
else:
other_args = [
"--speculative-num-steps",
steps,
"--speculative-eagle-topk",
topk,
"--speculative-num-draft-tokens",
num_draft_tokens,
]
if server_args.speculative_draft_model_path is not None:
other_args.extend(
[
"--speculative-draft-model-path",
server_args.speculative_draft_model_path,
"--speculative-algorithm",
server_args.speculative_algorithm,
]
)
other_args.extend(
[
"--cuda-graph-max-bs",
batch_size,
"--mem-fraction-static",
server_args.mem_fraction_static,
"--tp-size",
server_args.tp_size,
"--max-running-requests",
batch_size,
]
)
if server_args.trust_remote_code:
other_args.extend(
[
"--trust-remote-code",
]
)
if server_args.attention_backend:
other_args.extend(
[
"--attention-backend",
server_args.attention_backend,
]
)
if server_args.quantization:
other_args.extend(
[
"--quantization",
server_args.quantization,
]
)
process = popen_launch_server(
args.model_path,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
env={
"SGLANG_RECORD_STEP_TIME": "1",
**os.environ,
},
)
if args.is_multimodal:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
else:
processor = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
try:
# Warmup
send_one_batch(
base_url, batch_size, batch_size, processor, args.is_multimodal
)
# Benchmark
acc_length, step_time, speed, completion_tokens = send_one_batch(
base_url,
max(args.num_prompts, batch_size),
batch_size,
processor,
args.is_multimodal,
)
finally:
kill_process_tree(process.pid)
node0_print(
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
)
record = {
"batch_size": batch_size,
"steps": steps,
"topk": topk,
"num_draft_tokens": num_draft_tokens,
"acc_length": acc_length,
"step_time": step_time,
"speed": speed,
"completion_tokens": completion_tokens,
}
with open(args.output, "a") as fout:
fout.write(json.dumps(record) + "\n")
# Wait for the server to shutdown
time.sleep(5)
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--batch-size",
type=int,
nargs="+",
default=(1, 2, 4, 8, 16),
)
parser.add_argument(
"--steps",
type=int,
nargs="+",
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
)
parser.add_argument(
"--topk",
type=int,
nargs="+",
default=(0, 1, 2, 4, 8),
)
parser.add_argument(
"--num_draft_tokens",
type=int,
nargs="+",
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
)
parser.add_argument("--num-prompts", type=int, default=16)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int)
parser.add_argument("--output", type=str, default="output.jsonl")
parser.add_argument("--is-multimodal", action="store_true", default=False)
args = parser.parse_args()
server_args: ServerArgs = ServerArgs.from_cli_args(args)
main(args, server_args)

View File

@@ -0,0 +1,22 @@
prompt = "The capital of france is "
import json
import requests
response = requests.post(
"http://0.0.0.0:8000/generate",
json={
"text": prompt,
"sampling_params": {"temperature": 0},
"return_logprob": True,
"return_input_logprob": True,
"logprob_start_len": 0,
},
)
j = response.json()
input_logprobs = j["meta_info"]["input_token_logprobs"]
output_logprobs = j["meta_info"]["output_token_logprobs"]
print(len(input_logprobs), len(output_logprobs))

View File

@@ -0,0 +1,34 @@
import json
import requests
port = 8000
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
# JSON
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json_schema,
},
},
)
print(response.json())
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100

View File

@@ -0,0 +1,29 @@
import json
import requests
prompt = """
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
Give your honest take on the above text:
"""
response = requests.post(
"http://0.0.0.0:8000/generate",
json={"text": prompt, "sampling_params": {"temperature": 0}},
)
response_json = response.json()
print(response_json["text"])

View File

@@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch A Server\n",
"\n",
"Launch the server with a reasoning model (Qwen 3.5-4B) and reasoning parser."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sglang import separate_reasoning, assistant_begin, assistant_end\n",
"from sglang import assistant, function, gen, system, user\n",
"from sglang import image\n",
"from sglang import RuntimeEndpoint, set_default_backend\n",
"from sglang.srt.utils import load_image\n",
"from sglang.test.test_utils import is_in_ci\n",
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen3-4B --reasoning-parser qwen3 --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\", process=server_process)\n",
"print(f\"Server started on http://localhost:{port}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the default backend. Note: you can set chat_template_name in RontimeEndpoint. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"set_default_backend(\n",
" RuntimeEndpoint(f\"http://localhost:{port}\", chat_template_name=\"qwen\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start with a basic question-answering task. And see how the reasoning content is generated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def basic_qa(s, question):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(question)\n",
" s += assistant_begin()\n",
" s += gen(\"answer\", max_tokens=512)\n",
" s += assistant_end()\n",
"\n",
"\n",
"state = basic_qa(\"List 3 countries and their capitals.\")\n",
"print_highlight(state[\"answer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With `separate_reasoning`, you can move the reasoning content to `{param_name}_reasoning_content` in the state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def basic_qa_separate_reasoning(s, question):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(question)\n",
" s += assistant_begin()\n",
" s += separate_reasoning(gen(\"answer\", max_tokens=512), model_type=\"qwen3\")\n",
" s += assistant_end()\n",
"\n",
"\n",
"reasoning_state = basic_qa_separate_reasoning(\"List 3 countries and their capitals.\")\n",
"print_highlight(reasoning_state.stream_executor.variable_event.keys())\n",
"print_highlight(\n",
" f\"\\nSeparated Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
")\n",
"\n",
"print_highlight(f\"\\n\\nContent:\\n{reasoning_state['answer']}\")\n",
"print_highlight(f\"\\n\\nMessages:\\n{reasoning_state.messages()[-1]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`separate_reasoning` can also be used in multi-turn conversations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def multi_turn_qa(s):\n",
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
" s += user(\"Please give me a list of 3 countries and their capitals.\")\n",
" s += assistant(\n",
" separate_reasoning(gen(\"first_answer\", max_tokens=512), model_type=\"qwen3\")\n",
" )\n",
" s += user(\"Please give me another list of 3 countries and their capitals.\")\n",
" s += assistant(\n",
" separate_reasoning(gen(\"second_answer\", max_tokens=512), model_type=\"qwen3\")\n",
" )\n",
" return s\n",
"\n",
"\n",
"reasoning_state = multi_turn_qa()\n",
"print_highlight(f\"\\n\\nfirst_answer:\\n{reasoning_state['first_answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nfirst_answer_reasoning_content:\\n{reasoning_state['first_answer_reasoning_content']}\"\n",
")\n",
"print_highlight(f\"\\n\\nsecond_answer:\\n{reasoning_state['second_answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nsecond_answer_reasoning_content:\\n{reasoning_state['second_answer_reasoning_content']}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using No thinking as Qwen 3's advanced feature \n",
"\n",
"sglang separate_reasoning is particularly useful when combined with Qwen 3's advanced feature.\n",
"\n",
"[Qwen 3's advanced usages](https://qwenlm.github.io/blog/qwen3/#advanced-usages)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reasoning_state = basic_qa_separate_reasoning(\n",
" \"List 3 countries and their capitals. /no_think\"\n",
")\n",
"print_highlight(f\"Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\")\n",
"print_highlight(f\"Content:\\n{reasoning_state['answer']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`separate_reasoning` can also be used in regular expression generation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def regular_expression_gen(s):\n",
" s += user(\n",
" \"What is the IP address of the Google DNS servers? just provide the answer\"\n",
" )\n",
" s += assistant(\n",
" separate_reasoning(\n",
" gen(\n",
" \"answer\",\n",
" temperature=0,\n",
" regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n",
" max_tokens=512,\n",
" ),\n",
" model_type=\"qwen3\",\n",
" ),\n",
" )\n",
"\n",
"\n",
"reasoning_state = regular_expression_gen()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_highlight(f\"Answer:\\n{reasoning_state['answer']}\")\n",
"print_highlight(\n",
" f\"\\n\\nReasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
")"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,14 @@
import argparse
import code
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
)
args = parser.parse_args()
t = get_tokenizer(args.name)
code.interact(local=locals())

View File

@@ -0,0 +1,36 @@
from urllib.request import urlopen
from openai import OpenAI
test_cases = {
"64k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt",
"200k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt",
"600k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
"1m": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt",
}
client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1")
for name, url in test_cases.items():
print(f"\n==== Running test case: {name} ====")
try:
with urlopen(url, timeout=10) as response:
prompt = response.read().decode("utf-8")
except Exception as e:
print(f"Failed to load prompt for {name}: {e}")
continue
try:
response = client.chat.completions.create(
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
messages=[{"role": "user", "content": prompt}],
stream=True,
max_tokens=128,
temperature=0,
)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="", flush=True)
except Exception as e:
print(f"\nError during completion for {name}: {e}")

View File

@@ -0,0 +1,77 @@
import glob
import json
import os
import re
import sys
from tqdm import tqdm
sys.path.append("../../")
from fix_corrupted_json import clean_json_file
dirpath = "/Users/ying"
output_file_prefix = "analyzed_log"
time = {}
tot_time = {}
size = {}
os.system(f"rm {output_file_prefix}*")
for dirname in glob.glob(os.path.join(dirpath, "trace*")):
print(dirname)
trace_name = dirname.split("/")[-1]
time[trace_name] = {}
size[trace_name] = {}
total_time = 0
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
step_name = filename.split("/")[-1].split(".")[0]
step_name = "_".join(step_name.split("_")[1:])
if "prefill" not in filename and "decode" not in filename:
continue
match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
if match:
phase = match.group(1)
step = match.group(2)
else:
raise Exception(f"Cannot parse {filename}")
try:
with open(filename, "r") as f:
trace = json.load(f)
except:
clean_json_file(filename, filename)
with open(filename, "r") as f:
trace = json.load(f)
for event in trace["traceEvents"]:
name = event["name"]
if name in ["profile_prefill_step", "profile_decode_step"]:
dur = event["dur"] / 1e3
time[trace_name][step_name] = dur
break
total_time += dur
step = int(step_name.split("_")[-1])
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
size_info = json.load(f)
size[trace_name][step_name] = size_info["size"]
tot_time[trace_name] = total_time
time[trace_name] = dict(
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)
size[trace_name] = dict(
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
)
with open(f"{output_file_prefix}_{trace_name}", "a") as f:
for k, v in time[trace_name].items():
size_v = size[trace_name][k]
print(f"{k:>15}{v:10.2f}\t{size_v}")
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")
with open(f"{output_file_prefix}_total_time", "w") as f:
print(tot_time)
json.dump(tot_time, f)

View File

@@ -0,0 +1,62 @@
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
ADAPTER = "/home/ying/test_lora"
HF_TOKEN = "..."
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
base_model = LlamaForCausalLM.from_pretrained(
MODEL,
device_map="auto",
# load_in_8bit=True,
torch_dtype=torch.float16,
# use_auth_token=HF_TOKEN,
).cuda()
# base model generate
with torch.no_grad():
output_tensors = base_model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= base output ========")
print(output)
# peft model generate
model = PeftModel.from_pretrained(
base_model,
ADAPTER,
torch_dtype=torch.float16,
is_trainable=False,
)
with torch.no_grad():
output_tensors = model.generate(
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
max_new_tokens=32,
do_sample=False,
)[0]
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
print("======= peft output ========")
print(output)

View File

@@ -0,0 +1,30 @@
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
ADAPTER = "/home/ying/test_lora"
prompt = """
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
"""
llm = LLM(model=MODEL, enable_lora=True)
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
)
prompts = [prompt]
outputs = llm.generate(
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
)
print(outputs[0].prompt)
print(outputs[0].outputs[0].text)

View File

@@ -0,0 +1,197 @@
"""
Usage: python3 scripts/playground/reference_hf.py --model-path MODEL_PATH --model-type {text,vlm} [--max-new-tokens NUM] [--dtype DTYPE]
--model-path MODEL_PATH: Path to model (default: TinyLlama/TinyLlama-1.1B-Chat-v0.4)
--model-type {text,vlm}: Model type, text or vlm (default: text)
--max-new-tokens NUM: Max new tokens to generate (default: 16)
--dtype DTYPE: Data type for computation (default: float16)
Note: '--model' is deprecated; use '--model-path'. Runs normal_text() for text, vlm_text_with_image() for vlm.
Reference output:
========== Prompt 0 ==========
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
device='cuda:0')
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
device='cuda:0')
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of
========== Prompt 2 ==========
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
device='cuda:0')
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the
"""
import argparse
import requests
import torch
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
)
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
@torch.no_grad()
def vlm_text_with_image(args):
# Load the processor and model for ImageTextToText tasks
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
torch.cuda.set_device(0)
# List of image URLs to process
image_urls = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]
# Conversation template for the processor
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
},
{"type": "text", "text": "Describe this image."},
],
}
]
max_new_tokens = args.max_new_tokens
for i, url in enumerate(image_urls):
# Load the image from the URL
image = Image.open(requests.get(url, stream=True).raw)
# Apply the chat template to the text prompt
# Notice that not all processors support chat templates.
# LLaVA and QWen are two processors that support chat templates.
if not hasattr(processor, "apply_chat_template"):
raise ValueError("The processor does not support chat templates.")
text_prompt = processor.apply_chat_template(
conversation, add_generation_prompt=True
)
# Prepare inputs for the model
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
"cuda:0"
)
# Generate output from the model
output_ids = model.generate(
**inputs, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = processor.decode(output_ids[0])
# Get the logits from the model's forward pass
outputs = model.forward(**inputs)
logits = outputs.logits[0, -1, :]
print(f"\n========== Image {i} ==========")
print("prefill logits (final)", logits)
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
# making it cluttered and difficult to read.
# These tokens should be removed or cleaned up for better readability.
print(output_str)
@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=args.dtype,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = args.max_new_tokens
torch.cuda.set_device(0)
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
else:
input_ids = torch.tensor([p], device="cuda:0")
output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = t.decode(output_ids[0])
prefill_logits = m.forward(input_ids).logits[0][-1]
print(f"\n========== Prompt {i} ==========")
print("prefill logits (final)", prefill_logits)
print(output_str)
@torch.no_grad()
def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
input_len = 256
output_len = 8
prompts = [list(range(5, 5 + input_len))]
for p in prompts:
input_ids = p
for i in range(output_len + 1):
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
0
][-1]
if i == 0:
print("prefill logits", prefill_logits)
else:
print("decode", i - 1, prefill_logits)
input_ids.append(torch.argmax(prefill_logits).item())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
)
parser.add_argument("--max-new-tokens", type=int, default=16)
parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--model-type", type=str, default="text")
args = parser.parse_args()
if args.model_type == "vlm":
vlm_text_with_image(args)
else:
normal_text(args)

View File

@@ -0,0 +1,181 @@
"""
Usage:
# replay from a folder
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
# replay from a single file
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
"""
import argparse
import glob
import json
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from datetime import datetime
import requests
from sglang.benchmark.utils import set_ulimit
from sglang.srt.utils.common import safe_pickle_load
from sglang.utils import get_exception_traceback
def normalize_mm_data_item(item):
if isinstance(item, dict) and "url" in item:
return item["url"]
return item
def normalize_mm_data(data):
if data is None:
return None
if isinstance(data, list):
return [
(
[normalize_mm_data_item(item) for item in sublist]
if isinstance(sublist, list)
else normalize_mm_data_item(sublist)
)
for sublist in data
]
return normalize_mm_data_item(data)
def normalize_request_data(json_data):
"""Normalize multimodal fields in request data for replay compatibility."""
for field in ["image_data", "video_data", "audio_data"]:
if field in json_data and json_data[field] is not None:
json_data[field] = normalize_mm_data(json_data[field])
return json_data
def read_records(files):
records = []
for f in files:
with open(f, "rb") as fh:
tmp = safe_pickle_load(fh)
if isinstance(tmp, dict) and "requests" in tmp:
records.extend(tmp["requests"])
else:
records.extend(tmp)
return records
def run_one_request_internal(record):
req, output, replay_init_time, start_time, end_time, idx = record
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
if "completion_tokens" in output.get("meta_info", {}):
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
else:
recorded_completion_tokens = ""
json_data = normalize_request_data(asdict(req))
stream = json_data["stream"]
if args.ignore_eos:
json_data["sampling_params"]["ignore_eos"] = True
if recorded_completion_tokens:
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=stream,
)
if stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
print(
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
f"{completion_tokens=}, {recorded_completion_tokens=}"
)
def run_one_request(record):
# global success_ct, error_ct
try:
run_one_request_internal(record)
# success_ct += 1
except Exception:
# error_ct += 1
traceback = get_exception_traceback()
print(f"Hit an exception: {traceback}")
def main(records):
if len(records) == 0:
return
base_time = records[0][-2]
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
print(f"{base_time_str=}")
replay_init_time = time.time()
for i in range(len(records)):
req, output, start_time, end_time = records[i]
start_time -= base_time
records[i] = (req, output, replay_init_time, start_time, end_time, i)
with ThreadPoolExecutor(args.parallel) as executor:
executor.map(run_one_request, records)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--input-folder", type=str, default=None, help="Folder containing pickle files"
)
parser.add_argument(
"--input-file", type=str, default=None, help="Single pickle file to process"
)
parser.add_argument("--file-number", type=int, default=1)
parser.add_argument("--req-number", type=int, default=1000000)
parser.add_argument("--req-start", type=int, default=0)
parser.add_argument("--parallel", type=int, default=512)
parser.add_argument("--idx", type=int, default=None)
parser.add_argument("--ignore-eos", action="store_true")
parser.add_argument("--speed", type=float, default=1)
args = parser.parse_args()
set_ulimit()
files = []
if args.input_file:
files = [args.input_file]
if args.file_number > 1:
print("Warning: --file-number is ignored when --input-file is provided.")
elif args.input_folder:
files = glob.glob(f"{args.input_folder}/*.pkl")
files = files[: args.file_number]
else:
print("Error: Either --input-folder or --input-file must be provided.")
exit(1)
print(f"{files=}")
records = read_records(files)
# Sort by the receive time, before filtering
records.sort(key=lambda x: x[-2])
records = records[args.req_start :]
if args.idx:
records = [records[args.idx]]
print(f"testing {args.idx=}")
print(f"{records[0]}")
print(f"{len(records)=}")
main(records)

View File

@@ -0,0 +1,207 @@
import random
import string
import time
import unittest
from typing import Dict, List, Tuple
from tree import MultiTenantRadixTree
class TestMultiTenantRadixTree(unittest.TestCase):
def setUp(self):
self.tree = MultiTenantRadixTree()
def test_insert_exact_match(self):
"""Test 1: Basic insert and exact match operations"""
# Insert a single string for one tenant
self.tree.insert("hello", "tenant1")
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
# Insert same string for different tenant
self.tree.insert("hello", "tenant2")
matched, tenant = self.tree.prefix_match("hello")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Insert different string for same tenant
self.tree.insert("world", "tenant1")
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant1")
print(self.tree.pretty_print())
def test_insert_partial_match(self):
"""Test 2: Insert with partial matching scenarios"""
# Test partial matches with common prefixes
self.tree.insert("hello", "tenant1")
print(self.tree.pretty_print())
self.tree.insert("help", "tenant2")
print(self.tree.pretty_print())
# Match exact strings
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
matched, tenant = self.tree.prefix_match("help")
self.assertEqual(matched, "help")
self.assertEqual(tenant, "tenant2")
# Match partial string
matched, tenant = self.tree.prefix_match("hel")
self.assertEqual(matched, "hel")
self.assertIn(tenant, ["tenant1", "tenant2"])
# Match longer string
matched, tenant = self.tree.prefix_match("hello_world")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")
def test_insert_edge_cases(self):
"""Test 3: Edge cases for insert and match operations"""
# Empty string
self.tree.insert("", "tenant1")
matched, tenant = self.tree.prefix_match("")
self.assertEqual(matched, "")
self.assertEqual(tenant, "tenant1")
# Single character
self.tree.insert("a", "tenant1")
matched, tenant = self.tree.prefix_match("a")
self.assertEqual(matched, "a")
self.assertEqual(tenant, "tenant1")
# Very long string
long_str = "a" * 1000
self.tree.insert(long_str, "tenant1")
matched, tenant = self.tree.prefix_match(long_str)
self.assertEqual(matched, long_str)
self.assertEqual(tenant, "tenant1")
# Unicode characters
self.tree.insert("你好", "tenant1")
matched, tenant = self.tree.prefix_match("你好")
self.assertEqual(matched, "你好")
self.assertEqual(tenant, "tenant1")
def test_simple_eviction(self):
"""Test 4: Simple eviction scenarios
Tenant1: limit 10 chars
Tenant2: limit 5 chars
Should demonstrate:
1. Basic eviction when size limit exceeded
2. Proper eviction based on last access time
3. Verification that shared nodes remain intact for other tenants
"""
# Set up size limits
max_size = {"tenant1": 10, "tenant2": 5}
# Insert strings for both tenants
self.tree.insert("hello", "tenant1") # size 5
self.tree.insert("hello", "tenant2") # size 5
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
# Evict - should remove "hello" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
# Verify "world" remains for tenant2 (was accessed more recently)
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant2")
def test_medium_eviction(self):
"""Test 5: Medium complexity eviction scenarios with shared prefixes
Tenant1: limit 10 chars
Tenant2: limit 7 chars (forces one string to be evicted)
Tree structure after inserts:
└── 'h' [t1, t2]
├── 'i' [t1, t2] # Oldest for t2
└── 'e' [t1, t2]
├── 'llo' [t1, t2]
└── 'y' [t2] # Newest for t2
Size calculations:
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
After eviction (tenant2 exceeds limit by 1 char):
"hi" should be removed from tenant2 as it's the oldest access
"""
max_size = {
"tenant1": 10,
"tenant2": 6,
} # tenant2 will need to evict one string
# Create a tree with overlapping prefixes
self.tree.insert("hi", "tenant1")
self.tree.insert("hi", "tenant2") # OLDEST for t2
self.tree.insert("hello", "tenant1")
self.tree.insert("hello", "tenant2")
self.tree.insert("hey", "tenant2") # NEWEST for t2
# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
self.assertEqual(
sizes_before["tenant2"], 7
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
print("\nTree before eviction:")
print(self.tree.pretty_print())
# Evict - should remove "hi" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)
print("\nTree after eviction:")
print(self.tree.pretty_print())
# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
def test_advanced_eviction(self):
...
# Create 4 tenants
# Each tenants keeps adding strings with shared prefixes to thousands usage
# Set a strict limit for each tenant to only 100
# At the end, check whether all of the tenant is under 100 after eviction
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
for i in range(100):
for j, prefix in enumerate(prefixes):
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
sizes_before = self.tree.get_used_size_per_tenant()
print(sizes_before)
self.tree.evict_tenant_data(max_size)
sizes_after = self.tree.get_used_size_per_tenant()
print(sizes_after)
# ensure size_after is below max_size
for tenant, size in sizes_after.items():
self.assertLessEqual(size, max_size[tenant])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,292 @@
import time
from collections import defaultdict
from typing import Dict, List
class Node:
def __init__(self):
self.children: Dict[str, Node] = dict()
# We choose to use text because most of the use cases are text-to-text,
# so we can save the tokenizing overhead.
self.text: str = ""
# Maps tenant_id to their last access timestamp
self.tenant_last_access_time: Dict[str, float] = dict()
self.parent = None
def shared_prefix_length(s1, s2):
min_length = min(len(s1), len(s2))
for i in range(min_length):
if s1[i] != s2[i]:
return i
return min_length
class MultiTenantRadixTree:
"""
Python Reference of Rust implementation of MultiTenantRadixTree
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
while maintaining tenant isolation.
Key concepts:
- Tenant: An entity that owns a subset of the stored strings
- Each node tracks which tenants have access to it via tenant_last_access_time
- The tree structure is shared, but queries can be filtered by tenant_id
"""
def __init__(self):
self.root = Node()
def insert(self, s: str, tenant_id: str) -> None:
"""
Insert string 's' and associate it with the given tenant_id.
Args:
s: The string to insert
tenant_id: The identifier of the tenant who owns this string
"""
curr = self.root
curr_idx = 0
curr.tenant_last_access_time[tenant_id] = time.time()
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
# No match => create a new node
new_node = Node()
new_node.text = s[curr_idx:]
new_node.parent = curr
curr.children[s[curr_idx]] = new_node
curr_idx = len(s)
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
else:
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
# 1. If the matched text is shorter than the node text => split the node
if shared_len < len(matched_node.text):
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
matched_text = matched_node.text[:shared_len]
unmatched_text = matched_node.text[shared_len:]
new_node = Node()
new_node.text = matched_text
new_node.children = {unmatched_text[0]: matched_node}
new_node.parent = curr
new_node.parent.children[matched_text[0]] = new_node
new_node.tenant_last_access_time = (
matched_node.tenant_last_access_time.copy()
)
# Contract matched node
matched_node.text = unmatched_text
matched_node.parent = new_node
curr_idx += shared_len
curr = new_node
curr.tenant_last_access_time[tenant_id] = time.time()
# 2. If the matched text is longer or equal to the node text => walk down the node
else:
curr_idx += shared_len
curr = matched_node
curr.tenant_last_access_time[tenant_id] = time.time()
def prefix_match(self, s: str) -> tuple[str, int]:
"""
Match string 's' with multiple tenants' trees in one operation.
Args:
s: The string to match
Returns:
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
"""
curr = self.root
curr_idx = 0
ret_text = ""
ret_tenant = None
while curr_idx < len(s):
matched_node = None
if s[curr_idx] in curr.children:
matched_node = curr.children[s[curr_idx]]
if matched_node is None:
break
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
if shared_len == len(matched_node.text):
curr_idx += shared_len
curr = matched_node
else:
curr_idx += shared_len
curr = matched_node
break
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
# traverse back to the root to update last access time for the selected tenant
while curr != self.root:
curr.tenant_last_access_time[selected_tenant] = time.time()
curr = curr.parent
return s[:curr_idx], selected_tenant
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
"""
Evict data for tenants that have exceeded their storage limits.
Args:
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
"""
def leaf_of(node):
"""
If the node is a leaf for a tenant, add tenant_id to the return list
This will return list of tenant ids
If not a leaf for all tenants, return []
"""
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
for n in node.children.values():
for c in n.tenant_last_access_time.keys():
candidates[c] = False
return [k for k, v in candidates.items() if v]
# maintain a heap with (time, tenant, node) as the value
import heapq
# 1. traverse the tree to
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
# b. calculate the used size for each tenant
# do a dfs with stack
stack = [self.root]
pq = []
used_size_per_tenant = defaultdict(int)
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
# if the node is a leaf for a tenant, add the tenant to the heap
tenants = leaf_of(curr)
for t in tenants:
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
# 2. pop the heap
# a. if the tenant's used size is less than the limit, continue
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
while len(pq) > 0:
time, tenant, node = heapq.heappop(pq)
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
continue
# remove the leaf
used_size_per_tenant[tenant] -= len(node.text)
del node.tenant_last_access_time[tenant]
# if no children and no tenants, remove the node
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
del node.parent.children[node.text[0]]
# add its parent to the heap
if tenant in leaf_of(node.parent):
heapq.heappush(
pq,
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
)
def get_used_size_per_tenant(self) -> Dict[str, int]:
"""
Calculate the used storage size for each tenant.
Returns:
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
"""
used_size_per_tenant = defaultdict(int)
stack = [self.root]
while stack:
curr = stack.pop()
for t in curr.tenant_last_access_time.keys():
used_size_per_tenant[t] += len(curr.text)
for c in curr.children.values():
stack.append(c)
return used_size_per_tenant
def remove_tenant(self, tenant_id: str) -> None:
"""
Remove all data associated with a specific tenant from the tree.
This operation maintains the integrity of the shared tree structure while
removing only the specified tenant's access information.
Args:
tenant_id: The identifier of the tenant whose data should be removed
"""
# TODO: Implementation needed
pass
def pretty_print(self) -> str:
"""
Returns a string representation of the tree showing the structure, tenant ownership,
and leaf status for each node.
Returns:
str: A formatted string showing the tree hierarchy with tenant information
"""
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
# Current node representation
node_str = prefix
node_str += "└── " if is_last else "├── "
# Add node text
node_str += f"'{node.text}' ["
# Add tenant information including both timestamp and leaf status
tenant_info = []
for tid, ts in node.tenant_last_access_time.items():
time_str = (
time.strftime("%H:%M:%S.", time.localtime(ts))
+ f"{(ts % 1):0.3f}"[2:]
)
tenant_info.append(f"{tid} | {time_str}")
node_str += ", ".join(tenant_info)
node_str += "]\n"
# Handle children
children = list(node.children.items())
for i, (char, child) in enumerate(children):
is_last_child = i == len(children) - 1
# Adjust prefix for children based on whether this is the last child
new_prefix = prefix + (" " if is_last else "")
node_str += _node_to_str(child, new_prefix, is_last_child)
return node_str
if not self.root.children:
return "Empty tree"
# Start with root's children since root itself is just an empty node
result = ""
children = list(self.root.children.items())
for i, (char, child) in enumerate(children):
is_last = i == len(children) - 1
result += _node_to_str(child, "", is_last)
return result

View File

@@ -0,0 +1,96 @@
# Release Scripts
This directory contains scripts to automate version bumping for SGLang releases.
## Scripts
### `bump_sglang_version.py`
Updates SGLang version across all relevant files following the pattern from [PR #10468](https://github.com/sgl-project/sglang/pull/10468).
**Usage:**
```bash
python scripts/release/bump_sglang_version.py 0.5.3rc0
```
**Files updated:**
- `Makefile`
- `benchmark/deepseek_v3/README.md`
- `docker/rocm.Dockerfile`
- `docs/get_started/install.md`
- `docs/platforms/amd_gpu.md`
- `docs/platforms/ascend_npu.md`
- `python/pyproject.toml`
- `python/pyproject_other.toml`
- `python/pyproject_npu.toml`
- `python/sglang/version.py`
### `bump_kernel_version.py`
Updates the `sglang-kernel` release version across all relevant files following the pattern from [PR #10732](https://github.com/sgl-project/sglang/pull/10732).
**Usage:**
```bash
python scripts/release/bump_kernel_version.py 0.4.0
```
**Files updated:**
- `sgl-kernel/pyproject.toml`
- `sgl-kernel/pyproject_cpu.toml`
- `sgl-kernel/pyproject_rocm.toml`
- `sgl-kernel/pyproject_musa.toml`
- `sgl-kernel/python/sgl_kernel/version.py`
## Manual Testing Instructions
### Test SGLang Version Bump
1. **Run the script:**
```bash
python scripts/release/bump_sglang_version.py 0.5.4rc0
```
2. **Verify changes with git diff:**
```bash
git diff
```
3. **Check specific files contain the new version:**
```bash
grep -r "0.5.4rc0" python/sglang/version.py
grep -r "0.5.4rc0" python/pyproject.toml
grep -r "0.5.4rc0" docs/get_started/install.md
```
4. **Reset changes (if testing):**
```bash
git checkout .
```
### Test Kernel Version Bump
1. **Run the script:**
```bash
python scripts/release/bump_kernel_version.py 0.4.0
```
2. **Verify changes with git diff:**
```bash
git diff
```
3. **Check specific files contain the new version:**
```bash
grep -r "0.4.0" sgl-kernel/python/sgl_kernel/version.py
grep -r "0.4.0" sgl-kernel/pyproject.toml
```
4. **Reset changes (if testing):**
```bash
git checkout .
```
## Version Format Validation
- **SGLang versions:** `X.Y.Z` or `X.Y.ZrcN` (e.g., `0.5.3` or `0.5.3rc0`)
- **Kernel versions:** `X.Y.Z` (e.g., `0.4.0`)
The scripts will validate the version format and exit with an error if invalid.

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
import argparse
import re
import sys
from pathlib import Path
from utils import compare_versions, get_repo_root, normalize_version, validate_version
FILES_TO_UPDATE = [
Path("python/pyproject.toml"),
Path("docker/Dockerfile"),
Path("python/sglang/srt/entrypoints/engine.py"),
Path("python/sglang/srt/utils/common.py"),
]
def read_current_flashinfer_version(repo_root: Path) -> str:
"""Read the current flashinfer version from python/pyproject.toml."""
pyproject = repo_root / "python" / "pyproject.toml"
content = pyproject.read_text()
match = re.search(
r"flashinfer_python==(\d+\.\d+\.\d+(?:rc\d+|\.post\d+)?)", content
)
if not match:
raise ValueError(f"Could not find flashinfer_python version in {pyproject}")
return match.group(1)
def replace_flashinfer_version(
file_path: Path, old_version: str, new_version: str
) -> bool:
if not file_path.exists():
print(f"Warning: {file_path} does not exist, skipping")
return False
content = file_path.read_text()
new_content = content
name = file_path.name
if name == "pyproject.toml":
new_content = new_content.replace(
f"flashinfer_python=={old_version}", f"flashinfer_python=={new_version}"
)
new_content = new_content.replace(
f"flashinfer_cubin=={old_version}", f"flashinfer_cubin=={new_version}"
)
elif name == "Dockerfile":
new_content = re.sub(
rf"(ARG FLASHINFER_VERSION=){re.escape(old_version)}",
rf"\g<1>{new_version}",
new_content,
)
elif name == "engine.py":
new_content = re.sub(
r'(assert_pkg_version\(\s*"flashinfer_python",\s*)"'
+ re.escape(old_version)
+ r'"',
r'\g<1>"' + new_version + '"',
new_content,
flags=re.DOTALL,
)
elif name == "common.py":
new_content = new_content.replace(
f'e.g., "{old_version}"',
f'e.g., "{new_version}"',
)
if content == new_content:
print(f"No changes needed in {file_path}")
return False
file_path.write_text(new_content)
print(f"✓ Updated {file_path}")
return True
def main():
parser = argparse.ArgumentParser(
description="Bump flashinfer version across all relevant files"
)
parser.add_argument(
"new_version",
help="New version (e.g., 0.6.4, 0.6.4rc0, or 0.6.4.post1)",
)
args = parser.parse_args()
new_version = normalize_version(args.new_version)
if not validate_version(new_version):
print(f"Error: Invalid version format: {new_version}")
print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN")
print("Examples: 0.6.4, 0.6.4rc0, 0.6.4.post1")
sys.exit(1)
repo_root = get_repo_root()
old_version = read_current_flashinfer_version(repo_root)
print(f"Current flashinfer version: {old_version}")
print(f"New flashinfer version: {new_version}")
print()
comparison = compare_versions(new_version, old_version)
if comparison == 0:
print("Error: New version is the same as current version")
sys.exit(1)
elif comparison < 0:
print(
f"Error: New version ({new_version}) is older than current version ({old_version})"
)
print("Version must be greater than the current version")
sys.exit(1)
updated_count = 0
for file_rel in FILES_TO_UPDATE:
file_abs = repo_root / file_rel
if replace_flashinfer_version(file_abs, old_version, new_version):
updated_count += 1
print()
print(f"Successfully updated {updated_count} file(s)")
print(f"Flashinfer version bumped from {old_version} to {new_version}")
print("\nValidating version updates...")
failed_files = []
for file_rel in FILES_TO_UPDATE:
file_abs = repo_root / file_rel
if not file_abs.exists():
print(f"Warning: File {file_rel} does not exist, skipping validation.")
continue
content = file_abs.read_text()
if new_version not in content:
failed_files.append(file_rel)
print(f"{file_rel} does not contain version {new_version}")
else:
print(f"{file_rel} validated")
if failed_files:
print(f"\nError: {len(failed_files)} file(s) were not updated correctly:")
for file_rel in failed_files:
print(f" - {file_rel}")
sys.exit(1)
print("\nAll files validated successfully!")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
import argparse
from pathlib import Path
from utils import bump_version
def main():
parser = argparse.ArgumentParser(
description="Bump sgl-kernel version across all relevant files"
)
parser.add_argument(
"new_version",
help="New version (e.g., 0.3.12, 0.3.11rc0, or 0.3.11.post1)",
)
args = parser.parse_args()
version_file = Path("sgl-kernel/python/sgl_kernel/version.py")
files_to_update = [
Path("sgl-kernel/pyproject.toml"),
Path("sgl-kernel/pyproject_cpu.toml"),
Path("sgl-kernel/pyproject_rocm.toml"),
Path("sgl-kernel/pyproject_musa.toml"),
Path("sgl-kernel/python/sgl_kernel/version.py"),
]
bump_version(args.new_version, version_file, files_to_update)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python3
"""
Bump sglang-kernel version in SGLang files to match the version in sgl-kernel/pyproject.toml.
Updates:
- python/pyproject.toml
- python/sglang/srt/entrypoints/engine.py
- docker/Dockerfile
"""
import re
import sys
from pathlib import Path
try:
import tomllib # Python 3.11+
except ImportError:
import tomli as tomllib # Fallback for older Python versions
def get_kernel_version_from_source() -> str:
"""Extract version from sgl-kernel/pyproject.toml"""
pyproject_path = Path("sgl-kernel/pyproject.toml")
if not pyproject_path.exists():
print(f"Error: {pyproject_path} not found")
sys.exit(1)
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)
version = data.get("project", {}).get("version")
if not version:
print("Error: Could not find version in sgl-kernel/pyproject.toml")
sys.exit(1)
return version
def update_python_pyproject(new_version: str) -> bool:
"""Update sglang-kernel version in python/pyproject.toml"""
pyproject_path = Path("python/pyproject.toml")
if not pyproject_path.exists():
print(f"Error: {pyproject_path} not found")
sys.exit(1)
content = pyproject_path.read_text()
# Replace "sglang-kernel==x.x.x" with new version
new_content = re.sub(
r'"sglang-kernel==[^"]+"',
f'"sglang-kernel=={new_version}"',
content,
)
if content == new_content:
print("No changes needed in python/pyproject.toml")
return False
pyproject_path.write_text(new_content)
print(f"✓ Updated python/pyproject.toml to version {new_version}")
return True
def update_engine_py(new_version: str) -> bool:
"""Update sglang-kernel version in python/sglang/srt/entrypoints/engine.py"""
engine_path = Path("python/sglang/srt/entrypoints/engine.py")
if not engine_path.exists():
print(f"Error: {engine_path} not found")
sys.exit(1)
content = engine_path.read_text()
# Replace version in assert_pkg_version("sglang-kernel", "version", ...)
new_content = re.sub(
r'(assert_pkg_version\s*\(\s*"sglang-kernel"\s*,\s*)"[^"]+"',
rf'\1"{new_version}"',
content,
)
if content == new_content:
print("No changes needed in engine.py")
return False
engine_path.write_text(new_content)
print(f"✓ Updated engine.py to version {new_version}")
return True
def update_dockerfile(new_version: str) -> bool:
"""Update SGL_KERNEL_VERSION in docker/Dockerfile"""
dockerfile_path = Path("docker/Dockerfile")
if not dockerfile_path.exists():
print(f"Error: {dockerfile_path} not found")
sys.exit(1)
content = dockerfile_path.read_text()
# Replace ARG SGL_KERNEL_VERSION=x.x.x with new version
new_content = re.sub(
r"^(ARG\s+SGL_KERNEL_VERSION=)(.+)$",
rf"\g<1>{new_version}",
content,
flags=re.MULTILINE,
)
if content == new_content:
print("No changes needed in Dockerfile")
return False
dockerfile_path.write_text(new_content)
print(f"✓ Updated Dockerfile to version {new_version}")
return True
def main():
kernel_version = get_kernel_version_from_source()
print(f"Bumping sglang-kernel version to: {kernel_version}\n")
updated_files = []
if update_python_pyproject(kernel_version):
updated_files.append("python/pyproject.toml")
if update_engine_py(kernel_version):
updated_files.append("python/sglang/srt/entrypoints/engine.py")
if update_dockerfile(kernel_version):
updated_files.append("docker/Dockerfile")
print()
if updated_files:
print(f"✓ Successfully updated {len(updated_files)} file(s):")
for file in updated_files:
print(f" - {file}")
else:
print("✓ All files already have the correct version")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""
Check if sglang-kernel version from sgl-kernel/pyproject.toml matches the versions
used in SGLang files (python/pyproject.toml, engine.py, and Dockerfile).
Sets GitHub Actions output variables to indicate if sync is needed.
"""
import os
import re
import sys
from pathlib import Path
try:
import tomllib # Python 3.11+
except ImportError:
import tomli as tomllib # Fallback for older Python versions
def get_kernel_version_from_source() -> str:
"""Extract version from sgl-kernel/pyproject.toml (line 11)"""
pyproject_path = Path("sgl-kernel/pyproject.toml")
if not pyproject_path.exists():
print(f"Error: {pyproject_path} not found")
sys.exit(1)
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)
version = data.get("project", {}).get("version")
if not version:
print("Error: Could not find version in sgl-kernel/pyproject.toml")
sys.exit(1)
return version
def get_kernel_version_from_python_pyproject() -> str:
"""Extract sglang-kernel version from python/pyproject.toml"""
pyproject_path = Path("python/pyproject.toml")
if not pyproject_path.exists():
print(f"Error: {pyproject_path} not found")
sys.exit(1)
content = pyproject_path.read_text()
# Match "sglang-kernel==x.x.x"
match = re.search(r'"sglang-kernel==([^"]+)"', content)
if not match:
print("Error: Could not find sglang-kernel version in python/pyproject.toml")
sys.exit(1)
return match.group(1)
def get_kernel_version_from_engine() -> str:
"""Extract sglang-kernel version from python/sglang/srt/entrypoints/engine.py"""
engine_path = Path("python/sglang/srt/entrypoints/engine.py")
if not engine_path.exists():
print(f"Error: {engine_path} not found")
sys.exit(1)
content = engine_path.read_text()
# Find the assert_pkg_version call for sglang-kernel
# Look for the pattern: assert_pkg_version("sglang-kernel", "version", ...)
match = re.search(
r'assert_pkg_version\s*\(\s*"sglang-kernel"\s*,\s*"([^"]+)"', content
)
if not match:
print("Error: Could not find sglang-kernel version in engine.py")
sys.exit(1)
return match.group(1)
def get_kernel_version_from_dockerfile() -> str:
"""Extract SGL_KERNEL_VERSION from docker/Dockerfile"""
dockerfile_path = Path("docker/Dockerfile")
if not dockerfile_path.exists():
print(f"Error: {dockerfile_path} not found")
sys.exit(1)
content = dockerfile_path.read_text()
# Match ARG SGL_KERNEL_VERSION=x.x.x
match = re.search(r"^ARG\s+SGL_KERNEL_VERSION=(.+)$", content, re.MULTILINE)
if not match:
print("Error: Could not find SGL_KERNEL_VERSION in Dockerfile")
sys.exit(1)
return match.group(1).strip()
def main():
kernel_version = get_kernel_version_from_source()
pyproject_version = get_kernel_version_from_python_pyproject()
engine_version = get_kernel_version_from_engine()
dockerfile_version = get_kernel_version_from_dockerfile()
print(f"Kernel version in sgl-kernel/pyproject.toml: {kernel_version}")
print(
f"SGLang kernel dependency version in python/pyproject.toml: {pyproject_version}"
)
print(f"SGLang kernel dependency version in engine.py: {engine_version}")
print(f"Kernel version in Dockerfile: {dockerfile_version}")
# Check if any version differs from the source
needs_sync = (
kernel_version != pyproject_version
or kernel_version != engine_version
or kernel_version != dockerfile_version
)
# Set GitHub Actions output
github_output = os.getenv("GITHUB_OUTPUT")
if github_output:
with open(github_output, "a") as f:
f.write(f"needs_sync={'true' if needs_sync else 'false'}\n")
f.write(f"kernel_version={kernel_version}\n")
if needs_sync:
print(f"\n✓ Sync needed to version: {kernel_version}")
mismatches = []
if kernel_version != pyproject_version:
mismatches.append(
f" - python/pyproject.toml: {pyproject_version}{kernel_version}"
)
if kernel_version != engine_version:
mismatches.append(f" - engine.py: {engine_version}{kernel_version}")
if kernel_version != dockerfile_version:
mismatches.append(
f" - Dockerfile: {dockerfile_version}{kernel_version}"
)
print("Changes needed:")
for mismatch in mismatches:
print(mismatch)
sys.exit(0)
else:
print("\n✓ All versions are in sync, no action needed")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,72 @@
#!/bin/bash
set -e
# Script to commit version bump changes and create a pull request
# Usage: commit_and_pr.sh <version_type> <new_version> <branch_name>
#
# Arguments:
# version_type: "SGLang" or "sgl-kernel"
# new_version: The new version number
# branch_name: The git branch name to push to
VERSION_TYPE="$1"
NEW_VERSION="$2"
BRANCH_NAME="$3"
if [ -z "$VERSION_TYPE" ] || [ -z "$NEW_VERSION" ] || [ -z "$BRANCH_NAME" ]; then
echo "Error: Missing required arguments"
echo "Usage: $0 <version_type> <new_version> <branch_name>"
exit 1
fi
# Get changed files and format them
echo "Getting changed files..."
FILES_LIST=$(git diff --name-only | sed 's/^/- /')
COMMIT_FILES=$(git diff --name-only | sed 's/^/ - /')
# Commit changes
echo "Committing changes..."
git add -A
git commit -m "chore: bump ${VERSION_TYPE} version to ${NEW_VERSION}
This commit updates the ${VERSION_TYPE} version across all relevant files:
${COMMIT_FILES}
🤖 Generated with GitHub Actions"
# Push changes
echo "Pushing to ${BRANCH_NAME}..."
git push origin "${BRANCH_NAME}"
# Create pull request
echo "Creating pull request..."
PR_URL=$(gh pr create \
--title "chore: bump ${VERSION_TYPE} version to ${NEW_VERSION}" \
--body "## Summary
This PR bumps the ${VERSION_TYPE} version to \`${NEW_VERSION}\` across all relevant files.
## Files Updated
${FILES_LIST}
🤖 Generated with GitHub Actions" \
--base main \
--head "${BRANCH_NAME}")
echo "✓ Pull request created successfully"
# Add GitHub Actions job summary
if [ -n "$GITHUB_STEP_SUMMARY" ]; then
cat >> "$GITHUB_STEP_SUMMARY" <<EOF
## ✅ Version Bump Complete
**Version Type:** ${VERSION_TYPE}
**New Version:** \`${NEW_VERSION}\`
### 📝 Pull Request Created
${PR_URL}
### 📦 Files Updated
${FILES_LIST}
EOF
fi

View File

@@ -0,0 +1,81 @@
#!/bin/bash
set -e
# Script to commit kernel version bump changes to SGLang and create a pull request
# Usage: commit_and_pr_kernel_to_sglang.sh <kernel_version> <branch_name>
#
# Arguments:
# kernel_version: The kernel version being synced
# branch_name: The git branch name to push to
KERNEL_VERSION="$1"
BRANCH_NAME="$2"
if [ -z "$KERNEL_VERSION" ] || [ -z "$BRANCH_NAME" ]; then
echo "Error: Missing required arguments"
echo "Usage: $0 <kernel_version> <branch_name>"
exit 1
fi
# Get changed files and format them
echo "Getting changed files..."
FILES_LIST=$(git diff --name-only | sed 's/^/- /')
COMMIT_FILES=$(git diff --name-only | sed 's/^/ - /')
# Commit changes
echo "Committing changes..."
git add -A
git commit -m "chore: bump sglang-kernel version to ${KERNEL_VERSION} in SGLang
This commit updates the sglang-kernel version across SGLang files to match
the version defined in sgl-kernel/pyproject.toml.
Files updated:
${COMMIT_FILES}
🤖 Generated with GitHub Actions"
# Push changes
echo "Pushing to ${BRANCH_NAME}..."
git push origin "${BRANCH_NAME}"
# Create pull request
echo "Creating pull request..."
PR_URL=$(gh pr create \
--title "chore: bump sglang-kernel version to ${KERNEL_VERSION}" \
--body "## Summary
This PR bumps the \`sglang-kernel\` version to \`${KERNEL_VERSION}\` across SGLang files to match the version defined in \`sgl-kernel/pyproject.toml\`.
**Kernel Version:** \`${KERNEL_VERSION}\`
## Files Updated
${FILES_LIST}
## Context
The kernel version in \`sgl-kernel/pyproject.toml\` has been updated. This PR ensures that all SGLang files referencing the \`sglang-kernel\` dependency are updated accordingly:
- \`python/pyproject.toml\` - dependency specification
- \`python/sglang/srt/entrypoints/engine.py\` - version check
- \`docker/Dockerfile\` - Docker build argument
🤖 Generated with GitHub Actions" \
--base main \
--head "${BRANCH_NAME}")
echo "✓ Pull request created successfully"
# Add GitHub Actions job summary
if [ -n "$GITHUB_STEP_SUMMARY" ]; then
cat >> "$GITHUB_STEP_SUMMARY" <<EOF
## ✅ Kernel Version Bump Complete
**Kernel Version:** \`${KERNEL_VERSION}\`
### 📝 Pull Request Created
${PR_URL}
### 📦 Files Updated
${FILES_LIST}
EOF
fi

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python3
import unittest
from pathlib import Path
from utils import compare_versions, normalize_version, parse_version, validate_version
class TestVersionUtils(unittest.TestCase):
def test_normalize_version(self):
"""Test version normalization removes 'v' prefix."""
self.assertEqual(normalize_version("v0.5.3"), "0.5.3")
self.assertEqual(normalize_version("0.5.3"), "0.5.3")
self.assertEqual(normalize_version("v0.5.3rc0"), "0.5.3rc0")
self.assertEqual(normalize_version("0.5.3.post1"), "0.5.3.post1")
def test_validate_version(self):
"""Test version format validation."""
# Valid formats
self.assertTrue(validate_version("0.5.3"))
self.assertTrue(validate_version("0.5.3rc0"))
self.assertTrue(validate_version("0.5.3rc1"))
self.assertTrue(validate_version("0.5.3rc999"))
self.assertTrue(validate_version("0.5.3.post1"))
self.assertTrue(validate_version("0.5.3.post10"))
self.assertTrue(validate_version("1.2.3"))
self.assertTrue(validate_version("10.20.30"))
# Invalid formats
self.assertFalse(validate_version("0.5"))
self.assertFalse(validate_version("0.5.3."))
self.assertFalse(validate_version("0.5.3rc"))
self.assertFalse(validate_version("0.5.3post1"))
self.assertFalse(validate_version("0.5.3-rc0"))
self.assertFalse(validate_version("v0.5.3"))
self.assertFalse(validate_version("0.5.3beta1"))
self.assertFalse(validate_version("0.5.3.rc0"))
def test_parse_version_stable(self):
"""Test parsing stable version."""
self.assertEqual(parse_version("0.5.3"), (0, 5, 3, 0, 0))
self.assertEqual(parse_version("1.2.3"), (1, 2, 3, 0, 0))
self.assertEqual(parse_version("10.20.30"), (10, 20, 30, 0, 0))
def test_parse_version_rc(self):
"""Test parsing release candidate versions."""
self.assertEqual(parse_version("0.5.3rc0"), (0, 5, 3, -1000, 0))
self.assertEqual(parse_version("0.5.3rc1"), (0, 5, 3, -999, 0))
self.assertEqual(parse_version("0.5.3rc2"), (0, 5, 3, -998, 0))
self.assertEqual(parse_version("0.5.3rc10"), (0, 5, 3, -990, 0))
def test_parse_version_post(self):
"""Test parsing post-release versions."""
self.assertEqual(parse_version("0.5.3.post1"), (0, 5, 3, 0, 1))
self.assertEqual(parse_version("0.5.3.post2"), (0, 5, 3, 0, 2))
self.assertEqual(parse_version("0.5.3.post10"), (0, 5, 3, 0, 10))
def test_parse_version_invalid(self):
"""Test parsing invalid versions raises error."""
with self.assertRaises(ValueError):
parse_version("0.5")
with self.assertRaises(ValueError):
parse_version("invalid")
with self.assertRaises(ValueError):
parse_version("v0.5.3")
def test_compare_versions_equal(self):
"""Test comparing equal versions."""
self.assertEqual(compare_versions("0.5.3", "0.5.3"), 0)
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3rc0"), 0)
self.assertEqual(compare_versions("0.5.3.post1", "0.5.3.post1"), 0)
def test_compare_versions_rc_ordering(self):
"""Test release candidate ordering: rc0 < rc1 < rc2 < stable."""
# rc0 < rc1
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3rc1"), -1)
self.assertEqual(compare_versions("0.5.3rc1", "0.5.3rc0"), 1)
# rc1 < rc2
self.assertEqual(compare_versions("0.5.3rc1", "0.5.3rc2"), -1)
self.assertEqual(compare_versions("0.5.3rc2", "0.5.3rc1"), 1)
# rc < stable
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3"), -1)
self.assertEqual(compare_versions("0.5.3rc1", "0.5.3"), -1)
self.assertEqual(compare_versions("0.5.3", "0.5.3rc0"), 1)
def test_compare_versions_post_ordering(self):
"""Test post-release ordering: stable < post1 < post2."""
# stable < post1
self.assertEqual(compare_versions("0.5.3", "0.5.3.post1"), -1)
self.assertEqual(compare_versions("0.5.3.post1", "0.5.3"), 1)
# post1 < post2
self.assertEqual(compare_versions("0.5.3.post1", "0.5.3.post2"), -1)
self.assertEqual(compare_versions("0.5.3.post2", "0.5.3.post1"), 1)
def test_compare_versions_full_ordering(self):
"""Test complete version ordering: rc < stable < post."""
# rc < stable < post
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3"), -1)
self.assertEqual(compare_versions("0.5.3", "0.5.3.post1"), -1)
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3.post1"), -1)
# Verify transitivity: rc0 < rc1 < stable < post1 < post2
versions = [
"0.5.3rc0",
"0.5.3rc1",
"0.5.3",
"0.5.3.post1",
"0.5.3.post2",
]
for i in range(len(versions) - 1):
self.assertEqual(
compare_versions(versions[i], versions[i + 1]),
-1,
f"{versions[i]} should be less than {versions[i + 1]}",
)
def test_compare_versions_different_patch(self):
"""Test comparing versions with different patch numbers."""
# 0.5.3 < 0.5.4
self.assertEqual(compare_versions("0.5.3", "0.5.4"), -1)
self.assertEqual(compare_versions("0.5.4", "0.5.3"), 1)
# rc of higher patch > stable of lower patch
self.assertEqual(compare_versions("0.5.4rc0", "0.5.3"), 1)
self.assertEqual(compare_versions("0.5.3.post1", "0.5.4rc0"), -1)
def test_compare_versions_different_minor(self):
"""Test comparing versions with different minor numbers."""
self.assertEqual(compare_versions("0.4.9", "0.5.0"), -1)
self.assertEqual(compare_versions("0.5.0", "0.4.9"), 1)
def test_compare_versions_different_major(self):
"""Test comparing versions with different major numbers."""
self.assertEqual(compare_versions("0.9.9", "1.0.0"), -1)
self.assertEqual(compare_versions("1.0.0", "0.9.9"), 1)
def test_real_world_scenarios(self):
"""Test real-world version bump scenarios."""
# Scenario 1: RC progression
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3rc1"), -1)
# Scenario 2: RC to stable release
self.assertEqual(compare_versions("0.5.3rc2", "0.5.3"), -1)
# Scenario 3: Stable to post-release hotfix
self.assertEqual(compare_versions("0.5.3", "0.5.3.post1"), -1)
# Scenario 4: Post-release to next RC
self.assertEqual(compare_versions("0.5.3.post1", "0.5.4rc0"), -1)
# Scenario 5: Next stable version
self.assertEqual(compare_versions("0.5.3", "0.5.4"), -1)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,220 @@
import re
import sys
from pathlib import Path
from typing import List, Tuple
try:
import tomllib # Python 3.11+
except ImportError:
import tomli as tomllib # Fallback for older Python versions
def normalize_version(version: str) -> str:
"""Remove 'v' prefix from version string if present."""
return version.lstrip("v")
def validate_version(version: str) -> bool:
"""Validate version format: X.Y.Z, X.Y.Zrc0, or X.Y.Z.post1"""
pattern = r"^\d+\.\d+\.\d+(rc\d+|\.post\d+)?$"
return bool(re.match(pattern, version))
def parse_version(version: str) -> Tuple[int, int, int, int, int]:
"""
Parse version string into comparable components.
Returns: (major, minor, patch, pre_release, post_release)
- pre_release: -1000 + rc_number for rcN, 0 for stable (rc0 < rc1 < stable)
- post_release: N for .postN, 0 otherwise
The pre_release field uses negative numbers to ensure RC versions come before
stable versions when tuples are compared. Python compares tuples element by
element, so (0, 5, 3, -1000, 0) < (0, 5, 3, 0, 0) ensures rc0 < stable.
Examples:
- "0.5.3rc0" → (0, 5, 3, -1000, 0) # rc0 comes before stable
- "0.5.3rc1" → (0, 5, 3, -999, 0) # rc1 comes after rc0
- "0.5.3" → (0, 5, 3, 0, 0) # stable version
- "0.5.3.post1" → (0, 5, 3, 0, 1) # post comes after stable
"""
# Match version components
match = re.match(r"^(\d+)\.(\d+)\.(\d+)(?:rc(\d+)|\.post(\d+))?$", version)
if not match:
raise ValueError(f"Invalid version format: {version}")
major, minor, patch, rc, post = match.groups()
major, minor, patch = int(major), int(minor), int(patch)
if rc is not None:
# RC version: pre_release = -1000 + rc_number (ensures rc0 < rc1 < ... < stable)
return (major, minor, patch, -1000 + int(rc), 0)
elif post is not None:
# Post version: post_release = N
return (major, minor, patch, 0, int(post))
else:
# Stable version
return (major, minor, patch, 0, 0)
def compare_versions(v1: str, v2: str) -> int:
"""
Compare two version strings following PEP 440 ordering.
Returns:
- -1 if v1 < v2
- 0 if v1 == v2
- 1 if v1 > v2
Version ordering: X.Y.ZrcN < X.Y.Z < X.Y.Z.postN < X.Y.(Z+1)
"""
parsed_v1 = parse_version(v1)
parsed_v2 = parse_version(v2)
if parsed_v1 < parsed_v2:
return -1
elif parsed_v1 > parsed_v2:
return 1
else:
return 0
def get_repo_root() -> Path:
return Path(__file__).parent.parent.parent
def read_current_version(version_file: Path) -> str:
content = version_file.read_text()
match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', content)
if not match:
raise ValueError(f"Could not find version in {version_file}")
return match.group(1)
def replace_in_file(file_path: Path, old_version: str, new_version: str) -> bool:
if not file_path.exists():
print(f"Warning: {file_path} does not exist, skipping")
return False
content = file_path.read_text()
# For TOML files, parse and update only the [project] version field
if file_path.suffix == ".toml":
try:
# Parse TOML to verify structure
toml_data = tomllib.loads(content)
# Check if [project] section exists and has version field
if "project" not in toml_data or "version" not in toml_data["project"]:
print(
f"Warning: {file_path} does not have [project] version field, skipping"
)
return False
# Use regex to replace only the version field in [project] section
# This pattern matches the version field that comes after [project]
# and before any other section marker
pattern = r'(\[project\].*?version\s*=\s*)["\']([^"\']+)["\']'
new_content = re.sub(
pattern, rf'\g<1>"{new_version}"', content, flags=re.DOTALL
)
except Exception as e:
print(f"Warning: Failed to parse {file_path} as TOML: {e}")
print("Falling back to simple string replacement")
new_content = content.replace(old_version, new_version)
else:
# For non-TOML files, use simple string replacement
new_content = content.replace(old_version, new_version)
if content == new_content:
print(f"No changes needed in {file_path}")
return False
file_path.write_text(new_content)
print(f"✓ Updated {file_path}")
return True
def bump_version(
new_version: str,
version_file: Path,
files_to_update: List[Path],
) -> None:
# Normalize version (remove 'v' prefix if present)
new_version = normalize_version(new_version)
if not validate_version(new_version):
print(f"Error: Invalid version format: {new_version}")
print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN")
print("Examples: 0.5.4, 0.5.3rc0, 0.5.3.post1")
sys.exit(1)
repo_root = get_repo_root()
version_file_abs = repo_root / version_file
if not version_file_abs.exists():
print(f"Error: Version file {version_file_abs} does not exist")
sys.exit(1)
old_version = read_current_version(version_file_abs)
print(f"Current version: {old_version}")
print(f"New version: {new_version}")
print()
# Compare versions
comparison = compare_versions(new_version, old_version)
if comparison == 0:
print("Error: New version is the same as current version")
sys.exit(1)
elif comparison < 0:
print(
f"Error: New version ({new_version}) is older than current version ({old_version})"
)
print("Version must be greater than the current version")
sys.exit(1)
updated_count = 0
for file_rel in files_to_update:
file_abs = repo_root / file_rel
if replace_in_file(file_abs, old_version, new_version):
updated_count += 1
print()
print(f"Successfully updated {updated_count} file(s)")
print(f"Version bumped from {old_version} to {new_version}")
# Validate that all files now contain the new version
print("\nValidating version updates...")
failed_files = []
for file_rel in files_to_update:
file_abs = repo_root / file_rel
if not file_abs.exists():
print(f"Warning: File {file_rel} does not exist, skipping validation.")
continue
content = file_abs.read_text()
# For TOML files, use regex to specifically check the version field
if file_abs.suffix == ".toml":
# Match version field with optional quotes
pattern = r'version\s*=\s*["\']?' + re.escape(new_version) + r'["\']?'
if not re.search(pattern, content):
failed_files.append(file_rel)
print(f"{file_rel} does not contain version {new_version}")
else:
print(f"{file_rel} validated")
else:
# For non-TOML files, use simple string search
if new_version not in content:
failed_files.append(file_rel)
print(f"{file_rel} does not contain version {new_version}")
else:
print(f"{file_rel} validated")
if failed_files:
print(f"\nError: {len(failed_files)} file(s) were not updated correctly:")
for file_rel in failed_files:
print(f" - {file_rel}")
sys.exit(1)
print("\nAll files validated successfully!")

View File

@@ -0,0 +1,27 @@
"""
Sort the test case by name alphabetically for run_suite.py
"""
from dataclasses import dataclass
@dataclass
class TestFile:
name: str
estimated_time: float = 60
suites = {}
if __name__ == "__main__":
for key in suites:
cases = suites[key]
names = [x.name for x in cases]
names.sort()
print(f' "{key}": [')
for name in names:
estimated_time = [x.estimated_time for x in cases if x.name == name][0]
print(f' TestFile("{name}", {estimated_time}),')
print(f" ],\n")

View File

@@ -0,0 +1,92 @@
# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py
import argparse
import hashlib
import pathlib
import re
# All the CUDA versions that the wheels will cover
SUPPORTED_CUDA_VERSIONS = ["129", "130"]
DEFAULT_CUDA_VERSION = "129"
def check_wheel_cuda_version(path_name, target_cuda_version):
# Pass ROCm wheel
if re.search(f"rocm", path_name):
return False
# For other CUDA versions, the wheel path name will contain the cuda version suffix, e.g. sglang_kernel-0.4.0+cu130-cp310-abi3-manylinux2014_x86_64.whl
if target_cuda_version != DEFAULT_CUDA_VERSION:
return target_cuda_version in path_name
# For the default CUDA version, the wheel path name will not contain any cuda version suffix, e.g. sglang_kernel-0.4.0-cp310-abi3-manylinux2014_x86_64.whl
# So we need to check if the wheel path name contains any other cuda version suffix
for cuda_version in SUPPORTED_CUDA_VERSIONS:
if cuda_version != DEFAULT_CUDA_VERSION and cuda_version in path_name:
return False
return True
def update_wheel_index(cuda_version=DEFAULT_CUDA_VERSION, rocm_version=None):
index_dir = pathlib.Path(f"sgl-whl/cu{cuda_version}/sglang-kernel")
index_dir.mkdir(exist_ok=True, parents=True)
base_url = "https://github.com/sgl-project/whl/releases/download"
for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")):
# Skip the wheel if mismatches the passed in cuda_version
if not check_wheel_cuda_version(path.name, cuda_version):
continue
with open(path, "rb") as f:
sha256 = hashlib.sha256(f.read()).hexdigest()
ver = re.findall(
r"sglang_kernel-([0-9.]+(?:\.post[0-9]+)?)(?:\+cu[0-9]+)?-", path.name
)[0]
full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}"
with (index_dir / "index.html").open("a") as f:
f.write(f'<a href="{full_url}">{path.name}</a><br>\n')
def _update_non_cuda_wheel_index(backend, version):
index_dir = pathlib.Path(f"sgl-whl/{backend}{version}/sglang-kernel")
index_dir.mkdir(exist_ok=True, parents=True)
base_url = "https://github.com/sgl-project/whl/releases/download"
for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")):
# Skip the wheel if not for this backend
if re.search(f"{backend}", path.name) is None:
continue
with open(path, "rb") as f:
sha256 = hashlib.sha256(f.read()).hexdigest()
ver = re.findall(
rf"sglang_kernel-([0-9.]+(?:\.post[0-9]+)?)(?:\+{backend}[0-9]+)?-",
path.name,
)[0]
full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}"
with (index_dir / "index.html").open("a") as f:
f.write(f'<a href="{full_url}">{path.name}</a><br>\n')
def update_wheel_index_rocm(rocm_version):
_update_non_cuda_wheel_index("rocm", rocm_version)
def update_wheel_index_musa(musa_version):
_update_non_cuda_wheel_index("musa", musa_version)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", type=str, default=DEFAULT_CUDA_VERSION)
parser.add_argument("--rocm", type=str, default=None)
parser.add_argument("--musa", type=str, default=None)
args = parser.parse_args()
if args.musa is not None:
update_wheel_index_musa(args.musa)
elif args.rocm is not None:
update_wheel_index_rocm(args.rocm)
else:
update_wheel_index(args.cuda)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python3
"""
Update the wheel index for nightly SGLang releases.
This script generates a PyPI-compatible index.html file at cu{version}/sglang/index.html
containing all historical nightly builds, ordered by commit count (newest first).
The CUDA version is specified via the --cuda-version argument.
Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py
"""
import argparse
import hashlib
import pathlib
import re
def compute_sha256(file_path: pathlib.Path) -> str:
"""Compute SHA256 hash of a file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def update_wheel_index(
commit_hash: str, nightly_version: str, cuda_version: str, build_date: str = None
):
"""Update the wheel index for nightly releases.
Creates an index at cu{version}/sglang/index.html containing all historical nightlies.
Args:
commit_hash: Short git commit hash (e.g., 'c5f1e86')
nightly_version: Full nightly version string (e.g., '0.5.6.post1.dev7716+gc5f1e86')
cuda_version: CUDA version string (e.g., '129' or '130')
build_date: Build date in YYYY-MM-DD format (e.g., '2025-12-13')
"""
dist_dir = pathlib.Path("dist")
whl_repo_dir = pathlib.Path("sgl-whl")
if not dist_dir.exists():
print(f"Warning: {dist_dir} does not exist, skipping index update")
return
# Format CUDA version with 'cu' prefix if not already present
if not cuda_version.startswith("cu"):
cuda_version = f"cu{cuda_version}"
print(f"Using CUDA version: {cuda_version}")
# Base URL for wheels stored in GitHub Releases
base_url = "https://github.com/sgl-project/whl/releases/download"
# Use date-based tag if build_date is provided, otherwise fall back to commit-only
if build_date:
release_tag = f"nightly-{build_date}-{commit_hash}"
else:
release_tag = f"nightly-{commit_hash}"
# Create directory structure following PEP 503
# /cu{version}/index.html -> links to sglang/ and sgl-kernel/
# /cu{version}/sglang/index.html -> contains wheel links
cuda_dir = whl_repo_dir / cuda_version
cuda_dir.mkdir(parents=True, exist_ok=True)
sglang_dir = cuda_dir / "sglang"
sglang_dir.mkdir(parents=True, exist_ok=True)
root_index = cuda_dir / "index.html"
package_index = sglang_dir / "index.html"
print(f"\nUpdating nightly wheel index")
print(f" Root index: {root_index}")
print(f" Package index: {package_index}")
# Read existing package index if it exists
existing_links = []
if package_index.exists():
with open(package_index, "r") as f:
content = f.read()
# Extract existing links (skip header and HTML boilerplate)
existing_links = [
line for line in content.split("\n") if line.startswith("<a href=")
]
# Generate new links for current wheels
new_links = []
for wheel_path in sorted(dist_dir.glob("*.whl")):
try:
filename = wheel_path.name
sha256 = compute_sha256(wheel_path)
# URL format: {base_url}/{release_tag}/{filename}#sha256={hash}
wheel_url = f"{base_url}/{release_tag}/{filename}#sha256={sha256}"
link = f'<a href="{wheel_url}">{filename}</a><br>'
new_links.append(link)
print(f" Added: {filename}")
except Exception as e:
print(f" Error processing {wheel_path.name}: {e}")
continue
if not new_links:
print(" No new wheels to add")
return
# Combine existing and new links (new links first for latest)
all_links = new_links + existing_links
# Remove duplicates while preserving order (newer first)
seen = set()
unique_links = []
for link in all_links:
# Extract filename from link to check for duplicates
filename_match = re.search(r">([^<]+\.whl)</a>", link)
if filename_match:
filename = filename_match.group(1)
if filename not in seen:
seen.add(filename)
unique_links.append(link)
# Update root index to include both sgl-kernel and sglang
# Read existing packages from root index if it exists
existing_packages = set()
if root_index.exists():
with open(root_index, "r") as f:
content = f.read()
# Extract existing package links
for match in re.finditer(r'<a href="([^"]+)/">', content):
existing_packages.add(match.group(1))
# Add sglang to the package list
existing_packages.add("sglang")
# Write root index with all packages (sorted for consistency)
with open(root_index, "w") as f:
f.write("<!DOCTYPE html>\n")
for pkg in sorted(existing_packages):
f.write(f'<a href="{pkg}/">{pkg}</a>\n')
print(f" Written root index: {root_index} (packages: {sorted(existing_packages)})")
# Write package index in minimal format (matching production sgl-kernel index)
with open(package_index, "w") as f:
f.write("<!DOCTYPE html>\n")
f.write(f"<h1>SGLang Nightly Wheels ({cuda_version})</h1>\n")
# Write links only
f.write("\n".join(unique_links))
f.write("\n")
print(f" Written {len(unique_links)} total wheels to {package_index}")
print(f"\nDone! Users can install with:")
print(
f" pip install sglang --pre --extra-index-url https://sgl-project.github.io/whl/{cuda_version}/"
)
def main():
parser = argparse.ArgumentParser(
description="Update wheel index for nightly SGLang releases"
)
parser.add_argument(
"--commit-hash",
type=str,
required=True,
help="Short git commit hash (e.g., 'c5f1e86')",
)
parser.add_argument(
"--nightly-version",
type=str,
required=True,
help="Full nightly version string (e.g., '0.5.6.post1.dev7716+gc5f1e86')",
)
parser.add_argument(
"--cuda-version",
type=str,
default="129",
help="CUDA version (e.g., '129' or '130'). Defaults to '129'.",
)
parser.add_argument(
"--build-date",
type=str,
required=False,
help="Build date in YYYY-MM-DD format (e.g., '2025-12-13')",
)
args = parser.parse_args()
print(f"Updating nightly wheel index")
print(f" Commit: {args.commit_hash}")
print(f" Version: {args.nightly_version}")
print(f" CUDA version: {args.cuda_version}")
if args.build_date:
print(f" Build date: {args.build_date}")
update_wheel_index(
args.commit_hash, args.nightly_version, args.cuda_version, args.build_date
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
"""
Update the wheel index for PR SGLang releases.
This script generates a single PyPI-compatible index.html file at pr/index.html
containing all PR builds, ordered by PR number and commit count (newest first).
Similar to update_nightly_whl_index.py but for PR builds.
"""
import argparse
import hashlib
import pathlib
import re
def compute_sha256(file_path: pathlib.Path) -> str:
"""Compute SHA256 hash of a file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def update_wheel_index(
pr_number: str, commit_hash: str, wheel_version: str, build_date: str
):
"""Update the wheel index for PR releases.
Creates a single index at pr/index.html containing all PR builds.
Args:
pr_number: PR number (e.g., '123')
commit_hash: Short git commit hash (e.g., 'c5f1e86')
wheel_version: Full wheel version string (e.g., '0.5.6.dev7716+pr-123.gc5f1e86')
build_date: Build date in YYYY-MM-DD format (e.g., '2025-12-13')
"""
dist_dir = pathlib.Path("dist")
whl_repo_dir = pathlib.Path("sgl-whl")
if not dist_dir.exists():
print(f"Warning: {dist_dir} does not exist, skipping index update")
return
# Base URL for wheels stored in GitHub Releases
base_url = "https://github.com/sgl-project/whl/releases/download"
release_tag = f"pr-{pr_number}-{build_date}-{commit_hash}"
# Create pr directory structure following PEP 503
# /pr/index.html -> links to sglang/
# /pr/sglang/index.html -> contains wheel links
pr_dir = whl_repo_dir / "pr"
pr_dir.mkdir(parents=True, exist_ok=True)
sglang_dir = pr_dir / "sglang"
sglang_dir.mkdir(parents=True, exist_ok=True)
root_index = pr_dir / "index.html"
package_index = sglang_dir / "index.html"
print(f"\nUpdating PR wheel index")
print(f" Root index: {root_index}")
print(f" Package index: {package_index}")
# Read existing package index if it exists
existing_links = []
if package_index.exists():
with open(package_index, "r") as f:
content = f.read()
# Extract existing links (skip header and HTML boilerplate)
existing_links = [
line for line in content.split("\n") if line.startswith("<a href=")
]
# Generate new links for current wheels
new_links = []
for wheel_path in sorted(dist_dir.glob("*.whl")):
try:
filename = wheel_path.name
sha256 = compute_sha256(wheel_path)
# URL format: {base_url}/{release_tag}/{filename}#sha256={hash}
wheel_url = f"{base_url}/{release_tag}/{filename}#sha256={sha256}"
link = f'<a href="{wheel_url}">{filename}</a><br>'
new_links.append(link)
print(f" Added: {filename}")
except Exception as e:
print(f" Error processing {wheel_path.name}: {e}")
continue
if not new_links:
print(" No new wheels to add")
return
# Combine existing and new links (new links first for latest)
all_links = new_links + existing_links
# Remove duplicates while preserving order (newer first)
seen = set()
unique_links = []
for link in all_links:
# Extract filename from link to check for duplicates
filename_match = re.search(r">([^<]+\.whl)</a>", link)
if filename_match:
filename = filename_match.group(1)
if filename not in seen:
seen.add(filename)
unique_links.append(link)
# Write root index (links to sglang package directory)
with open(root_index, "w") as f:
f.write("<!DOCTYPE html>\n")
f.write('<a href="sglang/">sglang</a>\n')
print(f" Written root index: {root_index}")
# Write package index in minimal format
with open(package_index, "w") as f:
f.write("<!DOCTYPE html>\n")
f.write("<h1>SGLang PR Wheels</h1>\n")
# Write links only
f.write("\n".join(unique_links))
f.write("\n")
print(f" Written {len(unique_links)} total wheels to {package_index}")
print(f"\nDone! Users can install with:")
print(
f" pip install sglang --pre --extra-index-url https://sgl-project.github.io/whl/pr/"
)
print(f"\nOr install specific PR #{pr_number} wheel directly:")
if new_links:
first_wheel_match = re.search(r'href="([^"]+)"', new_links[0])
if first_wheel_match:
wheel_url = first_wheel_match.group(1).split("#")[0] # Remove sha256 hash
print(f" pip install {wheel_url}")
def main():
parser = argparse.ArgumentParser(
description="Update wheel index for PR SGLang releases"
)
parser.add_argument(
"--pr-number",
type=str,
required=True,
help="PR number (e.g., '123')",
)
parser.add_argument(
"--commit-hash",
type=str,
required=True,
help="Short git commit hash (e.g., 'c5f1e86')",
)
parser.add_argument(
"--wheel-version",
type=str,
required=True,
help="Full wheel version string (e.g., '0.5.6.dev7716+pr-123.gc5f1e86')",
)
parser.add_argument(
"--build-date",
type=str,
required=True,
help="Build date in YYYY-MM-DD format (e.g., '2025-12-13')",
)
args = parser.parse_args()
print(f"Updating PR wheel index")
print(f" PR: #{args.pr_number}")
print(f" Commit: {args.commit_hash}")
print(f" Version: {args.wheel_version}")
print(f" Build date: {args.build_date}")
update_wheel_index(
args.pr_number, args.commit_hash, args.wheel_version, args.build_date
)
if __name__ == "__main__":
main()