chore: vendor sglang v0.5.10 snapshot
This commit is contained in:
89
third_party/sglang/scripts/ci/amd/amd_ci_exec.sh
vendored
Executable file
89
third_party/sglang/scripts/ci/amd/amd_ci_exec.sh
vendored
Executable file
@@ -0,0 +1,89 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Detect GPU family from hostname (e.g., linux-mi35x-gpu-1-xxxxx-runner-zzzzz)
|
||||
HOSTNAME_VALUE=$(hostname)
|
||||
GPU_FAMILY=""
|
||||
|
||||
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
||||
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
||||
GPU_FAMILY="${BASH_REMATCH[1]}"
|
||||
echo "Detected GPU family from hostname: ${GPU_FAMILY}"
|
||||
else
|
||||
echo "Warning: could not parse GPU family from '${HOSTNAME_VALUE}'"
|
||||
fi
|
||||
|
||||
WORKDIR="/sglang-checkout/test/srt"
|
||||
declare -A ENV_MAP=(
|
||||
[SGLANG_IS_IN_CI_AMD]=1
|
||||
[SGLANG_IS_IN_CI]=1
|
||||
[SGLANG_USE_AITER]=1
|
||||
)
|
||||
|
||||
# Conditionally add GPU_ARCHS only for mi35x
|
||||
if [[ "${GPU_FAMILY}" == "mi35x" ]]; then
|
||||
ENV_MAP[GPU_ARCHS]="gfx950"
|
||||
fi
|
||||
|
||||
# Parse -w/--workdir and -e ENV=VAL
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
-w|--workdir)
|
||||
WORKDIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
-e)
|
||||
IFS="=" read -r key val <<< "$2"
|
||||
ENV_MAP["$key"]="$val"
|
||||
shift 2
|
||||
;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Build final ENV_ARGS
|
||||
ENV_ARGS=()
|
||||
for key in "${!ENV_MAP[@]}"; do
|
||||
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
|
||||
done
|
||||
|
||||
# Run docker exec with retry logic for HuggingFace network/download issues
|
||||
# When HF model downloads fail due to network timeouts or rate limits,
|
||||
# retrying with HF_HUB_OFFLINE=1 uses cached models from previous downloads.
|
||||
#
|
||||
# First attempt: normal mode (allows HF downloads)
|
||||
if docker exec \
|
||||
-w "$WORKDIR" \
|
||||
"${ENV_ARGS[@]}" \
|
||||
ci_sglang "$@"; then
|
||||
exit 0
|
||||
else
|
||||
FIRST_EXIT_CODE=$?
|
||||
fi
|
||||
|
||||
echo "First attempt failed with exit code $FIRST_EXIT_CODE"
|
||||
|
||||
# Skip retry for test failures that won't be fixed by offline mode:
|
||||
# - Exit 1: Test assertion failures (accuracy below threshold)
|
||||
# - Exit 137 (128+9): Process killed by OOM
|
||||
# - Exit 255: Test suite completed with test errors
|
||||
# Only retry for other exit codes (e.g., network timeouts, HF download failures)
|
||||
if [[ "$FIRST_EXIT_CODE" -eq 1 || "$FIRST_EXIT_CODE" -eq 137 || "$FIRST_EXIT_CODE" -eq 255 ]]; then
|
||||
echo "Exit code $FIRST_EXIT_CODE indicates test failure (not network issue), not retrying"
|
||||
exit $FIRST_EXIT_CODE
|
||||
fi
|
||||
|
||||
echo "Retrying with HF_HUB_OFFLINE=1 (offline mode to use cached models)..."
|
||||
|
||||
# Second attempt: force HF offline mode to avoid network timeouts
|
||||
docker exec \
|
||||
-w "$WORKDIR" \
|
||||
"${ENV_ARGS[@]}" \
|
||||
-e HF_HUB_OFFLINE=1 \
|
||||
ci_sglang "$@"
|
||||
320
third_party/sglang/scripts/ci/amd/amd_ci_install_dependency.sh
vendored
Executable file
320
third_party/sglang/scripts/ci/amd/amd_ci_install_dependency.sh
vendored
Executable file
@@ -0,0 +1,320 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
HOSTNAME_VALUE=$(hostname)
|
||||
GPU_ARCH="mi30x" # default
|
||||
SKIP_TT_DEPS=""
|
||||
SKIP_SGLANG_BUILD=""
|
||||
SKIP_AITER_BUILD=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--skip-aiter-build) SKIP_AITER_BUILD="1"; shift;;
|
||||
--skip-sglang-build) SKIP_SGLANG_BUILD="1"; shift;;
|
||||
--skip-test-time-deps) SKIP_TT_DEPS="1"; shift;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS] [OPTIONAL_DEPS]"
|
||||
echo "Options:"
|
||||
echo " --skip-sglang-build Don't build checkout sglang, use what was shipped with the image"
|
||||
echo " --skip-aiter-build Don't build aiter, use what was shipped with the image"
|
||||
echo " --skip-test-time-deps Don't build miscellaneous dependencies"
|
||||
exit 0
|
||||
;;
|
||||
*) break ;;
|
||||
esac
|
||||
done
|
||||
|
||||
OPTIONAL_DEPS="${1:-}"
|
||||
|
||||
# Build python extras
|
||||
EXTRAS="dev_hip"
|
||||
if [ -n "$OPTIONAL_DEPS" ]; then
|
||||
EXTRAS="dev_hip,${OPTIONAL_DEPS}"
|
||||
fi
|
||||
echo "Installing python extras: [${EXTRAS}]"
|
||||
|
||||
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
||||
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
||||
GPU_ARCH="${BASH_REMATCH[1]}"
|
||||
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
|
||||
else
|
||||
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
|
||||
fi
|
||||
|
||||
# Install the required dependencies in CI.
|
||||
# Fix permissions on pip cache, ignore errors from concurrent access or missing temp files
|
||||
docker exec ci_sglang chown -R root:root /sgl-data/pip-cache 2>/dev/null || true
|
||||
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache --upgrade pip
|
||||
|
||||
# Helper function to install with retries and fallback PyPI mirror
|
||||
install_with_retry() {
|
||||
local max_attempts=3
|
||||
local cmd="$@"
|
||||
|
||||
for attempt in $(seq 1 $max_attempts); do
|
||||
echo "Attempt $attempt/$max_attempts: $cmd"
|
||||
if eval "$cmd"; then
|
||||
echo "Success!"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ $attempt -lt $max_attempts ]; then
|
||||
echo "Failed, retrying in 5 seconds..."
|
||||
sleep 5
|
||||
# Try with alternative PyPI index on retry
|
||||
if [[ "$cmd" =~ "pip install" ]] && [ $attempt -eq 2 ]; then
|
||||
cmd="$cmd --index-url https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com"
|
||||
echo "Using fallback PyPI mirror: $cmd"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Failed after $max_attempts attempts"
|
||||
return 1
|
||||
}
|
||||
|
||||
# Helper function to git clone with retries
|
||||
git_clone_with_retry() {
|
||||
local repo_url="$1"
|
||||
local dest_dir="${2:-}"
|
||||
local branch_args="${3:-}"
|
||||
local max_attempts=3
|
||||
|
||||
for attempt in $(seq 1 $max_attempts); do
|
||||
echo "Git clone attempt $attempt/$max_attempts: $repo_url"
|
||||
|
||||
# prevent from partial clone
|
||||
if [ -n "$dest_dir" ] && [ -d "$dest_dir" ]; then
|
||||
rm -rf "$dest_dir"
|
||||
fi
|
||||
|
||||
if git \
|
||||
-c http.lowSpeedLimit=1000 \
|
||||
-c http.lowSpeedTime=30 \
|
||||
clone --depth 1 ${branch_args:+$branch_args} "$repo_url" "$dest_dir"; then
|
||||
echo "Git clone succeeded."
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ $attempt -lt $max_attempts ]; then
|
||||
echo "Git clone failed, retrying in 5 seconds..."
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Git clone failed after $max_attempts attempts: $repo_url"
|
||||
return 1
|
||||
}
|
||||
|
||||
# Install checkout sglang
|
||||
if [ -n "$SKIP_SGLANG_BUILD" ]; then
|
||||
echo "Didn't build checkout SGLang"
|
||||
else
|
||||
docker exec ci_sglang pip uninstall sgl-kernel -y || true
|
||||
docker exec ci_sglang pip uninstall sglang-kernel -y || true
|
||||
docker exec ci_sglang pip uninstall sglang -y || true
|
||||
# Clear Python cache to ensure latest code is used
|
||||
docker exec ci_sglang find /opt/venv -name "*.pyc" -delete || true
|
||||
docker exec ci_sglang find /opt/venv -name "__pycache__" -type d -exec rm -rf {} + || true
|
||||
# Also clear cache in sglang-checkout
|
||||
docker exec ci_sglang find /sglang-checkout -name "*.pyc" -delete || true
|
||||
docker exec ci_sglang find /sglang-checkout -name "__pycache__" -type d -exec rm -rf {} + || true
|
||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
|
||||
|
||||
docker exec ci_sglang bash -c 'rm -rf python/pyproject.toml && mv python/pyproject_other.toml python/pyproject.toml'
|
||||
install_with_retry docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache -e "python[${EXTRAS}]"
|
||||
fi
|
||||
|
||||
if [[ -n "${SKIP_TT_DEPS}" ]]; then
|
||||
echo "Didn't build lmms_eval, human-eval, and others"
|
||||
else
|
||||
# For lmms_evals evaluating MMMU
|
||||
docker exec -w / ci_sglang git clone --branch v0.4.1 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||
install_with_retry docker exec -w /lmms-eval ci_sglang pip install --cache-dir=/sgl-data/pip-cache -e .
|
||||
|
||||
git_clone_with_retry https://github.com/akao-amd/human-eval.git human-eval
|
||||
docker cp human-eval ci_sglang:/
|
||||
install_with_retry docker exec -w /human-eval ci_sglang pip install --cache-dir=/sgl-data/pip-cache -e .
|
||||
|
||||
docker exec -w / ci_sglang mkdir -p /dummy-grok
|
||||
# Create dummy grok config inline (bypasses Azure blob storage which may have auth issues)
|
||||
mkdir -p dummy-grok
|
||||
cat > dummy-grok/config.json << 'EOF'
|
||||
{
|
||||
"architectures": [
|
||||
"Grok1ModelForCausalLM"
|
||||
],
|
||||
"embedding_multiplier_scale": 78.38367176906169,
|
||||
"output_multiplier_scale": 0.5773502691896257,
|
||||
"vocab_size": 131072,
|
||||
"hidden_size": 6144,
|
||||
"intermediate_size": 32768,
|
||||
"max_position_embeddings": 8192,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": 8,
|
||||
"num_attention_heads": 48,
|
||||
"num_hidden_layers": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 10000.0,
|
||||
"model_type": "mixtral",
|
||||
"torch_dtype": "bfloat16"
|
||||
}
|
||||
EOF
|
||||
# docker exec -w / ci_sglang mkdir -p /dummy-grok
|
||||
# mkdir -p dummy-grok && wget https://sharkpublic.blob.core.windows.net/sharkpublic/sglang/dummy_grok.json -O dummy-grok/config.json
|
||||
# docker cp ./dummy-grok ci_sglang:/
|
||||
|
||||
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache huggingface_hub[hf_xet]
|
||||
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache pytest
|
||||
|
||||
# Install cache-dit for qwen_image_t2i_cache_dit_enabled test (added in PR 16204)
|
||||
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache cache-dit || echo "cache-dit installation failed"
|
||||
|
||||
# Install accelerate for distributed training and inference support
|
||||
docker exec ci_sglang pip install --cache-dir=/sgl-data/pip-cache accelerate || echo "accelerate installation failed"
|
||||
fi
|
||||
|
||||
if [[ -n "${SKIP_AITER_BUILD}" ]]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Detect AITER version
|
||||
#############################################
|
||||
# Detect correct AITER_COMMIT for this runner
|
||||
# + Check mismatch
|
||||
# + Rebuild AITER if needed
|
||||
#############################################
|
||||
|
||||
echo "[CI-AITER-CHECK] === AITER VERSION CHECK START ==="
|
||||
|
||||
DOCKERFILE="docker/rocm.Dockerfile"
|
||||
|
||||
# GPU_ARCH
|
||||
GPU_ARCH="${GPU_ARCH:-mi30x}"
|
||||
echo "[CI-AITER-CHECK] Runner GPU_ARCH=${GPU_ARCH}"
|
||||
|
||||
#############################################
|
||||
# 1. Extract AITER_COMMIT from correct Dockerfile block
|
||||
#############################################
|
||||
if [[ "${GPU_ARCH}" == "mi35x" ]]; then
|
||||
echo "[CI-AITER-CHECK] Using gfx950 block from Dockerfile..."
|
||||
REPO_AITER_COMMIT=$(grep -F -A20 'FROM $BASE_IMAGE_950 AS gfx950' docker/rocm.Dockerfile \
|
||||
| grep 'AITER_COMMIT_DEFAULT=' \
|
||||
| head -n1 \
|
||||
| sed 's/.*AITER_COMMIT_DEFAULT="\([^"]*\)".*/\1/')
|
||||
else
|
||||
echo "[CI-AITER-CHECK] Using gfx942 block from Dockerfile..."
|
||||
REPO_AITER_COMMIT=$(grep -F -A20 'FROM $BASE_IMAGE_942 AS gfx942' docker/rocm.Dockerfile \
|
||||
| grep 'AITER_COMMIT_DEFAULT=' \
|
||||
| head -n1 \
|
||||
| sed 's/.*AITER_COMMIT_DEFAULT="\([^"]*\)".*/\1/')
|
||||
fi
|
||||
|
||||
|
||||
if [[ -z "${REPO_AITER_COMMIT}" ]]; then
|
||||
echo "[CI-AITER-CHECK] ERROR: Failed to extract AITER_COMMIT from Dockerfile."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "[CI-AITER-CHECK] Dockerfile expects AITER_COMMIT=${REPO_AITER_COMMIT}"
|
||||
|
||||
#############################################
|
||||
# 2. Check container pre-installed AITER version
|
||||
#############################################
|
||||
IMAGE_AITER_VERSION=$(docker exec ci_sglang bash -c "pip show amd-aiter 2>/dev/null | grep '^Version:' | awk '{print \$2}'" || echo "none")
|
||||
IMAGE_AITER_VERSION="v${IMAGE_AITER_VERSION}"
|
||||
echo "[CI-AITER-CHECK] AITER version inside CI image: ${IMAGE_AITER_VERSION}"
|
||||
|
||||
#############################################
|
||||
# 3. Decide rebuild
|
||||
#############################################
|
||||
NEED_REBUILD="false"
|
||||
|
||||
if [[ -n "${AITER_COMMIT_OVERRIDE:-}" ]]; then
|
||||
echo "[CI-AITER-CHECK] AITER_COMMIT_OVERRIDE=${AITER_COMMIT_OVERRIDE} → forcing rebuild"
|
||||
REPO_AITER_COMMIT="${AITER_COMMIT_OVERRIDE}"
|
||||
NEED_REBUILD="true"
|
||||
elif [[ "${IMAGE_AITER_VERSION}" == "vnone" || "${IMAGE_AITER_VERSION}" == "v" ]]; then
|
||||
echo "[CI-AITER-CHECK] No AITER found in image → rebuild needed"
|
||||
NEED_REBUILD="true"
|
||||
elif [[ "${IMAGE_AITER_VERSION}" == "${REPO_AITER_COMMIT}" ]]; then
|
||||
echo "[CI-AITER-CHECK] AITER version matches"
|
||||
elif [[ "${IMAGE_AITER_VERSION}" =~ (dev|\+g[0-9a-f]+) ]]; then
|
||||
# Dev/patched version (contains 'dev' or git hash) → preserve it
|
||||
echo "[CI-AITER-CHECK] Dev/patched version detected: ${IMAGE_AITER_VERSION} → skipping rebuild"
|
||||
else
|
||||
echo "[CI-AITER-CHECK] Version mismatch: image=${IMAGE_AITER_VERSION}, repo=${REPO_AITER_COMMIT}"
|
||||
NEED_REBUILD="true"
|
||||
fi
|
||||
|
||||
|
||||
#############################################
|
||||
# 4. Rebuild AITER if needed
|
||||
#############################################
|
||||
if [[ "${NEED_REBUILD}" == "true" ]]; then
|
||||
echo "[CI-AITER-CHECK] === AITER REBUILD START ==="
|
||||
|
||||
# uninstall existing aiter
|
||||
docker exec ci_sglang pip uninstall -y amd-aiter || true
|
||||
|
||||
# delete old aiter directory
|
||||
docker exec ci_sglang rm -rf /sgl-workspace/aiter
|
||||
|
||||
# clone a fresh copy to /sgl-workspace/aiter
|
||||
docker exec ci_sglang git clone https://github.com/ROCm/aiter.git /sgl-workspace/aiter
|
||||
|
||||
# checkout correct version
|
||||
docker exec ci_sglang bash -c "
|
||||
cd /sgl-workspace/aiter && \
|
||||
git fetch --all && \
|
||||
git checkout ${REPO_AITER_COMMIT} && \
|
||||
git submodule update --init --recursive
|
||||
"
|
||||
|
||||
if [[ "${GPU_ARCH}" == "mi35x" ]]; then
|
||||
GPU_ARCH_LIST="gfx950"
|
||||
else
|
||||
GPU_ARCH_LIST="gfx942"
|
||||
fi
|
||||
echo "[CI-AITER-CHECK] GPU_ARCH_LIST=${GPU_ARCH_LIST}"
|
||||
|
||||
# Re-apply Dockerfile hotpatches for ROCm 7.2 (the fresh clone lost them, can be removed after triton fixed this problem)
|
||||
ROCM_VERSION=$(docker exec ci_sglang bash -c "cat /opt/rocm/.info/version 2>/dev/null || echo unknown")
|
||||
if [[ "${ROCM_VERSION}" == 7.2* ]]; then
|
||||
echo "[CI-AITER-CHECK] ROCm 7.2 detected (${ROCM_VERSION}), applying AITER hotpatches..."
|
||||
docker exec ci_sglang bash -c "
|
||||
cd /sgl-workspace/aiter && \
|
||||
TARGET_FILE='aiter/ops/triton/attention/pa_mqa_logits.py' && \
|
||||
if [ -f \"\${TARGET_FILE}\" ]; then \
|
||||
sed -i '459 s/if.*:/if False:/' \"\${TARGET_FILE}\" && \
|
||||
echo '[CI-AITER-CHECK] Hotpatch applied to pa_mqa_logits.py'; \
|
||||
else \
|
||||
echo '[CI-AITER-CHECK] pa_mqa_logits.py not found, skipping hotpatch'; \
|
||||
fi
|
||||
"
|
||||
else
|
||||
echo "[CI-AITER-CHECK] ROCm version=${ROCM_VERSION}, no hotpatch needed"
|
||||
fi
|
||||
|
||||
# build AITER
|
||||
docker exec ci_sglang bash -c "
|
||||
cd /sgl-workspace/aiter && \
|
||||
GPU_ARCHS=${GPU_ARCH_LIST} python3 setup.py develop
|
||||
"
|
||||
|
||||
echo "[CI-AITER-CHECK] === AITER REBUILD COMPLETE ==="
|
||||
fi
|
||||
|
||||
echo "[CI-AITER-CHECK] === AITER VERSION CHECK END ==="
|
||||
|
||||
|
||||
# # Clear pre-built AITER kernels from Docker image to avoid segfaults
|
||||
# # The Docker image may contain pre-compiled kernels incompatible with the current environment
|
||||
# echo "Clearing pre-built AITER kernels from Docker image..."
|
||||
# docker exec ci_sglang find /sgl-workspace/aiter/aiter/jit -name "*.so" -delete 2>/dev/null || true
|
||||
# docker exec ci_sglang ls -la /sgl-workspace/aiter/aiter/jit/ 2>/dev/null || echo "jit dir empty or not found"
|
||||
|
||||
# # Pre-build AITER kernels to avoid timeout during tests
|
||||
# echo "Warming up AITER JIT kernels..."
|
||||
# docker exec -e SGLANG_USE_AITER=1 ci_sglang python3 /sglang-checkout/scripts/ci/amd/amd_ci_warmup_aiter.py || echo "AITER warmup completed (some kernels may not be available)"
|
||||
248
third_party/sglang/scripts/ci/amd/amd_ci_start_container.sh
vendored
Executable file
248
third_party/sglang/scripts/ci/amd/amd_ci_start_container.sh
vendored
Executable file
@@ -0,0 +1,248 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Get version from git tags
|
||||
SGLANG_VERSION="v0.5.5" # Default version, will be overridden if git tags are found
|
||||
|
||||
# Fetch tags from origin to ensure we have the latest
|
||||
if git fetch --tags origin; then
|
||||
# Get the latest version tag sorted by version number (e.g., v0.5.7)
|
||||
VERSION_FROM_TAG=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1)
|
||||
if [ -n "$VERSION_FROM_TAG" ]; then
|
||||
SGLANG_VERSION="$VERSION_FROM_TAG"
|
||||
echo "Using SGLang version from git tags: $SGLANG_VERSION"
|
||||
else
|
||||
echo "Warning: No version tags found; using default $SGLANG_VERSION" >&2
|
||||
fi
|
||||
else
|
||||
echo "Warning: Failed to fetch tags from origin; using default $SGLANG_VERSION" >&2
|
||||
fi
|
||||
|
||||
|
||||
# Default base tags (can be overridden by command line arguments)
|
||||
ROCM_VERSION="rocm700"
|
||||
DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
|
||||
DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
|
||||
|
||||
# Parse command line arguments
|
||||
MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
|
||||
MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
|
||||
CUSTOM_IMAGE=""
|
||||
BUILD_FROM_DOCKERFILE=""
|
||||
GPU_ARCH_BUILD=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--mi30x-base-tag) MI30X_BASE_TAG="$2"; shift 2;;
|
||||
--mi35x-base-tag) MI35X_BASE_TAG="$2"; shift 2;;
|
||||
--custom-image) CUSTOM_IMAGE="$2"; shift 2;;
|
||||
--build-from-dockerfile) BUILD_FROM_DOCKERFILE="1"; shift;;
|
||||
--gpu-arch) GPU_ARCH_BUILD="$2"; shift 2;;
|
||||
--rocm-version)
|
||||
ROCM_VERSION="$2"
|
||||
MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
|
||||
MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
|
||||
echo "Using ROCm version override: ${ROCM_VERSION}"
|
||||
shift 2;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo "Options:"
|
||||
echo " --mi30x-base-tag TAG Override MI30x base image tag"
|
||||
echo " --mi35x-base-tag TAG Override MI35x base image tag"
|
||||
echo " --custom-image IMAGE Use a specific Docker image directly"
|
||||
echo " --build-from-dockerfile Build image from docker/rocm.Dockerfile"
|
||||
echo " --gpu-arch ARCH GPU architecture for Dockerfile build (e.g., gfx950-rocm720)"
|
||||
echo " --rocm-version VERSION Override ROCm version for image lookup (e.g., rocm720)"
|
||||
exit 0
|
||||
;;
|
||||
*) echo "Unknown option $1"; exit 1;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
|
||||
# Detect GPU architecture from the Kubernetes runner hostname
|
||||
HOSTNAME_VALUE=$(hostname)
|
||||
GPU_ARCH="mi30x" # default
|
||||
|
||||
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
||||
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
||||
GPU_ARCH="${BASH_REMATCH[1]}"
|
||||
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
|
||||
else
|
||||
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
|
||||
fi
|
||||
|
||||
# Normalise / collapse architectures we don't yet build specifically for
|
||||
case "${GPU_ARCH}" in
|
||||
mi35x)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
|
||||
;;
|
||||
mi30x|mi300|mi325)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
|
||||
GPU_ARCH="mi30x"
|
||||
;;
|
||||
*)
|
||||
echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
|
||||
GPU_ARCH="mi30x"
|
||||
;;
|
||||
esac
|
||||
|
||||
|
||||
# Set up DEVICE_FLAG based on Kubernetes pod info
|
||||
if [[ -f /etc/podinfo/gha-render-devices ]]; then
|
||||
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
|
||||
else
|
||||
DEVICE_FLAG="--device /dev/dri"
|
||||
fi
|
||||
|
||||
|
||||
# Find the latest image
|
||||
find_latest_image() {
|
||||
local gpu_arch=$1
|
||||
local base_tag days_back image_tag
|
||||
|
||||
case "${gpu_arch}" in
|
||||
mi30x) base_tag="${MI30X_BASE_TAG}" ;;
|
||||
mi35x) base_tag="${MI35X_BASE_TAG}" ;;
|
||||
*) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
|
||||
esac
|
||||
|
||||
# First, check local cache
|
||||
for days_back in {0..6}; do
|
||||
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
|
||||
local local_image="rocm/sgl-dev:${image_tag}"
|
||||
image_id=$(docker images -q "${local_image}")
|
||||
if [[ -n "$image_id" ]]; then
|
||||
echo "Found cached image locally: ${local_image}" >&2
|
||||
echo "${local_image}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
# If not found locally, fall back to pulling from public registry
|
||||
for days_back in {0..6}; do
|
||||
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
|
||||
echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
|
||||
if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
|
||||
echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
|
||||
echo "rocm/sgl-dev:${image_tag}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
# If still not found, try finding any image matching ROCm+arch from remote registry
|
||||
echo "Exact version not found. Searching remote registry for any ${ROCM_VERSION}-${gpu_arch} image…" >&2
|
||||
for days_back in {0..6}; do
|
||||
local target_date=$(date -d "${days_back} days ago" +%Y%m%d)
|
||||
local remote_tags=$(curl -s "https://registry.hub.docker.com/v2/repositories/rocm/sgl-dev/tags?page_size=100&name=${ROCM_VERSION}-${gpu_arch}-${target_date}" 2>/dev/null | grep -o '"name":"[^"]*"' | cut -d'"' -f4 | head -n 1)
|
||||
if [[ -n "$remote_tags" ]]; then
|
||||
echo "Found available image: rocm/sgl-dev:${remote_tags}" >&2
|
||||
echo "rocm/sgl-dev:${remote_tags}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
echo "No recent images found. Searching any cached local images matching ROCm+arch…" >&2
|
||||
local any_local
|
||||
any_local=$(docker images --format '{{.Repository}}:{{.Tag}}' --filter "reference=rocm/sgl-dev:*${ROCM_VERSION}*${gpu_arch}*" | sort -r | head -n 1)
|
||||
if [[ -n "$any_local" ]]; then
|
||||
echo "Using cached fallback image: ${any_local}" >&2
|
||||
echo "${any_local}"
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
|
||||
echo "Using hard-coded fallback for ${ROCM_VERSION}…" >&2
|
||||
case "${ROCM_VERSION}" in
|
||||
rocm720)
|
||||
if [[ "${gpu_arch}" == "mi35x" ]]; then
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi35x-20260211-preview"
|
||||
else
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi30x-20260211-preview"
|
||||
fi
|
||||
;;
|
||||
rocm700)
|
||||
if [[ "${gpu_arch}" == "mi35x" ]]; then
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi35x-20260211"
|
||||
else
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi30x-20260211"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Error: no hard-coded fallback available for ${ROCM_VERSION}" >&2
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Determine which image to use
|
||||
if [[ -n "${CUSTOM_IMAGE}" ]]; then
|
||||
# Use explicitly provided custom image
|
||||
IMAGE="${CUSTOM_IMAGE}"
|
||||
echo "Using custom image: ${IMAGE}"
|
||||
docker pull "${IMAGE}"
|
||||
elif [[ -n "${BUILD_FROM_DOCKERFILE}" ]]; then
|
||||
# Build image from Dockerfile
|
||||
if [[ -z "${GPU_ARCH_BUILD}" ]]; then
|
||||
echo "Error: --gpu-arch is required when using --build-from-dockerfile" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOCKERFILE_DIR="${GITHUB_WORKSPACE:-$PWD}/docker"
|
||||
DOCKERFILE="${DOCKERFILE_DIR}/rocm.Dockerfile"
|
||||
|
||||
if [[ ! -f "${DOCKERFILE}" ]]; then
|
||||
echo "Error: Dockerfile not found at ${DOCKERFILE}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
IMAGE="sglang-ci:${GPU_ARCH_BUILD}-$(date +%Y%m%d)"
|
||||
echo "Building Docker image from ${DOCKERFILE} with GPU_ARCH=${GPU_ARCH_BUILD}..."
|
||||
|
||||
# Pass full GPU_ARCH (e.g., gfx950-rocm720) - Dockerfile handles stripping suffix
|
||||
docker build \
|
||||
--build-arg GPU_ARCH="${GPU_ARCH_BUILD}" \
|
||||
--build-arg SGL_BRANCH="main" \
|
||||
-t "${IMAGE}" \
|
||||
-f "${DOCKERFILE}" \
|
||||
"${DOCKERFILE_DIR}"
|
||||
echo "Successfully built image: ${IMAGE}"
|
||||
else
|
||||
# Find the latest pre-built image
|
||||
IMAGE=$(find_latest_image "${GPU_ARCH}")
|
||||
echo "Pulling Docker image: ${IMAGE}"
|
||||
docker pull "${IMAGE}"
|
||||
fi
|
||||
|
||||
CACHE_HOST=/home/runner/sgl-data
|
||||
if [[ -d "$CACHE_HOST" ]]; then
|
||||
CACHE_VOLUME="-v $CACHE_HOST:/sgl-data"
|
||||
else
|
||||
CACHE_VOLUME=""
|
||||
fi
|
||||
|
||||
echo "Launching container: ci_sglang"
|
||||
docker run -dt --user root --device=/dev/kfd ${DEVICE_FLAG} \
|
||||
--ulimit nofile=65536:65536 \
|
||||
-v "${GITHUB_WORKSPACE:-$PWD}:/sglang-checkout" \
|
||||
$CACHE_VOLUME \
|
||||
--group-add video \
|
||||
--shm-size 32g \
|
||||
--cap-add=SYS_PTRACE \
|
||||
-e HF_TOKEN="${HF_TOKEN:-}" \
|
||||
-e HF_HOME=/sgl-data/hf-cache \
|
||||
-e HF_HUB_ETAG_TIMEOUT=300 \
|
||||
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
|
||||
-e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \
|
||||
-e MIOPEN_CUSTOM_CACHE_DIR=/sgl-data/miopen-cache \
|
||||
-e PYTHONPATH="/opt/tilelang:${PYTHONPATH:-}" \
|
||||
--security-opt seccomp=unconfined \
|
||||
-w /sglang-checkout \
|
||||
--name ci_sglang \
|
||||
"${IMAGE}"
|
||||
|
||||
# The checkout is owned by the runner (non-root) but the container runs as
|
||||
# root. Git >= 2.35.2 rejects cross-user repos; mark the mount as safe so
|
||||
# setuptools-scm / vcs_versioning can resolve the package version.
|
||||
docker exec ci_sglang git config --global --add safe.directory /sglang-checkout
|
||||
270
third_party/sglang/scripts/ci/amd/amd_ci_start_container_disagg.sh
vendored
Executable file
270
third_party/sglang/scripts/ci/amd/amd_ci_start_container_disagg.sh
vendored
Executable file
@@ -0,0 +1,270 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Get version from git tags
|
||||
SGLANG_VERSION="v0.5.5" # Default version, will be overridden if git tags are found
|
||||
|
||||
# Fetch tags from origin to ensure we have the latest
|
||||
if git fetch --tags origin; then
|
||||
# Get the latest version tag sorted by version number (e.g., v0.5.7)
|
||||
VERSION_FROM_TAG=$(git tag -l 'v[0-9]*' --sort=-v:refname | head -1)
|
||||
if [ -n "$VERSION_FROM_TAG" ]; then
|
||||
SGLANG_VERSION="$VERSION_FROM_TAG"
|
||||
echo "Using SGLang version from git tags: $SGLANG_VERSION"
|
||||
else
|
||||
echo "Warning: No version tags found; using default $SGLANG_VERSION" >&2
|
||||
fi
|
||||
else
|
||||
echo "Warning: Failed to fetch tags from origin; using default $SGLANG_VERSION" >&2
|
||||
fi
|
||||
|
||||
|
||||
# Default base tags (can be overridden by command line arguments)
|
||||
ROCM_VERSION="rocm700"
|
||||
DEFAULT_MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
|
||||
DEFAULT_MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
|
||||
|
||||
# Parse command line arguments
|
||||
MI30X_BASE_TAG="${DEFAULT_MI30X_BASE_TAG}"
|
||||
MI35X_BASE_TAG="${DEFAULT_MI35X_BASE_TAG}"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--mi30x-base-tag) MI30X_BASE_TAG="$2"; shift 2;;
|
||||
--mi35x-base-tag) MI35X_BASE_TAG="$2"; shift 2;;
|
||||
--rocm-version)
|
||||
ROCM_VERSION="$2"
|
||||
MI30X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi30x"
|
||||
MI35X_BASE_TAG="${SGLANG_VERSION}-${ROCM_VERSION}-mi35x"
|
||||
echo "Using ROCm version override: ${ROCM_VERSION}"
|
||||
shift 2;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [--mi30x-base-tag TAG] [--mi35x-base-tag TAG] [--rocm-version VERSION]"
|
||||
exit 0
|
||||
;;
|
||||
*) echo "Unknown option $1"; exit 1;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
|
||||
# Detect GPU architecture from the Kubernetes runner hostname
|
||||
HOSTNAME_VALUE=$(hostname)
|
||||
GPU_ARCH="mi30x" # default
|
||||
|
||||
# Host names look like: linux-mi35x-gpu-1-xxxxx-runner-zzzzz
|
||||
if [[ "${HOSTNAME_VALUE}" =~ ^linux-(mi[0-9]+[a-z]*)-gpu-[0-9]+ ]]; then
|
||||
GPU_ARCH="${BASH_REMATCH[1]}"
|
||||
echo "Detected GPU architecture from hostname: ${GPU_ARCH}"
|
||||
else
|
||||
echo "Warning: could not parse GPU architecture from '${HOSTNAME_VALUE}', defaulting to ${GPU_ARCH}"
|
||||
fi
|
||||
|
||||
# Normalise / collapse architectures we don’t yet build specifically for
|
||||
case "${GPU_ARCH}" in
|
||||
mi35x)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi35x image."
|
||||
;;
|
||||
mi30x|mi300|mi325)
|
||||
echo "Runner uses ${GPU_ARCH}; will fetch mi30x image."
|
||||
GPU_ARCH="mi30x"
|
||||
;;
|
||||
*)
|
||||
echo "Runner architecture '${GPU_ARCH}' unrecognised; defaulting to mi30x image." >&2
|
||||
GPU_ARCH="mi30x"
|
||||
;;
|
||||
esac
|
||||
|
||||
|
||||
# Set up DEVICE_FLAG based on Kubernetes pod info
|
||||
if [[ -f /etc/podinfo/gha-render-devices ]]; then
|
||||
DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices)
|
||||
else
|
||||
DEVICE_FLAG="--device /dev/dri"
|
||||
fi
|
||||
|
||||
|
||||
# Find the latest image
|
||||
find_latest_image() {
|
||||
local gpu_arch=$1
|
||||
local base_tag days_back image_tag
|
||||
|
||||
case "${gpu_arch}" in
|
||||
mi30x) base_tag="${MI30X_BASE_TAG}" ;;
|
||||
mi35x) base_tag="${MI35X_BASE_TAG}" ;;
|
||||
*) echo "Error: unsupported GPU architecture '${gpu_arch}'" >&2; return 1 ;;
|
||||
esac
|
||||
|
||||
# First, check local cache
|
||||
for days_back in {0..6}; do
|
||||
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
|
||||
local local_image="rocm/sgl-dev:${image_tag}"
|
||||
image_id=$(docker images -q "${local_image}")
|
||||
if [[ -n "$image_id" ]]; then
|
||||
echo "Found cached image locally: ${local_image}" >&2
|
||||
echo "${local_image}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
# If not found locally, fall back to pulling from public registry
|
||||
for days_back in {0..6}; do
|
||||
image_tag="${base_tag}-$(date -d "${days_back} days ago" +%Y%m%d)"
|
||||
echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2
|
||||
if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then
|
||||
echo "Found available image: rocm/sgl-dev:${image_tag}" >&2
|
||||
echo "rocm/sgl-dev:${image_tag}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
# If still not found, try finding any image matching ROCm+arch from remote registry
|
||||
echo "Exact version not found. Searching remote registry for any ${ROCM_VERSION}-${gpu_arch} image…" >&2
|
||||
for days_back in {0..6}; do
|
||||
local target_date=$(date -d "${days_back} days ago" +%Y%m%d)
|
||||
local remote_tags=$(curl -s "https://registry.hub.docker.com/v2/repositories/rocm/sgl-dev/tags?page_size=100&name=${ROCM_VERSION}-${gpu_arch}-${target_date}" 2>/dev/null | grep -o '"name":"[^"]*"' | cut -d'"' -f4 | head -n 1)
|
||||
if [[ -n "$remote_tags" ]]; then
|
||||
echo "Found available image: rocm/sgl-dev:${remote_tags}" >&2
|
||||
echo "rocm/sgl-dev:${remote_tags}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
|
||||
echo "No recent images found. Searching any cached local images matching ROCm+arch…" >&2
|
||||
local any_local
|
||||
any_local=$(docker images --format '{{.Repository}}:{{.Tag}}' --filter "reference=rocm/sgl-dev:*${ROCM_VERSION}*${gpu_arch}*" | sort -r | head -n 1)
|
||||
if [[ -n "$any_local" ]]; then
|
||||
echo "Using cached fallback image: ${any_local}" >&2
|
||||
echo "${any_local}"
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "Error: no ${gpu_arch} image found in the last 7 days for base ${base_tag}" >&2
|
||||
echo "Using hard-coded fallback for ${ROCM_VERSION}…" >&2
|
||||
case "${ROCM_VERSION}" in
|
||||
rocm720)
|
||||
if [[ "${gpu_arch}" == "mi35x" ]]; then
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi35x-20260211-preview"
|
||||
else
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm720-mi30x-20260211-preview"
|
||||
fi
|
||||
;;
|
||||
rocm700)
|
||||
if [[ "${gpu_arch}" == "mi35x" ]]; then
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi35x-20260211"
|
||||
else
|
||||
echo "rocm/sgl-dev:v0.5.8.post1-rocm700-mi30x-20260211"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Error: no hard-coded fallback available for ${ROCM_VERSION}" >&2
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Pull and run the latest image
|
||||
IMAGE=$(find_latest_image "${GPU_ARCH}")
|
||||
echo "Pulling Docker image: ${IMAGE}"
|
||||
docker pull "${IMAGE}"
|
||||
|
||||
CACHE_HOST=/home/runner/sgl-data
|
||||
if [[ -d "$CACHE_HOST" ]]; then
|
||||
CACHE_VOLUME="-v $CACHE_HOST:/sgl-data"
|
||||
else
|
||||
CACHE_VOLUME=""
|
||||
fi
|
||||
|
||||
# Detect libionic library for RDMA support
|
||||
LIBIONIC_MOUNT=""
|
||||
IONIC_SYMLINK="/usr/lib/x86_64-linux-gnu/libibverbs/libionic-rdmav34.so"
|
||||
if [[ -L "$IONIC_SYMLINK" ]]; then
|
||||
LIBIONIC_LIB=$(readlink -f "$IONIC_SYMLINK" 2>/dev/null)
|
||||
if [[ -f "$LIBIONIC_LIB" ]]; then
|
||||
echo "Found libionic library: $LIBIONIC_LIB (resolved from symlink)"
|
||||
LIBIONIC_MOUNT="-v ${LIBIONIC_LIB}:${LIBIONIC_LIB}:ro"
|
||||
else
|
||||
echo "Warning: libionic symlink exists but target does not: $LIBIONIC_LIB"
|
||||
fi
|
||||
else
|
||||
# Fallback: try to find directly
|
||||
LIBIONIC_FOUND=$(find /usr/lib/x86_64-linux-gnu -maxdepth 1 -name "libionic.so.*" 2>/dev/null | head -1)
|
||||
if [[ -n "$LIBIONIC_FOUND" ]]; then
|
||||
LIBIONIC_LIB=$(readlink -f "$LIBIONIC_FOUND" 2>/dev/null)
|
||||
if [[ -f "$LIBIONIC_LIB" ]]; then
|
||||
echo "Found libionic library: $LIBIONIC_LIB"
|
||||
LIBIONIC_MOUNT="-v ${LIBIONIC_LIB}:${LIBIONIC_LIB}:ro"
|
||||
else
|
||||
echo "Warning: libionic found but cannot resolve real path: $LIBIONIC_FOUND"
|
||||
fi
|
||||
else
|
||||
echo "Warning: libionic library not found on host, RDMA may not work"
|
||||
fi
|
||||
fi
|
||||
|
||||
MOUNT_ARGS=""
|
||||
|
||||
add_mount_if_exists() {
|
||||
local name=$1
|
||||
local search_pattern=$2
|
||||
local path=$(find /lib/x86_64-linux-gnu /usr/lib/x86_64-linux-gnu /lib64 /usr/lib64 -name "$search_pattern" -print -quit 2>/dev/null)
|
||||
|
||||
if [ -n "$path" ]; then
|
||||
echo "Found $name at: $path"
|
||||
MOUNT_ARGS="$MOUNT_ARGS -v $path:$path:ro"
|
||||
else
|
||||
echo "WARNING: Could not find $name on host! (Pattern: $search_pattern)"
|
||||
fi
|
||||
}
|
||||
|
||||
IONIC_LINK="/usr/lib/x86_64-linux-gnu/libibverbs/libionic-rdmav34.so"
|
||||
if [ -L "$IONIC_LINK" ]; then
|
||||
IONIC_REAL=$(readlink -f "$IONIC_LINK")
|
||||
if [ -f "$IONIC_REAL" ]; then
|
||||
echo "Ionic Driver: $IONIC_REAL"
|
||||
MOUNT_ARGS="$MOUNT_ARGS -v $IONIC_REAL:$IONIC_REAL:ro"
|
||||
fi
|
||||
fi
|
||||
|
||||
add_mount_if_exists "libnl-3" "libnl-3.so*"
|
||||
add_mount_if_exists "libmnl" "libmnl.so*"
|
||||
|
||||
echo "Mount args: $MOUNT_ARGS"
|
||||
|
||||
echo "Launching container: ci_sglang"
|
||||
docker run -dt --user root \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
${DEVICE_FLAG} \
|
||||
-v "${GITHUB_WORKSPACE:-$PWD}:/sglang-checkout" \
|
||||
-v /sys/class/infiniband:/sys/class/infiniband:ro \
|
||||
-v /sys/class/infiniband_verbs:/sys/class/infiniband_verbs:ro \
|
||||
-v /sys/class/net:/sys/class/net:ro \
|
||||
-v /etc/libibverbs.d:/etc/libibverbs.d:ro \
|
||||
-v /usr/lib/x86_64-linux-gnu/libibverbs:/usr/lib/x86_64-linux-gnu/libibverbs:ro \
|
||||
$MOUNT_ARGS \
|
||||
$CACHE_VOLUME \
|
||||
--privileged \
|
||||
--network=host \
|
||||
--ipc=host \
|
||||
--ulimit memlock=-1 \
|
||||
--cap-add=IPC_LOCK \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--group-add video \
|
||||
--group-add rdma \
|
||||
--shm-size 32g \
|
||||
-e HF_TOKEN="${HF_TOKEN:-}" \
|
||||
-e HF_HOME=/sgl-data/hf-cache \
|
||||
-e HF_HUB_ETAG_TIMEOUT=300 \
|
||||
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
|
||||
-e MIOPEN_USER_DB_PATH=/sgl-data/miopen-cache \
|
||||
-e MIOPEN_CUSTOM_CACHE_DIR=/sgl-data/miopen-cache \
|
||||
-w /sglang-checkout \
|
||||
--name ci_sglang \
|
||||
"${IMAGE}"
|
||||
|
||||
# The checkout is owned by the runner (non-root) but the container runs as
|
||||
# root. Git >= 2.35.2 rejects cross-user repos; mark the mount as safe so
|
||||
# setuptools-scm / vcs_versioning can resolve the package version.
|
||||
docker exec ci_sglang git config --global --add safe.directory /sglang-checkout
|
||||
151
third_party/sglang/scripts/ci/amd/amd_ci_warmup_aiter.py
vendored
Executable file
151
third_party/sglang/scripts/ci/amd/amd_ci_warmup_aiter.py
vendored
Executable file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Warmup script to pre-build AITER JIT kernels.
|
||||
|
||||
This script triggers compilation of commonly used AITER kernels by importing
|
||||
the relevant modules and calling functions with sample data. This avoids
|
||||
timeouts during actual tests when kernels need to be compiled on first use.
|
||||
|
||||
Run this after clearing pre-built AITER kernels from the Docker image.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Ensure AITER is enabled
|
||||
os.environ["SGLANG_USE_AITER"] = "1"
|
||||
|
||||
|
||||
def warmup_aiter_kernels():
|
||||
"""Trigger AITER JIT kernel compilation."""
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA/ROCm not available, skipping AITER warmup")
|
||||
return
|
||||
|
||||
print("=" * 60)
|
||||
print("AITER JIT Kernel Warmup")
|
||||
print("=" * 60)
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
start_time = time.time()
|
||||
|
||||
# Warmup module_rmsnorm_quant (small module, ~2MB)
|
||||
# Triggered by rmsnorm2d_fwd when hidden_size <= 8192
|
||||
try:
|
||||
print(
|
||||
"\n[1/5] Warming up module_rmsnorm_quant (rmsnorm2d_fwd, hidden<=8192)..."
|
||||
)
|
||||
from aiter import rmsnorm2d_fwd
|
||||
|
||||
hidden_size = 4096
|
||||
batch_size = 512 # Use larger batch to match CUDA graph capture
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# hidden_size=4096 <= 8192 -> takes rmsnorm() path -> compiles module_rmsnorm_quant
|
||||
_ = rmsnorm2d_fwd(x, weight, eps)
|
||||
torch.cuda.synchronize()
|
||||
print(" module_rmsnorm_quant compiled successfully")
|
||||
except Exception as e:
|
||||
print(f" module_rmsnorm_quant warmup failed: {e}")
|
||||
|
||||
# Warmup module_rmsnorm (large CK module, ~159MB)
|
||||
# Triggered by rmsnorm2d_fwd_with_add (always uses CK path)
|
||||
# NOTE: rmsnorm2d_fwd_with_add signature is:
|
||||
# rmsnorm2d_fwd_with_add(out, input, residual_in, residual_out, weight, epsilon)
|
||||
try:
|
||||
print("\n[2/5] Warming up module_rmsnorm (rmsnorm2d_fwd_with_add, CK path)...")
|
||||
from aiter import rmsnorm2d_fwd_with_add
|
||||
|
||||
hidden_size = 4096
|
||||
batch_size = 512
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
residual_in = torch.randn(
|
||||
batch_size, hidden_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
# This triggers JIT compilation of module_rmsnorm (CK kernels)
|
||||
rmsnorm2d_fwd_with_add(output, x, residual_in, residual_out, weight, eps)
|
||||
torch.cuda.synchronize()
|
||||
print(" module_rmsnorm compiled successfully")
|
||||
except Exception as e:
|
||||
print(f" module_rmsnorm warmup failed: {e}")
|
||||
|
||||
# Warmup module_rmsnorm via rmsnorm2d_fwd with large hidden_size (CK path)
|
||||
# When hidden_size > 8192, rmsnorm2d_fwd takes the rmsnorm2d_fwd_ck path
|
||||
# which also uses module_rmsnorm (already compiled in step 2, but this
|
||||
# ensures the CK rmsnorm2d_fwd path is exercised as well)
|
||||
try:
|
||||
print("\n[3/5] Warming up rmsnorm2d_fwd CK path (hidden>8192)...")
|
||||
from aiter import rmsnorm2d_fwd
|
||||
|
||||
hidden_size = 16384 # > 8192 to trigger rmsnorm2d_fwd_ck (module_rmsnorm)
|
||||
batch_size = 32
|
||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
|
||||
eps = 1e-6
|
||||
|
||||
_ = rmsnorm2d_fwd(x, weight, eps)
|
||||
torch.cuda.synchronize()
|
||||
print(" rmsnorm2d_fwd CK path compiled successfully")
|
||||
except Exception as e:
|
||||
print(f" rmsnorm2d_fwd CK path warmup skipped: {e}")
|
||||
|
||||
# Warmup rotary embedding kernel if available
|
||||
try:
|
||||
print("\n[4/5] Warming up rotary embedding kernel...")
|
||||
from aiter import rotary_embedding
|
||||
|
||||
head_size = 128
|
||||
seq_len = 32
|
||||
num_heads = 32
|
||||
positions = torch.arange(seq_len, device=device)
|
||||
query = torch.randn(
|
||||
seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
seq_len, num_heads, head_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
cos = torch.ones(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)
|
||||
sin = torch.zeros(seq_len, head_size // 2, dtype=torch.bfloat16, device=device)
|
||||
|
||||
_ = rotary_embedding(positions, query, key, head_size, cos, sin, True)
|
||||
torch.cuda.synchronize()
|
||||
print(" Rotary embedding kernel compiled successfully")
|
||||
except Exception as e:
|
||||
print(f" Rotary embedding warmup skipped (may not be available): {e}")
|
||||
|
||||
# Warmup activation kernels if available
|
||||
try:
|
||||
print("\n[5/5] Warming up activation kernels...")
|
||||
from aiter import silu_and_mul
|
||||
|
||||
hidden_size = 4096
|
||||
batch_size = 512
|
||||
x = torch.randn(
|
||||
batch_size, hidden_size * 2, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
out = torch.empty(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
|
||||
silu_and_mul(out, x)
|
||||
torch.cuda.synchronize()
|
||||
print(" Activation kernel compiled successfully")
|
||||
except Exception as e:
|
||||
print(f" Activation warmup skipped (may not be available): {e}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print("\n" + "=" * 60)
|
||||
print(f"AITER warmup completed in {elapsed:.1f}s")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
warmup_aiter_kernels()
|
||||
27
third_party/sglang/scripts/ci/amd/check_vram_clear.sh
vendored
Executable file
27
third_party/sglang/scripts/ci/amd/check_vram_clear.sh
vendored
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
check_vram_clear() {
|
||||
local vram_threshold_percent=5 # Allow up to 5% VRAM usage
|
||||
local memory_threshold_mb=500 # Allow up to 500MB memory usage
|
||||
|
||||
if command -v rocm-smi >/dev/null 2>&1; then
|
||||
echo "Checking ROCm GPU VRAM usage..."
|
||||
# Check if any GPU has more than threshold VRAM allocated
|
||||
local high_usage=$(rocm-smi --showmemuse | grep -E "GPU Memory Allocated \(VRAM%\): ([6-9]|[1-9][0-9]|100)")
|
||||
if [ -n "$high_usage" ]; then
|
||||
echo "ERROR: VRAM usage exceeds threshold (${vram_threshold_percent}%) on some GPUs:"
|
||||
echo "$high_usage"
|
||||
rocm-smi --showmemuse
|
||||
return 1
|
||||
else
|
||||
echo "✓ VRAM usage is within acceptable limits on all GPUs"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# If this script is run directly (not sourced), run the check
|
||||
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
|
||||
set -e
|
||||
check_vram_clear
|
||||
fi
|
||||
103
third_party/sglang/scripts/ci/amd/ensure_vram_clear.sh
vendored
Executable file
103
third_party/sglang/scripts/ci/amd/ensure_vram_clear.sh
vendored
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Source the VRAM checking function
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "$SCRIPT_DIR/check_vram_clear.sh"
|
||||
|
||||
ensure_vram_clear() {
|
||||
local max_retries=3
|
||||
local retry_count=0
|
||||
|
||||
# Stop and remove any existing ci_sglang container
|
||||
echo "Stopping any existing ci_sglang container..."
|
||||
docker stop ci_sglang || true
|
||||
docker rm ci_sglang || true
|
||||
|
||||
# Log host information for debugging
|
||||
echo "=== Host Information ==="
|
||||
echo "Hostname: $(hostname)"
|
||||
echo "Host IP: $(hostname -I 2>/dev/null || echo 'N/A')"
|
||||
echo "Date: $(date)"
|
||||
echo "Mode: rocm"
|
||||
echo "========================"
|
||||
echo "Running in ROCm mode"
|
||||
|
||||
# Show initial GPU status
|
||||
echo "=== Initial GPU Memory Status ==="
|
||||
rocm-smi --showmemuse
|
||||
echo "=================================="
|
||||
|
||||
while [ $retry_count -lt $max_retries ]; do
|
||||
echo "=== Cleanup Attempt $((retry_count + 1))/$max_retries ==="
|
||||
|
||||
# Clean SGLang processes
|
||||
echo "Killing SGLang processes..."
|
||||
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9 || true
|
||||
|
||||
if [ $retry_count -gt 0 ]; then
|
||||
echo "Performing aggressive cleanup..."
|
||||
# Kill all processes using KFD
|
||||
rocm-smi --showpids 2>/dev/null | grep 'PID:' | awk '{print $2}' | xargs -r kill -9 2>/dev/null || true
|
||||
# Wait a bit for cleanup to take effect
|
||||
echo "Waiting 30 seconds for VRAM to clear..."
|
||||
sleep 30
|
||||
fi
|
||||
|
||||
# Check VRAM
|
||||
echo "Checking VRAM status..."
|
||||
if check_vram_clear; then
|
||||
echo "✓ VRAM cleanup successful after $((retry_count + 1)) attempts"
|
||||
return 0
|
||||
else
|
||||
echo "✗ VRAM still not clear after attempt $((retry_count + 1))"
|
||||
retry_count=$((retry_count + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
# Failed after all retries
|
||||
echo "=== FAILED: VRAM cleanup unsuccessful after $max_retries attempts ==="
|
||||
echo "Final GPU status:"
|
||||
timeout 30 rocm-smi --showmemuse || echo "rocm-smi timed out"
|
||||
echo "Processes using GPU:"
|
||||
rocm-smi --showpids 2>/dev/null | grep -q 'PID:' || echo "No processes found using /dev/kfd"
|
||||
|
||||
# Print detailed information about suspicious processes
|
||||
echo "=== Detailed Process Information ==="
|
||||
if command -v rocm-smi >/dev/null 2>&1; then
|
||||
# For AMD GPUs, get processes from rocm-smi --showpids
|
||||
kfd_pids=$(rocm-smi --showpids 2>/dev/null | grep 'PID:' | awk '{print $2}' | sort -u)
|
||||
if [ -n "$kfd_pids" ]; then
|
||||
echo "Processes accessing /dev/kfd (AMD GPU device):"
|
||||
for pid in $kfd_pids; do
|
||||
if ps -p $pid -o pid,ppid,cmd --no-headers 2>/dev/null; then
|
||||
echo " └─ Command line: $(ps -p $pid -o cmd --no-headers 2>/dev/null | head -1)"
|
||||
else
|
||||
echo " └─ PID $pid: Process not found or already terminated"
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo "No processes found accessing /dev/kfd"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check for any remaining sglang-related processes
|
||||
echo "Checking for any remaining sglang-related processes:"
|
||||
sglang_procs=$(pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' 2>/dev/null)
|
||||
if [ -n "$sglang_procs" ]; then
|
||||
echo "Found sglang processes still running:"
|
||||
for pid in $sglang_procs; do
|
||||
ps -p $pid -o pid,ppid,cmd --no-headers 2>/dev/null || echo "PID $pid not found"
|
||||
done
|
||||
else
|
||||
echo "No sglang-related processes found."
|
||||
fi
|
||||
|
||||
echo "=================================================================="
|
||||
return 1
|
||||
}
|
||||
|
||||
# If this script is run directly (not sourced), run the ensure function
|
||||
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
|
||||
set -e
|
||||
ensure_vram_clear "$@"
|
||||
fi
|
||||
61
third_party/sglang/scripts/ci/amd/test_rccl_multi_gpu.py
vendored
Executable file
61
third_party/sglang/scripts/ci/amd/test_rccl_multi_gpu.py
vendored
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple RCCL test for multi-GPU communication.
|
||||
This test verifies that RCCL can initialize and communicate across multiple GPUs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def test_rccl_allreduce():
|
||||
"""Test basic RCCL allreduce operation across all GPUs."""
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available, skipping test")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize process group with NCCL (RCCL on AMD)
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
print(f"[Rank {rank}/{world_size}] Initialized successfully")
|
||||
|
||||
# Set device
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
print(f"[Rank {rank}] Device: {torch.cuda.get_device_name(device)}")
|
||||
print(
|
||||
f"[Rank {rank}] Device memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB"
|
||||
)
|
||||
|
||||
# Create a tensor and perform allreduce
|
||||
tensor = torch.ones(1000, device=device) * rank
|
||||
print(f"[Rank {rank}] Before allreduce: tensor sum = {tensor.sum().item()}")
|
||||
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
||||
|
||||
expected_sum = sum(range(world_size)) * 1000
|
||||
actual_sum = tensor.sum().item()
|
||||
|
||||
print(
|
||||
f"[Rank {rank}] After allreduce: tensor sum = {actual_sum}, expected = {expected_sum}"
|
||||
)
|
||||
|
||||
if abs(actual_sum - expected_sum) < 0.1:
|
||||
print(f"[Rank {rank}] ✓ RCCL allreduce test PASSED")
|
||||
dist.destroy_process_group()
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"[Rank {rank}] ✗ RCCL allreduce test FAILED")
|
||||
dist.destroy_process_group()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rccl_allreduce()
|
||||
58
third_party/sglang/scripts/ci/check_workflow_job_names.py
vendored
Executable file
58
third_party/sglang/scripts/ci/check_workflow_job_names.py
vendored
Executable file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Check that required status check job names are unique across workflows.
|
||||
|
||||
Duplicate job names on the same commit allow a passing job in one workflow
|
||||
to satisfy a required status check meant for a different workflow, bypassing
|
||||
branch protection.
|
||||
|
||||
See: https://github.com/sgl-project/sglang/pull/20208 for an example where
|
||||
pr-test-npu.yml's "pr-test-finish" job (which passed) caused GitHub to treat
|
||||
the required "pr-test-finish" check (from pr-test.yml, which failed) as met.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
|
||||
# Job names used as required status checks in branch protection.
|
||||
# These MUST be unique across all workflow files.
|
||||
PROTECTED_JOB_NAMES = {
|
||||
"pr-test-finish",
|
||||
"lint",
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
workflows = sorted(glob.glob(".github/workflows/*.yml"))
|
||||
job_to_files: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for wf in workflows:
|
||||
with open(wf) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not data or "jobs" not in data:
|
||||
continue
|
||||
for job in data["jobs"]:
|
||||
if job in PROTECTED_JOB_NAMES:
|
||||
job_to_files[job].append(wf)
|
||||
|
||||
duplicates = {job: files for job, files in job_to_files.items() if len(files) > 1}
|
||||
|
||||
if not duplicates:
|
||||
return 0
|
||||
|
||||
print("ERROR: Required status check job names must be unique across workflows.")
|
||||
print("Duplicates allow branch protection bypass via auto-merge.\n")
|
||||
for job, files in sorted(duplicates.items()):
|
||||
print(f" Job '{job}' appears in:")
|
||||
for f in files:
|
||||
print(f" - {f}")
|
||||
print()
|
||||
|
||||
print("Fix: rename the job in non-primary workflows to avoid collision.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
24
third_party/sglang/scripts/ci/cuda/cache_nvidia_wheels.sh
vendored
Executable file
24
third_party/sglang/scripts/ci/cuda/cache_nvidia_wheels.sh
vendored
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
# Cache and pre-install nvidia wheels that torch pins.
|
||||
#
|
||||
# pypi.nvidia.com returns Cache-Control: no-store, so pip re-downloads
|
||||
# cudnn (~707 MB) and nvshmem (~125 MB) on every CI run. This script
|
||||
# caches the wheels locally and installs them so that the subsequent
|
||||
# `pip install -e "python[dev]"` sees "Requirement already satisfied".
|
||||
#
|
||||
# Integrity: uses `unzip -t` to detect partial/corrupt downloads.
|
||||
#
|
||||
# Usage: source scripts/ci/cuda/cache_nvidia_wheels.sh
|
||||
|
||||
NVIDIA_WHEEL_CACHE="/root/.cache/nvidia-wheels"
|
||||
mkdir -p "$NVIDIA_WHEEL_CACHE"
|
||||
|
||||
for url in \
|
||||
"https://pypi.nvidia.com/nvidia-cudnn-cu12/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl" \
|
||||
"https://pypi.nvidia.com/nvidia-nvshmem-cu12/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"; do
|
||||
whl="$NVIDIA_WHEEL_CACHE/$(basename "$url")"
|
||||
[ -f "$whl" ] && unzip -tq "$whl" &>/dev/null || curl -fL -o "$whl" "$url"
|
||||
done
|
||||
|
||||
pip install --no-deps "$NVIDIA_WHEEL_CACHE"/nvidia_cudnn_cu12-*.whl \
|
||||
"$NVIDIA_WHEEL_CACHE"/nvidia_nvshmem_cu12-*.whl 2>/dev/null || true
|
||||
62
third_party/sglang/scripts/ci/cuda/ci_download_flashinfer_cubin.sh
vendored
Executable file
62
third_party/sglang/scripts/ci/cuda/ci_download_flashinfer_cubin.sh
vendored
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
# Download flashinfer cubins if the local set is incomplete.
|
||||
#
|
||||
# The flashinfer-cubin pip package may not include cubins for newer architectures
|
||||
# (e.g. sm_100, sm_120) due to PyPI size limits. This script checks the local
|
||||
# cubin status against the flashinfer artifact repository and downloads any
|
||||
# missing files.
|
||||
#
|
||||
# This script is best-effort: if the status check or download times out (e.g.
|
||||
# due to a GPU in error state blocking CUDA init), we warn and continue.
|
||||
# The pip package already includes cubins for common architectures (sm_80, sm_90).
|
||||
set -uxo pipefail
|
||||
|
||||
# Early exit: the pip package already includes cubins for sm_80 and sm_90.
|
||||
# Only sm_100+ (Blackwell) needs extra cubins downloaded. Skip the expensive
|
||||
# Python status check entirely if no such GPU is present.
|
||||
if COMPUTE_CAPS=$(timeout 10 nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null); then
|
||||
NEEDS_EXTRA_CUBINS=false
|
||||
while IFS= read -r cap; do
|
||||
major="${cap%%.*}"
|
||||
if [ "$major" -ge 10 ] 2>/dev/null; then
|
||||
NEEDS_EXTRA_CUBINS=true
|
||||
break
|
||||
fi
|
||||
done <<< "$COMPUTE_CAPS"
|
||||
if [ "$NEEDS_EXTRA_CUBINS" = false ]; then
|
||||
echo "All GPUs are sm_9x or older (compute caps: $(echo $COMPUTE_CAPS | tr '\n' ' ')), pip cubins sufficient — skipping download"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# Use timeout to prevent hangs when GPUs are in error state (the flashinfer
|
||||
# import can trigger CUDA init which blocks on bad GPUs).
|
||||
CUBIN_STATUS=$(timeout 60 python3 -c "
|
||||
import os
|
||||
os.environ.setdefault('CUDA_VISIBLE_DEVICES', '')
|
||||
from flashinfer.artifacts import get_artifacts_status
|
||||
status = get_artifacts_status()
|
||||
total = len(status)
|
||||
downloaded = sum(1 for _, exists in status if exists)
|
||||
print(f'{downloaded}/{total}')
|
||||
" 2>/dev/null) || CUBIN_STATUS="unknown"
|
||||
|
||||
echo "Flashinfer cubin status: ${CUBIN_STATUS}"
|
||||
|
||||
if echo "$CUBIN_STATUS" | grep -qE '^[0-9]+/[0-9]+$'; then
|
||||
CUBIN_DOWNLOADED="${CUBIN_STATUS%/*}"
|
||||
CUBIN_TOTAL="${CUBIN_STATUS#*/}"
|
||||
if [ "$CUBIN_DOWNLOADED" = "$CUBIN_TOTAL" ] && [ "$CUBIN_TOTAL" != "0" ]; then
|
||||
echo "All flashinfer cubins already present (${CUBIN_STATUS}), skipping download"
|
||||
else
|
||||
echo "Cubins incomplete (${CUBIN_STATUS}), downloading..."
|
||||
if ! timeout 300 env FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin; then
|
||||
echo "WARNING: flashinfer cubin download failed or timed out, continuing with existing cubins"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Could not determine cubin status (status check timed out or failed), attempting download..."
|
||||
if ! timeout 300 env FLASHINFER_LOGGING_LEVEL=warning python3 -m flashinfer --download-cubin; then
|
||||
echo "WARNING: flashinfer cubin download failed or timed out, continuing with existing cubins"
|
||||
fi
|
||||
fi
|
||||
69
third_party/sglang/scripts/ci/cuda/ci_download_flashinfer_jit_cache.sh
vendored
Executable file
69
third_party/sglang/scripts/ci/cuda/ci_download_flashinfer_jit_cache.sh
vendored
Executable file
@@ -0,0 +1,69 @@
|
||||
#!/bin/bash
|
||||
# Install flashinfer-jit-cache with caching and retry logic (flashinfer.ai can have transient DNS issues).
|
||||
# The jit-cache wheel is 1.2+ GB, so we skip the download entirely if already installed.
|
||||
#
|
||||
# Required environment (caller must export or set):
|
||||
# UNINSTALL_JIT_CACHE — literal true/false (skip download when false)
|
||||
# FLASHINFER_PYTHON_REQUIRED — e.g. from python/pyproject.toml (flashinfer_python)
|
||||
# CU_VERSION — e.g. cu129
|
||||
# PIP_CMD — e.g. "pip" or "uv pip"
|
||||
# PIP_INSTALL_SUFFIX — extra pip args for this runner
|
||||
set -euxo pipefail
|
||||
|
||||
: "${UNINSTALL_JIT_CACHE:?must be set}"
|
||||
: "${FLASHINFER_PYTHON_REQUIRED:?must be set}"
|
||||
: "${CU_VERSION:?must be set}"
|
||||
: "${PIP_CMD:?must be set}"
|
||||
|
||||
FLASHINFER_JIT_CACHE_INSTALLED=false
|
||||
if [ "$UNINSTALL_JIT_CACHE" = false ]; then
|
||||
FLASHINFER_JIT_CACHE_INSTALLED=true
|
||||
echo "flashinfer-jit-cache already at correct version, skipping download"
|
||||
fi
|
||||
|
||||
if [ "$FLASHINFER_JIT_CACHE_INSTALLED" = false ]; then
|
||||
FLASHINFER_CACHE_DIR="${HOME}/.cache/flashinfer-wheels"
|
||||
mkdir -p "${FLASHINFER_CACHE_DIR}"
|
||||
|
||||
FLASHINFER_WHEEL_PATTERN="flashinfer_jit_cache-${FLASHINFER_PYTHON_REQUIRED}*.whl"
|
||||
CACHED_WHEEL=$(find "${FLASHINFER_CACHE_DIR}" -name "${FLASHINFER_WHEEL_PATTERN}" -type f 2>/dev/null | head -n 1)
|
||||
|
||||
if [ -n "$CACHED_WHEEL" ] && [ -f "$CACHED_WHEEL" ]; then
|
||||
echo "Found cached flashinfer wheel: $CACHED_WHEEL"
|
||||
if $PIP_CMD install "$CACHED_WHEEL" $PIP_INSTALL_SUFFIX; then
|
||||
FLASHINFER_JIT_CACHE_INSTALLED=true
|
||||
echo "Successfully installed flashinfer-jit-cache from cache"
|
||||
else
|
||||
echo "Failed to install from cache, will try downloading..."
|
||||
rm -f "$CACHED_WHEEL"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$FLASHINFER_JIT_CACHE_INSTALLED" = false ]; then
|
||||
for i in {1..5}; do
|
||||
# Download wheel to cache directory (use pip directly as uv pip doesn't support download)
|
||||
if timeout 600 pip download "flashinfer-jit-cache==${FLASHINFER_PYTHON_REQUIRED}" \
|
||||
--index-url "https://flashinfer.ai/whl/${CU_VERSION}" \
|
||||
-d "${FLASHINFER_CACHE_DIR}"; then
|
||||
|
||||
CACHED_WHEEL=$(find "${FLASHINFER_CACHE_DIR}" -name "${FLASHINFER_WHEEL_PATTERN}" -type f 2>/dev/null | head -n 1)
|
||||
if [ -n "$CACHED_WHEEL" ] && [ -f "$CACHED_WHEEL" ]; then
|
||||
if $PIP_CMD install "$CACHED_WHEEL" $PIP_INSTALL_SUFFIX; then
|
||||
FLASHINFER_JIT_CACHE_INSTALLED=true
|
||||
echo "Successfully downloaded and installed flashinfer-jit-cache"
|
||||
break
|
||||
fi
|
||||
else
|
||||
echo "Warning: Download succeeded but wheel file not found"
|
||||
fi
|
||||
fi
|
||||
echo "Attempt $i to download flashinfer-jit-cache failed, retrying in 10 seconds..."
|
||||
sleep 10
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$FLASHINFER_JIT_CACHE_INSTALLED" = false ]; then
|
||||
echo "ERROR: Failed to install flashinfer-jit-cache after 5 attempts"
|
||||
exit 1
|
||||
fi
|
||||
119
third_party/sglang/scripts/ci/cuda/ci_install_deepep.sh
vendored
Executable file
119
third_party/sglang/scripts/ci/cuda/ci_install_deepep.sh
vendored
Executable file
@@ -0,0 +1,119 @@
|
||||
#!/bin/bash
|
||||
# Install the dependency in CI.
|
||||
set -euxo pipefail
|
||||
|
||||
bash scripts/ci/cuda/ci_install_dependency.sh
|
||||
|
||||
export GDRCOPY_HOME=/usr/src/gdrdrv-2.5.1/
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
GRACE_BLACKWELL=${GRACE_BLACKWELL:-0}
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" != "x86_64" ] && [ "$ARCH" != "aarch64" ]; then
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if python3 -c "import deep_ep" >/dev/null 2>&1; then
|
||||
echo "deep_ep is already installed or importable. Skipping installation."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Install system dependencies
|
||||
# Use fallback logic in case apt fails due to unrelated broken packages on the runner
|
||||
DEEPEP_SYSTEM_DEPS="curl wget git sudo rdma-core infiniband-diags openssh-server perftest libibumad3 libibverbs-dev libibverbs1 ibverbs-providers ibverbs-utils libnl-3-200 libnl-route-3-200 librdmacm1 build-essential cmake"
|
||||
apt-get install -y --no-install-recommends $DEEPEP_SYSTEM_DEPS || {
|
||||
echo "Warning: apt-get install failed, checking if required packages are available..."
|
||||
for pkg in $DEEPEP_SYSTEM_DEPS; do
|
||||
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
|
||||
echo "ERROR: Required package $pkg is not installed and apt-get failed"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo "All required packages are already installed, continuing..."
|
||||
}
|
||||
|
||||
# Install GDRCopy
|
||||
rm -rf /opt/gdrcopy && mkdir -p /opt/gdrcopy
|
||||
cd /opt/gdrcopy
|
||||
git clone https://github.com/NVIDIA/gdrcopy.git .
|
||||
git checkout v2.5.1
|
||||
apt-get update || true # May fail due to unrelated broken packages
|
||||
GDRCOPY_DEPS_1="nvidia-dkms-580"
|
||||
GDRCOPY_DEPS_2="build-essential devscripts debhelper fakeroot pkg-config dkms"
|
||||
GDRCOPY_DEPS_3="check libsubunit0 libsubunit-dev python3-venv"
|
||||
for deps_group in "$GDRCOPY_DEPS_1" "$GDRCOPY_DEPS_2" "$GDRCOPY_DEPS_3"; do
|
||||
apt-get install -y --no-install-recommends $deps_group || {
|
||||
echo "Warning: apt-get install failed for '$deps_group', checking if packages are available..."
|
||||
for pkg in $deps_group; do
|
||||
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
|
||||
echo "ERROR: Required package $pkg is not installed and apt-get failed"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo "All required packages from '$deps_group' are already installed, continuing..."
|
||||
}
|
||||
done
|
||||
cd packages
|
||||
CUDA=/usr/local/cuda ./build-deb-packages.sh
|
||||
dpkg -i gdrdrv-dkms_*.deb
|
||||
dpkg -i libgdrapi_*.deb
|
||||
dpkg -i gdrcopy-tests_*.deb
|
||||
dpkg -i gdrcopy_*.deb
|
||||
|
||||
# Set up library paths based on architecture
|
||||
LIB_PATH="/usr/lib/$ARCH-linux-gnu"
|
||||
if [ ! -e "$LIB_PATH/libmlx5.so" ]; then
|
||||
ln -s $LIB_PATH/libmlx5.so.1 $LIB_PATH/libmlx5.so
|
||||
fi
|
||||
apt-get update || true
|
||||
apt-get install -y --no-install-recommends libfabric-dev || {
|
||||
if ! dpkg -l libfabric-dev 2>/dev/null | grep -q "^ii"; then
|
||||
echo "ERROR: Required package libfabric-dev is not installed and apt-get failed"
|
||||
exit 1
|
||||
fi
|
||||
echo "libfabric-dev is already installed, continuing..."
|
||||
}
|
||||
|
||||
# Install DeepEP
|
||||
DEEPEP_DIR=/root/.cache/deepep
|
||||
rm -rf ${DEEPEP_DIR}
|
||||
if [ "$GRACE_BLACKWELL" = "1" ]; then
|
||||
# We use Tom's DeepEP fork for GB200 for now, which supports fp4 dispatch.
|
||||
GRACE_BLACKWELL_DEEPEP_BRANCH=gb200_blog_part_2
|
||||
git clone https://github.com/fzyzcjy/DeepEP.git ${DEEPEP_DIR} && \
|
||||
pushd ${DEEPEP_DIR} && \
|
||||
git checkout ${GRACE_BLACKWELL_DEEPEP_BRANCH} && \
|
||||
sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh && \
|
||||
popd
|
||||
else
|
||||
git clone https://github.com/deepseek-ai/DeepEP.git ${DEEPEP_DIR} && \
|
||||
pushd ${DEEPEP_DIR} && \
|
||||
git checkout 9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee && \
|
||||
popd
|
||||
fi
|
||||
|
||||
cd ${DEEPEP_DIR}
|
||||
if [ "$GRACE_BLACKWELL" = "1" ]; then
|
||||
CUDA_VERSION=$(nvidia-smi | grep "CUDA Version" | head -n1 | awk '{print $9}')
|
||||
if [ "$CUDA_VERSION" = "12.8" ]; then
|
||||
CHOSEN_TORCH_CUDA_ARCH_LIST='10.0'
|
||||
elif awk -v ver="$CUDA_VERSION" 'BEGIN {exit !(ver > 12.8)}'; then
|
||||
# With cuda > 12.8, the compiler supports 10.3, so we should use
|
||||
# CHOSEN_TORCH_CUDA_ARCH_LIST='10.0;10.3'
|
||||
#
|
||||
# However, our CI machine has a weird setup and nvidia-smi reports wrong CUDA version in the container.
|
||||
# The container is actually cuda 12.8, but nvidia-smi reports 13.0, leading to compilation errors. so we
|
||||
# drop 10.3.
|
||||
CHOSEN_TORCH_CUDA_ARCH_LIST='10.0'
|
||||
else
|
||||
echo "Unsupported CUDA version for Grace Blackwell: $CUDA_VERSION" && exit 1
|
||||
fi && \
|
||||
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
|
||||
sed -i "/^ include_dirs = \['csrc\/'\]/a\ include_dirs.append('${CUDA_HOME}/include/cccl')" setup.py; \
|
||||
fi
|
||||
TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install --no-build-isolation .
|
||||
else
|
||||
python3 setup.py install
|
||||
fi
|
||||
384
third_party/sglang/scripts/ci/cuda/ci_install_dependency.sh
vendored
Executable file
384
third_party/sglang/scripts/ci/cuda/ci_install_dependency.sh
vendored
Executable file
@@ -0,0 +1,384 @@
|
||||
#!/bin/bash
|
||||
# Install the dependency in CI.
|
||||
#
|
||||
# Structure (see section banners below):
|
||||
# - Configuration & timing
|
||||
# - Host / runner detection (arch, Blackwell, pip vs uv)
|
||||
# - Kill existing processes
|
||||
# - Install apt packages
|
||||
# - Python package site hygiene & install protoc
|
||||
# - Pip / uv toolchain & stale package cleanup
|
||||
# - Uninstall Flashinfer
|
||||
# - Install main package
|
||||
# - Install sglang-kernel
|
||||
# - Install sglang-router
|
||||
# - Download flashinfer artifacts
|
||||
# - Install extra dependency
|
||||
# - Fix other dependencies
|
||||
# - Prepare runner
|
||||
# - Verify imports
|
||||
set -euxo pipefail
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Configuration & timing
|
||||
# ------------------------------------------------------------------------------
|
||||
# Set up environment variables
|
||||
CU_VERSION="cu129"
|
||||
|
||||
# Nvidia package versions we override (torch pins older versions).
|
||||
# Used both as pip constraints during install and for post-install verification.
|
||||
NVIDIA_CUDNN_VERSION="9.16.0.29"
|
||||
NVIDIA_NVSHMEM_VERSION="3.4.5"
|
||||
OPTIONAL_DEPS="${1:-}"
|
||||
|
||||
SECONDS=0
|
||||
_CI_MARK_PREV=${SECONDS}
|
||||
|
||||
mark_step_done() {
|
||||
local label=$1
|
||||
local now=${SECONDS}
|
||||
local step=$((now - _CI_MARK_PREV))
|
||||
printf '\n[STEP DONE] %s, step: %ss, total: %ss, date: %s\n' \
|
||||
"${label}" "${step}" "${now}" "$(date -u '+%Y-%m-%dT%H:%M:%SZ')"
|
||||
_CI_MARK_PREV=${now}
|
||||
}
|
||||
|
||||
mark_step_done "Configuration"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Host / runner detection (CPU arch, Blackwell, USE_UV)
|
||||
# ------------------------------------------------------------------------------
|
||||
# Detect CPU architecture (x86_64 or aarch64)
|
||||
ARCH=$(uname -m)
|
||||
echo "Detected architecture: ${ARCH}"
|
||||
|
||||
# Detect GPU architecture (blackwell or not)
|
||||
if [ "${IS_BLACKWELL+set}" = set ]; then
|
||||
case "$IS_BLACKWELL" in 1 | true | yes) IS_BLACKWELL=1 ;; *) IS_BLACKWELL=0 ;; esac
|
||||
echo "IS_BLACKWELL=${IS_BLACKWELL} (manually set via environment)"
|
||||
else
|
||||
IS_BLACKWELL=0
|
||||
if command -v nvidia-smi >/dev/null 2>&1; then
|
||||
while IFS= read -r cap; do
|
||||
major="${cap%%.*}"
|
||||
if [ "${major:-0}" -ge 10 ] 2>/dev/null; then
|
||||
IS_BLACKWELL=1
|
||||
break
|
||||
fi
|
||||
done <<< "$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null || true)"
|
||||
fi
|
||||
echo "IS_BLACKWELL=${IS_BLACKWELL} (auto-detected via nvidia-smi)"
|
||||
fi
|
||||
|
||||
# Whether to use pip or uv to install dependencies
|
||||
if [ "${USE_UV+set}" != set ]; then
|
||||
if [ "$IS_BLACKWELL" = "1" ]; then
|
||||
# Our current b200 runners have some issues with uv, so we default to pip
|
||||
# It is a runner specific issue, not a GPU architecture issue.
|
||||
USE_UV=false
|
||||
else
|
||||
USE_UV=true
|
||||
fi
|
||||
fi
|
||||
case "$(printf '%s' "$USE_UV" | tr '[:upper:]' '[:lower:]')" in 1 | true | yes) USE_UV=1 ;; *) USE_UV=0 ;; esac
|
||||
echo "USE_UV=${USE_UV}"
|
||||
|
||||
mark_step_done "Host / runner detection"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Kill existing processes
|
||||
# ------------------------------------------------------------------------------
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
|
||||
python3 "${REPO_ROOT}/python/sglang/cli/killall.py"
|
||||
KILLALL_EXIT=$?
|
||||
echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-}"
|
||||
|
||||
if [ $KILLALL_EXIT -ne 0 ]; then
|
||||
echo "ERROR: killall.py detected uncleanable GPU memory. Aborting CI."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mark_step_done "Kill existing processes"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install apt packages
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install apt packages (including python3/pip which may be missing on some runners)
|
||||
# Use --no-install-recommends and ignore errors from unrelated broken packages on the runner
|
||||
# The NVIDIA driver packages may have broken dependencies that are unrelated to these packages
|
||||
# Run apt-get update first to refresh package index (stale index causes 404 on security.ubuntu.com)
|
||||
apt-get update || true
|
||||
CI_APT_PACKAGES=(
|
||||
python3 python3-pip python3-venv python3-dev git libnuma-dev libssl-dev pkg-config
|
||||
libibverbs-dev libibverbs1 ibverbs-providers ibverbs-utils
|
||||
ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev
|
||||
)
|
||||
apt-get install -y --no-install-recommends "${CI_APT_PACKAGES[@]}" || {
|
||||
echo "Warning: apt-get install failed, checking if required packages are available..."
|
||||
for pkg in "${CI_APT_PACKAGES[@]}"; do
|
||||
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
|
||||
echo "ERROR: Required package $pkg is not installed and apt-get failed"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo "All required packages are already installed, continuing..."
|
||||
}
|
||||
|
||||
mark_step_done "Install apt packages"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Python package site hygiene & install protoc
|
||||
# ------------------------------------------------------------------------------
|
||||
# Clear torch compilation cache
|
||||
python3 -c 'import os, shutil, tempfile, getpass; cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") or os.path.join(tempfile.gettempdir(), "torchinductor_" + getpass.getuser()); shutil.rmtree(cache_dir, ignore_errors=True)'
|
||||
|
||||
# Remove broken dist-info directories (missing METADATA per PEP 376)
|
||||
SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])")
|
||||
if [ -d "$SITE_PACKAGES" ]; then
|
||||
{ set +x; } 2>/dev/null
|
||||
find "$SITE_PACKAGES" -maxdepth 1 -name "*.dist-info" -type d | while read -r d; do
|
||||
if [ ! -f "$d/METADATA" ]; then
|
||||
echo "Removing broken dist-info: $d"
|
||||
rm -rf "$d"
|
||||
fi
|
||||
done
|
||||
set -x
|
||||
fi
|
||||
|
||||
# Install protoc
|
||||
bash "${SCRIPT_DIR}/../utils/install_protoc.sh"
|
||||
|
||||
mark_step_done "Python package site hygiene & install protoc"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Pip / uv toolchain & stale package cleanup
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install pip and uv (use python3 -m pip for robustness since some runners only have pip3)
|
||||
python3 -m pip install --upgrade pip
|
||||
|
||||
if [ "$USE_UV" = "0" ]; then
|
||||
PIP_CMD="pip"
|
||||
PIP_INSTALL_SUFFIX="--break-system-packages"
|
||||
PIP_UNINSTALL_CMD="pip uninstall -y"
|
||||
PIP_UNINSTALL_SUFFIX="--break-system-packages"
|
||||
else
|
||||
pip install uv
|
||||
export UV_SYSTEM_PYTHON=true
|
||||
|
||||
PIP_CMD="uv pip"
|
||||
PIP_INSTALL_SUFFIX="--index-strategy unsafe-best-match --prerelease allow"
|
||||
PIP_UNINSTALL_CMD="uv pip uninstall"
|
||||
PIP_UNINSTALL_SUFFIX=""
|
||||
fi
|
||||
|
||||
# Clean up existing installations
|
||||
$PIP_UNINSTALL_CMD sgl-kernel sglang-kernel sglang sgl-fa4 flash-attn-4 $PIP_UNINSTALL_SUFFIX || true
|
||||
|
||||
mark_step_done "Pip / uv toolchain & stale package cleanup"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Uninstall Flashinfer
|
||||
# ------------------------------------------------------------------------------
|
||||
# Keep flashinfer packages installed if version matches to avoid re-downloading:
|
||||
# - flashinfer-cubin: 150+ MB, plus extra cubins from ci_download_flashinfer_cubin.sh
|
||||
# - flashinfer-jit-cache: 1.2+ GB, by far the largest download in CI
|
||||
FLASHINFER_PYTHON_REQUIRED=$(grep -Po -m1 '(?<=flashinfer_python==)[0-9A-Za-z\.\-]+' python/pyproject.toml || echo "")
|
||||
FLASHINFER_CUBIN_REQUIRED=$(grep -Po -m1 '(?<=flashinfer_cubin==)[0-9A-Za-z\.\-]+' python/pyproject.toml || echo "")
|
||||
FLASHINFER_CUBIN_INSTALLED=$(pip show flashinfer-cubin 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
|
||||
FLASHINFER_JIT_INSTALLED=$(pip show flashinfer-jit-cache 2>/dev/null | grep "^Version:" | awk '{print $2}' | sed 's/+.*//' || echo "")
|
||||
|
||||
UNINSTALL_CUBIN=true
|
||||
UNINSTALL_JIT_CACHE=true
|
||||
|
||||
if [ "$FLASHINFER_CUBIN_INSTALLED" = "$FLASHINFER_CUBIN_REQUIRED" ] && [ -n "$FLASHINFER_CUBIN_REQUIRED" ]; then
|
||||
echo "flashinfer-cubin==${FLASHINFER_CUBIN_REQUIRED} already installed, keeping it"
|
||||
UNINSTALL_CUBIN=false
|
||||
else
|
||||
echo "flashinfer-cubin version mismatch (installed: ${FLASHINFER_CUBIN_INSTALLED:-none}, required: ${FLASHINFER_CUBIN_REQUIRED}), reinstalling"
|
||||
fi
|
||||
|
||||
if [ "$FLASHINFER_JIT_INSTALLED" = "$FLASHINFER_PYTHON_REQUIRED" ] && [ -n "$FLASHINFER_PYTHON_REQUIRED" ]; then
|
||||
echo "flashinfer-jit-cache==${FLASHINFER_PYTHON_REQUIRED} already installed, keeping it"
|
||||
UNINSTALL_JIT_CACHE=false
|
||||
else
|
||||
echo "flashinfer-jit-cache version mismatch (installed: ${FLASHINFER_JIT_INSTALLED:-none}, required: ${FLASHINFER_PYTHON_REQUIRED}), will reinstall"
|
||||
fi
|
||||
|
||||
# Build uninstall list based on what needs updating
|
||||
FLASHINFER_UNINSTALL="flashinfer-python"
|
||||
[ "$UNINSTALL_CUBIN" = true ] && FLASHINFER_UNINSTALL="$FLASHINFER_UNINSTALL flashinfer-cubin"
|
||||
[ "$UNINSTALL_JIT_CACHE" = true ] && FLASHINFER_UNINSTALL="$FLASHINFER_UNINSTALL flashinfer-jit-cache"
|
||||
$PIP_UNINSTALL_CMD $FLASHINFER_UNINSTALL $PIP_UNINSTALL_SUFFIX || true
|
||||
$PIP_UNINSTALL_CMD opencv-python opencv-python-headless $PIP_UNINSTALL_SUFFIX || true
|
||||
|
||||
mark_step_done "Uninstall Flashinfer"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install main package
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install the main package
|
||||
EXTRAS="dev,runai,tracing"
|
||||
if [ -n "$OPTIONAL_DEPS" ]; then
|
||||
EXTRAS="dev,runai,tracing,${OPTIONAL_DEPS}"
|
||||
fi
|
||||
echo "Installing python extras: [${EXTRAS}]"
|
||||
source "$(dirname "$0")/cache_nvidia_wheels.sh"
|
||||
$PIP_CMD install -e "python[${EXTRAS}]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX
|
||||
|
||||
mark_step_done "Install main package"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install sglang-kernel
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install sgl-kernel
|
||||
SGL_KERNEL_VERSION_FROM_KERNEL=$(grep -Po '(?<=^version = ")[^"]*' sgl-kernel/pyproject.toml)
|
||||
SGL_KERNEL_VERSION_FROM_SRT=$(grep -Po -m1 '(?<=sglang-kernel==)[0-9A-Za-z\.\-]+' python/pyproject.toml)
|
||||
echo "SGL_KERNEL_VERSION_FROM_KERNEL=${SGL_KERNEL_VERSION_FROM_KERNEL} SGL_KERNEL_VERSION_FROM_SRT=${SGL_KERNEL_VERSION_FROM_SRT}"
|
||||
|
||||
if [ "${CUSTOM_BUILD_SGL_KERNEL:-}" = "true" ] && [ -d "sgl-kernel/dist" ]; then
|
||||
ls -alh sgl-kernel/dist
|
||||
# Determine wheel architecture
|
||||
if [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then
|
||||
WHEEL_ARCH="aarch64"
|
||||
else
|
||||
WHEEL_ARCH="x86_64"
|
||||
fi
|
||||
$PIP_CMD install sgl-kernel/dist/sglang_kernel-${SGL_KERNEL_VERSION_FROM_KERNEL}-cp310-abi3-manylinux2014_${WHEEL_ARCH}.whl --force-reinstall $PIP_INSTALL_SUFFIX
|
||||
elif [ "${CUSTOM_BUILD_SGL_KERNEL:-}" = "true" ] && [ ! -d "sgl-kernel/dist" ]; then
|
||||
# CUSTOM_BUILD_SGL_KERNEL was set but artifacts not available (e.g., stage rerun without wheel build)
|
||||
# Fail instead of falling back to PyPI - we need to test the built kernel, not PyPI version
|
||||
echo "ERROR: CUSTOM_BUILD_SGL_KERNEL=true but sgl-kernel/dist not found."
|
||||
echo "This usually happens when rerunning a stage without the sgl-kernel-build-wheels job."
|
||||
echo "Please re-run the full workflow using /tag-and-rerun-ci to rebuild the kernel."
|
||||
exit 1
|
||||
else
|
||||
# On Blackwell machines, skip reinstall if correct version already installed to avoid race conditions
|
||||
if [ "$IS_BLACKWELL" = "1" ]; then
|
||||
INSTALLED_SGL_KERNEL=$(pip show sglang-kernel 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
|
||||
if [ "$INSTALLED_SGL_KERNEL" = "$SGL_KERNEL_VERSION_FROM_SRT" ]; then
|
||||
echo "sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} already installed, skipping reinstall"
|
||||
else
|
||||
echo "Installing sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} (current: ${INSTALLED_SGL_KERNEL:-none})"
|
||||
$PIP_CMD install sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
else
|
||||
$PIP_CMD install sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT} --force-reinstall $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
fi
|
||||
|
||||
mark_step_done "Install sglang-kernel"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install sglang-router
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install router for pd-disagg test
|
||||
$PIP_CMD install sglang-router $PIP_INSTALL_SUFFIX
|
||||
|
||||
# Show current packages
|
||||
$PIP_CMD list
|
||||
|
||||
mark_step_done "Install sglang-router"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Download flashinfer artifacts
|
||||
# ------------------------------------------------------------------------------
|
||||
# Download flashinfer jit cache
|
||||
UNINSTALL_JIT_CACHE="$UNINSTALL_JIT_CACHE" \
|
||||
FLASHINFER_PYTHON_REQUIRED="$FLASHINFER_PYTHON_REQUIRED" \
|
||||
CU_VERSION="$CU_VERSION" \
|
||||
PIP_CMD="$PIP_CMD" \
|
||||
PIP_INSTALL_SUFFIX="$PIP_INSTALL_SUFFIX" \
|
||||
bash "${SCRIPT_DIR}/ci_download_flashinfer_jit_cache.sh"
|
||||
# Download flashinfer cubins
|
||||
bash "${SCRIPT_DIR}/ci_download_flashinfer_cubin.sh"
|
||||
|
||||
mark_step_done "Download flashinfer artifacts"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install extra dependency
|
||||
# ------------------------------------------------------------------------------
|
||||
# Install other python dependencies
|
||||
if [ "$CU_VERSION" = "cu130" ]; then
|
||||
NVRTC_SPEC="nvidia-cuda-nvrtc"
|
||||
else
|
||||
NVRTC_SPEC="nvidia-cuda-nvrtc-cu12"
|
||||
fi
|
||||
$PIP_CMD install mooncake-transfer-engine==0.3.10.post1 "${NVRTC_SPEC}" py-spy scipy huggingface_hub[hf_xet] pytest $PIP_INSTALL_SUFFIX
|
||||
|
||||
# Install other test dependencies
|
||||
if [ "$IS_BLACKWELL" != "1" ]; then
|
||||
# For lmms_evals evaluating MMMU
|
||||
git clone --branch v0.5 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||
$PIP_CMD install -e lmms-eval/ $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
$PIP_CMD uninstall xformers || true
|
||||
|
||||
mark_step_done "Install extra dependency"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Fix other dependencies
|
||||
# ------------------------------------------------------------------------------
|
||||
# Fix CUDA version mismatch between torch and torchaudio.
|
||||
# PyPI's torch 2.9.1 bundles cu128 but torchaudio from pytorch.org/cu129 uses cu129.
|
||||
# This mismatch causes torchaudio's C extension to fail loading, producing:
|
||||
# "partially initialized module 'torchaudio' has no attribute 'lib'"
|
||||
# We cannot replace torch with cu129 (breaks sgl_kernel ABI), so instead we reinstall
|
||||
# torchaudio/torchvision from an index matching torch's CUDA version.
|
||||
TORCH_CUDA_VER=$(python3 -c "import torch; v=torch.version.cuda; parts=v.split('.'); print(f'cu{parts[0]}{parts[1]}')")
|
||||
echo "Detected torch CUDA version: ${TORCH_CUDA_VER}"
|
||||
if [ "${TORCH_CUDA_VER}" != "${CU_VERSION}" ]; then
|
||||
# Pin versions to match what was installed by pyproject.toml (strip +cuXYZ suffix)
|
||||
TORCHAUDIO_VER=$(pip show torchaudio 2>/dev/null | grep "^Version:" | awk '{print $2}' | sed 's/+.*//')
|
||||
TORCHVISION_VER=$(pip show torchvision 2>/dev/null | grep "^Version:" | awk '{print $2}' | sed 's/+.*//')
|
||||
echo "Reinstalling torchaudio==${TORCHAUDIO_VER} torchvision==${TORCHVISION_VER} from ${TORCH_CUDA_VER} index to match torch..."
|
||||
$PIP_CMD install "torchaudio==${TORCHAUDIO_VER}" "torchvision==${TORCHVISION_VER}" --index-url "https://download.pytorch.org/whl/${TORCH_CUDA_VER}" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
|
||||
# Fix dependencies: DeepEP depends on nvshmem 3.4.5 — skip reinstall when already correct (avoids pip races / wasted work)
|
||||
INSTALLED_NVSHMEM=$(pip show nvidia-nvshmem-cu12 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
|
||||
if [ "$INSTALLED_NVSHMEM" = "$NVIDIA_NVSHMEM_VERSION" ]; then
|
||||
echo "nvidia-nvshmem-cu12==${NVIDIA_NVSHMEM_VERSION} already installed, skipping reinstall"
|
||||
else
|
||||
$PIP_CMD install nvidia-nvshmem-cu12==${NVIDIA_NVSHMEM_VERSION} $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
|
||||
# Fix dependencies: Cudnn with version less than 9.16.0.29 will cause performance regression on Conv3D kernel
|
||||
INSTALLED_CUDNN=$(pip show nvidia-cudnn-cu12 2>/dev/null | grep "^Version:" | awk '{print $2}' || echo "")
|
||||
if [ "$INSTALLED_CUDNN" = "$NVIDIA_CUDNN_VERSION" ]; then
|
||||
echo "nvidia-cudnn-cu12==${NVIDIA_CUDNN_VERSION} already installed, skipping reinstall"
|
||||
else
|
||||
$PIP_CMD install nvidia-cudnn-cu12==${NVIDIA_CUDNN_VERSION} $PIP_INSTALL_SUFFIX
|
||||
fi
|
||||
|
||||
mark_step_done "Fix other dependencies"
|
||||
|
||||
# Force reinstall nvidia-cutlass-dsl to ensure the .pth file exists.
|
||||
# The Docker image ships nvidia-cutlass-dsl-libs-base 4.3.5; upgrading to 4.4.2
|
||||
# can delete the .pth file without reliably recreating it (pip race condition).
|
||||
$PIP_CMD install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --no-deps --force-reinstall $PIP_INSTALL_SUFFIX || true
|
||||
|
||||
|
||||
# Install human-eval
|
||||
pip install "setuptools==70.0.0"
|
||||
git clone https://github.com/merrymercy/human-eval.git
|
||||
cd human-eval
|
||||
pip install -e . --no-build-isolation
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Prepare runner
|
||||
# ------------------------------------------------------------------------------
|
||||
# Prepare the CI runner (cleanup HuggingFace cache, etc.)
|
||||
bash "${SCRIPT_DIR}/prepare_runner.sh"
|
||||
|
||||
mark_step_done "Prepare runner"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Verify imports
|
||||
# ------------------------------------------------------------------------------
|
||||
# Show current packages
|
||||
$PIP_CMD list
|
||||
python3 -c "import torch; print(torch.version.cuda)"
|
||||
python3 -c "import cutlass; import cutlass.cute;"
|
||||
|
||||
mark_step_done "Verify imports"
|
||||
24
third_party/sglang/scripts/ci/cuda/ci_install_gateway_dependencies.sh
vendored
Executable file
24
third_party/sglang/scripts/ci/cuda/ci_install_gateway_dependencies.sh
vendored
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
# Check if sudo is available
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libssl-dev pkg-config protobuf-compiler redis-server
|
||||
else
|
||||
apt-get update
|
||||
apt-get install -y libssl-dev pkg-config protobuf-compiler redis-server
|
||||
fi
|
||||
|
||||
# Install rustup (Rust installer and version manager)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.90
|
||||
|
||||
|
||||
# Follow the installation prompts, then reload your shell
|
||||
. "$HOME/.cargo/env"
|
||||
source $HOME/.cargo/env
|
||||
|
||||
# Verify installation
|
||||
rustc --version
|
||||
cargo --version
|
||||
protoc --version
|
||||
106
third_party/sglang/scripts/ci/cuda/ci_start_disaggregation_servers.sh
vendored
Executable file
106
third_party/sglang/scripts/ci/cuda/ci_start_disaggregation_servers.sh
vendored
Executable file
@@ -0,0 +1,106 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Optional: set DISAGG_READY_FILE to a filepath; when all servers are healthy, the script will
|
||||
# create this file as a readiness signal (useful for CI to proceed to next steps).
|
||||
DISAGG_READY_FILE="${DISAGG_READY_FILE:-}"
|
||||
|
||||
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Function to find the first available active IB device
|
||||
find_active_ib_device() {
|
||||
for device in mlx5_{0..11}; do
|
||||
if ibv_devinfo $device >/dev/null 2>&1; then
|
||||
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
|
||||
if [[ "$state" == "PORT_ACTIVE" ]]; then
|
||||
echo "$device"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
done
|
||||
echo "No active IB device found" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Get the first available active IB device
|
||||
DEVICE=$(find_active_ib_device)
|
||||
echo "Using IB device: $DEVICE"
|
||||
|
||||
# Launch prefill servers on GPU 0–3
|
||||
for i in {0..3}; do
|
||||
PORT=$((30001 + i))
|
||||
BOOTSTRAP_PORT=$((9001 + i))
|
||||
HOST="127.0.0.$((i + 1))"
|
||||
echo "Launching PREFILL server on GPU $i at $HOST:$PORT (bootstrap: $BOOTSTRAP_PORT)"
|
||||
CUDA_VISIBLE_DEVICES=$i \
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--disaggregation-mode prefill \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--disaggregation-ib-device "$DEVICE" \
|
||||
--disaggregation-bootstrap-port "$BOOTSTRAP_PORT" &
|
||||
done
|
||||
|
||||
# Launch decode servers on GPU 4–7
|
||||
for i in {4..7}; do
|
||||
PORT=$((30001 + i))
|
||||
HOST="127.0.0.$((i + 1))"
|
||||
echo "Launching DECODE server on GPU $i at $HOST:$PORT"
|
||||
CUDA_VISIBLE_DEVICES=$i \
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--disaggregation-mode decode \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--disaggregation-ib-device "$DEVICE" \
|
||||
--base-gpu-id 0 &
|
||||
done
|
||||
|
||||
# Wait for disaggregation servers to initialize
|
||||
echo "Waiting for disaggregation servers to initialize..."
|
||||
|
||||
# Health check with 5-minute timeout
|
||||
TIMEOUT=300
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
echo "Checking health of all 8 servers..."
|
||||
while true; do
|
||||
CURRENT_TIME=$(date +%s)
|
||||
ELAPSED=$((CURRENT_TIME - START_TIME))
|
||||
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo "❌ Timeout: Servers did not become healthy within 5 minutes"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HEALTHY_COUNT=0
|
||||
# Check all 8 servers (127.0.0.1-8:30001-30008)
|
||||
for i in {1..8}; do
|
||||
if curl -s -f "http://127.0.0.$i:$((30000 + i))/health" >/dev/null 2>&1; then
|
||||
HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Healthy servers: $HEALTHY_COUNT/8 (elapsed: ${ELAPSED}s)"
|
||||
|
||||
if [ $HEALTHY_COUNT -eq 8 ]; then
|
||||
echo "✅ All 8 servers are healthy!"
|
||||
# Emit readiness signal file if requested
|
||||
if [ -n "$DISAGG_READY_FILE" ]; then
|
||||
echo "Creating readiness flag: $DISAGG_READY_FILE"
|
||||
# Ensure parent dir exists; ignore errors
|
||||
mkdir -p "$(dirname "$DISAGG_READY_FILE")" 2>/dev/null || true
|
||||
touch "$DISAGG_READY_FILE"
|
||||
fi
|
||||
break
|
||||
else
|
||||
sleep 10 # Wait 10 seconds before next check
|
||||
fi
|
||||
done
|
||||
|
||||
# Don't launch router here - just keep servers running
|
||||
echo "✅ All disaggregation servers are ready and waiting for router connections"
|
||||
|
||||
# Keep the script running
|
||||
wait
|
||||
19
third_party/sglang/scripts/ci/cuda/prepare_runner.sh
vendored
Executable file
19
third_party/sglang/scripts/ci/cuda/prepare_runner.sh
vendored
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
# Prepare the CI runner by cleaning up stale HuggingFace cache artifacts and validating models
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
echo "Preparing CI runner..."
|
||||
echo ""
|
||||
|
||||
# Clean up stale HuggingFace cache artifacts from previous failed downloads
|
||||
python3 "${SCRIPT_DIR}/../utils/cleanup_hf_cache.py"
|
||||
echo ""
|
||||
|
||||
# Pre-validate cached models and write markers for offline mode
|
||||
# This allows tests to run with HF_HUB_OFFLINE=1 for models that are fully cached
|
||||
python3 "${SCRIPT_DIR}/../utils/prevalidate_cached_models.py"
|
||||
echo ""
|
||||
|
||||
echo "CI runner preparation complete!"
|
||||
399
third_party/sglang/scripts/ci/cuda/warmup_deep_gemm.py
vendored
Normal file
399
third_party/sglang/scripts/ci/cuda/warmup_deep_gemm.py
vendored
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Lightweight DeepGEMM JIT compilation warmup without loading model weights.
|
||||
|
||||
Reads model config.json from HF cache to derive kernel shapes, then compiles
|
||||
DeepGEMM kernels directly. This avoids the expensive model weight loading step
|
||||
that the full `sglang.compile_deep_gemm` requires.
|
||||
|
||||
Supports DeepSeek V2/V3 family models. Falls back to `sglang.compile_deep_gemm`
|
||||
for unsupported architectures.
|
||||
|
||||
Usage:
|
||||
python3 scripts/ci/cuda/warmup_deep_gemm.py \
|
||||
deepseek-ai/DeepSeek-V3-0324:8 \
|
||||
deepseek-ai/DeepSeek-V3.2-Exp:8
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
# Configure DeepGEMM cache before importing deep_gemm
|
||||
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
||||
"SGLANG_DG_CACHE_DIR",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm"),
|
||||
)
|
||||
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
||||
|
||||
BLOCK_SIZE = 128
|
||||
|
||||
|
||||
def get_config_json(model_name):
|
||||
"""Load config.json for a cached model from HF cache."""
|
||||
cache_dir = os.environ.get(
|
||||
"HF_HOME", os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
|
||||
)
|
||||
hub_dir = os.path.join(cache_dir, "hub")
|
||||
safe_name = "models--" + model_name.replace("/", "--")
|
||||
snapshots_dir = os.path.join(hub_dir, safe_name, "snapshots")
|
||||
|
||||
if not os.path.isdir(snapshots_dir):
|
||||
return None
|
||||
|
||||
snapshots = sorted(
|
||||
Path(snapshots_dir).iterdir(), key=lambda p: p.stat().st_mtime, reverse=True
|
||||
)
|
||||
for snapshot in snapshots:
|
||||
config_path = snapshot / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
return None
|
||||
|
||||
|
||||
def is_deepseek_v2v3(config):
|
||||
"""Check if a model is from the DeepSeek V2/V3 family."""
|
||||
architectures = config.get("architectures", [])
|
||||
model_type = config.get("model_type", "")
|
||||
return any(
|
||||
"DeepseekV2" in a or "DeepseekV3" in a for a in architectures
|
||||
) or model_type in ("deepseek_v2", "deepseek_v3")
|
||||
|
||||
|
||||
def compute_deepseek_v2v3_shapes(config, tp):
|
||||
"""Compute all DeepGEMM (kernel_type, N, K, num_groups) for DeepSeek V2/V3.
|
||||
|
||||
Shape derivation based on:
|
||||
- MoE: python/sglang/srt/layers/moe/fused_moe_triton/layer.py
|
||||
- MLA: python/sglang/srt/models/deepseek_v2.py
|
||||
- FP8: python/sglang/srt/layers/quantization/fp8_kernel.py
|
||||
"""
|
||||
shapes = []
|
||||
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config.get("num_attention_heads", 128)
|
||||
kv_lora_rank = config.get("kv_lora_rank", 512)
|
||||
qk_nope_head_dim = config.get("qk_nope_head_dim", 128)
|
||||
v_head_dim = config.get("v_head_dim", 128)
|
||||
n_routed_experts = config.get("n_routed_experts", 0)
|
||||
n_shared_experts = config.get("n_shared_experts", 0)
|
||||
moe_intermediate_size = config.get("moe_intermediate_size", 0)
|
||||
|
||||
num_local_heads = num_attention_heads // tp
|
||||
# Shared expert fusion is enabled by default (disable_shared_experts_fusion=False)
|
||||
# so the FusedMoE weight tensor includes shared experts
|
||||
num_local_experts = n_routed_experts + n_shared_experts
|
||||
|
||||
# --- MoE expert GEMM shapes ---
|
||||
# FusedMoE shards intermediate_size across TP ranks (column parallel for gate/up,
|
||||
# row parallel for down). All experts are replicated on each TP rank.
|
||||
if n_routed_experts > 0 and moe_intermediate_size > 0:
|
||||
moe_inter_per_tp = moe_intermediate_size // tp
|
||||
|
||||
# Gate-Up projection: (tokens, hidden_size) @ (experts, 2*inter_per_tp, hidden_size)^T
|
||||
# Both masked and contiguous paths are used at runtime
|
||||
shapes.append(("MASKED", moe_inter_per_tp * 2, hidden_size, num_local_experts))
|
||||
shapes.append(("CONTIG", moe_inter_per_tp * 2, hidden_size, num_local_experts))
|
||||
|
||||
# Down projection: (tokens, inter_per_tp) @ (experts, hidden_size, inter_per_tp)^T
|
||||
shapes.append(("MASKED", hidden_size, moe_inter_per_tp, num_local_experts))
|
||||
shapes.append(("CONTIG", hidden_size, moe_inter_per_tp, num_local_experts))
|
||||
|
||||
# --- MLA attention GEMM shapes (masked grouped GEMM) ---
|
||||
if kv_lora_rank > 0 and num_local_heads > 0:
|
||||
# Q_nope -> compressed K: (heads, m, qk_nope_head_dim) @ (heads, kv_lora_rank, qk_nope_head_dim)^T
|
||||
shapes.append(("MASKED", kv_lora_rank, qk_nope_head_dim, num_local_heads))
|
||||
|
||||
# Attention output -> V: (heads, m, kv_lora_rank) @ (heads, v_head_dim, kv_lora_rank)^T
|
||||
shapes.append(("MASKED", v_head_dim, kv_lora_rank, num_local_heads))
|
||||
|
||||
# --- kv_b_proj (non-grouped GEMM via FP8 kernel) ---
|
||||
# ColumnParallelLinear(kv_lora_rank, num_heads * (qk_nope + v_head_dim))
|
||||
# Per TP rank: N = num_local_heads * (qk_nope_head_dim + v_head_dim)
|
||||
if kv_lora_rank > 0 and num_local_heads > 0:
|
||||
kv_b_proj_n = num_local_heads * (qk_nope_head_dim + v_head_dim)
|
||||
shapes.append(("NORMAL", kv_b_proj_n, kv_lora_rank, 1))
|
||||
|
||||
return shapes
|
||||
|
||||
|
||||
def get_architecture_key(config, tp):
|
||||
"""Key for dedup: models with same key share DeepGEMM kernels."""
|
||||
if config is None:
|
||||
return None
|
||||
fields = [
|
||||
config.get("hidden_size", 0),
|
||||
config.get("moe_intermediate_size", 0),
|
||||
config.get("n_routed_experts", 0),
|
||||
config.get("n_shared_experts", 0),
|
||||
config.get("num_attention_heads", 0),
|
||||
config.get("kv_lora_rank", 0),
|
||||
config.get("qk_nope_head_dim", 0),
|
||||
config.get("v_head_dim", 0),
|
||||
tp,
|
||||
]
|
||||
return tuple(fields)
|
||||
|
||||
|
||||
def compute_m_list(fast_warmup=False, chunked_prefill_size=8192):
|
||||
"""Compute the list of M values to compile (matches compile_utils.py logic)."""
|
||||
m_list = []
|
||||
if fast_warmup:
|
||||
m_list += list(range(1, 1025))
|
||||
next_m, sample_step = 1024, 2
|
||||
max_prefill_bs = min(chunked_prefill_size, 32 * 1024)
|
||||
while next_m < max_prefill_bs:
|
||||
m_list += list(range(next_m, 2 * next_m, sample_step))
|
||||
next_m *= 2
|
||||
sample_step *= 2
|
||||
m_list.append(max_prefill_bs)
|
||||
m_list = sorted(set(m_list))
|
||||
else:
|
||||
m_max = 16 * 1024
|
||||
if chunked_prefill_size > 8192:
|
||||
m_max = chunked_prefill_size * 2
|
||||
m_max = min(128 * 1024, m_max)
|
||||
m_list = list(range(1, m_max + 1))
|
||||
return m_list
|
||||
|
||||
|
||||
def _empty_token_fp8(size):
|
||||
"""Create FP8 token tensor + per-block scale tensor."""
|
||||
import torch
|
||||
|
||||
*dims, k = size
|
||||
return (
|
||||
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
||||
torch.empty((*dims, ceil(k / BLOCK_SIZE)), device="cuda", dtype=torch.float32),
|
||||
)
|
||||
|
||||
|
||||
def _empty_block_fp8(size):
|
||||
"""Create FP8 block tensor + per-block scale tensor."""
|
||||
import torch
|
||||
|
||||
*dims, n, k = size
|
||||
return (
|
||||
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(*dims, ceil(n / BLOCK_SIZE), ceil(k / BLOCK_SIZE)),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_memory_requirement(kernel_type, max_m, n, k, num_groups):
|
||||
"""Estimate GPU memory needed in GB for compilation buffers."""
|
||||
_GB = 1 << 30
|
||||
if kernel_type == "NORMAL":
|
||||
return (max_m * k + n * k + max_m * n * 2) / _GB
|
||||
elif kernel_type == "CONTIG":
|
||||
return (max_m * k + num_groups * n * k + max_m * 4 + max_m * n * 2) / _GB
|
||||
elif kernel_type == "MASKED":
|
||||
return (
|
||||
num_groups * max_m * k
|
||||
+ num_groups * n * k
|
||||
+ num_groups * 4
|
||||
+ num_groups * max_m * n * 2
|
||||
) / _GB
|
||||
return 0
|
||||
|
||||
|
||||
def compile_one_shape(kernel_type, n, k, num_groups, m_list):
|
||||
"""Compile DeepGEMM kernels for one (kernel_type, N, K, num_groups) shape."""
|
||||
import deep_gemm
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
# Filter M list for contiguous layout alignment
|
||||
if kernel_type == "CONTIG":
|
||||
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
|
||||
m_list = sorted(set(m for m in m_list if m % m_alignment == 0))
|
||||
|
||||
if not m_list:
|
||||
return
|
||||
|
||||
max_m = max(m_list)
|
||||
|
||||
# Reduce max_m if not enough GPU memory
|
||||
mem_free = torch.cuda.mem_get_info()[0] / (1 << 30)
|
||||
mem_required = get_memory_requirement(kernel_type, max_m, n, k, num_groups)
|
||||
if mem_required > mem_free:
|
||||
while (
|
||||
get_memory_requirement(kernel_type, max_m, n, k, num_groups) > mem_free
|
||||
and max_m > 4096
|
||||
):
|
||||
max_m //= 2
|
||||
print(
|
||||
f" Memory {mem_free:.1f}GB < required {mem_required:.1f}GB, "
|
||||
f"reducing max_m to {max_m}"
|
||||
)
|
||||
m_list = [m for m in m_list if m <= max_m]
|
||||
|
||||
old_mode = deep_gemm.get_compile_mode()
|
||||
deep_gemm.set_compile_mode(1)
|
||||
try:
|
||||
if kernel_type == "NORMAL":
|
||||
lhs_q, lhs_s = _empty_token_fp8((max_m, k))
|
||||
rhs_q, rhs_s = _empty_block_fp8((n, k))
|
||||
out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
||||
for m in tqdm(m_list, desc=f" NORMAL N={n} K={k}"):
|
||||
deep_gemm.fp8_gemm_nt((lhs_q[:m], lhs_s[:m]), (rhs_q, rhs_s), out[:m])
|
||||
|
||||
elif kernel_type == "CONTIG":
|
||||
lhs_q, lhs_s = _empty_token_fp8((max_m, k))
|
||||
rhs_q, rhs_s = _empty_block_fp8((num_groups, n, k))
|
||||
m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
|
||||
out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
||||
for m in tqdm(m_list, desc=f" CONTIG N={n} K={k} G={num_groups}"):
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
||||
(lhs_q[:m], lhs_s[:m]),
|
||||
(rhs_q, rhs_s),
|
||||
out[:m],
|
||||
m_indices=m_indices[:m],
|
||||
)
|
||||
|
||||
elif kernel_type == "MASKED":
|
||||
lhs_q, lhs_s = _empty_token_fp8((num_groups, max_m, k))
|
||||
rhs_q, rhs_s = _empty_block_fp8((num_groups, n, k))
|
||||
masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
|
||||
out = torch.empty(
|
||||
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
for m in tqdm(m_list, desc=f" MASKED N={n} K={k} G={num_groups}"):
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
||||
(lhs_q, lhs_s),
|
||||
(rhs_q, rhs_s),
|
||||
out,
|
||||
masked_m=masked_m,
|
||||
expected_m=m,
|
||||
)
|
||||
finally:
|
||||
deep_gemm.set_compile_mode(old_mode)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def compile_shapes_lightweight(shapes, m_list):
|
||||
"""Compile all DeepGEMM shapes directly (no model loading)."""
|
||||
for i, (kernel_type, n, k, num_groups) in enumerate(shapes, 1):
|
||||
print(f"\n[{i}/{len(shapes)}] {kernel_type} N={n} K={k} G={num_groups}")
|
||||
t0 = time.time()
|
||||
compile_one_shape(kernel_type, n, k, num_groups, m_list)
|
||||
elapsed = time.time() - t0
|
||||
print(f" Done in {elapsed:.1f}s")
|
||||
|
||||
|
||||
def fallback_compile_deep_gemm(model, tp):
|
||||
"""Fall back to full sglang.compile_deep_gemm (loads model weights)."""
|
||||
print(f"Falling back to full compile_deep_gemm for {model} (tp={tp})...")
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"sglang.compile_deep_gemm",
|
||||
"--model",
|
||||
model,
|
||||
"--tp",
|
||||
str(tp),
|
||||
"--trust-remote-code",
|
||||
"--model-loader-extra-config",
|
||||
'{"enable_multithread_load": true, "num_threads": 64}',
|
||||
]
|
||||
result = subprocess.run(cmd)
|
||||
if result.returncode != 0:
|
||||
print(f"Warning: fallback failed for {model} (exit code {result.returncode})")
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
||||
print("Usage: warmup_deep_gemm.py model1:tp1 [model2:tp2 ...]")
|
||||
print("\nDerives DeepGEMM kernel shapes from config.json without loading model")
|
||||
print(
|
||||
"weights. Falls back to full compile_deep_gemm for unknown architectures."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
# Parse model:tp pairs
|
||||
model_tp_pairs = []
|
||||
for arg in sys.argv[1:]:
|
||||
if ":" not in arg:
|
||||
print(f"Error: expected model:tp format, got '{arg}'")
|
||||
sys.exit(1)
|
||||
model, tp_str = arg.rsplit(":", 1)
|
||||
model_tp_pairs.append((model, int(tp_str)))
|
||||
|
||||
fast_warmup = os.environ.get("SGLANG_JIT_DEEPGEMM_FAST_WARMUP", "0").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
)
|
||||
print(f"=== DeepGEMM Lightweight Warmup ({len(model_tp_pairs)} model(s)) ===")
|
||||
print(f" Fast warmup: {fast_warmup}")
|
||||
print(
|
||||
f" Cache dir: {os.environ.get('DG_JIT_CACHE_DIR', '~/.cache/deep_gemm')}\n"
|
||||
)
|
||||
|
||||
# Load configs and deduplicate by architecture
|
||||
seen_keys = {}
|
||||
to_process = [] # (model, tp, config_or_None, shapes_or_None)
|
||||
|
||||
for model, tp in model_tp_pairs:
|
||||
config = get_config_json(model)
|
||||
if config is None:
|
||||
print(f" SKIP {model} (tp={tp}): config.json not in HF cache")
|
||||
continue
|
||||
|
||||
key = get_architecture_key(config, tp)
|
||||
if key in seen_keys:
|
||||
print(f" DEDUP {model} (tp={tp}): same shapes as {seen_keys[key]}")
|
||||
continue
|
||||
|
||||
if is_deepseek_v2v3(config):
|
||||
shapes = compute_deepseek_v2v3_shapes(config, tp)
|
||||
seen_keys[key] = model
|
||||
to_process.append((model, tp, config, shapes))
|
||||
print(f" FOUND {model} (tp={tp}): {len(shapes)} DeepGEMM shape(s)")
|
||||
else:
|
||||
# Unknown architecture: will use fallback
|
||||
seen_keys[key] = model
|
||||
to_process.append((model, tp, config, None))
|
||||
arch = config.get("architectures", ["unknown"])
|
||||
print(f" FOUND {model} (tp={tp}): unknown arch {arch}, will use fallback")
|
||||
|
||||
if not to_process:
|
||||
print("\nNo models to process. Done.")
|
||||
return
|
||||
|
||||
m_list = compute_m_list(fast_warmup=fast_warmup)
|
||||
print(f"\nM list: {len(m_list)} values (range {min(m_list)}-{max(m_list)})")
|
||||
|
||||
for model, tp, config, shapes in to_process:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Model: {model} (tp={tp})")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if shapes is None:
|
||||
# Unknown architecture: fall back to full compile_deep_gemm
|
||||
fallback_compile_deep_gemm(model, tp)
|
||||
continue
|
||||
|
||||
# Print shape summary
|
||||
for kernel_type, n, k, num_groups in shapes:
|
||||
print(f" {kernel_type:8s} N={n:<6d} K={k:<6d} G={num_groups}")
|
||||
|
||||
t0 = time.time()
|
||||
compile_shapes_lightweight(shapes, m_list)
|
||||
elapsed = time.time() - t0
|
||||
print(f"\nCompleted {model} in {elapsed:.1f}s")
|
||||
|
||||
print("\nDeepGEMM lightweight warmup complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
313
third_party/sglang/scripts/ci/cuda/warmup_server.py
vendored
Normal file
313
third_party/sglang/scripts/ci/cuda/warmup_server.py
vendored
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Full server warmup to pre-warm Triton autotuning and CUDA graph capture.
|
||||
|
||||
On cold H200 nodes (new nodes or after container recreation), CUDA graph capture
|
||||
triggers Triton autotuning which takes ~330s per server launch. This script
|
||||
launches actual servers with CUDA graphs enabled to cache the autotuned kernels,
|
||||
so subsequent test launches are fast (~30-60s).
|
||||
|
||||
Uses marker files to skip warmup on already-warm nodes. Marker files are
|
||||
invalidated when Python, Triton, or PyTorch versions change.
|
||||
|
||||
Usage:
|
||||
python3 scripts/ci/cuda/warmup_server.py \
|
||||
deepseek-ai/DeepSeek-V3-0324:8 \
|
||||
inclusionAI/Ring-2.5-1T:8
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Reuse helpers from warmup_deep_gemm (same directory)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from warmup_deep_gemm import get_architecture_key, get_config_json
|
||||
|
||||
MARKER_DIR = os.path.join(os.path.expanduser("~"), ".cache", "sglang", "warmup_markers")
|
||||
HEALTH_POLL_INTERVAL = 10 # seconds between health checks
|
||||
SERVER_STARTUP_TIMEOUT = 900 # 15 min max to wait for server ready
|
||||
DEFAULT_PORT = 39876
|
||||
|
||||
|
||||
def get_version_key():
|
||||
"""Hash of Python + Triton + PyTorch versions to invalidate markers on upgrades."""
|
||||
parts = [sys.version]
|
||||
try:
|
||||
import triton
|
||||
|
||||
parts.append(f"triton={triton.__version__}")
|
||||
except ImportError:
|
||||
parts.append("triton=none")
|
||||
try:
|
||||
import torch
|
||||
|
||||
parts.append(f"torch={torch.__version__}")
|
||||
except ImportError:
|
||||
parts.append("torch=none")
|
||||
return hashlib.sha256("|".join(parts).encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
def get_marker_path(model, tp):
|
||||
"""Get the marker file path for a model:tp pair."""
|
||||
version_key = get_version_key()
|
||||
safe_model = model.replace("/", "--")
|
||||
return os.path.join(
|
||||
MARKER_DIR, f"server_warmup_{safe_model}_tp{tp}_{version_key}.done"
|
||||
)
|
||||
|
||||
|
||||
def check_marker(model, tp):
|
||||
"""Check if warmup marker exists (node already warm)."""
|
||||
marker = get_marker_path(model, tp)
|
||||
return os.path.exists(marker)
|
||||
|
||||
|
||||
def write_marker(model, tp):
|
||||
"""Write warmup marker after successful warmup."""
|
||||
marker = get_marker_path(model, tp)
|
||||
os.makedirs(os.path.dirname(marker), exist_ok=True)
|
||||
Path(marker).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"model": model,
|
||||
"tp": tp,
|
||||
"version_key": get_version_key(),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
)
|
||||
print(f" Wrote marker: {marker}")
|
||||
|
||||
|
||||
def kill_server(proc):
|
||||
"""Kill server process tree."""
|
||||
if proc.poll() is not None:
|
||||
return
|
||||
try:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, OSError):
|
||||
pass
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
||||
except (ProcessLookupError, OSError):
|
||||
pass
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
|
||||
|
||||
def wait_for_server(base_url, proc, timeout):
|
||||
"""Poll /health_generate until server is ready or timeout."""
|
||||
import requests
|
||||
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
ret = proc.poll()
|
||||
if ret is not None:
|
||||
return False, f"Server exited with code {ret}"
|
||||
try:
|
||||
resp = requests.get(f"{base_url}/health_generate", timeout=5)
|
||||
if resp.status_code == 200:
|
||||
return True, None
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(HEALTH_POLL_INTERVAL)
|
||||
return False, "Timed out waiting for server"
|
||||
|
||||
|
||||
def send_generate_request(base_url):
|
||||
"""Send one /generate request to exercise the full inference path."""
|
||||
import requests
|
||||
|
||||
payload = {
|
||||
"input_ids": [0, 1, 2, 3],
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 8,
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
try:
|
||||
resp = requests.post(f"{base_url}/generate", json=payload, timeout=120)
|
||||
if resp.status_code == 200:
|
||||
print(" Generate request succeeded")
|
||||
else:
|
||||
print(f" Warning: generate request returned {resp.status_code}")
|
||||
except requests.RequestException as e:
|
||||
print(f" Warning: generate request failed: {e}")
|
||||
|
||||
|
||||
def warmup_one_model(model, tp, port):
|
||||
"""Launch server, wait for ready, send one request, then kill."""
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--tp",
|
||||
str(tp),
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--trust-remote-code",
|
||||
"--model-loader-extra-config",
|
||||
'{"enable_multithread_load": true, "num_threads": 64}',
|
||||
]
|
||||
|
||||
# Use a temp file for server output to avoid pipe buffer deadlock
|
||||
# (server logs can exceed the 64KB pipe buffer during CUDA graph capture)
|
||||
log_file = tempfile.NamedTemporaryFile(
|
||||
mode="w", prefix="warmup_server_", suffix=".log", delete=False
|
||||
)
|
||||
log_path = log_file.name
|
||||
|
||||
print(f" Launching server: {' '.join(cmd)}")
|
||||
print(f" Server log: {log_path}")
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait for server to be ready (includes CUDA graph capture)
|
||||
print(
|
||||
f" Waiting for server (timeout={SERVER_STARTUP_TIMEOUT}s, "
|
||||
f"polling every {HEALTH_POLL_INTERVAL}s)..."
|
||||
)
|
||||
ok, err = wait_for_server(base_url, proc, SERVER_STARTUP_TIMEOUT)
|
||||
if not ok:
|
||||
print(f" Warning: server not ready: {err}")
|
||||
# Dump last lines of server log for debugging
|
||||
try:
|
||||
log_file.flush()
|
||||
with open(log_path) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines[-20:]:
|
||||
print(f" | {line.rstrip()}")
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
print(" Server ready, sending generate request...")
|
||||
send_generate_request(base_url)
|
||||
return True
|
||||
|
||||
finally:
|
||||
print(" Killing server...")
|
||||
kill_server(proc)
|
||||
log_file.close()
|
||||
try:
|
||||
os.unlink(log_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
||||
print("Usage: warmup_server.py model1:tp1 [model2:tp2 ...]")
|
||||
print(
|
||||
"\nLaunches full servers with CUDA graphs enabled to pre-warm"
|
||||
" Triton autotuning."
|
||||
)
|
||||
print("Skips instantly on warm nodes (marker file exists).")
|
||||
sys.exit(0)
|
||||
|
||||
# Parse model:tp pairs
|
||||
model_tp_pairs = []
|
||||
for arg in sys.argv[1:]:
|
||||
if ":" not in arg:
|
||||
print(f"Error: expected model:tp format, got '{arg}'")
|
||||
sys.exit(1)
|
||||
model, tp_str = arg.rsplit(":", 1)
|
||||
model_tp_pairs.append((model, int(tp_str)))
|
||||
|
||||
print(f"=== Server CUDA Graph Warmup ({len(model_tp_pairs)} model(s)) ===")
|
||||
print(f" Marker dir: {MARKER_DIR}")
|
||||
print(f" Version key: {get_version_key()}\n")
|
||||
|
||||
# Deduplicate by architecture and check markers
|
||||
seen_keys = {}
|
||||
to_warmup = []
|
||||
|
||||
for model, tp in model_tp_pairs:
|
||||
# Check marker first (fast path)
|
||||
if check_marker(model, tp):
|
||||
print(f" SKIP {model} (tp={tp}): already warm (marker exists)")
|
||||
continue
|
||||
|
||||
# Architecture dedup
|
||||
config = get_config_json(model)
|
||||
if config is not None:
|
||||
key = get_architecture_key(config, tp)
|
||||
if key in seen_keys:
|
||||
print(
|
||||
f" DEDUP {model} (tp={tp}): same architecture as {seen_keys[key]}"
|
||||
)
|
||||
continue
|
||||
seen_keys[key] = model
|
||||
|
||||
to_warmup.append((model, tp))
|
||||
print(f" QUEUE {model} (tp={tp}): needs warmup")
|
||||
|
||||
if not to_warmup:
|
||||
print("\nAll models already warm. Done.")
|
||||
return
|
||||
|
||||
print(f"\n{len(to_warmup)} model(s) to warm up.\n")
|
||||
|
||||
port = DEFAULT_PORT
|
||||
for i, (model, tp) in enumerate(to_warmup, 1):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"[{i}/{len(to_warmup)}] {model} (tp={tp})")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
t0 = time.time()
|
||||
success = warmup_one_model(model, tp, port)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
if success:
|
||||
print(f" Completed in {elapsed:.0f}s")
|
||||
write_marker(model, tp)
|
||||
# Also write markers for dedup'd models that share this architecture
|
||||
config = get_config_json(model)
|
||||
if config is not None:
|
||||
key = get_architecture_key(config, tp)
|
||||
for other_model, other_tp in model_tp_pairs:
|
||||
if (other_model, other_tp) == (model, tp):
|
||||
continue
|
||||
other_config = get_config_json(other_model)
|
||||
if other_config is not None:
|
||||
other_key = get_architecture_key(other_config, other_tp)
|
||||
if other_key == key and not check_marker(other_model, other_tp):
|
||||
write_marker(other_model, other_tp)
|
||||
print(
|
||||
f" Also marked {other_model} (tp={other_tp}) as warm (same arch)"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" Warning: warmup failed after {elapsed:.0f}s (non-fatal, tests will still work)"
|
||||
)
|
||||
|
||||
# Use a different port for the next model to avoid bind conflicts
|
||||
port += 100
|
||||
|
||||
print("\nServer CUDA graph warmup complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
third_party/sglang/scripts/ci/musa/musa_install_dependency.sh
vendored
Executable file
5
third_party/sglang/scripts/ci/musa/musa_install_dependency.sh
vendored
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
PIP_INSTALL="python3 -m pip install --no-cache-dir"
|
||||
${PIP_INSTALL} --upgrade pip setuptools torchada
|
||||
46
third_party/sglang/scripts/ci/musa/rename_wheels_musa.sh
vendored
Executable file
46
third_party/sglang/scripts/ci/musa/rename_wheels_musa.sh
vendored
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Rename MUSA wheels to include a +musa<suffix> build tag.
|
||||
# Usage:
|
||||
# rename_wheels_musa.sh <musa_suffix> [wheel_dir]
|
||||
# Example:
|
||||
# rename_wheels_musa.sh 43 sgl-kernel/dist
|
||||
|
||||
if [[ $# -lt 1 || $# -gt 2 ]]; then
|
||||
echo "Usage: $0 <musa_suffix> [wheel_dir]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MUSA_SUFFIX="$1"
|
||||
WHEEL_DIR="${2:-dist}"
|
||||
|
||||
wheel_files=("$WHEEL_DIR"/*.whl)
|
||||
|
||||
if [[ ! -e "${wheel_files[0]}" ]]; then
|
||||
echo "No wheel files found in ${WHEEL_DIR}/, nothing to rename."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for wheel in "${wheel_files[@]}"; do
|
||||
# Normalize platform tag to manylinux2014
|
||||
intermediate_wheel="${wheel/linux/manylinux2014}"
|
||||
|
||||
# Extract Python ABI version (e.g. cp310)
|
||||
if [[ $intermediate_wheel =~ -cp([0-9]+)- ]]; then
|
||||
cp_version="${BASH_REMATCH[1]}"
|
||||
else
|
||||
echo "Could not extract Python version from wheel name: $intermediate_wheel" >&2
|
||||
continue
|
||||
fi
|
||||
|
||||
# Insert +musa<suffix> before the Python ABI tag
|
||||
new_wheel="${intermediate_wheel/-cp${cp_version}/+musa${MUSA_SUFFIX}-cp${cp_version}}"
|
||||
|
||||
if [[ "$wheel" != "$new_wheel" ]]; then
|
||||
echo "Renaming $wheel -> $new_wheel"
|
||||
mv -- "$wheel" "$new_wheel"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "MUSA wheel renaming completed."
|
||||
68
third_party/sglang/scripts/ci/npu/npu_ci_install_dependency.sh
vendored
Executable file
68
third_party/sglang/scripts/ci/npu/npu_ci_install_dependency.sh
vendored
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
PIP_INSTALL="python3 -m pip install --no-cache-dir"
|
||||
DEVICE_TYPE=$1
|
||||
OPTIONAL_DEPS="${2:-}"
|
||||
|
||||
|
||||
# Install the required dependencies in CI.
|
||||
apt update -y && apt install -y \
|
||||
unzip \
|
||||
build-essential \
|
||||
cmake \
|
||||
wget \
|
||||
curl \
|
||||
net-tools \
|
||||
zlib1g-dev \
|
||||
lld \
|
||||
clang \
|
||||
locales \
|
||||
ccache \
|
||||
libgl1-mesa-glx \
|
||||
libgl1-mesa-dri \
|
||||
ca-certificates \
|
||||
libgl1 \
|
||||
libglib2.0-0
|
||||
update-ca-certificates
|
||||
${PIP_INSTALL} --upgrade pip
|
||||
# Pin wheel to 0.45.1, REF: https://github.com/pypa/wheel/issues/662
|
||||
${PIP_INSTALL} wheel==0.45.1 pybind11 pyyaml decorator scipy attrs psutil
|
||||
|
||||
|
||||
### Install MemFabric
|
||||
${PIP_INSTALL} memfabric-hybrid==1.0.5
|
||||
|
||||
|
||||
### Install PyTorch and PTA
|
||||
if [ -n "$OPTIONAL_DEPS" ]; then
|
||||
PYTORCH_VERSION="2.10.0"
|
||||
TORCHVISION_VERSION="0.25.0"
|
||||
${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url ${TORCH_CACHE_URL:="https://download.pytorch.org/whl/cpu"} --extra-index-url ${PYPI_CACHE_URL:="https://pypi.org/simple/"}
|
||||
PTA_URL="https://gitcode.com/Ascend/pytorch/releases/download/7.3.0.alpha002/torch_npu-2.10.0rc2-cp311-cp311-manylinux_2_28_aarch64.whl"
|
||||
${PIP_INSTALL} ${PTA_URL}
|
||||
else
|
||||
PYTORCH_VERSION="2.8.0"
|
||||
TORCHVISION_VERSION="0.23.0"
|
||||
${PIP_INSTALL} torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} --index-url ${TORCH_CACHE_URL:="https://download.pytorch.org/whl/cpu"} --extra-index-url ${PYPI_CACHE_URL:="https://pypi.org/simple/"}
|
||||
PTA_URL="https://gitcode.com/Ascend/pytorch/releases/download/v7.3.0-pytorch2.8.0/torch_npu-2.8.0.post2-cp311-cp311-manylinux_2_28_aarch64.whl"
|
||||
${PIP_INSTALL} ${PTA_URL}
|
||||
fi
|
||||
|
||||
|
||||
### Install Triton-Ascend
|
||||
${PIP_INSTALL} triton-ascend
|
||||
|
||||
|
||||
### Install sgl-kernel-npu
|
||||
SGLANG_KERNEL_NPU_TAG="2026.03.10.rc1"
|
||||
mkdir sgl-kernel-npu
|
||||
(cd sgl-kernel-npu && wget "${GITHUB_PROXY_URL:=""}https://github.com/sgl-project/sgl-kernel-npu/releases/download/${SGLANG_KERNEL_NPU_TAG}/sgl-kernel-npu-${SGLANG_KERNEL_NPU_TAG}-torch2.8.0-py311-cann8.5.0-${DEVICE_TYPE}-$(arch).zip" \
|
||||
&& unzip ./sgl-kernel-npu-${SGLANG_KERNEL_NPU_TAG}-torch2.8.0-py311-cann8.5.0-${DEVICE_TYPE}-$(arch).zip \
|
||||
&& ${PIP_INSTALL} ./deep_ep*.whl ./sgl_kernel_npu*.whl \
|
||||
&& (cd "$(python3 -m pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so))
|
||||
|
||||
|
||||
### Install SGLang
|
||||
rm -rf python/pyproject.toml && mv python/pyproject_npu.toml python/pyproject.toml
|
||||
${PIP_INSTALL} -v -e "python[dev_npu]"
|
||||
26
third_party/sglang/scripts/ci/npu/npu_log_print.sh
vendored
Executable file
26
third_party/sglang/scripts/ci/npu/npu_log_print.sh
vendored
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Print log information(sglang version, commit sha, sgl-kernel-npu version, sgl-kernel-npu commit sha, npu-smi info and pip list.
|
||||
npu-smi info
|
||||
pip list
|
||||
get_version() {
|
||||
[ -f "$1" ] && python3 -c 'import re, sys; print(sys.argv[2] + " version: v" + re.search(r"__version__\s*=\s*[\"'"'"'](.*?)[\"'"'"']", open(sys.argv[1]).read()).group(1))' "$1" "$2" 2>/dev/null || echo "$2 version: unknown"
|
||||
}
|
||||
get_version "./python/sglang/version.py" "sglang"
|
||||
get_version "./sgl-kernel/python/sgl_kernel/version.py" "sgl_kernel"
|
||||
SGLANG_URL="https://github.com/sgl-project/sglang.git"
|
||||
SGL_KERNEL_URL="https://github.com/sgl-project/sgl-kernel-npu.git"
|
||||
SGLANG_BRANCH="main"
|
||||
SGL_KERNEL_BRANCH="main"
|
||||
get_sha() {
|
||||
local name="$1"
|
||||
local url="$2"
|
||||
local branch="$3"
|
||||
local sha
|
||||
sha=$(git ls-remote "$url" "refs/heads/$branch" | cut -f1)
|
||||
echo "$name SHA for branch $branch: ${sha:-"Not Found"}"
|
||||
}
|
||||
get_sha "sglang" "$SGLANG_URL" "$SGLANG_BRANCH"
|
||||
get_sha "sgl-kernel" "$SGL_KERNEL_URL" "$SGL_KERNEL_BRANCH"
|
||||
chmod +x scripts/ci/npu/npu_log_print.sh
|
||||
477
third_party/sglang/scripts/ci/utils/ci_coverage_report.py
vendored
Executable file
477
third_party/sglang/scripts/ci/utils/ci_coverage_report.py
vendored
Executable file
@@ -0,0 +1,477 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CI Coverage Report Generator
|
||||
|
||||
Collects all CI test registrations from test/registered/ and generates
|
||||
a coverage report organized by folder, backend, and suite.
|
||||
|
||||
Usage:
|
||||
python scripts/ci/utils/ci_coverage_report.py [--output-format markdown|json]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
# Add the ci_register module path directly to avoid heavy sglang imports
|
||||
sys.path.insert(
|
||||
0,
|
||||
str(
|
||||
Path(__file__).parent.parent.parent.parent / "python" / "sglang" / "test" / "ci"
|
||||
),
|
||||
)
|
||||
|
||||
from ci_register import CIRegistry, HWBackend, ut_parse_one_file
|
||||
|
||||
|
||||
def collect_all_tests(registered_dir: str) -> list[CIRegistry]:
|
||||
"""Collect all CI registrations from registered directory."""
|
||||
files = glob.glob(f"{registered_dir}/**/*.py", recursive=True)
|
||||
all_tests = []
|
||||
|
||||
for file in sorted(files):
|
||||
try:
|
||||
registries = ut_parse_one_file(file)
|
||||
all_tests.extend(registries)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse {file}: {e}", file=sys.stderr)
|
||||
|
||||
return all_tests
|
||||
|
||||
|
||||
def get_folder_name(filename: str) -> str:
|
||||
"""Extract folder name from test filename."""
|
||||
# e.g., "registered/models/test_foo.py" -> "models"
|
||||
parts = Path(filename).parts
|
||||
if "registered" in parts:
|
||||
idx = parts.index("registered")
|
||||
if idx + 1 < len(parts) - 1: # Has subfolder
|
||||
return parts[idx + 1]
|
||||
return "root"
|
||||
|
||||
|
||||
def get_test_basename(filename: str) -> str:
|
||||
"""Extract just the test file name from the path."""
|
||||
return Path(filename).name
|
||||
|
||||
|
||||
def organize_test_data(tests: list[CIRegistry]) -> dict:
|
||||
"""Organize tests into various groupings."""
|
||||
by_backend = defaultdict(list)
|
||||
by_folder = defaultdict(list)
|
||||
disabled_tests = []
|
||||
|
||||
for t in tests:
|
||||
by_backend[t.backend.name].append(t)
|
||||
by_folder[get_folder_name(t.filename)].append(t)
|
||||
if t.disabled:
|
||||
disabled_tests.append(t)
|
||||
|
||||
# Count unique test files (a file may be registered for multiple backends)
|
||||
unique_files = set(t.filename for t in tests)
|
||||
unique_enabled_files = set(t.filename for t in tests if not t.disabled)
|
||||
unique_disabled_files = set(t.filename for t in tests if t.disabled)
|
||||
|
||||
return {
|
||||
"total": len(tests),
|
||||
"total_unique_files": len(unique_files),
|
||||
"enabled": len(tests) - len(disabled_tests),
|
||||
"enabled_unique_files": len(unique_enabled_files),
|
||||
"disabled_count": len(disabled_tests),
|
||||
"disabled_unique_files": len(unique_disabled_files),
|
||||
"by_backend": by_backend,
|
||||
"by_folder": by_folder,
|
||||
"disabled_tests": disabled_tests,
|
||||
}
|
||||
|
||||
|
||||
def generate_summary_section(data: dict) -> str:
|
||||
"""Generate the summary/overview section."""
|
||||
lines = []
|
||||
lines.append("# CI Coverage Overview\n")
|
||||
lines.append(
|
||||
f"**Unique Test Files:** {data['total_unique_files']} ({data['enabled_unique_files']} enabled, {data['disabled_unique_files']} disabled)\n"
|
||||
)
|
||||
lines.append(
|
||||
f"**Total Registrations:** {data['total']} ({data['enabled']} enabled, {data['disabled_count']} disabled)\n"
|
||||
)
|
||||
lines.append(
|
||||
"*Note: A test file may be registered for multiple backends (e.g., CUDA + AMD), so total registrations > unique files.*\n"
|
||||
)
|
||||
|
||||
by_backend = data["by_backend"]
|
||||
by_folder = data["by_folder"]
|
||||
disabled_tests = data["disabled_tests"]
|
||||
|
||||
# Backend summary (collapsible)
|
||||
lines.append("<details>")
|
||||
lines.append("<summary><h2>Backend Summary</h2></summary>\n")
|
||||
lines.append("| Backend | Total | Enabled | Disabled | Per-Commit | Nightly |")
|
||||
lines.append("|---------|-------|---------|----------|------------|---------|")
|
||||
|
||||
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
|
||||
backend_tests = by_backend.get(backend, [])
|
||||
if not backend_tests:
|
||||
continue
|
||||
b_total = len(backend_tests)
|
||||
b_disabled = sum(1 for t in backend_tests if t.disabled)
|
||||
b_enabled = b_total - b_disabled
|
||||
b_per_commit = sum(1 for t in backend_tests if not t.nightly and not t.disabled)
|
||||
b_nightly = sum(1 for t in backend_tests if t.nightly and not t.disabled)
|
||||
lines.append(
|
||||
f"| {backend} | {b_total} | {b_enabled} | {b_disabled} | {b_per_commit} | {b_nightly} |"
|
||||
)
|
||||
|
||||
lines.append("\n</details>\n")
|
||||
|
||||
# Folder summary (collapsible)
|
||||
lines.append("<details>")
|
||||
lines.append("<summary><h2>Folder Summary</h2></summary>\n")
|
||||
lines.append("| Folder | CUDA | AMD | NPU | CPU | Total |")
|
||||
lines.append("|--------|------|-----|-----|-----|-------|")
|
||||
|
||||
for folder in sorted(by_folder.keys()):
|
||||
folder_tests = by_folder[folder]
|
||||
cuda = sum(1 for t in folder_tests if t.backend == HWBackend.CUDA)
|
||||
amd = sum(1 for t in folder_tests if t.backend == HWBackend.AMD)
|
||||
npu = sum(1 for t in folder_tests if t.backend == HWBackend.NPU)
|
||||
cpu = sum(1 for t in folder_tests if t.backend == HWBackend.CPU)
|
||||
lines.append(
|
||||
f"| {folder} | {cuda} | {amd} | {npu} | {cpu} | {len(folder_tests)} |"
|
||||
)
|
||||
|
||||
lines.append("\n</details>\n")
|
||||
|
||||
# Disabled tests section (collapsible)
|
||||
if disabled_tests:
|
||||
lines.append("<details>")
|
||||
lines.append("<summary><h2>Disabled Tests</h2></summary>\n")
|
||||
lines.append("| File | Backend | Suite | Reason |")
|
||||
lines.append("|------|---------|-------|--------|")
|
||||
for t in sorted(disabled_tests, key=lambda x: (x.backend.name, x.filename)):
|
||||
test_name = get_test_basename(t.filename)
|
||||
reason = t.disabled[:50] + "..." if len(t.disabled) > 50 else t.disabled
|
||||
lines.append(f"| `{test_name}` | {t.backend.name} | {t.suite} | {reason} |")
|
||||
lines.append("\n</details>\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_by_folder_section(data: dict) -> str:
|
||||
"""Generate the 'All Tests by Folder' section."""
|
||||
lines = []
|
||||
by_folder = data["by_folder"]
|
||||
|
||||
lines.append("# All Tests by Folder\n")
|
||||
|
||||
for folder in sorted(by_folder.keys()):
|
||||
folder_tests = by_folder[folder]
|
||||
lines.append("<details>")
|
||||
lines.append(
|
||||
f"<summary><h2>{folder}/ ({len(folder_tests)} tests)</h2></summary>\n"
|
||||
)
|
||||
|
||||
# Group by backend within folder
|
||||
folder_by_backend = defaultdict(list)
|
||||
for t in folder_tests:
|
||||
folder_by_backend[t.backend.name].append(t)
|
||||
|
||||
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
|
||||
backend_tests = folder_by_backend.get(backend, [])
|
||||
if not backend_tests:
|
||||
continue
|
||||
|
||||
lines.append(f"### {backend} ({len(backend_tests)} tests)\n")
|
||||
lines.append("| Test File | Suite | Est. Time | Status |")
|
||||
lines.append("|-----------|-------|-----------|--------|")
|
||||
|
||||
for t in sorted(backend_tests, key=lambda x: x.filename):
|
||||
test_name = get_test_basename(t.filename)
|
||||
status = (
|
||||
"Disabled"
|
||||
if t.disabled
|
||||
else ("Nightly" if t.nightly else "Per-Commit")
|
||||
)
|
||||
lines.append(
|
||||
f"| `{test_name}` | {t.suite} | {t.est_time:.0f}s | {status} |"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
|
||||
lines.append("</details>\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_by_suite_section(data: dict) -> str:
|
||||
"""Generate the 'All Tests by Test Suite' section."""
|
||||
lines = []
|
||||
by_backend = data["by_backend"]
|
||||
|
||||
lines.append("# All Tests by Test Suite\n")
|
||||
|
||||
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
|
||||
backend_tests = by_backend.get(backend, [])
|
||||
if not backend_tests:
|
||||
continue
|
||||
|
||||
b_total = len(backend_tests)
|
||||
b_disabled = sum(1 for t in backend_tests if t.disabled)
|
||||
b_enabled = b_total - b_disabled
|
||||
|
||||
lines.append("<details>")
|
||||
lines.append(
|
||||
f"<summary><h2>{backend} Backend ({b_enabled} enabled, {b_disabled} disabled)</h2></summary>\n"
|
||||
)
|
||||
|
||||
# Group by suite within backend
|
||||
backend_suites = defaultdict(list)
|
||||
for t in backend_tests:
|
||||
backend_suites[t.suite].append(t)
|
||||
|
||||
for suite in sorted(backend_suites.keys()):
|
||||
suite_tests = backend_suites[suite]
|
||||
s_enabled = sum(1 for t in suite_tests if not t.disabled)
|
||||
s_disabled = sum(1 for t in suite_tests if t.disabled)
|
||||
s_est_time = sum(t.est_time for t in suite_tests if not t.disabled)
|
||||
is_nightly = any(t.nightly for t in suite_tests if not t.disabled)
|
||||
|
||||
suite_type = "Nightly" if is_nightly else "Per-Commit"
|
||||
lines.append("<details>")
|
||||
lines.append(
|
||||
f"<summary><h3>{suite} ({s_enabled} enabled, {s_disabled} disabled) - {suite_type}</h3></summary>\n"
|
||||
)
|
||||
lines.append(f"*Estimated total time: {s_est_time:.0f}s*\n")
|
||||
|
||||
lines.append("| Test File | Folder | Est. Time | Status |")
|
||||
lines.append("|-----------|--------|-----------|--------|")
|
||||
|
||||
for t in sorted(suite_tests, key=lambda x: x.filename):
|
||||
test_name = get_test_basename(t.filename)
|
||||
folder = get_folder_name(t.filename)
|
||||
if t.disabled:
|
||||
status = (
|
||||
f"Disabled: {t.disabled[:30]}..."
|
||||
if len(t.disabled) > 30
|
||||
else f"Disabled: {t.disabled}"
|
||||
)
|
||||
else:
|
||||
status = "Nightly" if t.nightly else "Per-Commit"
|
||||
lines.append(
|
||||
f"| `{test_name}` | {folder} | {t.est_time:.0f}s | {status} |"
|
||||
)
|
||||
|
||||
lines.append("\n</details>\n")
|
||||
|
||||
lines.append("</details>\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_markdown_report(tests: list[CIRegistry], section: str = "all") -> str:
|
||||
"""Generate markdown report for GitHub step summary."""
|
||||
data = organize_test_data(tests)
|
||||
|
||||
if section == "summary":
|
||||
return generate_summary_section(data)
|
||||
elif section == "by-folder":
|
||||
return generate_by_folder_section(data)
|
||||
elif section == "by-suite":
|
||||
return generate_by_suite_section(data)
|
||||
else: # "all"
|
||||
parts = [
|
||||
generate_summary_section(data),
|
||||
"---",
|
||||
generate_by_folder_section(data),
|
||||
"---",
|
||||
generate_by_suite_section(data),
|
||||
]
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def generate_json_report(tests: list[CIRegistry]) -> str:
|
||||
"""Generate JSON report with detailed test listings."""
|
||||
by_backend = defaultdict(list)
|
||||
by_folder = defaultdict(list)
|
||||
|
||||
for t in tests:
|
||||
by_backend[t.backend.name].append(t)
|
||||
by_folder[get_folder_name(t.filename)].append(t)
|
||||
|
||||
disabled_tests = [t for t in tests if t.disabled]
|
||||
|
||||
# Build structured data
|
||||
data = {
|
||||
"summary": {
|
||||
"total": len(tests),
|
||||
"enabled": len(tests) - len(disabled_tests),
|
||||
"disabled": len(disabled_tests),
|
||||
},
|
||||
"tests_by_folder": {},
|
||||
"tests_by_suite": {},
|
||||
"backend_summary": {},
|
||||
"folder_summary": {},
|
||||
"disabled_tests": [],
|
||||
}
|
||||
|
||||
# Section 1: Tests by Folder
|
||||
for folder in sorted(by_folder.keys()):
|
||||
folder_tests = by_folder[folder]
|
||||
folder_by_backend = defaultdict(list)
|
||||
for t in folder_tests:
|
||||
folder_by_backend[t.backend.name].append(t)
|
||||
|
||||
data["tests_by_folder"][folder] = {
|
||||
"total": len(folder_tests),
|
||||
"backends": {},
|
||||
}
|
||||
|
||||
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
|
||||
backend_tests = folder_by_backend.get(backend, [])
|
||||
if backend_tests:
|
||||
data["tests_by_folder"][folder]["backends"][backend] = [
|
||||
{
|
||||
"filename": get_test_basename(t.filename),
|
||||
"suite": t.suite,
|
||||
"est_time": t.est_time,
|
||||
"status": (
|
||||
"disabled"
|
||||
if t.disabled
|
||||
else ("nightly" if t.nightly else "per-commit")
|
||||
),
|
||||
}
|
||||
for t in sorted(backend_tests, key=lambda x: x.filename)
|
||||
]
|
||||
|
||||
# Section 2: Tests by Suite (Backend -> Suite)
|
||||
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
|
||||
backend_tests = by_backend.get(backend, [])
|
||||
if not backend_tests:
|
||||
continue
|
||||
|
||||
backend_suites = defaultdict(list)
|
||||
for t in backend_tests:
|
||||
backend_suites[t.suite].append(t)
|
||||
|
||||
data["tests_by_suite"][backend] = {
|
||||
"total": len(backend_tests),
|
||||
"enabled": sum(1 for t in backend_tests if not t.disabled),
|
||||
"disabled": sum(1 for t in backend_tests if t.disabled),
|
||||
"suites": {},
|
||||
}
|
||||
|
||||
for suite in sorted(backend_suites.keys()):
|
||||
suite_tests = backend_suites[suite]
|
||||
is_nightly = any(t.nightly for t in suite_tests if not t.disabled)
|
||||
|
||||
data["tests_by_suite"][backend]["suites"][suite] = {
|
||||
"total": len(suite_tests),
|
||||
"enabled": sum(1 for t in suite_tests if not t.disabled),
|
||||
"disabled": sum(1 for t in suite_tests if t.disabled),
|
||||
"est_time": sum(t.est_time for t in suite_tests if not t.disabled),
|
||||
"type": "nightly" if is_nightly else "per-commit",
|
||||
"tests": [
|
||||
{
|
||||
"filename": get_test_basename(t.filename),
|
||||
"folder": get_folder_name(t.filename),
|
||||
"est_time": t.est_time,
|
||||
"status": (
|
||||
"disabled"
|
||||
if t.disabled
|
||||
else ("nightly" if t.nightly else "per-commit")
|
||||
),
|
||||
"disabled_reason": t.disabled if t.disabled else None,
|
||||
}
|
||||
for t in sorted(suite_tests, key=lambda x: x.filename)
|
||||
],
|
||||
}
|
||||
|
||||
# Backend summary
|
||||
for backend in ["CUDA", "AMD", "NPU", "CPU"]:
|
||||
backend_tests = by_backend.get(backend, [])
|
||||
if backend_tests:
|
||||
data["backend_summary"][backend] = {
|
||||
"total": len(backend_tests),
|
||||
"enabled": sum(1 for t in backend_tests if not t.disabled),
|
||||
"disabled": sum(1 for t in backend_tests if t.disabled),
|
||||
"per_commit": sum(
|
||||
1 for t in backend_tests if not t.nightly and not t.disabled
|
||||
),
|
||||
"nightly": sum(
|
||||
1 for t in backend_tests if t.nightly and not t.disabled
|
||||
),
|
||||
}
|
||||
|
||||
# Folder summary
|
||||
for folder in sorted(by_folder.keys()):
|
||||
folder_tests = by_folder[folder]
|
||||
data["folder_summary"][folder] = {
|
||||
"CUDA": sum(1 for t in folder_tests if t.backend == HWBackend.CUDA),
|
||||
"AMD": sum(1 for t in folder_tests if t.backend == HWBackend.AMD),
|
||||
"NPU": sum(1 for t in folder_tests if t.backend == HWBackend.NPU),
|
||||
"CPU": sum(1 for t in folder_tests if t.backend == HWBackend.CPU),
|
||||
"total": len(folder_tests),
|
||||
}
|
||||
|
||||
# Disabled tests
|
||||
for t in sorted(disabled_tests, key=lambda x: (x.backend.name, x.filename)):
|
||||
data["disabled_tests"].append(
|
||||
{
|
||||
"filename": get_test_basename(t.filename),
|
||||
"backend": t.backend.name,
|
||||
"suite": t.suite,
|
||||
"reason": t.disabled,
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate CI coverage report")
|
||||
parser.add_argument(
|
||||
"--output-format",
|
||||
choices=["markdown", "json"],
|
||||
default="markdown",
|
||||
help="Output format (default: markdown)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--section",
|
||||
choices=["all", "summary", "by-folder", "by-suite"],
|
||||
default="all",
|
||||
help="Which section to output (default: all). Only applies to markdown format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--registered-dir",
|
||||
default="test/registered",
|
||||
help="Path to registered test directory",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Change to repo root if needed
|
||||
script_dir = Path(__file__).parent.parent
|
||||
repo_root = script_dir.parent.parent
|
||||
os.chdir(repo_root)
|
||||
|
||||
tests = collect_all_tests(args.registered_dir)
|
||||
|
||||
if args.output_format == "markdown":
|
||||
report = generate_markdown_report(tests, section=args.section)
|
||||
else:
|
||||
report = generate_json_report(tests)
|
||||
|
||||
print(report)
|
||||
|
||||
# Write to GITHUB_STEP_SUMMARY if available
|
||||
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
|
||||
if summary_file and args.output_format == "markdown":
|
||||
with open(summary_file, "a") as f:
|
||||
f.write(report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
third_party/sglang/scripts/ci/utils/cleanup_hf_cache.py
vendored
Executable file
146
third_party/sglang/scripts/ci/utils/cleanup_hf_cache.py
vendored
Executable file
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean up stale HuggingFace cache artifacts from previous failed downloads.
|
||||
|
||||
This script removes incomplete marker files, temporary files, and lock files
|
||||
from the HuggingFace cache directory. These artifacts can accumulate from
|
||||
interrupted or failed downloads and may interfere with future downloads.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
try:
|
||||
from huggingface_hub import constants
|
||||
|
||||
HF_HUB_AVAILABLE = True
|
||||
except ImportError:
|
||||
print("Warning: huggingface_hub not available")
|
||||
HF_HUB_AVAILABLE = False
|
||||
|
||||
|
||||
def get_hf_cache_dir() -> str:
|
||||
"""Get the HuggingFace cache directory."""
|
||||
if HF_HUB_AVAILABLE:
|
||||
return constants.HF_HUB_CACHE
|
||||
|
||||
# Fallback to environment variable or default
|
||||
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
||||
return os.path.join(hf_home, "hub")
|
||||
|
||||
|
||||
def find_stale_artifacts(cache_dir: str) -> List[Path]:
|
||||
"""
|
||||
Find stale artifact files in the HuggingFace cache.
|
||||
|
||||
Args:
|
||||
cache_dir: HuggingFace cache directory
|
||||
|
||||
Returns:
|
||||
List of paths to stale artifact files
|
||||
"""
|
||||
cache_path = Path(cache_dir)
|
||||
|
||||
if not cache_path.exists():
|
||||
return []
|
||||
|
||||
# Patterns for stale files to clean up
|
||||
patterns = [
|
||||
"**/*.incomplete", # Incomplete download markers
|
||||
"**/*.tmp", # Temporary files
|
||||
"**/*.lock", # Lock files from interrupted downloads
|
||||
]
|
||||
|
||||
stale_files = []
|
||||
for pattern in patterns:
|
||||
stale_files.extend(cache_path.glob(pattern))
|
||||
|
||||
return stale_files
|
||||
|
||||
|
||||
def cleanup_artifacts(artifacts: List[Path]) -> tuple[int, int]:
|
||||
"""
|
||||
Remove stale artifact files.
|
||||
|
||||
Args:
|
||||
artifacts: List of file paths to remove
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_removals, failed_removals)
|
||||
"""
|
||||
successful = 0
|
||||
failed = 0
|
||||
|
||||
for file_path in artifacts:
|
||||
try:
|
||||
file_path.unlink()
|
||||
print(f" Removed: {file_path}")
|
||||
successful += 1
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not remove {file_path}: {e}")
|
||||
failed += 1
|
||||
|
||||
return successful, failed
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""
|
||||
Main cleanup logic.
|
||||
|
||||
Returns:
|
||||
Always returns 0 (cleanup is best-effort and should not fail CI)
|
||||
"""
|
||||
print("=" * 70)
|
||||
print("HuggingFace Cache Cleanup")
|
||||
print("=" * 70)
|
||||
|
||||
# Get cache directory
|
||||
cache_dir = get_hf_cache_dir()
|
||||
print(f"Cache directory: {cache_dir}")
|
||||
|
||||
if not os.path.exists(cache_dir):
|
||||
print("Cache directory does not exist - nothing to clean")
|
||||
return 0
|
||||
|
||||
print("-" * 70)
|
||||
|
||||
# Find stale artifacts
|
||||
print("Scanning for stale artifacts...")
|
||||
stale_artifacts = find_stale_artifacts(cache_dir)
|
||||
|
||||
if not stale_artifacts:
|
||||
print("✓ No stale cache artifacts found")
|
||||
return 0
|
||||
|
||||
# Clean up artifacts
|
||||
print(f"Found {len(stale_artifacts)} stale artifact(s) to remove:")
|
||||
successful, failed = cleanup_artifacts(stale_artifacts)
|
||||
|
||||
print("-" * 70)
|
||||
|
||||
# Summary
|
||||
if failed > 0:
|
||||
print(f"⚠ Cleaned up {successful} file(s), {failed} removal(s) failed")
|
||||
else:
|
||||
print(f"✓ Successfully cleaned up {successful} stale file(s)")
|
||||
|
||||
# Always return 0 - cleanup failures should not fail CI
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Unexpected error during cleanup: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# Still return 0 - cleanup failures should not fail CI
|
||||
sys.exit(0)
|
||||
0
third_party/sglang/scripts/ci/utils/diffusion/__init__.py
vendored
Normal file
0
third_party/sglang/scripts/ci/utils/diffusion/__init__.py
vendored
Normal file
157
third_party/sglang/scripts/ci/utils/diffusion/comparison_configs.json
vendored
Normal file
157
third_party/sglang/scripts/ci/utils/diffusion/comparison_configs.json
vendored
Normal file
@@ -0,0 +1,157 @@
|
||||
{
|
||||
"_comment": "Per-model comparison config. Sampling params omitted where model defaults are correct — only override resolution, seed, and params that differ from defaults.",
|
||||
"test_image_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png",
|
||||
"cases": [
|
||||
{
|
||||
"id": "flux1_dev_t2i_1024",
|
||||
"model": "black-forest-labs/FLUX.1-dev",
|
||||
"task": "text-to-image",
|
||||
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"seed": 42,
|
||||
"num_gpus": 1,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup --dit-layerwise-offload false",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "flux2_dev_t2i_1024",
|
||||
"model": "black-forest-labs/FLUX.2-dev",
|
||||
"task": "text-to-image",
|
||||
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"seed": 42,
|
||||
"num_gpus": 1,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup --dit-layerwise-offload false",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "qwen_image_2512_t2i_1024",
|
||||
"model": "Qwen/Qwen-Image-2512",
|
||||
"task": "text-to-image",
|
||||
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"seed": 42,
|
||||
"num_gpus": 1,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "qwen_image_edit_2511",
|
||||
"model": "Qwen/Qwen-Image-Edit-2511",
|
||||
"task": "image-edit",
|
||||
"prompt": "Make the cat wear a red hat",
|
||||
"reference_image": true,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"seed": 42,
|
||||
"num_gpus": 1,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "zimage_turbo_t2i_1024",
|
||||
"model": "Tongyi-MAI/Z-Image-Turbo",
|
||||
"task": "text-to-image",
|
||||
"prompt": "A futuristic cyberpunk city at night, neon lights reflecting on wet streets",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"seed": 42,
|
||||
"num_gpus": 1,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "wan22_t2v_a14b_720p",
|
||||
"model": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
|
||||
"task": "text-to-video",
|
||||
"prompt": "A cat and a dog baking a cake together in a kitchen.",
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"num_frames": 81,
|
||||
"seed": 42,
|
||||
"num_gpus": 4,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup --enable-cfg-parallel --ulysses-degree 2 --text-encoder-cpu-offload --pin-cpu-memory",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "wan22_ti2v_5b_720p",
|
||||
"model": "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
|
||||
"task": "text-image-to-video",
|
||||
"prompt": "The cat starts walking slowly towards the camera.",
|
||||
"reference_image": true,
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"num_frames": 81,
|
||||
"seed": 42,
|
||||
"num_gpus": 1,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ltx2_twostage_t2v",
|
||||
"model": "Lightricks/LTX-2",
|
||||
"task": "text-to-video",
|
||||
"prompt": "A cat and a dog baking a cake together in a kitchen.",
|
||||
"width": 768,
|
||||
"height": 512,
|
||||
"num_frames": 121,
|
||||
"seed": 42,
|
||||
"num_gpus": 2,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup --enable-cfg-parallel --pipeline-class-name LTX2TwoStagePipeline",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "wan22_i2v_a14b_720p",
|
||||
"model": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
|
||||
"task": "image-to-video",
|
||||
"prompt": "The cat starts walking slowly towards the camera.",
|
||||
"reference_image": true,
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"num_frames": 81,
|
||||
"seed": 42,
|
||||
"num_gpus": 4,
|
||||
"frameworks": {
|
||||
"sglang": {
|
||||
"serve_args": "--enable-torch-compile --warmup --enable-cfg-parallel --ulysses-degree 2 --text-encoder-cpu-offload --pin-cpu-memory",
|
||||
"extra_env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
836
third_party/sglang/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py
vendored
Normal file
836
third_party/sglang/scripts/ci/utils/diffusion/generate_diffusion_dashboard.py
vendored
Normal file
@@ -0,0 +1,836 @@
|
||||
"""Generate a Markdown dashboard for diffusion cross-framework comparisons.
|
||||
|
||||
Reads current comparison results + historical data from sglang-ci-data repo
|
||||
and produces a Markdown report with tables and trend charts saved as PNG files.
|
||||
|
||||
Usage:
|
||||
python3 scripts/ci/utils/diffusion/generate_diffusion_dashboard.py \
|
||||
--results comparison-results.json \
|
||||
--output dashboard.md \
|
||||
--charts-dir comparison-charts/ \
|
||||
--history-dir history/ # optional, local history JSONs
|
||||
--fetch-history # fetch from GitHub API instead
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# History fetching (from sglang-ci-data repo via GitHub API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CI_DATA_REPO_OWNER = "sglang-bot"
|
||||
CI_DATA_REPO_NAME = "sglang-ci-data"
|
||||
CI_DATA_BRANCH = "main"
|
||||
HISTORY_PREFIX = "diffusion-comparisons"
|
||||
MAX_HISTORY_RUNS = 14
|
||||
|
||||
# Base URL for chart images pushed to sglang-ci-data
|
||||
CHARTS_RAW_BASE_URL = (
|
||||
f"https://raw.githubusercontent.com/{CI_DATA_REPO_OWNER}/{CI_DATA_REPO_NAME}"
|
||||
f"/{CI_DATA_BRANCH}/{HISTORY_PREFIX}/charts"
|
||||
)
|
||||
|
||||
|
||||
def _github_get(url: str, token: str) -> dict | list | None:
|
||||
"""Simple GET to GitHub API."""
|
||||
from urllib.error import HTTPError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
req = Request(url, headers=headers)
|
||||
try:
|
||||
with urlopen(req) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
except HTTPError as e:
|
||||
print(f" Warning: GitHub API request failed ({e.code}): {url}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" Warning: GitHub API request error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def fetch_history_from_github(token: str) -> list[dict]:
|
||||
"""Fetch recent comparison result JSONs from sglang-ci-data repo."""
|
||||
print("Fetching historical comparison data from GitHub...")
|
||||
url = (
|
||||
f"https://api.github.com/repos/{CI_DATA_REPO_OWNER}/{CI_DATA_REPO_NAME}"
|
||||
f"/contents/{HISTORY_PREFIX}?ref={CI_DATA_BRANCH}"
|
||||
)
|
||||
listing = _github_get(url, token)
|
||||
if not listing or not isinstance(listing, list):
|
||||
print(" No historical data found.")
|
||||
return []
|
||||
|
||||
# Filter JSON files and sort by name (date prefix) descending
|
||||
json_files = sorted(
|
||||
[f for f in listing if f["name"].endswith(".json")],
|
||||
key=lambda f: f["name"],
|
||||
reverse=True,
|
||||
)[:MAX_HISTORY_RUNS]
|
||||
|
||||
history = []
|
||||
for entry in json_files:
|
||||
raw_url = entry.get("download_url")
|
||||
if not raw_url:
|
||||
continue
|
||||
data = _github_get(raw_url, token)
|
||||
if data and isinstance(data, dict):
|
||||
history.append(data)
|
||||
print(f" Loaded {len(history)} historical run(s).")
|
||||
return history
|
||||
|
||||
|
||||
def load_history_from_dir(history_dir: str) -> list[dict]:
|
||||
"""Load historical JSONs from a local directory."""
|
||||
if not os.path.isdir(history_dir):
|
||||
return []
|
||||
files = sorted(
|
||||
[f for f in os.listdir(history_dir) if f.endswith(".json")],
|
||||
reverse=True,
|
||||
)[:MAX_HISTORY_RUNS]
|
||||
history = []
|
||||
for fname in files:
|
||||
try:
|
||||
with open(os.path.join(history_dir, fname)) as f:
|
||||
history.append(json.load(f))
|
||||
except Exception:
|
||||
pass
|
||||
return history
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dashboard generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fmt_latency(val: float | None) -> str:
|
||||
if val is None:
|
||||
return "N/A"
|
||||
return f"{val:.2f}"
|
||||
|
||||
|
||||
def _fmt_speedup(sglang_lat: float | None, other_lat: float | None) -> str:
|
||||
if sglang_lat is None or other_lat is None or sglang_lat <= 0:
|
||||
return "N/A"
|
||||
ratio = other_lat / sglang_lat
|
||||
return f"{ratio:.2f}x"
|
||||
|
||||
|
||||
def _short_date(ts: str) -> str:
|
||||
"""Extract short date from ISO timestamp."""
|
||||
try:
|
||||
dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
|
||||
return dt.strftime("%b %d")
|
||||
except Exception:
|
||||
return ts[:10]
|
||||
|
||||
|
||||
def _short_sha(sha: str) -> str:
|
||||
return sha[:7] if sha and sha != "unknown" else "?"
|
||||
|
||||
|
||||
def _assess_risk(
|
||||
cid: str,
|
||||
current_cases: dict[str, dict[str, float | None]],
|
||||
history: list[dict],
|
||||
other_frameworks: list[str],
|
||||
) -> tuple[str, str]:
|
||||
"""Assess risk for a given case, returning (emoji, reason).
|
||||
|
||||
Rules (checked in order):
|
||||
- N/A latency → ❌ broken
|
||||
- History exists: SGLang latency >5% vs avg of last 3 runs → ⚠️ regression
|
||||
- Competitor exists & SGLang slower → 🔴 competitive risk
|
||||
- SGLang faster than all competitors by >20% → 🟢 strong advantage
|
||||
- SGLang faster than all competitors by ≤20% → 🟡 moderate advantage
|
||||
- Default → ✅ stable
|
||||
"""
|
||||
sg_lat = current_cases.get(cid, {}).get("sglang")
|
||||
|
||||
# Broken: sglang latency is N/A
|
||||
if sg_lat is None:
|
||||
return "❌", f"{cid}: SGLang latency is N/A (broken)"
|
||||
|
||||
# Check regression against 3-run historical average
|
||||
if history:
|
||||
hist_lats: list[float] = []
|
||||
for run in history[:3]:
|
||||
run_cases = _extract_case_results(run)
|
||||
h_lat = run_cases.get(cid, {}).get("sglang")
|
||||
if h_lat is not None:
|
||||
hist_lats.append(h_lat)
|
||||
if hist_lats:
|
||||
avg_3 = sum(hist_lats) / len(hist_lats)
|
||||
if avg_3 > 0 and (sg_lat - avg_3) / avg_3 > 0.05:
|
||||
pct = (sg_lat - avg_3) / avg_3 * 100
|
||||
return (
|
||||
"⚠️",
|
||||
f"{cid}: SGLang regression +{pct:.1f}% vs 3-run avg "
|
||||
f"({sg_lat:.2f}s vs {avg_3:.2f}s)",
|
||||
)
|
||||
|
||||
# Check competitive risk
|
||||
if other_frameworks:
|
||||
competitor_lats: dict[str, float] = {}
|
||||
for ofw in other_frameworks:
|
||||
olat = current_cases.get(cid, {}).get(ofw)
|
||||
if olat is not None:
|
||||
competitor_lats[ofw] = olat
|
||||
|
||||
if competitor_lats:
|
||||
# SGLang slower than any competitor?
|
||||
for ofw, olat in competitor_lats.items():
|
||||
if sg_lat > olat:
|
||||
return (
|
||||
"🔴",
|
||||
f"{cid}: SGLang slower than {ofw} "
|
||||
f"({sg_lat:.2f}s vs {olat:.2f}s)",
|
||||
)
|
||||
|
||||
# SGLang faster — check margin
|
||||
min_competitor = min(competitor_lats.values())
|
||||
advantage = (min_competitor - sg_lat) / min_competitor
|
||||
if advantage > 0.20:
|
||||
return "🟢", ""
|
||||
else:
|
||||
return "🟡", ""
|
||||
|
||||
# Default: stable
|
||||
return "✅", ""
|
||||
|
||||
|
||||
def _trend_emoji(current: float | None, previous: float | None) -> str:
|
||||
if current is None or previous is None:
|
||||
return ""
|
||||
diff_pct = (current - previous) / previous * 100
|
||||
if diff_pct < -2:
|
||||
return " :arrow_down:" # faster (good)
|
||||
elif diff_pct > 2:
|
||||
return " :arrow_up:" # slower (bad)
|
||||
return " :left_right_arrow:"
|
||||
|
||||
|
||||
def _extract_case_results(run_data: dict) -> dict[str, dict[str, float | None]]:
|
||||
"""Extract {case_id: {framework: latency}} from a run."""
|
||||
mapping: dict[str, dict[str, float | None]] = {}
|
||||
for r in run_data.get("results", []):
|
||||
cid = r["case_id"]
|
||||
fw = r["framework"]
|
||||
if cid not in mapping:
|
||||
mapping[cid] = {}
|
||||
mapping[cid][fw] = r.get("latency_s")
|
||||
return mapping
|
||||
|
||||
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
"""Sanitize a case ID to be a safe filename."""
|
||||
return name.replace("/", "_").replace(" ", "_").replace(":", "_")
|
||||
|
||||
|
||||
def generate_dashboard(
|
||||
current: dict,
|
||||
history: list[dict],
|
||||
charts_dir: str | None = None,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Generate full markdown dashboard.
|
||||
|
||||
Returns (markdown_string, alert_reasons) where alert_reasons is a list of
|
||||
human-readable strings for cases that need attention (empty if all is well).
|
||||
|
||||
If charts_dir is provided, saves chart PNGs as files to that directory
|
||||
and references them via raw.githubusercontent URLs. Otherwise, charts
|
||||
are omitted.
|
||||
|
||||
Returns the markdown string.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
lines.append("# Diffusion Cross-Framework Performance Dashboard\n")
|
||||
ts = current.get("timestamp", datetime.now(timezone.utc).isoformat())
|
||||
sha = current.get("commit_sha", "unknown")
|
||||
lines.append(f"*Generated: {_short_date(ts)} | Commit: `{_short_sha(sha)}`*\n")
|
||||
|
||||
current_cases = _extract_case_results(current)
|
||||
case_ids = list(current_cases.keys())
|
||||
|
||||
# ---- Regression detection ----
|
||||
REGRESSION_THRESHOLD = 0.05 # 5%
|
||||
regressions: list[str] = []
|
||||
if history:
|
||||
prev_cases = _extract_case_results(history[0])
|
||||
for cid in case_ids:
|
||||
for fw in ("sglang", "vllm-omni"):
|
||||
cur = current_cases.get(cid, {}).get(fw)
|
||||
prev = prev_cases.get(cid, {}).get(fw)
|
||||
if cur and prev and prev > 0:
|
||||
pct = (cur - prev) / prev
|
||||
if pct > REGRESSION_THRESHOLD:
|
||||
regressions.append(
|
||||
f"**{cid}** ({fw}): {prev:.2f}s -> {cur:.2f}s "
|
||||
f"(+{pct*100:.1f}%)"
|
||||
)
|
||||
|
||||
if regressions:
|
||||
lines.append("> [!WARNING]\n> **Performance Regression Detected**\n>")
|
||||
for reg in regressions:
|
||||
lines.append(f"> - {reg}")
|
||||
lines.append("\n")
|
||||
|
||||
# Discover all frameworks present in results
|
||||
all_frameworks = []
|
||||
seen_fw = set()
|
||||
for r in current.get("results", []):
|
||||
fw = r["framework"]
|
||||
if fw not in seen_fw:
|
||||
all_frameworks.append(fw)
|
||||
seen_fw.add(fw)
|
||||
# Ensure sglang is first
|
||||
if "sglang" in all_frameworks:
|
||||
all_frameworks.remove("sglang")
|
||||
all_frameworks.insert(0, "sglang")
|
||||
other_frameworks = [fw for fw in all_frameworks if fw != "sglang"]
|
||||
|
||||
# ---- Section 1: Cross-Framework Comparison (current run) ----
|
||||
lines.append("## Cross-Framework Performance Comparison\n")
|
||||
|
||||
# Compute risk assessments for all cases
|
||||
risk_map: dict[str, tuple[str, str]] = {}
|
||||
for cid in case_ids:
|
||||
risk_map[cid] = _assess_risk(cid, current_cases, history, other_frameworks)
|
||||
|
||||
# Dynamic header
|
||||
header = "| Model | Risk |"
|
||||
sep = "|-------|------|"
|
||||
for fw in all_frameworks:
|
||||
header += f" {fw} (s) |"
|
||||
sep += "---------|"
|
||||
for ofw in other_frameworks:
|
||||
header += f" vs {ofw} |"
|
||||
sep += "---------|"
|
||||
lines.append(header)
|
||||
lines.append(sep)
|
||||
|
||||
# One row per case (deduplicated by case_id)
|
||||
seen_cases = set()
|
||||
for r in current.get("results", []):
|
||||
cid = r["case_id"]
|
||||
if cid in seen_cases:
|
||||
continue
|
||||
seen_cases.add(cid)
|
||||
|
||||
case_fws = current_cases.get(cid, {})
|
||||
sg_lat = case_fws.get("sglang")
|
||||
|
||||
risk_emoji, _ = risk_map.get(cid, ("✅", ""))
|
||||
row = f"| {r['model'].split('/')[-1]} | {risk_emoji} |"
|
||||
# Latency columns -- bold the fastest
|
||||
lats = {fw: case_fws.get(fw) for fw in all_frameworks}
|
||||
valid_lats = [v for v in lats.values() if v is not None]
|
||||
min_lat = min(valid_lats) if valid_lats else None
|
||||
for fw in all_frameworks:
|
||||
lat = lats[fw]
|
||||
if lat is not None and min_lat is not None and lat == min_lat:
|
||||
row += f" **{_fmt_latency(lat)}** |"
|
||||
else:
|
||||
row += f" {_fmt_latency(lat)} |"
|
||||
# Speedup columns
|
||||
for ofw in other_frameworks:
|
||||
row += f" {_fmt_speedup(sg_lat, case_fws.get(ofw))} |"
|
||||
lines.append(row)
|
||||
|
||||
# ---- Section 2: Cross-Framework Speedup Trend (only if multiple frameworks) ----
|
||||
if history and other_frameworks:
|
||||
lines.append("\n## SGLang vs vLLM-Omni Speedup Over Time\n")
|
||||
|
||||
header = "| Date |"
|
||||
sep = "|------|"
|
||||
for cid in case_ids:
|
||||
header += f" {cid} |"
|
||||
sep += "---------|"
|
||||
lines.append(header)
|
||||
lines.append(sep)
|
||||
|
||||
all_runs = [current] + history
|
||||
for run in all_runs:
|
||||
run_cases = _extract_case_results(run)
|
||||
date = _short_date(run.get("timestamp", ""))
|
||||
row = f"| {date} |"
|
||||
for cid in case_ids:
|
||||
sg = run_cases.get(cid, {}).get("sglang")
|
||||
vl = run_cases.get(cid, {}).get("vllm-omni")
|
||||
row += f" {_fmt_speedup(sg, vl)} |"
|
||||
lines.append(row)
|
||||
|
||||
# ---- Section 4: Matplotlib Trend Charts (saved as PNG files) ----
|
||||
if history and charts_dir:
|
||||
all_runs = list(reversed([current] + history)) # chronological order
|
||||
|
||||
def _chart_label(run: dict) -> str:
|
||||
d = _short_date(run.get("timestamp", ""))
|
||||
s = _short_sha(run.get("commit_sha", ""))
|
||||
return f"{d}\n({s})"
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
os.makedirs(charts_dir, exist_ok=True)
|
||||
|
||||
# Per-case latency trend charts
|
||||
for cid in case_ids:
|
||||
labels = []
|
||||
sg_vals = []
|
||||
vl_vals = []
|
||||
for run in all_runs:
|
||||
run_cases = _extract_case_results(run)
|
||||
sg = run_cases.get(cid, {}).get("sglang")
|
||||
vl = run_cases.get(cid, {}).get("vllm-omni")
|
||||
if sg is None:
|
||||
continue
|
||||
labels.append(_chart_label(run))
|
||||
sg_vals.append(sg)
|
||||
vl_vals.append(vl)
|
||||
|
||||
if not sg_vals:
|
||||
continue
|
||||
|
||||
has_vl = any(v is not None for v in vl_vals)
|
||||
fig, ax = plt.subplots(figsize=(max(6, len(labels) * 1.2), 4))
|
||||
|
||||
# SGLang line
|
||||
ax.plot(
|
||||
range(len(sg_vals)),
|
||||
sg_vals,
|
||||
"o-",
|
||||
color="#2563eb",
|
||||
linewidth=2,
|
||||
markersize=6,
|
||||
label="SGLang",
|
||||
)
|
||||
for i, v in enumerate(sg_vals):
|
||||
ax.annotate(
|
||||
f"{v:.2f}s",
|
||||
(i, v),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 10),
|
||||
ha="center",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
color="#2563eb",
|
||||
)
|
||||
|
||||
# vLLM-Omni line (if data exists)
|
||||
if has_vl:
|
||||
vl_clean = [v if v is not None else float("nan") for v in vl_vals]
|
||||
ax.plot(
|
||||
range(len(vl_clean)),
|
||||
vl_clean,
|
||||
"s--",
|
||||
color="#dc2626",
|
||||
linewidth=2,
|
||||
markersize=5,
|
||||
label="vLLM-Omni",
|
||||
)
|
||||
for i, v in enumerate(vl_vals):
|
||||
if v is not None:
|
||||
ax.annotate(
|
||||
f"{v:.2f}s",
|
||||
(i, v),
|
||||
textcoords="offset points",
|
||||
xytext=(0, -14),
|
||||
ha="center",
|
||||
fontsize=8,
|
||||
color="#dc2626",
|
||||
)
|
||||
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels, fontsize=7)
|
||||
ax.set_ylabel("Latency (s)")
|
||||
ax.set_title(f"Latency Trend -- {cid}", fontsize=11, fontweight="bold")
|
||||
ax.legend(loc="lower right", fontsize=8, framealpha=0.8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
all_vals = sg_vals + [v for v in vl_vals if v is not None]
|
||||
y_min = min(all_vals)
|
||||
y_max = max(all_vals)
|
||||
y_range = y_max - y_min if y_max > y_min else max(y_max * 0.1, 0.1)
|
||||
ax.set_ylim(
|
||||
bottom=max(0, y_min - y_range * 0.3),
|
||||
top=y_max + y_range * 0.3,
|
||||
)
|
||||
|
||||
filename = f"latency_{_sanitize_filename(cid)}.png"
|
||||
chart_path = os.path.join(charts_dir, filename)
|
||||
fig.savefig(chart_path, format="png", dpi=120, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
print(f" Saved chart: {chart_path}")
|
||||
|
||||
chart_url = f"{CHARTS_RAW_BASE_URL}/{filename}"
|
||||
lines.append(f"\n### Latency Trend: {cid}\n")
|
||||
lines.append(f"\n")
|
||||
|
||||
# Speedup trend chart (only if multiple frameworks)
|
||||
if other_frameworks:
|
||||
fig, ax = plt.subplots(figsize=(max(6, len(all_runs) * 1.2), 4))
|
||||
colors = ["#2563eb", "#dc2626", "#16a34a", "#ea580c"]
|
||||
for ci_idx, cid in enumerate(case_ids):
|
||||
speedups = []
|
||||
run_labels = []
|
||||
for run in all_runs:
|
||||
run_cases = _extract_case_results(run)
|
||||
sg = run_cases.get(cid, {}).get("sglang")
|
||||
vl = run_cases.get(cid, {}).get("vllm-omni")
|
||||
if sg and vl and sg > 0:
|
||||
speedups.append(vl / sg)
|
||||
else:
|
||||
speedups.append(None)
|
||||
run_labels.append(_chart_label(run))
|
||||
clean = [v if v is not None else float("nan") for v in speedups]
|
||||
ax.plot(
|
||||
range(len(clean)),
|
||||
clean,
|
||||
"o-",
|
||||
color=colors[ci_idx % len(colors)],
|
||||
linewidth=2,
|
||||
markersize=5,
|
||||
label=cid,
|
||||
)
|
||||
|
||||
ax.set_xticks(range(len(run_labels)))
|
||||
ax.set_xticklabels(run_labels, fontsize=7)
|
||||
ax.set_ylabel("Speedup (x)")
|
||||
ax.set_title(
|
||||
"SGLang Speedup Over vLLM-Omni", fontsize=11, fontweight="bold"
|
||||
)
|
||||
ax.axhline(y=1.0, color="gray", linestyle=":", alpha=0.5)
|
||||
ax.legend(loc="upper left", fontsize=7)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
filename = "speedup_trend.png"
|
||||
chart_path = os.path.join(charts_dir, filename)
|
||||
fig.savefig(chart_path, format="png", dpi=120, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
print(f" Saved chart: {chart_path}")
|
||||
|
||||
chart_url = f"{CHARTS_RAW_BASE_URL}/{filename}"
|
||||
lines.append("\n### Speedup Trend (SGLang vs vLLM-Omni)\n")
|
||||
lines.append(f"\n")
|
||||
|
||||
except ImportError:
|
||||
lines.append("\n*Charts unavailable (matplotlib not installed)*\n")
|
||||
|
||||
# ---- SGLang Performance Trend (raw data table, at the end) ----
|
||||
if history:
|
||||
lines.append(f"\n## SGLang Performance Trend (Last {len(history) + 1} Runs)\n")
|
||||
|
||||
header = "| Date | Commit |"
|
||||
sep = "|------|--------|"
|
||||
for cid in case_ids:
|
||||
header += f" {cid} (s) |"
|
||||
sep += "---------|"
|
||||
header += " Trend |"
|
||||
sep += "-------|"
|
||||
lines.append(header)
|
||||
lines.append(sep)
|
||||
|
||||
all_runs = [current] + history
|
||||
for i, run in enumerate(all_runs):
|
||||
run_cases = _extract_case_results(run)
|
||||
date = _short_date(run.get("timestamp", ""))
|
||||
sha_s = _short_sha(run.get("commit_sha", ""))
|
||||
row = f"| {date} | `{sha_s}` |"
|
||||
for cid in case_ids:
|
||||
lat = run_cases.get(cid, {}).get("sglang")
|
||||
row += f" {_fmt_latency(lat)} |"
|
||||
if i + 1 < len(all_runs):
|
||||
prev_cases = _extract_case_results(all_runs[i + 1])
|
||||
emojis = []
|
||||
for cid in case_ids:
|
||||
cur = run_cases.get(cid, {}).get("sglang")
|
||||
prev = prev_cases.get(cid, {}).get("sglang")
|
||||
emojis.append(_trend_emoji(cur, prev))
|
||||
row += " ".join(emojis) + " |"
|
||||
else:
|
||||
row += " -- |"
|
||||
lines.append(row)
|
||||
|
||||
# ---- Risk Notification ----
|
||||
alert_cases = [
|
||||
(cid, emoji, reason)
|
||||
for cid, (emoji, reason) in risk_map.items()
|
||||
if emoji in ("⚠️", "🔴", "❌")
|
||||
]
|
||||
if alert_cases:
|
||||
lines.append("\n> [!CAUTION]")
|
||||
lines.append("> **Action Required — Performance Alert**")
|
||||
lines.append(">")
|
||||
lines.append("> The following cases need attention:")
|
||||
for _cid, _emoji, reason in alert_cases:
|
||||
lines.append(f"> - {reason}")
|
||||
lines.append("")
|
||||
|
||||
# Footer
|
||||
lines.append("\n---")
|
||||
lines.append(
|
||||
"*Generated by `generate_diffusion_dashboard.py` in SGLang nightly CI.*"
|
||||
)
|
||||
|
||||
alert_reasons = [reason for _, _, reason in alert_cases]
|
||||
return "\n".join(lines) + "\n", alert_reasons
|
||||
|
||||
|
||||
ALERT_ASSIGNEES = ["mickqian", "bbuf", "yhyang201"]
|
||||
ALERT_LABEL = "perf-regression"
|
||||
|
||||
|
||||
ALERT_ISSUE_TITLE = "[Diffusion CI] Performance regression tracker"
|
||||
|
||||
|
||||
def _find_alert_issue(repo: str) -> tuple[str | None, bool]:
|
||||
"""Find the perf-regression tracker issue (open OR closed).
|
||||
|
||||
Returns (issue_number, is_open). Prefers an open issue; if none,
|
||||
returns the most recent closed one so it can be reopened.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
for state in ("open", "closed"):
|
||||
result = subprocess.run(
|
||||
[
|
||||
"gh",
|
||||
"issue",
|
||||
"list",
|
||||
"--repo",
|
||||
repo,
|
||||
"--label",
|
||||
ALERT_LABEL,
|
||||
"--state",
|
||||
state,
|
||||
"--json",
|
||||
"number",
|
||||
"--limit",
|
||||
"1",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode != 0 or not result.stdout.strip():
|
||||
continue
|
||||
issues = json.loads(result.stdout)
|
||||
if issues:
|
||||
return str(issues[0]["number"]), state == "open"
|
||||
return None, False
|
||||
|
||||
|
||||
def _create_alert_issue(alert_reasons: list[str]) -> None:
|
||||
"""Create or update the single perf-regression tracker issue.
|
||||
|
||||
Logic:
|
||||
- If an open issue exists → add a comment with the new alert.
|
||||
- If a closed issue exists → reopen it, then add a comment.
|
||||
- If no issue exists → create one.
|
||||
|
||||
This guarantees at most one tracker issue ever exists.
|
||||
|
||||
Uses `gh` (GitHub CLI) which is available in all GitHub Actions runners.
|
||||
Falls back silently outside CI.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
run_url = ""
|
||||
run_id = os.environ.get("GITHUB_RUN_ID", "")
|
||||
repo = os.environ.get("GITHUB_REPOSITORY", "sgl-project/sglang")
|
||||
server_url = os.environ.get("GITHUB_SERVER_URL", "https://github.com")
|
||||
if run_id:
|
||||
run_url = f"{server_url}/{repo}/actions/runs/{run_id}"
|
||||
|
||||
date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
body_lines = [
|
||||
f"## Performance Alert — {date}",
|
||||
"",
|
||||
"The nightly diffusion benchmark detected the following issue(s):",
|
||||
"",
|
||||
]
|
||||
for reason in alert_reasons:
|
||||
body_lines.append(f"- {reason}")
|
||||
if run_url:
|
||||
body_lines += ["", f"**CI Run:** {run_url}"]
|
||||
body = "\n".join(body_lines)
|
||||
|
||||
try:
|
||||
existing, is_open = _find_alert_issue(repo)
|
||||
|
||||
if existing:
|
||||
# Reopen if closed
|
||||
if not is_open:
|
||||
subprocess.run(
|
||||
[
|
||||
"gh",
|
||||
"issue",
|
||||
"reopen",
|
||||
existing,
|
||||
"--repo",
|
||||
repo,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
print(f"Reopened alert issue #{existing}")
|
||||
|
||||
# Add comment
|
||||
result = subprocess.run(
|
||||
[
|
||||
"gh",
|
||||
"issue",
|
||||
"comment",
|
||||
existing,
|
||||
"--repo",
|
||||
repo,
|
||||
"--body",
|
||||
body,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
print(f"Commented on alert issue #{existing}")
|
||||
else:
|
||||
print(
|
||||
f"Warning: failed to comment on issue #{existing} "
|
||||
f"(rc={result.returncode}): {result.stderr.strip()}"
|
||||
)
|
||||
else:
|
||||
# Create a new issue
|
||||
cmd = [
|
||||
"gh",
|
||||
"issue",
|
||||
"create",
|
||||
"--repo",
|
||||
repo,
|
||||
"--title",
|
||||
ALERT_ISSUE_TITLE,
|
||||
"--body",
|
||||
body,
|
||||
"--label",
|
||||
ALERT_LABEL,
|
||||
]
|
||||
for user in ALERT_ASSIGNEES:
|
||||
cmd += ["--assignee", user]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode == 0:
|
||||
print(f"Created alert issue: {result.stdout.strip()}")
|
||||
else:
|
||||
print(
|
||||
f"Warning: failed to create alert issue "
|
||||
f"(rc={result.returncode}): {result.stderr.strip()}"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
print("Warning: `gh` CLI not found — skipping alert issue creation")
|
||||
except Exception as e:
|
||||
print(f"Warning: failed to create/update alert issue: {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate diffusion cross-framework comparison dashboard"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results",
|
||||
required=True,
|
||||
help="Path to comparison-results.json from current run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="dashboard.md",
|
||||
help="Output markdown file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--charts-dir",
|
||||
default="comparison-charts",
|
||||
help="Directory to save chart PNG files (default: comparison-charts/)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--history-dir",
|
||||
default=None,
|
||||
help="Local directory containing historical comparison JSONs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fetch-history",
|
||||
action="store_true",
|
||||
help="Fetch history from sglang-ci-data GitHub repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--step-summary",
|
||||
action="store_true",
|
||||
help="Also write to $GITHUB_STEP_SUMMARY",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load current results
|
||||
with open(args.results) as f:
|
||||
current = json.load(f)
|
||||
print(f"Loaded current results: {len(current.get('results', []))} entries")
|
||||
|
||||
# Load history
|
||||
history: list[dict] = []
|
||||
if args.fetch_history:
|
||||
token = os.environ.get("GH_PAT_FOR_NIGHTLY_CI_DATA") or os.environ.get(
|
||||
"GITHUB_TOKEN"
|
||||
)
|
||||
if token:
|
||||
history = fetch_history_from_github(token)
|
||||
else:
|
||||
print("Warning: No GitHub token available, skipping history fetch")
|
||||
elif args.history_dir:
|
||||
history = load_history_from_dir(args.history_dir)
|
||||
print(f"Loaded {len(history)} historical run(s) from {args.history_dir}")
|
||||
|
||||
# Generate dashboard
|
||||
markdown, alert_reasons = generate_dashboard(
|
||||
current, history, charts_dir=args.charts_dir
|
||||
)
|
||||
|
||||
# Write output
|
||||
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
||||
with open(args.output, "w") as f:
|
||||
f.write(markdown)
|
||||
print(f"Dashboard written to {args.output}")
|
||||
|
||||
# Write to GitHub Step Summary
|
||||
if args.step_summary:
|
||||
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
|
||||
if summary_file:
|
||||
with open(summary_file, "a") as f:
|
||||
f.write(markdown)
|
||||
print("Dashboard appended to $GITHUB_STEP_SUMMARY")
|
||||
else:
|
||||
print("Warning: $GITHUB_STEP_SUMMARY not set, skipping")
|
||||
|
||||
# Create GitHub Issue for performance alerts (so assignees get notified)
|
||||
if alert_reasons:
|
||||
_create_alert_issue(alert_reasons)
|
||||
else:
|
||||
print("No performance alerts — skipping issue creation.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
231
third_party/sglang/scripts/ci/utils/diffusion/publish_comparison_results.py
vendored
Normal file
231
third_party/sglang/scripts/ci/utils/diffusion/publish_comparison_results.py
vendored
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Publish diffusion comparison results to sglang-bot/sglang-ci-data repo.
|
||||
|
||||
Pushes comparison-results.json, dashboard.md, and chart PNG files to the
|
||||
ci-data repository for historical tracking. Chart PNGs are stored under
|
||||
diffusion-comparisons/charts/ so they can be referenced via
|
||||
raw.githubusercontent URLs in the dashboard markdown (GitHub Step Summary
|
||||
blocks data: URIs).
|
||||
|
||||
Usage:
|
||||
python3 scripts/ci/utils/diffusion/publish_comparison_results.py \
|
||||
--results comparison-results.json \
|
||||
--dashboard dashboard.md \
|
||||
--charts-dir comparison-charts/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Reuse GitHub API helpers from publish_traces.
|
||||
# Support both direct script execution and package-style imports.
|
||||
if __package__:
|
||||
from ..publish_traces import (
|
||||
create_blobs,
|
||||
create_commit,
|
||||
create_tree,
|
||||
get_branch_sha,
|
||||
get_tree_sha,
|
||||
is_permission_error,
|
||||
is_rate_limit_error,
|
||||
update_branch_ref,
|
||||
verify_token_permissions,
|
||||
)
|
||||
else:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from publish_traces import (
|
||||
create_blobs,
|
||||
create_commit,
|
||||
create_tree,
|
||||
get_branch_sha,
|
||||
get_tree_sha,
|
||||
is_permission_error,
|
||||
is_rate_limit_error,
|
||||
update_branch_ref,
|
||||
verify_token_permissions,
|
||||
)
|
||||
|
||||
# Repository configuration
|
||||
REPO_OWNER = "sglang-bot"
|
||||
REPO_NAME = "sglang-ci-data"
|
||||
BRANCH = "main"
|
||||
STORAGE_PREFIX = "diffusion-comparisons"
|
||||
|
||||
|
||||
def _collect_chart_files(charts_dir: str) -> list[tuple[str, bytes]]:
|
||||
"""Collect PNG chart files from directory for upload."""
|
||||
files: list[tuple[str, bytes]] = []
|
||||
if not charts_dir or not os.path.isdir(charts_dir):
|
||||
return files
|
||||
|
||||
for entry in sorted(os.listdir(charts_dir)):
|
||||
if not entry.lower().endswith(".png"):
|
||||
continue
|
||||
full_path = os.path.join(charts_dir, entry)
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
with open(full_path, "rb") as f:
|
||||
content = f.read()
|
||||
# Store charts under diffusion-comparisons/charts/
|
||||
repo_path = f"{STORAGE_PREFIX}/charts/{entry}"
|
||||
files.append((repo_path, content))
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def publish_comparison(
|
||||
results_path: str,
|
||||
dashboard_path: str | None = None,
|
||||
charts_dir: str | None = None,
|
||||
) -> None:
|
||||
"""Publish comparison results, dashboard, and charts to ci-data repo."""
|
||||
token = os.environ.get("GH_PAT_FOR_NIGHTLY_CI_DATA") or os.environ.get(
|
||||
"GITHUB_TOKEN"
|
||||
)
|
||||
if not token:
|
||||
print("Error: GH_PAT_FOR_NIGHTLY_CI_DATA or GITHUB_TOKEN not set")
|
||||
sys.exit(1)
|
||||
|
||||
run_id = os.environ.get("GITHUB_RUN_ID", "local")
|
||||
run_number = os.environ.get("GITHUB_RUN_NUMBER", "0")
|
||||
|
||||
# Verify permissions
|
||||
perm = verify_token_permissions(REPO_OWNER, REPO_NAME, token)
|
||||
if perm == "rate_limited":
|
||||
print("Warning: Rate limited, skipping publish")
|
||||
return
|
||||
elif not perm:
|
||||
print("Error: Token permission verification failed")
|
||||
sys.exit(1)
|
||||
|
||||
# Prepare files to upload
|
||||
files_to_upload: list[tuple[str, bytes]] = []
|
||||
|
||||
# Results JSON: stored with date prefix for chronological ordering
|
||||
date_prefix = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
results_target = f"{STORAGE_PREFIX}/{date_prefix}_{run_id}.json"
|
||||
with open(results_path, "rb") as f:
|
||||
files_to_upload.append((results_target, f.read()))
|
||||
|
||||
# Dashboard markdown: always overwrite latest
|
||||
if dashboard_path and os.path.exists(dashboard_path):
|
||||
dashboard_target = f"{STORAGE_PREFIX}/dashboard.md"
|
||||
with open(dashboard_path, "rb") as f:
|
||||
files_to_upload.append((dashboard_target, f.read()))
|
||||
|
||||
# Chart PNG files
|
||||
chart_files = _collect_chart_files(charts_dir)
|
||||
if chart_files:
|
||||
print(f"Found {len(chart_files)} chart PNG(s) to upload")
|
||||
files_to_upload.extend(chart_files)
|
||||
|
||||
print(f"Publishing {len(files_to_upload)} file(s) to {REPO_OWNER}/{REPO_NAME}")
|
||||
|
||||
# Create blobs
|
||||
try:
|
||||
tree_items = create_blobs(REPO_OWNER, REPO_NAME, files_to_upload, token)
|
||||
except Exception as e:
|
||||
if is_rate_limit_error(e):
|
||||
print("Warning: Rate limited during blob creation, skipping")
|
||||
return
|
||||
if is_permission_error(e):
|
||||
print(f"Error: No write permission to {REPO_OWNER}/{REPO_NAME}")
|
||||
sys.exit(1)
|
||||
raise
|
||||
|
||||
# Commit with retry (handle concurrent writes)
|
||||
max_retries = 5
|
||||
retry_delay = 5
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
branch_sha = get_branch_sha(REPO_OWNER, REPO_NAME, BRANCH, token)
|
||||
tree_sha = get_tree_sha(REPO_OWNER, REPO_NAME, branch_sha, token)
|
||||
|
||||
new_tree_sha = create_tree(
|
||||
REPO_OWNER, REPO_NAME, tree_sha, tree_items, token
|
||||
)
|
||||
|
||||
commit_msg = (
|
||||
f"Diffusion comparison results for run {run_id} (#{run_number})"
|
||||
)
|
||||
commit_sha = create_commit(
|
||||
REPO_OWNER, REPO_NAME, new_tree_sha, branch_sha, commit_msg, token
|
||||
)
|
||||
|
||||
update_branch_ref(REPO_OWNER, REPO_NAME, BRANCH, commit_sha, token)
|
||||
print(
|
||||
f"Successfully published comparison results (commit {commit_sha[:7]})"
|
||||
)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
is_retryable = False
|
||||
if hasattr(e, "error_body"):
|
||||
body = getattr(e, "error_body", "")
|
||||
if "Update is not a fast forward" in body:
|
||||
is_retryable = True
|
||||
elif "Object does not exist" in body:
|
||||
is_retryable = True
|
||||
|
||||
from urllib.error import HTTPError
|
||||
|
||||
if isinstance(e, HTTPError) and e.code in [422, 500, 502, 503, 504]:
|
||||
is_retryable = True
|
||||
|
||||
if is_rate_limit_error(e):
|
||||
print("Warning: Rate limited, skipping publish")
|
||||
return
|
||||
|
||||
if is_permission_error(e):
|
||||
print(f"Error: No write permission to {REPO_OWNER}/{REPO_NAME}")
|
||||
sys.exit(1)
|
||||
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
print(
|
||||
f"Attempt {attempt + 1}/{max_retries} failed, retrying in {retry_delay}s..."
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
print(f"Failed to publish after {attempt + 1} attempts: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Publish diffusion comparison results to sglang-ci-data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results",
|
||||
required=True,
|
||||
help="Path to comparison-results.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dashboard",
|
||||
default=None,
|
||||
help="Path to dashboard.md (optional)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--charts-dir",
|
||||
default=None,
|
||||
help="Directory containing chart PNG files to upload (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.results):
|
||||
print(f"Error: Results file not found: {args.results}")
|
||||
sys.exit(1)
|
||||
|
||||
publish_comparison(
|
||||
results_path=args.results,
|
||||
dashboard_path=args.dashboard,
|
||||
charts_dir=args.charts_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
third_party/sglang/scripts/ci/utils/diffusion/publish_diffusion_gt.py
vendored
Normal file
166
third_party/sglang/scripts/ci/utils/diffusion/publish_diffusion_gt.py
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Publish diffusion CI ground-truth images to sglang-bot/sglang-ci-data
|
||||
via the GitHub API (same pattern as publish_traces.py).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Reuse GitHub API helpers from publish_traces.
|
||||
# Support both direct script execution and package-style imports.
|
||||
if __package__:
|
||||
from ..publish_traces import (
|
||||
create_blobs,
|
||||
create_commit,
|
||||
create_tree,
|
||||
get_branch_sha,
|
||||
get_tree_sha,
|
||||
is_permission_error,
|
||||
is_rate_limit_error,
|
||||
update_branch_ref,
|
||||
verify_token_permissions,
|
||||
)
|
||||
else:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from publish_traces import (
|
||||
create_blobs,
|
||||
create_commit,
|
||||
create_tree,
|
||||
get_branch_sha,
|
||||
get_tree_sha,
|
||||
is_permission_error,
|
||||
is_rate_limit_error,
|
||||
update_branch_ref,
|
||||
verify_token_permissions,
|
||||
)
|
||||
|
||||
REPO_OWNER = "sglang-bot"
|
||||
REPO_NAME = "sglang-ci-data"
|
||||
BRANCH = "main"
|
||||
TARGET_DIR = "diffusion-ci/consistency_gt"
|
||||
|
||||
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp"}
|
||||
|
||||
|
||||
def collect_images(source_dir):
|
||||
"""Collect image files from source_dir and return list of (repo_path, content) tuples."""
|
||||
files = []
|
||||
for entry in sorted(os.listdir(source_dir)):
|
||||
ext = os.path.splitext(entry)[1].lower()
|
||||
if ext not in IMAGE_EXTENSIONS:
|
||||
continue
|
||||
full_path = os.path.join(source_dir, entry)
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
with open(full_path, "rb") as f:
|
||||
content = f.read()
|
||||
repo_path = f"{TARGET_DIR}/{entry}"
|
||||
files.append((repo_path, content))
|
||||
return files
|
||||
|
||||
|
||||
def publish(source_dir):
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
if not token:
|
||||
print("Error: GITHUB_TOKEN environment variable not set")
|
||||
sys.exit(1)
|
||||
|
||||
files_to_upload = collect_images(source_dir)
|
||||
if not files_to_upload:
|
||||
print(f"No image files found in {source_dir}")
|
||||
return
|
||||
|
||||
print(
|
||||
f"Found {len(files_to_upload)} image(s) to upload to {REPO_OWNER}/{REPO_NAME}/{TARGET_DIR}"
|
||||
)
|
||||
|
||||
# Verify token
|
||||
perm = verify_token_permissions(REPO_OWNER, REPO_NAME, token)
|
||||
if perm == "rate_limited":
|
||||
print("GitHub API rate-limited, skipping upload.")
|
||||
return
|
||||
if not perm:
|
||||
print("Token permission verification failed.")
|
||||
sys.exit(1)
|
||||
|
||||
# Create blobs
|
||||
try:
|
||||
tree_items = create_blobs(REPO_OWNER, REPO_NAME, files_to_upload, token)
|
||||
except Exception as e:
|
||||
if is_rate_limit_error(e):
|
||||
print("Rate-limited during blob creation, skipping.")
|
||||
return
|
||||
if is_permission_error(e):
|
||||
print(
|
||||
f"ERROR: Token lacks write permission to {REPO_OWNER}/{REPO_NAME}. "
|
||||
"Update GH_PAT_FOR_NIGHTLY_CI_DATA with a token that has contents:write."
|
||||
)
|
||||
sys.exit(1)
|
||||
raise
|
||||
|
||||
# Commit with retry (handle concurrent pushes)
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
branch_sha = get_branch_sha(REPO_OWNER, REPO_NAME, BRANCH, token)
|
||||
tree_sha = get_tree_sha(REPO_OWNER, REPO_NAME, branch_sha, token)
|
||||
new_tree_sha = create_tree(
|
||||
REPO_OWNER, REPO_NAME, tree_sha, tree_items, token
|
||||
)
|
||||
commit_msg = f"diffusion-ci: update consistency_gt images ({len(files_to_upload)} files) [automated]"
|
||||
commit_sha = create_commit(
|
||||
REPO_OWNER, REPO_NAME, new_tree_sha, branch_sha, commit_msg, token
|
||||
)
|
||||
update_branch_ref(REPO_OWNER, REPO_NAME, BRANCH, commit_sha, token)
|
||||
print(
|
||||
f"Successfully pushed {len(files_to_upload)} images (commit {commit_sha[:10]})"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
if is_rate_limit_error(e):
|
||||
print("Rate-limited, skipping.")
|
||||
return
|
||||
if is_permission_error(e):
|
||||
print(f"ERROR: permission denied to {REPO_OWNER}/{REPO_NAME}")
|
||||
sys.exit(1)
|
||||
|
||||
retryable = False
|
||||
if hasattr(e, "error_body"):
|
||||
if "Update is not a fast forward" in e.error_body:
|
||||
retryable = True
|
||||
elif "Object does not exist" in e.error_body:
|
||||
retryable = True
|
||||
|
||||
from urllib.error import HTTPError
|
||||
|
||||
if isinstance(e, HTTPError) and e.code in [422, 500, 502, 503, 504]:
|
||||
retryable = True
|
||||
|
||||
if retryable and attempt < max_retries - 1:
|
||||
import time
|
||||
|
||||
wait = 2**attempt
|
||||
print(
|
||||
f"Attempt {attempt + 1}/{max_retries} failed, retrying in {wait}s..."
|
||||
)
|
||||
time.sleep(wait)
|
||||
else:
|
||||
print(f"Failed after {attempt + 1} attempts: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Publish diffusion GT images to GitHub"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-dir", required=True, help="Directory containing GT images"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
publish(args.source_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
951
third_party/sglang/scripts/ci/utils/diffusion/run_comparison.py
vendored
Normal file
951
third_party/sglang/scripts/ci/utils/diffusion/run_comparison.py
vendored
Normal file
@@ -0,0 +1,951 @@
|
||||
"""Cross-framework comparison benchmark for diffusion serving.
|
||||
|
||||
Launches servers (SGLang, vLLM-Omni, LightX2V) for each test case, sends a
|
||||
single request, measures end-to-end latency, and writes comparison-results.json.
|
||||
|
||||
Usage:
|
||||
# Full run (requires GPU)
|
||||
python3 scripts/ci/utils/diffusion/run_comparison.py
|
||||
|
||||
# Dry-run (config parsing + command preview only)
|
||||
python3 scripts/ci/utils/diffusion/run_comparison.py --dry-run
|
||||
|
||||
# Run only specific case(s)
|
||||
python3 scripts/ci/utils/diffusion/run_comparison.py --case-ids flux1_dev_t2i_1024
|
||||
|
||||
# Run only specific framework(s)
|
||||
python3 scripts/ci/utils/diffusion/run_comparison.py --frameworks sglang
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
CONFIGS_PATH = Path(__file__).parent / "comparison_configs.json"
|
||||
INSTALL_SCRIPT = Path(__file__).parents[1] / "install_comparison_frameworks.sh"
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 30000
|
||||
HEALTH_TIMEOUT = (
|
||||
2400 # seconds (40 min — FLUX.2-dev needs ~10 min download + torch.compile)
|
||||
)
|
||||
REQUEST_TIMEOUT = 1200 # seconds
|
||||
GPU_CLEAR_WAIT = 15 # seconds between framework runs
|
||||
|
||||
# Frameworks that need separate installation (conflict with sglang's deps)
|
||||
INSTALLABLE_FRAMEWORKS = {"vllm-omni", "lightx2v"}
|
||||
|
||||
# Cached reference image (downloaded once)
|
||||
_cached_ref_image: bytes | None = None
|
||||
_cached_ref_image_path: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server lifecycle — command builders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_sglang_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]:
|
||||
cmd = [
|
||||
"sglang",
|
||||
"serve",
|
||||
"--model-path",
|
||||
case["model"],
|
||||
"--port",
|
||||
str(port),
|
||||
"--host",
|
||||
DEFAULT_HOST,
|
||||
]
|
||||
if case["num_gpus"] > 1:
|
||||
cmd += ["--num-gpus", str(case["num_gpus"])]
|
||||
if fw_cfg.get("serve_args", "").strip():
|
||||
cmd += fw_cfg["serve_args"].strip().split()
|
||||
return cmd
|
||||
|
||||
|
||||
def _build_vllm_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]:
|
||||
cmd = [
|
||||
"vllm",
|
||||
"serve",
|
||||
case["model"],
|
||||
"--omni",
|
||||
"--port",
|
||||
str(port),
|
||||
"--host",
|
||||
DEFAULT_HOST,
|
||||
]
|
||||
if fw_cfg.get("serve_args", "").strip():
|
||||
cmd += fw_cfg["serve_args"].strip().split()
|
||||
return cmd
|
||||
|
||||
|
||||
def _resolve_hf_model_path(model_id: str) -> str:
|
||||
"""Resolve a HuggingFace model ID to a local cache path, or return as-is."""
|
||||
if os.path.isdir(model_id):
|
||||
return model_id
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = snapshot_download(model_id)
|
||||
print(f" Resolved {model_id} -> {path}")
|
||||
return path
|
||||
except Exception:
|
||||
return model_id
|
||||
|
||||
|
||||
def _write_lightx2v_config(case: dict) -> str:
|
||||
"""Write a minimal LightX2V config JSON and return its path."""
|
||||
cfg = {
|
||||
"infer_steps": case.get("num_inference_steps", 50),
|
||||
"guidance_scale": case.get("guidance_scale", 4.0),
|
||||
"seed": case.get("seed", 42),
|
||||
}
|
||||
if "num_frames" in case:
|
||||
cfg["target_video_length"] = case["num_frames"]
|
||||
if "height" in case:
|
||||
cfg["height"] = case["height"]
|
||||
if "width" in case:
|
||||
cfg["width"] = case["width"]
|
||||
|
||||
config_path = os.path.join(
|
||||
tempfile.gettempdir(), f"lightx2v_config_{case['id']}.json"
|
||||
)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(cfg, f)
|
||||
return config_path
|
||||
|
||||
|
||||
def _build_lightx2v_cmd(case: dict, fw_cfg: dict, port: int) -> list[str]:
|
||||
"""Build LightX2V server launch command.
|
||||
|
||||
Single GPU: python -m lightx2v.server --model_path ... --model_cls ... --task ... --port ...
|
||||
Multi GPU: torchrun --nproc_per_node=N -m lightx2v.server ...
|
||||
|
||||
LightX2V requires a local model path and a config JSON with infer params.
|
||||
"""
|
||||
model_cls = fw_cfg["model_cls"]
|
||||
task = fw_cfg["lightx2v_task"]
|
||||
num_gpus = case["num_gpus"]
|
||||
model_path = _resolve_hf_model_path(case["model"])
|
||||
config_path = _write_lightx2v_config(case)
|
||||
|
||||
server_args = [
|
||||
"--model_path",
|
||||
model_path,
|
||||
"--model_cls",
|
||||
model_cls,
|
||||
"--task",
|
||||
task,
|
||||
"--config_json",
|
||||
config_path,
|
||||
"--host",
|
||||
DEFAULT_HOST,
|
||||
"--port",
|
||||
str(port),
|
||||
]
|
||||
if fw_cfg.get("serve_args", "").strip():
|
||||
server_args += fw_cfg["serve_args"].strip().split()
|
||||
|
||||
if num_gpus > 1:
|
||||
cmd = [
|
||||
"torchrun",
|
||||
f"--nproc_per_node={num_gpus}",
|
||||
"-m",
|
||||
"lightx2v.server",
|
||||
] + server_args
|
||||
else:
|
||||
cmd = ["python3", "-m", "lightx2v.server"] + server_args
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def build_server_cmd(framework: str, case: dict, fw_cfg: dict, port: int) -> list[str]:
|
||||
builders = {
|
||||
"sglang": _build_sglang_cmd,
|
||||
"vllm-omni": _build_vllm_cmd,
|
||||
"lightx2v": _build_lightx2v_cmd,
|
||||
}
|
||||
builder = builders.get(framework)
|
||||
if builder is None:
|
||||
raise ValueError(f"Unknown framework: {framework}")
|
||||
return builder(case, fw_cfg, port)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server lifecycle — health check & cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Health check endpoints per framework
|
||||
HEALTH_ENDPOINTS = {
|
||||
"sglang": "/health",
|
||||
"vllm-omni": "/health",
|
||||
"lightx2v": "/v1/service/status",
|
||||
}
|
||||
|
||||
|
||||
def wait_for_health(
|
||||
base_url: str, framework: str = "sglang", timeout: int = HEALTH_TIMEOUT
|
||||
) -> None:
|
||||
"""Poll health endpoint until 200, then verify model is loaded."""
|
||||
endpoint = HEALTH_ENDPOINTS.get(framework, "/health")
|
||||
health_url = f"{base_url}{endpoint}"
|
||||
print(f" Waiting for server at {health_url} ...")
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
resp = requests.get(health_url, timeout=2)
|
||||
if resp.status_code == 200:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(
|
||||
f"Server at {health_url} did not start within {timeout}s"
|
||||
)
|
||||
time.sleep(2)
|
||||
|
||||
# For SGLang, /health can return 200 before model routes are registered.
|
||||
# Poll /v1/models to confirm the model is fully loaded.
|
||||
if framework == "sglang":
|
||||
models_url = f"{base_url}/v1/models"
|
||||
while True:
|
||||
try:
|
||||
resp = requests.get(models_url, timeout=5)
|
||||
if resp.status_code == 200:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Model at {models_url} not ready within {timeout}s")
|
||||
time.sleep(2)
|
||||
|
||||
elapsed = time.time() - start
|
||||
print(f" Server ready in {elapsed:.1f}s")
|
||||
|
||||
|
||||
KILLALL_SCRIPT = Path(__file__).parents[3] / "killall_sglang.sh"
|
||||
|
||||
|
||||
def kill_server(proc: subprocess.Popen) -> None:
|
||||
"""Kill server process tree and clean up GPU processes."""
|
||||
if proc.poll() is not None:
|
||||
return
|
||||
try:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
try:
|
||||
proc.wait(timeout=30)
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
proc.wait(timeout=10)
|
||||
# Use killall_sglang.sh for thorough cleanup (esp. multi-GPU workers)
|
||||
if KILLALL_SCRIPT.exists():
|
||||
subprocess.run(
|
||||
["bash", str(KILLALL_SCRIPT)],
|
||||
timeout=30,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reference image helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_ref_image_bytes(config: dict) -> bytes:
|
||||
"""Download and cache the shared test reference image."""
|
||||
global _cached_ref_image
|
||||
if _cached_ref_image is not None:
|
||||
return _cached_ref_image
|
||||
url = config.get("test_image_url", "")
|
||||
if not url:
|
||||
raise RuntimeError("No test_image_url in config for image-conditioned case")
|
||||
print(f" Downloading reference image from {url} ...")
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
_cached_ref_image = resp.content
|
||||
return _cached_ref_image
|
||||
|
||||
|
||||
def _get_ref_image_b64(config: dict) -> str:
|
||||
"""Get reference image as base64 string."""
|
||||
return base64.b64encode(_get_ref_image_bytes(config)).decode("utf-8")
|
||||
|
||||
|
||||
def _get_ref_image_path(config: dict) -> str:
|
||||
"""Save reference image to a temp file and return path."""
|
||||
global _cached_ref_image_path
|
||||
if _cached_ref_image_path and os.path.exists(_cached_ref_image_path):
|
||||
return _cached_ref_image_path
|
||||
data = _get_ref_image_bytes(config)
|
||||
fd, path = tempfile.mkstemp(suffix=".png")
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(data)
|
||||
_cached_ref_image_path = path
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request helpers — SGLang (OpenAI-compatible)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_sglang_payload(case: dict) -> dict:
|
||||
"""Build common SGLang request payload."""
|
||||
payload = {
|
||||
"model": case["model"],
|
||||
"prompt": case["prompt"],
|
||||
"size": f"{case['width']}x{case['height']}",
|
||||
"n": 1,
|
||||
"response_format": "b64_json",
|
||||
}
|
||||
for key in ("num_inference_steps", "guidance_scale", "seed", "num_frames"):
|
||||
if key in case:
|
||||
payload[key] = case[key]
|
||||
return payload
|
||||
|
||||
|
||||
def _read_perf_dump(perf_dump_path: str, timeout: float = 10.0) -> float | None:
|
||||
"""Read total_duration_ms from a perf dump JSON written by the server.
|
||||
|
||||
The server writes the file asynchronously after the HTTP response,
|
||||
so we poll briefly.
|
||||
"""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
with open(perf_dump_path) as f:
|
||||
data = json.load(f)
|
||||
total_ms = data.get("total_duration_ms")
|
||||
if total_ms is not None:
|
||||
return total_ms / 1000.0
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
pass
|
||||
time.sleep(0.5)
|
||||
return None
|
||||
|
||||
|
||||
def send_image_request_sglang(
|
||||
base_url: str, case: dict, perf_dump_path: str | None = None
|
||||
) -> float:
|
||||
"""Send a single T2I request via SGLang's /v1/images/generations."""
|
||||
payload = _build_sglang_payload(case)
|
||||
if perf_dump_path:
|
||||
payload["perf_dump_path"] = perf_dump_path
|
||||
|
||||
start = time.time()
|
||||
resp = requests.post(
|
||||
f"{base_url}/v1/images/generations",
|
||||
json=payload,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
client_latency = time.time() - start
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if "data" not in data or len(data["data"]) == 0:
|
||||
raise RuntimeError(f"Image request returned no data: {data}")
|
||||
|
||||
if perf_dump_path:
|
||||
server_latency = _read_perf_dump(perf_dump_path)
|
||||
if server_latency is not None:
|
||||
print(
|
||||
f" Image generated in {server_latency:.2f}s (server-side), "
|
||||
f"client={client_latency:.2f}s"
|
||||
)
|
||||
return server_latency
|
||||
print(f" Image generated in {client_latency:.2f}s")
|
||||
return client_latency
|
||||
|
||||
|
||||
def send_video_request_sglang(
|
||||
base_url: str, case: dict, perf_dump_path: str | None = None
|
||||
) -> float:
|
||||
"""Send a single T2V request via SGLang's /v1/videos (async)."""
|
||||
payload = _build_sglang_payload(case)
|
||||
if perf_dump_path:
|
||||
payload["perf_dump_path"] = perf_dump_path
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Submit job
|
||||
resp = requests.post(
|
||||
f"{base_url}/v1/videos",
|
||||
json=payload,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
job = resp.json()
|
||||
job_id = job.get("id")
|
||||
if not job_id:
|
||||
raise RuntimeError(f"Video submit returned no job id: {job}")
|
||||
|
||||
# Poll for completion
|
||||
poll_url = f"{base_url}/v1/videos/{job_id}"
|
||||
while True:
|
||||
time.sleep(1)
|
||||
poll_resp = requests.get(poll_url, timeout=30)
|
||||
poll_resp.raise_for_status()
|
||||
poll_data = poll_resp.json()
|
||||
status = poll_data.get("status")
|
||||
if status == "completed":
|
||||
break
|
||||
elif status == "failed":
|
||||
raise RuntimeError(f"Video generation failed: {poll_data}")
|
||||
if time.time() - start > REQUEST_TIMEOUT:
|
||||
raise TimeoutError(f"Video generation timed out after {REQUEST_TIMEOUT}s")
|
||||
|
||||
client_latency = time.time() - start
|
||||
|
||||
if perf_dump_path:
|
||||
server_latency = _read_perf_dump(perf_dump_path)
|
||||
if server_latency is not None:
|
||||
print(
|
||||
f" Video generated in {server_latency:.2f}s (server-side), "
|
||||
f"client={client_latency:.2f}s"
|
||||
)
|
||||
return server_latency
|
||||
print(f" Video generated in {client_latency:.2f}s")
|
||||
return client_latency
|
||||
|
||||
|
||||
def send_image_conditioned_request_sglang(
|
||||
base_url: str, case: dict, config: dict, perf_dump_path: str | None = None
|
||||
) -> float:
|
||||
"""Send an image-conditioned request (edit/I2V/TI2V) via SGLang multipart API."""
|
||||
task = case["task"]
|
||||
ref_bytes = _get_ref_image_bytes(config)
|
||||
|
||||
# Build multipart form — field name depends on endpoint:
|
||||
# image edits use "image", video (I2V/TI2V) uses "input_reference"
|
||||
if task in ("image-to-video", "text-image-to-video"):
|
||||
file_field = "input_reference"
|
||||
else:
|
||||
file_field = "image"
|
||||
files = {file_field: ("ref.png", io.BytesIO(ref_bytes), "image/png")}
|
||||
data = {
|
||||
"model": case["model"],
|
||||
"prompt": case["prompt"],
|
||||
"size": f"{case['width']}x{case['height']}",
|
||||
"n": "1",
|
||||
"response_format": "b64_json",
|
||||
}
|
||||
for key in ("num_inference_steps", "guidance_scale", "seed", "num_frames"):
|
||||
if key in case:
|
||||
data[key] = str(case[key])
|
||||
if perf_dump_path:
|
||||
data["perf_dump_path"] = perf_dump_path
|
||||
# Choose endpoint based on task
|
||||
if task in ("image-edit", "image-to-image"):
|
||||
endpoint = "/v1/images/edits"
|
||||
elif task in ("image-to-video", "text-image-to-video"):
|
||||
endpoint = "/v1/videos"
|
||||
else:
|
||||
endpoint = "/v1/images/generations"
|
||||
|
||||
start = time.time()
|
||||
resp = requests.post(
|
||||
f"{base_url}{endpoint}",
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
# For video endpoints, need to poll
|
||||
if task in ("image-to-video", "text-image-to-video"):
|
||||
resp.raise_for_status()
|
||||
job = resp.json()
|
||||
job_id = job.get("id")
|
||||
if not job_id:
|
||||
raise RuntimeError(f"Video submit returned no job id: {job}")
|
||||
poll_url = f"{base_url}/v1/videos/{job_id}"
|
||||
while True:
|
||||
time.sleep(1)
|
||||
poll_resp = requests.get(poll_url, timeout=30)
|
||||
poll_resp.raise_for_status()
|
||||
poll_data = poll_resp.json()
|
||||
status = poll_data.get("status")
|
||||
if status == "completed":
|
||||
break
|
||||
elif status == "failed":
|
||||
raise RuntimeError(f"Video generation failed: {poll_data}")
|
||||
if time.time() - start > REQUEST_TIMEOUT:
|
||||
raise TimeoutError(f"Timed out after {REQUEST_TIMEOUT}s")
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
|
||||
client_latency = time.time() - start
|
||||
|
||||
if perf_dump_path:
|
||||
server_latency = _read_perf_dump(perf_dump_path)
|
||||
if server_latency is not None:
|
||||
print(
|
||||
f" Generated in {server_latency:.2f}s (server-side), "
|
||||
f"client={client_latency:.2f}s"
|
||||
)
|
||||
return server_latency
|
||||
print(f" Generated in {client_latency:.2f}s (sglang, image-conditioned)")
|
||||
return client_latency
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request helpers — vLLM-Omni
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def send_request_vllm_omni(base_url: str, case: dict, config: dict) -> float:
|
||||
"""Send request via vLLM-Omni's /v1/chat/completions endpoint."""
|
||||
extra_body = {
|
||||
"height": case["height"],
|
||||
"width": case["width"],
|
||||
"num_inference_steps": case.get("num_inference_steps", 50),
|
||||
"guidance_scale": case.get("guidance_scale", 4.0),
|
||||
"seed": case.get("seed", 42),
|
||||
}
|
||||
if "num_frames" in case:
|
||||
extra_body["num_frames"] = case["num_frames"]
|
||||
|
||||
# Build message content (text or text+image)
|
||||
content: list[dict] | str = case["prompt"]
|
||||
if case.get("reference_image"):
|
||||
ref_b64 = _get_ref_image_b64(config)
|
||||
content = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{ref_b64}"},
|
||||
},
|
||||
{"type": "text", "text": case["prompt"]},
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": case["model"],
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"extra_body": extra_body,
|
||||
}
|
||||
|
||||
start = time.time()
|
||||
resp = requests.post(
|
||||
f"{base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
latency = time.time() - start
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
raise RuntimeError(f"vLLM-Omni request returned no choices: {data}")
|
||||
print(f" Generated in {latency:.2f}s (vllm-omni)")
|
||||
return latency
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request helpers — LightX2V
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def send_request_lightx2v(base_url: str, case: dict, config: dict) -> float:
|
||||
"""Send request via LightX2V's async task API."""
|
||||
task = case["task"]
|
||||
if task in ("text-to-image", "image-edit"):
|
||||
endpoint = "/v1/tasks/image"
|
||||
else:
|
||||
endpoint = "/v1/tasks/video"
|
||||
|
||||
payload = {
|
||||
"prompt": case["prompt"],
|
||||
"seed": case.get("seed", 42),
|
||||
"infer_steps": case.get("num_inference_steps", 50),
|
||||
}
|
||||
# LightX2V uses target_video_length for frames, height/width directly
|
||||
if "num_frames" in case:
|
||||
payload["target_video_length"] = case["num_frames"]
|
||||
if "height" in case:
|
||||
payload["height"] = case["height"]
|
||||
if "width" in case:
|
||||
payload["width"] = case["width"]
|
||||
if "guidance_scale" in case:
|
||||
payload["guidance_scale"] = case["guidance_scale"]
|
||||
# Image-conditioned: LightX2V accepts image_path (URL or local path)
|
||||
if case.get("reference_image"):
|
||||
payload["image_path"] = config.get("test_image_url", "")
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Submit task
|
||||
resp = requests.post(
|
||||
f"{base_url}{endpoint}",
|
||||
json=payload,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
task_data = resp.json()
|
||||
task_id = task_data.get("task_id")
|
||||
if not task_id:
|
||||
raise RuntimeError(f"LightX2V submit returned no task_id: {task_data}")
|
||||
|
||||
# Poll for completion
|
||||
poll_url = f"{base_url}/v1/tasks/{task_id}/status"
|
||||
while True:
|
||||
time.sleep(1)
|
||||
poll_resp = requests.get(poll_url, timeout=30)
|
||||
poll_resp.raise_for_status()
|
||||
poll_data = poll_resp.json()
|
||||
status = poll_data.get("task_status", "").upper()
|
||||
if status == "COMPLETED":
|
||||
break
|
||||
elif status in ("FAILED", "CANCELLED"):
|
||||
raise RuntimeError(f"LightX2V task {status}: {poll_data}")
|
||||
if time.time() - start > REQUEST_TIMEOUT:
|
||||
raise TimeoutError(f"LightX2V task timed out after {REQUEST_TIMEOUT}s")
|
||||
|
||||
latency = time.time() - start
|
||||
print(f" Generated in {latency:.2f}s (lightx2v)")
|
||||
return latency
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unified request dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def send_request(
|
||||
base_url: str,
|
||||
case: dict,
|
||||
framework: str = "sglang",
|
||||
config: dict | None = None,
|
||||
perf_dump_path: str | None = None,
|
||||
) -> float:
|
||||
config = config or {}
|
||||
if framework == "vllm-omni":
|
||||
return send_request_vllm_omni(base_url, case, config)
|
||||
elif framework == "lightx2v":
|
||||
return send_request_lightx2v(base_url, case, config)
|
||||
# SGLang — use OpenAI-compatible endpoints with optional perf log
|
||||
task = case["task"]
|
||||
if case.get("reference_image"):
|
||||
return send_image_conditioned_request_sglang(
|
||||
base_url, case, config, perf_dump_path
|
||||
)
|
||||
elif task == "text-to-image":
|
||||
return send_image_request_sglang(base_url, case, perf_dump_path)
|
||||
elif task == "text-to-video":
|
||||
return send_video_request_sglang(base_url, case, perf_dump_path)
|
||||
else:
|
||||
raise ValueError(f"Unknown task type: {task}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_single(
|
||||
case: dict,
|
||||
framework: str,
|
||||
fw_cfg: dict,
|
||||
port: int,
|
||||
log_dir: Path,
|
||||
config: dict | None = None,
|
||||
) -> dict:
|
||||
"""Run a single (case, framework) combination. Returns result dict."""
|
||||
result = {
|
||||
"case_id": case["id"],
|
||||
"framework": framework,
|
||||
"model": case["model"],
|
||||
"task": case["task"],
|
||||
"latency_s": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
cmd = build_server_cmd(framework, case, fw_cfg, port)
|
||||
print(f"\n Command: {' '.join(cmd)}")
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(fw_cfg.get("extra_env", {}))
|
||||
|
||||
# perf_dump_path for SGLang server-side timing (passed in request, zero overhead when None)
|
||||
perf_dump_path = None
|
||||
if framework == "sglang":
|
||||
perf_dump_path = os.path.join(str(log_dir), f"perf_{case['id']}_measured.json")
|
||||
|
||||
log_file = log_dir / f"{case['id']}_{framework}.log"
|
||||
log_fh = open(log_file, "w", encoding="utf-8", buffering=1)
|
||||
log_thread = None
|
||||
|
||||
proc = None
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
preexec_fn=os.setsid,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Tee server output to both log file and stdout (like test_server_utils)
|
||||
def _log_pipe(pipe, fh):
|
||||
try:
|
||||
for line in iter(pipe.readline, ""):
|
||||
sys.stdout.write(f" [server] {line}")
|
||||
sys.stdout.flush()
|
||||
fh.write(line)
|
||||
except ValueError:
|
||||
pass # pipe closed
|
||||
|
||||
log_thread = threading.Thread(target=_log_pipe, args=(proc.stdout, log_fh))
|
||||
log_thread.daemon = True
|
||||
log_thread.start()
|
||||
|
||||
base_url = f"http://{DEFAULT_HOST}:{port}"
|
||||
wait_for_health(base_url, framework)
|
||||
|
||||
# Warmup requests (not measured, no perf dump)
|
||||
# Use few steps to be fast — server's own warmup (warmup_steps=3) handles
|
||||
# torch.compile compilation; these external warmups just stabilize triton
|
||||
# kernel specializations across requests.
|
||||
WARMUP_STEPS = 3
|
||||
warmup_case = {**case, "num_inference_steps": WARMUP_STEPS}
|
||||
for wi in range(1, 3):
|
||||
print(f" Sending warmup request ({wi}/2, {WARMUP_STEPS} steps)...")
|
||||
try:
|
||||
send_request(base_url, warmup_case, framework, config)
|
||||
except Exception as e:
|
||||
print(f" Warmup request {wi} failed (non-fatal): {e}")
|
||||
|
||||
# Measured request — pass perf_dump_path for SGLang server-side timing
|
||||
if perf_dump_path and os.path.exists(perf_dump_path):
|
||||
os.remove(perf_dump_path)
|
||||
print(" Sending measured request...")
|
||||
latency = send_request(
|
||||
base_url, case, framework, config, perf_dump_path=perf_dump_path
|
||||
)
|
||||
result["latency_s"] = round(latency, 3)
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
print(f" ERROR: {e}")
|
||||
finally:
|
||||
if proc:
|
||||
kill_server(proc)
|
||||
if log_thread:
|
||||
log_thread.join(timeout=5)
|
||||
log_fh.close()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _install_framework(fw_name: str, dry_run: bool = False) -> bool:
|
||||
"""Install a comparison framework via the install script. Returns True on success."""
|
||||
if fw_name not in INSTALLABLE_FRAMEWORKS:
|
||||
return True
|
||||
if not INSTALL_SCRIPT.exists():
|
||||
print(f" WARNING: Install script not found at {INSTALL_SCRIPT}")
|
||||
return False
|
||||
if dry_run:
|
||||
print(f" [DRY-RUN] Would install: bash {INSTALL_SCRIPT} {fw_name}")
|
||||
return True
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Installing framework: {fw_name}")
|
||||
print(f"{'='*60}")
|
||||
ret = subprocess.run(
|
||||
["bash", str(INSTALL_SCRIPT), fw_name],
|
||||
timeout=600,
|
||||
)
|
||||
if ret.returncode != 0:
|
||||
print(f" WARNING: {fw_name} installation failed (exit {ret.returncode})")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_comparison(
|
||||
config: dict,
|
||||
case_ids: list[str] | None = None,
|
||||
frameworks: list[str] | None = None,
|
||||
port: int = DEFAULT_PORT,
|
||||
output: str = "comparison-results.json",
|
||||
dry_run: bool = False,
|
||||
) -> dict:
|
||||
"""Run all comparison cases, grouped by framework to minimize installs.
|
||||
|
||||
Order: sglang first (already installed), then vllm-omni, then lightx2v.
|
||||
Each non-sglang framework is installed right before its cases run.
|
||||
"""
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
commit_sha = os.environ.get("GITHUB_SHA", "unknown")
|
||||
run_id = os.environ.get("GITHUB_RUN_ID", "local")
|
||||
|
||||
log_dir = Path("comparison-logs")
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Collect all (case, framework) pairs, grouped by framework
|
||||
fw_order = ["sglang", "vllm-omni", "lightx2v"]
|
||||
fw_cases: dict[str, list[tuple[dict, dict]]] = {fw: [] for fw in fw_order}
|
||||
|
||||
for case in config["cases"]:
|
||||
if case_ids and case["id"] not in case_ids:
|
||||
continue
|
||||
for fw_name, fw_cfg in case["frameworks"].items():
|
||||
if frameworks and fw_name not in frameworks:
|
||||
continue
|
||||
if fw_name not in fw_cases:
|
||||
fw_cases[fw_name] = []
|
||||
fw_cases[fw_name].append((case, fw_cfg))
|
||||
|
||||
results = []
|
||||
installed_fws: set[str] = set()
|
||||
|
||||
for fw_name in fw_order:
|
||||
pairs = fw_cases.get(fw_name, [])
|
||||
if not pairs:
|
||||
continue
|
||||
|
||||
# Install framework if needed (once per framework)
|
||||
if fw_name not in installed_fws and fw_name in INSTALLABLE_FRAMEWORKS:
|
||||
if not _install_framework(fw_name, dry_run):
|
||||
# Skip all cases for this framework
|
||||
for case, _ in pairs:
|
||||
results.append(
|
||||
{
|
||||
"case_id": case["id"],
|
||||
"framework": fw_name,
|
||||
"model": case["model"],
|
||||
"task": case["task"],
|
||||
"latency_s": None,
|
||||
"error": f"{fw_name} installation failed",
|
||||
}
|
||||
)
|
||||
continue
|
||||
installed_fws.add(fw_name)
|
||||
|
||||
for case, fw_cfg in pairs:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Case: {case['id']} | Model: {case['model']} | Framework: {fw_name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if dry_run:
|
||||
cmd = build_server_cmd(fw_name, case, fw_cfg, port)
|
||||
print(f" [DRY-RUN] Would run: {' '.join(cmd)}")
|
||||
results.append(
|
||||
{
|
||||
"case_id": case["id"],
|
||||
"framework": fw_name,
|
||||
"model": case["model"],
|
||||
"task": case["task"],
|
||||
"latency_s": None,
|
||||
"error": "dry-run",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
result = run_single(case, fw_name, fw_cfg, port, log_dir, config)
|
||||
results.append(result)
|
||||
|
||||
# Wait for GPU memory to clear
|
||||
print(f" Waiting {GPU_CLEAR_WAIT}s for GPU memory to clear...")
|
||||
time.sleep(GPU_CLEAR_WAIT)
|
||||
|
||||
output_data = {
|
||||
"timestamp": timestamp,
|
||||
"commit_sha": commit_sha,
|
||||
"run_id": run_id,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
os.makedirs(os.path.dirname(output) or ".", exist_ok=True)
|
||||
with open(output, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
print(f"\nResults written to {output}")
|
||||
|
||||
# Print summary table
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
for r in results:
|
||||
lat = f"{r['latency_s']:.2f}s" if r["latency_s"] else r.get("error", "N/A")
|
||||
print(f" {r['case_id']:30s} | {r['framework']:12s} | {lat}")
|
||||
|
||||
return output_data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Cross-framework diffusion serving comparison benchmark"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=str(CONFIGS_PATH),
|
||||
help="Path to comparison_configs.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--case-ids",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Only run specific case IDs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frameworks",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Only run specific frameworks (sglang, vllm-omni, lightx2v)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=DEFAULT_PORT,
|
||||
help="Server port",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="comparison-results.json",
|
||||
help="Output JSON path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Parse config and print commands without launching servers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
print(f"Loaded {len(config['cases'])} comparison case(s) from {args.config}")
|
||||
|
||||
run_comparison(
|
||||
config=config,
|
||||
case_ids=args.case_ids,
|
||||
frameworks=args.frameworks,
|
||||
port=args.port,
|
||||
output=args.output,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
163
third_party/sglang/scripts/ci/utils/diffusion/save_diffusion_metrics.py
vendored
Executable file
163
third_party/sglang/scripts/ci/utils/diffusion/save_diffusion_metrics.py
vendored
Executable file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Collect and save diffusion performance metrics for artifact collection in CI.
|
||||
|
||||
This script reads diffusion test results from the pytest stash and saves them
|
||||
with metadata for the performance dashboard.
|
||||
|
||||
Usage:
|
||||
python3 scripts/ci/utils/diffusion/save_diffusion_metrics.py \
|
||||
--gpu-config 1-gpu-h100 \
|
||||
--run-id 12345678 \
|
||||
--output test/diffusion-metrics-1gpu.json \
|
||||
--results-json test/diffusion-results.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def load_diffusion_results(results_file: str) -> list[dict]:
|
||||
"""Load diffusion performance results from JSON file."""
|
||||
if not os.path.exists(results_file):
|
||||
print(f"Warning: Results file not found: {results_file}")
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(results_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data if isinstance(data, list) else [data]
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"Warning: Failed to parse {results_file}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def transform_diffusion_result(result: dict, gpu_config: str) -> dict:
|
||||
"""Transform a diffusion result to match dashboard expectations.
|
||||
|
||||
Dashboard expects:
|
||||
- Separate test_name, class_name
|
||||
- Numeric metrics in consistent units
|
||||
- Optional modality field
|
||||
"""
|
||||
return {
|
||||
"test_name": result.get("test_name"),
|
||||
"class_name": result.get("class_name"),
|
||||
"modality": result.get("modality", "image"),
|
||||
"e2e_ms": result.get("e2e_ms"),
|
||||
"avg_denoise_ms": result.get("avg_denoise_ms"),
|
||||
"median_denoise_ms": result.get("median_denoise_ms"),
|
||||
"stage_metrics": result.get("stage_metrics", {}),
|
||||
"sampled_steps": result.get("sampled_steps", {}),
|
||||
# Video-specific metrics (if present)
|
||||
"frames_per_second": result.get("frames_per_second"),
|
||||
"total_frames": result.get("total_frames"),
|
||||
"avg_frame_time_ms": result.get("avg_frame_time_ms"),
|
||||
}
|
||||
|
||||
|
||||
def group_results_by_class(results: list[dict], gpu_config: str) -> list[dict]:
|
||||
"""Group diffusion results by test class (suite).
|
||||
|
||||
Returns list with one entry per test class, containing all tests in that class.
|
||||
"""
|
||||
groups = {}
|
||||
|
||||
for result in results:
|
||||
class_name = result.get("class_name", "unknown")
|
||||
|
||||
if class_name not in groups:
|
||||
groups[class_name] = {
|
||||
"gpu_config": gpu_config,
|
||||
"test_suite": class_name,
|
||||
"tests": [],
|
||||
}
|
||||
|
||||
transformed = transform_diffusion_result(result, gpu_config)
|
||||
groups[class_name]["tests"].append(transformed)
|
||||
|
||||
return list(groups.values())
|
||||
|
||||
|
||||
def save_metrics(
|
||||
gpu_config: str,
|
||||
run_id: str,
|
||||
output_file: str,
|
||||
results_file: str,
|
||||
) -> bool:
|
||||
"""Collect diffusion metrics and save to output file."""
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Load diffusion results
|
||||
raw_results = load_diffusion_results(results_file)
|
||||
print(f"Loaded {len(raw_results)} diffusion test result(s)")
|
||||
|
||||
# Group by test class
|
||||
grouped = group_results_by_class(raw_results, gpu_config)
|
||||
|
||||
# Create metrics structure
|
||||
metrics = {
|
||||
"run_id": run_id,
|
||||
"timestamp": timestamp,
|
||||
"gpu_config": gpu_config,
|
||||
"test_type": "diffusion",
|
||||
"results": grouped,
|
||||
}
|
||||
|
||||
# Ensure output directory exists and write output
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
if not raw_results:
|
||||
print(f"Created empty metrics file: {output_file}")
|
||||
else:
|
||||
print(f"Saved diffusion metrics to: {output_file}")
|
||||
return True
|
||||
except OSError as e:
|
||||
print(f"Error writing metrics file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Collect diffusion performance metrics from test results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-config",
|
||||
required=True,
|
||||
help="GPU configuration (e.g., 1-gpu-h100, 2-gpu-h100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-id",
|
||||
required=True,
|
||||
help="GitHub Actions run ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=True,
|
||||
help="Output file path for metrics JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results-json",
|
||||
required=True,
|
||||
help="Path to diffusion results JSON file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
success = save_metrics(
|
||||
gpu_config=args.gpu_config,
|
||||
run_id=args.run_id,
|
||||
output_file=args.output,
|
||||
results_file=args.results_json,
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
third_party/sglang/scripts/ci/utils/install_protoc.sh
vendored
Executable file
53
third_party/sglang/scripts/ci/utils/install_protoc.sh
vendored
Executable file
@@ -0,0 +1,53 @@
|
||||
#!/bin/bash
|
||||
# Ensure protoc is installed for router build (gRPC protobuf compilation).
|
||||
set -euxo pipefail
|
||||
|
||||
if command -v protoc >/dev/null 2>&1 && protoc --version >/dev/null 2>&1; then
|
||||
echo "protoc already installed: $(protoc --version)"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if command -v protoc >/dev/null 2>&1; then
|
||||
echo "protoc found but not runnable, reinstalling..."
|
||||
else
|
||||
echo "protoc not found, installing..."
|
||||
fi
|
||||
|
||||
ARCH=$(uname -m)
|
||||
|
||||
if command -v apt-get &> /dev/null; then
|
||||
# Ubuntu/Debian
|
||||
apt-get update || true # May fail due to unrelated broken packages
|
||||
PROTOC_APT_PACKAGES=(wget unzip gcc g++ perl make)
|
||||
apt-get install -y --no-install-recommends "${PROTOC_APT_PACKAGES[@]}" || {
|
||||
echo "Warning: apt-get install failed, checking if required packages are available..."
|
||||
for pkg in "${PROTOC_APT_PACKAGES[@]}"; do
|
||||
if ! dpkg -l "$pkg" 2>/dev/null | grep -q "^ii"; then
|
||||
echo "ERROR: Required package $pkg is not installed and apt-get failed"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo "All required packages are already installed, continuing..."
|
||||
}
|
||||
elif command -v yum &> /dev/null; then
|
||||
# RHEL/CentOS
|
||||
yum update -y
|
||||
yum install -y wget unzip gcc gcc-c++ perl-core make
|
||||
else
|
||||
echo "ERROR: Neither apt-get nor yum found; cannot install protoc build deps"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then
|
||||
PROTOC_ARCH="aarch_64"
|
||||
else
|
||||
PROTOC_ARCH="x86_64"
|
||||
fi
|
||||
PROTOC_ZIP="protoc-32.0-linux-${PROTOC_ARCH}.zip"
|
||||
(
|
||||
cd /tmp
|
||||
wget "https://github.com/protocolbuffers/protobuf/releases/download/v32.0/${PROTOC_ZIP}"
|
||||
unzip -o "${PROTOC_ZIP}" -d /usr/local
|
||||
rm -f "${PROTOC_ZIP}"
|
||||
)
|
||||
protoc --version
|
||||
141
third_party/sglang/scripts/ci/utils/merge_metrics.py
vendored
Executable file
141
third_party/sglang/scripts/ci/utils/merge_metrics.py
vendored
Executable file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Merge per-partition metrics into a consolidated metrics file.
|
||||
|
||||
This script reads all per-partition metric JSON files and consolidates them
|
||||
into a single JSON file with run-level metadata.
|
||||
|
||||
Usage:
|
||||
python3 scripts/ci/utils/merge_metrics.py \
|
||||
--input-dir metrics/ \
|
||||
--output consolidated-metrics-12345678.json \
|
||||
--run-id 12345678 \
|
||||
--commit-sha abc123def456
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def find_partition_files(input_dir: str) -> list[str]:
|
||||
"""Find all partition metric files in the input directory."""
|
||||
patterns = [
|
||||
os.path.join(input_dir, "**/metrics-*.json"),
|
||||
os.path.join(input_dir, "**/diffusion-metrics-*.json"),
|
||||
os.path.join(input_dir, "**/comparison-metrics-*.json"),
|
||||
]
|
||||
files = set()
|
||||
for pattern in patterns:
|
||||
files.update(glob.glob(pattern, recursive=True))
|
||||
return list(files)
|
||||
|
||||
|
||||
def load_partition_metrics(filepath: str) -> dict | None:
|
||||
"""Load a partition metrics file."""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"Warning: Failed to load {filepath}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def merge_metrics(
|
||||
input_dir: str,
|
||||
output_file: str,
|
||||
run_id: str,
|
||||
commit_sha: str,
|
||||
branch: str | None = None,
|
||||
) -> bool:
|
||||
"""Merge all partition metrics into a consolidated file."""
|
||||
run_date = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Find all partition files
|
||||
partition_files = find_partition_files(input_dir)
|
||||
print(f"Found {len(partition_files)} partition file(s)")
|
||||
|
||||
all_results = []
|
||||
if not partition_files:
|
||||
print("No partition metrics files found")
|
||||
else:
|
||||
# Load all partition files
|
||||
for filepath in sorted(partition_files):
|
||||
print(f" Reading: {filepath}")
|
||||
metrics = load_partition_metrics(filepath)
|
||||
if metrics and "results" in metrics:
|
||||
all_results.extend(metrics["results"])
|
||||
print(f"Total results collected: {len(all_results)}")
|
||||
|
||||
# Create consolidated structure
|
||||
consolidated = {
|
||||
"run_id": run_id,
|
||||
"run_date": run_date,
|
||||
"commit_sha": commit_sha,
|
||||
"branch": branch,
|
||||
"results": all_results,
|
||||
}
|
||||
|
||||
# Ensure output directory exists and write output
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(consolidated, f, indent=2)
|
||||
|
||||
if not partition_files:
|
||||
print(f"Created empty consolidated file: {output_file}")
|
||||
else:
|
||||
print(f"Saved consolidated metrics to: {output_file}")
|
||||
return True
|
||||
except OSError as e:
|
||||
print(f"Error writing consolidated file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Merge per-partition metrics into consolidated file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
required=True,
|
||||
help="Directory containing partition metric files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=True,
|
||||
help="Output file path for consolidated metrics JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-id",
|
||||
required=True,
|
||||
help="GitHub Actions run ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit-sha",
|
||||
required=True,
|
||||
help="Git commit SHA",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
default=None,
|
||||
help="Git branch name (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
success = merge_metrics(
|
||||
input_dir=args.input_dir,
|
||||
output_file=args.output,
|
||||
run_id=args.run_id,
|
||||
commit_sha=args.commit_sha,
|
||||
branch=args.branch,
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
407
third_party/sglang/scripts/ci/utils/prevalidate_cached_models.py
vendored
Executable file
407
third_party/sglang/scripts/ci/utils/prevalidate_cached_models.py
vendored
Executable file
@@ -0,0 +1,407 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-validate all cached HuggingFace models to provide detailed feedback.
|
||||
|
||||
This script runs once during CI initialization (in prepare_runner.sh) to:
|
||||
1. Scan snapshots in ~/.cache/huggingface/hub/ (with time/quantity limits)
|
||||
2. Validate completeness (config/tokenizer/weights)
|
||||
3. Output detailed failure reasons for debugging
|
||||
|
||||
NOTE: This script no longer writes shared validation markers. Each test run
|
||||
independently validates its cache using per-run markers to avoid cross-runner
|
||||
cache state pollution.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add python directory to path to import sglang modules
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.insert(0, str(REPO_ROOT / "python"))
|
||||
|
||||
from sglang.srt.model_loader.ci_weight_validation import ( # noqa: E402
|
||||
_validate_diffusion_model,
|
||||
validate_cache_with_detailed_reason,
|
||||
)
|
||||
|
||||
# Limits to avoid spending too much time on validation
|
||||
MAX_VALIDATION_TIME_SECONDS = 300 # Max 5 minutes total
|
||||
|
||||
|
||||
def find_all_hf_snapshots():
|
||||
"""
|
||||
Find all HuggingFace snapshots in cache.
|
||||
|
||||
Returns:
|
||||
List of (model_name, snapshot_dir) tuples, sorted by mtime (newest first)
|
||||
"""
|
||||
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
||||
hub_dir = os.path.join(hf_home, "hub")
|
||||
|
||||
if not os.path.isdir(hub_dir):
|
||||
print(f"HF hub directory not found: {hub_dir}")
|
||||
return []
|
||||
|
||||
snapshots = []
|
||||
|
||||
# Pattern: models--org--model/snapshots/hash
|
||||
for model_dir in glob.glob(os.path.join(hub_dir, "models--*")):
|
||||
# Extract model name from directory (models--org--model -> org/model)
|
||||
dir_name = os.path.basename(model_dir)
|
||||
if not dir_name.startswith("models--"):
|
||||
continue
|
||||
|
||||
# models--meta-llama--Llama-2-7b-hf -> meta-llama/Llama-2-7b-hf
|
||||
# Handle multi-part names: models--a--b--c -> a/b-c (join parts 1+ with /)
|
||||
parts = dir_name.split("--")
|
||||
if len(parts) < 3 or parts[0] != "models":
|
||||
# Invalid format, skip
|
||||
continue
|
||||
# Standard format: models--org--repo -> org/repo
|
||||
# Extended format: models--org--repo--extra -> org/repo-extra (join with -)
|
||||
model_name = parts[1] + "/" + "-".join(parts[2:])
|
||||
|
||||
snapshots_dir = os.path.join(model_dir, "snapshots")
|
||||
if not os.path.isdir(snapshots_dir):
|
||||
continue
|
||||
|
||||
# Find all snapshot hashes
|
||||
for snapshot_hash_dir in os.listdir(snapshots_dir):
|
||||
snapshot_path = os.path.join(snapshots_dir, snapshot_hash_dir)
|
||||
if os.path.isdir(snapshot_path):
|
||||
try:
|
||||
mtime = os.path.getmtime(snapshot_path)
|
||||
snapshots.append((model_name, snapshot_path, mtime))
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
# Sort by mtime (newest first) - prioritize recently used models
|
||||
snapshots.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Return without mtime
|
||||
return [(name, path) for name, path, _ in snapshots]
|
||||
|
||||
|
||||
def is_transformers_text_model(snapshot_dir):
|
||||
"""
|
||||
Check if a snapshot is a transformers text model.
|
||||
|
||||
Only excludes (returns False) for models with STRONG evidence of being
|
||||
diffusers/generation pipelines. Uses conservative heuristics to avoid
|
||||
false negatives on multimodal LLMs with tokenizers.
|
||||
|
||||
Args:
|
||||
snapshot_dir: Path to snapshot directory
|
||||
|
||||
Returns:
|
||||
True if this looks like a transformers text model, False otherwise (N/A)
|
||||
"""
|
||||
# Check for diffusers pipeline markers (strong evidence)
|
||||
diffusers_markers = [
|
||||
"model_index.json", # Diffusers pipeline config
|
||||
"scheduler", # Scheduler directory (diffusers)
|
||||
]
|
||||
if any(
|
||||
os.path.exists(os.path.join(snapshot_dir, marker))
|
||||
for marker in diffusers_markers
|
||||
):
|
||||
return False
|
||||
|
||||
config_path = os.path.join(snapshot_dir, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
# No config.json - likely not a transformers model
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Check for explicit diffusers/generation model types (conservative keywords)
|
||||
model_type = config.get("_class_name") or config.get("model_type")
|
||||
if model_type:
|
||||
model_type_lower = str(model_type).lower()
|
||||
# Only exclude clear diffusion/generation models
|
||||
if any(
|
||||
keyword in model_type_lower
|
||||
for keyword in [
|
||||
"diffusion",
|
||||
"unet",
|
||||
"vae",
|
||||
"controlnet",
|
||||
"stable-diffusion",
|
||||
"latent-diffusion",
|
||||
]
|
||||
):
|
||||
return False
|
||||
|
||||
# Check architectures for explicit generation/diffusion classes
|
||||
architectures = config.get("architectures", [])
|
||||
if architectures:
|
||||
arch_str = " ".join(architectures).lower()
|
||||
# Conservative: only exclude obvious diffusion/generation architectures
|
||||
# Use word boundaries to avoid false positives (e.g., "dit" in "conditional")
|
||||
for keyword in [
|
||||
"diffusion",
|
||||
"unet2d",
|
||||
"unet3d",
|
||||
"vaedecoder", # More specific than "vae"
|
||||
"vaeencoder",
|
||||
"controlnet",
|
||||
"autoencoder",
|
||||
"ditmodel", # Diffusion Transformer - use more specific pattern
|
||||
"pixart", # PixArt diffusion model
|
||||
]:
|
||||
if keyword in arch_str:
|
||||
return False
|
||||
|
||||
# Check for standalone vision encoder/image processor (no text component)
|
||||
# Only if model name explicitly indicates non-text usage
|
||||
model_name = config.get("_name_or_path", "").lower()
|
||||
|
||||
if any(
|
||||
keyword in model_name
|
||||
for keyword in [
|
||||
"image-edit-", # Pure image editing (e.g., Qwen-Image-Edit)
|
||||
"-image-editing",
|
||||
"dit-", # DiT generation models
|
||||
"pixart-", # PixArt generation models
|
||||
]
|
||||
):
|
||||
# Additional check: does it have tokenizer? If yes, might be multimodal LLM
|
||||
has_tokenizer = any(
|
||||
os.path.exists(os.path.join(snapshot_dir, fname))
|
||||
for fname in ["tokenizer.json", "tokenizer.model", "tiktoken.model"]
|
||||
)
|
||||
if not has_tokenizer:
|
||||
# Image-edit model without tokenizer -> likely pure vision pipeline
|
||||
return False
|
||||
|
||||
# Default: assume it's a transformers text/multimodal model
|
||||
# Even if it lacks tokenizer, let validation report the actual error
|
||||
# (better false positive than false negative for text models)
|
||||
return True
|
||||
|
||||
except (json.JSONDecodeError, OSError, KeyError):
|
||||
# Can't parse config - assume it's transformers and let validation report failure
|
||||
return True
|
||||
|
||||
|
||||
def scan_weight_files(snapshot_dir):
|
||||
"""
|
||||
Scan for weight files in a snapshot.
|
||||
|
||||
Returns:
|
||||
List of weight file paths, or empty list if scan fails
|
||||
"""
|
||||
weight_files = []
|
||||
|
||||
# First, look for index files
|
||||
index_patterns = ["*.safetensors.index.json", "pytorch_model.bin.index.json"]
|
||||
index_files = []
|
||||
for pattern in index_patterns:
|
||||
index_files.extend(glob.glob(os.path.join(snapshot_dir, pattern)))
|
||||
|
||||
# If we have safetensors index, collect shards from it
|
||||
for index_file in index_files:
|
||||
if index_file.endswith(".safetensors.index.json"):
|
||||
try:
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index_data = json.load(f)
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
for weight_file in set(weight_map.values()):
|
||||
weight_path = os.path.join(snapshot_dir, weight_file)
|
||||
if os.path.exists(weight_path):
|
||||
weight_files.append(weight_path)
|
||||
except Exception as e:
|
||||
print(
|
||||
f" Warning: Failed to parse index {os.path.basename(index_file)}: {e}"
|
||||
)
|
||||
|
||||
# If no index found or no shards from index, do recursive glob
|
||||
if not weight_files:
|
||||
matched = glob.glob(
|
||||
os.path.join(snapshot_dir, "**/*.safetensors"), recursive=True
|
||||
)
|
||||
MAX_WEIGHT_FILES = 1000
|
||||
if len(matched) > MAX_WEIGHT_FILES:
|
||||
print(
|
||||
f" Warning: Too many safetensors files ({len(matched)} > {MAX_WEIGHT_FILES})"
|
||||
)
|
||||
return []
|
||||
|
||||
for f in matched:
|
||||
if os.path.exists(f): # Filter out broken symlinks
|
||||
weight_files.append(f)
|
||||
|
||||
return weight_files
|
||||
|
||||
|
||||
def validate_snapshot(model_name, snapshot_dir, weight_files, validated_cache):
|
||||
"""
|
||||
Validate a snapshot and return detailed status.
|
||||
|
||||
Uses in-process cache to avoid duplicate validation within the same run.
|
||||
|
||||
Args:
|
||||
model_name: Model identifier
|
||||
snapshot_dir: Path to snapshot directory
|
||||
weight_files: List of weight files to validate
|
||||
validated_cache: Dict to track already-validated snapshots in this run
|
||||
|
||||
Returns:
|
||||
Tuple of (result, reason):
|
||||
- (True, None) if validation passed
|
||||
- (False, reason_str) if validation failed
|
||||
- (None, None) if skipped (already validated in this run)
|
||||
"""
|
||||
# Fast path: check in-process cache first
|
||||
if snapshot_dir in validated_cache:
|
||||
return None, None # Already validated in this run, skip
|
||||
|
||||
try:
|
||||
# Perform validation with detailed reason
|
||||
is_complete, reason = validate_cache_with_detailed_reason(
|
||||
snapshot_dir=snapshot_dir,
|
||||
weight_files=weight_files,
|
||||
model_name_or_path=model_name,
|
||||
)
|
||||
|
||||
# Cache result to avoid re-validation in this run
|
||||
validated_cache[snapshot_dir] = (is_complete, reason)
|
||||
|
||||
return is_complete, reason
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Validation raised exception: {e}"
|
||||
return False, error_msg
|
||||
|
||||
|
||||
def main():
|
||||
start_time = time.time()
|
||||
|
||||
print("=" * 70)
|
||||
print("CI_OFFLINE: Pre-validating cached HuggingFace models")
|
||||
print("=" * 70)
|
||||
print(f"Max time: {MAX_VALIDATION_TIME_SECONDS}s")
|
||||
print()
|
||||
|
||||
print("Scanning HuggingFace cache for models...")
|
||||
snapshots = find_all_hf_snapshots()
|
||||
|
||||
if not snapshots:
|
||||
print("No cached models found, skipping validation")
|
||||
print("=" * 70)
|
||||
return
|
||||
|
||||
print(f"Found {len(snapshots)} snapshot(s) in cache")
|
||||
print()
|
||||
|
||||
validated_count = 0
|
||||
failed_count = 0
|
||||
skipped_count = 0
|
||||
processed_count = 0
|
||||
|
||||
# In-process cache to avoid re-validating same snapshot in this run
|
||||
validated_cache = {}
|
||||
|
||||
for model_name, snapshot_dir in snapshots:
|
||||
# Check time limit
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > MAX_VALIDATION_TIME_SECONDS:
|
||||
print()
|
||||
print(
|
||||
f"Time limit reached ({elapsed:.1f}s > {MAX_VALIDATION_TIME_SECONDS}s)"
|
||||
)
|
||||
print(
|
||||
f"Stopping validation, {len(snapshots) - processed_count} snapshots remaining"
|
||||
)
|
||||
break
|
||||
|
||||
snapshot_hash = os.path.basename(snapshot_dir)
|
||||
print(
|
||||
f"[{processed_count + 1}/{len(snapshots)}] {model_name} ({snapshot_hash[:8]}...)"
|
||||
)
|
||||
processed_count += 1
|
||||
|
||||
# Determine model type by checking for model_index.json (diffusers pipeline marker)
|
||||
model_index_path = os.path.join(snapshot_dir, "model_index.json")
|
||||
is_diffusion_model = os.path.exists(model_index_path)
|
||||
|
||||
if is_diffusion_model:
|
||||
# This is a diffusers pipeline - use diffusion validation
|
||||
try:
|
||||
is_valid, reason = _validate_diffusion_model(snapshot_dir)
|
||||
|
||||
if is_valid:
|
||||
print(" PASS (diffusion) - Cache complete & valid")
|
||||
validated_count += 1
|
||||
else:
|
||||
print(f" FAIL (diffusion) - {reason}")
|
||||
failed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" FAIL (diffusion) - Validation raised exception: {e}")
|
||||
failed_count += 1
|
||||
|
||||
continue
|
||||
|
||||
# Transformers model - use standard validation
|
||||
# First check if this looks like a transformers text model
|
||||
if not is_transformers_text_model(snapshot_dir):
|
||||
# Not a recognized model type, skip
|
||||
print(
|
||||
" SKIP (unknown type) - Not a diffusers pipeline or transformers model"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Scan weight files
|
||||
weight_files = scan_weight_files(snapshot_dir)
|
||||
|
||||
if not weight_files:
|
||||
print(" SKIP (no weights) - empty or incomplete download")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Validate
|
||||
try:
|
||||
result, reason = validate_snapshot(
|
||||
model_name, snapshot_dir, weight_files, validated_cache
|
||||
)
|
||||
|
||||
if result is True:
|
||||
print(" PASS - Cache complete & valid")
|
||||
validated_count += 1
|
||||
elif result is False:
|
||||
# Print detailed failure reason
|
||||
if reason:
|
||||
print(f" FAIL (incomplete) - {reason}")
|
||||
else:
|
||||
print(" FAIL (incomplete) - cache validation failed")
|
||||
failed_count += 1
|
||||
else: # None (skipped)
|
||||
print(" SKIP (already validated in this run)")
|
||||
skipped_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" FAIL (error) - Validation raised exception: {e}")
|
||||
failed_count += 1
|
||||
|
||||
elapsed_total = time.time() - start_time
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f"Validation summary (completed in {elapsed_total:.1f}s):")
|
||||
print(f" PASS (complete & valid): {validated_count}")
|
||||
print(f" FAIL (incomplete/corrupted): {failed_count}")
|
||||
print(f" SKIP (no weights/duplicate): {skipped_count}")
|
||||
print(f" Total processed: {processed_count}/{len(snapshots)}")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
517
third_party/sglang/scripts/ci/utils/publish_traces.py
vendored
Normal file
517
third_party/sglang/scripts/ci/utils/publish_traces.py
vendored
Normal file
@@ -0,0 +1,517 @@
|
||||
"""
|
||||
Publish performance traces to GitHub repository
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from urllib.error import HTTPError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
def is_rate_limit_error(e):
|
||||
"""Check if an exception is a GitHub rate limit error (not permission error)"""
|
||||
if not isinstance(e, HTTPError):
|
||||
return False
|
||||
if e.code == 429:
|
||||
return True
|
||||
if e.code == 403:
|
||||
# 403 can be rate limit OR permission error - check the message
|
||||
error_body = getattr(e, "error_body", "")
|
||||
if isinstance(error_body, str):
|
||||
# Rate limit errors contain specific phrases
|
||||
rate_limit_phrases = [
|
||||
"rate limit",
|
||||
"abuse detection",
|
||||
"secondary rate limit",
|
||||
]
|
||||
return any(phrase in error_body.lower() for phrase in rate_limit_phrases)
|
||||
return False
|
||||
|
||||
|
||||
def is_permission_error(e):
|
||||
"""Check if an exception is a GitHub permission error"""
|
||||
if not isinstance(e, HTTPError) or e.code != 403:
|
||||
return False
|
||||
error_body = getattr(e, "error_body", "")
|
||||
if isinstance(error_body, str):
|
||||
permission_phrases = [
|
||||
"resource not accessible",
|
||||
"must have push access",
|
||||
"permission",
|
||||
"denied",
|
||||
]
|
||||
return any(phrase in error_body.lower() for phrase in permission_phrases)
|
||||
return False
|
||||
|
||||
|
||||
def make_github_request(url, token, method="GET", data=None):
|
||||
"""Make authenticated request to GitHub API"""
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
# "User-Agent": "sglang-ci",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
if data:
|
||||
headers["Content-Type"] = "application/json"
|
||||
data = json.dumps(data).encode("utf-8")
|
||||
|
||||
req = Request(url, data=data, headers=headers, method=method)
|
||||
|
||||
try:
|
||||
with urlopen(req) as response:
|
||||
return response.read().decode("utf-8")
|
||||
except HTTPError as e:
|
||||
print(f"GitHub API request failed: {e}")
|
||||
try:
|
||||
error_body = e.read().decode("utf-8")
|
||||
print(f"Error response body: {error_body}")
|
||||
e.error_body = error_body # Attach for later inspection
|
||||
except Exception:
|
||||
e.error_body = ""
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"GitHub API request failed with a non-HTTP error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def verify_token_permissions(repo_owner, repo_name, token):
|
||||
"""Verify that the token has necessary permissions for the repository"""
|
||||
print("Verifying token permissions...")
|
||||
|
||||
checks = [
|
||||
(
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}", # Check if we can access the repository
|
||||
"Repository access verified",
|
||||
),
|
||||
(
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/contents", # Check if we can read the repository contents
|
||||
"Repository contents access verified",
|
||||
),
|
||||
]
|
||||
|
||||
for url, success_message in checks:
|
||||
try:
|
||||
response = make_github_request(url, token)
|
||||
if success_message == "Repository access verified":
|
||||
repo_data = json.loads(response)
|
||||
print(f"{success_message}: {repo_data['full_name']}")
|
||||
else:
|
||||
print(success_message)
|
||||
except Exception as e:
|
||||
if is_rate_limit_error(e):
|
||||
warnings.warn(
|
||||
"GitHub API rate limit exceeded during token verification."
|
||||
)
|
||||
return "rate_limited"
|
||||
print(f"Failed to verify permissions for {url}: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_branch_sha(repo_owner, repo_name, branch, token):
|
||||
"""Get SHA of the branch head"""
|
||||
url = (
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/refs/heads/{branch}"
|
||||
)
|
||||
response = make_github_request(url, token)
|
||||
data = json.loads(response)
|
||||
return data["object"]["sha"]
|
||||
|
||||
|
||||
def get_tree_sha(repo_owner, repo_name, commit_sha, token):
|
||||
"""Get tree SHA from commit"""
|
||||
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/commits/{commit_sha}"
|
||||
response = make_github_request(url, token)
|
||||
data = json.loads(response)
|
||||
return data["tree"]["sha"]
|
||||
|
||||
|
||||
def create_blob(repo_owner, repo_name, content, token, max_retries=3):
|
||||
"""Create a blob with file content"""
|
||||
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/blobs"
|
||||
|
||||
# Encode content as base64 for GitHub API
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
data = {"content": content_b64, "encoding": "base64"}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = make_github_request(url, token, method="POST", data=data)
|
||||
return json.loads(response)["sha"]
|
||||
except Exception as e:
|
||||
# Don't retry on rate limit errors - fail fast
|
||||
if is_rate_limit_error(e):
|
||||
raise
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt # Exponential backoff: 1s, 2s, 4s
|
||||
print(
|
||||
f"Blob creation failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def create_blobs(repo_owner, repo_name, files, token):
|
||||
"""Create blobs for all files and return tree items with blob SHAs"""
|
||||
tree_items = []
|
||||
for i, (file_path, content) in enumerate(files):
|
||||
# Create blob first to get SHA
|
||||
blob_sha = create_blob(repo_owner, repo_name, content, token)
|
||||
tree_items.append(
|
||||
{
|
||||
"path": file_path,
|
||||
"mode": "100644",
|
||||
"type": "blob",
|
||||
"sha": blob_sha,
|
||||
}
|
||||
)
|
||||
# Progress indicator for large uploads
|
||||
if (i + 1) % 10 == 0 or (i + 1) == len(files):
|
||||
print(f"Created {i + 1}/{len(files)} blobs...")
|
||||
return tree_items
|
||||
|
||||
|
||||
def create_tree(repo_owner, repo_name, base_tree_sha, tree_items, token, max_retries=3):
|
||||
"""Create a new tree from pre-created blob SHAs"""
|
||||
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/trees"
|
||||
|
||||
data = {"base_tree": base_tree_sha, "tree": tree_items}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = make_github_request(url, token, method="POST", data=data)
|
||||
return json.loads(response)["sha"]
|
||||
except Exception as e:
|
||||
# Don't retry on rate limit errors - fail fast
|
||||
if is_rate_limit_error(e):
|
||||
raise
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt
|
||||
print(
|
||||
f"Tree creation failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def create_commit(
|
||||
repo_owner, repo_name, tree_sha, parent_sha, message, token, max_retries=3
|
||||
):
|
||||
"""Create a new commit"""
|
||||
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/commits"
|
||||
|
||||
data = {"tree": tree_sha, "parents": [parent_sha], "message": message}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = make_github_request(url, token, method="POST", data=data)
|
||||
commit_sha = json.loads(response)["sha"]
|
||||
|
||||
# Verify the commit was actually created
|
||||
verify_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/commits/{commit_sha}"
|
||||
verify_response = make_github_request(verify_url, token)
|
||||
verify_data = json.loads(verify_response)
|
||||
if verify_data["sha"] != commit_sha:
|
||||
raise Exception(
|
||||
f"Commit verification failed: expected {commit_sha}, got {verify_data['sha']}"
|
||||
)
|
||||
|
||||
return commit_sha
|
||||
except Exception as e:
|
||||
# Don't retry on rate limit errors - fail fast
|
||||
if is_rate_limit_error(e):
|
||||
raise
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt
|
||||
print(
|
||||
f"Commit creation failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def update_branch_ref(repo_owner, repo_name, branch, commit_sha, token, max_retries=3):
|
||||
"""Update branch reference to point to new commit"""
|
||||
url = (
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/git/refs/heads/{branch}"
|
||||
)
|
||||
|
||||
data = {"sha": commit_sha}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
make_github_request(url, token, method="PATCH", data=data)
|
||||
return
|
||||
except HTTPError as e:
|
||||
# Don't retry on rate limit errors - fail fast
|
||||
if is_rate_limit_error(e):
|
||||
raise
|
||||
|
||||
# Check if this is an "Object does not exist" error
|
||||
is_object_not_exist = False
|
||||
if hasattr(e, "error_body"):
|
||||
try:
|
||||
error_data = json.loads(e.error_body)
|
||||
if "Object does not exist" in error_data.get("message", ""):
|
||||
is_object_not_exist = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_object_not_exist and attempt < max_retries - 1:
|
||||
# This might be a transient consistency issue - wait and retry
|
||||
wait_time = 2**attempt
|
||||
print(
|
||||
f"Branch update failed with 'Object does not exist' (attempt {attempt + 1}/{max_retries}), waiting {wait_time}s for consistency..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Don't retry on rate limit errors - fail fast
|
||||
if is_rate_limit_error(e):
|
||||
raise
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt
|
||||
print(
|
||||
f"Branch update failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def copy_trace_files(source_dir, target_base_path):
|
||||
"""Copy trace files and return list of files to upload.
|
||||
|
||||
Only uploads traces from TP rank 0 to avoid duplicated data across tensor parallel ranks.
|
||||
"""
|
||||
files_to_upload = []
|
||||
|
||||
if not os.path.exists(source_dir):
|
||||
print(f"Warning: Traces directory {source_dir} does not exist")
|
||||
return files_to_upload
|
||||
|
||||
# Walk through source directory and find .json.gz files
|
||||
for root, dirs, files in os.walk(source_dir):
|
||||
for file in files:
|
||||
if file.endswith(".json.gz"):
|
||||
|
||||
# Only upload TP rank 0 traces to avoid duplicates across tensor parallel ranks
|
||||
if "TP-" in file and "TP-0" not in file:
|
||||
continue
|
||||
|
||||
source_file = os.path.join(root, file)
|
||||
# Calculate relative path from source_dir
|
||||
rel_path = os.path.relpath(source_file, source_dir)
|
||||
target_path = f"{target_base_path}/{rel_path}"
|
||||
|
||||
# Read file content
|
||||
with open(source_file, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
files_to_upload.append((target_path, content))
|
||||
|
||||
return files_to_upload
|
||||
|
||||
|
||||
def publish_traces(traces_dir, run_id, run_number):
|
||||
"""Publish traces from a single directory to GitHub repository in a single commit"""
|
||||
target_base_path = f"traces/{run_id}"
|
||||
files_to_upload = copy_trace_files(traces_dir, target_base_path)
|
||||
|
||||
if not files_to_upload:
|
||||
print("No trace files found to upload")
|
||||
return
|
||||
|
||||
print(f"Found {len(files_to_upload)} files to upload")
|
||||
publish_traces_from_files(files_to_upload, run_id, run_number)
|
||||
|
||||
|
||||
def publish_traces_from_files(files_to_upload, run_id, run_number):
|
||||
"""Publish pre-collected trace files to GitHub repository in a single commit"""
|
||||
# Get environment variables
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
if not token:
|
||||
print("Error: GITHUB_TOKEN environment variable not set")
|
||||
sys.exit(1)
|
||||
|
||||
# Repository configuration
|
||||
repo_owner = "sglang-bot"
|
||||
repo_name = "sglang-ci-data"
|
||||
branch = "main"
|
||||
|
||||
# Verify token permissions before proceeding
|
||||
permission_check = verify_token_permissions(repo_owner, repo_name, token)
|
||||
if permission_check == "rate_limited":
|
||||
warnings.warn(
|
||||
"Skipping trace upload due to GitHub API rate limit. "
|
||||
"This is expected during high CI activity and does not indicate a test failure."
|
||||
)
|
||||
return
|
||||
elif not permission_check:
|
||||
print(
|
||||
"Token permission verification failed. Please check the token permissions."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
max_retries = 5
|
||||
retry_delay = 5 # seconds
|
||||
|
||||
# Create blobs once before retry loop to avoid re-uploading on failures
|
||||
try:
|
||||
tree_items = create_blobs(repo_owner, repo_name, files_to_upload, token)
|
||||
except Exception as e:
|
||||
# Check for rate limit errors during blob creation
|
||||
if is_rate_limit_error(e):
|
||||
warnings.warn(
|
||||
"GitHub API rate limit exceeded during blob creation. Skipping trace upload."
|
||||
)
|
||||
return
|
||||
# Check for permission errors - these should fail loudly
|
||||
if is_permission_error(e):
|
||||
print(
|
||||
f"ERROR: Token does not have write permission to {repo_owner}/{repo_name}. "
|
||||
"Please update the GH_PAT_FOR_NIGHTLY_CI_DATA secret with a token that has "
|
||||
"'contents: write' permission for the repository."
|
||||
)
|
||||
sys.exit(1)
|
||||
print(f"Failed to create blobs: {e}")
|
||||
raise
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Get current branch head
|
||||
branch_sha = get_branch_sha(repo_owner, repo_name, branch, token)
|
||||
print(f"Current branch head: {branch_sha}")
|
||||
|
||||
# Get current tree
|
||||
tree_sha = get_tree_sha(repo_owner, repo_name, branch_sha, token)
|
||||
print(f"Current tree SHA: {tree_sha}")
|
||||
|
||||
# Create new tree with pre-created blobs
|
||||
new_tree_sha = create_tree(
|
||||
repo_owner, repo_name, tree_sha, tree_items, token
|
||||
)
|
||||
print(f"Created new tree: {new_tree_sha}")
|
||||
|
||||
# Create commit
|
||||
commit_message = f"Nightly traces for run {run_id} at {run_number} ({len(files_to_upload)} files)"
|
||||
commit_sha = create_commit(
|
||||
repo_owner,
|
||||
repo_name,
|
||||
new_tree_sha,
|
||||
branch_sha,
|
||||
commit_message,
|
||||
token,
|
||||
)
|
||||
print(f"Created commit: {commit_sha}")
|
||||
|
||||
# Update branch reference
|
||||
update_branch_ref(repo_owner, repo_name, branch, commit_sha, token)
|
||||
print("Updated branch reference")
|
||||
|
||||
print("Successfully published all traces in a single commit")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
# Check for retryable errors
|
||||
is_retryable = False
|
||||
error_type = "unknown"
|
||||
|
||||
if hasattr(e, "error_body"):
|
||||
if "Update is not a fast forward" in e.error_body:
|
||||
is_retryable = True
|
||||
error_type = "fast-forward conflict"
|
||||
elif "Object does not exist" in e.error_body:
|
||||
is_retryable = True
|
||||
error_type = "object consistency"
|
||||
|
||||
# Also retry on HTTP errors that might be transient
|
||||
if isinstance(e, HTTPError) and e.code in [422, 500, 502, 503, 504]:
|
||||
is_retryable = True
|
||||
error_type = f"HTTP {e.code}"
|
||||
|
||||
# Check for rate limit errors (non-fatal - just warn and skip)
|
||||
if is_rate_limit_error(e):
|
||||
warnings.warn("GitHub API rate limit exceeded. Skipping trace upload.")
|
||||
return
|
||||
|
||||
# Check for permission errors - these should fail loudly
|
||||
if is_permission_error(e):
|
||||
print(
|
||||
f"ERROR: Token does not have write permission to {repo_owner}/{repo_name}. "
|
||||
"Please update the GH_PAT_FOR_NIGHTLY_CI_DATA secret with a token that has "
|
||||
"'contents: write' permission for the repository."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
print(
|
||||
f"Attempt {attempt + 1}/{max_retries} failed ({error_type}). Retrying in {retry_delay} seconds..."
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
print(f"Failed to publish traces after {attempt + 1} attempts: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Publish performance traces to GitHub repository"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--traces-dir",
|
||||
type=str,
|
||||
action="append",
|
||||
dest="traces_dirs",
|
||||
required=True,
|
||||
help="Traces directory to publish (can be specified multiple times)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get environment variables
|
||||
run_id = os.getenv("GITHUB_RUN_ID", "test")
|
||||
run_number = os.getenv("GITHUB_RUN_NUMBER", "12345")
|
||||
|
||||
if not run_id or not run_number:
|
||||
print(
|
||||
"Error: GITHUB_RUN_ID and GITHUB_RUN_NUMBER environment variables must be set"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Collect trace files from all directories
|
||||
target_base_path = f"traces/{run_id}"
|
||||
all_files = []
|
||||
for traces_dir in args.traces_dirs:
|
||||
print(f"Processing traces from directory: {traces_dir}")
|
||||
files = copy_trace_files(traces_dir, target_base_path)
|
||||
all_files.extend(files)
|
||||
|
||||
if not all_files:
|
||||
print("No trace files found to upload across all directories")
|
||||
return
|
||||
|
||||
print(f"Found {len(all_files)} total files to upload")
|
||||
|
||||
# Publish all collected traces in a single commit
|
||||
publish_traces_from_files(all_files, run_id, run_number)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1748
third_party/sglang/scripts/ci/utils/query_job_status.py
vendored
Executable file
1748
third_party/sglang/scripts/ci/utils/query_job_status.py
vendored
Executable file
File diff suppressed because it is too large
Load Diff
527
third_party/sglang/scripts/ci/utils/runner_utilization_report.py
vendored
Executable file
527
third_party/sglang/scripts/ci/utils/runner_utilization_report.py
vendored
Executable file
@@ -0,0 +1,527 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Runner Utilization Report
|
||||
|
||||
Analyzes GitHub Actions job data to calculate runner utilization metrics.
|
||||
Reports idle time, active time, and utilization percentage per runner label.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Labels to skip when grouping runners (GitHub default labels)
|
||||
DEFAULT_LABELS_TO_IGNORE = {"self-hosted", "Linux", "X64", "ARM64"}
|
||||
GITHUB_HOSTED_LABELS = {"ubuntu-latest", "ubuntu-22.04", "ubuntu-24.04"}
|
||||
|
||||
|
||||
def run_gh_command(args: list[str]) -> dict:
|
||||
"""Run gh CLI command and return JSON result."""
|
||||
result = subprocess.run(
|
||||
["gh", "api"] + args,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"gh api failed: {result.stderr}")
|
||||
return json.loads(result.stdout)
|
||||
|
||||
|
||||
def get_workflow_runs(repo: str, hours: int = 24) -> list[dict]:
|
||||
"""Get workflow runs from the last N hours."""
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
runs = []
|
||||
page = 1
|
||||
while True:
|
||||
data = run_gh_command(
|
||||
[
|
||||
f"repos/{repo}/actions/runs?per_page=100&page={page}",
|
||||
]
|
||||
)
|
||||
page_runs = data.get("workflow_runs", [])
|
||||
|
||||
# Filter by time
|
||||
for run in page_runs:
|
||||
created_at = parse_time(run.get("created_at"))
|
||||
if created_at and created_at >= since:
|
||||
runs.append(run)
|
||||
elif created_at and created_at < since:
|
||||
# Runs are ordered by created_at desc, so we can stop
|
||||
return runs
|
||||
|
||||
if len(page_runs) < 100:
|
||||
break
|
||||
page += 1
|
||||
if page > 20: # Safety limit
|
||||
break
|
||||
return runs
|
||||
|
||||
|
||||
def get_jobs_for_run(repo: str, run_id: int) -> list[dict]:
|
||||
"""Get all jobs for a workflow run."""
|
||||
jobs = []
|
||||
page = 1
|
||||
while True:
|
||||
data = run_gh_command(
|
||||
[
|
||||
f"repos/{repo}/actions/runs/{run_id}/jobs?per_page=100&page={page}",
|
||||
]
|
||||
)
|
||||
jobs.extend(data.get("jobs", []))
|
||||
if len(data.get("jobs", [])) < 100:
|
||||
break
|
||||
page += 1
|
||||
if page > 5: # Safety limit
|
||||
break
|
||||
return jobs
|
||||
|
||||
|
||||
def get_runners(repo: str, online_only: bool = True) -> list[dict]:
|
||||
"""Get all self-hosted runners with pagination. Returns empty if no permission."""
|
||||
try:
|
||||
all_runners = []
|
||||
page = 1
|
||||
while True:
|
||||
data = run_gh_command(
|
||||
[f"repos/{repo}/actions/runners?per_page=100&page={page}"]
|
||||
)
|
||||
runners = data.get("runners", [])
|
||||
all_runners.extend(runners)
|
||||
if len(runners) < 100:
|
||||
break
|
||||
page += 1
|
||||
if page > 10: # Safety limit
|
||||
break
|
||||
if online_only:
|
||||
all_runners = [r for r in all_runners if r.get("status") == "online"]
|
||||
return all_runners
|
||||
except Exception as e:
|
||||
print(f"Warning: Cannot access runners API (need admin): {e}")
|
||||
return []
|
||||
|
||||
|
||||
def parse_time(time_str: str) -> datetime:
|
||||
"""Parse ISO timestamp to datetime."""
|
||||
if not time_str:
|
||||
return None
|
||||
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
# Known runner counts per label (fallback when API unavailable)
|
||||
KNOWN_RUNNER_COUNTS = {
|
||||
"1-gpu-5090": 16,
|
||||
"h200": 8,
|
||||
"h20": 4,
|
||||
"b200": 4,
|
||||
"amd": 8,
|
||||
"github-hosted": 20, # GitHub hosted runners (variable)
|
||||
"other": 10,
|
||||
}
|
||||
|
||||
|
||||
def calculate_concurrency_metrics(
|
||||
jobs: list[dict],
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
num_runners: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Calculate concurrency metrics using a sweep line algorithm.
|
||||
|
||||
Tracks:
|
||||
- Peak concurrent runners in use
|
||||
- Average concurrent runners over time
|
||||
- Time at saturation (all runners busy)
|
||||
- Queue depth when runners are saturated
|
||||
"""
|
||||
if not jobs:
|
||||
return {
|
||||
"peak_concurrent": 0,
|
||||
"avg_concurrent": 0.0,
|
||||
"saturation_seconds": 0,
|
||||
"saturation_pct": 0.0,
|
||||
"peak_queue": 0,
|
||||
}
|
||||
|
||||
window_seconds = (window_end - window_start).total_seconds()
|
||||
if window_seconds <= 0:
|
||||
return {
|
||||
"peak_concurrent": 0,
|
||||
"avg_concurrent": 0.0,
|
||||
"saturation_seconds": 0,
|
||||
"saturation_pct": 0.0,
|
||||
"peak_queue": 0,
|
||||
}
|
||||
|
||||
# Create events for running jobs: +1 at start, -1 at end
|
||||
running_events = []
|
||||
for job in jobs:
|
||||
start = job["start"]
|
||||
end = job["end"]
|
||||
# Clamp to window
|
||||
if end < window_start or start > window_end:
|
||||
continue
|
||||
clamped_start = max(start, window_start)
|
||||
clamped_end = min(end, window_end)
|
||||
running_events.append((clamped_start, 1, "start")) # +1 for start
|
||||
running_events.append((clamped_end, -1, "end")) # -1 for end
|
||||
|
||||
# Create events for queue tracking (jobs created but not started)
|
||||
queue_events = []
|
||||
for job in jobs:
|
||||
created_at = job.get("created_at")
|
||||
started_at = job["start"]
|
||||
if created_at and created_at < started_at:
|
||||
# Clamp to window
|
||||
if started_at < window_start or created_at > window_end:
|
||||
continue
|
||||
clamped_created = max(created_at, window_start)
|
||||
clamped_started = min(started_at, window_end)
|
||||
queue_events.append((clamped_created, 1, "queued"))
|
||||
queue_events.append((clamped_started, -1, "dequeued"))
|
||||
|
||||
# Sort running events: by time, then ends before starts at same time
|
||||
running_events.sort(key=lambda e: (e[0], e[1] == 1))
|
||||
|
||||
# Process running events to get concurrency metrics
|
||||
current_running = 0
|
||||
peak_running = 0
|
||||
prev_time = window_start
|
||||
total_running_seconds = 0.0
|
||||
saturation_seconds = 0.0
|
||||
|
||||
for event_time, delta, _ in running_events:
|
||||
# Accumulate time at previous concurrency level
|
||||
time_delta = (event_time - prev_time).total_seconds()
|
||||
if time_delta > 0:
|
||||
total_running_seconds += current_running * time_delta
|
||||
if current_running >= num_runners:
|
||||
saturation_seconds += time_delta
|
||||
|
||||
# Update concurrency
|
||||
current_running += delta
|
||||
peak_running = max(peak_running, current_running)
|
||||
prev_time = event_time
|
||||
|
||||
# Handle remaining time after last event
|
||||
if prev_time < window_end:
|
||||
time_delta = (window_end - prev_time).total_seconds()
|
||||
total_running_seconds += current_running * time_delta
|
||||
if current_running >= num_runners:
|
||||
saturation_seconds += time_delta
|
||||
|
||||
# Sort queue events and calculate peak queue depth
|
||||
queue_events.sort(key=lambda e: (e[0], e[1] == 1))
|
||||
current_queued = 0
|
||||
peak_queue = 0
|
||||
|
||||
for _, delta, _ in queue_events:
|
||||
current_queued += delta
|
||||
peak_queue = max(peak_queue, current_queued)
|
||||
|
||||
avg_concurrent = total_running_seconds / window_seconds if window_seconds > 0 else 0
|
||||
|
||||
return {
|
||||
"peak_concurrent": peak_running,
|
||||
"avg_concurrent": avg_concurrent,
|
||||
"saturation_seconds": saturation_seconds,
|
||||
"saturation_pct": (
|
||||
(saturation_seconds / window_seconds * 100) if window_seconds > 0 else 0
|
||||
),
|
||||
"peak_queue": peak_queue,
|
||||
}
|
||||
|
||||
|
||||
def calculate_utilization(repo: str, hours: int = 24, runner_filter: str = None):
|
||||
"""Calculate runner utilization metrics."""
|
||||
|
||||
print(f"Fetching workflow runs from last {hours} hours...")
|
||||
runs = get_workflow_runs(repo, hours)
|
||||
print(f"Found {len(runs)} workflow runs")
|
||||
|
||||
# Try to get online runners from API
|
||||
print("Fetching online runners...")
|
||||
runners = get_runners(repo, online_only=True)
|
||||
|
||||
# Build label -> set of online runner names from API
|
||||
api_label_runners = defaultdict(set)
|
||||
if runners:
|
||||
for runner in runners:
|
||||
for label in runner.get("labels", []):
|
||||
label_name = label.get("name", "")
|
||||
if label_name not in DEFAULT_LABELS_TO_IGNORE:
|
||||
api_label_runners[label_name].add(runner["name"])
|
||||
print(f"Got {len(runners)} online runners from API")
|
||||
else:
|
||||
print("No runner API access, will use observed runners from job data")
|
||||
|
||||
# Track runners seen in jobs (for labels not in API or when API unavailable)
|
||||
job_label_runners = defaultdict(set)
|
||||
label_jobs = defaultdict(list) # label -> list of job_info
|
||||
|
||||
# Fetch jobs for all runs in parallel
|
||||
total_runs = len(runs)
|
||||
print(f"Fetching jobs for {total_runs} runs in parallel...")
|
||||
|
||||
def fetch_jobs_for_run(run):
|
||||
"""Fetch jobs for a single run, returning (run_id, jobs) or (run_id, None) on error."""
|
||||
try:
|
||||
return (run["id"], get_jobs_for_run(repo, run["id"]))
|
||||
except Exception:
|
||||
return (run["id"], None)
|
||||
|
||||
all_jobs = []
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
futures = [executor.submit(fetch_jobs_for_run, run) for run in runs]
|
||||
completed = 0
|
||||
for future in as_completed(futures):
|
||||
completed += 1
|
||||
if completed % 50 == 0:
|
||||
print(f"Fetched jobs for {completed}/{total_runs} runs...")
|
||||
run_id, jobs = future.result()
|
||||
if jobs:
|
||||
all_jobs.extend(jobs)
|
||||
|
||||
print(f"Processing {len(all_jobs)} jobs...")
|
||||
|
||||
for job in all_jobs:
|
||||
runner_name = job.get("runner_name")
|
||||
if not runner_name:
|
||||
continue
|
||||
|
||||
created_at = parse_time(job.get("created_at"))
|
||||
started_at = parse_time(job.get("started_at"))
|
||||
completed_at = parse_time(job.get("completed_at"))
|
||||
|
||||
if not started_at or not completed_at:
|
||||
continue
|
||||
|
||||
duration = (completed_at - started_at).total_seconds()
|
||||
queue_time = (started_at - created_at).total_seconds() if created_at else 0
|
||||
job_info = {
|
||||
"start": started_at,
|
||||
"end": completed_at,
|
||||
"created_at": created_at,
|
||||
"duration": duration,
|
||||
"queue_time": queue_time,
|
||||
"job_name": job["name"],
|
||||
"runner_name": runner_name,
|
||||
}
|
||||
|
||||
# Use job labels directly (available in job data)
|
||||
job_labels = job.get("labels", [])
|
||||
for label in job_labels:
|
||||
# Skip generic labels
|
||||
if label in DEFAULT_LABELS_TO_IGNORE | GITHUB_HOSTED_LABELS:
|
||||
continue
|
||||
job_label_runners[label].add(runner_name)
|
||||
label_jobs[label].append(job_info)
|
||||
|
||||
# Merge API runners and job-observed runners
|
||||
# Prefer API count (online runners) when available
|
||||
all_labels = set(api_label_runners.keys()) | set(job_label_runners.keys())
|
||||
|
||||
# Filter labels if specified
|
||||
if runner_filter:
|
||||
all_labels = {lbl for lbl in all_labels if runner_filter in lbl}
|
||||
|
||||
print(f"Tracking {len(all_labels)} runner labels: {sorted(all_labels)}")
|
||||
|
||||
# Calculate metrics per label
|
||||
window_seconds = hours * 3600
|
||||
window_end = datetime.now(timezone.utc)
|
||||
window_start = window_end - timedelta(hours=hours)
|
||||
|
||||
results = []
|
||||
|
||||
for label in sorted(all_labels):
|
||||
# Use API runner count if available, otherwise use job-observed count
|
||||
if label in api_label_runners and api_label_runners[label]:
|
||||
num_runners = len(api_label_runners[label])
|
||||
elif label in job_label_runners:
|
||||
num_runners = len(job_label_runners[label])
|
||||
else:
|
||||
num_runners = KNOWN_RUNNER_COUNTS.get(label, 1)
|
||||
|
||||
total_capacity_seconds = window_seconds * num_runners
|
||||
|
||||
jobs = label_jobs.get(label, [])
|
||||
total_active_seconds = sum(j["duration"] for j in jobs)
|
||||
|
||||
utilization = (
|
||||
(total_active_seconds / total_capacity_seconds * 100)
|
||||
if total_capacity_seconds > 0
|
||||
else 0
|
||||
)
|
||||
idle_seconds = total_capacity_seconds - total_active_seconds
|
||||
|
||||
# Calculate queue time metrics
|
||||
queue_times = [j["queue_time"] for j in jobs if j["queue_time"] > 0]
|
||||
avg_queue_time = sum(queue_times) / len(queue_times) if queue_times else 0
|
||||
max_queue_time = max(queue_times) if queue_times else 0
|
||||
|
||||
# Calculate concurrency metrics
|
||||
# First pass: get peak concurrent to determine effective capacity
|
||||
concurrency_initial = calculate_concurrency_metrics(
|
||||
jobs, window_start, window_end, num_runners
|
||||
)
|
||||
|
||||
# Use observed peak as effective capacity if lower than API count
|
||||
# This handles cases where not all runners are active all the time
|
||||
effective_runners = min(num_runners, concurrency_initial["peak_concurrent"])
|
||||
if effective_runners < num_runners and effective_runners > 0:
|
||||
# Recalculate with effective capacity for accurate saturation
|
||||
concurrency = calculate_concurrency_metrics(
|
||||
jobs, window_start, window_end, effective_runners
|
||||
)
|
||||
else:
|
||||
concurrency = concurrency_initial
|
||||
effective_runners = num_runners
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"num_runners": num_runners,
|
||||
"effective_runners": effective_runners,
|
||||
"num_jobs": len(jobs),
|
||||
"total_active_hours": total_active_seconds / 3600,
|
||||
"total_idle_hours": idle_seconds / 3600,
|
||||
"total_capacity_hours": total_capacity_seconds / 3600,
|
||||
"utilization_pct": utilization,
|
||||
"avg_queue_min": avg_queue_time / 60,
|
||||
"max_queue_min": max_queue_time / 60,
|
||||
# Concurrency metrics
|
||||
"peak_concurrent": concurrency_initial["peak_concurrent"],
|
||||
"avg_concurrent": concurrency["avg_concurrent"],
|
||||
"saturation_hours": concurrency["saturation_seconds"] / 3600,
|
||||
"saturation_pct": concurrency["saturation_pct"],
|
||||
"peak_queue": concurrency["peak_queue"],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_report(results: list[dict], hours: int) -> str:
|
||||
"""Format results as markdown report."""
|
||||
lines = [
|
||||
"# Runner Utilization Report",
|
||||
"",
|
||||
f"**Time window:** Last {hours} hours",
|
||||
f"**Generated:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}",
|
||||
"",
|
||||
"## Concurrency Analysis",
|
||||
"",
|
||||
"| Label | Runners (API/Effective) | Peak Concurrent | Avg Concurrent | Saturation Time | Peak Queue |",
|
||||
"|-------|-------------------------|-----------------|----------------|-----------------|------------|",
|
||||
]
|
||||
|
||||
for r in results:
|
||||
effective = r["effective_runners"]
|
||||
avg_pct = (r["avg_concurrent"] / effective * 100) if effective > 0 else 0
|
||||
runner_str = (
|
||||
f"{r['num_runners']}/{effective}"
|
||||
if effective != r["num_runners"]
|
||||
else str(r["num_runners"])
|
||||
)
|
||||
lines.append(
|
||||
f"| {r['label']} | {runner_str} | "
|
||||
f"{r['peak_concurrent']} | "
|
||||
f"{r['avg_concurrent']:.1f} ({avg_pct:.0f}%) | "
|
||||
f"{r['saturation_hours']:.1f}h ({r['saturation_pct']:.0f}%) | "
|
||||
f"{r['peak_queue']} jobs |"
|
||||
)
|
||||
|
||||
# Add recommendations section
|
||||
lines.extend(["", "## Recommendations", ""])
|
||||
has_recommendations = False
|
||||
for r in results:
|
||||
label = r["label"]
|
||||
saturation_pct = r["saturation_pct"]
|
||||
peak_queue = r["peak_queue"]
|
||||
effective = r["effective_runners"]
|
||||
avg_pct = (r["avg_concurrent"] / effective * 100) if effective > 0 else 0
|
||||
|
||||
if saturation_pct > 50 or peak_queue > 5:
|
||||
lines.append(
|
||||
f"⚠️ **{label}**: High saturation ({saturation_pct:.0f}%) "
|
||||
f"with queue buildup ({peak_queue} jobs). Consider adding runners."
|
||||
)
|
||||
has_recommendations = True
|
||||
elif saturation_pct > 20 or peak_queue > 0:
|
||||
lines.append(
|
||||
f"📊 **{label}**: Moderate saturation ({saturation_pct:.0f}%), "
|
||||
f"peak queue {peak_queue} jobs. Monitor for trends."
|
||||
)
|
||||
has_recommendations = True
|
||||
elif avg_pct < 30 and r["num_jobs"] > 0:
|
||||
lines.append(
|
||||
f"💡 **{label}**: Low average utilization ({avg_pct:.0f}%). "
|
||||
f"Runner pool may be oversized."
|
||||
)
|
||||
has_recommendations = True
|
||||
else:
|
||||
lines.append(f"✓ **{label}**: Healthy utilization with minimal queueing.")
|
||||
|
||||
if not has_recommendations and results:
|
||||
lines.append("All runner pools have healthy utilization.")
|
||||
|
||||
# Add summary table
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Summary by Runner Label",
|
||||
"",
|
||||
"| Label | Runners | Jobs | Active (hrs) | Utilization | Avg Queue | Max Queue |",
|
||||
"|-------|---------|------|--------------|-------------|-----------|-----------|",
|
||||
]
|
||||
)
|
||||
|
||||
for r in results:
|
||||
utilization_bar = "█" * int(r["utilization_pct"] / 10) + "░" * (
|
||||
10 - int(r["utilization_pct"] / 10)
|
||||
)
|
||||
lines.append(
|
||||
f"| {r['label']} | {r['num_runners']} | {r['num_jobs']} | "
|
||||
f"{r['total_active_hours']:.1f} | "
|
||||
f"{r['utilization_pct']:.1f}% {utilization_bar} | "
|
||||
f"{r['avg_queue_min']:.1f}m | {r['max_queue_min']:.1f}m |"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate runner utilization report")
|
||||
parser.add_argument("--repo", default="sgl-project/sglang", help="GitHub repo")
|
||||
parser.add_argument("--hours", type=int, default=24, help="Time window in hours")
|
||||
parser.add_argument(
|
||||
"--filter", type=str, help="Filter runner labels (e.g., '5090', 'h200')"
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output file (default: stdout)")
|
||||
args = parser.parse_args()
|
||||
|
||||
results = calculate_utilization(args.repo, args.hours, args.filter)
|
||||
report = format_report(results, args.hours)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
f.write(report)
|
||||
print(f"Report written to {args.output}")
|
||||
else:
|
||||
print(report)
|
||||
|
||||
# Also write to GITHUB_STEP_SUMMARY if available
|
||||
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
|
||||
if summary_file:
|
||||
with open(summary_file, "a") as f:
|
||||
f.write(report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
245
third_party/sglang/scripts/ci/utils/save_metrics.py
vendored
Executable file
245
third_party/sglang/scripts/ci/utils/save_metrics.py
vendored
Executable file
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Collect and save performance metrics from nightly benchmark results.
|
||||
|
||||
This script reads benchmark result JSON files from performance profile directories
|
||||
and saves them with metadata for artifact collection in CI.
|
||||
|
||||
Usage:
|
||||
python3 scripts/ci/utils/save_metrics.py \
|
||||
--gpu-config 8-gpu-h200 \
|
||||
--partition 0 \
|
||||
--run-id 12345678 \
|
||||
--output test/metrics-8gpu-h200-partition-0.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def find_result_files(search_dirs: list[str]) -> list[str]:
|
||||
"""Find all results_*.json files in the given directories."""
|
||||
result_files = set()
|
||||
for search_dir in search_dirs:
|
||||
if os.path.exists(search_dir):
|
||||
pattern = os.path.join(search_dir, "**/results_*.json")
|
||||
result_files.update(glob.glob(pattern, recursive=True))
|
||||
return list(result_files)
|
||||
|
||||
|
||||
def parse_result_file(filepath: str) -> list[dict]:
|
||||
"""Parse a benchmark result JSON file."""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
return [data]
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"Warning: Failed to parse {filepath}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def transform_benchmark_result(result: dict, gpu_config: str, partition: int) -> dict:
|
||||
"""Transform a benchmark result to the metrics schema.
|
||||
|
||||
Note: input_len and output_len are preserved here for the flat benchmarks list,
|
||||
but are also used as grouping keys in benchmarks_by_io_len.
|
||||
"""
|
||||
# Handle None values safely for numeric conversions
|
||||
latency = result.get("latency")
|
||||
last_ttft = result.get("last_ttft")
|
||||
|
||||
return {
|
||||
"batch_size": result.get("batch_size"),
|
||||
"input_len": result.get("input_len"),
|
||||
"output_len": result.get("output_len"),
|
||||
"latency_ms": latency * 1000 if latency is not None else None,
|
||||
"input_throughput": result.get("input_throughput"),
|
||||
"output_throughput": result.get("output_throughput"),
|
||||
"overall_throughput": result.get("overall_throughput"),
|
||||
"ttft_ms": last_ttft * 1000 if last_ttft is not None else None,
|
||||
"acc_length": result.get("acc_length"),
|
||||
}
|
||||
|
||||
|
||||
def get_io_len_key(input_len: int, output_len: int) -> str:
|
||||
"""Generate a key for input/output length combination."""
|
||||
return f"{input_len}_{output_len}"
|
||||
|
||||
|
||||
def group_results_by_model(
|
||||
results: list[dict], gpu_config: str, partition: int
|
||||
) -> list[dict]:
|
||||
"""Group benchmark results by model, variant, and server_args.
|
||||
|
||||
Results are organized with two benchmark structures:
|
||||
- benchmarks: flat list of all benchmarks (for backward compatibility)
|
||||
- benchmarks_by_io_len: nested structure grouped by input/output length combinations
|
||||
"""
|
||||
groups = {}
|
||||
|
||||
for result in results:
|
||||
model_path = result.get("model_path", "unknown")
|
||||
run_name = result.get("run_name", "default")
|
||||
variant = run_name if run_name != "default" else None
|
||||
server_args = result.get("server_args")
|
||||
# Convert server_args list to tuple for use as dict key (lists are not hashable)
|
||||
server_args_key = tuple(server_args) if server_args else None
|
||||
|
||||
key = (model_path, variant, server_args_key)
|
||||
if key not in groups:
|
||||
groups[key] = {
|
||||
"gpu_config": gpu_config,
|
||||
"partition": partition,
|
||||
"model": model_path,
|
||||
"variant": variant,
|
||||
"server_args": server_args,
|
||||
"benchmarks": [],
|
||||
"benchmarks_by_io_len": {},
|
||||
}
|
||||
|
||||
transformed = transform_benchmark_result(result, gpu_config, partition)
|
||||
|
||||
# Add to flat benchmarks list (backward compatibility)
|
||||
groups[key]["benchmarks"].append(transformed)
|
||||
|
||||
# Add to nested benchmarks_by_io_len structure
|
||||
input_len = result.get("input_len")
|
||||
output_len = result.get("output_len")
|
||||
if input_len is not None and output_len is not None:
|
||||
io_key = get_io_len_key(input_len, output_len)
|
||||
if io_key not in groups[key]["benchmarks_by_io_len"]:
|
||||
groups[key]["benchmarks_by_io_len"][io_key] = {
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
"benchmarks": [],
|
||||
}
|
||||
# For the nested structure, exclude input_len and output_len from individual benchmarks
|
||||
# since they're already in the parent
|
||||
nested_benchmark = {
|
||||
k: v
|
||||
for k, v in transformed.items()
|
||||
if k not in ("input_len", "output_len")
|
||||
}
|
||||
groups[key]["benchmarks_by_io_len"][io_key]["benchmarks"].append(
|
||||
nested_benchmark
|
||||
)
|
||||
|
||||
return list(groups.values())
|
||||
|
||||
|
||||
def save_metrics(
|
||||
gpu_config: str,
|
||||
partition: int,
|
||||
run_id: str,
|
||||
output_file: str,
|
||||
search_dirs: list[str],
|
||||
) -> bool:
|
||||
"""Collect metrics and save to output file."""
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Find all result files
|
||||
result_files = find_result_files(search_dirs)
|
||||
print(f"Found {len(result_files)} result file(s)")
|
||||
|
||||
grouped = []
|
||||
if not result_files:
|
||||
print("No benchmark result files found")
|
||||
else:
|
||||
# Parse all result files
|
||||
all_results = []
|
||||
for filepath in sorted(result_files):
|
||||
print(f" Reading: {filepath}")
|
||||
results = parse_result_file(filepath)
|
||||
all_results.extend(results)
|
||||
print(f"Total benchmark results: {len(all_results)}")
|
||||
|
||||
# Group by model/variant
|
||||
grouped = group_results_by_model(all_results, gpu_config, partition)
|
||||
|
||||
# Create metrics structure
|
||||
metrics = {
|
||||
"run_id": run_id,
|
||||
"timestamp": timestamp,
|
||||
"gpu_config": gpu_config,
|
||||
"partition": partition,
|
||||
"results": grouped,
|
||||
}
|
||||
|
||||
# Ensure output directory exists and write output
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
if not result_files:
|
||||
print(f"Created empty metrics file: {output_file}")
|
||||
else:
|
||||
print(f"Saved metrics to: {output_file}")
|
||||
return True
|
||||
except OSError as e:
|
||||
print(f"Error writing metrics file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Collect performance metrics from benchmark results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-config",
|
||||
required=True,
|
||||
help="GPU configuration (e.g., 8-gpu-h200, 8-gpu-b200)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Partition number (0, 1, 2, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-id",
|
||||
required=True,
|
||||
help="GitHub Actions run ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=True,
|
||||
help="Output file path for metrics JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search-dir",
|
||||
action="append",
|
||||
default=[],
|
||||
dest="search_dirs",
|
||||
help="Directory to search for result files (can be specified multiple times)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Default search directories if none specified
|
||||
search_dirs = args.search_dirs or [
|
||||
"test/performance_profiles_8_gpu",
|
||||
"test/performance_profiles_text_models",
|
||||
"test/performance_profiles_vlms",
|
||||
"test",
|
||||
".",
|
||||
]
|
||||
|
||||
success = save_metrics(
|
||||
gpu_config=args.gpu_config,
|
||||
partition=args.partition,
|
||||
run_id=args.run_id,
|
||||
output_file=args.output,
|
||||
search_dirs=search_dirs,
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
870
third_party/sglang/scripts/ci/utils/slash_command_handler.py
vendored
Normal file
870
third_party/sglang/scripts/ci/utils/slash_command_handler.py
vendored
Normal file
@@ -0,0 +1,870 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import requests
|
||||
from github import Auth, Github
|
||||
|
||||
# Configuration
|
||||
PERMISSIONS_FILE_PATH = ".github/CI_PERMISSIONS.json"
|
||||
|
||||
|
||||
def find_workflow_run_url(
|
||||
gh_repo,
|
||||
workflow_id,
|
||||
ref,
|
||||
target_stage,
|
||||
token,
|
||||
dispatch_time,
|
||||
pr_head_sha=None,
|
||||
max_wait=30,
|
||||
test_command=None,
|
||||
):
|
||||
"""
|
||||
Poll for the workflow run URL after dispatch.
|
||||
|
||||
Uses the dynamic run-name feature to identify runs:
|
||||
- Fork PRs: display_title = "[stage-name] sha"
|
||||
- Non-fork PRs: display_title = "[stage-name]"
|
||||
|
||||
Args:
|
||||
gh_repo: PyGithub repository object
|
||||
workflow_id: ID of the workflow that was dispatched
|
||||
ref: Branch/ref the workflow was dispatched on
|
||||
target_stage: The stage name we're looking for
|
||||
token: GitHub API token
|
||||
dispatch_time: Unix timestamp when dispatch was triggered
|
||||
pr_head_sha: PR head SHA (for fork PRs, used to match display_title)
|
||||
max_wait: Maximum seconds to wait for the run to appear
|
||||
|
||||
Returns:
|
||||
The workflow run URL if found, None otherwise.
|
||||
"""
|
||||
# Build expected display_title based on workflow's run-name.
|
||||
# rerun-test includes test_command: "[rerun-test] <test_command> [<sha>]"
|
||||
# Other workflows: "[stage-name] [<sha>]"
|
||||
suffix = f" {test_command}" if test_command else ""
|
||||
if pr_head_sha:
|
||||
expected_title = f"[{target_stage}]{suffix} {pr_head_sha}"
|
||||
else:
|
||||
expected_title = f"[{target_stage}]{suffix}"
|
||||
|
||||
print(f"Looking for workflow run with display_title: {expected_title}")
|
||||
|
||||
for attempt in range(max_wait // 5):
|
||||
time.sleep(5)
|
||||
|
||||
# Get recent workflow_dispatch runs for this workflow
|
||||
runs_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{workflow_id}/runs"
|
||||
runs_resp = requests.get(
|
||||
runs_url,
|
||||
params={"event": "workflow_dispatch", "branch": ref, "per_page": 10},
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
},
|
||||
)
|
||||
|
||||
if runs_resp.status_code != 200:
|
||||
print(f"Failed to fetch workflow runs: {runs_resp.status_code}")
|
||||
continue
|
||||
|
||||
for run in runs_resp.json().get("workflow_runs", []):
|
||||
# Skip runs created before our dispatch (with 10s tolerance)
|
||||
run_created = datetime.fromisoformat(
|
||||
run["created_at"].replace("Z", "+00:00")
|
||||
).timestamp()
|
||||
if run_created < dispatch_time - 10:
|
||||
continue
|
||||
|
||||
# Match by display_title (set by workflow's run-name directive)
|
||||
# This is immediately available, unlike job names which require waiting
|
||||
display_title = run.get("display_title", "")
|
||||
if display_title == expected_title:
|
||||
print(
|
||||
f"Found matching workflow run: {run['id']} with title '{display_title}'"
|
||||
)
|
||||
return run["html_url"]
|
||||
|
||||
print(f"Could not find workflow run after {max_wait} seconds")
|
||||
return None
|
||||
|
||||
|
||||
def get_env_var(name):
|
||||
val = os.getenv(name)
|
||||
if not val:
|
||||
print(f"Error: Environment variable {name} not set.")
|
||||
sys.exit(1)
|
||||
return val
|
||||
|
||||
|
||||
def load_permissions(user_login):
|
||||
"""
|
||||
Reads the permissions JSON from the local file system and returns
|
||||
the permissions dict for the specific user.
|
||||
"""
|
||||
try:
|
||||
print(f"Loading permissions from {PERMISSIONS_FILE_PATH}...")
|
||||
if not os.path.exists(PERMISSIONS_FILE_PATH):
|
||||
print(f"Error: Permissions file not found at {PERMISSIONS_FILE_PATH}")
|
||||
return None
|
||||
|
||||
with open(PERMISSIONS_FILE_PATH, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
user_perms = data.get(user_login)
|
||||
|
||||
if not user_perms:
|
||||
print(f"User '{user_login}' not found in permissions file.")
|
||||
return None
|
||||
|
||||
return user_perms
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load or parse permissions file: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def has_sgl_kernel_changes(pr):
|
||||
"""
|
||||
Check if the PR has changes to the sgl-kernel directory.
|
||||
This is used to determine if we need a full workflow rerun
|
||||
(to rebuild the kernel) vs just rerunning failed jobs.
|
||||
"""
|
||||
try:
|
||||
files = pr.get_files()
|
||||
for f in files:
|
||||
if f.filename.startswith("sgl-kernel/"):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not check PR files for sgl-kernel changes: {e}")
|
||||
# Default to False to avoid unnecessary full reruns
|
||||
return False
|
||||
|
||||
|
||||
def handle_tag_run_ci(gh_repo, pr, comment, user_perms, react_on_success=True):
|
||||
"""
|
||||
Handles the /tag-run-ci-label command.
|
||||
Returns True if action was taken, False otherwise.
|
||||
"""
|
||||
if not user_perms.get("can_tag_run_ci_label", False):
|
||||
print("Permission denied: can_tag_run_ci_label is false.")
|
||||
return False
|
||||
|
||||
print("Permission granted. Adding 'run-ci' label.")
|
||||
pr.add_to_labels("run-ci")
|
||||
|
||||
if react_on_success:
|
||||
comment.create_reaction("+1")
|
||||
print("Label added and comment reacted.")
|
||||
else:
|
||||
print("Label added (reaction suppressed).")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def handle_rerun_failed_ci(gh_repo, pr, comment, user_perms, react_on_success=True):
|
||||
"""
|
||||
Handles the /rerun-failed-ci command.
|
||||
Reruns workflows with 'failure' or 'skipped' conclusions.
|
||||
Returns True if action was taken, False otherwise.
|
||||
"""
|
||||
if not user_perms.get("can_rerun_failed_ci", False):
|
||||
print("Permission denied: can_rerun_failed_ci is false.")
|
||||
return False
|
||||
|
||||
print("Permission granted. Triggering rerun of failed or skipped workflows.")
|
||||
|
||||
# Check if PR has sgl-kernel changes - if so, we need full reruns
|
||||
# to ensure sgl-kernel-build-wheels runs and produces fresh artifacts
|
||||
sgl_kernel_changes = has_sgl_kernel_changes(pr)
|
||||
if sgl_kernel_changes:
|
||||
print("PR has sgl-kernel changes - will use full rerun to rebuild kernel")
|
||||
|
||||
# Get the SHA of the latest commit in the PR
|
||||
head_sha = pr.head.sha
|
||||
print(f"Checking workflows for commit: {head_sha}")
|
||||
|
||||
# List all workflow runs for this commit
|
||||
runs = gh_repo.get_workflow_runs(head_sha=head_sha)
|
||||
|
||||
rerun_count = 0
|
||||
for run in runs:
|
||||
if run.status != "completed":
|
||||
continue
|
||||
|
||||
if run.conclusion == "failure":
|
||||
print(f"Rerunning failed workflow: {run.name} (ID: {run.id})")
|
||||
try:
|
||||
if sgl_kernel_changes:
|
||||
# Full rerun to ensure sgl-kernel-build-wheels runs
|
||||
# and produces fresh artifacts for dependent jobs
|
||||
run.rerun()
|
||||
else:
|
||||
# Use rerun_failed_jobs for efficiency on failures
|
||||
run.rerun_failed_jobs()
|
||||
rerun_count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to rerun workflow {run.id}: {e}")
|
||||
|
||||
elif run.conclusion == "skipped":
|
||||
print(f"Rerunning skipped workflow: {run.name} (ID: {run.id})")
|
||||
try:
|
||||
# Skipped workflows don't have 'failed jobs', so we use full rerun()
|
||||
run.rerun()
|
||||
rerun_count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to rerun workflow {run.id}: {e}")
|
||||
|
||||
if rerun_count > 0:
|
||||
print(f"Triggered rerun for {rerun_count} workflows.")
|
||||
if react_on_success:
|
||||
comment.create_reaction("+1")
|
||||
return True
|
||||
else:
|
||||
print("No failed or skipped workflows found to rerun.")
|
||||
return False
|
||||
|
||||
|
||||
def handle_rerun_stage(
|
||||
gh_repo, pr, comment, user_perms, stage_name, token, react_on_success=True
|
||||
):
|
||||
"""
|
||||
Handles the /rerun-stage <stage-name> command.
|
||||
Triggers a workflow_dispatch to run only the specified stage, skipping dependencies.
|
||||
Returns True if action was taken, False otherwise.
|
||||
"""
|
||||
if not user_perms.get("can_rerun_stage", False):
|
||||
print("Permission denied: can_rerun_stage is false.")
|
||||
return False
|
||||
|
||||
if not stage_name:
|
||||
print("Error: No stage name provided")
|
||||
comment.create_reaction("confused")
|
||||
pr.create_issue_comment(
|
||||
f"❌ Please specify a stage name: `/rerun-stage <stage-name>`\n\n"
|
||||
f"Examples: `/rerun-stage unit-test-backend-4-gpu`, `/rerun-stage accuracy-test-1-gpu`"
|
||||
)
|
||||
return False
|
||||
|
||||
print(f"Permission granted. Triggering workflow_dispatch for stage '{stage_name}'.")
|
||||
|
||||
# Valid NVIDIA stage names that support target_stage
|
||||
nvidia_stages = [
|
||||
"stage-a-test-1-gpu-small",
|
||||
"stage-a-test-cpu",
|
||||
"stage-b-test-1-gpu-small",
|
||||
"stage-b-test-1-gpu-large",
|
||||
"stage-b-test-2-gpu-large",
|
||||
"stage-b-test-4-gpu-b200",
|
||||
"stage-c-test-4-gpu-h100",
|
||||
"stage-c-test-8-gpu-h200",
|
||||
"stage-c-test-8-gpu-h20",
|
||||
"stage-c-test-4-gpu-b200",
|
||||
"stage-c-test-4-gpu-gb200",
|
||||
"stage-c-test-deepep-4-gpu-h100",
|
||||
"stage-c-test-deepep-8-gpu-h200",
|
||||
"multimodal-gen-test-1-gpu",
|
||||
"multimodal-gen-test-2-gpu",
|
||||
"multimodal-gen-test-1-b200",
|
||||
]
|
||||
|
||||
# Valid AMD stage names that support target_stage
|
||||
amd_stages = [
|
||||
"sgl-kernel-unit-test-amd",
|
||||
"sgl-kernel-unit-test-2-gpu-amd",
|
||||
"stage-a-test-1-gpu-small-amd",
|
||||
"stage-b-test-1-gpu-small-amd",
|
||||
"stage-b-test-1-gpu-small-amd-nondeterministic",
|
||||
"stage-b-test-1-gpu-small-amd-mi35x",
|
||||
"stage-b-test-1-gpu-large-amd",
|
||||
"stage-b-test-2-gpu-large-amd",
|
||||
"multimodal-gen-test-1-gpu-amd",
|
||||
"multimodal-gen-test-2-gpu-amd",
|
||||
"stage-c-test-large-8-gpu-amd",
|
||||
"stage-c-test-large-8-gpu-amd-mi35x",
|
||||
]
|
||||
|
||||
valid_stages = nvidia_stages + amd_stages
|
||||
is_amd_stage = stage_name in amd_stages
|
||||
|
||||
if stage_name not in valid_stages:
|
||||
comment.create_reaction("confused")
|
||||
pr.create_issue_comment(
|
||||
f"❌ Stage `{stage_name}` doesn't support isolated runs yet.\n\n"
|
||||
f"**NVIDIA stages:**\n"
|
||||
+ "\n".join(f"- `{s}`" for s in nvidia_stages)
|
||||
+ "\n\n**AMD stages:**\n"
|
||||
+ "\n".join(f"- `{s}`" for s in amd_stages)
|
||||
+ "\n\nOther stages will be added soon. For now, use `/rerun-failed-ci` for those stages."
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get the appropriate workflow based on stage type
|
||||
workflow_name = "PR Test (AMD)" if is_amd_stage else "PR Test"
|
||||
workflows = gh_repo.get_workflows()
|
||||
target_workflow = None
|
||||
for wf in workflows:
|
||||
if wf.name == workflow_name:
|
||||
target_workflow = wf
|
||||
break
|
||||
|
||||
if not target_workflow:
|
||||
print(f"Error: {workflow_name} workflow not found")
|
||||
return False
|
||||
|
||||
# Check if PR is from a fork by comparing repo owners
|
||||
# Handle case where fork repo may have been deleted (pr.head.repo is None)
|
||||
is_fork = (
|
||||
pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
|
||||
)
|
||||
print(f"PR is from fork: {is_fork}")
|
||||
|
||||
# pr_head_sha is used for fork PRs (passed to workflow and used for URL lookup)
|
||||
pr_head_sha = None
|
||||
|
||||
if is_fork:
|
||||
# For fork PRs: dispatch on main and pass SHA as input
|
||||
# This is needed because fork branch names don't exist in the main repo
|
||||
ref = "main"
|
||||
pr_head_sha = pr.head.sha
|
||||
print(
|
||||
f"Triggering {workflow_name} workflow on ref: {ref}, PR head SHA: {pr_head_sha}"
|
||||
)
|
||||
if is_amd_stage:
|
||||
inputs = {
|
||||
"target_stage": stage_name,
|
||||
"pr_head_sha": pr_head_sha,
|
||||
}
|
||||
else:
|
||||
inputs = {
|
||||
"target_stage": stage_name,
|
||||
"pr_head_sha": pr_head_sha,
|
||||
}
|
||||
else:
|
||||
# For non-fork PRs: dispatch on the PR branch directly
|
||||
# This allows testing workflow changes before merge
|
||||
ref = pr.head.ref
|
||||
print(f"Triggering {workflow_name} workflow on branch: {ref}")
|
||||
if is_amd_stage:
|
||||
inputs = {"target_stage": stage_name}
|
||||
else:
|
||||
inputs = {"target_stage": stage_name}
|
||||
|
||||
# Record dispatch time before triggering
|
||||
dispatch_time = time.time()
|
||||
|
||||
# Use requests directly as PyGithub's create_dispatch only accepts HTTP 204
|
||||
dispatch_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{target_workflow.id}/dispatches"
|
||||
dispatch_resp = requests.post(
|
||||
dispatch_url,
|
||||
json={"ref": ref, "inputs": inputs},
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
},
|
||||
)
|
||||
success = dispatch_resp.status_code in (200, 204)
|
||||
if not success:
|
||||
print(f"Dispatch failed: {dispatch_resp.status_code} {dispatch_resp.text}")
|
||||
|
||||
if success:
|
||||
print(f"Successfully triggered workflow for stage '{stage_name}'")
|
||||
if react_on_success:
|
||||
comment.create_reaction("+1")
|
||||
|
||||
run_url = find_workflow_run_url(
|
||||
gh_repo,
|
||||
target_workflow.id,
|
||||
ref,
|
||||
stage_name,
|
||||
token,
|
||||
dispatch_time,
|
||||
pr_head_sha=pr_head_sha,
|
||||
max_wait=30,
|
||||
)
|
||||
if run_url:
|
||||
pr.create_issue_comment(
|
||||
f"✅ Triggered `{stage_name}` to run independently"
|
||||
f" (skipping dependencies)."
|
||||
f" [View workflow run]({run_url})"
|
||||
)
|
||||
else:
|
||||
pr.create_issue_comment(
|
||||
f"✅ Triggered `{stage_name}` to run independently"
|
||||
f" (skipping dependencies).\n"
|
||||
f"⚠️ Could not retrieve workflow run URL. "
|
||||
f"Check the [Actions tab](https://github.com/{gh_repo.full_name}/actions) for progress."
|
||||
)
|
||||
return True
|
||||
else:
|
||||
print("Failed to trigger workflow_dispatch")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error triggering workflow_dispatch: {e}")
|
||||
comment.create_reaction("confused")
|
||||
pr.create_issue_comment(
|
||||
f"❌ Failed to trigger workflow: {str(e)}\n\n"
|
||||
f"Please check the logs or contact maintainers."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
CUDA_SUITE_TO_RUNNER = {
|
||||
"stage-a-test-1-gpu-small": "1-gpu-5090",
|
||||
"stage-a-test-cpu": "ubuntu-latest",
|
||||
"stage-b-test-1-gpu-small": "1-gpu-5090",
|
||||
"stage-b-test-1-gpu-large": "1-gpu-h100",
|
||||
"stage-b-test-2-gpu-large": "2-gpu-h100",
|
||||
"stage-b-test-4-gpu-b200": "4-gpu-b200",
|
||||
"stage-c-test-4-gpu-h100": "4-gpu-h100",
|
||||
"stage-c-test-8-gpu-h200": "8-gpu-h200",
|
||||
"stage-c-test-8-gpu-h20": "8-gpu-h20",
|
||||
"stage-c-test-4-gpu-b200": "4-gpu-b200",
|
||||
"stage-c-test-deepep-4-gpu-h100": "4-gpu-h100",
|
||||
"stage-c-test-deepep-8-gpu-h200": "8-gpu-h200",
|
||||
}
|
||||
|
||||
DEEPEP_SUITES = {
|
||||
"stage-c-test-8-gpu-h20",
|
||||
"stage-c-test-deepep-4-gpu-h100",
|
||||
"stage-c-test-deepep-8-gpu-h200",
|
||||
}
|
||||
|
||||
|
||||
def resolve_test_file(file_part):
|
||||
"""
|
||||
Resolve a user-provided file path to a path relative to test/.
|
||||
|
||||
Supports:
|
||||
- Full path: test/registered/core/test_srt_endpoint.py
|
||||
- Relative to test/: registered/core/test_srt_endpoint.py
|
||||
- Bare filename: test_srt_endpoint.py (glob-matched, must be unique)
|
||||
|
||||
Returns (resolved_path, error_message). On success error_message is None.
|
||||
"""
|
||||
if file_part.startswith("test/"):
|
||||
file_part = file_part[len("test/") :]
|
||||
|
||||
if "/" not in file_part:
|
||||
matches = glob.glob(f"test/registered/**/{file_part}", recursive=True)
|
||||
if len(matches) == 0:
|
||||
return (
|
||||
None,
|
||||
f"No test file found matching `{file_part}` under `test/registered/`.",
|
||||
)
|
||||
if len(matches) > 1:
|
||||
match_list = "\n".join(f"- `{m}`" for m in sorted(matches))
|
||||
return None, (
|
||||
f"Ambiguous filename `{file_part}` — matched {len(matches)} files:\n\n"
|
||||
f"{match_list}\n\n"
|
||||
f"Please provide the full path, e.g. `/rerun-test {matches[0]}`"
|
||||
)
|
||||
return matches[0][len("test/") :], None
|
||||
|
||||
full_path = f"test/{file_part}"
|
||||
if not os.path.isfile(full_path):
|
||||
return None, f"File not found: `{full_path}`"
|
||||
return file_part, None
|
||||
|
||||
|
||||
def detect_suite(file_path_from_test):
|
||||
"""
|
||||
Read a test file and extract the suite from register_cuda_ci or register_cpu_ci.
|
||||
|
||||
Returns (suite_name, runner_label, use_deepep, is_cpu, error_message).
|
||||
"""
|
||||
full_path = f"test/{file_path_from_test}"
|
||||
with open(full_path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# Try CUDA first
|
||||
match = re.search(
|
||||
r'^[^#\n]*register_cuda_ci\([^)]*suite\s*=\s*["\']([^"\']+)["\']',
|
||||
content,
|
||||
re.MULTILINE,
|
||||
)
|
||||
if match:
|
||||
suite = match.group(1)
|
||||
runner = CUDA_SUITE_TO_RUNNER.get(suite)
|
||||
if not runner:
|
||||
known = ", ".join(f"`{s}`" for s in sorted(CUDA_SUITE_TO_RUNNER))
|
||||
return (
|
||||
suite,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
(
|
||||
f"Unknown CUDA suite `{suite}` in `{full_path}`.\n\n"
|
||||
f"Known suites: {known}"
|
||||
),
|
||||
)
|
||||
use_deepep = suite in DEEPEP_SUITES
|
||||
return suite, runner, use_deepep, False, None
|
||||
|
||||
# Try CPU
|
||||
match = re.search(
|
||||
r'^[^#\n]*register_cpu_ci\([^)]*suite\s*=\s*["\']([^"\']+)["\']',
|
||||
content,
|
||||
re.MULTILINE,
|
||||
)
|
||||
if match:
|
||||
suite = match.group(1)
|
||||
return suite, "ubuntu-latest", False, True, None
|
||||
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
(
|
||||
f"No `register_cuda_ci()` or `register_cpu_ci()` found in `{full_path}`.\n\n"
|
||||
f"This file may not be a registered CI test."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_test_spec(test_spec):
|
||||
"""
|
||||
Resolve a single test spec into its components without dispatching.
|
||||
|
||||
Returns a dict with keys: spec, resolved_path, test_command, suite,
|
||||
runner_label, use_deepep, is_cpu, error.
|
||||
"""
|
||||
if "::" in test_spec:
|
||||
file_part, test_selector = test_spec.split("::", 1)
|
||||
else:
|
||||
file_part = test_spec
|
||||
test_selector = None
|
||||
|
||||
file_part = file_part.strip()
|
||||
if test_selector:
|
||||
test_selector = test_selector.strip()
|
||||
|
||||
resolved_path, err = resolve_test_file(file_part)
|
||||
if err:
|
||||
return {"spec": test_spec, "error": err}
|
||||
|
||||
suite, runner_label, use_deepep, is_cpu, err = detect_suite(resolved_path)
|
||||
if err:
|
||||
return {"spec": test_spec, "error": err}
|
||||
|
||||
test_command = resolved_path
|
||||
if test_selector:
|
||||
test_command = f"{resolved_path} {test_selector}"
|
||||
|
||||
print(
|
||||
f"Resolved: file={resolved_path}, selector={test_selector}, "
|
||||
f"suite={suite}, runner={runner_label}, deepep={use_deepep}, "
|
||||
f"cpu={is_cpu}, command='{test_command}'"
|
||||
)
|
||||
return {
|
||||
"spec": test_spec,
|
||||
"test_command": test_command,
|
||||
"suite": suite,
|
||||
"runner_label": runner_label,
|
||||
"use_deepep": use_deepep,
|
||||
"is_cpu": is_cpu,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
def _dispatch_batch(gh_repo, pr, batch, token):
|
||||
"""
|
||||
Dispatch a single workflow run for a batch of resolved test specs
|
||||
that share the same (runner_label, use_deepep, is_cpu).
|
||||
|
||||
Returns a dict with keys: specs, success, test_commands, runner_label, run_url, error.
|
||||
"""
|
||||
test_commands = [r["test_command"] for r in batch]
|
||||
runner_label = batch[0]["runner_label"]
|
||||
use_deepep = batch[0]["use_deepep"]
|
||||
is_cpu = batch[0]["is_cpu"]
|
||||
|
||||
# Join multiple commands with newlines for the workflow to iterate over
|
||||
combined_command = "\n".join(test_commands)
|
||||
|
||||
try:
|
||||
workflow_name = "Rerun Test"
|
||||
workflows = gh_repo.get_workflows()
|
||||
target_workflow = None
|
||||
for wf in workflows:
|
||||
if wf.name == workflow_name:
|
||||
target_workflow = wf
|
||||
break
|
||||
|
||||
if not target_workflow:
|
||||
return {
|
||||
"specs": [r["spec"] for r in batch],
|
||||
"success": False,
|
||||
"error": f"{workflow_name} workflow not found",
|
||||
}
|
||||
|
||||
is_fork = (
|
||||
pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
|
||||
)
|
||||
|
||||
pr_head_sha = None
|
||||
inputs = {
|
||||
"test_command": combined_command,
|
||||
"runner_label": runner_label,
|
||||
"use_deepep": str(use_deepep).lower(),
|
||||
"is_cpu": str(is_cpu).lower(),
|
||||
}
|
||||
if is_fork:
|
||||
ref = "main"
|
||||
pr_head_sha = pr.head.sha
|
||||
inputs["pr_head_sha"] = pr_head_sha
|
||||
else:
|
||||
ref = pr.head.ref
|
||||
|
||||
dispatch_time = time.time()
|
||||
|
||||
dispatch_url = f"https://api.github.com/repos/{gh_repo.full_name}/actions/workflows/{target_workflow.id}/dispatches"
|
||||
dispatch_resp = requests.post(
|
||||
dispatch_url,
|
||||
json={"ref": ref, "inputs": inputs},
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
},
|
||||
)
|
||||
success = dispatch_resp.status_code in (200, 204)
|
||||
if not success:
|
||||
print(f"Dispatch failed: {dispatch_resp.status_code} {dispatch_resp.text}")
|
||||
return {
|
||||
"specs": [r["spec"] for r in batch],
|
||||
"success": False,
|
||||
"error": f"Dispatch failed: {dispatch_resp.status_code}",
|
||||
}
|
||||
|
||||
print(f"Successfully triggered rerun-test: {combined_command}")
|
||||
|
||||
run_url = find_workflow_run_url(
|
||||
gh_repo,
|
||||
target_workflow.id,
|
||||
ref,
|
||||
"rerun-test",
|
||||
token,
|
||||
dispatch_time,
|
||||
pr_head_sha=pr_head_sha,
|
||||
max_wait=30,
|
||||
test_command=combined_command,
|
||||
)
|
||||
return {
|
||||
"specs": [r["spec"] for r in batch],
|
||||
"success": True,
|
||||
"test_commands": test_commands,
|
||||
"runner_label": runner_label,
|
||||
"run_url": run_url,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error triggering rerun-test for batch: {e}")
|
||||
return {
|
||||
"specs": [r["spec"] for r in batch],
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def handle_rerun_test(gh_repo, pr, comment, user_perms, test_specs, token):
|
||||
"""
|
||||
Handles the /rerun-test command. Resolves all test specs, groups them by
|
||||
(runner_label, use_deepep, is_cpu), and dispatches one workflow per group.
|
||||
"""
|
||||
# SECURITY: For fork PRs, only allow /rerun-test if the commenter has write+ permission.
|
||||
# This command checks out and executes code from the PR branch on self-hosted GPU
|
||||
# runners, so we must ensure the commenter is a trusted collaborator.
|
||||
is_fork = pr.head.repo is None or pr.head.repo.owner.login != gh_repo.owner.login
|
||||
if is_fork:
|
||||
commenter = comment.user.login
|
||||
perm = gh_repo.get_collaborator_permission(commenter)
|
||||
if perm not in ("admin", "write"):
|
||||
print(f"Permission denied: /rerun-test on fork PR by {commenter}.")
|
||||
comment.create_reaction("confused")
|
||||
pr.create_issue_comment(
|
||||
"❌ `/rerun-test` is not available for fork PRs unless the commenter "
|
||||
"has write permission on the repo.\n\n"
|
||||
"Please ask a maintainer to run this command, or use the normal CI flow."
|
||||
)
|
||||
return False
|
||||
print(f"Fork PR, but commenter {commenter} has write+ permission. Proceeding.")
|
||||
|
||||
if not (
|
||||
user_perms.get("can_rerun_test", False)
|
||||
or user_perms.get("can_rerun_stage", False)
|
||||
):
|
||||
print("Permission denied: neither can_rerun_test nor can_rerun_stage is true.")
|
||||
return False
|
||||
|
||||
if not test_specs:
|
||||
comment.create_reaction("confused")
|
||||
pr.create_issue_comment(
|
||||
"❌ Please specify a test: `/rerun-test <file>::<TestClass.test_method>`\n\n"
|
||||
"Examples:\n"
|
||||
"- `/rerun-test test/registered/core/test_srt_endpoint.py::TestSRTEndpoint.test_simple_decode`\n"
|
||||
"- `/rerun-test registered/core/test_srt_endpoint.py::TestSRTEndpoint`\n"
|
||||
"- `/rerun-test test_srt_endpoint.py`\n"
|
||||
"- `/rerun-test test_a.py test_b.py test_c.py` (multiple tests)"
|
||||
)
|
||||
return False
|
||||
|
||||
# Phase 1: Resolve all specs
|
||||
resolved = []
|
||||
resolve_failures = []
|
||||
for spec in test_specs:
|
||||
r = _resolve_test_spec(spec)
|
||||
if r.get("error"):
|
||||
resolve_failures.append(r)
|
||||
else:
|
||||
resolved.append(r)
|
||||
|
||||
# Phase 2: Group by (runner_label, use_deepep, is_cpu)
|
||||
groups = {}
|
||||
for r in resolved:
|
||||
key = (r["runner_label"], r["use_deepep"], r["is_cpu"])
|
||||
groups.setdefault(key, []).append(r)
|
||||
|
||||
# Phase 3: Dispatch one workflow per group
|
||||
dispatch_results = []
|
||||
for batch in groups.values():
|
||||
dispatch_results.append(_dispatch_batch(gh_repo, pr, batch, token))
|
||||
|
||||
# Build consolidated comment
|
||||
lines = []
|
||||
for dr in dispatch_results:
|
||||
if dr["success"]:
|
||||
cmds = "\n".join(
|
||||
f"cd test/ && python3 {cmd}" for cmd in dr["test_commands"]
|
||||
)
|
||||
if dr.get("run_url"):
|
||||
lines.append(
|
||||
f"✅ `{dr['runner_label']}` ({len(dr['test_commands'])} test{'s' if len(dr['test_commands']) > 1 else ''}): "
|
||||
f"[View workflow run]({dr['run_url']})\n"
|
||||
f"```\n{cmds}\n```"
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
f"✅ `{dr['runner_label']}` ({len(dr['test_commands'])} test{'s' if len(dr['test_commands']) > 1 else ''}):\n"
|
||||
f"```\n{cmds}\n```\n"
|
||||
f"⚠️ Could not retrieve workflow run URL. "
|
||||
f"Check the [Actions tab](https://github.com/{gh_repo.full_name}/actions) for progress."
|
||||
)
|
||||
else:
|
||||
specs_str = ", ".join(f"`{s}`" for s in dr["specs"])
|
||||
lines.append(f"❌ {specs_str}: {dr['error']}")
|
||||
|
||||
for r in resolve_failures:
|
||||
lines.append(f"❌ `{r['spec']}`: {r['error']}")
|
||||
|
||||
body = "\n\n".join(lines)
|
||||
|
||||
successes = [dr for dr in dispatch_results if dr["success"]]
|
||||
if successes:
|
||||
comment.create_reaction("+1")
|
||||
if not successes and (resolve_failures or dispatch_results):
|
||||
comment.create_reaction("confused")
|
||||
|
||||
pr.create_issue_comment(body)
|
||||
return len(successes) > 0
|
||||
|
||||
|
||||
def main():
|
||||
# 1. Load Environment Variables
|
||||
token = get_env_var("GITHUB_TOKEN")
|
||||
repo_name = get_env_var("REPO_FULL_NAME")
|
||||
pr_number = int(get_env_var("PR_NUMBER"))
|
||||
comment_id = int(get_env_var("COMMENT_ID"))
|
||||
comment_body = get_env_var("COMMENT_BODY").strip()
|
||||
user_login = get_env_var("USER_LOGIN")
|
||||
|
||||
# 2. Load Permissions (local file check first to avoid unnecessary API calls)
|
||||
user_perms = load_permissions(user_login)
|
||||
|
||||
# 3. Initialize GitHub API with Auth
|
||||
auth = Auth.Token(token)
|
||||
g = Github(auth=auth)
|
||||
|
||||
repo = g.get_repo(repo_name)
|
||||
pr = repo.get_pull(pr_number)
|
||||
comment = repo.get_issue(pr_number).get_comment(comment_id)
|
||||
|
||||
# PR authors can always rerun failed CI and rerun individual UTs on their own PRs,
|
||||
# even if they are not listed in CI_PERMISSIONS.json.
|
||||
# Note: /tag-run-ci-label and /rerun-stage still require CI_PERMISSIONS.json.
|
||||
# Note: /rerun-test is blocked entirely for fork PRs in handle_rerun_test() itself.
|
||||
if pr.user.login == user_login:
|
||||
if user_perms is None:
|
||||
print(
|
||||
f"User {user_login} is the PR author (not in CI_PERMISSIONS.json). "
|
||||
"Granting CI rerun permissions."
|
||||
)
|
||||
user_perms = {}
|
||||
else:
|
||||
print(
|
||||
f"User {user_login} is the PR author and has existing CI permissions."
|
||||
)
|
||||
user_perms["can_rerun_failed_ci"] = True
|
||||
user_perms["can_rerun_test"] = True
|
||||
|
||||
if not user_perms:
|
||||
print(f"User {user_login} does not have any configured permissions. Exiting.")
|
||||
return
|
||||
|
||||
# 4. Parse Command and Execute
|
||||
first_line = comment_body.split("\n")[0].strip()
|
||||
|
||||
if first_line.startswith("/tag-run-ci-label"):
|
||||
handle_tag_run_ci(repo, pr, comment, user_perms)
|
||||
|
||||
elif first_line.startswith("/rerun-failed-ci"):
|
||||
handle_rerun_failed_ci(repo, pr, comment, user_perms)
|
||||
|
||||
elif first_line.startswith("/tag-and-rerun-ci"):
|
||||
# Perform both actions, but suppress individual reactions
|
||||
print("Processing combined command: /tag-and-rerun-ci")
|
||||
|
||||
tagged = handle_tag_run_ci(
|
||||
repo, pr, comment, user_perms, react_on_success=False
|
||||
)
|
||||
|
||||
# Wait for the label to propagate before triggering rerun
|
||||
if tagged:
|
||||
print("Waiting 5 seconds for label to propagate...")
|
||||
time.sleep(5)
|
||||
|
||||
rerun = handle_rerun_failed_ci(
|
||||
repo, pr, comment, user_perms, react_on_success=False
|
||||
)
|
||||
|
||||
# If at least one action was successful, add the reaction here
|
||||
if tagged or rerun:
|
||||
comment.create_reaction("+1")
|
||||
print("Combined command processed successfully; reaction added.")
|
||||
else:
|
||||
print("Combined command finished, but no actions were taken.")
|
||||
|
||||
elif first_line.startswith("/rerun-stage"):
|
||||
# Extract stage name from command
|
||||
parts = first_line.split(maxsplit=1)
|
||||
stage_name = parts[1].strip() if len(parts) > 1 else None
|
||||
handle_rerun_stage(repo, pr, comment, user_perms, stage_name, token)
|
||||
|
||||
elif first_line.startswith("/rerun-test"):
|
||||
test_specs = first_line.split()[1:]
|
||||
handle_rerun_test(repo, pr, comment, user_perms, test_specs or None, token)
|
||||
|
||||
else:
|
||||
print(f"Unknown or ignored command: {first_line}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
55
third_party/sglang/scripts/ci_monitor/README.md
vendored
Normal file
55
third_party/sglang/scripts/ci_monitor/README.md
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
# SGLang CI failure monitoring
|
||||
|
||||
Scripts used by [.github/workflows/ci-failure-monitor.yml](../../.github/workflows/ci-failure-monitor.yml): scheduled failure analysis and optional Slack notifications.
|
||||
|
||||
## Tools
|
||||
|
||||
1. **Failures Analyzer** (`ci_failures_analysis.py`): Tracks consecutive failures, identifies flaky jobs, and monitors runner health across PR Test / Nightly workflows (Nvidia, AMD, Intel, XPU, NPU).
|
||||
|
||||
2. **Slack poster** (`post_ci_failures_to_slack.py`): Sends a condensed summary from a failure-analysis JSON to Slack (invoked by the workflow when `SGLANG_DIFFUSION_SLACK_TOKEN` is set).
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install requests slack_sdk
|
||||
```
|
||||
|
||||
(`slack_sdk` is only required for `post_ci_failures_to_slack.py`.)
|
||||
|
||||
## Usage
|
||||
|
||||
### Failures Analyzer
|
||||
|
||||
```bash
|
||||
export GITHUB_TOKEN="your_token_here"
|
||||
|
||||
python ci_failures_analysis.py --token $GITHUB_TOKEN --limit 50 --threshold 2
|
||||
python ci_failures_analysis.py --token $GITHUB_TOKEN --limit 300 --threshold 2
|
||||
python ci_failures_analysis.py --token $GITHUB_TOKEN --limit 500 --threshold 3
|
||||
```
|
||||
|
||||
### Slack notifications
|
||||
|
||||
From the `scripts/ci_monitor` directory, after generating a report:
|
||||
|
||||
```bash
|
||||
export SGLANG_DIFFUSION_SLACK_TOKEN="xoxb-..."
|
||||
python post_ci_failures_to_slack.py --report-file ci_failure_analysis_YYYYMMDD_HHMMSS.json
|
||||
```
|
||||
|
||||
## Token permissions
|
||||
|
||||
The GitHub token needs `repo` and `workflow` scopes to read CI run data; otherwise API calls may return 404.
|
||||
|
||||
### Failures Analyzer parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `--token` | Required | GitHub Personal Access Token |
|
||||
| `--limit` | 500 | Number of workflow runs to analyze |
|
||||
| `--threshold` | 3 | Alert threshold for consecutive failures |
|
||||
| `--output` | None | Output JSON file (optional) |
|
||||
|
||||
## Historical note
|
||||
|
||||
The former **CI Monitor** workflow (`ci-monitor.yml`) and its analyzers (`ci_analyzer.py`, `ci_analyzer_perf.py`, `ci_analyzer_balance.py`) were removed as redundant; use this failure monitor workflow and scripts for ongoing CI health alerts.
|
||||
2750
third_party/sglang/scripts/ci_monitor/ci_failures_analysis.py
vendored
Normal file
2750
third_party/sglang/scripts/ci_monitor/ci_failures_analysis.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
274
third_party/sglang/scripts/ci_monitor/post_ci_failures_to_slack.py
vendored
Executable file
274
third_party/sglang/scripts/ci_monitor/post_ci_failures_to_slack.py
vendored
Executable file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Post CI failure analysis results to Slack.
|
||||
|
||||
This is a standalone script that doesn't depend on sglang package installation.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def post_ci_failures_to_slack(report_file: str) -> bool:
|
||||
"""
|
||||
Post CI failure report to Slack with threaded details.
|
||||
|
||||
Creates a parent message with summary (workflow: job1, job2, ...)
|
||||
and a threaded reply with detailed failure information.
|
||||
|
||||
Args:
|
||||
report_file: Path to JSON file containing failure analysis from ci_failures_analysis.py
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
from slack_sdk import WebClient
|
||||
|
||||
token = os.environ.get("SGLANG_DIFFUSION_SLACK_TOKEN")
|
||||
if not token:
|
||||
logger.info("Slack post failed: no token")
|
||||
return False
|
||||
|
||||
# CI failures channel
|
||||
channel_id = "C0A2DG0R7CJ"
|
||||
|
||||
# Get GitHub run ID for linking to the workflow run
|
||||
run_id = os.environ.get("GITHUB_RUN_ID", "")
|
||||
|
||||
# Load report data
|
||||
with open(report_file, "r") as f:
|
||||
report_data = json.load(f)
|
||||
|
||||
client = WebClient(token=token)
|
||||
|
||||
# Parse the real JSON structure
|
||||
# The JSON has workflow sections like "pr_test_nvidia_scheduled_data", "nightly_scheduled_data"
|
||||
# Each section contains jobs with their stats including "current_streak"
|
||||
|
||||
critical_failures = []
|
||||
|
||||
# Map workflow data keys to display names and hardware category
|
||||
# Format: (display_name, hardware, test_type_order)
|
||||
# test_type_order: 0 = PR Test, 1 = Nightly (so PR Test comes first)
|
||||
workflow_info_map = {
|
||||
# Nvidia
|
||||
"pr_test_nvidia_scheduled_data": ("PR Test", "Nvidia", 0),
|
||||
"nightly_nvidia_scheduled_data": ("Nightly", "Nvidia", 1),
|
||||
# AMD
|
||||
"pr_test_amd_scheduled_data": ("PR Test", "AMD", 0),
|
||||
"nightly_amd_scheduled_data": ("Nightly", "AMD", 1),
|
||||
# Intel/Xeon
|
||||
"pr_test_xeon_scheduled_data": ("PR Test", "Intel", 0),
|
||||
"nightly_intel_scheduled_data": ("Nightly", "Intel", 1),
|
||||
# XPU
|
||||
"pr_test_xpu_scheduled_data": ("PR Test", "XPU", 0),
|
||||
# NPU
|
||||
"pr_test_npu_scheduled_data": ("PR Test", "NPU", 0),
|
||||
"nightly_npu_scheduled_data": ("Nightly", "NPU", 1),
|
||||
}
|
||||
|
||||
# Hardware priority order (Nvidia first)
|
||||
hardware_order = ["Nvidia", "AMD", "Intel", "XPU", "NPU"]
|
||||
|
||||
# Iterate through each workflow section
|
||||
for workflow_key, workflow_data in report_data.items():
|
||||
# Skip non-workflow keys (summary, limits, etc.)
|
||||
if not isinstance(workflow_data, dict) or not any(
|
||||
isinstance(v, dict) and "current_streak" in v
|
||||
for v in workflow_data.values()
|
||||
):
|
||||
continue
|
||||
|
||||
# Only process scheduled workflows that are in our map
|
||||
if workflow_key not in workflow_info_map:
|
||||
continue
|
||||
|
||||
test_type, hardware, test_order = workflow_info_map[workflow_key]
|
||||
|
||||
# Check each job in this workflow
|
||||
for job_name, job_data in workflow_data.items():
|
||||
if not isinstance(job_data, dict):
|
||||
continue
|
||||
|
||||
current_streak = job_data.get("current_streak", 0)
|
||||
|
||||
# Filter for jobs with streak >= 2
|
||||
if current_streak >= 2:
|
||||
first_failure = job_data.get("first_failure_in_streak", {})
|
||||
last_failure = job_data.get("last_failure_in_streak", {})
|
||||
|
||||
critical_failures.append(
|
||||
{
|
||||
"hardware": hardware,
|
||||
"test_type": test_type,
|
||||
"test_order": test_order,
|
||||
"job_name": job_name,
|
||||
"consecutive_failures": current_streak,
|
||||
"first_failed_at": (
|
||||
first_failure.get("created_at", "unknown")
|
||||
if first_failure
|
||||
else "unknown"
|
||||
),
|
||||
"first_failed_url": (
|
||||
first_failure.get("job_url", "")
|
||||
if first_failure
|
||||
else ""
|
||||
),
|
||||
"last_failed_at": (
|
||||
last_failure.get("created_at", "unknown")
|
||||
if last_failure
|
||||
else "unknown"
|
||||
),
|
||||
"last_failed_url": (
|
||||
last_failure.get("job_url", "") if last_failure else ""
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Group by hardware, then by test type
|
||||
# Structure: {hardware: {test_type: [job_names]}}
|
||||
hardware_jobs = {}
|
||||
for job in critical_failures:
|
||||
hardware = job.get("hardware", "Unknown")
|
||||
test_type = job.get("test_type", "Unknown")
|
||||
job_name = job.get("job_name", "unknown")
|
||||
if hardware not in hardware_jobs:
|
||||
hardware_jobs[hardware] = {}
|
||||
if test_type not in hardware_jobs[hardware]:
|
||||
hardware_jobs[hardware][test_type] = []
|
||||
hardware_jobs[hardware][test_type].append(job_name)
|
||||
|
||||
# Create summary message
|
||||
workflow_url = ""
|
||||
if run_id:
|
||||
workflow_url = (
|
||||
f"https://github.com/sgl-project/sglang/actions/runs/{run_id}"
|
||||
)
|
||||
|
||||
if not hardware_jobs:
|
||||
summary = "✅ No critical failures detected in scheduled runs"
|
||||
if workflow_url:
|
||||
summary += f"\n<{workflow_url}|View CI Failure Monitor run>"
|
||||
color = "good"
|
||||
else:
|
||||
# Ping relevant people when there are failures
|
||||
mentions = "<@U09R55D8EAY> <@U09ABMCKQPM>"
|
||||
summary_lines = [f"{mentions} 🚨 *CI Critical Failures (Scheduled Runs)*"]
|
||||
|
||||
# Iterate in hardware priority order, with PR Test before Nightly
|
||||
test_type_order = ["PR Test", "Nightly"]
|
||||
for hardware in hardware_order:
|
||||
if hardware not in hardware_jobs:
|
||||
continue
|
||||
summary_lines.append(f"\n*{hardware}:*")
|
||||
for test_type in test_type_order:
|
||||
if test_type not in hardware_jobs[hardware]:
|
||||
continue
|
||||
jobs = hardware_jobs[hardware][test_type]
|
||||
job_list = ", ".join(jobs)
|
||||
summary_lines.append(f" • {test_type}: {job_list}")
|
||||
|
||||
if workflow_url:
|
||||
summary_lines.append(
|
||||
f"\n<{workflow_url}|View full CI Failure Monitor report>"
|
||||
)
|
||||
summary = "\n".join(summary_lines)
|
||||
color = "danger"
|
||||
|
||||
# Post parent message
|
||||
response = client.chat_postMessage(
|
||||
channel=channel_id,
|
||||
text=summary,
|
||||
attachments=[
|
||||
{
|
||||
"color": color,
|
||||
"footer": "SGLang CI Failure Monitor",
|
||||
"footer_icon": "https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png",
|
||||
"ts": int(datetime.now().timestamp()),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
thread_ts = response["ts"]
|
||||
|
||||
# If there are failures, post detailed breakdown in thread
|
||||
if hardware_jobs:
|
||||
details_lines = ["*Detailed Failure Breakdown*\n"]
|
||||
|
||||
# Sort critical_failures by hardware order, then test_order
|
||||
hardware_order_map = {hw: i for i, hw in enumerate(hardware_order)}
|
||||
sorted_failures = sorted(
|
||||
critical_failures,
|
||||
key=lambda x: (
|
||||
hardware_order_map.get(x.get("hardware", ""), 99),
|
||||
x.get("test_order", 99),
|
||||
x.get("job_name", ""),
|
||||
),
|
||||
)
|
||||
|
||||
current_hardware = None
|
||||
for job in sorted_failures:
|
||||
hardware = job.get("hardware", "Unknown")
|
||||
test_type = job.get("test_type", "Unknown")
|
||||
job_name = job.get("job_name", "unknown")
|
||||
consecutive = job.get("consecutive_failures", 0)
|
||||
first_url = job.get("first_failed_url", "")
|
||||
first_at = job.get("first_failed_at", "unknown")
|
||||
last_url = job.get("last_failed_url", "")
|
||||
last_at = job.get("last_failed_at", "unknown")
|
||||
|
||||
# Add hardware section header
|
||||
if hardware != current_hardware:
|
||||
details_lines.append(f"\n*━━━ {hardware} ━━━*")
|
||||
current_hardware = hardware
|
||||
|
||||
details_lines.append(
|
||||
f"• *{test_type}* → `{job_name}`\n"
|
||||
f" Consecutive failures: {consecutive}\n"
|
||||
f" First failed: <{first_url}|{first_at}>\n"
|
||||
f" Last failed: <{last_url}|{last_at}>\n"
|
||||
)
|
||||
|
||||
details_text = "\n".join(details_lines)
|
||||
|
||||
client.chat_postMessage(
|
||||
channel=channel_id,
|
||||
thread_ts=thread_ts,
|
||||
text=details_text,
|
||||
)
|
||||
|
||||
logger.info("CI failure report posted to Slack successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to post CI failures to Slack: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Post CI failure analysis results to Slack"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to CI failure analysis JSON report",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
success = post_ci_failures_to_slack(args.report_file)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
483
third_party/sglang/scripts/code_sync/check_commits.py
vendored
Normal file
483
third_party/sglang/scripts/code_sync/check_commits.py
vendored
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
List commits in the private repo that need to be synced to the OSS repo.
|
||||
|
||||
NOTE:
|
||||
1. This script resolves the git root automatically and can be run anywhere
|
||||
inside the repo.
|
||||
|
||||
This script will:
|
||||
1. Find the most recent sync commit (message starts with
|
||||
"[Automated PR] Copy OSS code from commit").
|
||||
2. Scan commits after that point and keep those that touch the configured paths.
|
||||
3. Compare added diff lines in relevant files against OSS main.
|
||||
4. Print a markdown summary with commit links and write it to GitHub Step Summary.
|
||||
|
||||
Usage:
|
||||
python3 scripts/code_sync/check_commits.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
# Allow sibling imports regardless of the working directory.
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from utils import ( # noqa: E402
|
||||
FOLDER_NAMES,
|
||||
get_last_sync_commit,
|
||||
write_github_step_summary,
|
||||
)
|
||||
|
||||
# --- Configuration Begin ---
|
||||
private_repo = "your-org/sglang-private-repo"
|
||||
oss_repo_url = "https://github.com/sgl-project/sglang.git"
|
||||
oss_repo_branch = "main"
|
||||
default_oss_repo_dir = ".oss_repo"
|
||||
# --- Configuration End ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitInfo:
|
||||
commit_hash: str
|
||||
subject: str
|
||||
commit_date: str
|
||||
relevant_files: List[str]
|
||||
synced_lines: int
|
||||
total_added_lines: int
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
"""Check for required command-line tools."""
|
||||
if not shutil.which("git"):
|
||||
raise EnvironmentError("git is not installed or not in PATH.")
|
||||
|
||||
|
||||
def get_repo_root() -> str:
|
||||
try:
|
||||
output = subprocess.run(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Unable to determine git repo root: {e.stderr or e}") from e
|
||||
|
||||
if not output:
|
||||
raise RuntimeError("Unable to determine git repo root.")
|
||||
return os.path.abspath(output)
|
||||
|
||||
|
||||
def get_repo_from_origin(repo_root: str) -> str:
|
||||
"""Try to infer the repo slug (owner/name) from git remote.origin.url."""
|
||||
try:
|
||||
url = subprocess.run(
|
||||
["git", "config", "--get", "remote.origin.url"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
cwd=repo_root,
|
||||
).stdout.strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return private_repo
|
||||
|
||||
if url.startswith("git@github.com:"):
|
||||
repo = url.split("git@github.com:", 1)[1]
|
||||
elif url.startswith("https://github.com/"):
|
||||
repo = url.split("https://github.com/", 1)[1]
|
||||
else:
|
||||
return private_repo
|
||||
|
||||
if repo.endswith(".git"):
|
||||
repo = repo[: -len(".git")]
|
||||
return repo or private_repo
|
||||
|
||||
|
||||
def get_default_oss_repo_path(repo_root: str) -> str:
|
||||
env_path = os.environ.get("OSS_REPO_PATH")
|
||||
if env_path:
|
||||
return os.path.abspath(env_path)
|
||||
return os.path.abspath(os.path.join(repo_root, default_oss_repo_dir))
|
||||
|
||||
|
||||
def ensure_oss_repo(oss_repo_path: str, repo_url: str, branch: str) -> str:
|
||||
oss_repo_path = os.path.abspath(oss_repo_path)
|
||||
if os.path.exists(oss_repo_path) and not os.path.isdir(oss_repo_path):
|
||||
raise RuntimeError(f"OSS repo path is not a directory: {oss_repo_path}")
|
||||
|
||||
if os.path.isdir(os.path.join(oss_repo_path, ".git")):
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "-C", oss_repo_path, "rev-parse", "--is-inside-work-tree"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
f"OSS repo path exists but is not a git repo: {oss_repo_path}"
|
||||
) from e
|
||||
|
||||
subprocess.run(
|
||||
["git", "-C", oss_repo_path, "fetch", "origin", branch, "--depth", "1"],
|
||||
check=True,
|
||||
)
|
||||
return oss_repo_path
|
||||
|
||||
parent_dir = os.path.dirname(oss_repo_path)
|
||||
if parent_dir and not os.path.isdir(parent_dir):
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
subprocess.run(
|
||||
["git", "clone", "--depth", "1", "--branch", branch, repo_url, oss_repo_path],
|
||||
check=True,
|
||||
)
|
||||
return oss_repo_path
|
||||
|
||||
|
||||
def get_commits_since(repo_root: str, last_sync_hash: Optional[str]) -> List[str]:
|
||||
"""Get commit hashes from last sync commit (exclusive) to HEAD."""
|
||||
try:
|
||||
if last_sync_hash:
|
||||
command = ["git", "rev-list", f"{last_sync_hash}..HEAD"]
|
||||
else:
|
||||
command = ["git", "rev-list", "HEAD"]
|
||||
result = subprocess.run(
|
||||
command, capture_output=True, text=True, check=True, cwd=repo_root
|
||||
).stdout.strip()
|
||||
return [line for line in result.split("\n") if line]
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error getting commit list: {e.stderr}")
|
||||
return []
|
||||
|
||||
|
||||
def get_changed_files(repo_root: str, commit_hash: str) -> List[str]:
|
||||
try:
|
||||
output = subprocess.run(
|
||||
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
cwd=repo_root,
|
||||
).stdout.strip()
|
||||
return [line for line in output.split("\n") if line]
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error getting changed files for {commit_hash}: {e.stderr}")
|
||||
return []
|
||||
|
||||
|
||||
def is_relevant_path(changed_file: str, path_prefix: str) -> bool:
|
||||
if changed_file == path_prefix:
|
||||
return True
|
||||
return changed_file.startswith(f"{path_prefix}/")
|
||||
|
||||
|
||||
def get_relevant_files(changed_files: List[str]) -> List[str]:
|
||||
return [
|
||||
changed_file
|
||||
for changed_file in changed_files
|
||||
if any(is_relevant_path(changed_file, path) for path in FOLDER_NAMES)
|
||||
]
|
||||
|
||||
|
||||
def get_added_lines_by_file(
|
||||
repo_root: str, commit_hash: str, relevant_files: List[str]
|
||||
) -> Dict[str, List[str]]:
|
||||
if not relevant_files:
|
||||
return {}
|
||||
|
||||
command = [
|
||||
"git",
|
||||
"show",
|
||||
"--no-color",
|
||||
"--unified=0",
|
||||
"--format=",
|
||||
commit_hash,
|
||||
"--",
|
||||
] + relevant_files
|
||||
try:
|
||||
output = subprocess.run(
|
||||
command, capture_output=True, text=True, check=True, cwd=repo_root
|
||||
).stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error getting diff for {commit_hash}: {e.stderr}")
|
||||
return {}
|
||||
|
||||
added_lines: Dict[str, List[str]] = {path: [] for path in relevant_files}
|
||||
relevant_set = set(relevant_files)
|
||||
current_file: Optional[str] = None
|
||||
for line in output.splitlines():
|
||||
if line.startswith("diff --git "):
|
||||
current_file = None
|
||||
continue
|
||||
if line.startswith("+++ "):
|
||||
file_path = None
|
||||
if line.startswith("+++ b/"):
|
||||
file_path = line[6:]
|
||||
else:
|
||||
candidate = line[4:]
|
||||
if candidate == "/dev/null":
|
||||
file_path = None
|
||||
elif candidate.startswith("b/") or candidate.startswith("a/"):
|
||||
file_path = candidate[2:]
|
||||
else:
|
||||
file_path = candidate
|
||||
|
||||
if file_path in relevant_set:
|
||||
current_file = file_path
|
||||
else:
|
||||
current_file = None
|
||||
continue
|
||||
|
||||
if current_file and line.startswith("+") and not line.startswith("+++ "):
|
||||
added_lines[current_file].append(line[1:])
|
||||
|
||||
return added_lines
|
||||
|
||||
|
||||
def get_oss_file_lines(
|
||||
oss_repo_path: str,
|
||||
oss_ref: str,
|
||||
file_path: str,
|
||||
cache: Dict[str, Optional[Set[str]]],
|
||||
) -> Optional[Set[str]]:
|
||||
if file_path in cache:
|
||||
return cache[file_path]
|
||||
try:
|
||||
output = subprocess.run(
|
||||
["git", "-C", oss_repo_path, "show", f"{oss_ref}:{file_path}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
errors="replace",
|
||||
check=True,
|
||||
).stdout
|
||||
except subprocess.CalledProcessError:
|
||||
cache[file_path] = None
|
||||
return None
|
||||
|
||||
lines = output.splitlines()
|
||||
line_set = set(lines)
|
||||
cache[file_path] = line_set
|
||||
return line_set
|
||||
|
||||
|
||||
def count_synced_lines(
|
||||
added_lines_by_file: Dict[str, List[str]],
|
||||
oss_repo_path: str,
|
||||
oss_ref: str,
|
||||
oss_file_cache: Dict[str, Optional[Set[str]]],
|
||||
) -> Tuple[int, int]:
|
||||
total_added_lines = 0
|
||||
synced_lines = 0
|
||||
for file_path, lines in added_lines_by_file.items():
|
||||
total_added_lines += len(lines)
|
||||
if not lines:
|
||||
continue
|
||||
oss_lines = get_oss_file_lines(
|
||||
oss_repo_path, oss_ref, file_path, oss_file_cache
|
||||
)
|
||||
if not oss_lines:
|
||||
continue
|
||||
for line in lines:
|
||||
if line in oss_lines:
|
||||
synced_lines += 1
|
||||
return synced_lines, total_added_lines
|
||||
|
||||
|
||||
def get_commit_summary(repo_root: str, commit_hash: str) -> Tuple[str, str]:
|
||||
"""Return (subject, date) for a commit."""
|
||||
try:
|
||||
output = subprocess.run(
|
||||
["git", "show", "-s", "--format=%s%x00%ad", "--date=short", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
cwd=repo_root,
|
||||
).stdout.strip()
|
||||
subject, commit_date = output.split("\x00", 1)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error getting commit subject for {commit_hash}: {e.stderr}")
|
||||
subject = "(unknown subject)"
|
||||
commit_date = "(unknown date)"
|
||||
return subject, commit_date
|
||||
|
||||
|
||||
def format_files_list(relevant_files: List[str]) -> str:
|
||||
return "\n".join([f"- {file_path}" for file_path in relevant_files])
|
||||
|
||||
|
||||
def format_last_sync_block(
|
||||
repo: str, subject: str, commit_hash: str, commit_date: str
|
||||
) -> str:
|
||||
short_hash = commit_hash[:9]
|
||||
commit_url = f"https://github.com/{repo}/commit/{commit_hash}"
|
||||
return "\n".join(
|
||||
[
|
||||
"## Last sync",
|
||||
"",
|
||||
f"#### {subject}",
|
||||
f"date: {commit_date}",
|
||||
f"commit: [{short_hash}]({commit_url})",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def format_commit_block(
|
||||
repo: str,
|
||||
subject: str,
|
||||
commit_hash: str,
|
||||
commit_date: str,
|
||||
relevant_files: List[str],
|
||||
synced_lines: int,
|
||||
total_added_lines: int,
|
||||
) -> str:
|
||||
short_hash = commit_hash[:9]
|
||||
commit_url = f"https://github.com/{repo}/commit/{commit_hash}"
|
||||
files_str = format_files_list(relevant_files) if relevant_files else "- None"
|
||||
status_icon = "✅" if synced_lines == total_added_lines else "❌"
|
||||
status_line = (
|
||||
f"status: {status_icon} {synced_lines}/{total_added_lines} lines synced"
|
||||
)
|
||||
return "\n".join(
|
||||
[
|
||||
f"#### {subject}",
|
||||
status_line,
|
||||
f"date: {commit_date}",
|
||||
"files to sync:",
|
||||
files_str,
|
||||
"",
|
||||
f"commit: [{short_hash}]({commit_url})",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def format_output(
|
||||
repo: str,
|
||||
last_sync: Optional[Tuple[str, str, str]],
|
||||
commits: List[CommitInfo],
|
||||
) -> str:
|
||||
lines: List[str] = []
|
||||
if last_sync:
|
||||
subject, commit_hash, commit_date = last_sync
|
||||
lines.append(format_last_sync_block(repo, subject, commit_hash, commit_date))
|
||||
else:
|
||||
lines.extend(["## Last sync", "", "No sync commit found.", ""])
|
||||
|
||||
lines.extend(["## Commits to sync", ""])
|
||||
if not commits:
|
||||
lines.append("No commits need to be synced.")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
for commit in commits:
|
||||
lines.append(
|
||||
format_commit_block(
|
||||
repo,
|
||||
commit.subject,
|
||||
commit.commit_hash,
|
||||
commit.commit_date,
|
||||
commit.relevant_files,
|
||||
commit.synced_lines,
|
||||
commit.total_added_lines,
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="List commits in the private repo that need to be synced to OSS."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Limit number of commits printed (0 means no limit).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--oss-repo-path",
|
||||
default=None,
|
||||
help="Path to OSS repo clone (default: $OSS_REPO_PATH or .oss_repo).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--oss-repo-url",
|
||||
default=oss_repo_url,
|
||||
help="OSS repo URL (default: https://github.com/sgl-project/sglang.git).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--oss-branch",
|
||||
default=oss_repo_branch,
|
||||
help="OSS repo branch to check (default: main).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dependencies()
|
||||
repo_root = get_repo_root()
|
||||
oss_repo_path = (
|
||||
os.path.abspath(args.oss_repo_path)
|
||||
if args.oss_repo_path
|
||||
else get_default_oss_repo_path(repo_root)
|
||||
)
|
||||
|
||||
repo = get_repo_from_origin(repo_root)
|
||||
last_sync_hash = get_last_sync_commit(repo_root)
|
||||
last_sync_block = None
|
||||
if last_sync_hash:
|
||||
last_sync_subject, last_sync_date = get_commit_summary(
|
||||
repo_root, last_sync_hash
|
||||
)
|
||||
last_sync_block = (last_sync_subject, last_sync_hash, last_sync_date)
|
||||
|
||||
commits = get_commits_since(repo_root, last_sync_hash)
|
||||
if args.limit > 0:
|
||||
commits = commits[: args.limit]
|
||||
|
||||
relevant_commit_inputs: List[Tuple[str, List[str]]] = []
|
||||
for commit_hash in commits:
|
||||
changed_files = get_changed_files(repo_root, commit_hash)
|
||||
if not changed_files:
|
||||
continue
|
||||
relevant_files = get_relevant_files(changed_files)
|
||||
if relevant_files:
|
||||
relevant_commit_inputs.append((commit_hash, relevant_files))
|
||||
|
||||
relevant_commits: List[CommitInfo] = []
|
||||
if relevant_commit_inputs:
|
||||
oss_repo_path = ensure_oss_repo(
|
||||
oss_repo_path, args.oss_repo_url, args.oss_branch
|
||||
)
|
||||
oss_ref = f"origin/{args.oss_branch}"
|
||||
oss_file_cache: Dict[str, Optional[Set[str]]] = {}
|
||||
for commit_hash, relevant_files in relevant_commit_inputs:
|
||||
subject, commit_date = get_commit_summary(repo_root, commit_hash)
|
||||
added_lines_by_file = get_added_lines_by_file(
|
||||
repo_root, commit_hash, relevant_files
|
||||
)
|
||||
synced_lines, total_added_lines = count_synced_lines(
|
||||
added_lines_by_file, oss_repo_path, oss_ref, oss_file_cache
|
||||
)
|
||||
relevant_commits.append(
|
||||
CommitInfo(
|
||||
commit_hash=commit_hash,
|
||||
subject=subject,
|
||||
commit_date=commit_date,
|
||||
relevant_files=relevant_files,
|
||||
synced_lines=synced_lines,
|
||||
total_added_lines=total_added_lines,
|
||||
)
|
||||
)
|
||||
|
||||
output = format_output(repo, last_sync_block, relevant_commits)
|
||||
print(output)
|
||||
if os.environ.get("GITHUB_STEP_SUMMARY"):
|
||||
write_github_step_summary(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
273
third_party/sglang/scripts/code_sync/copy_from_oss.py
vendored
Normal file
273
third_party/sglang/scripts/code_sync/copy_from_oss.py
vendored
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Sync code from OSS repo to the local repo and open a PR if changes exist.
|
||||
|
||||
NOTE:
|
||||
1. You need to execute this script in the git root folder.
|
||||
2. A GH_TOKEN environment variable is required to create the pull request.
|
||||
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
|
||||
|
||||
This script will:
|
||||
1. Clone the sgl-project/sglang repository (or use a local copy).
|
||||
2. Sync specified files and directories using rsync.
|
||||
3. Check if the sync operation resulted in any changes.
|
||||
4. If there are changes:
|
||||
a. Create a new branch.
|
||||
b. Commit and push the changes.
|
||||
c. Open a pull request using the GitHub CLI (gh).
|
||||
|
||||
Usage:
|
||||
# Run the full sync and PR creation process
|
||||
python3 scripts/copy_from_oss.py
|
||||
|
||||
# Perform a dry run without making any actual changes
|
||||
python3 scripts/copy_from_oss.py --dry-run
|
||||
|
||||
# Use a local directory as the source instead of cloning
|
||||
python3 scripts/copy_from_oss.py --local-dir ~/projects/sglang
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
# Allow sibling imports regardless of the working directory.
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from utils import FOLDER_NAMES, write_github_step_summary # noqa: E402
|
||||
|
||||
# --- Configuration Begin ---
|
||||
private_repo = "your-org/sglang-private-repo"
|
||||
# --- Configuration End ---
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""Check for required command-line tools."""
|
||||
if not shutil.which("git"):
|
||||
raise EnvironmentError("git is not installed or not in PATH.")
|
||||
if not shutil.which("gh"):
|
||||
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
|
||||
print("✅ All dependencies (git, gh) are available.")
|
||||
|
||||
|
||||
def checkout_main(dry_run):
|
||||
"""Checkout to the main branch."""
|
||||
commands = [
|
||||
"git checkout main",
|
||||
"git reset --hard",
|
||||
]
|
||||
for cmd in commands:
|
||||
print(f"Run: {cmd}")
|
||||
if not dry_run:
|
||||
try:
|
||||
subprocess.run(cmd, shell=True, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Git command failed: {e.stderr.decode()}")
|
||||
raise
|
||||
print("✅ Checkout the main branch.")
|
||||
|
||||
|
||||
def get_source_folder(args):
|
||||
"""
|
||||
Prepare the source repository, either by cloning from GitHub or using a local directory.
|
||||
Returns the path to the source repo root, a temporary directory path (if created),
|
||||
and the short commit hash.
|
||||
"""
|
||||
temp_dir = None
|
||||
if args.local_dir:
|
||||
oss_root = os.path.expanduser(args.local_dir)
|
||||
if not os.path.exists(oss_root):
|
||||
raise FileNotFoundError(
|
||||
f"Specified local directory {oss_root} does not exist."
|
||||
)
|
||||
print(f"Using local directory as the source: {oss_root}")
|
||||
else:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
oss_root = temp_dir
|
||||
print(f"Created temporary directory: {oss_root}")
|
||||
|
||||
repo_url = "https://github.com/sgl-project/sglang.git"
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--single-branch",
|
||||
"--branch",
|
||||
"main",
|
||||
repo_url,
|
||||
temp_dir,
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
print(f"Successfully cloned repository to {temp_dir}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error cloning repository: {e.stderr.decode()}")
|
||||
raise
|
||||
|
||||
commit_hash = subprocess.run(
|
||||
["git", "-C", oss_root, "rev-parse", "HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()[:8]
|
||||
print(f"✅ Get source OSS code at commit: {commit_hash}")
|
||||
return oss_root, temp_dir, commit_hash
|
||||
|
||||
|
||||
def sync_directories(oss_root, sync_paths, dry_run):
|
||||
"""Sync specified directories from oss_root to current working directory."""
|
||||
rsync_commands = []
|
||||
for folder_name in sync_paths:
|
||||
target_name = f"{oss_root}/{folder_name}"
|
||||
src_name = "./" + "/".join(folder_name.split("/")[:-1])
|
||||
cmd = f"rsync -r --delete {target_name} {src_name}"
|
||||
rsync_commands.append(cmd)
|
||||
|
||||
for cmd in rsync_commands:
|
||||
try:
|
||||
print(f"Run: {cmd}")
|
||||
if not dry_run:
|
||||
subprocess.run(cmd, shell=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing command '{cmd}': {e}")
|
||||
raise
|
||||
print(f"✅ Sync all folders.")
|
||||
|
||||
|
||||
def check_for_changes():
|
||||
"""Check if there are any uncommitted git changes."""
|
||||
# This command exits with 1 if there are changes, 0 otherwise.
|
||||
result = subprocess.run(["git", "diff", "--quiet"])
|
||||
return result.returncode != 0
|
||||
|
||||
|
||||
def create_and_push_branch(branch_name, commit_message, dry_run):
|
||||
"""Create a new branch, commit all changes, and push to origin."""
|
||||
commands = [
|
||||
f"git checkout -b {branch_name}",
|
||||
"git config user.name 'github-actions[bot]'",
|
||||
"git config user.email 'github-actions[bot]@users.noreply.github.com'",
|
||||
"git add .",
|
||||
f"git commit -m '{commit_message}'",
|
||||
f"git push origin {branch_name} --force",
|
||||
]
|
||||
print("\nCreating and pushing git branch...")
|
||||
for cmd in commands:
|
||||
print(f"Run: {cmd}")
|
||||
if not dry_run:
|
||||
try:
|
||||
subprocess.run(cmd, shell=True, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Git command failed: {e.stderr.decode()}")
|
||||
raise
|
||||
|
||||
|
||||
def create_pull_request(branch_name, title, body, dry_run):
|
||||
"""Create a pull request using the GitHub CLI."""
|
||||
gh_token = os.getenv("GH_TOKEN")
|
||||
if not gh_token:
|
||||
print(
|
||||
"\n⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
|
||||
)
|
||||
if not dry_run:
|
||||
return
|
||||
|
||||
print("\nCreating pull request...")
|
||||
command = [
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--head",
|
||||
branch_name,
|
||||
"--repo",
|
||||
private_repo,
|
||||
"--title",
|
||||
title,
|
||||
"--body",
|
||||
body,
|
||||
]
|
||||
print(f"Run: {' '.join(command)}")
|
||||
if not dry_run:
|
||||
env = os.environ.copy()
|
||||
env["GH_TOKEN"] = gh_token
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command, check=True, capture_output=True, text=True, env=env
|
||||
)
|
||||
pr_url = result.stdout.strip()
|
||||
msg = f"✅ Successfully created pull request: {pr_url}"
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error creating pull request: {e.stderr}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Copy code from OSS and open a PR if changes are detected."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
help="Path to local SGLang directory to use instead of cloning from GitHub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Dry run the script without executing git, rsync, or gh commands.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dependencies()
|
||||
checkout_main(args.dry_run)
|
||||
|
||||
oss_root, temp_dir, oss_commit = get_source_folder(args)
|
||||
|
||||
try:
|
||||
# Sync directories
|
||||
sync_directories(oss_root, FOLDER_NAMES, args.dry_run)
|
||||
|
||||
# Check for changes and create PR if necessary
|
||||
if not check_for_changes():
|
||||
msg = "😴 No changes detected. The code is already in sync."
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
return
|
||||
|
||||
print("✅ Changes detected. Proceeding to create a PR.")
|
||||
|
||||
current_date = datetime.datetime.now().strftime("%Y%m%d")
|
||||
branch_name = f"copy-from-oss-{oss_commit}-{current_date}"
|
||||
commit_message = f"Copy OSS code from {oss_commit} on {current_date}"
|
||||
pr_title = (
|
||||
f"[Automated PR] Copy OSS code from commit {oss_commit} on {current_date}"
|
||||
)
|
||||
pr_body = (
|
||||
f"Copy OSS code from https://github.com/sgl-project/sglang/commit/{oss_commit} on {current_date}."
|
||||
"\n\n---\n\n"
|
||||
"*This is an automated PR created by scripts/copy_from_oss.py.*"
|
||||
)
|
||||
|
||||
create_and_push_branch(branch_name, commit_message, args.dry_run)
|
||||
create_pull_request(branch_name, pr_title, pr_body, args.dry_run)
|
||||
|
||||
finally:
|
||||
# Remove temporary directory if it was created
|
||||
if temp_dir:
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"\nRemoved temporary directory: {temp_dir}")
|
||||
except OSError as e:
|
||||
print(f"Error removing temporary directory {temp_dir}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
591
third_party/sglang/scripts/code_sync/copy_to_oss.py
vendored
Normal file
591
third_party/sglang/scripts/code_sync/copy_to_oss.py
vendored
Normal file
@@ -0,0 +1,591 @@
|
||||
"""
|
||||
Sync a specific commit from the local private repo to the OSS upstream and open a PR.
|
||||
|
||||
NOTE:
|
||||
1. You need to execute this script in the git root folder.
|
||||
2. A GH_TOKEN environment variable is required to create the pull request.
|
||||
- see also https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
|
||||
|
||||
This script will:
|
||||
1. Take a commit hash as an argument (or use the latest commit by default).
|
||||
2. Create a patch for that commit.
|
||||
3. Filter the patch to only include changes in specified directories.
|
||||
4. Clone the sgl-project/sglang repository.
|
||||
5. Create a new branch in the OSS repo.
|
||||
6. Apply the filtered patch, commit, and force push.
|
||||
7. Open a pull request to the OSS repo using the GitHub CLI (gh).
|
||||
|
||||
Usage:
|
||||
# Sync the latest commit from the current branch
|
||||
python3 scripts/copy_to_oss.py
|
||||
|
||||
# Run the full sync and PR creation process for a given commit
|
||||
python3 scripts/copy_to_oss.py --commit <commit_hash>
|
||||
|
||||
# Perform a dry run without making any actual changes
|
||||
python3 scripts/copy_to_oss.py --commit <commit_hash> --dry-run
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
# Allow sibling imports regardless of the working directory.
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from utils import ( # noqa: E402
|
||||
FOLDER_NAMES,
|
||||
find_latest_oss_sync_commit,
|
||||
write_github_step_summary,
|
||||
)
|
||||
|
||||
|
||||
def get_commit_info(commit_ref):
|
||||
"""
|
||||
Retrieves the hash and message of a specific commit.
|
||||
|
||||
Args:
|
||||
commit_ref (str): The commit hash, tag, or branch to inspect (e.g., 'HEAD').
|
||||
|
||||
Returns:
|
||||
A tuple containing the (commit_hash, commit_message),
|
||||
or (None, None) if an error occurs.
|
||||
"""
|
||||
try:
|
||||
# Use a custom format to get the hash (%H) and the full message (%B)
|
||||
# separated by a null character for safe parsing.
|
||||
command = ["git", "log", "-1", f"--pretty=%H%x00%B", commit_ref]
|
||||
result = subprocess.run(
|
||||
command, capture_output=True, text=True, check=True, encoding="utf-8"
|
||||
)
|
||||
|
||||
# Split the output by the null character separator
|
||||
commit_hash, commit_message = result.stdout.strip().split("\x00", 1)
|
||||
return commit_hash, commit_message
|
||||
|
||||
except FileNotFoundError:
|
||||
print("❌ Error: 'git' command not found. Is Git installed and in your PATH?")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Error getting commit info for '{commit_ref}': {e.stderr.strip()}")
|
||||
print(
|
||||
"Hint: Make sure you are running this from within a Git repository and the commit exists."
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""Check for required command-line tools."""
|
||||
if not shutil.which("git"):
|
||||
raise EnvironmentError("git is not installed or not in PATH.")
|
||||
if not shutil.which("gh"):
|
||||
raise EnvironmentError("GitHub CLI (gh) is not installed or not in PATH.")
|
||||
print("✅ All dependencies (git, gh) are available.")
|
||||
|
||||
|
||||
def create_filtered_patch(commit_hash, dry_run):
|
||||
"""
|
||||
Create a patch file for the given commit, containing only changes
|
||||
to files and directories specified in `folder_names`.
|
||||
"""
|
||||
print(f"Creating a filtered patch for commit {commit_hash}")
|
||||
|
||||
try:
|
||||
# Get the list of all files changed in the commit
|
||||
changed_files_raw = subprocess.run(
|
||||
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout
|
||||
changed_files = changed_files_raw.strip().split("\n")
|
||||
|
||||
# Filter the list of files
|
||||
relevant_files = [
|
||||
f for f in changed_files if any(f.startswith(path) for path in FOLDER_NAMES)
|
||||
]
|
||||
|
||||
if not relevant_files:
|
||||
msg = "\n😴 No relevant file changes found in this commit. Exiting."
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
return None, None
|
||||
|
||||
print("Found relevant changes in the following files:")
|
||||
for f in relevant_files:
|
||||
print(f" - {f}")
|
||||
|
||||
# Create a patch containing only the changes for the relevant files
|
||||
patch_command = [
|
||||
"git",
|
||||
"format-patch",
|
||||
"--stdout",
|
||||
f"{commit_hash}^..{commit_hash}",
|
||||
"--",
|
||||
] + relevant_files
|
||||
|
||||
print(f"Run: {' '.join(patch_command)}")
|
||||
|
||||
patch_content = subprocess.run(
|
||||
patch_command, capture_output=True, text=True, check=True
|
||||
).stdout
|
||||
|
||||
# Save the patch to a temporary file
|
||||
patch_file = tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".patch", encoding="utf-8"
|
||||
)
|
||||
patch_file.write(patch_content)
|
||||
patch_file.close()
|
||||
|
||||
print(f"✅ Filtered patch created successfully at: {patch_file.name}")
|
||||
return patch_file.name, relevant_files
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error creating patch: {e.stderr}")
|
||||
raise
|
||||
|
||||
|
||||
def get_oss_repo(dry_run):
|
||||
"""
|
||||
Clones the OSS repository into a temporary directory.
|
||||
Returns the path to the repo root and the temp directory itself.
|
||||
"""
|
||||
gh_token = os.getenv("GH_TOKEN")
|
||||
if not gh_token:
|
||||
print(
|
||||
"⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
|
||||
)
|
||||
if not dry_run:
|
||||
return
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
oss_root = os.path.join(temp_dir, "sglang")
|
||||
print(f"\nCreated temporary directory for OSS repo: {temp_dir}")
|
||||
|
||||
repo_url = f"https://{gh_token}@github.com/sgl-project/sglang.git"
|
||||
command = ["git", "clone", repo_url, oss_root]
|
||||
|
||||
print(f"Run: {' '.join(command)}")
|
||||
if not dry_run:
|
||||
try:
|
||||
subprocess.run(command, check=True, capture_output=True)
|
||||
print(f"✅ Successfully cloned repository to {oss_root}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error cloning repository: {e.stderr.decode()}")
|
||||
shutil.rmtree(temp_dir)
|
||||
raise
|
||||
|
||||
return oss_root, temp_dir
|
||||
|
||||
|
||||
def _apply_patch(patch_file, dry_run):
|
||||
"""
|
||||
Try to apply a patch, falling back to --3way merge if a clean apply fails.
|
||||
|
||||
Returns True if the patch was applied cleanly.
|
||||
Returns False if conflicts were encountered (changes are still staged
|
||||
with conflict markers so a PR can be created for manual resolution).
|
||||
"""
|
||||
# --- Attempt 1: clean git apply ---
|
||||
apply_cmd = ["git", "apply", patch_file]
|
||||
print(f"Run: {' '.join(apply_cmd)}")
|
||||
if dry_run:
|
||||
return True
|
||||
|
||||
result = subprocess.run(apply_cmd, capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
print("✅ Patch applied cleanly.")
|
||||
return True
|
||||
|
||||
print(f"⚠️ Clean apply failed:\n{result.stderr.strip()}")
|
||||
print("Falling back to git apply --3way ...\n")
|
||||
|
||||
# --- Attempt 2: three-way merge ---
|
||||
threeway_cmd = ["git", "apply", "--3way", patch_file]
|
||||
print(f"Run: {' '.join(threeway_cmd)}")
|
||||
result_3way = subprocess.run(threeway_cmd, capture_output=True, text=True)
|
||||
|
||||
if result_3way.returncode == 0:
|
||||
print("✅ Patch applied via --3way merge (no conflicts).")
|
||||
return True
|
||||
|
||||
# --- --3way left conflict markers in the working tree ---
|
||||
print(f"⚠️ --3way merge had conflicts:\n{result_3way.stderr.strip()}\n")
|
||||
|
||||
# Show which hunks conflict
|
||||
check_cmd = ["git", "apply", "--check", "--verbose", patch_file]
|
||||
print(f"Run: {' '.join(check_cmd)}")
|
||||
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||
conflict_details = (check_result.stdout + check_result.stderr).strip()
|
||||
print(
|
||||
f"\n--- Conflict details ---\n{conflict_details}\n--- End conflict details ---\n"
|
||||
)
|
||||
|
||||
# Show git diff if --3way left conflict markers
|
||||
diff_result = subprocess.run(["git", "diff"], capture_output=True, text=True)
|
||||
if diff_result.stdout.strip():
|
||||
print(
|
||||
f"\n--- git diff (conflict markers) ---\n"
|
||||
f"{diff_result.stdout.strip()}\n"
|
||||
f"--- End git diff ---\n"
|
||||
)
|
||||
|
||||
# Read the patch content for the summary
|
||||
with open(patch_file, "r", encoding="utf-8") as pf:
|
||||
patch_content = pf.read()
|
||||
|
||||
# Print the patch to stdout so it's visible in the CI logs
|
||||
separator = "=" * 72
|
||||
print(
|
||||
f"\n{separator}\n"
|
||||
f"PATCH CONTENT (apply this manually):\n"
|
||||
f"{separator}\n"
|
||||
f"{patch_content}\n"
|
||||
f"{separator}\n"
|
||||
)
|
||||
|
||||
# Write a rich summary to the GitHub Actions step summary
|
||||
summary_lines = [
|
||||
"\n## ⚠️ Patch had conflicts — PR created for manual resolution\n",
|
||||
"### Conflict details\n",
|
||||
f"```\n{conflict_details}\n```\n",
|
||||
]
|
||||
if diff_result.stdout.strip():
|
||||
summary_lines.append("### git diff (conflict markers)\n")
|
||||
summary_lines.append(f"```diff\n{diff_result.stdout.strip()}\n```\n")
|
||||
summary_lines.append("### Patch to apply manually\n")
|
||||
summary_lines.append(
|
||||
"<details><summary>Click to expand full patch</summary>\n\n"
|
||||
f"```diff\n{patch_content}\n```\n"
|
||||
"</details>\n"
|
||||
)
|
||||
write_github_step_summary("".join(summary_lines))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def apply_patch_and_push(
|
||||
oss_root, patch_file, branch_name, commit_message, base_oss_commit, dry_run
|
||||
):
|
||||
"""
|
||||
In the OSS repo, create a branch from base_oss_commit, apply the patch,
|
||||
commit, and push.
|
||||
|
||||
Args:
|
||||
base_oss_commit: The OSS commit hash to branch from (the last sync
|
||||
point). If None, the current HEAD (main) is used.
|
||||
|
||||
Returns True if the patch applied cleanly, False if there were conflicts
|
||||
(the conflicted state is still committed and pushed so a PR can be opened).
|
||||
"""
|
||||
print("\nApplying patch and pushing to OSS repo...")
|
||||
|
||||
original_cwd = os.getcwd()
|
||||
if not dry_run:
|
||||
os.chdir(oss_root)
|
||||
|
||||
applied_cleanly = True
|
||||
try:
|
||||
# Check out a new branch from the base OSS commit
|
||||
if base_oss_commit:
|
||||
checkout_cmd = ["git", "checkout", "-b", branch_name, base_oss_commit]
|
||||
else:
|
||||
checkout_cmd = ["git", "checkout", "-b", branch_name]
|
||||
print(f"Run: {' '.join(checkout_cmd)}")
|
||||
if not dry_run:
|
||||
subprocess.run(checkout_cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
# Apply the patch (with --3way fallback and diagnostics)
|
||||
applied_cleanly = _apply_patch(patch_file, dry_run)
|
||||
|
||||
# Configure git user and stage changes
|
||||
post_apply_commands = [
|
||||
["git", "config", "user.name", "github-actions[bot]"],
|
||||
[
|
||||
"git",
|
||||
"config",
|
||||
"user.email",
|
||||
"github-actions[bot]@users.noreply.github.com",
|
||||
],
|
||||
["git", "add", "."],
|
||||
]
|
||||
|
||||
for cmd_list in post_apply_commands:
|
||||
print(f"Run: {' '.join(cmd_list)}")
|
||||
if not dry_run:
|
||||
subprocess.run(cmd_list, check=True, capture_output=True, text=True)
|
||||
|
||||
# Handle commit separately to pass multi-line message safely via stdin
|
||||
commit_cmd = ["git", "commit", "-F", "-"]
|
||||
print(f"Run: {' '.join(commit_cmd)}")
|
||||
if not dry_run:
|
||||
print(f"Commit Message:\n---\n{commit_message}\n---")
|
||||
subprocess.run(
|
||||
commit_cmd,
|
||||
input=commit_message,
|
||||
text=True,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
# Push the changes
|
||||
push_cmd = ["git", "push", "origin", branch_name, "--force"]
|
||||
print(f"Run: {' '.join(push_cmd)}")
|
||||
if not dry_run:
|
||||
subprocess.run(push_cmd, check=True, capture_output=True, text=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Git command failed: {e.stderr}")
|
||||
raise
|
||||
finally:
|
||||
if not dry_run:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
if applied_cleanly:
|
||||
print("✅ Branch created, patch applied cleanly, and pushed successfully.")
|
||||
else:
|
||||
print(
|
||||
"⚠️ Branch created and pushed with conflict markers. "
|
||||
"A PR will be opened for manual resolution."
|
||||
)
|
||||
|
||||
return applied_cleanly
|
||||
|
||||
|
||||
def create_pull_request(oss_root, branch_name, title, body, dry_run):
|
||||
"""Create a pull request in the OSS repo using the GitHub CLI."""
|
||||
gh_token = os.getenv("GH_TOKEN")
|
||||
if not gh_token:
|
||||
print(
|
||||
"⚠️ Warning: GH_TOKEN environment variable not set. Skipping PR creation."
|
||||
)
|
||||
if not dry_run:
|
||||
return
|
||||
|
||||
print("\nCreating pull request...")
|
||||
command = [
|
||||
"gh",
|
||||
"pr",
|
||||
"create",
|
||||
"--base",
|
||||
"main",
|
||||
"--head",
|
||||
branch_name,
|
||||
"--repo",
|
||||
"sgl-project/sglang",
|
||||
"--title",
|
||||
title,
|
||||
"--body",
|
||||
body,
|
||||
]
|
||||
|
||||
print(f"Run: {' '.join(command)}")
|
||||
if not dry_run:
|
||||
env = os.environ.copy()
|
||||
env["GH_TOKEN"] = gh_token
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
cwd=oss_root,
|
||||
)
|
||||
msg = f"✅ Successfully created pull request: {result.stdout.strip()}"
|
||||
print(msg)
|
||||
write_github_step_summary(msg)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error creating pull request: {e.stderr}")
|
||||
# Check if a PR already exists
|
||||
if "A pull request for" in e.stderr and "already exists" in e.stderr:
|
||||
print("ℹ️ A PR for this branch likely already exists.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def get_commit_author(commit_hash):
|
||||
"""Get the author name and email of a commit."""
|
||||
try:
|
||||
author_name = subprocess.run(
|
||||
["git", "show", "-s", "--format=%an", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()
|
||||
author_email = subprocess.run(
|
||||
["git", "show", "-s", "--format=%ae", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()
|
||||
return author_name, author_email
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error getting commit author for {commit_hash}: {e.stderr}")
|
||||
raise
|
||||
|
||||
|
||||
def get_all_co_author_lines(commit_hash, commit_message):
|
||||
"""
|
||||
Build a deduplicated list of Co-authored-by lines that includes both
|
||||
the primary commit author and any Co-authored-by trailers already
|
||||
present in the commit message.
|
||||
|
||||
Returns a list of unique "Co-authored-by: Name <email>" strings.
|
||||
"""
|
||||
seen = set()
|
||||
co_author_lines = []
|
||||
|
||||
def _add(name, email):
|
||||
key = (name.strip(), email.strip().lower())
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
co_author_lines.append(f"Co-authored-by: {name.strip()} <{email.strip()}>")
|
||||
|
||||
# 1. Primary author of the commit
|
||||
author_name, author_email = get_commit_author(commit_hash)
|
||||
_add(author_name, author_email)
|
||||
|
||||
# 2. Existing Co-authored-by trailers in the commit message
|
||||
for line in commit_message.splitlines():
|
||||
m = re.match(r"^\s*Co-authored-by:\s*(.+?)\s*<([^>]+)>", line, re.IGNORECASE)
|
||||
if m:
|
||||
_add(m.group(1), m.group(2))
|
||||
|
||||
return co_author_lines
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Copy a commit from the private repo to OSS and open a PR."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit",
|
||||
type=str,
|
||||
default="LAST",
|
||||
help="The commit hash to sync. Defaults to 'LAST' to use the latest commit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Dry run the script without executing git, rsync, or gh commands.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dependencies()
|
||||
|
||||
commit_ref = "HEAD" if args.commit == "LAST" else args.commit
|
||||
commit_hash, original_commit_message = get_commit_info(commit_ref)
|
||||
|
||||
if not commit_hash:
|
||||
return # Exit if we couldn't get commit info
|
||||
|
||||
# Display the details of the commit being processed
|
||||
if args.commit == "LAST":
|
||||
summary = (
|
||||
f"\nℹ️ No commit specified. Using the last commit:\n"
|
||||
f" - **Hash:** `{commit_hash}`\n"
|
||||
f" - **Message:** {original_commit_message}\n\n"
|
||||
)
|
||||
else:
|
||||
summary = (
|
||||
f"\nℹ️ Using specified commit:\n"
|
||||
f" - **Hash:** `{commit_hash}`\n"
|
||||
f" - **Message:** {original_commit_message}\n\n"
|
||||
)
|
||||
print(summary)
|
||||
write_github_step_summary(summary)
|
||||
|
||||
short_hash = commit_hash[:8]
|
||||
|
||||
patch_file = None
|
||||
temp_dir = None
|
||||
try:
|
||||
# 1. Create a filtered patch from the local repo
|
||||
patch_file, relevant_files = create_filtered_patch(commit_hash, args.dry_run)
|
||||
if not patch_file:
|
||||
return
|
||||
|
||||
# 2. Get the OSS repo
|
||||
oss_root, temp_dir = get_oss_repo(args.dry_run)
|
||||
|
||||
# 3. Find the latest OSS commit that was synced into sglang-private.
|
||||
# This is the correct base for our patch, since the private repo's
|
||||
# code is based on this sync point.
|
||||
base_oss_commit = find_latest_oss_sync_commit()
|
||||
if base_oss_commit:
|
||||
print(f"ℹ️ Will branch from OSS commit {base_oss_commit}")
|
||||
else:
|
||||
print(
|
||||
"⚠️ Could not determine latest OSS sync commit. "
|
||||
"Falling back to OSS main HEAD."
|
||||
)
|
||||
|
||||
# 4. Get all co-author lines (primary author + trailers from commit message)
|
||||
co_author_lines = get_all_co_author_lines(commit_hash, original_commit_message)
|
||||
authors_display = "\n".join(co_author_lines)
|
||||
|
||||
# 5. Prepare content for the commit and PR based on changed files
|
||||
file_list_str = "\n".join([f"- {f}" for f in relevant_files])
|
||||
filename_list_str = ", ".join([f.split("/")[-1] for f in relevant_files])
|
||||
if len(filename_list_str) > 40:
|
||||
filename_list_str = filename_list_str[:40] + "..."
|
||||
current_date = datetime.datetime.now().strftime("%Y%m%d")
|
||||
pr_title = f"[Auto Sync] Update {filename_list_str} ({current_date})"
|
||||
|
||||
# 6. Create branch from the last synced OSS commit, apply patch, and push
|
||||
branch_name = f"sync-{short_hash}-{current_date}"
|
||||
co_authors_block = "\n".join(co_author_lines)
|
||||
commit_message = f"{pr_title}\n\n{co_authors_block}"
|
||||
applied_cleanly = apply_patch_and_push(
|
||||
oss_root,
|
||||
patch_file,
|
||||
branch_name,
|
||||
commit_message,
|
||||
base_oss_commit,
|
||||
args.dry_run,
|
||||
)
|
||||
|
||||
# 7. Adjust PR title and body when there are conflicts
|
||||
if not applied_cleanly:
|
||||
pr_title = (
|
||||
f"[Auto Sync][⚠️ Conflicts] Update {filename_list_str} ({current_date})"
|
||||
)
|
||||
|
||||
pr_body_parts = [
|
||||
f"Sync changes from commit `{short_hash}`.\n",
|
||||
f"**Files Changed:**\n{file_list_str}\n",
|
||||
f"**Authors:**\n{authors_display}",
|
||||
]
|
||||
if not applied_cleanly:
|
||||
pr_body_parts.append(
|
||||
"\n\n⚠️ **This patch had merge conflicts.** "
|
||||
"The branch contains conflict markers that must be resolved manually. "
|
||||
"Please check the CI logs for the full patch and conflict details."
|
||||
)
|
||||
pr_body_parts.append(
|
||||
f"\n\n---\n\n"
|
||||
f"*This is an automated PR created by scripts/copy_to_oss.py.*"
|
||||
)
|
||||
pr_body = "\n".join(pr_body_parts)
|
||||
|
||||
# 8. Create Pull Request
|
||||
create_pull_request(oss_root, branch_name, pr_title, pr_body, args.dry_run)
|
||||
|
||||
finally:
|
||||
# Cleanup temporary files
|
||||
if patch_file and os.path.exists(patch_file):
|
||||
os.remove(patch_file)
|
||||
print(f"\nRemoved temporary patch file: {patch_file}")
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"Removed temporary directory: {temp_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
51
third_party/sglang/scripts/code_sync/guideline.md
vendored
Normal file
51
third_party/sglang/scripts/code_sync/guideline.md
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
### Sync Code Between OSS and Private Fork
|
||||
|
||||
You can use the following principles and tools to sync the code between a private fork and the OSS repo [sgl-project/sglang](https://github.com/sgl-project/sglang/tree/main).
|
||||
It learns from [Copybara](https://github.com/google/copybara), a tool used at Google for maintaining open-source code synchronization.
|
||||
|
||||
## Principals
|
||||
|
||||
- The core folders (e.g., `python/sglang/srt`) are 100% mirrored between the private fork and OSS repo.
|
||||
- The OSS repo is the single source of truth. If one commit changes `python/sglang/srt` in the private repo, the change should be synced to the OSS repo as soon as possible with the action B below.
|
||||
- The common code (e.g., base classes, well-known techniques in the industry without private secrets) goes to `python/sglang/srt`. The private-specific code (e.g., with private-specific features, confidential info) goes to `python/sglang/private` .
|
||||
- Anytime you want to make private changes to a file or class under `python/sglang/srt`, duplicate the file and move it under `python/sglang/private`. You can achieve code reuse by importing and inheriting.
|
||||
|
||||
## How to sync the code bidirectionally
|
||||
### Action A: Copy code from OSS to private
|
||||
|
||||
- We can run this action: [Open A PR to Copy Code From OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-from-oss.yml)
|
||||
- It opens a PR to copy all files under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) from the OSS main branch to the private fork.
|
||||
- Since the OSS repo is the single source of truth, this action copies files and overwrites any changes in the private fork. To prevent the private changes from being overwritten, you need to ensure all private changes are merged into the OSS repo before running this action.
|
||||
- This action will be run automatically every day and can also be triggered manually.
|
||||
|
||||
### Action B: Copy diff from private to OSS
|
||||
|
||||
- We can run this action: [Open A PR to Copy Code To OSS](https://github.com/sgl-project/sglang/tree/main/.github/workflows/open-pr-copy-to-oss.yml)
|
||||
- It opens a PR to apply the diff of one specific commit of the private fork to the OSS main branch. It will only pick the changes under certain folders (e.g., `python/sglang/srt` , `test/srt` , `sgl-kernel` ) and ignore changes under private folders (e.g., `python/sglang/private` )
|
||||
- For example, you can have a PR that changes both `python/sglang/srt` and `python/sglang/private/srt`. Once you merge the PR into the private repo, `python/sglang/srt` becomes desynced between the two repos. You need to run this action on your merge commit immediately to open a PR to send your diff to the OSS repo. Then, we need to merge the OSS PR as soon as possible. Once your OSS PR is merged, we can run action A again.
|
||||
- Action A copies files directly, but Action B applies diff. This is because OSS is the source of truth; action A can just copy files. Action B cannot copy, so it uses diff instead.
|
||||
- This action currently needs a manual trigger in order to prevent incidental code leaks. One can also consider making it automatic.
|
||||
|
||||
## Examples
|
||||
- If you want to have some private server arguments, you can create a new file `python/sglang/private/server_args.py`. It defines a class that inherits the oss ServerArgs.
|
||||
```python
|
||||
from sglang.srt.server_args import ServerArgs as ServerArgsOSS
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs(ServerArgsOSS):
|
||||
private_flag: str = "foo"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Get all public args
|
||||
ServerArgsOSS.add_cli_args(parser)
|
||||
|
||||
# Add your private flags
|
||||
parser.add_argument(
|
||||
"--private-flag",
|
||||
type=str,
|
||||
default=ServerArgs.private_flag,
|
||||
)
|
||||
```
|
||||
- Similarly, you can inherit `Engine` and override its fields. You can override `server_args_class` to use your own ServerArgs,
|
||||
override `init_tokenizer_manager_func` to use your own TokenizerManager, override `run_scheduler_process_func` to use your own scheduler.
|
||||
18
third_party/sglang/scripts/code_sync/install_github_cli.sh
vendored
Executable file
18
third_party/sglang/scripts/code_sync/install_github_cli.sh
vendored
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Check if gh is installed before attempting to install it
|
||||
if ! command -v gh &> /dev/null
|
||||
then
|
||||
echo "GitHub CLI not found. Installing now..."
|
||||
(type -p wget >/dev/null || ( apt update && apt install wget -y)) \
|
||||
&& mkdir -p -m 755 /etc/apt/keyrings \
|
||||
&& out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
|
||||
&& cat $out | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& mkdir -p -m 755 /etc/apt/sources.list.d \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
|
||||
&& apt update \
|
||||
&& apt install gh -y
|
||||
else
|
||||
echo "GitHub CLI is already installed. Skipping installation."
|
||||
fi
|
||||
136
third_party/sglang/scripts/code_sync/utils.py
vendored
Normal file
136
third_party/sglang/scripts/code_sync/utils.py
vendored
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Shared constants and helpers for code-sync scripts.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
# --- Configuration Begin ---
|
||||
# List of folders and files to copy to / from the OSS repo.
|
||||
# Changes outside these paths will be ignored.
|
||||
FOLDER_NAMES = [
|
||||
"3rdparty",
|
||||
"assets",
|
||||
"benchmark",
|
||||
"docker",
|
||||
"docs",
|
||||
"examples",
|
||||
"python/sglang/lang",
|
||||
"python/sglang/jit_kernel",
|
||||
"python/sglang/srt",
|
||||
"python/sglang/test",
|
||||
"python/sglang/utils.py",
|
||||
"python/sglang/README.md",
|
||||
"sgl-kernel",
|
||||
"test/manual",
|
||||
"test/registered",
|
||||
"test/srt",
|
||||
"test/README.md",
|
||||
"test/run_suite.py",
|
||||
"README.md",
|
||||
]
|
||||
|
||||
SYNC_COMMIT_PREFIX = r"\[Automated PR\] Copy OSS code from commit"
|
||||
# --- Configuration End ---
|
||||
|
||||
|
||||
def write_github_step_summary(content: str) -> None:
|
||||
"""Append *content* to the GitHub Actions step summary (no-op outside CI)."""
|
||||
summary_path = os.environ.get("GITHUB_STEP_SUMMARY")
|
||||
if not summary_path:
|
||||
return
|
||||
with open(summary_path, "a") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def get_last_sync_commit(repo_root: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Find the most recent sync commit that copied from OSS.
|
||||
|
||||
Returns the full private-repo commit hash, or None if not found.
|
||||
The match is restricted to commits whose **subject** starts with the
|
||||
sync prefix so that unrelated commits mentioning the phrase in their
|
||||
body are ignored.
|
||||
"""
|
||||
subject_pattern = re.compile("^" + SYNC_COMMIT_PREFIX)
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
"git",
|
||||
"log",
|
||||
"--all",
|
||||
"--grep",
|
||||
SYNC_COMMIT_PREFIX,
|
||||
"--format=%H %s",
|
||||
]
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
cwd=repo_root,
|
||||
).stdout.strip()
|
||||
|
||||
for line in result.splitlines():
|
||||
# Format: "<full_hash> <subject>"
|
||||
parts = line.split(" ", 1)
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
commit_hash, subject = parts
|
||||
if subject_pattern.search(subject):
|
||||
return commit_hash
|
||||
|
||||
return None
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error finding last sync commit: {e.stderr}")
|
||||
return None
|
||||
|
||||
|
||||
def find_latest_oss_sync_commit(repo_root: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Search the private repo history for the latest commit whose **subject**
|
||||
matches "[Automated PR] Copy OSS code from commit {commit_id} on {date}"
|
||||
and return the embedded **OSS** commit hash.
|
||||
|
||||
Returns the short OSS commit hash string, or None if not found.
|
||||
"""
|
||||
oss_hash_pattern = re.compile("^" + SYNC_COMMIT_PREFIX + r" ([0-9a-f]+)")
|
||||
|
||||
try:
|
||||
# --grep filters on the full message body, so we request subject-only
|
||||
# output and validate the pattern against the subject ourselves.
|
||||
result = subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"log",
|
||||
"--all",
|
||||
"--grep",
|
||||
SYNC_COMMIT_PREFIX,
|
||||
"--pretty=%s",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
cwd=repo_root,
|
||||
)
|
||||
|
||||
for subject in result.stdout.strip().splitlines():
|
||||
m = oss_hash_pattern.search(subject)
|
||||
if m:
|
||||
oss_commit = m.group(1)
|
||||
print(
|
||||
f"✅ Latest OSS sync commit found: {oss_commit} "
|
||||
f"(from: {subject})"
|
||||
)
|
||||
return oss_commit
|
||||
|
||||
print(
|
||||
"⚠️ No '[Automated PR] Copy OSS code from commit ...' " "found in history."
|
||||
)
|
||||
return None
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error searching for OSS sync commits: {e.stderr.strip()}")
|
||||
return None
|
||||
463
third_party/sglang/scripts/convert_otel_2_perfetto.py
vendored
Normal file
463
third_party/sglang/scripts/convert_otel_2_perfetto.py
vendored
Normal file
@@ -0,0 +1,463 @@
|
||||
import argparse
|
||||
import bisect
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert SGLang OTEL trace files to Perfetto format.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input",
|
||||
dest="input_file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the input OTEL trace file (JSON or JSONL format).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
dest="output_file",
|
||||
type=str,
|
||||
default="sglang_trace_perfetto.json",
|
||||
help="Path to the output Perfetto JSON file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f", "--torch-file", dest="torch_file", help="specify torch profile file"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
perfetto_data = None
|
||||
if args.torch_file:
|
||||
with open(args.torch_file, "r", encoding="utf-8") as file:
|
||||
perfetto_data = json.load(file)
|
||||
baseline = perfetto_data["baseTimeNanoseconds"]
|
||||
else:
|
||||
baseline = 0
|
||||
|
||||
|
||||
def id_generator():
|
||||
i = 0
|
||||
while True:
|
||||
yield i
|
||||
i += 1
|
||||
|
||||
|
||||
relation_id_gen = id_generator()
|
||||
|
||||
|
||||
class SpanLayoutContainer:
|
||||
def __init__(self):
|
||||
self.intervals = []
|
||||
|
||||
def check_overlap(self, start, end):
|
||||
idx = bisect.bisect_left(self.intervals, (start, float("-inf")))
|
||||
|
||||
if idx > 0:
|
||||
prev_start, prev_end = self.intervals[idx - 1]
|
||||
if prev_end > start:
|
||||
return True
|
||||
|
||||
if idx < len(self.intervals):
|
||||
next_start, next_end = self.intervals[idx]
|
||||
if next_start < end:
|
||||
return True
|
||||
return False
|
||||
|
||||
def insert_span(self, start, end):
|
||||
bisect.insort_left(self.intervals, (start, end))
|
||||
|
||||
|
||||
def new_metadata_level1(name: str, pid):
|
||||
return {
|
||||
"name": "process_name",
|
||||
"ph": "M",
|
||||
"pid": pid,
|
||||
"args": {"name": name},
|
||||
}
|
||||
|
||||
|
||||
def new_metadata_level2(name: str, pid, slot_seq):
|
||||
return {
|
||||
"name": "thread_name",
|
||||
"ph": "M",
|
||||
"pid": pid,
|
||||
"tid": slot_seq,
|
||||
"args": {"name": name},
|
||||
}
|
||||
|
||||
|
||||
def __find_line(graph, trans_graph_status, slot_meta_data, pid, start, end):
|
||||
if pid in trans_graph_status:
|
||||
line = trans_graph_status[pid]
|
||||
if start == end:
|
||||
return line
|
||||
# check conflict
|
||||
if not graph[pid][line].check_overlap(start, end):
|
||||
return line
|
||||
|
||||
if pid not in graph:
|
||||
line = 1
|
||||
graph[pid] = {line: SpanLayoutContainer()}
|
||||
trans_graph_status[pid] = line
|
||||
slot_meta_data.append(new_metadata_level2("slot", pid, line))
|
||||
return line
|
||||
|
||||
for line in graph[pid]:
|
||||
if not graph[pid][line].check_overlap(start, end):
|
||||
trans_graph_status[pid] = line
|
||||
return line
|
||||
|
||||
new_line = len(graph[pid]) + 1
|
||||
graph[pid][new_line] = SpanLayoutContainer()
|
||||
trans_graph_status[pid] = new_line
|
||||
slot_meta_data.append(new_metadata_level2("slot", pid, new_line))
|
||||
return new_line
|
||||
|
||||
|
||||
OtelSpan = Dict[str, Any]
|
||||
|
||||
|
||||
def load_otel_data(path: str | Path):
|
||||
p = Path(path)
|
||||
with p.open("rt", encoding="utf-8") as f:
|
||||
first = f.read(1)
|
||||
f.seek(0)
|
||||
if first == "[":
|
||||
data = json.load(f) # JSON array
|
||||
else:
|
||||
data = [json.loads(line) for line in f if line.strip()] # JSONL
|
||||
return data
|
||||
|
||||
|
||||
def extract_all_otel_spans(otel_data):
|
||||
engine_otel_spans = []
|
||||
smg_otel_spans = []
|
||||
for line_data in otel_data:
|
||||
for resource_spans in line_data["resourceSpans"]:
|
||||
# filter: only keep spans which service.name is 'sglang' or 'smg'
|
||||
service_name = ""
|
||||
for attr in resource_spans["resource"]["attributes"]:
|
||||
if attr["key"] == "service.name":
|
||||
service_name = attr["value"]["stringValue"]
|
||||
|
||||
if service_name == "sglang":
|
||||
spans_ref = engine_otel_spans
|
||||
elif service_name == "smg":
|
||||
spans_ref = smg_otel_spans
|
||||
else:
|
||||
continue
|
||||
|
||||
for scope_spans in resource_spans["scopeSpans"]:
|
||||
for span in scope_spans["spans"]:
|
||||
if "attributes" in span:
|
||||
attributes_dict = {
|
||||
attr.get("key"): next(
|
||||
iter(attr.get("value", {}).values()), None
|
||||
)
|
||||
for attr in span["attributes"]
|
||||
}
|
||||
span["attributes"] = attributes_dict
|
||||
else:
|
||||
span["attributes"] = {}
|
||||
spans_ref.append(span)
|
||||
return engine_otel_spans, smg_otel_spans
|
||||
|
||||
|
||||
def build_otel_span_tree(otel_spans):
|
||||
span_id_map = {span["spanId"]: span for span in otel_spans}
|
||||
for span in otel_spans:
|
||||
span["child"] = []
|
||||
|
||||
root_spans = []
|
||||
|
||||
for span in otel_spans:
|
||||
parent_span_id = span.get("parentSpanId", "")
|
||||
if span.get("attributes", {}).get("module") == "sglang::request":
|
||||
root_spans.append(span)
|
||||
elif parent_span_id in span_id_map:
|
||||
parent_span = span_id_map[parent_span_id]
|
||||
parent_span["child"].append(span)
|
||||
|
||||
link_spans = []
|
||||
if "links" in span:
|
||||
for link in span["links"]:
|
||||
link_span = span_id_map.get(link["spanId"])
|
||||
if link_span:
|
||||
link_spans.append(link_span)
|
||||
span["links"] = link_spans
|
||||
|
||||
return root_spans
|
||||
|
||||
|
||||
def __convert_to_perfetto_span(span, rid, bootstrap_room, pid, host_id):
|
||||
if bootstrap_room:
|
||||
span["attributes"]["bootstrap_room"] = bootstrap_room
|
||||
if rid:
|
||||
span["attributes"]["rid"] = rid
|
||||
if host_id:
|
||||
span["host_id"] = host_id
|
||||
span["pid"] = pid
|
||||
|
||||
span["startTimeUnixNano"] = int(span["startTimeUnixNano"])
|
||||
span["endTimeUnixNano"] = int(span["endTimeUnixNano"]) - 1000
|
||||
ts = span["startTimeUnixNano"]
|
||||
dur = span["endTimeUnixNano"] - ts
|
||||
|
||||
perfetto_span = {
|
||||
"ph": "X",
|
||||
"name": span.get("name", "unknown"),
|
||||
"cat": "sglang",
|
||||
"ts": (ts - baseline) / 1000.0,
|
||||
"dur": dur / 1000.0,
|
||||
"pid": pid,
|
||||
"tid": 0,
|
||||
"args": span["attributes"],
|
||||
}
|
||||
|
||||
span["perfetto_span"] = perfetto_span
|
||||
|
||||
for child_span in span["child"]:
|
||||
__convert_to_perfetto_span(child_span, rid, bootstrap_room, pid, host_id)
|
||||
|
||||
|
||||
def generate_perfetto_span(engine_root_spans, smg_otel_spans, thread_meta_data):
|
||||
for root_span in engine_root_spans:
|
||||
root_span["spans"] = []
|
||||
|
||||
rid = root_span["attributes"]["rid"]
|
||||
bootstrap_room = root_span["attributes"].get("bootstrap_room", "")
|
||||
|
||||
for thread_span in root_span["child"]:
|
||||
pid = int(thread_span["attributes"]["pid"])
|
||||
host_id = thread_span["attributes"]["host_id"]
|
||||
thread_name = f'{thread_span["attributes"]["host_id"][:8]}:{thread_span["attributes"]["thread_label"]}'
|
||||
if "tp_rank" in thread_span["attributes"]:
|
||||
thread_name += f"-TP{thread_span['attributes']['tp_rank']}"
|
||||
|
||||
if pid not in thread_meta_data:
|
||||
thread_meta_data[pid] = new_metadata_level1(thread_name, pid)
|
||||
|
||||
for span in thread_span["child"]:
|
||||
__convert_to_perfetto_span(span, rid, bootstrap_room, pid, host_id)
|
||||
root_span["spans"].append(span)
|
||||
|
||||
smg_pid = "smg"
|
||||
thread_meta_data[smg_pid] = new_metadata_level1("smg", smg_pid)
|
||||
for span in smg_otel_spans:
|
||||
span["pid"] = smg_pid
|
||||
__convert_to_perfetto_span(span, None, None, smg_pid, None)
|
||||
|
||||
|
||||
def __set_span_tid(span, line):
|
||||
span["perfetto_span"]["tid"] = line
|
||||
|
||||
for child_span in span["child"]:
|
||||
__set_span_tid(child_span, line)
|
||||
|
||||
|
||||
def generate_perfetto_span_layout(engine_root_spans, smg_otel_spans, slot_meta_data):
|
||||
for root_span in engine_root_spans:
|
||||
root_span["spans"] = sorted(
|
||||
root_span["spans"], key=lambda x: int(x["startTimeUnixNano"])
|
||||
)
|
||||
|
||||
engine_root_spans = sorted(
|
||||
engine_root_spans, key=lambda x: int(x["spans"][0]["startTimeUnixNano"])
|
||||
)
|
||||
graph = {}
|
||||
for root_span in engine_root_spans:
|
||||
req_thread_status = {}
|
||||
for span in root_span["spans"]:
|
||||
line = __find_line(
|
||||
graph,
|
||||
req_thread_status,
|
||||
slot_meta_data,
|
||||
span["perfetto_span"]["pid"],
|
||||
span["startTimeUnixNano"],
|
||||
span["endTimeUnixNano"],
|
||||
)
|
||||
graph[span["perfetto_span"]["pid"]][line].insert_span(
|
||||
span["startTimeUnixNano"], span["endTimeUnixNano"]
|
||||
)
|
||||
__set_span_tid(span, line)
|
||||
|
||||
smg_otel_spans = sorted(smg_otel_spans, key=lambda x: int(x["startTimeUnixNano"]))
|
||||
req_thread_status = {}
|
||||
for span in smg_otel_spans:
|
||||
line = __find_line(
|
||||
graph,
|
||||
req_thread_status,
|
||||
slot_meta_data,
|
||||
span["perfetto_span"]["pid"],
|
||||
span["startTimeUnixNano"],
|
||||
span["endTimeUnixNano"],
|
||||
)
|
||||
graph[span["perfetto_span"]["pid"]][line].insert_span(
|
||||
span["startTimeUnixNano"], span["endTimeUnixNano"]
|
||||
)
|
||||
span["perfetto_span"]["tid"] = line
|
||||
|
||||
|
||||
def __convert_to_perfetto_events(span):
|
||||
span["perfetto_events"] = []
|
||||
if "events" in span:
|
||||
for event in span["events"]:
|
||||
attributes_dict = {
|
||||
attr.get("key"): next(iter(attr.get("value", {}).values()), None)
|
||||
for attr in event["attributes"]
|
||||
}
|
||||
perfetto_event = {
|
||||
"ph": "i",
|
||||
"cat": "sglang",
|
||||
"ts": (int(event["timeUnixNano"]) - baseline) / 1000.0,
|
||||
"pid": span["perfetto_span"]["pid"],
|
||||
"tid": span["perfetto_span"]["tid"],
|
||||
"name": event.get("name", "unknown"),
|
||||
"args": attributes_dict,
|
||||
}
|
||||
|
||||
span["perfetto_events"].append(perfetto_event)
|
||||
|
||||
for child_span in span["child"]:
|
||||
__convert_to_perfetto_events(child_span)
|
||||
|
||||
|
||||
def generate_perfetto_events(engine_root_spans, smg_otel_spans):
|
||||
spans = [span for root_span in engine_root_spans for span in root_span["spans"]]
|
||||
|
||||
for span in spans:
|
||||
__convert_to_perfetto_events(span)
|
||||
|
||||
for span in smg_otel_spans:
|
||||
__convert_to_perfetto_events(span)
|
||||
|
||||
|
||||
def generate_perfetto_links(engine_root_spans, smg_otel_spans):
|
||||
# build link between engine span and smg span
|
||||
span_id_map = {span["spanId"]: span for span in smg_otel_spans}
|
||||
|
||||
for root_span in engine_root_spans:
|
||||
if "parentSpanId" in root_span and root_span["parentSpanId"] in span_id_map:
|
||||
parent_span = span_id_map[root_span["parentSpanId"]]
|
||||
root_span["spans"][0]["links"] = [parent_span]
|
||||
|
||||
for span in root_span["spans"]:
|
||||
span["perfetto_links"] = []
|
||||
|
||||
if "links" in span:
|
||||
for link_span in span["links"]:
|
||||
try:
|
||||
link_perfetto_span = link_span["perfetto_span"]
|
||||
except (KeyError, AttributeError):
|
||||
continue
|
||||
|
||||
if "correlation" in link_perfetto_span["args"]:
|
||||
id = link_perfetto_span["args"]["correlation"]
|
||||
else:
|
||||
id = next(relation_id_gen)
|
||||
link_perfetto_span["args"]["correlation"] = id
|
||||
|
||||
perfetto_start_node = {
|
||||
"ph": "s",
|
||||
"id": id,
|
||||
"pid": link_perfetto_span["pid"],
|
||||
"tid": link_perfetto_span["tid"],
|
||||
"ts": link_perfetto_span["ts"],
|
||||
"cat": "ac2g",
|
||||
"name": "ac2g",
|
||||
}
|
||||
|
||||
perfetto_end_node = {
|
||||
"ph": "f",
|
||||
"id": id,
|
||||
"pid": span["perfetto_span"]["pid"],
|
||||
"tid": span["perfetto_span"]["tid"],
|
||||
"ts": span["perfetto_span"]["ts"],
|
||||
"cat": "ac2g",
|
||||
"name": "ac2g",
|
||||
"bp": "e",
|
||||
}
|
||||
|
||||
span["perfetto_links"].append(perfetto_start_node)
|
||||
span["perfetto_links"].append(perfetto_end_node)
|
||||
|
||||
|
||||
def __gather_one_span(span):
|
||||
elems = []
|
||||
elems.append(span["perfetto_span"])
|
||||
if "perfetto_events" in span:
|
||||
elems.extend(span["perfetto_events"])
|
||||
if "perfetto_links" in span:
|
||||
elems.extend(span["perfetto_links"])
|
||||
|
||||
for child_span in span["child"]:
|
||||
elems.extend(__gather_one_span(child_span))
|
||||
|
||||
return elems
|
||||
|
||||
|
||||
def gather_all_perfetto_elems(
|
||||
engine_root_spans, smg_otel_spans, thread_meta_data, slot_meta_data
|
||||
):
|
||||
elems = []
|
||||
elems.extend(thread_meta_data.values())
|
||||
elems.extend(slot_meta_data)
|
||||
for root_span in engine_root_spans:
|
||||
for span in root_span["spans"]:
|
||||
elems.extend(__gather_one_span(span))
|
||||
|
||||
for span in smg_otel_spans:
|
||||
elems.append(span["perfetto_span"])
|
||||
elems.extend(span["perfetto_events"])
|
||||
|
||||
return elems
|
||||
|
||||
|
||||
def write_json(perfetto_elems):
|
||||
global perfetto_data
|
||||
|
||||
if args.torch_file:
|
||||
perfetto_data["traceEvents"].extend(perfetto_elems)
|
||||
filered_data = [
|
||||
item
|
||||
for item in perfetto_data["traceEvents"]
|
||||
if item.get("cat") != "gpu_user_annotation"
|
||||
]
|
||||
perfetto_data["traceEvents"] = filered_data
|
||||
else:
|
||||
perfetto_data = perfetto_elems
|
||||
|
||||
with open(args.output_file, "w", encoding="utf-8") as file:
|
||||
json.dump(perfetto_data, file, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def main():
|
||||
start_time = time.time()
|
||||
otel_data = load_otel_data(args.input_file)
|
||||
engine_otel_spans, smg_otel_spans = extract_all_otel_spans(otel_data)
|
||||
engine_root_spans = build_otel_span_tree(engine_otel_spans)
|
||||
thread_meta_data = {}
|
||||
generate_perfetto_span(engine_root_spans, smg_otel_spans, thread_meta_data)
|
||||
slot_meta_data = []
|
||||
generate_perfetto_span_layout(engine_root_spans, smg_otel_spans, slot_meta_data)
|
||||
generate_perfetto_events(engine_root_spans, smg_otel_spans)
|
||||
generate_perfetto_links(engine_root_spans, smg_otel_spans)
|
||||
perfetto_elems = gather_all_perfetto_elems(
|
||||
engine_root_spans, smg_otel_spans, thread_meta_data, slot_meta_data
|
||||
)
|
||||
write_json(perfetto_elems)
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
print(f"\nConversion finished successfully!")
|
||||
print(f"Output written to: {args.output_file}")
|
||||
print(f"Execution time: {execution_time * 1000:.4f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
115
third_party/sglang/scripts/export_deepseek_nextn.py
vendored
Normal file
115
third_party/sglang/scripts/export_deepseek_nextn.py
vendored
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Export NextN layer for DeepSeek-V3/R1 model. The exported model can be used for speculative decoding.
|
||||
|
||||
Usage:
|
||||
python3 export_deepseek_nextn.py --input-dir /path/to/DeepSeek-V3 --output-dir /path/to/DeepSeek-V3-NextN
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def get_nextn_layer_id(config):
|
||||
if not hasattr(config, "num_hidden_layers"):
|
||||
raise ValueError("'num_hidden_layers' not found in model config.")
|
||||
return config.num_hidden_layers
|
||||
|
||||
|
||||
def update_and_save_config(config, output_dir):
|
||||
new_config = config.to_dict()
|
||||
new_config.update(
|
||||
{
|
||||
"num_hidden_layers": 1,
|
||||
"architectures": ["DeepseekV3ForCausalLMNextN"],
|
||||
}
|
||||
)
|
||||
with open(os.path.join(output_dir, "config.json"), "w") as f:
|
||||
json.dump(new_config, f, indent=2, ensure_ascii=False, sort_keys=True)
|
||||
|
||||
|
||||
def copy_non_safetensors_files(input_dir, output_dir):
|
||||
for filename in os.listdir(input_dir):
|
||||
src_file_path = os.path.join(input_dir, filename)
|
||||
if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"):
|
||||
dst_file_path = os.path.join(output_dir, filename)
|
||||
shutil.copy2(src_file_path, dst_file_path)
|
||||
print(f"All non-safetensors files have been copied to {output_dir}")
|
||||
|
||||
|
||||
def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
|
||||
prefix = f"model.layers.{nextn_layer_id}"
|
||||
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
|
||||
params = {}
|
||||
for filename in os.listdir(input_dir):
|
||||
if not filename.endswith(".safetensors"):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(input_dir, filename)
|
||||
print(f"Processing: {filename}")
|
||||
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
matching_keys = [k for k in f.keys() if k.startswith(prefix)]
|
||||
|
||||
if not matching_keys:
|
||||
print(f" No parameters starting with '{prefix}' found")
|
||||
continue
|
||||
|
||||
for key in matching_keys:
|
||||
if "embed_tokens" in key or "shared_head.head" in key:
|
||||
continue
|
||||
new_key = key.replace(prefix, "model.layers.0")
|
||||
params[new_key] = f.get_tensor(key)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error processing {filename}: {str(e)}")
|
||||
|
||||
if params:
|
||||
print(f"Saving {len(params)} parameters to {output_path}")
|
||||
save_file(params, output_path)
|
||||
else:
|
||||
print("No matching parameters found.")
|
||||
|
||||
# Update safetensors index
|
||||
index_path = os.path.join(output_dir, "model.safetensors.index.json")
|
||||
print(f"Updating safetensors index to {index_path}")
|
||||
index_data = {"weight_map": {}}
|
||||
for key in params:
|
||||
index_data["weight_map"][key] = "nextn_layer_parameters.safetensors"
|
||||
with open(index_path, "w") as f:
|
||||
json.dump(index_data, f, indent=4)
|
||||
|
||||
print("All done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export NextN layer parameters for DeepSeek-V3/R1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input HF model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output nextn model directory.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
|
||||
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
|
||||
nextn_layer_id = get_nextn_layer_id(config)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
copy_non_safetensors_files(args.input_dir, args.output_dir)
|
||||
update_and_save_config(config, args.output_dir)
|
||||
export_nextn_layer_parameters(args.input_dir, args.output_dir, nextn_layer_id)
|
||||
60
third_party/sglang/scripts/killall_sglang.sh
vendored
Executable file
60
third_party/sglang/scripts/killall_sglang.sh
vendored
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
|
||||
# DEPRECATED: This script will be migrated to python/sglang/cli/killall.py.
|
||||
# CI mode is already handled there. This script remains for local/non-CI usage.
|
||||
#
|
||||
# TODO: Migrate remaining modes (rocm, all, gpus) to killall.py and remove this file.
|
||||
#
|
||||
# Usage:
|
||||
# ./killall_sglang.sh - Kill SGLang processes only (NVIDIA mode)
|
||||
# ./killall_sglang.sh rocm - Kill SGLang processes only (ROCm mode)
|
||||
# ./killall_sglang.sh all - Kill all GPU processes (NVIDIA mode)
|
||||
# ./killall_sglang.sh gpus 0,1,2,3 - Kill all processes on specific GPUs
|
||||
|
||||
if [ "$1" = "rocm" ]; then
|
||||
echo "Running in ROCm mode"
|
||||
|
||||
# Clean SGLang processes
|
||||
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_diffusion::' | xargs -r kill -9
|
||||
|
||||
elif [ "$1" = "gpus" ] && [ -n "$2" ]; then
|
||||
# Kill all processes on specific GPUs only
|
||||
echo "Killing all processes on GPUs: $2"
|
||||
|
||||
# Show current GPU status
|
||||
nvidia-smi
|
||||
|
||||
# Build device file list from GPU IDs (e.g., "0,1,2,3" -> "/dev/nvidia0 /dev/nvidia1 ...")
|
||||
devices=$(echo "$2" | tr ',' '\n' | sed 's/^[[:space:]]*//;s/[[:space:]]*$//' | sed 's|^|/dev/nvidia|' | tr '\n' ' ')
|
||||
echo "Targeting devices: $devices"
|
||||
|
||||
# Kill all processes using specified GPU devices
|
||||
[ -n "$devices" ] && lsof $devices 2>/dev/null | awk 'NR>1 {print $2}' | sort -u | xargs -r kill -9 2>/dev/null
|
||||
|
||||
# Show GPU status after clean up
|
||||
nvidia-smi
|
||||
|
||||
else
|
||||
# Show current GPU status
|
||||
nvidia-smi
|
||||
|
||||
# Clean SGLang processes
|
||||
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_diffusion::' | xargs -r kill -9
|
||||
|
||||
# Clean all GPU processes if "all" argument is provided
|
||||
if [ "$1" = "all" ]; then
|
||||
# Check if sudo is available
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y lsof
|
||||
else
|
||||
apt-get update
|
||||
apt-get install -y lsof
|
||||
fi
|
||||
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
|
||||
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
|
||||
fi
|
||||
|
||||
# Show GPU status after clean up
|
||||
nvidia-smi
|
||||
fi
|
||||
319
third_party/sglang/scripts/playground/bench_speculative.py
vendored
Normal file
319
third_party/sglang/scripts/playground/bench_speculative.py
vendored
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Usage:
|
||||
# single GPU
|
||||
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
|
||||
|
||||
# multiple GPU
|
||||
python3 bench_speculative.py --model-path deepseek-ai/DeepSeek-V3 --speculative-draft-model-path lmsys/DeepSeek-V3-NextN --tp-size 8 --trust-remote-code --batch-size 1 4 8 16 32 --steps 0 1 2 --topk 0 1 2 4 --num_draft_tokens 0 2 4 8
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.bench_serving import benchmark, set_global_args
|
||||
from sglang.benchmark.datasets import DatasetRow
|
||||
from sglang.benchmark.datasets.mmmu import sample_mmmu_requests
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
kill_process_tree,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def node0_print(msg):
|
||||
if server_args.node_rank == 0:
|
||||
print(msg)
|
||||
|
||||
|
||||
prompts = [
|
||||
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
|
||||
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
|
||||
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
|
||||
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
|
||||
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
|
||||
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
|
||||
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
|
||||
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
|
||||
]
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, num_prompts, batch_size, processor, is_multimodal):
|
||||
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
|
||||
if is_multimodal:
|
||||
backend = "sglang-oai-chat"
|
||||
api_url = f"{base_url}/v1/chat/completions"
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_prompts,
|
||||
processor,
|
||||
backend=backend,
|
||||
fixed_output_len=512,
|
||||
)
|
||||
tokenizer = processor.tokenizer
|
||||
else:
|
||||
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
||||
:num_prompts
|
||||
]
|
||||
input_requests: List[DatasetRow] = [
|
||||
DatasetRow(p, 0, 512) for p in padded_prompts
|
||||
]
|
||||
backend = "sglang"
|
||||
api_url = f"{base_url}/generate"
|
||||
tokenizer = processor
|
||||
|
||||
# We need to set some dummy values in order to call `benchmark` below.
|
||||
args = SimpleNamespace(
|
||||
disable_ignore_eos=False,
|
||||
disable_stream=False,
|
||||
return_logprob=False,
|
||||
return_routed_experts=False,
|
||||
plot_throughput=False,
|
||||
backend=backend,
|
||||
dataset_name="custom",
|
||||
num_prompts=None,
|
||||
sharegpt_output_len=None,
|
||||
random_input_len=None,
|
||||
random_output_len=None,
|
||||
random_range_ratio=None,
|
||||
output_file=None,
|
||||
warmup_requests=1,
|
||||
output_details=False,
|
||||
)
|
||||
set_global_args(args)
|
||||
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=float("inf"),
|
||||
max_concurrency=batch_size,
|
||||
disable_tqdm=False,
|
||||
lora_names=None,
|
||||
lora_request_distribution=None,
|
||||
lora_zipf_alpha=None,
|
||||
extra_request_body={},
|
||||
profile=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert results["completed"] == len(input_requests)
|
||||
acc_length = results["accept_length"] or 1.0
|
||||
avg_output_token = results["total_output_tokens"] / results["completed"]
|
||||
|
||||
server_info = requests.get(base_url + "/server_info").json()
|
||||
# We use 20% percentile instead of median on purpose
|
||||
step_time = np.percentile(
|
||||
server_info["internal_states"][0]["step_time_dict"][str(batch_size)], 20
|
||||
)
|
||||
speed = 1 / step_time * acc_length
|
||||
|
||||
return (
|
||||
round(acc_length, 3),
|
||||
round(step_time, 5),
|
||||
round(speed, 3),
|
||||
avg_output_token,
|
||||
)
|
||||
|
||||
|
||||
def main(args, server_args):
|
||||
base_url = "http://127.0.0.1:20000"
|
||||
|
||||
configs = []
|
||||
for batch_size in args.batch_size:
|
||||
for steps in args.steps:
|
||||
for topk in args.topk:
|
||||
for num_draft_tokens in args.num_draft_tokens:
|
||||
if steps * topk + 1 < num_draft_tokens:
|
||||
continue
|
||||
|
||||
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
|
||||
steps + topk + num_draft_tokens != 0
|
||||
):
|
||||
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
|
||||
continue
|
||||
|
||||
configs.append((batch_size, steps, topk, num_draft_tokens))
|
||||
|
||||
for i in range(args.start, args.end or len(configs)):
|
||||
batch_size, steps, topk, num_draft_tokens = configs[i]
|
||||
|
||||
node0_print(
|
||||
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
|
||||
)
|
||||
|
||||
# Create an LLM.
|
||||
if steps == 0:
|
||||
other_args = []
|
||||
else:
|
||||
other_args = [
|
||||
"--speculative-num-steps",
|
||||
steps,
|
||||
"--speculative-eagle-topk",
|
||||
topk,
|
||||
"--speculative-num-draft-tokens",
|
||||
num_draft_tokens,
|
||||
]
|
||||
if server_args.speculative_draft_model_path is not None:
|
||||
other_args.extend(
|
||||
[
|
||||
"--speculative-draft-model-path",
|
||||
server_args.speculative_draft_model_path,
|
||||
"--speculative-algorithm",
|
||||
server_args.speculative_algorithm,
|
||||
]
|
||||
)
|
||||
|
||||
other_args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
batch_size,
|
||||
"--mem-fraction-static",
|
||||
server_args.mem_fraction_static,
|
||||
"--tp-size",
|
||||
server_args.tp_size,
|
||||
"--max-running-requests",
|
||||
batch_size,
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.trust_remote_code:
|
||||
other_args.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.attention_backend:
|
||||
other_args.extend(
|
||||
[
|
||||
"--attention-backend",
|
||||
server_args.attention_backend,
|
||||
]
|
||||
)
|
||||
|
||||
if server_args.quantization:
|
||||
other_args.extend(
|
||||
[
|
||||
"--quantization",
|
||||
server_args.quantization,
|
||||
]
|
||||
)
|
||||
|
||||
process = popen_launch_server(
|
||||
args.model_path,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
env={
|
||||
"SGLANG_RECORD_STEP_TIME": "1",
|
||||
**os.environ,
|
||||
},
|
||||
)
|
||||
|
||||
if args.is_multimodal:
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
args.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
else:
|
||||
processor = AutoTokenizer.from_pretrained(
|
||||
args.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
try:
|
||||
# Warmup
|
||||
send_one_batch(
|
||||
base_url, batch_size, batch_size, processor, args.is_multimodal
|
||||
)
|
||||
|
||||
# Benchmark
|
||||
acc_length, step_time, speed, completion_tokens = send_one_batch(
|
||||
base_url,
|
||||
max(args.num_prompts, batch_size),
|
||||
batch_size,
|
||||
processor,
|
||||
args.is_multimodal,
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
node0_print(
|
||||
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
|
||||
)
|
||||
|
||||
record = {
|
||||
"batch_size": batch_size,
|
||||
"steps": steps,
|
||||
"topk": topk,
|
||||
"num_draft_tokens": num_draft_tokens,
|
||||
"acc_length": acc_length,
|
||||
"step_time": step_time,
|
||||
"speed": speed,
|
||||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
with open(args.output, "a") as fout:
|
||||
fout.write(json.dumps(record) + "\n")
|
||||
|
||||
# Wait for the server to shutdown
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 1, 2, 4, 8),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_draft_tokens",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
|
||||
)
|
||||
parser.add_argument("--num-prompts", type=int, default=16)
|
||||
parser.add_argument("--start", type=int, default=0)
|
||||
parser.add_argument("--end", type=int)
|
||||
parser.add_argument("--output", type=str, default="output.jsonl")
|
||||
parser.add_argument("--is-multimodal", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
||||
|
||||
main(args, server_args)
|
||||
22
third_party/sglang/scripts/playground/disaggregation/cli-logprob.py
vendored
Normal file
22
third_party/sglang/scripts/playground/disaggregation/cli-logprob.py
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
prompt = "The capital of france is "
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {"temperature": 0},
|
||||
"return_logprob": True,
|
||||
"return_input_logprob": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
|
||||
j = response.json()
|
||||
input_logprobs = j["meta_info"]["input_token_logprobs"]
|
||||
output_logprobs = j["meta_info"]["output_token_logprobs"]
|
||||
|
||||
print(len(input_logprobs), len(output_logprobs))
|
||||
34
third_party/sglang/scripts/playground/disaggregation/cli-so.py
vendored
Normal file
34
third_party/sglang/scripts/playground/disaggregation/cli-so.py
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
port = 8000
|
||||
|
||||
json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
)
|
||||
|
||||
# JSON
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json={
|
||||
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
|
||||
|
||||
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100
|
||||
29
third_party/sglang/scripts/playground/disaggregation/cli.py
vendored
Normal file
29
third_party/sglang/scripts/playground/disaggregation/cli.py
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
prompt = """
|
||||
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
|
||||
|
||||
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
|
||||
|
||||
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
|
||||
|
||||
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
|
||||
|
||||
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
|
||||
|
||||
|
||||
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
|
||||
|
||||
Give your honest take on the above text:
|
||||
"""
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={"text": prompt, "sampling_params": {"temperature": 0}},
|
||||
)
|
||||
|
||||
|
||||
response_json = response.json()
|
||||
print(response_json["text"])
|
||||
240
third_party/sglang/scripts/playground/frontend_reasoning.ipynb
vendored
Normal file
240
third_party/sglang/scripts/playground/frontend_reasoning.ipynb
vendored
Normal file
@@ -0,0 +1,240 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Launch A Server\n",
|
||||
"\n",
|
||||
"Launch the server with a reasoning model (Qwen 3.5-4B) and reasoning parser."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sglang import separate_reasoning, assistant_begin, assistant_end\n",
|
||||
"from sglang import assistant, function, gen, system, user\n",
|
||||
"from sglang import image\n",
|
||||
"from sglang import RuntimeEndpoint, set_default_backend\n",
|
||||
"from sglang.srt.utils import load_image\n",
|
||||
"from sglang.test.test_utils import is_in_ci\n",
|
||||
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
|
||||
"\n",
|
||||
"if is_in_ci():\n",
|
||||
" from patch import launch_server_cmd\n",
|
||||
"else:\n",
|
||||
" from sglang.utils import launch_server_cmd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen3-4B --reasoning-parser qwen3 --host 0.0.0.0\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wait_for_server(f\"http://localhost:{port}\", process=server_process)\n",
|
||||
"print(f\"Server started on http://localhost:{port}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set the default backend. Note: you can set chat_template_name in RontimeEndpoint. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"set_default_backend(\n",
|
||||
" RuntimeEndpoint(f\"http://localhost:{port}\", chat_template_name=\"qwen\")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start with a basic question-answering task. And see how the reasoning content is generated."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def basic_qa(s, question):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(question)\n",
|
||||
" s += assistant_begin()\n",
|
||||
" s += gen(\"answer\", max_tokens=512)\n",
|
||||
" s += assistant_end()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"state = basic_qa(\"List 3 countries and their capitals.\")\n",
|
||||
"print_highlight(state[\"answer\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With `separate_reasoning`, you can move the reasoning content to `{param_name}_reasoning_content` in the state."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def basic_qa_separate_reasoning(s, question):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(question)\n",
|
||||
" s += assistant_begin()\n",
|
||||
" s += separate_reasoning(gen(\"answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" s += assistant_end()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = basic_qa_separate_reasoning(\"List 3 countries and their capitals.\")\n",
|
||||
"print_highlight(reasoning_state.stream_executor.variable_event.keys())\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\nSeparated Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(f\"\\n\\nContent:\\n{reasoning_state['answer']}\")\n",
|
||||
"print_highlight(f\"\\n\\nMessages:\\n{reasoning_state.messages()[-1]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`separate_reasoning` can also be used in multi-turn conversations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def multi_turn_qa(s):\n",
|
||||
" s += system(f\"You are a helpful assistant than can answer questions.\")\n",
|
||||
" s += user(\"Please give me a list of 3 countries and their capitals.\")\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(gen(\"first_answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" )\n",
|
||||
" s += user(\"Please give me another list of 3 countries and their capitals.\")\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(gen(\"second_answer\", max_tokens=512), model_type=\"qwen3\")\n",
|
||||
" )\n",
|
||||
" return s\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = multi_turn_qa()\n",
|
||||
"print_highlight(f\"\\n\\nfirst_answer:\\n{reasoning_state['first_answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nfirst_answer_reasoning_content:\\n{reasoning_state['first_answer_reasoning_content']}\"\n",
|
||||
")\n",
|
||||
"print_highlight(f\"\\n\\nsecond_answer:\\n{reasoning_state['second_answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nsecond_answer_reasoning_content:\\n{reasoning_state['second_answer_reasoning_content']}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using No thinking as Qwen 3's advanced feature \n",
|
||||
"\n",
|
||||
"sglang separate_reasoning is particularly useful when combined with Qwen 3's advanced feature.\n",
|
||||
"\n",
|
||||
"[Qwen 3's advanced usages](https://qwenlm.github.io/blog/qwen3/#advanced-usages)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reasoning_state = basic_qa_separate_reasoning(\n",
|
||||
" \"List 3 countries and their capitals. /no_think\"\n",
|
||||
")\n",
|
||||
"print_highlight(f\"Reasoning Content:\\n{reasoning_state['answer_reasoning_content']}\")\n",
|
||||
"print_highlight(f\"Content:\\n{reasoning_state['answer']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`separate_reasoning` can also be used in regular expression generation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@function\n",
|
||||
"def regular_expression_gen(s):\n",
|
||||
" s += user(\n",
|
||||
" \"What is the IP address of the Google DNS servers? just provide the answer\"\n",
|
||||
" )\n",
|
||||
" s += assistant(\n",
|
||||
" separate_reasoning(\n",
|
||||
" gen(\n",
|
||||
" \"answer\",\n",
|
||||
" temperature=0,\n",
|
||||
" regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n",
|
||||
" max_tokens=512,\n",
|
||||
" ),\n",
|
||||
" model_type=\"qwen3\",\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reasoning_state = regular_expression_gen()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_highlight(f\"Answer:\\n{reasoning_state['answer']}\")\n",
|
||||
"print_highlight(\n",
|
||||
" f\"\\n\\nReasoning Content:\\n{reasoning_state['answer_reasoning_content']}\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
14
third_party/sglang/scripts/playground/load_tokenizer.py
vendored
Normal file
14
third_party/sglang/scripts/playground/load_tokenizer.py
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
import argparse
|
||||
import code
|
||||
|
||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
t = get_tokenizer(args.name)
|
||||
code.interact(local=locals())
|
||||
36
third_party/sglang/scripts/playground/long_context_example.py
vendored
Normal file
36
third_party/sglang/scripts/playground/long_context_example.py
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
from urllib.request import urlopen
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
test_cases = {
|
||||
"64k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt",
|
||||
"200k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt",
|
||||
"600k": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
|
||||
"1m": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt",
|
||||
}
|
||||
|
||||
client = OpenAI(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1")
|
||||
|
||||
for name, url in test_cases.items():
|
||||
print(f"\n==== Running test case: {name} ====")
|
||||
try:
|
||||
with urlopen(url, timeout=10) as response:
|
||||
prompt = response.read().decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Failed to load prompt for {name}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content is not None:
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
except Exception as e:
|
||||
print(f"\nError during completion for {name}: {e}")
|
||||
77
third_party/sglang/scripts/playground/lora/analyzer.py
vendored
Normal file
77
third_party/sglang/scripts/playground/lora/analyzer.py
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append("../../")
|
||||
from fix_corrupted_json import clean_json_file
|
||||
|
||||
dirpath = "/Users/ying"
|
||||
output_file_prefix = "analyzed_log"
|
||||
|
||||
time = {}
|
||||
tot_time = {}
|
||||
size = {}
|
||||
|
||||
os.system(f"rm {output_file_prefix}*")
|
||||
|
||||
for dirname in glob.glob(os.path.join(dirpath, "trace*")):
|
||||
print(dirname)
|
||||
trace_name = dirname.split("/")[-1]
|
||||
time[trace_name] = {}
|
||||
size[trace_name] = {}
|
||||
total_time = 0
|
||||
for filename in tqdm(glob.glob(os.path.join(dirname, "*.json"))):
|
||||
step_name = filename.split("/")[-1].split(".")[0]
|
||||
step_name = "_".join(step_name.split("_")[1:])
|
||||
if "prefill" not in filename and "decode" not in filename:
|
||||
continue
|
||||
|
||||
match = re.search(r"(prefill|decode)_step_(\d+)\.json", filename)
|
||||
if match:
|
||||
phase = match.group(1)
|
||||
step = match.group(2)
|
||||
else:
|
||||
raise Exception(f"Cannot parse {filename}")
|
||||
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
trace = json.load(f)
|
||||
except:
|
||||
clean_json_file(filename, filename)
|
||||
with open(filename, "r") as f:
|
||||
trace = json.load(f)
|
||||
|
||||
for event in trace["traceEvents"]:
|
||||
name = event["name"]
|
||||
if name in ["profile_prefill_step", "profile_decode_step"]:
|
||||
dur = event["dur"] / 1e3
|
||||
time[trace_name][step_name] = dur
|
||||
break
|
||||
total_time += dur
|
||||
|
||||
step = int(step_name.split("_")[-1])
|
||||
with open(os.path.join(dirname, f"size_{step}.json"), "r") as f:
|
||||
size_info = json.load(f)
|
||||
size[trace_name][step_name] = size_info["size"]
|
||||
|
||||
tot_time[trace_name] = total_time
|
||||
time[trace_name] = dict(
|
||||
sorted(time[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||
)
|
||||
size[trace_name] = dict(
|
||||
sorted(size[trace_name].items(), key=lambda x: int(x[0].split("_")[-1]))
|
||||
)
|
||||
|
||||
with open(f"{output_file_prefix}_{trace_name}", "a") as f:
|
||||
for k, v in time[trace_name].items():
|
||||
size_v = size[trace_name][k]
|
||||
print(f"{k:>15}{v:10.2f}\t{size_v}")
|
||||
f.write(f"{k:>15}{v:10.2f}\t{size_v}\n")
|
||||
|
||||
with open(f"{output_file_prefix}_total_time", "w") as f:
|
||||
print(tot_time)
|
||||
json.dump(tot_time, f)
|
||||
62
third_party/sglang/scripts/playground/lora/lora_hf_play.py
vendored
Normal file
62
third_party/sglang/scripts/playground/lora/lora_hf_play.py
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# ADAPTER = "winddude/wizardLM-LlaMA-LoRA-7B"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
HF_TOKEN = "..."
|
||||
|
||||
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(MODEL)
|
||||
|
||||
base_model = LlamaForCausalLM.from_pretrained(
|
||||
MODEL,
|
||||
device_map="auto",
|
||||
# load_in_8bit=True,
|
||||
torch_dtype=torch.float16,
|
||||
# use_auth_token=HF_TOKEN,
|
||||
).cuda()
|
||||
|
||||
|
||||
# base model generate
|
||||
with torch.no_grad():
|
||||
output_tensors = base_model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= base output ========")
|
||||
print(output)
|
||||
|
||||
|
||||
# peft model generate
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
ADAPTER,
|
||||
torch_dtype=torch.float16,
|
||||
is_trainable=False,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = model.generate(
|
||||
input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)[0]
|
||||
|
||||
output = tokenizer.decode(output_tensors, skip_special_tokens=True)
|
||||
print("======= peft output ========")
|
||||
print(output)
|
||||
30
third_party/sglang/scripts/playground/lora/lora_vllm_play.py
vendored
Normal file
30
third_party/sglang/scripts/playground/lora/lora_vllm_play.py
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
ADAPTER = "/home/ying/test_lora"
|
||||
prompt = """
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
"""
|
||||
|
||||
|
||||
llm = LLM(model=MODEL, enable_lora=True)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=LoRARequest("test_lora", 1, ADAPTER)
|
||||
)
|
||||
|
||||
print(outputs[0].prompt)
|
||||
print(outputs[0].outputs[0].text)
|
||||
197
third_party/sglang/scripts/playground/reference_hf.py
vendored
Normal file
197
third_party/sglang/scripts/playground/reference_hf.py
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Usage: python3 scripts/playground/reference_hf.py --model-path MODEL_PATH --model-type {text,vlm} [--max-new-tokens NUM] [--dtype DTYPE]
|
||||
--model-path MODEL_PATH: Path to model (default: TinyLlama/TinyLlama-1.1B-Chat-v0.4)
|
||||
--model-type {text,vlm}: Model type, text or vlm (default: text)
|
||||
--max-new-tokens NUM: Max new tokens to generate (default: 16)
|
||||
--dtype DTYPE: Data type for computation (default: float16)
|
||||
Note: '--model' is deprecated; use '--model-path'. Runs normal_text() for text, vlm_text_with_image() for vlm.
|
||||
|
||||
Reference output:
|
||||
========== Prompt 0 ==========
|
||||
prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141],
|
||||
device='cuda:0')
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
========== Prompt 1 ==========
|
||||
prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742],
|
||||
device='cuda:0')
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of
|
||||
|
||||
========== Prompt 2 ==========
|
||||
prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609],
|
||||
device='cuda:0')
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def vlm_text_with_image(args):
|
||||
# Load the processor and model for ImageTextToText tasks
|
||||
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=args.dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# List of image URLs to process
|
||||
image_urls = [
|
||||
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
|
||||
]
|
||||
|
||||
# Conversation template for the processor
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
max_new_tokens = args.max_new_tokens
|
||||
|
||||
for i, url in enumerate(image_urls):
|
||||
# Load the image from the URL
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# Apply the chat template to the text prompt
|
||||
# Notice that not all processors support chat templates.
|
||||
# LLaVA and QWen are two processors that support chat templates.
|
||||
if not hasattr(processor, "apply_chat_template"):
|
||||
raise ValueError("The processor does not support chat templates.")
|
||||
text_prompt = processor.apply_chat_template(
|
||||
conversation, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Prepare inputs for the model
|
||||
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(
|
||||
"cuda:0"
|
||||
)
|
||||
|
||||
# Generate output from the model
|
||||
output_ids = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = processor.decode(output_ids[0])
|
||||
|
||||
# Get the logits from the model's forward pass
|
||||
outputs = model.forward(**inputs)
|
||||
logits = outputs.logits[0, -1, :]
|
||||
|
||||
print(f"\n========== Image {i} ==========")
|
||||
print("prefill logits (final)", logits)
|
||||
# TODO(gaocegege): The output contains numerous <|image_pad|> tokens,
|
||||
# making it cluttered and difficult to read.
|
||||
# These tokens should be removed or cleaned up for better readability.
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def normal_text(args):
|
||||
t = get_tokenizer(args.model_path, trust_remote_code=True)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=args.dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
max_new_tokens = args.max_new_tokens
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = t.encode(p, return_tensors="pt").to("cuda:0")
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda:0")
|
||||
|
||||
output_ids = m.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = t.decode(output_ids[0])
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
|
||||
print(f"\n========== Prompt {i} ==========")
|
||||
print("prefill logits (final)", prefill_logits)
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def synthetic_tokens(args):
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
m.cuda()
|
||||
print(m)
|
||||
|
||||
input_len = 256
|
||||
output_len = 8
|
||||
prompts = [list(range(5, 5 + input_len))]
|
||||
|
||||
for p in prompts:
|
||||
input_ids = p
|
||||
for i in range(output_len + 1):
|
||||
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
|
||||
0
|
||||
][-1]
|
||||
|
||||
if i == 0:
|
||||
print("prefill logits", prefill_logits)
|
||||
else:
|
||||
print("decode", i - 1, prefill_logits)
|
||||
|
||||
input_ids.append(torch.argmax(prefill_logits).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
|
||||
)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=16)
|
||||
|
||||
parser.add_argument("--dtype", type=str, default="float16")
|
||||
|
||||
parser.add_argument("--model-type", type=str, default="text")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "vlm":
|
||||
vlm_text_with_image(args)
|
||||
else:
|
||||
normal_text(args)
|
||||
181
third_party/sglang/scripts/playground/replay_request_dump.py
vendored
Normal file
181
third_party/sglang/scripts/playground/replay_request_dump.py
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Usage:
|
||||
# replay from a folder
|
||||
python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/
|
||||
|
||||
# replay from a single file
|
||||
python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.benchmark.utils import set_ulimit
|
||||
from sglang.srt.utils.common import safe_pickle_load
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
def normalize_mm_data_item(item):
|
||||
if isinstance(item, dict) and "url" in item:
|
||||
return item["url"]
|
||||
return item
|
||||
|
||||
|
||||
def normalize_mm_data(data):
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, list):
|
||||
return [
|
||||
(
|
||||
[normalize_mm_data_item(item) for item in sublist]
|
||||
if isinstance(sublist, list)
|
||||
else normalize_mm_data_item(sublist)
|
||||
)
|
||||
for sublist in data
|
||||
]
|
||||
return normalize_mm_data_item(data)
|
||||
|
||||
|
||||
def normalize_request_data(json_data):
|
||||
"""Normalize multimodal fields in request data for replay compatibility."""
|
||||
for field in ["image_data", "video_data", "audio_data"]:
|
||||
if field in json_data and json_data[field] is not None:
|
||||
json_data[field] = normalize_mm_data(json_data[field])
|
||||
return json_data
|
||||
|
||||
|
||||
def read_records(files):
|
||||
records = []
|
||||
for f in files:
|
||||
with open(f, "rb") as fh:
|
||||
tmp = safe_pickle_load(fh)
|
||||
if isinstance(tmp, dict) and "requests" in tmp:
|
||||
records.extend(tmp["requests"])
|
||||
else:
|
||||
records.extend(tmp)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def run_one_request_internal(record):
|
||||
req, output, replay_init_time, start_time, end_time, idx = record
|
||||
time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed))
|
||||
|
||||
if "completion_tokens" in output.get("meta_info", {}):
|
||||
recorded_completion_tokens = output["meta_info"]["completion_tokens"]
|
||||
else:
|
||||
recorded_completion_tokens = ""
|
||||
|
||||
json_data = normalize_request_data(asdict(req))
|
||||
stream = json_data["stream"]
|
||||
|
||||
if args.ignore_eos:
|
||||
json_data["sampling_params"]["ignore_eos"] = True
|
||||
if recorded_completion_tokens:
|
||||
json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens
|
||||
|
||||
response = requests.post(
|
||||
f"http://{args.host}:{args.port}/generate",
|
||||
json=json_data,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if stream:
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
ret = json.loads(chunk[5:].strip("\n"))
|
||||
else:
|
||||
ret = response.json()
|
||||
|
||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||
print(
|
||||
f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, "
|
||||
f"{completion_tokens=}, {recorded_completion_tokens=}"
|
||||
)
|
||||
|
||||
|
||||
def run_one_request(record):
|
||||
# global success_ct, error_ct
|
||||
|
||||
try:
|
||||
run_one_request_internal(record)
|
||||
# success_ct += 1
|
||||
except Exception:
|
||||
# error_ct += 1
|
||||
traceback = get_exception_traceback()
|
||||
print(f"Hit an exception: {traceback}")
|
||||
|
||||
|
||||
def main(records):
|
||||
if len(records) == 0:
|
||||
return
|
||||
|
||||
base_time = records[0][-2]
|
||||
base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S")
|
||||
print(f"{base_time_str=}")
|
||||
replay_init_time = time.time()
|
||||
|
||||
for i in range(len(records)):
|
||||
req, output, start_time, end_time = records[i]
|
||||
start_time -= base_time
|
||||
records[i] = (req, output, replay_init_time, start_time, end_time, i)
|
||||
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
executor.map(run_one_request, records)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
parser.add_argument(
|
||||
"--input-folder", type=str, default=None, help="Folder containing pickle files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-file", type=str, default=None, help="Single pickle file to process"
|
||||
)
|
||||
parser.add_argument("--file-number", type=int, default=1)
|
||||
parser.add_argument("--req-number", type=int, default=1000000)
|
||||
parser.add_argument("--req-start", type=int, default=0)
|
||||
parser.add_argument("--parallel", type=int, default=512)
|
||||
parser.add_argument("--idx", type=int, default=None)
|
||||
parser.add_argument("--ignore-eos", action="store_true")
|
||||
parser.add_argument("--speed", type=float, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
set_ulimit()
|
||||
|
||||
files = []
|
||||
if args.input_file:
|
||||
files = [args.input_file]
|
||||
if args.file_number > 1:
|
||||
print("Warning: --file-number is ignored when --input-file is provided.")
|
||||
elif args.input_folder:
|
||||
files = glob.glob(f"{args.input_folder}/*.pkl")
|
||||
files = files[: args.file_number]
|
||||
else:
|
||||
print("Error: Either --input-folder or --input-file must be provided.")
|
||||
exit(1)
|
||||
print(f"{files=}")
|
||||
|
||||
records = read_records(files)
|
||||
# Sort by the receive time, before filtering
|
||||
records.sort(key=lambda x: x[-2])
|
||||
records = records[args.req_start :]
|
||||
if args.idx:
|
||||
records = [records[args.idx]]
|
||||
print(f"testing {args.idx=}")
|
||||
print(f"{records[0]}")
|
||||
print(f"{len(records)=}")
|
||||
main(records)
|
||||
207
third_party/sglang/scripts/playground/router/test_tree.py
vendored
Normal file
207
third_party/sglang/scripts/playground/router/test_tree.py
vendored
Normal file
@@ -0,0 +1,207 @@
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from tree import MultiTenantRadixTree
|
||||
|
||||
|
||||
class TestMultiTenantRadixTree(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tree = MultiTenantRadixTree()
|
||||
|
||||
def test_insert_exact_match(self):
|
||||
"""Test 1: Basic insert and exact match operations"""
|
||||
# Insert a single string for one tenant
|
||||
self.tree.insert("hello", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Insert same string for different tenant
|
||||
self.tree.insert("hello", "tenant2")
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Insert different string for same tenant
|
||||
self.tree.insert("world", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
def test_insert_partial_match(self):
|
||||
"""Test 2: Insert with partial matching scenarios"""
|
||||
# Test partial matches with common prefixes
|
||||
self.tree.insert("hello", "tenant1")
|
||||
print(self.tree.pretty_print())
|
||||
self.tree.insert("help", "tenant2")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Match exact strings
|
||||
matched, tenant = self.tree.prefix_match("hello")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
matched, tenant = self.tree.prefix_match("help")
|
||||
self.assertEqual(matched, "help")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
# Match partial string
|
||||
matched, tenant = self.tree.prefix_match("hel")
|
||||
self.assertEqual(matched, "hel")
|
||||
self.assertIn(tenant, ["tenant1", "tenant2"])
|
||||
|
||||
# Match longer string
|
||||
matched, tenant = self.tree.prefix_match("hello_world")
|
||||
self.assertEqual(matched, "hello")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_insert_edge_cases(self):
|
||||
"""Test 3: Edge cases for insert and match operations"""
|
||||
# Empty string
|
||||
self.tree.insert("", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("")
|
||||
self.assertEqual(matched, "")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Single character
|
||||
self.tree.insert("a", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("a")
|
||||
self.assertEqual(matched, "a")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Very long string
|
||||
long_str = "a" * 1000
|
||||
self.tree.insert(long_str, "tenant1")
|
||||
matched, tenant = self.tree.prefix_match(long_str)
|
||||
self.assertEqual(matched, long_str)
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
# Unicode characters
|
||||
self.tree.insert("你好", "tenant1")
|
||||
matched, tenant = self.tree.prefix_match("你好")
|
||||
self.assertEqual(matched, "你好")
|
||||
self.assertEqual(tenant, "tenant1")
|
||||
|
||||
def test_simple_eviction(self):
|
||||
"""Test 4: Simple eviction scenarios
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 5 chars
|
||||
|
||||
Should demonstrate:
|
||||
1. Basic eviction when size limit exceeded
|
||||
2. Proper eviction based on last access time
|
||||
3. Verification that shared nodes remain intact for other tenants
|
||||
"""
|
||||
# Set up size limits
|
||||
max_size = {"tenant1": 10, "tenant2": 5}
|
||||
|
||||
# Insert strings for both tenants
|
||||
self.tree.insert("hello", "tenant1") # size 5
|
||||
self.tree.insert("hello", "tenant2") # size 5
|
||||
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
|
||||
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10
|
||||
|
||||
# Evict - should remove "hello" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains
|
||||
|
||||
# Verify "world" remains for tenant2 (was accessed more recently)
|
||||
matched, tenant = self.tree.prefix_match("world")
|
||||
self.assertEqual(matched, "world")
|
||||
self.assertEqual(tenant, "tenant2")
|
||||
|
||||
def test_medium_eviction(self):
|
||||
"""Test 5: Medium complexity eviction scenarios with shared prefixes
|
||||
Tenant1: limit 10 chars
|
||||
Tenant2: limit 7 chars (forces one string to be evicted)
|
||||
|
||||
Tree structure after inserts:
|
||||
└── 'h' [t1, t2]
|
||||
├── 'i' [t1, t2] # Oldest for t2
|
||||
└── 'e' [t1, t2]
|
||||
├── 'llo' [t1, t2]
|
||||
└── 'y' [t2] # Newest for t2
|
||||
|
||||
Size calculations:
|
||||
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
|
||||
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
|
||||
|
||||
After eviction (tenant2 exceeds limit by 1 char):
|
||||
"hi" should be removed from tenant2 as it's the oldest access
|
||||
"""
|
||||
max_size = {
|
||||
"tenant1": 10,
|
||||
"tenant2": 6,
|
||||
} # tenant2 will need to evict one string
|
||||
|
||||
# Create a tree with overlapping prefixes
|
||||
self.tree.insert("hi", "tenant1")
|
||||
self.tree.insert("hi", "tenant2") # OLDEST for t2
|
||||
|
||||
self.tree.insert("hello", "tenant1")
|
||||
self.tree.insert("hello", "tenant2")
|
||||
|
||||
self.tree.insert("hey", "tenant2") # NEWEST for t2
|
||||
|
||||
# Verify initial sizes
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
|
||||
self.assertEqual(
|
||||
sizes_before["tenant2"], 7
|
||||
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7
|
||||
|
||||
print("\nTree before eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Evict - should remove "hi" from tenant2 as it's the oldest
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
print("\nTree after eviction:")
|
||||
print(self.tree.pretty_print())
|
||||
|
||||
# Verify sizes after eviction
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
|
||||
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6
|
||||
|
||||
def test_advanced_eviction(self):
|
||||
...
|
||||
# Create 4 tenants
|
||||
# Each tenants keeps adding strings with shared prefixes to thousands usage
|
||||
# Set a strict limit for each tenant to only 100
|
||||
# At the end, check whether all of the tenant is under 100 after eviction
|
||||
|
||||
max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}
|
||||
|
||||
prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
|
||||
for i in range(100):
|
||||
for j, prefix in enumerate(prefixes):
|
||||
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
|
||||
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")
|
||||
|
||||
sizes_before = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_before)
|
||||
|
||||
self.tree.evict_tenant_data(max_size)
|
||||
|
||||
sizes_after = self.tree.get_used_size_per_tenant()
|
||||
print(sizes_after)
|
||||
# ensure size_after is below max_size
|
||||
for tenant, size in sizes_after.items():
|
||||
self.assertLessEqual(size, max_size[tenant])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
292
third_party/sglang/scripts/playground/router/tree.py
vendored
Normal file
292
third_party/sglang/scripts/playground/router/tree.py
vendored
Normal file
@@ -0,0 +1,292 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self):
|
||||
self.children: Dict[str, Node] = dict()
|
||||
# We choose to use text because most of the use cases are text-to-text,
|
||||
# so we can save the tokenizing overhead.
|
||||
self.text: str = ""
|
||||
# Maps tenant_id to their last access timestamp
|
||||
self.tenant_last_access_time: Dict[str, float] = dict()
|
||||
self.parent = None
|
||||
|
||||
|
||||
def shared_prefix_length(s1, s2):
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(min_length):
|
||||
if s1[i] != s2[i]:
|
||||
return i
|
||||
return min_length
|
||||
|
||||
|
||||
class MultiTenantRadixTree:
|
||||
"""
|
||||
Python Reference of Rust implementation of MultiTenantRadixTree
|
||||
|
||||
MultiTenantRadixTree is the overlap of multiple radix trees by different tenant
|
||||
Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes
|
||||
while maintaining tenant isolation.
|
||||
|
||||
Key concepts:
|
||||
- Tenant: An entity that owns a subset of the stored strings
|
||||
- Each node tracks which tenants have access to it via tenant_last_access_time
|
||||
- The tree structure is shared, but queries can be filtered by tenant_id
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.root = Node()
|
||||
|
||||
def insert(self, s: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Insert string 's' and associate it with the given tenant_id.
|
||||
|
||||
Args:
|
||||
s: The string to insert
|
||||
tenant_id: The identifier of the tenant who owns this string
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
# No match => create a new node
|
||||
new_node = Node()
|
||||
new_node.text = s[curr_idx:]
|
||||
new_node.parent = curr
|
||||
|
||||
curr.children[s[curr_idx]] = new_node
|
||||
curr_idx = len(s)
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
else:
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
|
||||
# 1. If the matched text is shorter than the node text => split the node
|
||||
if shared_len < len(matched_node.text):
|
||||
# Split structure: [matched_node] => [new_node] -> [contracted_matched_node]
|
||||
|
||||
matched_text = matched_node.text[:shared_len]
|
||||
unmatched_text = matched_node.text[shared_len:]
|
||||
|
||||
new_node = Node()
|
||||
new_node.text = matched_text
|
||||
new_node.children = {unmatched_text[0]: matched_node}
|
||||
new_node.parent = curr
|
||||
new_node.parent.children[matched_text[0]] = new_node
|
||||
new_node.tenant_last_access_time = (
|
||||
matched_node.tenant_last_access_time.copy()
|
||||
)
|
||||
|
||||
# Contract matched node
|
||||
matched_node.text = unmatched_text
|
||||
matched_node.parent = new_node
|
||||
|
||||
curr_idx += shared_len
|
||||
curr = new_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
# 2. If the matched text is longer or equal to the node text => walk down the node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
curr.tenant_last_access_time[tenant_id] = time.time()
|
||||
|
||||
def prefix_match(self, s: str) -> tuple[str, int]:
|
||||
"""
|
||||
Match string 's' with multiple tenants' trees in one operation.
|
||||
|
||||
Args:
|
||||
s: The string to match
|
||||
|
||||
Returns:
|
||||
Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix
|
||||
"""
|
||||
curr = self.root
|
||||
curr_idx = 0
|
||||
|
||||
ret_text = ""
|
||||
ret_tenant = None
|
||||
|
||||
while curr_idx < len(s):
|
||||
matched_node = None
|
||||
if s[curr_idx] in curr.children:
|
||||
matched_node = curr.children[s[curr_idx]]
|
||||
|
||||
if matched_node is None:
|
||||
break
|
||||
|
||||
shared_len = shared_prefix_length(s[curr_idx:], matched_node.text)
|
||||
if shared_len == len(matched_node.text):
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
else:
|
||||
curr_idx += shared_len
|
||||
curr = matched_node
|
||||
break
|
||||
|
||||
selected_tenant = list(curr.tenant_last_access_time.keys())[0]
|
||||
|
||||
# traverse back to the root to update last access time for the selected tenant
|
||||
while curr != self.root:
|
||||
curr.tenant_last_access_time[selected_tenant] = time.time()
|
||||
curr = curr.parent
|
||||
|
||||
return s[:curr_idx], selected_tenant
|
||||
|
||||
def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None:
|
||||
"""
|
||||
Evict data for tenants that have exceeded their storage limits.
|
||||
|
||||
Args:
|
||||
max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size
|
||||
"""
|
||||
|
||||
def leaf_of(node):
|
||||
"""
|
||||
If the node is a leaf for a tenant, add tenant_id to the return list
|
||||
This will return list of tenant ids
|
||||
If not a leaf for all tenants, return []
|
||||
"""
|
||||
candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()])
|
||||
|
||||
for n in node.children.values():
|
||||
for c in n.tenant_last_access_time.keys():
|
||||
candidates[c] = False
|
||||
|
||||
return [k for k, v in candidates.items() if v]
|
||||
|
||||
# maintain a heap with (time, tenant, node) as the value
|
||||
import heapq
|
||||
|
||||
# 1. traverse the tree to
|
||||
# a. add all the leaves into a heap (a node with N tenants will be added N times into the heap)
|
||||
# b. calculate the used size for each tenant
|
||||
# do a dfs with stack
|
||||
stack = [self.root]
|
||||
pq = []
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
# if the node is a leaf for a tenant, add the tenant to the heap
|
||||
tenants = leaf_of(curr)
|
||||
for t in tenants:
|
||||
heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr))
|
||||
|
||||
# 2. pop the heap
|
||||
# a. if the tenant's used size is less than the limit, continue
|
||||
# b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap
|
||||
while len(pq) > 0:
|
||||
time, tenant, node = heapq.heappop(pq)
|
||||
if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]:
|
||||
continue
|
||||
|
||||
# remove the leaf
|
||||
used_size_per_tenant[tenant] -= len(node.text)
|
||||
del node.tenant_last_access_time[tenant]
|
||||
# if no children and no tenants, remove the node
|
||||
if len(node.children) == 0 and len(node.tenant_last_access_time) == 0:
|
||||
del node.parent.children[node.text[0]]
|
||||
|
||||
# add its parent to the heap
|
||||
if tenant in leaf_of(node.parent):
|
||||
heapq.heappush(
|
||||
pq,
|
||||
(node.parent.tenant_last_access_time[tenant], tenant, node.parent),
|
||||
)
|
||||
|
||||
def get_used_size_per_tenant(self) -> Dict[str, int]:
|
||||
"""
|
||||
Calculate the used storage size for each tenant.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping tenant_id to their used storage size
|
||||
"""
|
||||
used_size_per_tenant = defaultdict(int)
|
||||
|
||||
stack = [self.root]
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
for t in curr.tenant_last_access_time.keys():
|
||||
used_size_per_tenant[t] += len(curr.text)
|
||||
|
||||
for c in curr.children.values():
|
||||
stack.append(c)
|
||||
|
||||
return used_size_per_tenant
|
||||
|
||||
def remove_tenant(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Remove all data associated with a specific tenant from the tree.
|
||||
This operation maintains the integrity of the shared tree structure while
|
||||
removing only the specified tenant's access information.
|
||||
|
||||
Args:
|
||||
tenant_id: The identifier of the tenant whose data should be removed
|
||||
"""
|
||||
# TODO: Implementation needed
|
||||
pass
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""
|
||||
Returns a string representation of the tree showing the structure, tenant ownership,
|
||||
and leaf status for each node.
|
||||
|
||||
Returns:
|
||||
str: A formatted string showing the tree hierarchy with tenant information
|
||||
"""
|
||||
|
||||
def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str:
|
||||
# Current node representation
|
||||
node_str = prefix
|
||||
node_str += "└── " if is_last else "├── "
|
||||
|
||||
# Add node text
|
||||
node_str += f"'{node.text}' ["
|
||||
|
||||
# Add tenant information including both timestamp and leaf status
|
||||
tenant_info = []
|
||||
for tid, ts in node.tenant_last_access_time.items():
|
||||
time_str = (
|
||||
time.strftime("%H:%M:%S.", time.localtime(ts))
|
||||
+ f"{(ts % 1):0.3f}"[2:]
|
||||
)
|
||||
tenant_info.append(f"{tid} | {time_str}")
|
||||
|
||||
node_str += ", ".join(tenant_info)
|
||||
node_str += "]\n"
|
||||
|
||||
# Handle children
|
||||
children = list(node.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last_child = i == len(children) - 1
|
||||
# Adjust prefix for children based on whether this is the last child
|
||||
new_prefix = prefix + (" " if is_last else "│ ")
|
||||
node_str += _node_to_str(child, new_prefix, is_last_child)
|
||||
|
||||
return node_str
|
||||
|
||||
if not self.root.children:
|
||||
return "Empty tree"
|
||||
|
||||
# Start with root's children since root itself is just an empty node
|
||||
result = ""
|
||||
children = list(self.root.children.items())
|
||||
for i, (char, child) in enumerate(children):
|
||||
is_last = i == len(children) - 1
|
||||
result += _node_to_str(child, "", is_last)
|
||||
|
||||
return result
|
||||
96
third_party/sglang/scripts/release/README.md
vendored
Normal file
96
third_party/sglang/scripts/release/README.md
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
# Release Scripts
|
||||
|
||||
This directory contains scripts to automate version bumping for SGLang releases.
|
||||
|
||||
## Scripts
|
||||
|
||||
### `bump_sglang_version.py`
|
||||
Updates SGLang version across all relevant files following the pattern from [PR #10468](https://github.com/sgl-project/sglang/pull/10468).
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
python scripts/release/bump_sglang_version.py 0.5.3rc0
|
||||
```
|
||||
|
||||
**Files updated:**
|
||||
- `Makefile`
|
||||
- `benchmark/deepseek_v3/README.md`
|
||||
- `docker/rocm.Dockerfile`
|
||||
- `docs/get_started/install.md`
|
||||
- `docs/platforms/amd_gpu.md`
|
||||
- `docs/platforms/ascend_npu.md`
|
||||
- `python/pyproject.toml`
|
||||
- `python/pyproject_other.toml`
|
||||
- `python/pyproject_npu.toml`
|
||||
- `python/sglang/version.py`
|
||||
|
||||
### `bump_kernel_version.py`
|
||||
Updates the `sglang-kernel` release version across all relevant files following the pattern from [PR #10732](https://github.com/sgl-project/sglang/pull/10732).
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
python scripts/release/bump_kernel_version.py 0.4.0
|
||||
```
|
||||
|
||||
**Files updated:**
|
||||
- `sgl-kernel/pyproject.toml`
|
||||
- `sgl-kernel/pyproject_cpu.toml`
|
||||
- `sgl-kernel/pyproject_rocm.toml`
|
||||
- `sgl-kernel/pyproject_musa.toml`
|
||||
- `sgl-kernel/python/sgl_kernel/version.py`
|
||||
|
||||
## Manual Testing Instructions
|
||||
|
||||
### Test SGLang Version Bump
|
||||
|
||||
1. **Run the script:**
|
||||
```bash
|
||||
python scripts/release/bump_sglang_version.py 0.5.4rc0
|
||||
```
|
||||
|
||||
2. **Verify changes with git diff:**
|
||||
```bash
|
||||
git diff
|
||||
```
|
||||
|
||||
3. **Check specific files contain the new version:**
|
||||
```bash
|
||||
grep -r "0.5.4rc0" python/sglang/version.py
|
||||
grep -r "0.5.4rc0" python/pyproject.toml
|
||||
grep -r "0.5.4rc0" docs/get_started/install.md
|
||||
```
|
||||
|
||||
4. **Reset changes (if testing):**
|
||||
```bash
|
||||
git checkout .
|
||||
```
|
||||
|
||||
### Test Kernel Version Bump
|
||||
|
||||
1. **Run the script:**
|
||||
```bash
|
||||
python scripts/release/bump_kernel_version.py 0.4.0
|
||||
```
|
||||
|
||||
2. **Verify changes with git diff:**
|
||||
```bash
|
||||
git diff
|
||||
```
|
||||
|
||||
3. **Check specific files contain the new version:**
|
||||
```bash
|
||||
grep -r "0.4.0" sgl-kernel/python/sgl_kernel/version.py
|
||||
grep -r "0.4.0" sgl-kernel/pyproject.toml
|
||||
```
|
||||
|
||||
4. **Reset changes (if testing):**
|
||||
```bash
|
||||
git checkout .
|
||||
```
|
||||
|
||||
## Version Format Validation
|
||||
|
||||
- **SGLang versions:** `X.Y.Z` or `X.Y.ZrcN` (e.g., `0.5.3` or `0.5.3rc0`)
|
||||
- **Kernel versions:** `X.Y.Z` (e.g., `0.4.0`)
|
||||
|
||||
The scripts will validate the version format and exit with an error if invalid.
|
||||
148
third_party/sglang/scripts/release/bump_flashinfer_version.py
vendored
Executable file
148
third_party/sglang/scripts/release/bump_flashinfer_version.py
vendored
Executable file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from utils import compare_versions, get_repo_root, normalize_version, validate_version
|
||||
|
||||
FILES_TO_UPDATE = [
|
||||
Path("python/pyproject.toml"),
|
||||
Path("docker/Dockerfile"),
|
||||
Path("python/sglang/srt/entrypoints/engine.py"),
|
||||
Path("python/sglang/srt/utils/common.py"),
|
||||
]
|
||||
|
||||
|
||||
def read_current_flashinfer_version(repo_root: Path) -> str:
|
||||
"""Read the current flashinfer version from python/pyproject.toml."""
|
||||
pyproject = repo_root / "python" / "pyproject.toml"
|
||||
content = pyproject.read_text()
|
||||
match = re.search(
|
||||
r"flashinfer_python==(\d+\.\d+\.\d+(?:rc\d+|\.post\d+)?)", content
|
||||
)
|
||||
if not match:
|
||||
raise ValueError(f"Could not find flashinfer_python version in {pyproject}")
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def replace_flashinfer_version(
|
||||
file_path: Path, old_version: str, new_version: str
|
||||
) -> bool:
|
||||
if not file_path.exists():
|
||||
print(f"Warning: {file_path} does not exist, skipping")
|
||||
return False
|
||||
|
||||
content = file_path.read_text()
|
||||
new_content = content
|
||||
|
||||
name = file_path.name
|
||||
if name == "pyproject.toml":
|
||||
new_content = new_content.replace(
|
||||
f"flashinfer_python=={old_version}", f"flashinfer_python=={new_version}"
|
||||
)
|
||||
new_content = new_content.replace(
|
||||
f"flashinfer_cubin=={old_version}", f"flashinfer_cubin=={new_version}"
|
||||
)
|
||||
elif name == "Dockerfile":
|
||||
new_content = re.sub(
|
||||
rf"(ARG FLASHINFER_VERSION=){re.escape(old_version)}",
|
||||
rf"\g<1>{new_version}",
|
||||
new_content,
|
||||
)
|
||||
elif name == "engine.py":
|
||||
new_content = re.sub(
|
||||
r'(assert_pkg_version\(\s*"flashinfer_python",\s*)"'
|
||||
+ re.escape(old_version)
|
||||
+ r'"',
|
||||
r'\g<1>"' + new_version + '"',
|
||||
new_content,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
elif name == "common.py":
|
||||
new_content = new_content.replace(
|
||||
f'e.g., "{old_version}"',
|
||||
f'e.g., "{new_version}"',
|
||||
)
|
||||
|
||||
if content == new_content:
|
||||
print(f"No changes needed in {file_path}")
|
||||
return False
|
||||
|
||||
file_path.write_text(new_content)
|
||||
print(f"✓ Updated {file_path}")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Bump flashinfer version across all relevant files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"new_version",
|
||||
help="New version (e.g., 0.6.4, 0.6.4rc0, or 0.6.4.post1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
new_version = normalize_version(args.new_version)
|
||||
|
||||
if not validate_version(new_version):
|
||||
print(f"Error: Invalid version format: {new_version}")
|
||||
print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN")
|
||||
print("Examples: 0.6.4, 0.6.4rc0, 0.6.4.post1")
|
||||
sys.exit(1)
|
||||
|
||||
repo_root = get_repo_root()
|
||||
old_version = read_current_flashinfer_version(repo_root)
|
||||
print(f"Current flashinfer version: {old_version}")
|
||||
print(f"New flashinfer version: {new_version}")
|
||||
print()
|
||||
|
||||
comparison = compare_versions(new_version, old_version)
|
||||
if comparison == 0:
|
||||
print("Error: New version is the same as current version")
|
||||
sys.exit(1)
|
||||
elif comparison < 0:
|
||||
print(
|
||||
f"Error: New version ({new_version}) is older than current version ({old_version})"
|
||||
)
|
||||
print("Version must be greater than the current version")
|
||||
sys.exit(1)
|
||||
|
||||
updated_count = 0
|
||||
for file_rel in FILES_TO_UPDATE:
|
||||
file_abs = repo_root / file_rel
|
||||
if replace_flashinfer_version(file_abs, old_version, new_version):
|
||||
updated_count += 1
|
||||
|
||||
print()
|
||||
print(f"Successfully updated {updated_count} file(s)")
|
||||
print(f"Flashinfer version bumped from {old_version} to {new_version}")
|
||||
|
||||
print("\nValidating version updates...")
|
||||
failed_files = []
|
||||
for file_rel in FILES_TO_UPDATE:
|
||||
file_abs = repo_root / file_rel
|
||||
if not file_abs.exists():
|
||||
print(f"Warning: File {file_rel} does not exist, skipping validation.")
|
||||
continue
|
||||
|
||||
content = file_abs.read_text()
|
||||
if new_version not in content:
|
||||
failed_files.append(file_rel)
|
||||
print(f"✗ {file_rel} does not contain version {new_version}")
|
||||
else:
|
||||
print(f"✓ {file_rel} validated")
|
||||
|
||||
if failed_files:
|
||||
print(f"\nError: {len(failed_files)} file(s) were not updated correctly:")
|
||||
for file_rel in failed_files:
|
||||
print(f" - {file_rel}")
|
||||
sys.exit(1)
|
||||
|
||||
print("\nAll files validated successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
33
third_party/sglang/scripts/release/bump_kernel_version.py
vendored
Executable file
33
third_party/sglang/scripts/release/bump_kernel_version.py
vendored
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from utils import bump_version
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Bump sgl-kernel version across all relevant files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"new_version",
|
||||
help="New version (e.g., 0.3.12, 0.3.11rc0, or 0.3.11.post1)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
version_file = Path("sgl-kernel/python/sgl_kernel/version.py")
|
||||
|
||||
files_to_update = [
|
||||
Path("sgl-kernel/pyproject.toml"),
|
||||
Path("sgl-kernel/pyproject_cpu.toml"),
|
||||
Path("sgl-kernel/pyproject_rocm.toml"),
|
||||
Path("sgl-kernel/pyproject_musa.toml"),
|
||||
Path("sgl-kernel/python/sgl_kernel/version.py"),
|
||||
]
|
||||
|
||||
bump_version(args.new_version, version_file, files_to_update)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
143
third_party/sglang/scripts/release/bump_kernel_version_to_sglang.py
vendored
Executable file
143
third_party/sglang/scripts/release/bump_kernel_version_to_sglang.py
vendored
Executable file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Bump sglang-kernel version in SGLang files to match the version in sgl-kernel/pyproject.toml.
|
||||
Updates:
|
||||
- python/pyproject.toml
|
||||
- python/sglang/srt/entrypoints/engine.py
|
||||
- docker/Dockerfile
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import tomllib # Python 3.11+
|
||||
except ImportError:
|
||||
import tomli as tomllib # Fallback for older Python versions
|
||||
|
||||
|
||||
def get_kernel_version_from_source() -> str:
|
||||
"""Extract version from sgl-kernel/pyproject.toml"""
|
||||
pyproject_path = Path("sgl-kernel/pyproject.toml")
|
||||
|
||||
if not pyproject_path.exists():
|
||||
print(f"Error: {pyproject_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
with open(pyproject_path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
version = data.get("project", {}).get("version")
|
||||
if not version:
|
||||
print("Error: Could not find version in sgl-kernel/pyproject.toml")
|
||||
sys.exit(1)
|
||||
|
||||
return version
|
||||
|
||||
|
||||
def update_python_pyproject(new_version: str) -> bool:
|
||||
"""Update sglang-kernel version in python/pyproject.toml"""
|
||||
pyproject_path = Path("python/pyproject.toml")
|
||||
|
||||
if not pyproject_path.exists():
|
||||
print(f"Error: {pyproject_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
content = pyproject_path.read_text()
|
||||
|
||||
# Replace "sglang-kernel==x.x.x" with new version
|
||||
new_content = re.sub(
|
||||
r'"sglang-kernel==[^"]+"',
|
||||
f'"sglang-kernel=={new_version}"',
|
||||
content,
|
||||
)
|
||||
|
||||
if content == new_content:
|
||||
print("No changes needed in python/pyproject.toml")
|
||||
return False
|
||||
|
||||
pyproject_path.write_text(new_content)
|
||||
print(f"✓ Updated python/pyproject.toml to version {new_version}")
|
||||
return True
|
||||
|
||||
|
||||
def update_engine_py(new_version: str) -> bool:
|
||||
"""Update sglang-kernel version in python/sglang/srt/entrypoints/engine.py"""
|
||||
engine_path = Path("python/sglang/srt/entrypoints/engine.py")
|
||||
|
||||
if not engine_path.exists():
|
||||
print(f"Error: {engine_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
content = engine_path.read_text()
|
||||
|
||||
# Replace version in assert_pkg_version("sglang-kernel", "version", ...)
|
||||
new_content = re.sub(
|
||||
r'(assert_pkg_version\s*\(\s*"sglang-kernel"\s*,\s*)"[^"]+"',
|
||||
rf'\1"{new_version}"',
|
||||
content,
|
||||
)
|
||||
|
||||
if content == new_content:
|
||||
print("No changes needed in engine.py")
|
||||
return False
|
||||
|
||||
engine_path.write_text(new_content)
|
||||
print(f"✓ Updated engine.py to version {new_version}")
|
||||
return True
|
||||
|
||||
|
||||
def update_dockerfile(new_version: str) -> bool:
|
||||
"""Update SGL_KERNEL_VERSION in docker/Dockerfile"""
|
||||
dockerfile_path = Path("docker/Dockerfile")
|
||||
|
||||
if not dockerfile_path.exists():
|
||||
print(f"Error: {dockerfile_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
content = dockerfile_path.read_text()
|
||||
|
||||
# Replace ARG SGL_KERNEL_VERSION=x.x.x with new version
|
||||
new_content = re.sub(
|
||||
r"^(ARG\s+SGL_KERNEL_VERSION=)(.+)$",
|
||||
rf"\g<1>{new_version}",
|
||||
content,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
if content == new_content:
|
||||
print("No changes needed in Dockerfile")
|
||||
return False
|
||||
|
||||
dockerfile_path.write_text(new_content)
|
||||
print(f"✓ Updated Dockerfile to version {new_version}")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
kernel_version = get_kernel_version_from_source()
|
||||
print(f"Bumping sglang-kernel version to: {kernel_version}\n")
|
||||
|
||||
updated_files = []
|
||||
|
||||
if update_python_pyproject(kernel_version):
|
||||
updated_files.append("python/pyproject.toml")
|
||||
|
||||
if update_engine_py(kernel_version):
|
||||
updated_files.append("python/sglang/srt/entrypoints/engine.py")
|
||||
|
||||
if update_dockerfile(kernel_version):
|
||||
updated_files.append("docker/Dockerfile")
|
||||
|
||||
print()
|
||||
if updated_files:
|
||||
print(f"✓ Successfully updated {len(updated_files)} file(s):")
|
||||
for file in updated_files:
|
||||
print(f" - {file}")
|
||||
else:
|
||||
print("✓ All files already have the correct version")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
150
third_party/sglang/scripts/release/check_kernel_version_to_sglang.py
vendored
Executable file
150
third_party/sglang/scripts/release/check_kernel_version_to_sglang.py
vendored
Executable file
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Check if sglang-kernel version from sgl-kernel/pyproject.toml matches the versions
|
||||
used in SGLang files (python/pyproject.toml, engine.py, and Dockerfile).
|
||||
Sets GitHub Actions output variables to indicate if sync is needed.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import tomllib # Python 3.11+
|
||||
except ImportError:
|
||||
import tomli as tomllib # Fallback for older Python versions
|
||||
|
||||
|
||||
def get_kernel_version_from_source() -> str:
|
||||
"""Extract version from sgl-kernel/pyproject.toml (line 11)"""
|
||||
pyproject_path = Path("sgl-kernel/pyproject.toml")
|
||||
|
||||
if not pyproject_path.exists():
|
||||
print(f"Error: {pyproject_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
with open(pyproject_path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
version = data.get("project", {}).get("version")
|
||||
if not version:
|
||||
print("Error: Could not find version in sgl-kernel/pyproject.toml")
|
||||
sys.exit(1)
|
||||
|
||||
return version
|
||||
|
||||
|
||||
def get_kernel_version_from_python_pyproject() -> str:
|
||||
"""Extract sglang-kernel version from python/pyproject.toml"""
|
||||
pyproject_path = Path("python/pyproject.toml")
|
||||
|
||||
if not pyproject_path.exists():
|
||||
print(f"Error: {pyproject_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
content = pyproject_path.read_text()
|
||||
|
||||
# Match "sglang-kernel==x.x.x"
|
||||
match = re.search(r'"sglang-kernel==([^"]+)"', content)
|
||||
if not match:
|
||||
print("Error: Could not find sglang-kernel version in python/pyproject.toml")
|
||||
sys.exit(1)
|
||||
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def get_kernel_version_from_engine() -> str:
|
||||
"""Extract sglang-kernel version from python/sglang/srt/entrypoints/engine.py"""
|
||||
engine_path = Path("python/sglang/srt/entrypoints/engine.py")
|
||||
|
||||
if not engine_path.exists():
|
||||
print(f"Error: {engine_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
content = engine_path.read_text()
|
||||
|
||||
# Find the assert_pkg_version call for sglang-kernel
|
||||
# Look for the pattern: assert_pkg_version("sglang-kernel", "version", ...)
|
||||
match = re.search(
|
||||
r'assert_pkg_version\s*\(\s*"sglang-kernel"\s*,\s*"([^"]+)"', content
|
||||
)
|
||||
if not match:
|
||||
print("Error: Could not find sglang-kernel version in engine.py")
|
||||
sys.exit(1)
|
||||
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def get_kernel_version_from_dockerfile() -> str:
|
||||
"""Extract SGL_KERNEL_VERSION from docker/Dockerfile"""
|
||||
dockerfile_path = Path("docker/Dockerfile")
|
||||
|
||||
if not dockerfile_path.exists():
|
||||
print(f"Error: {dockerfile_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
content = dockerfile_path.read_text()
|
||||
|
||||
# Match ARG SGL_KERNEL_VERSION=x.x.x
|
||||
match = re.search(r"^ARG\s+SGL_KERNEL_VERSION=(.+)$", content, re.MULTILINE)
|
||||
if not match:
|
||||
print("Error: Could not find SGL_KERNEL_VERSION in Dockerfile")
|
||||
sys.exit(1)
|
||||
|
||||
return match.group(1).strip()
|
||||
|
||||
|
||||
def main():
|
||||
kernel_version = get_kernel_version_from_source()
|
||||
pyproject_version = get_kernel_version_from_python_pyproject()
|
||||
engine_version = get_kernel_version_from_engine()
|
||||
dockerfile_version = get_kernel_version_from_dockerfile()
|
||||
|
||||
print(f"Kernel version in sgl-kernel/pyproject.toml: {kernel_version}")
|
||||
print(
|
||||
f"SGLang kernel dependency version in python/pyproject.toml: {pyproject_version}"
|
||||
)
|
||||
print(f"SGLang kernel dependency version in engine.py: {engine_version}")
|
||||
print(f"Kernel version in Dockerfile: {dockerfile_version}")
|
||||
|
||||
# Check if any version differs from the source
|
||||
needs_sync = (
|
||||
kernel_version != pyproject_version
|
||||
or kernel_version != engine_version
|
||||
or kernel_version != dockerfile_version
|
||||
)
|
||||
|
||||
# Set GitHub Actions output
|
||||
github_output = os.getenv("GITHUB_OUTPUT")
|
||||
if github_output:
|
||||
with open(github_output, "a") as f:
|
||||
f.write(f"needs_sync={'true' if needs_sync else 'false'}\n")
|
||||
f.write(f"kernel_version={kernel_version}\n")
|
||||
|
||||
if needs_sync:
|
||||
print(f"\n✓ Sync needed to version: {kernel_version}")
|
||||
mismatches = []
|
||||
if kernel_version != pyproject_version:
|
||||
mismatches.append(
|
||||
f" - python/pyproject.toml: {pyproject_version} → {kernel_version}"
|
||||
)
|
||||
if kernel_version != engine_version:
|
||||
mismatches.append(f" - engine.py: {engine_version} → {kernel_version}")
|
||||
if kernel_version != dockerfile_version:
|
||||
mismatches.append(
|
||||
f" - Dockerfile: {dockerfile_version} → {kernel_version}"
|
||||
)
|
||||
|
||||
print("Changes needed:")
|
||||
for mismatch in mismatches:
|
||||
print(mismatch)
|
||||
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n✓ All versions are in sync, no action needed")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
72
third_party/sglang/scripts/release/commit_and_pr.sh
vendored
Executable file
72
third_party/sglang/scripts/release/commit_and_pr.sh
vendored
Executable file
@@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Script to commit version bump changes and create a pull request
|
||||
# Usage: commit_and_pr.sh <version_type> <new_version> <branch_name>
|
||||
#
|
||||
# Arguments:
|
||||
# version_type: "SGLang" or "sgl-kernel"
|
||||
# new_version: The new version number
|
||||
# branch_name: The git branch name to push to
|
||||
|
||||
VERSION_TYPE="$1"
|
||||
NEW_VERSION="$2"
|
||||
BRANCH_NAME="$3"
|
||||
|
||||
if [ -z "$VERSION_TYPE" ] || [ -z "$NEW_VERSION" ] || [ -z "$BRANCH_NAME" ]; then
|
||||
echo "Error: Missing required arguments"
|
||||
echo "Usage: $0 <version_type> <new_version> <branch_name>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get changed files and format them
|
||||
echo "Getting changed files..."
|
||||
FILES_LIST=$(git diff --name-only | sed 's/^/- /')
|
||||
COMMIT_FILES=$(git diff --name-only | sed 's/^/ - /')
|
||||
|
||||
# Commit changes
|
||||
echo "Committing changes..."
|
||||
git add -A
|
||||
git commit -m "chore: bump ${VERSION_TYPE} version to ${NEW_VERSION}
|
||||
|
||||
This commit updates the ${VERSION_TYPE} version across all relevant files:
|
||||
${COMMIT_FILES}
|
||||
|
||||
🤖 Generated with GitHub Actions"
|
||||
|
||||
# Push changes
|
||||
echo "Pushing to ${BRANCH_NAME}..."
|
||||
git push origin "${BRANCH_NAME}"
|
||||
|
||||
# Create pull request
|
||||
echo "Creating pull request..."
|
||||
PR_URL=$(gh pr create \
|
||||
--title "chore: bump ${VERSION_TYPE} version to ${NEW_VERSION}" \
|
||||
--body "## Summary
|
||||
|
||||
This PR bumps the ${VERSION_TYPE} version to \`${NEW_VERSION}\` across all relevant files.
|
||||
|
||||
## Files Updated
|
||||
${FILES_LIST}
|
||||
|
||||
🤖 Generated with GitHub Actions" \
|
||||
--base main \
|
||||
--head "${BRANCH_NAME}")
|
||||
|
||||
echo "✓ Pull request created successfully"
|
||||
|
||||
# Add GitHub Actions job summary
|
||||
if [ -n "$GITHUB_STEP_SUMMARY" ]; then
|
||||
cat >> "$GITHUB_STEP_SUMMARY" <<EOF
|
||||
## ✅ Version Bump Complete
|
||||
|
||||
**Version Type:** ${VERSION_TYPE}
|
||||
**New Version:** \`${NEW_VERSION}\`
|
||||
|
||||
### 📝 Pull Request Created
|
||||
${PR_URL}
|
||||
|
||||
### 📦 Files Updated
|
||||
${FILES_LIST}
|
||||
EOF
|
||||
fi
|
||||
81
third_party/sglang/scripts/release/commit_and_pr_kernel_to_sglang.sh
vendored
Executable file
81
third_party/sglang/scripts/release/commit_and_pr_kernel_to_sglang.sh
vendored
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Script to commit kernel version bump changes to SGLang and create a pull request
|
||||
# Usage: commit_and_pr_kernel_to_sglang.sh <kernel_version> <branch_name>
|
||||
#
|
||||
# Arguments:
|
||||
# kernel_version: The kernel version being synced
|
||||
# branch_name: The git branch name to push to
|
||||
|
||||
KERNEL_VERSION="$1"
|
||||
BRANCH_NAME="$2"
|
||||
|
||||
if [ -z "$KERNEL_VERSION" ] || [ -z "$BRANCH_NAME" ]; then
|
||||
echo "Error: Missing required arguments"
|
||||
echo "Usage: $0 <kernel_version> <branch_name>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get changed files and format them
|
||||
echo "Getting changed files..."
|
||||
FILES_LIST=$(git diff --name-only | sed 's/^/- /')
|
||||
COMMIT_FILES=$(git diff --name-only | sed 's/^/ - /')
|
||||
|
||||
# Commit changes
|
||||
echo "Committing changes..."
|
||||
git add -A
|
||||
git commit -m "chore: bump sglang-kernel version to ${KERNEL_VERSION} in SGLang
|
||||
|
||||
This commit updates the sglang-kernel version across SGLang files to match
|
||||
the version defined in sgl-kernel/pyproject.toml.
|
||||
|
||||
Files updated:
|
||||
${COMMIT_FILES}
|
||||
|
||||
🤖 Generated with GitHub Actions"
|
||||
|
||||
# Push changes
|
||||
echo "Pushing to ${BRANCH_NAME}..."
|
||||
git push origin "${BRANCH_NAME}"
|
||||
|
||||
# Create pull request
|
||||
echo "Creating pull request..."
|
||||
PR_URL=$(gh pr create \
|
||||
--title "chore: bump sglang-kernel version to ${KERNEL_VERSION}" \
|
||||
--body "## Summary
|
||||
|
||||
This PR bumps the \`sglang-kernel\` version to \`${KERNEL_VERSION}\` across SGLang files to match the version defined in \`sgl-kernel/pyproject.toml\`.
|
||||
|
||||
**Kernel Version:** \`${KERNEL_VERSION}\`
|
||||
|
||||
## Files Updated
|
||||
${FILES_LIST}
|
||||
|
||||
## Context
|
||||
|
||||
The kernel version in \`sgl-kernel/pyproject.toml\` has been updated. This PR ensures that all SGLang files referencing the \`sglang-kernel\` dependency are updated accordingly:
|
||||
- \`python/pyproject.toml\` - dependency specification
|
||||
- \`python/sglang/srt/entrypoints/engine.py\` - version check
|
||||
- \`docker/Dockerfile\` - Docker build argument
|
||||
|
||||
🤖 Generated with GitHub Actions" \
|
||||
--base main \
|
||||
--head "${BRANCH_NAME}")
|
||||
|
||||
echo "✓ Pull request created successfully"
|
||||
|
||||
# Add GitHub Actions job summary
|
||||
if [ -n "$GITHUB_STEP_SUMMARY" ]; then
|
||||
cat >> "$GITHUB_STEP_SUMMARY" <<EOF
|
||||
## ✅ Kernel Version Bump Complete
|
||||
|
||||
**Kernel Version:** \`${KERNEL_VERSION}\`
|
||||
|
||||
### 📝 Pull Request Created
|
||||
${PR_URL}
|
||||
|
||||
### 📦 Files Updated
|
||||
${FILES_LIST}
|
||||
EOF
|
||||
fi
|
||||
159
third_party/sglang/scripts/release/test_utils.py
vendored
Executable file
159
third_party/sglang/scripts/release/test_utils.py
vendored
Executable file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from utils import compare_versions, normalize_version, parse_version, validate_version
|
||||
|
||||
|
||||
class TestVersionUtils(unittest.TestCase):
|
||||
def test_normalize_version(self):
|
||||
"""Test version normalization removes 'v' prefix."""
|
||||
self.assertEqual(normalize_version("v0.5.3"), "0.5.3")
|
||||
self.assertEqual(normalize_version("0.5.3"), "0.5.3")
|
||||
self.assertEqual(normalize_version("v0.5.3rc0"), "0.5.3rc0")
|
||||
self.assertEqual(normalize_version("0.5.3.post1"), "0.5.3.post1")
|
||||
|
||||
def test_validate_version(self):
|
||||
"""Test version format validation."""
|
||||
# Valid formats
|
||||
self.assertTrue(validate_version("0.5.3"))
|
||||
self.assertTrue(validate_version("0.5.3rc0"))
|
||||
self.assertTrue(validate_version("0.5.3rc1"))
|
||||
self.assertTrue(validate_version("0.5.3rc999"))
|
||||
self.assertTrue(validate_version("0.5.3.post1"))
|
||||
self.assertTrue(validate_version("0.5.3.post10"))
|
||||
self.assertTrue(validate_version("1.2.3"))
|
||||
self.assertTrue(validate_version("10.20.30"))
|
||||
|
||||
# Invalid formats
|
||||
self.assertFalse(validate_version("0.5"))
|
||||
self.assertFalse(validate_version("0.5.3."))
|
||||
self.assertFalse(validate_version("0.5.3rc"))
|
||||
self.assertFalse(validate_version("0.5.3post1"))
|
||||
self.assertFalse(validate_version("0.5.3-rc0"))
|
||||
self.assertFalse(validate_version("v0.5.3"))
|
||||
self.assertFalse(validate_version("0.5.3beta1"))
|
||||
self.assertFalse(validate_version("0.5.3.rc0"))
|
||||
|
||||
def test_parse_version_stable(self):
|
||||
"""Test parsing stable version."""
|
||||
self.assertEqual(parse_version("0.5.3"), (0, 5, 3, 0, 0))
|
||||
self.assertEqual(parse_version("1.2.3"), (1, 2, 3, 0, 0))
|
||||
self.assertEqual(parse_version("10.20.30"), (10, 20, 30, 0, 0))
|
||||
|
||||
def test_parse_version_rc(self):
|
||||
"""Test parsing release candidate versions."""
|
||||
self.assertEqual(parse_version("0.5.3rc0"), (0, 5, 3, -1000, 0))
|
||||
self.assertEqual(parse_version("0.5.3rc1"), (0, 5, 3, -999, 0))
|
||||
self.assertEqual(parse_version("0.5.3rc2"), (0, 5, 3, -998, 0))
|
||||
self.assertEqual(parse_version("0.5.3rc10"), (0, 5, 3, -990, 0))
|
||||
|
||||
def test_parse_version_post(self):
|
||||
"""Test parsing post-release versions."""
|
||||
self.assertEqual(parse_version("0.5.3.post1"), (0, 5, 3, 0, 1))
|
||||
self.assertEqual(parse_version("0.5.3.post2"), (0, 5, 3, 0, 2))
|
||||
self.assertEqual(parse_version("0.5.3.post10"), (0, 5, 3, 0, 10))
|
||||
|
||||
def test_parse_version_invalid(self):
|
||||
"""Test parsing invalid versions raises error."""
|
||||
with self.assertRaises(ValueError):
|
||||
parse_version("0.5")
|
||||
with self.assertRaises(ValueError):
|
||||
parse_version("invalid")
|
||||
with self.assertRaises(ValueError):
|
||||
parse_version("v0.5.3")
|
||||
|
||||
def test_compare_versions_equal(self):
|
||||
"""Test comparing equal versions."""
|
||||
self.assertEqual(compare_versions("0.5.3", "0.5.3"), 0)
|
||||
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3rc0"), 0)
|
||||
self.assertEqual(compare_versions("0.5.3.post1", "0.5.3.post1"), 0)
|
||||
|
||||
def test_compare_versions_rc_ordering(self):
|
||||
"""Test release candidate ordering: rc0 < rc1 < rc2 < stable."""
|
||||
# rc0 < rc1
|
||||
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3rc1"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3rc1", "0.5.3rc0"), 1)
|
||||
|
||||
# rc1 < rc2
|
||||
self.assertEqual(compare_versions("0.5.3rc1", "0.5.3rc2"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3rc2", "0.5.3rc1"), 1)
|
||||
|
||||
# rc < stable
|
||||
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3rc1", "0.5.3"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3", "0.5.3rc0"), 1)
|
||||
|
||||
def test_compare_versions_post_ordering(self):
|
||||
"""Test post-release ordering: stable < post1 < post2."""
|
||||
# stable < post1
|
||||
self.assertEqual(compare_versions("0.5.3", "0.5.3.post1"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3.post1", "0.5.3"), 1)
|
||||
|
||||
# post1 < post2
|
||||
self.assertEqual(compare_versions("0.5.3.post1", "0.5.3.post2"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3.post2", "0.5.3.post1"), 1)
|
||||
|
||||
def test_compare_versions_full_ordering(self):
|
||||
"""Test complete version ordering: rc < stable < post."""
|
||||
# rc < stable < post
|
||||
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3", "0.5.3.post1"), -1)
|
||||
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3.post1"), -1)
|
||||
|
||||
# Verify transitivity: rc0 < rc1 < stable < post1 < post2
|
||||
versions = [
|
||||
"0.5.3rc0",
|
||||
"0.5.3rc1",
|
||||
"0.5.3",
|
||||
"0.5.3.post1",
|
||||
"0.5.3.post2",
|
||||
]
|
||||
for i in range(len(versions) - 1):
|
||||
self.assertEqual(
|
||||
compare_versions(versions[i], versions[i + 1]),
|
||||
-1,
|
||||
f"{versions[i]} should be less than {versions[i + 1]}",
|
||||
)
|
||||
|
||||
def test_compare_versions_different_patch(self):
|
||||
"""Test comparing versions with different patch numbers."""
|
||||
# 0.5.3 < 0.5.4
|
||||
self.assertEqual(compare_versions("0.5.3", "0.5.4"), -1)
|
||||
self.assertEqual(compare_versions("0.5.4", "0.5.3"), 1)
|
||||
|
||||
# rc of higher patch > stable of lower patch
|
||||
self.assertEqual(compare_versions("0.5.4rc0", "0.5.3"), 1)
|
||||
self.assertEqual(compare_versions("0.5.3.post1", "0.5.4rc0"), -1)
|
||||
|
||||
def test_compare_versions_different_minor(self):
|
||||
"""Test comparing versions with different minor numbers."""
|
||||
self.assertEqual(compare_versions("0.4.9", "0.5.0"), -1)
|
||||
self.assertEqual(compare_versions("0.5.0", "0.4.9"), 1)
|
||||
|
||||
def test_compare_versions_different_major(self):
|
||||
"""Test comparing versions with different major numbers."""
|
||||
self.assertEqual(compare_versions("0.9.9", "1.0.0"), -1)
|
||||
self.assertEqual(compare_versions("1.0.0", "0.9.9"), 1)
|
||||
|
||||
def test_real_world_scenarios(self):
|
||||
"""Test real-world version bump scenarios."""
|
||||
# Scenario 1: RC progression
|
||||
self.assertEqual(compare_versions("0.5.3rc0", "0.5.3rc1"), -1)
|
||||
|
||||
# Scenario 2: RC to stable release
|
||||
self.assertEqual(compare_versions("0.5.3rc2", "0.5.3"), -1)
|
||||
|
||||
# Scenario 3: Stable to post-release hotfix
|
||||
self.assertEqual(compare_versions("0.5.3", "0.5.3.post1"), -1)
|
||||
|
||||
# Scenario 4: Post-release to next RC
|
||||
self.assertEqual(compare_versions("0.5.3.post1", "0.5.4rc0"), -1)
|
||||
|
||||
# Scenario 5: Next stable version
|
||||
self.assertEqual(compare_versions("0.5.3", "0.5.4"), -1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
220
third_party/sglang/scripts/release/utils.py
vendored
Normal file
220
third_party/sglang/scripts/release/utils.py
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
try:
|
||||
import tomllib # Python 3.11+
|
||||
except ImportError:
|
||||
import tomli as tomllib # Fallback for older Python versions
|
||||
|
||||
|
||||
def normalize_version(version: str) -> str:
|
||||
"""Remove 'v' prefix from version string if present."""
|
||||
return version.lstrip("v")
|
||||
|
||||
|
||||
def validate_version(version: str) -> bool:
|
||||
"""Validate version format: X.Y.Z, X.Y.Zrc0, or X.Y.Z.post1"""
|
||||
pattern = r"^\d+\.\d+\.\d+(rc\d+|\.post\d+)?$"
|
||||
return bool(re.match(pattern, version))
|
||||
|
||||
|
||||
def parse_version(version: str) -> Tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Parse version string into comparable components.
|
||||
|
||||
Returns: (major, minor, patch, pre_release, post_release)
|
||||
- pre_release: -1000 + rc_number for rcN, 0 for stable (rc0 < rc1 < stable)
|
||||
- post_release: N for .postN, 0 otherwise
|
||||
|
||||
The pre_release field uses negative numbers to ensure RC versions come before
|
||||
stable versions when tuples are compared. Python compares tuples element by
|
||||
element, so (0, 5, 3, -1000, 0) < (0, 5, 3, 0, 0) ensures rc0 < stable.
|
||||
|
||||
Examples:
|
||||
- "0.5.3rc0" → (0, 5, 3, -1000, 0) # rc0 comes before stable
|
||||
- "0.5.3rc1" → (0, 5, 3, -999, 0) # rc1 comes after rc0
|
||||
- "0.5.3" → (0, 5, 3, 0, 0) # stable version
|
||||
- "0.5.3.post1" → (0, 5, 3, 0, 1) # post comes after stable
|
||||
"""
|
||||
# Match version components
|
||||
match = re.match(r"^(\d+)\.(\d+)\.(\d+)(?:rc(\d+)|\.post(\d+))?$", version)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid version format: {version}")
|
||||
|
||||
major, minor, patch, rc, post = match.groups()
|
||||
major, minor, patch = int(major), int(minor), int(patch)
|
||||
|
||||
if rc is not None:
|
||||
# RC version: pre_release = -1000 + rc_number (ensures rc0 < rc1 < ... < stable)
|
||||
return (major, minor, patch, -1000 + int(rc), 0)
|
||||
elif post is not None:
|
||||
# Post version: post_release = N
|
||||
return (major, minor, patch, 0, int(post))
|
||||
else:
|
||||
# Stable version
|
||||
return (major, minor, patch, 0, 0)
|
||||
|
||||
|
||||
def compare_versions(v1: str, v2: str) -> int:
|
||||
"""
|
||||
Compare two version strings following PEP 440 ordering.
|
||||
|
||||
Returns:
|
||||
- -1 if v1 < v2
|
||||
- 0 if v1 == v2
|
||||
- 1 if v1 > v2
|
||||
|
||||
Version ordering: X.Y.ZrcN < X.Y.Z < X.Y.Z.postN < X.Y.(Z+1)
|
||||
"""
|
||||
parsed_v1 = parse_version(v1)
|
||||
parsed_v2 = parse_version(v2)
|
||||
|
||||
if parsed_v1 < parsed_v2:
|
||||
return -1
|
||||
elif parsed_v1 > parsed_v2:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_repo_root() -> Path:
|
||||
return Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def read_current_version(version_file: Path) -> str:
|
||||
content = version_file.read_text()
|
||||
match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', content)
|
||||
if not match:
|
||||
raise ValueError(f"Could not find version in {version_file}")
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def replace_in_file(file_path: Path, old_version: str, new_version: str) -> bool:
|
||||
if not file_path.exists():
|
||||
print(f"Warning: {file_path} does not exist, skipping")
|
||||
return False
|
||||
|
||||
content = file_path.read_text()
|
||||
|
||||
# For TOML files, parse and update only the [project] version field
|
||||
if file_path.suffix == ".toml":
|
||||
try:
|
||||
# Parse TOML to verify structure
|
||||
toml_data = tomllib.loads(content)
|
||||
|
||||
# Check if [project] section exists and has version field
|
||||
if "project" not in toml_data or "version" not in toml_data["project"]:
|
||||
print(
|
||||
f"Warning: {file_path} does not have [project] version field, skipping"
|
||||
)
|
||||
return False
|
||||
|
||||
# Use regex to replace only the version field in [project] section
|
||||
# This pattern matches the version field that comes after [project]
|
||||
# and before any other section marker
|
||||
pattern = r'(\[project\].*?version\s*=\s*)["\']([^"\']+)["\']'
|
||||
new_content = re.sub(
|
||||
pattern, rf'\g<1>"{new_version}"', content, flags=re.DOTALL
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse {file_path} as TOML: {e}")
|
||||
print("Falling back to simple string replacement")
|
||||
new_content = content.replace(old_version, new_version)
|
||||
else:
|
||||
# For non-TOML files, use simple string replacement
|
||||
new_content = content.replace(old_version, new_version)
|
||||
|
||||
if content == new_content:
|
||||
print(f"No changes needed in {file_path}")
|
||||
return False
|
||||
|
||||
file_path.write_text(new_content)
|
||||
print(f"✓ Updated {file_path}")
|
||||
return True
|
||||
|
||||
|
||||
def bump_version(
|
||||
new_version: str,
|
||||
version_file: Path,
|
||||
files_to_update: List[Path],
|
||||
) -> None:
|
||||
# Normalize version (remove 'v' prefix if present)
|
||||
new_version = normalize_version(new_version)
|
||||
|
||||
if not validate_version(new_version):
|
||||
print(f"Error: Invalid version format: {new_version}")
|
||||
print("Expected format: X.Y.Z, X.Y.ZrcN, or X.Y.Z.postN")
|
||||
print("Examples: 0.5.4, 0.5.3rc0, 0.5.3.post1")
|
||||
sys.exit(1)
|
||||
|
||||
repo_root = get_repo_root()
|
||||
version_file_abs = repo_root / version_file
|
||||
|
||||
if not version_file_abs.exists():
|
||||
print(f"Error: Version file {version_file_abs} does not exist")
|
||||
sys.exit(1)
|
||||
|
||||
old_version = read_current_version(version_file_abs)
|
||||
print(f"Current version: {old_version}")
|
||||
print(f"New version: {new_version}")
|
||||
print()
|
||||
|
||||
# Compare versions
|
||||
comparison = compare_versions(new_version, old_version)
|
||||
if comparison == 0:
|
||||
print("Error: New version is the same as current version")
|
||||
sys.exit(1)
|
||||
elif comparison < 0:
|
||||
print(
|
||||
f"Error: New version ({new_version}) is older than current version ({old_version})"
|
||||
)
|
||||
print("Version must be greater than the current version")
|
||||
sys.exit(1)
|
||||
|
||||
updated_count = 0
|
||||
for file_rel in files_to_update:
|
||||
file_abs = repo_root / file_rel
|
||||
if replace_in_file(file_abs, old_version, new_version):
|
||||
updated_count += 1
|
||||
|
||||
print()
|
||||
print(f"Successfully updated {updated_count} file(s)")
|
||||
print(f"Version bumped from {old_version} to {new_version}")
|
||||
|
||||
# Validate that all files now contain the new version
|
||||
print("\nValidating version updates...")
|
||||
failed_files = []
|
||||
for file_rel in files_to_update:
|
||||
file_abs = repo_root / file_rel
|
||||
if not file_abs.exists():
|
||||
print(f"Warning: File {file_rel} does not exist, skipping validation.")
|
||||
continue
|
||||
|
||||
content = file_abs.read_text()
|
||||
|
||||
# For TOML files, use regex to specifically check the version field
|
||||
if file_abs.suffix == ".toml":
|
||||
# Match version field with optional quotes
|
||||
pattern = r'version\s*=\s*["\']?' + re.escape(new_version) + r'["\']?'
|
||||
if not re.search(pattern, content):
|
||||
failed_files.append(file_rel)
|
||||
print(f"✗ {file_rel} does not contain version {new_version}")
|
||||
else:
|
||||
print(f"✓ {file_rel} validated")
|
||||
else:
|
||||
# For non-TOML files, use simple string search
|
||||
if new_version not in content:
|
||||
failed_files.append(file_rel)
|
||||
print(f"✗ {file_rel} does not contain version {new_version}")
|
||||
else:
|
||||
print(f"✓ {file_rel} validated")
|
||||
|
||||
if failed_files:
|
||||
print(f"\nError: {len(failed_files)} file(s) were not updated correctly:")
|
||||
for file_rel in failed_files:
|
||||
print(f" - {file_rel}")
|
||||
sys.exit(1)
|
||||
|
||||
print("\nAll files validated successfully!")
|
||||
27
third_party/sglang/scripts/sort_testcases_alphabetically.py
vendored
Normal file
27
third_party/sglang/scripts/sort_testcases_alphabetically.py
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Sort the test case by name alphabetically for run_suite.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestFile:
|
||||
name: str
|
||||
estimated_time: float = 60
|
||||
|
||||
|
||||
suites = {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for key in suites:
|
||||
cases = suites[key]
|
||||
names = [x.name for x in cases]
|
||||
names.sort()
|
||||
|
||||
print(f' "{key}": [')
|
||||
for name in names:
|
||||
estimated_time = [x.estimated_time for x in cases if x.name == name][0]
|
||||
print(f' TestFile("{name}", {estimated_time}),')
|
||||
print(f" ],\n")
|
||||
92
third_party/sglang/scripts/update_kernel_whl_index.py
vendored
Normal file
92
third_party/sglang/scripts/update_kernel_whl_index.py
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
# All the CUDA versions that the wheels will cover
|
||||
SUPPORTED_CUDA_VERSIONS = ["129", "130"]
|
||||
DEFAULT_CUDA_VERSION = "129"
|
||||
|
||||
|
||||
def check_wheel_cuda_version(path_name, target_cuda_version):
|
||||
# Pass ROCm wheel
|
||||
if re.search(f"rocm", path_name):
|
||||
return False
|
||||
|
||||
# For other CUDA versions, the wheel path name will contain the cuda version suffix, e.g. sglang_kernel-0.4.0+cu130-cp310-abi3-manylinux2014_x86_64.whl
|
||||
if target_cuda_version != DEFAULT_CUDA_VERSION:
|
||||
return target_cuda_version in path_name
|
||||
|
||||
# For the default CUDA version, the wheel path name will not contain any cuda version suffix, e.g. sglang_kernel-0.4.0-cp310-abi3-manylinux2014_x86_64.whl
|
||||
# So we need to check if the wheel path name contains any other cuda version suffix
|
||||
for cuda_version in SUPPORTED_CUDA_VERSIONS:
|
||||
if cuda_version != DEFAULT_CUDA_VERSION and cuda_version in path_name:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def update_wheel_index(cuda_version=DEFAULT_CUDA_VERSION, rocm_version=None):
|
||||
index_dir = pathlib.Path(f"sgl-whl/cu{cuda_version}/sglang-kernel")
|
||||
index_dir.mkdir(exist_ok=True, parents=True)
|
||||
base_url = "https://github.com/sgl-project/whl/releases/download"
|
||||
|
||||
for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")):
|
||||
# Skip the wheel if mismatches the passed in cuda_version
|
||||
if not check_wheel_cuda_version(path.name, cuda_version):
|
||||
continue
|
||||
with open(path, "rb") as f:
|
||||
sha256 = hashlib.sha256(f.read()).hexdigest()
|
||||
ver = re.findall(
|
||||
r"sglang_kernel-([0-9.]+(?:\.post[0-9]+)?)(?:\+cu[0-9]+)?-", path.name
|
||||
)[0]
|
||||
full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}"
|
||||
with (index_dir / "index.html").open("a") as f:
|
||||
f.write(f'<a href="{full_url}">{path.name}</a><br>\n')
|
||||
|
||||
|
||||
def _update_non_cuda_wheel_index(backend, version):
|
||||
index_dir = pathlib.Path(f"sgl-whl/{backend}{version}/sglang-kernel")
|
||||
index_dir.mkdir(exist_ok=True, parents=True)
|
||||
base_url = "https://github.com/sgl-project/whl/releases/download"
|
||||
|
||||
for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")):
|
||||
# Skip the wheel if not for this backend
|
||||
if re.search(f"{backend}", path.name) is None:
|
||||
continue
|
||||
with open(path, "rb") as f:
|
||||
sha256 = hashlib.sha256(f.read()).hexdigest()
|
||||
ver = re.findall(
|
||||
rf"sglang_kernel-([0-9.]+(?:\.post[0-9]+)?)(?:\+{backend}[0-9]+)?-",
|
||||
path.name,
|
||||
)[0]
|
||||
full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}"
|
||||
with (index_dir / "index.html").open("a") as f:
|
||||
f.write(f'<a href="{full_url}">{path.name}</a><br>\n')
|
||||
|
||||
|
||||
def update_wheel_index_rocm(rocm_version):
|
||||
_update_non_cuda_wheel_index("rocm", rocm_version)
|
||||
|
||||
|
||||
def update_wheel_index_musa(musa_version):
|
||||
_update_non_cuda_wheel_index("musa", musa_version)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--cuda", type=str, default=DEFAULT_CUDA_VERSION)
|
||||
parser.add_argument("--rocm", type=str, default=None)
|
||||
parser.add_argument("--musa", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
if args.musa is not None:
|
||||
update_wheel_index_musa(args.musa)
|
||||
elif args.rocm is not None:
|
||||
update_wheel_index_rocm(args.rocm)
|
||||
else:
|
||||
update_wheel_index(args.cuda)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
203
third_party/sglang/scripts/update_nightly_whl_index.py
vendored
Executable file
203
third_party/sglang/scripts/update_nightly_whl_index.py
vendored
Executable file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Update the wheel index for nightly SGLang releases.
|
||||
|
||||
This script generates a PyPI-compatible index.html file at cu{version}/sglang/index.html
|
||||
containing all historical nightly builds, ordered by commit count (newest first).
|
||||
|
||||
The CUDA version is specified via the --cuda-version argument.
|
||||
|
||||
Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
|
||||
def compute_sha256(file_path: pathlib.Path) -> str:
|
||||
"""Compute SHA256 hash of a file."""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def update_wheel_index(
|
||||
commit_hash: str, nightly_version: str, cuda_version: str, build_date: str = None
|
||||
):
|
||||
"""Update the wheel index for nightly releases.
|
||||
|
||||
Creates an index at cu{version}/sglang/index.html containing all historical nightlies.
|
||||
|
||||
Args:
|
||||
commit_hash: Short git commit hash (e.g., 'c5f1e86')
|
||||
nightly_version: Full nightly version string (e.g., '0.5.6.post1.dev7716+gc5f1e86')
|
||||
cuda_version: CUDA version string (e.g., '129' or '130')
|
||||
build_date: Build date in YYYY-MM-DD format (e.g., '2025-12-13')
|
||||
"""
|
||||
dist_dir = pathlib.Path("dist")
|
||||
whl_repo_dir = pathlib.Path("sgl-whl")
|
||||
|
||||
if not dist_dir.exists():
|
||||
print(f"Warning: {dist_dir} does not exist, skipping index update")
|
||||
return
|
||||
|
||||
# Format CUDA version with 'cu' prefix if not already present
|
||||
if not cuda_version.startswith("cu"):
|
||||
cuda_version = f"cu{cuda_version}"
|
||||
print(f"Using CUDA version: {cuda_version}")
|
||||
|
||||
# Base URL for wheels stored in GitHub Releases
|
||||
base_url = "https://github.com/sgl-project/whl/releases/download"
|
||||
# Use date-based tag if build_date is provided, otherwise fall back to commit-only
|
||||
if build_date:
|
||||
release_tag = f"nightly-{build_date}-{commit_hash}"
|
||||
else:
|
||||
release_tag = f"nightly-{commit_hash}"
|
||||
|
||||
# Create directory structure following PEP 503
|
||||
# /cu{version}/index.html -> links to sglang/ and sgl-kernel/
|
||||
# /cu{version}/sglang/index.html -> contains wheel links
|
||||
cuda_dir = whl_repo_dir / cuda_version
|
||||
cuda_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sglang_dir = cuda_dir / "sglang"
|
||||
sglang_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
root_index = cuda_dir / "index.html"
|
||||
package_index = sglang_dir / "index.html"
|
||||
|
||||
print(f"\nUpdating nightly wheel index")
|
||||
print(f" Root index: {root_index}")
|
||||
print(f" Package index: {package_index}")
|
||||
|
||||
# Read existing package index if it exists
|
||||
existing_links = []
|
||||
if package_index.exists():
|
||||
with open(package_index, "r") as f:
|
||||
content = f.read()
|
||||
# Extract existing links (skip header and HTML boilerplate)
|
||||
existing_links = [
|
||||
line for line in content.split("\n") if line.startswith("<a href=")
|
||||
]
|
||||
|
||||
# Generate new links for current wheels
|
||||
new_links = []
|
||||
for wheel_path in sorted(dist_dir.glob("*.whl")):
|
||||
try:
|
||||
filename = wheel_path.name
|
||||
sha256 = compute_sha256(wheel_path)
|
||||
|
||||
# URL format: {base_url}/{release_tag}/{filename}#sha256={hash}
|
||||
wheel_url = f"{base_url}/{release_tag}/{filename}#sha256={sha256}"
|
||||
link = f'<a href="{wheel_url}">{filename}</a><br>'
|
||||
|
||||
new_links.append(link)
|
||||
print(f" Added: {filename}")
|
||||
except Exception as e:
|
||||
print(f" Error processing {wheel_path.name}: {e}")
|
||||
continue
|
||||
|
||||
if not new_links:
|
||||
print(" No new wheels to add")
|
||||
return
|
||||
|
||||
# Combine existing and new links (new links first for latest)
|
||||
all_links = new_links + existing_links
|
||||
|
||||
# Remove duplicates while preserving order (newer first)
|
||||
seen = set()
|
||||
unique_links = []
|
||||
for link in all_links:
|
||||
# Extract filename from link to check for duplicates
|
||||
filename_match = re.search(r">([^<]+\.whl)</a>", link)
|
||||
if filename_match:
|
||||
filename = filename_match.group(1)
|
||||
if filename not in seen:
|
||||
seen.add(filename)
|
||||
unique_links.append(link)
|
||||
|
||||
# Update root index to include both sgl-kernel and sglang
|
||||
# Read existing packages from root index if it exists
|
||||
existing_packages = set()
|
||||
if root_index.exists():
|
||||
with open(root_index, "r") as f:
|
||||
content = f.read()
|
||||
# Extract existing package links
|
||||
for match in re.finditer(r'<a href="([^"]+)/">', content):
|
||||
existing_packages.add(match.group(1))
|
||||
|
||||
# Add sglang to the package list
|
||||
existing_packages.add("sglang")
|
||||
|
||||
# Write root index with all packages (sorted for consistency)
|
||||
with open(root_index, "w") as f:
|
||||
f.write("<!DOCTYPE html>\n")
|
||||
for pkg in sorted(existing_packages):
|
||||
f.write(f'<a href="{pkg}/">{pkg}</a>\n')
|
||||
|
||||
print(f" Written root index: {root_index} (packages: {sorted(existing_packages)})")
|
||||
|
||||
# Write package index in minimal format (matching production sgl-kernel index)
|
||||
with open(package_index, "w") as f:
|
||||
f.write("<!DOCTYPE html>\n")
|
||||
f.write(f"<h1>SGLang Nightly Wheels ({cuda_version})</h1>\n")
|
||||
# Write links only
|
||||
f.write("\n".join(unique_links))
|
||||
f.write("\n")
|
||||
|
||||
print(f" Written {len(unique_links)} total wheels to {package_index}")
|
||||
print(f"\nDone! Users can install with:")
|
||||
print(
|
||||
f" pip install sglang --pre --extra-index-url https://sgl-project.github.io/whl/{cuda_version}/"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Update wheel index for nightly SGLang releases"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit-hash",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Short git commit hash (e.g., 'c5f1e86')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nightly-version",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Full nightly version string (e.g., '0.5.6.post1.dev7716+gc5f1e86')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda-version",
|
||||
type=str,
|
||||
default="129",
|
||||
help="CUDA version (e.g., '129' or '130'). Defaults to '129'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build-date",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Build date in YYYY-MM-DD format (e.g., '2025-12-13')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Updating nightly wheel index")
|
||||
print(f" Commit: {args.commit_hash}")
|
||||
print(f" Version: {args.nightly_version}")
|
||||
print(f" CUDA version: {args.cuda_version}")
|
||||
if args.build_date:
|
||||
print(f" Build date: {args.build_date}")
|
||||
|
||||
update_wheel_index(
|
||||
args.commit_hash, args.nightly_version, args.cuda_version, args.build_date
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
third_party/sglang/scripts/update_pr_whl_index.py
vendored
Executable file
183
third_party/sglang/scripts/update_pr_whl_index.py
vendored
Executable file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Update the wheel index for PR SGLang releases.
|
||||
|
||||
This script generates a single PyPI-compatible index.html file at pr/index.html
|
||||
containing all PR builds, ordered by PR number and commit count (newest first).
|
||||
|
||||
Similar to update_nightly_whl_index.py but for PR builds.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
|
||||
def compute_sha256(file_path: pathlib.Path) -> str:
|
||||
"""Compute SHA256 hash of a file."""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def update_wheel_index(
|
||||
pr_number: str, commit_hash: str, wheel_version: str, build_date: str
|
||||
):
|
||||
"""Update the wheel index for PR releases.
|
||||
|
||||
Creates a single index at pr/index.html containing all PR builds.
|
||||
|
||||
Args:
|
||||
pr_number: PR number (e.g., '123')
|
||||
commit_hash: Short git commit hash (e.g., 'c5f1e86')
|
||||
wheel_version: Full wheel version string (e.g., '0.5.6.dev7716+pr-123.gc5f1e86')
|
||||
build_date: Build date in YYYY-MM-DD format (e.g., '2025-12-13')
|
||||
"""
|
||||
dist_dir = pathlib.Path("dist")
|
||||
whl_repo_dir = pathlib.Path("sgl-whl")
|
||||
|
||||
if not dist_dir.exists():
|
||||
print(f"Warning: {dist_dir} does not exist, skipping index update")
|
||||
return
|
||||
|
||||
# Base URL for wheels stored in GitHub Releases
|
||||
base_url = "https://github.com/sgl-project/whl/releases/download"
|
||||
release_tag = f"pr-{pr_number}-{build_date}-{commit_hash}"
|
||||
|
||||
# Create pr directory structure following PEP 503
|
||||
# /pr/index.html -> links to sglang/
|
||||
# /pr/sglang/index.html -> contains wheel links
|
||||
pr_dir = whl_repo_dir / "pr"
|
||||
pr_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sglang_dir = pr_dir / "sglang"
|
||||
sglang_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
root_index = pr_dir / "index.html"
|
||||
package_index = sglang_dir / "index.html"
|
||||
|
||||
print(f"\nUpdating PR wheel index")
|
||||
print(f" Root index: {root_index}")
|
||||
print(f" Package index: {package_index}")
|
||||
|
||||
# Read existing package index if it exists
|
||||
existing_links = []
|
||||
if package_index.exists():
|
||||
with open(package_index, "r") as f:
|
||||
content = f.read()
|
||||
# Extract existing links (skip header and HTML boilerplate)
|
||||
existing_links = [
|
||||
line for line in content.split("\n") if line.startswith("<a href=")
|
||||
]
|
||||
|
||||
# Generate new links for current wheels
|
||||
new_links = []
|
||||
for wheel_path in sorted(dist_dir.glob("*.whl")):
|
||||
try:
|
||||
filename = wheel_path.name
|
||||
sha256 = compute_sha256(wheel_path)
|
||||
|
||||
# URL format: {base_url}/{release_tag}/{filename}#sha256={hash}
|
||||
wheel_url = f"{base_url}/{release_tag}/{filename}#sha256={sha256}"
|
||||
link = f'<a href="{wheel_url}">{filename}</a><br>'
|
||||
|
||||
new_links.append(link)
|
||||
print(f" Added: {filename}")
|
||||
except Exception as e:
|
||||
print(f" Error processing {wheel_path.name}: {e}")
|
||||
continue
|
||||
|
||||
if not new_links:
|
||||
print(" No new wheels to add")
|
||||
return
|
||||
|
||||
# Combine existing and new links (new links first for latest)
|
||||
all_links = new_links + existing_links
|
||||
|
||||
# Remove duplicates while preserving order (newer first)
|
||||
seen = set()
|
||||
unique_links = []
|
||||
for link in all_links:
|
||||
# Extract filename from link to check for duplicates
|
||||
filename_match = re.search(r">([^<]+\.whl)</a>", link)
|
||||
if filename_match:
|
||||
filename = filename_match.group(1)
|
||||
if filename not in seen:
|
||||
seen.add(filename)
|
||||
unique_links.append(link)
|
||||
|
||||
# Write root index (links to sglang package directory)
|
||||
with open(root_index, "w") as f:
|
||||
f.write("<!DOCTYPE html>\n")
|
||||
f.write('<a href="sglang/">sglang</a>\n')
|
||||
|
||||
print(f" Written root index: {root_index}")
|
||||
|
||||
# Write package index in minimal format
|
||||
with open(package_index, "w") as f:
|
||||
f.write("<!DOCTYPE html>\n")
|
||||
f.write("<h1>SGLang PR Wheels</h1>\n")
|
||||
# Write links only
|
||||
f.write("\n".join(unique_links))
|
||||
f.write("\n")
|
||||
|
||||
print(f" Written {len(unique_links)} total wheels to {package_index}")
|
||||
print(f"\nDone! Users can install with:")
|
||||
print(
|
||||
f" pip install sglang --pre --extra-index-url https://sgl-project.github.io/whl/pr/"
|
||||
)
|
||||
print(f"\nOr install specific PR #{pr_number} wheel directly:")
|
||||
if new_links:
|
||||
first_wheel_match = re.search(r'href="([^"]+)"', new_links[0])
|
||||
if first_wheel_match:
|
||||
wheel_url = first_wheel_match.group(1).split("#")[0] # Remove sha256 hash
|
||||
print(f" pip install {wheel_url}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Update wheel index for PR SGLang releases"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pr-number",
|
||||
type=str,
|
||||
required=True,
|
||||
help="PR number (e.g., '123')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--commit-hash",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Short git commit hash (e.g., 'c5f1e86')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wheel-version",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Full wheel version string (e.g., '0.5.6.dev7716+pr-123.gc5f1e86')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build-date",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Build date in YYYY-MM-DD format (e.g., '2025-12-13')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Updating PR wheel index")
|
||||
print(f" PR: #{args.pr_number}")
|
||||
print(f" Commit: {args.commit_hash}")
|
||||
print(f" Version: {args.wheel_version}")
|
||||
print(f" Build date: {args.build_date}")
|
||||
|
||||
update_wheel_index(
|
||||
args.pr_number, args.commit_hash, args.wheel_version, args.build_date
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user