chore: vendor sglang v0.5.10 snapshot
This commit is contained in:
15
third_party/sglang/sgl-model-gateway/.cargo/config.toml
vendored
Normal file
15
third_party/sglang/sgl-model-gateway/.cargo/config.toml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
[build]
|
||||
rustflags = []
|
||||
incremental = true
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = [
|
||||
"-C", "link-arg=-undefined",
|
||||
"-C", "link-arg=dynamic_lookup",
|
||||
]
|
||||
|
||||
[target.x86_64-apple-darwin]
|
||||
rustflags = [
|
||||
"-C", "link-arg=-undefined",
|
||||
"-C", "link-arg=dynamic_lookup",
|
||||
]
|
||||
198
third_party/sglang/sgl-model-gateway/Cargo.toml
vendored
Normal file
198
third_party/sglang/sgl-model-gateway/Cargo.toml
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
[package]
|
||||
name = "sgl-model-gateway"
|
||||
version = "0.3.2"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["grpc-client"]
|
||||
grpc-client = []
|
||||
grpc-server = []
|
||||
|
||||
vendored-openssl = ["openssl/vendored"]
|
||||
|
||||
[lints.rust]
|
||||
unused_qualifications = "warn"
|
||||
|
||||
[lib]
|
||||
name = "smg"
|
||||
crate-type = ["rlib"]
|
||||
|
||||
[[bin]]
|
||||
name = "sgl-model-gateway"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "smg"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "amg"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4", features = ["derive", "env"] }
|
||||
axum = { version = "0.8.6", features = ["macros", "ws", "tracing"] }
|
||||
axum-server = { version = "0.8.0", default-features = false, features = ["tls-rustls"] }
|
||||
tower = { version = "0.5", features = ["full"] }
|
||||
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = { version = "1.0", default-features = false, features = [
|
||||
"std",
|
||||
"preserve_order",
|
||||
] }
|
||||
bytes = "1.8.0"
|
||||
http-body = "1.0"
|
||||
rand = "0.9.2"
|
||||
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json", "rustls-tls"], default-features = false }
|
||||
futures-util = "0.3"
|
||||
futures = "0.3"
|
||||
dashmap = "6.1.0"
|
||||
blake3 = "1.5"
|
||||
xxhash-rust = { version = "0.8", features = ["xxh3"] }
|
||||
bytemuck = { version = "1.21", features = ["derive"] }
|
||||
http = "1.1.0"
|
||||
tokio = { version = "1.42.0", features = ["full"] }
|
||||
async-trait = "0.1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "chrono"] }
|
||||
tracing-log = "0.2"
|
||||
tracing-appender = "0.2.3"
|
||||
opentelemetry = "0.27"
|
||||
opentelemetry_sdk = { version = "0.27", features = ["trace", "rt-tokio"] }
|
||||
opentelemetry-otlp = { version = "0.27", features = ["trace", "grpc-tonic"] }
|
||||
tracing-opentelemetry = "0.28"
|
||||
chrono = "0.4"
|
||||
kube = { version = "1.1.0", features = ["runtime", "derive"] }
|
||||
k8s-openapi = { version = "0.25.0", features = ["v1_33"] }
|
||||
metrics = "0.24.2"
|
||||
metrics-exporter-prometheus = "0.17.0"
|
||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||
parking_lot = "0.12.4"
|
||||
thiserror = "2.0.12"
|
||||
regex = "1.10"
|
||||
memchr = "2.7" # SIMD-optimized byte pattern searching
|
||||
url = "2.5.4"
|
||||
validator = { version = "0.20.0", features = ["derive"] }
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
anyhow = "1.0"
|
||||
reasoning-parser = "=1.0.0"
|
||||
openai-protocol = { version = "=1.0.0", features = ["axum"] }
|
||||
tool-parser = "=1.0.0"
|
||||
llm-tokenizer = "=1.0.0"
|
||||
smg-auth = "=1.0.0"
|
||||
wfaas = "=1.0.0"
|
||||
data-connector = "=1.0.0"
|
||||
smg-mcp = "=1.0.0"
|
||||
smg-wasm = "=1.0.0"
|
||||
smg-mesh = "=1.0.0"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rustls-pemfile = "2.2"
|
||||
openssl = "0.10.73"
|
||||
rmcp = { version = "0.8.3", features = ["client", "server",
|
||||
"transport-child-process",
|
||||
"transport-sse-client-reqwest",
|
||||
"transport-streamable-http-client-reqwest",
|
||||
"transport-streamable-http-server",
|
||||
"transport-streamable-http-server-session",
|
||||
"reqwest",
|
||||
"auth"] }
|
||||
serde_yaml = "0.9"
|
||||
subtle = "2.6"
|
||||
jsonwebtoken = { version = "9.3", default-features = false, features = ["use_pem"] }
|
||||
num-traits = "0.2"
|
||||
num-bigint = "0.4"
|
||||
base64 = "0.22"
|
||||
openai-harmony = { git = "https://github.com/openai/harmony", tag = "v0.0.4" }
|
||||
openmetrics-parser = "0.4.4"
|
||||
arc-swap = "1.7.1"
|
||||
|
||||
# gRPC and Protobuf dependencies
|
||||
smg-grpc-client = "=1.0.0"
|
||||
tonic = { version = "0.14.2", features = ["gzip", "transport"] }
|
||||
prost = "0.14.1"
|
||||
prost-types = "0.14.1"
|
||||
tonic-prost = "0.14.2"
|
||||
bitflags = "2.10.0"
|
||||
once_cell = "1.21.3"
|
||||
|
||||
# CRDT for Mesh state synchronization
|
||||
crdts = "7.3"
|
||||
redis = { version = "0.27.6", features = ["tokio-comp", "json", "connection-manager"] }
|
||||
|
||||
|
||||
# wasm dependencies
|
||||
sha2 = "0.10"
|
||||
wasmtime = { version = "38.0", features = ["component-model", "async"] }
|
||||
|
||||
[build-dependencies]
|
||||
chrono = { version = "0.4", features = ["clock"] }
|
||||
toml = "0.9"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
tower = { version = "0.5", features = ["util"] }
|
||||
http-body-util = "0.1"
|
||||
portpicker = "0.1"
|
||||
tempfile = "3.8"
|
||||
lazy_static = "1.4"
|
||||
wasm-encoder = "0.242"
|
||||
npyz = { version = "0.8", features = ["npz"] } # For reading numpy .npz files in golden tests
|
||||
opentelemetry-proto = { version = "0.27", features = ["gen-tonic"] }
|
||||
tonic-v12 = { version = "0.12.3", package = "tonic" }
|
||||
serial_test = "3.0"
|
||||
rsa = { version = "0.9", features = ["sha2"] }
|
||||
|
||||
[[bench]]
|
||||
name = "consistent_hash_bench"
|
||||
harness = false
|
||||
path = "benches/consistent_hash_bench.rs"
|
||||
[[bench]]
|
||||
name = "wasm_middleware_latency"
|
||||
harness = false
|
||||
path = "benches/wasm_middleware_latency.rs"
|
||||
[[bench]]
|
||||
name = "request_processing"
|
||||
harness = false
|
||||
path = "benches/request_processing.rs"
|
||||
|
||||
[[bench]]
|
||||
name = "router_registry_bench"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "manual_policy_benchmark"
|
||||
harness = false
|
||||
path = "benches/manual_policy_benchmark.rs"
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z" # Optimize for size
|
||||
lto = "fat" # Full LTO for smaller binaries
|
||||
codegen-units = 1 # Better optimization, slower compile
|
||||
strip = true # Strip debug symbols
|
||||
|
||||
[profile.ci]
|
||||
inherits = "release"
|
||||
opt-level = 2 # Lighter optimization (still fast runtime, much faster compile)
|
||||
lto = "thin" # Thin LTO - good balance
|
||||
codegen-units = 16 # More parallelization for faster builds
|
||||
strip = true
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 0
|
||||
debug = 1
|
||||
split-debuginfo = "unpacked"
|
||||
incremental = true
|
||||
codegen-units = 256
|
||||
|
||||
[profile.dev.package."*"]
|
||||
opt-level = 2
|
||||
debug = false
|
||||
|
||||
|
||||
[profile.dev.build-override]
|
||||
opt-level = 3
|
||||
codegen-units = 1
|
||||
|
||||
[profile.dev-opt]
|
||||
inherits = "dev"
|
||||
opt-level = 1
|
||||
1
third_party/sglang/sgl-model-gateway/LICENSE
vendored
Symbolic link
1
third_party/sglang/sgl-model-gateway/LICENSE
vendored
Symbolic link
@@ -0,0 +1 @@
|
||||
../LICENSE
|
||||
202
third_party/sglang/sgl-model-gateway/Makefile
vendored
Normal file
202
third_party/sglang/sgl-model-gateway/Makefile
vendored
Normal file
@@ -0,0 +1,202 @@
|
||||
# Model Gateway Makefile
|
||||
# Provides convenient shortcuts for common development tasks
|
||||
|
||||
# Python bindings directory
|
||||
PYTHON_DIR := bindings/python
|
||||
|
||||
# Auto-detect CPU cores and cap at reasonable limit to avoid thread exhaustion
|
||||
# Can be overridden: make python-dev JOBS=4
|
||||
NPROC := $(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8)
|
||||
JOBS ?= $(shell echo $$(($(NPROC) > 16 ? 16 : $(NPROC))))
|
||||
|
||||
# Check if sccache is available and set RUSTC_WRAPPER accordingly
|
||||
SCCACHE := $(shell which sccache 2>/dev/null)
|
||||
ifdef SCCACHE
|
||||
export RUSTC_WRAPPER := $(SCCACHE)
|
||||
$(info Using sccache for compilation caching)
|
||||
else
|
||||
$(info sccache not found. Install it for faster builds: cargo install sccache)
|
||||
endif
|
||||
|
||||
.PHONY: help build test clean docs check fmt dev-setup pre-commit setup-sccache sccache-stats sccache-clean sccache-stop \
|
||||
python-dev python-build python-build-release python-install python-clean python-test python-check \
|
||||
show-version bump-version release-notes
|
||||
|
||||
help: ## Show this help message
|
||||
@echo "Model Gateway Development Commands"
|
||||
@echo "=================================="
|
||||
@echo ""
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||
@echo ""
|
||||
|
||||
build: ## Build the project in release mode
|
||||
@echo "Building SGLang Model Gateway..."
|
||||
@cargo build --release
|
||||
|
||||
test: ## Run all tests
|
||||
@echo "Running tests..."
|
||||
@cargo test
|
||||
|
||||
clean: ## Clean build artifacts
|
||||
@echo "Cleaning build artifacts..."
|
||||
@cargo clean
|
||||
|
||||
docs: ## Generate and open documentation
|
||||
@echo "Generating documentation..."
|
||||
@cargo doc --open
|
||||
|
||||
check: ## Run cargo check and clippy
|
||||
@echo "Running cargo check..."
|
||||
@cargo check
|
||||
@echo "Running clippy..."
|
||||
@cargo clippy --all-targets --all-features -- -D warnings
|
||||
|
||||
fmt: ## Format code with rustfmt
|
||||
@echo "Formatting code..."
|
||||
@rustup run nightly cargo fmt
|
||||
|
||||
# Development workflow shortcuts
|
||||
dev-setup: build test ## Set up development environment
|
||||
@echo "Development environment ready!"
|
||||
|
||||
pre-commit: fmt check test ## Run pre-commit checks
|
||||
@echo "Pre-commit checks passed!"
|
||||
|
||||
# sccache management targets
|
||||
setup-sccache: ## Install and configure sccache
|
||||
@echo "Setting up sccache..."
|
||||
@./scripts/setup-sccache.sh
|
||||
|
||||
sccache-stats: ## Show sccache statistics
|
||||
@if [ -n "$(SCCACHE)" ]; then \
|
||||
echo "sccache statistics:"; \
|
||||
sccache -s; \
|
||||
else \
|
||||
echo "sccache not installed. Run 'make setup-sccache' to install it."; \
|
||||
fi
|
||||
|
||||
sccache-clean: ## Clear sccache cache
|
||||
@if [ -n "$(SCCACHE)" ]; then \
|
||||
echo "Clearing sccache cache..."; \
|
||||
sccache -C; \
|
||||
echo "sccache cache cleared"; \
|
||||
else \
|
||||
echo "sccache not installed"; \
|
||||
fi
|
||||
|
||||
sccache-stop: ## Stop the sccache server
|
||||
@if [ -n "$(SCCACHE)" ]; then \
|
||||
echo "Stopping sccache server..."; \
|
||||
sccache --stop-server || true; \
|
||||
else \
|
||||
echo "sccache not installed"; \
|
||||
fi
|
||||
|
||||
# Python bindings (maturin) targets
|
||||
python-dev: ## Build Python bindings in development mode (fast, debug build)
|
||||
@echo "Building Python bindings in development mode (using $(JOBS) parallel jobs with sccache)..."
|
||||
@cd $(PYTHON_DIR) && CARGO_BUILD_JOBS=$(JOBS) maturin develop
|
||||
|
||||
python-build: ## Build Python wheel (release mode with vendored OpenSSL)
|
||||
@echo "Building Python wheel (release, vendored OpenSSL, using $(JOBS) parallel jobs with sccache)..."
|
||||
@cd $(PYTHON_DIR) && CARGO_BUILD_JOBS=$(JOBS) maturin build --release --out dist --features vendored-openssl
|
||||
|
||||
python-build-release: python-build ## Alias for python-build
|
||||
|
||||
python-install: python-build ## Build and install Python wheel
|
||||
@echo "Installing Python wheel..."
|
||||
@pip install --force-reinstall $(PYTHON_DIR)/dist/*.whl
|
||||
@echo "Python package installed!"
|
||||
|
||||
python-clean: ## Clean Python build artifacts
|
||||
@echo "Cleaning Python build artifacts..."
|
||||
@rm -rf $(PYTHON_DIR)/dist/
|
||||
@rm -rf $(PYTHON_DIR)/target/
|
||||
@rm -rf $(PYTHON_DIR)/sglang_router.egg-info/
|
||||
@rm -rf $(PYTHON_DIR)/sglang_router/__pycache__/
|
||||
@find $(PYTHON_DIR) -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find $(PYTHON_DIR) -name "*.pyc" -delete 2>/dev/null || true
|
||||
@echo "Python build artifacts cleaned!"
|
||||
|
||||
python-test: ## Run Python tests
|
||||
@echo "Running Python tests..."
|
||||
@pytest e2e_test/ -v
|
||||
|
||||
python-check: ## Check Python package with twine
|
||||
@echo "Checking Python package..."
|
||||
@cd $(PYTHON_DIR) && CARGO_BUILD_JOBS=$(JOBS) maturin build --release --out dist --features vendored-openssl
|
||||
@pip install twine 2>/dev/null || true
|
||||
@twine check $(PYTHON_DIR)/dist/*
|
||||
@echo "Python package check passed!"
|
||||
|
||||
# Combined shortcuts
|
||||
dev: python-dev ## Quick development setup (build Python bindings in dev mode)
|
||||
|
||||
install: python-install ## Build and install everything
|
||||
|
||||
# Release management
|
||||
VERSION_FILES := Cargo.toml \
|
||||
bindings/golang/Cargo.toml \
|
||||
bindings/python/Cargo.toml \
|
||||
bindings/python/pyproject.toml \
|
||||
bindings/python/src/sglang_router/version.py
|
||||
|
||||
show-version: ## Show current version across all files
|
||||
@echo "Current versions:"
|
||||
@echo " Cargo.toml: $$(grep -m1 '^version = ' Cargo.toml | sed 's/version = "\(.*\)"/\1/')"
|
||||
@echo " bindings/golang/Cargo.toml: $$(grep -m1 '^version = ' bindings/golang/Cargo.toml | sed 's/version = "\(.*\)"/\1/')"
|
||||
@echo " bindings/python/Cargo.toml: $$(grep -m1 '^version = ' bindings/python/Cargo.toml | sed 's/version = "\(.*\)"/\1/')"
|
||||
@echo " bindings/python/pyproject.toml: $$(grep -m1 '^version = ' bindings/python/pyproject.toml | sed 's/version = "\(.*\)"/\1/')"
|
||||
@echo " bindings/python/.../version.py: $$(grep '__version__' bindings/python/src/sglang_router/version.py | sed 's/__version__ = "\(.*\)"/\1/')"
|
||||
|
||||
bump-version: ## Bump version across all files (usage: make bump-version VERSION=0.3.3)
|
||||
@if [ -z "$(VERSION)" ]; then \
|
||||
echo "Usage: make bump-version VERSION=<new-version>"; \
|
||||
echo "Example: make bump-version VERSION=0.3.3"; \
|
||||
echo ""; \
|
||||
echo "Current version:"; \
|
||||
grep -m1 '^version = ' Cargo.toml | sed 's/version = "\(.*\)"/ \1/'; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "Bumping version to $(VERSION)..."
|
||||
@# Update main Cargo.toml (line 3)
|
||||
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' Cargo.toml && rm -f Cargo.toml.bak
|
||||
@# Update golang binding Cargo.toml
|
||||
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' bindings/golang/Cargo.toml && rm -f bindings/golang/Cargo.toml.bak
|
||||
@# Update python binding Cargo.toml
|
||||
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' bindings/python/Cargo.toml && rm -f bindings/python/Cargo.toml.bak
|
||||
@# Update pyproject.toml
|
||||
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' bindings/python/pyproject.toml && rm -f bindings/python/pyproject.toml.bak
|
||||
@# Update version.py
|
||||
@sed -i.bak 's/__version__ = ".*"/__version__ = "$(VERSION)"/' bindings/python/src/sglang_router/version.py && rm -f bindings/python/src/sglang_router/version.py.bak
|
||||
@echo "Version updated to $(VERSION) in all files:"
|
||||
@echo " - Cargo.toml"
|
||||
@echo " - bindings/golang/Cargo.toml"
|
||||
@echo " - bindings/python/Cargo.toml"
|
||||
@echo " - bindings/python/pyproject.toml"
|
||||
@echo " - bindings/python/src/sglang_router/version.py"
|
||||
@echo ""
|
||||
@echo "Verify with: make show-version"
|
||||
|
||||
release-notes: ## Generate release notes for gateway (usage: make release-notes PREV=gateway-v0.2.2 CURR=gateway-v1.0.0)
|
||||
@if [ -z "$(PREV)" ] || [ -z "$(CURR)" ]; then \
|
||||
echo "Usage: make release-notes PREV=<previous-tag> CURR=<current-tag>"; \
|
||||
echo "Example: make release-notes PREV=gateway-v0.2.2 CURR=gateway-v1.0.0"; \
|
||||
echo ""; \
|
||||
echo "Options:"; \
|
||||
echo " OUTPUT=<file> Save to file (default: stdout)"; \
|
||||
echo " CREATE_RELEASE=1 Create GitHub draft release via gh CLI (default: draft)"; \
|
||||
echo " DRAFT=0 Publish release immediately (skip draft)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@ARGS="$(PREV) $(CURR)"; \
|
||||
if [ -n "$(OUTPUT)" ]; then \
|
||||
ARGS="$$ARGS --output $(OUTPUT)"; \
|
||||
fi; \
|
||||
if [ "$(CREATE_RELEASE)" = "1" ]; then \
|
||||
ARGS="$$ARGS --create-release"; \
|
||||
if [ "$(DRAFT)" = "0" ]; then \
|
||||
ARGS="$$ARGS --no-draft"; \
|
||||
fi; \
|
||||
fi; \
|
||||
./scripts/generate_gateway_release_notes.sh $$ARGS
|
||||
1107
third_party/sglang/sgl-model-gateway/README.md
vendored
Normal file
1107
third_party/sglang/sgl-model-gateway/README.md
vendored
Normal file
File diff suppressed because it is too large
Load Diff
36
third_party/sglang/sgl-model-gateway/benches/consistent_hash_bench.rs
vendored
Normal file
36
third_party/sglang/sgl-model-gateway/benches/consistent_hash_bench.rs
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use smg_mesh::consistent_hash::ConsistentHashRing;
|
||||
|
||||
fn setup_ring(node_count: usize) -> ConsistentHashRing {
|
||||
let mut ring = ConsistentHashRing::new();
|
||||
for i in 0..node_count {
|
||||
ring.add_node(&format!("node-{}", i));
|
||||
}
|
||||
ring
|
||||
}
|
||||
|
||||
fn bench_consistent_hash(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("ConsistentHashRing");
|
||||
|
||||
for size in [10, 100, 500].iter() {
|
||||
let ring = setup_ring(*size);
|
||||
let key = "test-request-key-for-rate-limiting";
|
||||
let node_name = "node-5";
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("get_owners", size), size, |b, _| {
|
||||
b.iter(|| {
|
||||
black_box(ring.get_owners(key));
|
||||
});
|
||||
});
|
||||
group.bench_with_input(BenchmarkId::new("is_owner", size), size, |b, _| {
|
||||
b.iter(|| {
|
||||
black_box(ring.is_owner(key, node_name));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_consistent_hash);
|
||||
criterion_main!(benches);
|
||||
260
third_party/sglang/sgl-model-gateway/benches/manual_policy_benchmark.rs
vendored
Normal file
260
third_party/sglang/sgl-model-gateway/benches/manual_policy_benchmark.rs
vendored
Normal file
@@ -0,0 +1,260 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use smg::{
|
||||
core::{BasicWorkerBuilder, Worker, WorkerType},
|
||||
policies::{LoadBalancingPolicy, ManualPolicy, SelectWorkerInfo},
|
||||
};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
// ============================================================================
|
||||
// Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn create_workers(count: usize) -> Vec<Arc<dyn Worker>> {
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
Arc::new(
|
||||
BasicWorkerBuilder::new(format!("http://worker-{}:8000", i))
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
) as Arc<dyn Worker>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn select_with_key(
|
||||
rt: &Runtime,
|
||||
policy: &ManualPolicy,
|
||||
workers: &[Arc<dyn Worker>],
|
||||
key: &str,
|
||||
) -> Option<usize> {
|
||||
let mut headers = http::HeaderMap::new();
|
||||
headers.insert("x-smg-routing-key", key.parse().unwrap());
|
||||
let info = SelectWorkerInfo {
|
||||
headers: Some(&headers),
|
||||
..Default::default()
|
||||
};
|
||||
rt.block_on(policy.select_worker(workers, &info))
|
||||
}
|
||||
|
||||
fn warmup_keys(rt: &Runtime, policy: &ManualPolicy, workers: &[Arc<dyn Worker>], keys: &[String]) {
|
||||
for key in keys {
|
||||
select_with_key(rt, policy, workers, key);
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_keys(count: usize, prefix: &str) -> Vec<String> {
|
||||
(0..count).map(|i| format!("{}{}", prefix, i)).collect()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_fast_path_hit(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
let mut group = c.benchmark_group("manual_policy/fast_path");
|
||||
|
||||
for worker_count in [4, 16, 64, 256] {
|
||||
let policy = ManualPolicy::new();
|
||||
let workers = create_workers(worker_count);
|
||||
let keys = gen_keys(1000, "user-");
|
||||
warmup_keys(&rt, &policy, &workers, &keys);
|
||||
|
||||
group.throughput(Throughput::Elements(1));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("workers", worker_count),
|
||||
&worker_count,
|
||||
|b, _| {
|
||||
let mut idx = 0;
|
||||
b.iter(|| {
|
||||
let result = select_with_key(&rt, &policy, &workers, &keys[idx % keys.len()]);
|
||||
idx += 1;
|
||||
black_box(result)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_slow_path_vacant(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
let mut group = c.benchmark_group("manual_policy/slow_path_vacant");
|
||||
|
||||
for worker_count in [4, 16, 64, 256] {
|
||||
let workers = create_workers(worker_count);
|
||||
|
||||
group.throughput(Throughput::Elements(1));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("workers", worker_count),
|
||||
&worker_count,
|
||||
|b, _| {
|
||||
let policy = ManualPolicy::new();
|
||||
let mut idx = 0;
|
||||
b.iter(|| {
|
||||
let key = format!("new-user-{}", idx);
|
||||
let result = select_with_key(&rt, &policy, &workers, &key);
|
||||
idx += 1;
|
||||
black_box(result)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_no_routing_key(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
let mut group = c.benchmark_group("manual_policy/no_routing_key");
|
||||
|
||||
for worker_count in [4, 16, 64, 256] {
|
||||
let policy = ManualPolicy::new();
|
||||
let workers = create_workers(worker_count);
|
||||
|
||||
group.throughput(Throughput::Elements(1));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("workers", worker_count),
|
||||
&worker_count,
|
||||
|b, _| {
|
||||
let info = SelectWorkerInfo::default();
|
||||
b.iter(|| black_box(rt.block_on(policy.select_worker(&workers, &info))));
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_failover(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
let mut group = c.benchmark_group("manual_policy/failover");
|
||||
group.sample_size(50);
|
||||
|
||||
for worker_count in [4, 16, 64] {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("workers", worker_count),
|
||||
&worker_count,
|
||||
|b, &count| {
|
||||
b.iter_with_setup(
|
||||
|| {
|
||||
let policy = ManualPolicy::new();
|
||||
let workers = create_workers(count);
|
||||
let idx = select_with_key(&rt, &policy, &workers, "failover-test").unwrap();
|
||||
workers[idx].set_healthy(false);
|
||||
(policy, workers)
|
||||
},
|
||||
|(policy, workers)| {
|
||||
black_box(select_with_key(&rt, &policy, &workers, "failover-test"))
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_concurrent(c: &mut Criterion) {
|
||||
let rt = Arc::new(
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
let mut group = c.benchmark_group("manual_policy/concurrent");
|
||||
group.sample_size(50);
|
||||
|
||||
for num_threads in [2, 4, 8, 16] {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("threads", num_threads),
|
||||
&num_threads,
|
||||
|b, &threads| {
|
||||
b.iter(|| {
|
||||
let policy = Arc::new(ManualPolicy::new());
|
||||
let workers: Arc<Vec<Arc<dyn Worker>>> = Arc::new(create_workers(16));
|
||||
|
||||
rt.block_on(async {
|
||||
let handles: Vec<_> = (0..threads)
|
||||
.map(|t| {
|
||||
let policy = Arc::clone(&policy);
|
||||
let workers = Arc::clone(&workers);
|
||||
tokio::spawn(async move {
|
||||
for i in 0..500 {
|
||||
let key = if i % 5 == 0 {
|
||||
format!("thread{}_user{}", t, i)
|
||||
} else {
|
||||
format!("shared_user{}", i % 50)
|
||||
};
|
||||
let mut headers = http::HeaderMap::new();
|
||||
headers.insert("x-smg-routing-key", key.parse().unwrap());
|
||||
let info = SelectWorkerInfo {
|
||||
headers: Some(&headers),
|
||||
..Default::default()
|
||||
};
|
||||
let _ =
|
||||
black_box(policy.select_worker(&workers, &info).await);
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.await.unwrap();
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_cache_size_impact(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
let mut group = c.benchmark_group("manual_policy/cache_size");
|
||||
|
||||
for cache_size in [100, 1000, 10000, 100000] {
|
||||
let policy = ManualPolicy::new();
|
||||
let workers = create_workers(16);
|
||||
let keys = gen_keys(cache_size, "user-");
|
||||
warmup_keys(&rt, &policy, &workers, &keys);
|
||||
|
||||
group.throughput(Throughput::Elements(1));
|
||||
group.bench_with_input(BenchmarkId::new("keys", cache_size), &cache_size, |b, _| {
|
||||
let mut idx = 0;
|
||||
b.iter(|| {
|
||||
let result = select_with_key(&rt, &policy, &workers, &keys[idx % keys.len()]);
|
||||
idx += 1;
|
||||
black_box(result)
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_comparison_baseline(c: &mut Criterion) {
|
||||
use rand::Rng;
|
||||
|
||||
let mut group = c.benchmark_group("manual_policy/vs_baseline");
|
||||
let workers = create_workers(16);
|
||||
|
||||
// Baseline: raw random selection without any policy overhead
|
||||
group.bench_function("raw_random", |b| {
|
||||
let mut rng = rand::rng();
|
||||
b.iter(|| black_box(rng.random_range(0..workers.len())));
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_fast_path_hit,
|
||||
bench_slow_path_vacant,
|
||||
bench_no_routing_key,
|
||||
bench_failover,
|
||||
bench_concurrent,
|
||||
bench_cache_size_impact,
|
||||
bench_comparison_baseline,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
670
third_party/sglang/sgl-model-gateway/benches/request_processing.rs
vendored
Normal file
670
third_party/sglang/sgl-model-gateway/benches/request_processing.rs
vendored
Normal file
@@ -0,0 +1,670 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use serde_json::{from_str, to_string, to_value, to_vec};
|
||||
use smg::{
|
||||
core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType},
|
||||
protocols::{
|
||||
chat::{ChatCompletionRequest, ChatMessage, MessageContent},
|
||||
common::StringOrArray,
|
||||
completion::CompletionRequest,
|
||||
generate::GenerateRequest,
|
||||
sampling_params::SamplingParams,
|
||||
},
|
||||
routers::http::pd_types::{generate_room_id, RequestWithBootstrap},
|
||||
};
|
||||
|
||||
fn create_test_worker() -> BasicWorker {
|
||||
BasicWorkerBuilder::new("http://test-server:8000")
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(5678),
|
||||
})
|
||||
.build()
|
||||
}
|
||||
|
||||
// Helper function to get bootstrap info from worker
|
||||
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
|
||||
let hostname = worker.bootstrap_host().to_string();
|
||||
let bootstrap_port = worker.bootstrap_port();
|
||||
(hostname, bootstrap_port)
|
||||
}
|
||||
|
||||
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||
fn default_generate_request() -> GenerateRequest {
|
||||
GenerateRequest {
|
||||
text: None,
|
||||
model: None,
|
||||
input_ids: None,
|
||||
input_embeds: None,
|
||||
image_data: None,
|
||||
video_data: None,
|
||||
audio_data: None,
|
||||
sampling_params: None,
|
||||
return_logprob: None,
|
||||
logprob_start_len: None,
|
||||
top_logprobs_num: None,
|
||||
token_ids_logprob: None,
|
||||
return_text_in_logprobs: false,
|
||||
stream: false,
|
||||
log_metrics: true,
|
||||
return_hidden_states: false,
|
||||
modalities: None,
|
||||
session_params: None,
|
||||
lora_path: None,
|
||||
lora_id: None,
|
||||
custom_logit_processor: None,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
bootstrap_pair_key: None,
|
||||
data_parallel_rank: None,
|
||||
background: false,
|
||||
conversation_id: None,
|
||||
priority: None,
|
||||
extra_key: None,
|
||||
no_logs: false,
|
||||
custom_labels: None,
|
||||
return_bytes: false,
|
||||
return_entropy: false,
|
||||
rid: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
|
||||
#[allow(deprecated)]
|
||||
fn default_chat_completion_request() -> ChatCompletionRequest {
|
||||
ChatCompletionRequest {
|
||||
// Required fields in OpenAI order
|
||||
messages: vec![],
|
||||
model: String::new(),
|
||||
|
||||
// Use default for all other fields
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a default CompletionRequest for benchmarks with minimal fields set
|
||||
fn default_completion_request() -> CompletionRequest {
|
||||
CompletionRequest {
|
||||
model: String::new(),
|
||||
prompt: StringOrArray::String(String::new()),
|
||||
suffix: None,
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
n: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
logprobs: None,
|
||||
echo: false,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
best_of: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
// SGLang Extensions
|
||||
top_k: None,
|
||||
min_p: None,
|
||||
min_tokens: None,
|
||||
repetition_penalty: None,
|
||||
regex: None,
|
||||
ebnf: None,
|
||||
json_schema: None,
|
||||
stop_token_ids: None,
|
||||
no_stop_trim: false,
|
||||
ignore_eos: false,
|
||||
skip_special_tokens: true,
|
||||
// SGLang Extensions
|
||||
lora_path: None,
|
||||
session_params: None,
|
||||
return_hidden_states: false,
|
||||
sampling_seed: None,
|
||||
other: serde_json::Map::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// Sample request data for benchmarks
|
||||
fn create_sample_generate_request() -> GenerateRequest {
|
||||
GenerateRequest {
|
||||
text: Some("Write a story about artificial intelligence".to_string()),
|
||||
sampling_params: Some(SamplingParams {
|
||||
max_new_tokens: Some(100),
|
||||
temperature: Some(0.8),
|
||||
top_p: Some(0.9),
|
||||
top_k: Some(50),
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
repetition_penalty: Some(1.0),
|
||||
..Default::default()
|
||||
}),
|
||||
..default_generate_request()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
fn create_sample_chat_completion_request() -> ChatCompletionRequest {
|
||||
ChatCompletionRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![
|
||||
ChatMessage::System {
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
content: MessageContent::Text(
|
||||
"Explain quantum computing in simple terms".to_string(),
|
||||
),
|
||||
name: None,
|
||||
},
|
||||
],
|
||||
max_tokens: Some(150),
|
||||
max_completion_tokens: Some(150),
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
n: Some(1),
|
||||
presence_penalty: Some(0.0),
|
||||
frequency_penalty: Some(0.0),
|
||||
parallel_tool_calls: Some(true),
|
||||
..default_chat_completion_request()
|
||||
}
|
||||
}
|
||||
|
||||
fn create_sample_completion_request() -> CompletionRequest {
|
||||
CompletionRequest {
|
||||
model: "text-davinci-003".to_string(),
|
||||
prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()),
|
||||
max_tokens: Some(50),
|
||||
temperature: Some(0.8),
|
||||
top_p: Some(1.0),
|
||||
n: Some(1),
|
||||
presence_penalty: Some(0.0),
|
||||
frequency_penalty: Some(0.0),
|
||||
best_of: Some(1),
|
||||
..default_completion_request()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
fn create_large_chat_completion_request() -> ChatCompletionRequest {
|
||||
let mut messages = vec![ChatMessage::System {
|
||||
content: MessageContent::Text(
|
||||
"You are a helpful assistant with extensive knowledge.".to_string(),
|
||||
),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
// Add many user/assistant pairs to simulate a long conversation
|
||||
for i in 0..50 {
|
||||
messages.push(ChatMessage::User {
|
||||
content: MessageContent::Text(format!("Question {}: What do you think about topic number {} which involves complex reasoning about multiple interconnected systems and their relationships?", i, i)),
|
||||
name: None,
|
||||
});
|
||||
messages.push(ChatMessage::Assistant {
|
||||
content: Some(MessageContent::Text(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
});
|
||||
}
|
||||
|
||||
ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages,
|
||||
max_tokens: Some(1000),
|
||||
max_completion_tokens: Some(1000),
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.95),
|
||||
n: Some(1),
|
||||
presence_penalty: Some(0.1),
|
||||
frequency_penalty: Some(0.1),
|
||||
top_logprobs: Some(5),
|
||||
seed: Some(42),
|
||||
parallel_tool_calls: Some(true),
|
||||
..default_chat_completion_request()
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark JSON serialization
|
||||
fn bench_json_serialization(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("json_serialization");
|
||||
|
||||
let generate_req = create_sample_generate_request();
|
||||
let chat_req = create_sample_chat_completion_request();
|
||||
let completion_req = create_sample_completion_request();
|
||||
let large_chat_req = create_large_chat_completion_request();
|
||||
|
||||
group.bench_function("generate_request", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_string(black_box(&generate_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("chat_completion_request", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_string(black_box(&chat_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("completion_request", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_string(black_box(&completion_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("large_chat_completion_request", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_string(black_box(&large_chat_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("generate_request_to_bytes", |b| {
|
||||
b.iter(|| {
|
||||
let bytes = to_vec(black_box(&generate_req)).unwrap();
|
||||
black_box(bytes);
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark JSON deserialization
|
||||
fn bench_json_deserialization(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("json_deserialization");
|
||||
|
||||
let generate_json = to_string(&create_sample_generate_request()).unwrap();
|
||||
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
|
||||
let completion_json = to_string(&create_sample_completion_request()).unwrap();
|
||||
let large_chat_json = to_string(&create_large_chat_completion_request()).unwrap();
|
||||
|
||||
group.bench_function("generate_request", |b| {
|
||||
b.iter(|| {
|
||||
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
||||
black_box(req);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("chat_completion_request", |b| {
|
||||
b.iter(|| {
|
||||
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
|
||||
black_box(req);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("completion_request", |b| {
|
||||
b.iter(|| {
|
||||
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
|
||||
black_box(req);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("large_chat_completion_request", |b| {
|
||||
b.iter(|| {
|
||||
let req: ChatCompletionRequest = from_str(black_box(&large_chat_json)).unwrap();
|
||||
black_box(req);
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark bootstrap injection (replaces request adaptation)
|
||||
fn bench_bootstrap_injection(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("bootstrap_injection");
|
||||
|
||||
let generate_req = create_sample_generate_request();
|
||||
let chat_req = create_sample_chat_completion_request();
|
||||
let completion_req = create_sample_completion_request();
|
||||
let large_chat_req = create_large_chat_completion_request();
|
||||
let worker = create_test_worker();
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
|
||||
group.bench_function("generate_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &generate_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("chat_completion_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &chat_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("completion_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &completion_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("large_chat_completion_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &large_chat_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark direct JSON routing (replaces regular routing)
|
||||
fn bench_direct_json_routing(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("direct_json_routing");
|
||||
|
||||
let generate_req = create_sample_generate_request();
|
||||
let chat_req = create_sample_chat_completion_request();
|
||||
let completion_req = create_sample_completion_request();
|
||||
|
||||
group.bench_function("generate_to_json", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_value(black_box(&generate_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("generate_to_json_string", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_string(black_box(&generate_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("generate_to_bytes", |b| {
|
||||
b.iter(|| {
|
||||
let bytes = to_vec(black_box(&generate_req)).unwrap();
|
||||
black_box(bytes);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("chat_completion_to_json", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_value(black_box(&chat_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("chat_completion_to_json_string", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_string(black_box(&chat_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("completion_to_json", |b| {
|
||||
b.iter(|| {
|
||||
let json = to_value(black_box(&completion_req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark throughput with different request sizes
|
||||
fn bench_throughput_by_size(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("throughput_by_size");
|
||||
|
||||
// Create requests of different sizes
|
||||
let small_generate = GenerateRequest {
|
||||
text: Some("Hi".to_string()),
|
||||
..default_generate_request()
|
||||
};
|
||||
|
||||
let medium_generate = GenerateRequest {
|
||||
text: Some("Write a medium length story about AI".repeat(10)),
|
||||
..default_generate_request()
|
||||
};
|
||||
|
||||
let large_generate = GenerateRequest {
|
||||
text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)),
|
||||
..default_generate_request()
|
||||
};
|
||||
|
||||
let worker = create_test_worker();
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
|
||||
for (name, req) in [
|
||||
("small", &small_generate),
|
||||
("medium", &medium_generate),
|
||||
("large", &large_generate),
|
||||
] {
|
||||
let json = to_string(req).unwrap();
|
||||
let size_bytes = json.len();
|
||||
let hostname_clone = hostname.clone();
|
||||
|
||||
group.throughput(Throughput::Bytes(size_bytes as u64));
|
||||
group.bench_with_input(BenchmarkId::new("serialize", name), &req, |b, req| {
|
||||
b.iter(|| {
|
||||
let json = to_string(black_box(req)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("deserialize", name),
|
||||
&json,
|
||||
|b, json_str| {
|
||||
b.iter(|| {
|
||||
let req: GenerateRequest = black_box(from_str(json_str)).unwrap();
|
||||
black_box(req);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("bootstrap_inject", name),
|
||||
&req,
|
||||
move |b, req| {
|
||||
let hostname = hostname_clone.clone();
|
||||
b.iter(|| {
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(&request_with_bootstrap).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark full round-trip: deserialize -> inject bootstrap -> serialize
|
||||
fn bench_full_round_trip(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("full_round_trip");
|
||||
|
||||
let generate_json = to_string(&create_sample_generate_request()).unwrap();
|
||||
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
|
||||
let completion_json = to_string(&create_sample_completion_request()).unwrap();
|
||||
let worker = create_test_worker();
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
|
||||
group.bench_function("generate_openai_to_pd_pipeline", |b| {
|
||||
b.iter(|| {
|
||||
// Deserialize OpenAI request
|
||||
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
||||
// Create wrapper with bootstrap fields
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
// Serialize final request
|
||||
let pd_json = to_string(&request_with_bootstrap).unwrap();
|
||||
black_box(pd_json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
|
||||
b.iter(|| {
|
||||
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let pd_json = to_string(&request_with_bootstrap).unwrap();
|
||||
black_box(pd_json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("completion_openai_to_pd_pipeline", |b| {
|
||||
b.iter(|| {
|
||||
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let pd_json = to_string(&request_with_bootstrap).unwrap();
|
||||
black_box(pd_json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("generate_direct_json_pipeline", |b| {
|
||||
b.iter(|| {
|
||||
// Deserialize OpenAI request
|
||||
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
||||
// Convert to JSON for direct routing (no bootstrap injection)
|
||||
let routing_json = to_value(&req).unwrap();
|
||||
let json_string = to_string(&routing_json).unwrap();
|
||||
black_box(json_string);
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn benchmark_summary(c: &mut Criterion) {
|
||||
let group = c.benchmark_group("benchmark_summary");
|
||||
|
||||
println!("\nSGLang Model Gateway Performance Benchmark Suite");
|
||||
println!("=================================================");
|
||||
|
||||
// Quick performance overview
|
||||
let generate_req = create_sample_generate_request();
|
||||
let worker = create_test_worker();
|
||||
|
||||
println!("\nQuick Performance Overview:");
|
||||
|
||||
// Measure serialization
|
||||
let start = Instant::now();
|
||||
for _ in 0..1000 {
|
||||
let _ = black_box(to_string(&generate_req).unwrap());
|
||||
}
|
||||
let serialize_time = start.elapsed().as_nanos() / 1000;
|
||||
println!(" * Serialization (avg): {:>8} ns/req", serialize_time);
|
||||
|
||||
// Measure deserialization
|
||||
let json = to_string(&generate_req).unwrap();
|
||||
let start = Instant::now();
|
||||
for _ in 0..1000 {
|
||||
let _: GenerateRequest = black_box(from_str(&json).unwrap());
|
||||
}
|
||||
let deserialize_time = start.elapsed().as_nanos() / 1000;
|
||||
println!(
|
||||
" * Deserialization (avg): {:>8} ns/req",
|
||||
deserialize_time
|
||||
);
|
||||
|
||||
// Measure bootstrap injection (replaces adaptation)
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
let start = Instant::now();
|
||||
for _ in 0..1000 {
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &generate_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let _ = black_box(to_value(&request_with_bootstrap).unwrap());
|
||||
}
|
||||
let inject_time = start.elapsed().as_nanos() / 1000;
|
||||
println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time);
|
||||
|
||||
// Calculate ratios
|
||||
let total_pipeline = serialize_time + deserialize_time + inject_time;
|
||||
println!(" * Total Pipeline (avg): {:>8} ns/req", total_pipeline);
|
||||
|
||||
println!("\nPerformance Insights:");
|
||||
if deserialize_time > serialize_time * 2 {
|
||||
println!(" • Deserialization is significantly faster than serialization");
|
||||
}
|
||||
if inject_time < serialize_time / 10 {
|
||||
println!(
|
||||
" • Bootstrap injection overhead is negligible ({:.1}% of serialization)",
|
||||
(inject_time as f64 / serialize_time as f64) * 100.0
|
||||
);
|
||||
}
|
||||
if total_pipeline < 100_000 {
|
||||
println!(" • Total pipeline latency is excellent (< 100μs)");
|
||||
}
|
||||
|
||||
println!("\nSimplification Benefits:");
|
||||
println!(" • Eliminated complex type conversion layer");
|
||||
println!(" • Reduced memory allocations");
|
||||
println!(" • Automatic field preservation (no manual mapping)");
|
||||
println!(" • Direct JSON manipulation improves performance");
|
||||
|
||||
println!("\nRecommendations:");
|
||||
if serialize_time > deserialize_time {
|
||||
println!(" • Focus optimization efforts on serialization rather than deserialization");
|
||||
}
|
||||
println!(" • PD mode overhead is minimal - safe to use for latency-sensitive workloads");
|
||||
println!(" • Consider batching small requests to improve overall throughput");
|
||||
|
||||
println!("\n{}", "=".repeat(50));
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
benchmark_summary,
|
||||
bench_json_serialization,
|
||||
bench_json_deserialization,
|
||||
bench_bootstrap_injection,
|
||||
bench_direct_json_routing,
|
||||
bench_throughput_by_size,
|
||||
bench_full_round_trip
|
||||
);
|
||||
criterion_main!(benches);
|
||||
59
third_party/sglang/sgl-model-gateway/benches/router_registry_bench.rs
vendored
Normal file
59
third_party/sglang/sgl-model-gateway/benches/router_registry_bench.rs
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use smg::core::{BasicWorkerBuilder, CircuitBreakerConfig, WorkerRegistry, WorkerType};
|
||||
|
||||
// Helper to populate registry
|
||||
fn setup_registry(count: usize) -> Arc<WorkerRegistry> {
|
||||
let registry = Arc::new(WorkerRegistry::new());
|
||||
|
||||
for i in 0..count {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("model_id".to_string(), "benchmark-model".to_string());
|
||||
|
||||
let worker_type = if i % 2 == 0 {
|
||||
WorkerType::Regular
|
||||
} else {
|
||||
WorkerType::Decode
|
||||
};
|
||||
|
||||
let worker = BasicWorkerBuilder::new(format!("http://worker-{}:8000", i))
|
||||
.worker_type(worker_type)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build();
|
||||
|
||||
registry.register(Arc::from(worker));
|
||||
}
|
||||
registry
|
||||
}
|
||||
|
||||
fn bench_optimizations(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("Registry Optimizations");
|
||||
|
||||
// We test with 5000 workers to simulate high load
|
||||
let size = 5000;
|
||||
let registry = setup_registry(size);
|
||||
|
||||
// The OLD method (Slow: Allocates vector + Clones ARCs)
|
||||
group.bench_function(BenchmarkId::new("Old: get_all()", size), |b| {
|
||||
b.iter(|| {
|
||||
black_box(registry.get_all());
|
||||
});
|
||||
});
|
||||
|
||||
// The NEW method (Fast: O(1) Lookup, Zero Allocation)
|
||||
group.bench_function(
|
||||
BenchmarkId::new("New: get_worker_distribution()", size),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
black_box(registry.get_worker_distribution());
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_optimizations);
|
||||
criterion_main!(benches);
|
||||
1099
third_party/sglang/sgl-model-gateway/benches/tree_benchmark.rs
vendored
Normal file
1099
third_party/sglang/sgl-model-gateway/benches/tree_benchmark.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
110
third_party/sglang/sgl-model-gateway/benches/wasm_middleware_latency.rs
vendored
Normal file
110
third_party/sglang/sgl-model-gateway/benches/wasm_middleware_latency.rs
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{HeaderMap, Request, Response, StatusCode},
|
||||
middleware,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use http_body_util::BodyExt;
|
||||
use smg::{
|
||||
app_context::AppContext, config::RouterConfig, middleware::wasm_middleware,
|
||||
protocols::chat::ChatCompletionRequest, routers::RouterTrait, server::AppState,
|
||||
};
|
||||
use tokio::runtime::Runtime;
|
||||
use tower::{Layer, Service};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MockRouter;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RouterTrait for MockRouter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
async fn route_chat(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ChatCompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response<Body> {
|
||||
StatusCode::OK.into_response()
|
||||
}
|
||||
fn router_type(&self) -> &'static str {
|
||||
"mock"
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock service that simulates a streaming response with a 500ms delay.
|
||||
async fn mock_next_streaming(_req: Request<Body>) -> Response<Body> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(16);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Send first chunk immediately
|
||||
let _ = tx
|
||||
.send(Ok::<_, std::io::Error>(bytes::Bytes::from("chunk 1 ")))
|
||||
.await;
|
||||
// Simulate generation delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
// Send final chunk
|
||||
let _ = tx
|
||||
.send(Ok::<_, std::io::Error>(bytes::Bytes::from("chunk 2")))
|
||||
.await;
|
||||
});
|
||||
|
||||
Response::new(Body::from_stream(
|
||||
tokio_stream::wrappers::ReceiverStream::new(rx),
|
||||
))
|
||||
}
|
||||
|
||||
fn bench_wasm_middleware_buffering(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
// Setup AppContext with WASM enabled
|
||||
let config = RouterConfig::builder().enable_wasm(true).build_unchecked();
|
||||
|
||||
let context = rt.block_on(AppContext::from_config(config, 30)).unwrap();
|
||||
let app_state = Arc::new(AppState {
|
||||
router: Arc::new(MockRouter),
|
||||
context: Arc::new(context),
|
||||
concurrency_queue_tx: None,
|
||||
router_manager: None,
|
||||
mesh_handler: None,
|
||||
mesh_sync_manager: None,
|
||||
});
|
||||
|
||||
c.bench_function("wasm_middleware_pre_fix_latency", |b| {
|
||||
b.iter(|| {
|
||||
rt.block_on(async {
|
||||
let req = Request::builder()
|
||||
.uri("/v1/chat/completions")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
// Create the service by applying the middleware layer to the mock streamer
|
||||
let mut service =
|
||||
middleware::from_fn_with_state(app_state.clone(), wasm_middleware).layer(
|
||||
tower::service_fn(|req: Request<Body>| async move {
|
||||
Ok::<_, std::convert::Infallible>(mock_next_streaming(req).await)
|
||||
}),
|
||||
);
|
||||
|
||||
// Explicitly poll the service
|
||||
let response: Response<Body> =
|
||||
service.call(req).await.expect("Middleware service failed");
|
||||
|
||||
// Measure how long it takes to receive the FIRST frame
|
||||
let mut body = response.into_body();
|
||||
let _first_frame = body.frame().await;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = bench_wasm_middleware_buffering
|
||||
}
|
||||
criterion_main!(benches);
|
||||
24
third_party/sglang/sgl-model-gateway/bindings/golang/.gitignore
vendored
Normal file
24
third_party/sglang/sgl-model-gateway/bindings/golang/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Build artifacts
|
||||
target/
|
||||
lib/
|
||||
|
||||
# Compiled binaries
|
||||
examples/simple/simple
|
||||
examples/streaming/streaming
|
||||
|
||||
# Go build artifacts
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# IDE and editor files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
48
third_party/sglang/sgl-model-gateway/bindings/golang/Cargo.toml
vendored
Normal file
48
third_party/sglang/sgl-model-gateway/bindings/golang/Cargo.toml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
[package]
|
||||
name = "sgl-model-gateway-golang"
|
||||
version = "0.3.2"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "sgl_model_gateway_go"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.42.0", features = ["full"] }
|
||||
serde_json = { version = "1.0", default-features = false, features = [
|
||||
"std",
|
||||
"preserve_order",
|
||||
] }
|
||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||
once_cell = "1.21.3"
|
||||
futures-util = "0.3"
|
||||
tracing = "0.1"
|
||||
libc = "0.2.179"
|
||||
|
||||
[dependencies.sgl-model-gateway]
|
||||
path = "../.."
|
||||
default-features = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
vendored-openssl = ["sgl-model-gateway/vendored-openssl"]
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z" # Optimize for size
|
||||
lto = "fat" # Full LTO for smaller binaries
|
||||
codegen-units = 1 # Better optimization, slower compile
|
||||
strip = true # Strip debug symbols
|
||||
|
||||
[profile.ci]
|
||||
inherits = "release"
|
||||
opt-level = 2 # Lighter optimization (still fast runtime, much faster compile)
|
||||
lto = "thin" # Thin LTO - good balance
|
||||
codegen-units = 16 # More parallelization for faster builds
|
||||
strip = true
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 0
|
||||
debug = 1
|
||||
split-debuginfo = "unpacked"
|
||||
incremental = true
|
||||
codegen-units = 256
|
||||
103
third_party/sglang/sgl-model-gateway/bindings/golang/Makefile
vendored
Normal file
103
third_party/sglang/sgl-model-gateway/bindings/golang/Makefile
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
# Makefile for sgl-model-gateway golang bindings
|
||||
# This builds the Rust FFI library and provides convenience targets for Go development
|
||||
|
||||
# Configuration
|
||||
CARGO_BUILD_DIR ?= $(shell pwd)/target
|
||||
BUILD_MODE ?= release
|
||||
LIB_NAME = libsgl_model_gateway_go
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
LIB_EXT = .so
|
||||
LD_LIBRARY_PATH_VAR = LD_LIBRARY_PATH
|
||||
endif
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
LIB_EXT = .dylib
|
||||
LD_LIBRARY_PATH_VAR = DYLD_LIBRARY_PATH
|
||||
endif
|
||||
|
||||
# Paths
|
||||
ROOT_DIR := $(shell pwd)
|
||||
RUST_SRC_DIR := $(ROOT_DIR)/src
|
||||
LIB_BUILD_DIR := $(CARGO_BUILD_DIR)/$(BUILD_MODE)
|
||||
LIB_BUILD_PATH := $(LIB_BUILD_DIR)/$(LIB_NAME)$(LIB_EXT)
|
||||
LIB_EXPORT_DIR := $(ROOT_DIR)/lib
|
||||
LIB_EXPORT_PATH := $(LIB_EXPORT_DIR)/$(LIB_NAME)$(LIB_EXT)
|
||||
|
||||
# Python LDFLAGS (needed for Rust FFI that depends on Python)
|
||||
PYTHON_LDFLAGS := $(shell python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
|
||||
|
||||
# CGO flags - use exported lib directory if available, otherwise build directory
|
||||
LIB_DIR := $(if $(wildcard $(LIB_EXPORT_PATH)),$(LIB_EXPORT_DIR),$(LIB_BUILD_DIR))
|
||||
export CGO_LDFLAGS = -L$(LIB_DIR) -lsgl_model_gateway_go $(PYTHON_LDFLAGS) -ldl
|
||||
export $(LD_LIBRARY_PATH_VAR) := $(LIB_DIR):$($(LD_LIBRARY_PATH_VAR))
|
||||
|
||||
.PHONY: all build build-dev lib lib-clean clean test examples help run-simple run-streaming check-lib
|
||||
|
||||
help:
|
||||
@echo "Available targets:"
|
||||
@echo " build - Build release version of Rust FFI library"
|
||||
@echo " build-dev - Build debug version of Rust FFI library"
|
||||
@echo " lib - Copy built library to ./lib directory"
|
||||
@echo " lib-clean - Clean ./lib directory"
|
||||
@echo " clean - Clean build artifacts"
|
||||
@echo " test - Run Go tests"
|
||||
@echo " examples - Build example programs"
|
||||
@echo " run-simple - Run simple example"
|
||||
@echo " run-streaming - Run streaming example"
|
||||
|
||||
all: build
|
||||
|
||||
build:
|
||||
@echo "Building Rust FFI library (release mode)..."
|
||||
@CARGO_TARGET_DIR=$(CARGO_BUILD_DIR) cargo build --release --manifest-path Cargo.toml
|
||||
@echo "Library built at: $(LIB_BUILD_PATH)"
|
||||
|
||||
build-dev:
|
||||
@echo "Building Rust FFI library (debug mode)..."
|
||||
@CARGO_TARGET_DIR=$(CARGO_BUILD_DIR) cargo build --manifest-path Cargo.toml
|
||||
@echo "Library built at: $(LIB_BUILD_DIR)/debug/$(LIB_NAME)$(LIB_EXT)"
|
||||
|
||||
lib: build
|
||||
@echo "Copying library to ./lib directory..."
|
||||
@mkdir -p $(LIB_EXPORT_DIR)
|
||||
@cp $(LIB_BUILD_PATH) $(LIB_EXPORT_PATH)
|
||||
@echo "Library exported at: $(LIB_EXPORT_PATH)"
|
||||
|
||||
lib-clean:
|
||||
@echo "Cleaning ./lib directory..."
|
||||
@rm -rf $(LIB_EXPORT_DIR)
|
||||
@echo "Lib directory cleaned"
|
||||
|
||||
clean: lib-clean
|
||||
@echo "Cleaning build artifacts..."
|
||||
@CARGO_TARGET_DIR=$(CARGO_BUILD_DIR) cargo clean --manifest-path Cargo.toml
|
||||
@echo "Clean complete"
|
||||
|
||||
test: build
|
||||
@echo "Running Go tests..."
|
||||
@go test ./...
|
||||
|
||||
examples: build
|
||||
@echo "Building example programs..."
|
||||
@cd examples/simple && go build -o simple main.go
|
||||
@cd examples/streaming && go build -o streaming main.go
|
||||
@echo "Examples built"
|
||||
|
||||
run-simple: build
|
||||
@echo "Running simple example..."
|
||||
@cd examples/simple && bash run.sh
|
||||
|
||||
run-streaming: build
|
||||
@echo "Running streaming example..."
|
||||
@cd examples/streaming && bash run.sh
|
||||
|
||||
# Check if library exists (either in lib dir or build dir)
|
||||
check-lib:
|
||||
@if [ ! -f "$(LIB_EXPORT_PATH)" ] && [ ! -f "$(LIB_BUILD_PATH)" ]; then \
|
||||
echo "Error: Library not found at $(LIB_EXPORT_PATH) or $(LIB_BUILD_PATH)"; \
|
||||
echo "Run 'make build' or 'make lib' first"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "Library found at: $(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)"
|
||||
482
third_party/sglang/sgl-model-gateway/bindings/golang/README.md
vendored
Normal file
482
third_party/sglang/sgl-model-gateway/bindings/golang/README.md
vendored
Normal file
@@ -0,0 +1,482 @@
|
||||
# SGLang Go gRPC SDK
|
||||
|
||||
A high-level Go SDK for interacting with SGLang gRPC API, designed with an OpenAI-style API for familiarity and ease of use.
|
||||
|
||||
**Location**: `sgl-model-gateway/bindings/golang/`
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Features](#features)
|
||||
- [Installation](#installation)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Basic Usage](#basic-usage)
|
||||
- [Streaming Usage](#streaming-usage)
|
||||
- [Examples](#examples)
|
||||
- [Configuration](#configuration)
|
||||
- [API Reference](#api-reference)
|
||||
- [Testing](#testing)
|
||||
- [Unit Tests](#unit-tests)
|
||||
- [Integration Tests](#integration-tests)
|
||||
- [Benchmarks](#benchmarks)
|
||||
- [Documentation](#documentation)
|
||||
- [Development](#development)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [License](#license)
|
||||
|
||||
## Features
|
||||
|
||||
- **OpenAI-style API**: Familiar interface similar to OpenAI Go SDK
|
||||
- **Streaming Support**: Real-time streaming chat completions
|
||||
- **Non-streaming Support**: Simple request/response API
|
||||
- **Tool Calling**: Support for function calling and tool use
|
||||
- **Type-safe**: Full Go type definitions for requests and responses
|
||||
- **Comprehensive Testing**: 18+ unit and integration tests
|
||||
- **Thread-safe**: All public methods are safe for concurrent use
|
||||
- **Well-documented**: Full API documentation with examples
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/sglang/sglang-go-grpc-sdk
|
||||
```
|
||||
|
||||
### Sync Dependencies
|
||||
|
||||
```bash
|
||||
cd sgl-model-gateway/bindings/golang
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
### Build Requirements
|
||||
|
||||
- Go 1.21+, Rust toolchain, Python 3.x
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Benchmark
|
||||
|
||||
Run the OpenAI-compatible server and benchmark:
|
||||
|
||||
```bash
|
||||
# Set environment variables
|
||||
export SGL_TOKENIZER_PATH="/Users/yangyanbo/tokenizer"
|
||||
export SGL_GRPC_ENDPOINT="grpc://10.109.185.20:8001"
|
||||
|
||||
# Run server
|
||||
cd examples/oai_server
|
||||
bash run.sh
|
||||
|
||||
# Run E2E benchmark
|
||||
cd ../..
|
||||
make e2e E2E_MODEL=/work/models/qwencoder-3b E2E_TOKENIZER=/Users/yangyanbo/tokenizer E2E_INPUT_LEN=1024 E2E_OUTPUT_LEN=512
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
The SDK includes several examples in the `examples/` directory:
|
||||
|
||||
- **simple**: Basic non-streaming chat completion example
|
||||
- **streaming**: Real-time streaming with performance metrics
|
||||
|
||||
### Running Examples
|
||||
|
||||
```bash
|
||||
# Run simple example
|
||||
cd bindings/golang/examples/simple
|
||||
bash run.sh
|
||||
|
||||
# Run streaming example
|
||||
cd bindings/golang/examples/streaming
|
||||
bash run.sh
|
||||
|
||||
# Or use Makefile from bindings/golang directory
|
||||
cd bindings/golang
|
||||
make run-simple
|
||||
make run-streaming
|
||||
```
|
||||
|
||||
### Basic Usage (Non-streaming)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sglang/sglang-go-grpc-sdk"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create client
|
||||
client, err := sglang.NewClient(sglang.ClientConfig{
|
||||
Endpoint: "grpc://localhost:20000",
|
||||
TokenizerPath: "/path/to/tokenizer",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Create completion
|
||||
resp, err := client.CreateChatCompletion(context.Background(), sglang.ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []sglang.ChatMessage{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(resp.Choices[0].Message.Content)
|
||||
fmt.Printf("Usage: Prompt=%d, Completion=%d, Total=%d\n",
|
||||
resp.Usage.PromptTokens,
|
||||
resp.Usage.CompletionTokens,
|
||||
resp.Usage.TotalTokens)
|
||||
}
|
||||
```
|
||||
|
||||
### Streaming Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"github.com/sglang/sglang-go-grpc-sdk"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create client
|
||||
client, err := sglang.NewClient(sglang.ClientConfig{
|
||||
Endpoint: "grpc://localhost:20000",
|
||||
TokenizerPath: "/path/to/tokenizer",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Create streaming completion
|
||||
ctx := context.Background()
|
||||
stream, err := client.CreateChatCompletionStream(ctx, sglang.ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []sglang.ChatMessage{
|
||||
{Role: "user", Content: "Tell me a story"},
|
||||
},
|
||||
Stream: true,
|
||||
MaxCompletionTokens: intPtr(500),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Read streaming response
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
fmt.Print(choice.Delta.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println() // newline
|
||||
}
|
||||
|
||||
// Helper functions for optional pointer fields
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
func float32Ptr(f float32) *float32 {
|
||||
return &f
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
Examples automatically detect the server endpoint and tokenizer path via environment variables or defaults.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `SGL_GRPC_ENDPOINT`: gRPC server endpoint (default: `grpc://localhost:20000`)
|
||||
- `SGL_TOKENIZER_PATH`: Path to tokenizer directory (required)
|
||||
- `CARGO_BUILD_DIR`: Rust build output directory (auto-detected if not set)
|
||||
|
||||
### ClientConfig
|
||||
|
||||
```go
|
||||
type ClientConfig struct {
|
||||
// Endpoint is the gRPC endpoint URL (e.g., "grpc://localhost:20000")
|
||||
// Required field. Must include the scheme (grpc://) and port number.
|
||||
Endpoint string
|
||||
|
||||
// TokenizerPath is the path to the tokenizer directory containing
|
||||
// tokenizer configuration files (e.g., tokenizer.json, vocab.json)
|
||||
// Required field.
|
||||
TokenizerPath string
|
||||
}
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Client Methods
|
||||
|
||||
```go
|
||||
type Client struct {
|
||||
// Thread-safe client for SGLang gRPC API
|
||||
}
|
||||
|
||||
// Creates a new client with the given configuration
|
||||
func NewClient(config ClientConfig) (*Client, error)
|
||||
|
||||
// Closes the client and releases all resources
|
||||
func (c *Client) Close() error
|
||||
|
||||
// Creates a non-streaming chat completion
|
||||
func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error)
|
||||
|
||||
// Creates a streaming chat completion
|
||||
func (c *Client) CreateChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionStream, error)
|
||||
```
|
||||
|
||||
### Request Types
|
||||
|
||||
- `ChatCompletionRequest`: Main request type for chat completions
|
||||
- Model, Messages, Stream, Temperature, TopP, MaxCompletionTokens, Tools, etc.
|
||||
- `ChatMessage`: Individual message in a conversation
|
||||
- Role, Content
|
||||
- `Tool`: Tool/function definition for function calling
|
||||
- Type, Function (name, description, parameters)
|
||||
|
||||
### Response Types
|
||||
|
||||
- `ChatCompletionResponse`: Non-streaming response
|
||||
- ID, Model, Created, Choices, Usage
|
||||
- `ChatCompletionStreamResponse`: Streaming response chunk
|
||||
- Same structure as above but for incremental updates
|
||||
- `Message`: Complete message with content and tool calls
|
||||
- `ToolCall`: Tool call information with function and arguments
|
||||
- `Usage`: Token usage statistics
|
||||
- PromptTokens, CompletionTokens, TotalTokens
|
||||
|
||||
## Testing
|
||||
|
||||
The SDK includes comprehensive testing infrastructure with both unit and integration tests.
|
||||
|
||||
### Unit Tests
|
||||
|
||||
Unit tests are located in `client_test.go` and test individual components without requiring a server.
|
||||
|
||||
#### Running Unit Tests
|
||||
|
||||
```bash
|
||||
# Run all unit tests
|
||||
go test ./...
|
||||
|
||||
# Run with verbose output
|
||||
go test -v ./...
|
||||
|
||||
# Run specific test
|
||||
go test -run TestClientConfig
|
||||
|
||||
# Run tests with race detector (detects concurrency issues)
|
||||
go test -race ./...
|
||||
|
||||
# Run with coverage analysis
|
||||
go test -cover ./...
|
||||
|
||||
# Generate detailed coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
```
|
||||
|
||||
#### Unit Test Coverage
|
||||
|
||||
- Configuration validation, type structures, response handling, concurrent operations, and benchmarks
|
||||
- `client_test.go` - 10 unit tests covering core functionality
|
||||
|
||||
### Integration Tests
|
||||
|
||||
Integration tests require a running SGLang server and test the full client-server interaction.
|
||||
|
||||
#### Prerequisites
|
||||
|
||||
1. Start SGLang server: `python -m sglang.launch_server --model-path <model_path>`
|
||||
2. Set environment variables:
|
||||
```bash
|
||||
export SGL_GRPC_ENDPOINT=grpc://localhost:20000
|
||||
export SGL_TOKENIZER_PATH=/path/to/tokenizer
|
||||
```
|
||||
|
||||
#### Running Integration Tests
|
||||
|
||||
```bash
|
||||
# Run all integration tests
|
||||
go test -tags=integration ./...
|
||||
|
||||
# Run specific integration test
|
||||
go test -tags=integration -run TestIntegrationNonStreamingCompletion
|
||||
|
||||
# Run with verbose output
|
||||
go test -tags=integration -v ./...
|
||||
|
||||
# Run with race detector
|
||||
go test -tags=integration -race ./...
|
||||
```
|
||||
|
||||
#### Integration Test Coverage
|
||||
|
||||
**Test File**: `integration_test.go` - 4 integration tests
|
||||
|
||||
- `TestIntegrationNonStreamingCompletion` - Basic non-streaming request/response
|
||||
- `TestIntegrationStreamingCompletion` - Streaming response handling
|
||||
- `TestIntegrationConcurrentRequests` - Multiple simultaneous requests
|
||||
- `TestIntegrationContextCancellation` - Context timeout and cancellation
|
||||
|
||||
### Benchmarks
|
||||
|
||||
```bash
|
||||
go test -bench=. -benchmem ./...
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
All public types and functions include comprehensive documentation with usage examples.
|
||||
|
||||
### Key Documented Components
|
||||
|
||||
- `Client` - Main client with thread-safety notes
|
||||
- `ClientConfig` - Configuration requirements and validation rules
|
||||
- `ChatCompletionRequest` - Request structure with field descriptions
|
||||
- `ChatCompletionResponse` - Response structure and usage
|
||||
- `ChatCompletionStreamResponse` - Streaming response format
|
||||
- `Usage` - Token usage information structure
|
||||
- `Tool`, `Function`, `ToolCall` - Tool call structures
|
||||
|
||||
### Viewing Documentation
|
||||
|
||||
```bash
|
||||
godoc -http=:6060
|
||||
# Visit: http://localhost:6060/pkg/github.com/sglang/sglang-go-grpc-sdk/
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
cd bindings/golang
|
||||
make build # Build Go bindings
|
||||
go vet ./... # Check code quality
|
||||
go fmt ./... # Format code
|
||||
go test -race ./... # Run tests
|
||||
```
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
bindings/golang/
|
||||
├── client.go # Main client implementation
|
||||
├── client_test.go # Unit tests
|
||||
├── integration_test.go # Integration tests
|
||||
├── README.md # This file
|
||||
├── Makefile # Build automation
|
||||
├── Cargo.toml # Rust FFI dependencies
|
||||
├── examples/ # Example programs
|
||||
│ ├── simple/ # Non-streaming example
|
||||
│ └── streaming/ # Streaming example
|
||||
├── src/ # Rust FFI source
|
||||
│ ├── client.rs # Client FFI
|
||||
│ ├── stream.rs # Stream handling
|
||||
│ ├── grpc_converter.rs # Response conversion
|
||||
│ └── ...
|
||||
└── internal/ # Internal packages
|
||||
└── ffi/ # FFI bindings
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Missing Dependencies
|
||||
|
||||
Run `go mod tidy` to sync dependencies.
|
||||
|
||||
### Connection Errors
|
||||
|
||||
Ensure SGLang server is running and check `SGL_GRPC_ENDPOINT`.
|
||||
|
||||
### Tokenizer Not Found
|
||||
|
||||
Set `SGL_TOKENIZER_PATH` environment variable.
|
||||
2. Verify path contains required files: `ls $SGL_TOKENIZER_PATH`
|
||||
3. Files should include: `tokenizer.json`, `vocab.json`, `config.json`
|
||||
|
||||
### Build Failures
|
||||
|
||||
**Error**: `library 'sgl_model_gateway_go' not found`
|
||||
|
||||
**Solution**:
|
||||
1. Rebuild Rust library: `cd sgl-model-gateway/bindings/golang && make build`
|
||||
2. Or manually with cargo: `cd sgl-model-gateway/bindings/golang && cargo build --release`
|
||||
3. Set `CARGO_BUILD_DIR` if using non-standard build location
|
||||
4. Ensure Rust toolchain is installed: `rustup toolchain list`
|
||||
|
||||
### Tests Hanging
|
||||
|
||||
**Error**: Tests seem to hang indefinitely
|
||||
|
||||
**Solution**:
|
||||
1. Use timeout for hanging tests: `timeout 30s go test ./...`
|
||||
2. Run with verbose output to see which test hangs: `go test -v ./...`
|
||||
3. Ensure server is responsive: `grpcurl -plaintext localhost:20000 list`
|
||||
|
||||
### Memory Issues
|
||||
|
||||
**Error**: Out of memory during tests
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Run with memory limit for long-running tests
|
||||
GODEBUG=madvdontneed=1 go test -timeout 5m ./...
|
||||
|
||||
# Monitor memory during tests
|
||||
watch -n1 'ps aux | grep test'
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
When adding new features:
|
||||
|
||||
1. Add comprehensive documentation to public types/functions
|
||||
2. Include usage examples for complex APIs
|
||||
3. Add unit tests covering happy path and error cases
|
||||
4. Add integration tests if server interaction required
|
||||
5. Ensure code passes `go vet` and `go test -race`
|
||||
6. Update this README if adding new features
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file for details.
|
||||
|
||||
---
|
||||
|
||||
**Need Help?**
|
||||
- Check examples in `examples/` directory
|
||||
- Run tests to see working code: `go test -v ./...`
|
||||
- Review function documentation: `godoc` or inline comments
|
||||
- Check troubleshooting section above
|
||||
483
third_party/sglang/sgl-model-gateway/bindings/golang/client.go
vendored
Normal file
483
third_party/sglang/sgl-model-gateway/bindings/golang/client.go
vendored
Normal file
@@ -0,0 +1,483 @@
|
||||
// Package sglang provides a Go SDK for SGLang gRPC API.
|
||||
//
|
||||
// SGLang is a fast language model serving framework. This package provides a Go client
|
||||
// library for interacting with SGLang's gRPC API, following the style of OpenAI's Go SDK.
|
||||
//
|
||||
// Basic usage:
|
||||
//
|
||||
// client, err := sglang.NewClient(sglang.ClientConfig{
|
||||
// Endpoint: "grpc://localhost:20000",
|
||||
// TokenizerPath: "/path/to/tokenizer",
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// defer client.Close()
|
||||
//
|
||||
// resp, err := client.CreateChatCompletion(ctx, sglang.ChatCompletionRequest{
|
||||
// Model: "default",
|
||||
// Messages: []sglang.ChatMessage{
|
||||
// {Role: "user", Content: "Hello"},
|
||||
// },
|
||||
// })
|
||||
//
|
||||
// For streaming responses, use CreateChatCompletionStream instead.
|
||||
package sglang
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
grpcclient "github.com/sglang/sglang-go-grpc-sdk/internal/grpc"
|
||||
)
|
||||
|
||||
// Client is the main client for interacting with SGLang gRPC API.
|
||||
// It manages the connection to the SGLang server and handles both streaming
|
||||
// and non-streaming chat completions.
|
||||
//
|
||||
// Thread-safe: All public methods are safe for concurrent use.
|
||||
type Client struct {
|
||||
endpoint string
|
||||
tokenizerPath string
|
||||
grpcClient *grpcclient.GrpcClient // gRPC-based client
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// ClientConfig holds configuration for creating a new client.
|
||||
type ClientConfig struct {
|
||||
// Endpoint is the gRPC endpoint URL (e.g., "grpc://localhost:20000").
|
||||
// Required field. Must include the scheme (grpc://) and port number.
|
||||
Endpoint string
|
||||
|
||||
// TokenizerPath is the path to the tokenizer directory containing
|
||||
// tokenizer configuration files (e.g., tokenizer.json, vocab.json).
|
||||
// Required field.
|
||||
TokenizerPath string
|
||||
|
||||
// ChannelBufferSizes configures buffer sizes for internal channels.
|
||||
// If nil, default values will be used (optimized for high concurrency).
|
||||
ChannelBufferSizes *ChannelBufferSizes
|
||||
|
||||
// Timeouts configures timeout values for various operations.
|
||||
// If nil, default values will be used.
|
||||
Timeouts *Timeouts
|
||||
}
|
||||
|
||||
// ChannelBufferSizes configures buffer sizes for internal channels.
|
||||
// These affect concurrency and memory usage. Larger buffers allow more
|
||||
// concurrent operations but use more memory.
|
||||
type ChannelBufferSizes = grpcclient.ChannelBufferSizes
|
||||
|
||||
// Timeouts configures timeout values for various operations.
|
||||
type Timeouts = grpcclient.Timeouts
|
||||
|
||||
// defaultChannelBufferSizes returns default channel buffer sizes optimized for high concurrency (10k+).
|
||||
// These values are designed to handle thousands of concurrent requests without blocking.
|
||||
func defaultChannelBufferSizes() ChannelBufferSizes {
|
||||
return ChannelBufferSizes{
|
||||
ResultJSONChan: 10000, // Increased for high concurrency: each request may produce 200-500 chunks
|
||||
ErrChan: 100, // Errors are rare, 100 is sufficient
|
||||
RecvChan: 2000, // Increased for high concurrency: more gRPC responses to buffer
|
||||
}
|
||||
}
|
||||
|
||||
// defaultTimeouts returns default timeout values.
|
||||
func defaultTimeouts() Timeouts {
|
||||
return Timeouts{
|
||||
KeepaliveTime: 300 * time.Second, // Increased to reduce ping frequency and avoid "too many pings" errors
|
||||
KeepaliveTimeout: 20 * time.Second,
|
||||
CloseTimeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient creates a new SGLang client with the given configuration.
|
||||
//
|
||||
// The client maintains a long-lived connection to the SGLang server and should
|
||||
// be reused for multiple requests. Call Close() to release resources.
|
||||
//
|
||||
// Returns an error if:
|
||||
// - Endpoint is empty
|
||||
// - TokenizerPath is empty
|
||||
// - Connection to the server fails
|
||||
func NewClient(config ClientConfig) (*Client, error) {
|
||||
if config.Endpoint == "" {
|
||||
return nil, errors.New("endpoint is required")
|
||||
}
|
||||
if config.TokenizerPath == "" {
|
||||
return nil, errors.New("tokenizer path is required")
|
||||
}
|
||||
|
||||
bufferSizes := defaultChannelBufferSizes()
|
||||
if config.ChannelBufferSizes != nil {
|
||||
if config.ChannelBufferSizes.ResultJSONChan > 0 {
|
||||
bufferSizes.ResultJSONChan = config.ChannelBufferSizes.ResultJSONChan
|
||||
}
|
||||
if config.ChannelBufferSizes.ErrChan > 0 {
|
||||
bufferSizes.ErrChan = config.ChannelBufferSizes.ErrChan
|
||||
}
|
||||
if config.ChannelBufferSizes.RecvChan > 0 {
|
||||
bufferSizes.RecvChan = config.ChannelBufferSizes.RecvChan
|
||||
}
|
||||
}
|
||||
|
||||
timeouts := defaultTimeouts()
|
||||
if config.Timeouts != nil {
|
||||
if config.Timeouts.KeepaliveTime > 0 {
|
||||
timeouts.KeepaliveTime = config.Timeouts.KeepaliveTime
|
||||
}
|
||||
if config.Timeouts.KeepaliveTimeout > 0 {
|
||||
timeouts.KeepaliveTimeout = config.Timeouts.KeepaliveTimeout
|
||||
}
|
||||
if config.Timeouts.CloseTimeout > 0 {
|
||||
timeouts.CloseTimeout = config.Timeouts.CloseTimeout
|
||||
}
|
||||
}
|
||||
|
||||
grpcClient, err := grpcclient.NewGrpcClient(config.Endpoint, config.TokenizerPath, bufferSizes, timeouts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gRPC client: %w", err)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
endpoint: config.Endpoint,
|
||||
tokenizerPath: config.TokenizerPath,
|
||||
grpcClient: grpcClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the client and releases all resources.
|
||||
//
|
||||
// After Close() is called, the client cannot be used for further requests.
|
||||
// Calling Close() multiple times is safe and idempotent.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.grpcClient != nil {
|
||||
if err := c.grpcClient.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.grpcClient = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatCompletionRequest represents a request for chat completion.
|
||||
// It follows the OpenAI API style for familiar usage.
|
||||
type ChatCompletionRequest struct {
|
||||
// Model specifies the model to use for completion (e.g., "default")
|
||||
Model string `json:"model"`
|
||||
// Messages is the list of messages in the conversation
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
Stop interface{} `json:"stop,omitempty"`
|
||||
StopTokenIDs []int `json:"stop_token_ids,omitempty"`
|
||||
SkipSpecialTokens bool `json:"skip_special_tokens,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a single message in a chat conversation
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool/function that can be called
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function Function `json:"function"`
|
||||
}
|
||||
|
||||
// Function represents a function definition
|
||||
type Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
// ResponseFormat represents the response format
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// ChatCompletionResponse represents a non-streaming chat completion response
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Choice represents a choice in the completion response
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// Message represents a message in the response
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall represents a tool call in the response
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionCall represents a function call
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// Usage represents token usage information
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// ChatCompletionStreamResponse represents a streaming chat completion response
|
||||
type ChatCompletionStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
Choices []StreamChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// StreamChoice represents a choice in a streaming response
|
||||
type StreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta MessageDelta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// MessageDelta represents incremental message updates
|
||||
type MessageDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatCompletion creates a non-streaming chat completion with context support.
|
||||
//
|
||||
// Context Support:
|
||||
// The ctx parameter is fully supported for cancellation and timeouts:
|
||||
// - If ctx is cancelled, the request will be interrupted on the next stream.RecvJSON() call
|
||||
// - If ctx times out, the request will return context.DeadlineExceeded
|
||||
//
|
||||
// Example with timeout:
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// resp, err := client.CreateChatCompletion(ctx, req)
|
||||
//
|
||||
// Note: Internally, this creates a stream and collects all chunks,
|
||||
// so context monitoring happens at the chunk level.
|
||||
func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||
// For non-streaming, we'll collect all chunks and return the final response
|
||||
req.Stream = true // We still use streaming internally, but collect all chunks
|
||||
|
||||
if len(req.Tools) == 0 {
|
||||
req.Tools = nil
|
||||
}
|
||||
|
||||
stream, err := c.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var fullContent strings.Builder
|
||||
var fullToolCalls []ToolCall
|
||||
var finishReason string
|
||||
var usage Usage
|
||||
var responseID string
|
||||
var created int64
|
||||
var model string
|
||||
var systemFingerprint string
|
||||
|
||||
for {
|
||||
chunkJSON, err := stream.RecvJSON()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var chunk ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(chunkJSON), &chunk); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse chunk: %w", err)
|
||||
}
|
||||
|
||||
if chunk.ID != "" {
|
||||
responseID = chunk.ID
|
||||
}
|
||||
if chunk.Created > 0 {
|
||||
created = chunk.Created
|
||||
}
|
||||
if chunk.Model != "" {
|
||||
model = chunk.Model
|
||||
}
|
||||
if chunk.SystemFingerprint != "" {
|
||||
systemFingerprint = chunk.SystemFingerprint
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
fullContent.WriteString(choice.Delta.Content)
|
||||
}
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
fullToolCalls = append(fullToolCalls, choice.Delta.ToolCalls...)
|
||||
}
|
||||
if choice.FinishReason != "" {
|
||||
finishReason = choice.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.Usage != nil {
|
||||
usage = *chunk.Usage
|
||||
}
|
||||
}
|
||||
|
||||
message := Message{
|
||||
Role: "assistant",
|
||||
Content: fullContent.String(),
|
||||
}
|
||||
if len(fullToolCalls) > 0 {
|
||||
message.ToolCalls = fullToolCalls
|
||||
}
|
||||
|
||||
if finishReason == "" {
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return &ChatCompletionResponse{
|
||||
ID: responseID,
|
||||
Object: "chat.completion",
|
||||
Created: created,
|
||||
Model: model,
|
||||
SystemFingerprint: systemFingerprint,
|
||||
Choices: []Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: message,
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ChatCompletionStream represents a streaming chat completion
|
||||
type ChatCompletionStream struct {
|
||||
grpcStream *grpcclient.GrpcChatCompletionStream
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (s *ChatCompletionStream) RecvJSON() (string, error) {
|
||||
return s.grpcStream.RecvJSON()
|
||||
}
|
||||
|
||||
// Close closes the stream and cancels any pending operations.
|
||||
func (s *ChatCompletionStream) Close() error {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
if s.grpcStream != nil {
|
||||
return s.grpcStream.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream creates a streaming chat completion with context cancellation support.
|
||||
//
|
||||
// Context Support:
|
||||
// The ctx parameter is now fully supported for cancellation and timeouts:
|
||||
// - If ctx is cancelled, stream.RecvJSON() will return context.Canceled on the next call
|
||||
// - If ctx times out (WithTimeout), stream.RecvJSON() will return context.DeadlineExceeded
|
||||
// - Calling stream.Close() also cancels the context
|
||||
//
|
||||
// Example with timeout:
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||
// // Stream will auto-close if 30 seconds elapse
|
||||
//
|
||||
// Example with cancellation:
|
||||
//
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||
// go func() {
|
||||
// time.Sleep(5*time.Second)
|
||||
// cancel() // Cancel after 5 seconds
|
||||
// }()
|
||||
func (c *Client) CreateChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionStream, error) {
|
||||
reqJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
var reqMap map[string]interface{}
|
||||
if err := json.Unmarshal(reqJSON, &reqMap); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal request to map: %w", err)
|
||||
}
|
||||
|
||||
if _, exists := reqMap["tools"]; !exists {
|
||||
reqMap["tools"] = []interface{}{}
|
||||
}
|
||||
|
||||
reqJSON, err = json.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request map to JSON: %w", err)
|
||||
}
|
||||
|
||||
if c.grpcClient == nil {
|
||||
return nil, errors.New("gRPC client is closed")
|
||||
}
|
||||
|
||||
grpcStream, err := c.grpcClient.CreateChatCompletionStream(ctx, string(reqJSON))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gRPC stream: %w", err)
|
||||
}
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
return &ChatCompletionStream{
|
||||
grpcStream: grpcStream,
|
||||
ctx: streamCtx,
|
||||
cancel: cancel,
|
||||
}, nil
|
||||
}
|
||||
325
third_party/sglang/sgl-model-gateway/bindings/golang/client_test.go
vendored
Normal file
325
third_party/sglang/sgl-model-gateway/bindings/golang/client_test.go
vendored
Normal file
@@ -0,0 +1,325 @@
|
||||
package sglang
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestClientConfig tests ClientConfig validation
|
||||
func TestClientConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config ClientConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: ClientConfig{
|
||||
Endpoint: "grpc://localhost:20000",
|
||||
TokenizerPath: "/path/to/tokenizer",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing endpoint",
|
||||
config: ClientConfig{
|
||||
Endpoint: "",
|
||||
TokenizerPath: "/path/to/tokenizer",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing tokenizer path",
|
||||
config: ClientConfig{
|
||||
Endpoint: "grpc://localhost:20000",
|
||||
TokenizerPath: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "both missing",
|
||||
config: ClientConfig{
|
||||
Endpoint: "",
|
||||
TokenizerPath: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewClient(tt.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatMessageTypes tests ChatMessage struct and its variants
|
||||
func TestChatMessageTypes(t *testing.T) {
|
||||
msg := ChatMessage{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
}
|
||||
|
||||
if msg.Role != "user" {
|
||||
t.Errorf("Expected role 'user', got '%s'", msg.Role)
|
||||
}
|
||||
if msg.Content != "Hello" {
|
||||
t.Errorf("Expected content 'Hello', got '%s'", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionRequestValidation tests ChatCompletionRequest validation
|
||||
func TestChatCompletionRequestValidation(t *testing.T) {
|
||||
// Test valid request
|
||||
req := ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "test"},
|
||||
},
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
t.Error("Expected model to be set")
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
t.Error("Expected messages to be non-empty")
|
||||
}
|
||||
|
||||
if req.Messages[0].Role != "user" {
|
||||
t.Errorf("Expected first message role 'user', got '%s'", req.Messages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientClose tests that Close can be called multiple times safely
|
||||
func TestClientClose(t *testing.T) {
|
||||
// Create a mock client (note: in real tests, you might want to skip this
|
||||
// if it requires actual server connection)
|
||||
config := ClientConfig{
|
||||
Endpoint: "grpc://localhost:20000",
|
||||
TokenizerPath: "/path/to/tokenizer",
|
||||
}
|
||||
|
||||
// Skip if connection fails (expected in unit test environment)
|
||||
client, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Skip("Skipping client close test: server not available")
|
||||
}
|
||||
|
||||
// First close should succeed
|
||||
if err := client.Close(); err != nil {
|
||||
t.Errorf("First Close() failed: %v", err)
|
||||
}
|
||||
|
||||
// Second close should also succeed (idempotent)
|
||||
if err := client.Close(); err != nil {
|
||||
t.Errorf("Second Close() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionResponseTypes tests response type structures
|
||||
func TestChatCompletionResponseTypes(t *testing.T) {
|
||||
resp := ChatCompletionResponse{
|
||||
ID: "test-id",
|
||||
Model: "default",
|
||||
Created: 1234567890,
|
||||
Choices: []Choice{
|
||||
{
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: "Hello",
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: Usage{
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}
|
||||
|
||||
if resp.ID != "test-id" {
|
||||
t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
|
||||
}
|
||||
|
||||
if len(resp.Choices) != 1 {
|
||||
t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
|
||||
}
|
||||
|
||||
if resp.Choices[0].Message.Content != "Hello" {
|
||||
t.Errorf("Expected content 'Hello', got '%s'", resp.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
if resp.Usage.TotalTokens != 30 {
|
||||
t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamingResponseTypes tests streaming response structures
|
||||
func TestStreamingResponseTypes(t *testing.T) {
|
||||
chunk := ChatCompletionStreamResponse{
|
||||
ID: "stream-id",
|
||||
Created: 1234567890,
|
||||
Choices: []StreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: MessageDelta{
|
||||
Content: "Hello",
|
||||
},
|
||||
FinishReason: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if chunk.ID != "stream-id" {
|
||||
t.Errorf("Expected ID 'stream-id', got '%s'", chunk.ID)
|
||||
}
|
||||
|
||||
if len(chunk.Choices) == 0 {
|
||||
t.Error("Expected at least one choice")
|
||||
}
|
||||
|
||||
if chunk.Choices[0].Delta.Content != "Hello" {
|
||||
t.Errorf("Expected delta content 'Hello', got '%s'", chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCallStructure tests Tool and ToolCall structures
|
||||
func TestToolCallStructure(t *testing.T) {
|
||||
tool := Tool{
|
||||
Type: "function",
|
||||
Function: Function{
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather",
|
||||
Parameters: map[string]interface{}{
|
||||
"location": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("Expected tool type 'function', got '%s'", tool.Type)
|
||||
}
|
||||
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("Expected function name 'get_weather', got '%s'", tool.Function.Name)
|
||||
}
|
||||
|
||||
toolCall := ToolCall{
|
||||
ID: "call-123",
|
||||
Type: "function",
|
||||
Function: FunctionCall{
|
||||
Name: "get_weather",
|
||||
Arguments: `{"location": "San Francisco"}`,
|
||||
},
|
||||
}
|
||||
|
||||
if toolCall.ID != "call-123" {
|
||||
t.Errorf("Expected tool call ID 'call-123', got '%s'", toolCall.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentClientOperations tests thread safety
|
||||
// This is a basic test that just verifies concurrent calls don't panic
|
||||
func TestConcurrentClientOperations(t *testing.T) {
|
||||
config := ClientConfig{
|
||||
Endpoint: "grpc://localhost:20000",
|
||||
TokenizerPath: "/path/to/tokenizer",
|
||||
}
|
||||
|
||||
client, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Skip("Skipping concurrent operations test: server not available")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Try concurrent Close calls (should not panic or race)
|
||||
done := make(chan bool, 2)
|
||||
|
||||
go func() {
|
||||
client.Close()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
client.Close()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
// BenchmarkChatCompletionRequest benchmarks request creation
|
||||
func BenchmarkChatCompletionRequest(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "test message"},
|
||||
},
|
||||
Stream: false,
|
||||
Temperature: floatPtr(0.7),
|
||||
MaxCompletionTokens: intPtr(100),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for benchmarks
|
||||
func floatPtr(f float32) *float32 {
|
||||
return &f
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
// TestContextCancellation tests that cancelled context is handled gracefully.
|
||||
//
|
||||
// NOTE: Currently, the FFI layer is blocking and doesn't actively monitor context cancellation.
|
||||
// This test verifies that the client at least returns an error rather than panicking or
|
||||
// hanging indefinitely when a pre-cancelled context is passed.
|
||||
//
|
||||
// Future: When FFI supports context cancellation (via signals or async operations),
|
||||
// this test should be updated to assert that the error is context.Canceled or wrapped
|
||||
// context cancellation error.
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
config := ClientConfig{
|
||||
Endpoint: "grpc://localhost:20000",
|
||||
TokenizerPath: "/path/to/tokenizer",
|
||||
}
|
||||
|
||||
client, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Skip("Skipping context cancellation test: server not available")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Create a pre-cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "test"},
|
||||
},
|
||||
}
|
||||
|
||||
// Attempt request with cancelled context
|
||||
// Since FFI is blocking, we expect either:
|
||||
// 1. An error from the server/network
|
||||
// 2. The call to complete normally (FFI doesn't check context)
|
||||
// What we DON'T expect is a panic or indefinite hang
|
||||
_, err = client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
t.Logf("Request with cancelled context returned error: %v", err)
|
||||
} else {
|
||||
t.Logf("Request with cancelled context completed (FFI may not support context cancellation)")
|
||||
}
|
||||
}
|
||||
239
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/Makefile
vendored
Normal file
239
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/Makefile
vendored
Normal file
@@ -0,0 +1,239 @@
|
||||
# Makefile for OAI Server
|
||||
# Builds binary, runs tests, and provides basic targets
|
||||
|
||||
# Configuration
|
||||
APP_NAME = oai_server
|
||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
||||
GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
|
||||
# Paths
|
||||
ROOT_DIR := $(shell pwd)
|
||||
BINDINGS_DIR := $(shell cd $(ROOT_DIR)/../.. && pwd)
|
||||
BUILD_DIR := $(ROOT_DIR)/build
|
||||
BINARY := $(BUILD_DIR)/$(APP_NAME)
|
||||
|
||||
# Rust FFI library paths
|
||||
LIB_DIR := $(BINDINGS_DIR)/lib
|
||||
LIB_NAME = libsgl_model_gateway_go
|
||||
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
LIB_EXT = .so
|
||||
LD_LIBRARY_PATH_VAR = LD_LIBRARY_PATH
|
||||
ARCH := $(shell uname -m)
|
||||
ifeq ($(ARCH),x86_64)
|
||||
GOARCH = amd64
|
||||
else ifeq ($(ARCH),aarch64)
|
||||
GOARCH = arm64
|
||||
endif
|
||||
endif
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
LIB_EXT = .dylib
|
||||
LD_LIBRARY_PATH_VAR = DYLD_LIBRARY_PATH
|
||||
ARCH := $(shell uname -m)
|
||||
ifeq ($(ARCH),x86_64)
|
||||
GOARCH = amd64
|
||||
else ifeq ($(ARCH),arm64)
|
||||
GOARCH = arm64
|
||||
endif
|
||||
endif
|
||||
|
||||
# Build flags
|
||||
LDFLAGS = -X main.Version=$(VERSION) -X main.BuildTime=$(BUILD_TIME) -X main.GitCommit=$(GIT_COMMIT)
|
||||
GO_BUILD_FLAGS = -ldflags "$(LDFLAGS)"
|
||||
|
||||
# Python LDFLAGS (needed for Rust FFI that depends on Python)
|
||||
PYTHON_LDFLAGS := $(shell python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || python-config --ldflags --embed 2>/dev/null || python-config --ldflags 2>/dev/null || echo "")
|
||||
|
||||
# CGO flags
|
||||
CGO_LDFLAGS = -L$(LIB_DIR) $(PYTHON_LDFLAGS)
|
||||
|
||||
.PHONY: all build build-dev test e2e clean help lib run stream check-rust-lib check-server
|
||||
|
||||
# E2E test configuration
|
||||
E2E_HOST ?= localhost
|
||||
E2E_PORT ?= 8080
|
||||
E2E_MODEL ?= default
|
||||
E2E_TOKENIZER ?= $(shell echo $$SGL_TOKENIZER_PATH || echo "./examples/tokenizer")
|
||||
E2E_NUM_PROMPTS ?= 100
|
||||
E2E_INPUT_LEN ?= 1024
|
||||
E2E_OUTPUT_LEN ?= 512
|
||||
E2E_REQUEST_RATE ?= 20
|
||||
E2E_MAX_CONCURRENCY ?= 20
|
||||
E2E_BASE_URL ?= http://$(E2E_HOST):$(E2E_PORT)
|
||||
|
||||
help:
|
||||
@echo "OAI Server Makefile"
|
||||
@echo ""
|
||||
@echo "Available targets:"
|
||||
@echo " lib - Build Rust FFI library"
|
||||
@echo " build - Build binary (release mode)"
|
||||
@echo " build-dev - Build binary (debug mode)"
|
||||
@echo " test - Run tests"
|
||||
@echo " e2e - Run end-to-end test with bench_serving.py"
|
||||
@echo " run - Run the server (development)"
|
||||
@echo " stream - Run streaming example"
|
||||
@echo " clean - Clean build artifacts"
|
||||
@echo ""
|
||||
@echo "E2E test variables:"
|
||||
@echo " E2E_HOST - OAI Server host (default: localhost)"
|
||||
@echo " E2E_PORT - OAI Server port (default: 8080)"
|
||||
@echo " E2E_MODEL - Model name (default: default)"
|
||||
@echo " E2E_TOKENIZER - Tokenizer path"
|
||||
@echo " E2E_NUM_PROMPTS - Number of prompts (default: 100)"
|
||||
@echo " E2E_INPUT_LEN - Input token length (default: 1024)"
|
||||
@echo " E2E_OUTPUT_LEN - Output token length (default: 512)"
|
||||
@echo " E2E_REQUEST_RATE - Request rate per second (default: 20)"
|
||||
@echo " E2E_MAX_CONCURRENCY - Max concurrent requests (default: 20)"
|
||||
|
||||
all: build
|
||||
|
||||
# Build Rust FFI library
|
||||
lib:
|
||||
@echo "Building Rust FFI library..."
|
||||
@cd $(BINDINGS_DIR) && $(MAKE) lib
|
||||
@echo "✓ Rust FFI library built"
|
||||
|
||||
# Check if Rust FFI library exists
|
||||
check-rust-lib:
|
||||
@if [ ! -f "$(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)" ]; then \
|
||||
echo "Error: Rust FFI library not found at $(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)"; \
|
||||
echo "Building Rust library..."; \
|
||||
cd $(BINDINGS_DIR) && $(MAKE) lib; \
|
||||
fi
|
||||
@echo "✓ Rust FFI library found"
|
||||
|
||||
# Build binary (release)
|
||||
build: check-rust-lib
|
||||
@echo "Building $(APP_NAME) (release mode)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
@CGO_ENABLED=1 \
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" \
|
||||
GOOS=$(shell go env GOOS) \
|
||||
GOARCH=$(GOARCH) \
|
||||
go build $(GO_BUILD_FLAGS) -o $(BINARY) .
|
||||
@echo "✓ Binary built: $(BINARY)"
|
||||
|
||||
# Build binary (debug)
|
||||
build-dev: check-rust-lib
|
||||
@echo "Building $(APP_NAME) (debug mode)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
@CGO_ENABLED=1 \
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" \
|
||||
go build -o $(BINARY) .
|
||||
@echo "✓ Binary built (debug): $(BINARY)"
|
||||
|
||||
# Run tests
|
||||
test: check-rust-lib
|
||||
@echo "Running tests..."
|
||||
@CGO_ENABLED=1 \
|
||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" \
|
||||
export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \
|
||||
go test -v ./...
|
||||
@echo "✓ Tests completed"
|
||||
|
||||
# Check if OAI Server is running
|
||||
check-server:
|
||||
@echo "Checking if OAI Server is running at $(E2E_BASE_URL)..."
|
||||
@if curl -s -f $(E2E_BASE_URL)/health > /dev/null 2>&1; then \
|
||||
echo "✓ OAI Server is running"; \
|
||||
exit 0; \
|
||||
else \
|
||||
echo "✗ OAI Server is not running at $(E2E_BASE_URL)"; \
|
||||
echo " Start it with: make run"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
# Find sglang project root (4 levels up from oai_server)
|
||||
SGLANG_ROOT := $(shell cd $(ROOT_DIR)/../../../../.. && pwd)
|
||||
|
||||
# Run end-to-end test with bench_serving.py
|
||||
e2e: check-server
|
||||
@echo "Checking if bench_serving.py is available..."
|
||||
@if python -m sglang.bench_serving --help > /dev/null 2>&1; then \
|
||||
echo "✓ Using installed bench_serving.py module"; \
|
||||
USE_SGLANG_ROOT=false; \
|
||||
elif [ -f "$(SGLANG_ROOT)/python/sglang/bench_serving.py" ]; then \
|
||||
echo "✓ Using bench_serving.py from $(SGLANG_ROOT)"; \
|
||||
USE_SGLANG_ROOT=true; \
|
||||
else \
|
||||
echo "✗ bench_serving.py is not available"; \
|
||||
echo " Install dependencies: pip install aiohttp numpy datasets transformers tqdm pillow pybase64"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "Running end-to-end test with bench_serving.py..."
|
||||
@echo "Configuration:"
|
||||
@echo " Server: $(E2E_BASE_URL)"
|
||||
@if [ "$(E2E_MODEL)" != "default" ]; then \
|
||||
echo " Model: $(E2E_MODEL)"; \
|
||||
fi
|
||||
@if [ -n "$(E2E_TOKENIZER)" ]; then \
|
||||
echo " Tokenizer: $(E2E_TOKENIZER)"; \
|
||||
fi
|
||||
@echo " Prompts: $(E2E_NUM_PROMPTS)"
|
||||
@echo " Input/Output: $(E2E_INPUT_LEN)/$(E2E_OUTPUT_LEN) tokens"
|
||||
@echo " Request rate: $(E2E_REQUEST_RATE) req/s"
|
||||
@echo " Max concurrency: $(E2E_MAX_CONCURRENCY)"
|
||||
@echo ""
|
||||
@TOKENIZER_ABS=$$(cd $(ROOT_DIR) && python3 -c "import os; path='$(E2E_TOKENIZER)'; print(os.path.abspath(path) if not os.path.isabs(path) else path)" 2>/dev/null || echo "$(E2E_TOKENIZER)"); \
|
||||
if [ -n "$(E2E_TOKENIZER)" ]; then \
|
||||
if [ -n "$$TOKENIZER_ABS" ] && ([ -d "$$TOKENIZER_ABS" ] || [ -f "$$TOKENIZER_ABS" ]); then \
|
||||
TOKENIZER_ARG="--tokenizer $$TOKENIZER_ABS"; \
|
||||
else \
|
||||
TOKENIZER_ARG="--tokenizer $(E2E_TOKENIZER)"; \
|
||||
fi; \
|
||||
else \
|
||||
TOKENIZER_ARG=""; \
|
||||
fi; \
|
||||
if [ "$$USE_SGLANG_ROOT" = "true" ]; then \
|
||||
cd $(SGLANG_ROOT) && PYTHONPATH=$(SGLANG_ROOT)/python:$$PYTHONPATH python python/sglang/bench_serving.py \
|
||||
--backend sglang-oai-chat \
|
||||
--base-url $(E2E_BASE_URL) \
|
||||
$$([ "$(E2E_MODEL)" != "default" ] && echo "--model $(E2E_MODEL)") \
|
||||
$$TOKENIZER_ARG \
|
||||
--dataset-name random \
|
||||
--num-prompts $(E2E_NUM_PROMPTS) \
|
||||
--random-input-len $(E2E_INPUT_LEN) \
|
||||
--random-output-len $(E2E_OUTPUT_LEN) \
|
||||
--request-rate $(E2E_REQUEST_RATE) \
|
||||
--max-concurrency $(E2E_MAX_CONCURRENCY) \
|
||||
--warmup-requests 5 \
|
||||
--disable-tqdm || (echo "✗ E2E test failed"; exit 1); \
|
||||
else \
|
||||
python -m sglang.bench_serving \
|
||||
--backend sglang-oai-chat \
|
||||
--base-url $(E2E_BASE_URL) \
|
||||
$$([ "$(E2E_MODEL)" != "default" ] && echo "--model $(E2E_MODEL)") \
|
||||
$$TOKENIZER_ARG \
|
||||
--dataset-name random \
|
||||
--num-prompts $(E2E_NUM_PROMPTS) \
|
||||
--random-input-len $(E2E_INPUT_LEN) \
|
||||
--random-output-len $(E2E_OUTPUT_LEN) \
|
||||
--request-rate $(E2E_REQUEST_RATE) \
|
||||
--max-concurrency $(E2E_MAX_CONCURRENCY) \
|
||||
--warmup-requests 5 \
|
||||
--disable-tqdm || (echo "✗ E2E test failed"; exit 1); \
|
||||
fi
|
||||
@echo ""
|
||||
@echo "✓ E2E test completed"
|
||||
|
||||
# Run the server (development)
|
||||
run: build-dev
|
||||
@echo "Running server..."
|
||||
@export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \
|
||||
$(BINARY)
|
||||
|
||||
# Run streaming example
|
||||
stream: check-rust-lib
|
||||
@echo "Running streaming example..."
|
||||
@cd $(BINDINGS_DIR)/examples/streaming && \
|
||||
export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \
|
||||
bash run.sh
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "Cleaning build artifacts..."
|
||||
@rm -rf $(BUILD_DIR)
|
||||
@echo "✓ Clean complete"
|
||||
305
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/README.md
vendored
Normal file
305
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/README.md
vendored
Normal file
@@ -0,0 +1,305 @@
|
||||
# Go SGLang Router - OpenAI Compatible API Server
|
||||
|
||||
Go SGLang Router is a high-performance OpenAI-compatible API server that communicates with the SGLang backend via gRPC and performs efficient preprocessing and postprocessing through Rust FFI.
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ **OpenAI API Compatible**: Fully compatible with OpenAI Chat Completions API
|
||||
- ✅ **High Performance**: Low latency and high throughput using gRPC and Rust FFI
|
||||
- ✅ **Streaming Support**: Server-Sent Events (SSE) streaming responses
|
||||
- ✅ **Thread-Safe**: Pre-created tokenizer handle, lock-free concurrency
|
||||
- ✅ **Graceful Shutdown**: Context cancellation mechanism to avoid resource leaks and panics
|
||||
- ✅ **Configurable**: Supports configuring channel buffer sizes and timeout durations
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
**Important Note**: gRPC mode **still calls FFI**, which is used for:
|
||||
- **Preprocessing**: chat_template and tokenization (request phase)
|
||||
- **Postprocessing**: token decoding and tool parsing (response phase)
|
||||
|
||||
gRPC is only used for communication with the SGLang backend, while input/output processing completely relies on Rust FFI.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ HTTP Client │
|
||||
│ (OpenAI API Format) │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ FastHTTP Server │
|
||||
│ handlers/chat.go:HandleChatCompletion │
|
||||
│ - Parse request JSON │
|
||||
│ - SetBodyStreamWriter (SSE) │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ SGLang Client (client.go) │
|
||||
│ CreateChatCompletionStream(ctx, req) │
|
||||
│ - Wraps gRPC client │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ gRPC Client (internal/grpc/client_grpc.go) │
|
||||
│ CreateChatCompletionStream(ctx, reqJSON) │
|
||||
│ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ Step 1: FFI Preprocess (Rust FFI) │ │
|
||||
│ │ - ffi.PreprocessChatRequestWithTokenizer() │ │
|
||||
│ │ - chat_template application │ │
|
||||
│ │ - tokenization │ │
|
||||
│ │ - tool constraints generation │ │
|
||||
│ │ Returns: PromptText, TokenIDs, ToolConstraintsJSON, │ │
|
||||
│ │ PromptTokens │ │
|
||||
│ └────────────────────┬─────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ Step 2: Build gRPC Request │ │
|
||||
│ │ - Parse request JSON (model, temperature, etc.) │ │
|
||||
│ │ - Create proto.GenerateRequest │ │
|
||||
│ │ - Set TokenizedInput (PromptText, TokenIDs) │ │
|
||||
│ │ - Set SamplingParams (temperature, top_p, top_k, etc.) │ │
|
||||
│ │ - Set Constraints (from ToolConstraintsJSON) │ │
|
||||
│ └────────────────────┬─────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ Step 3: Create gRPC Stream │ │
|
||||
│ │ - client.Generate(generateReq) → gRPC stream │ │
|
||||
│ │ - Connects to SGLang Backend (Rust) │ │
|
||||
│ └────────────────────┬─────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ Step 4: Create Converter & BatchPostprocessor │ │
|
||||
│ │ - ffi.CreateGrpcResponseConverterWithTokenizer() │ │
|
||||
│ │ - Uses preprocessed.PromptTokens for initial count │ │
|
||||
│ │ - ffi.NewBatchPostprocessor(batchSize=1, immediate) │ │
|
||||
│ └────────────────────┬─────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ Step 5: Start readLoop (Background Goroutine) │ │
|
||||
│ │ - go grpcStream.readLoop() │ │
|
||||
│ │ - Returns GrpcChatCompletionStream immediately │ │
|
||||
│ └────────────────────┬─────────────────────────────────────┘ │
|
||||
└───────────────────────┼────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ GrpcChatCompletionStream.readLoop() │
|
||||
│ (Background Goroutine) │
|
||||
│ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ Recv() Goroutine (Dedicated) │ │
|
||||
│ │ - Continuously calls stream.Recv() │ │
|
||||
│ │ - Sends results to recvChan (buffered, 2000) │ │
|
||||
│ │ - Exits on ctx.Done() or error │ │
|
||||
│ │ - Calls stream.CloseSend() on ctx.Done() │ │
|
||||
│ └────────────────────┬─────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌──────────────────────────────────────────────────────────┐ │
|
||||
│ │ Main Loop │ │
|
||||
│ │ - Reads from recvChan │ │
|
||||
│ │ - For each proto.GenerateResponse: │ │
|
||||
│ │ → go processAndSendResponse() (async) │ │
|
||||
│ │ - protoToJSON() converts proto to JSON string │ │
|
||||
│ │ - batchPostprocessor.AddChunk(protoJSON) │ │
|
||||
│ │ → FFI postprocessing (token decoding, tool parsing)│ │
|
||||
│ │ → Returns OpenAI-format JSON strings │ │
|
||||
│ │ - Sends JSON to resultJSONChan (buffered, 10000) │ │
|
||||
│ │ - All operations check ctx.Done() for cancellation │ │
|
||||
│ │ - On EOF: flush batch, send remaining results, return │ │
|
||||
│ │ - On error: send to errChan (buffered, 100) │ │
|
||||
│ │ - defer: cancel ctx, wait goroutines, close channels │ │
|
||||
│ └────────────────────┬─────────────────────────────────────┘ │
|
||||
└───────────────────────┼────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ resultJSONChan (Buffered Channel, 10000) │
|
||||
│ - Contains OpenAI-format JSON strings │
|
||||
│ - Ready for consumption │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ ChatCompletionStream.RecvJSON() │
|
||||
│ (client.go:410) │
|
||||
│ - Direct wrapper: return grpcStream.RecvJSON() │
|
||||
│ - No intermediate processing │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ FastHTTP SetBodyStreamWriter │
|
||||
│ (handlers/chat.go:159) │
|
||||
│ - Loop: stream.RecvJSON() → format SSE → flush │
|
||||
│ - Format: "data: {json}\n\n" │
|
||||
│ - Final: "data: [DONE]\n\n" │
|
||||
│ - Immediate flush after each chunk │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ HTTP Client │
|
||||
│ (SSE Stream) │
|
||||
│ Receives: data: {...}\n\n │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Start Server
|
||||
|
||||
```bash
|
||||
./run.sh
|
||||
```
|
||||
|
||||
The server will start on port `:8080`.
|
||||
|
||||
### Usage Example
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "/path/to/model",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
## Key Design
|
||||
|
||||
### 1. Thread-Safe Tokenizer
|
||||
- Pre-create `TokenizerHandle` at startup
|
||||
- Rust side uses `Arc<dyn TokenizerTrait>`, thread-safe
|
||||
- Lock-free concurrency, eliminating lock contention
|
||||
|
||||
### 2. Context Cancellation Mechanism (Graceful Shutdown)
|
||||
- Use `context.Context` cancellation mechanism
|
||||
- In `readLoop`'s `defer`: cancel context first, then wait for all goroutines to complete, finally close channels
|
||||
- `processAndSendResponse` checks `ctx.Done()` at function start, all `select` statements include `case <-s.ctx.Done()`
|
||||
- Avoids "send on closed channel" panic
|
||||
|
||||
### 3. Cancellable Recv()
|
||||
- Use dedicated goroutine to execute `Recv()`
|
||||
- Pass results through `recvChan`
|
||||
- Call `CloseSend()` when context is cancelled to make `Recv()` return error
|
||||
|
||||
### 4. Simplified Channel Design
|
||||
- `resultJSONChan`: Main data channel (gRPC layer)
|
||||
- `errChan`: Error channel (gRPC layer)
|
||||
- `recvChan`: Internal communication channel (gRPC layer)
|
||||
- Removed redundant channels and duplicate reads
|
||||
|
||||
## Configuration
|
||||
|
||||
### Channel Buffer Sizes
|
||||
|
||||
```go
|
||||
type ChannelBufferSizes struct {
|
||||
ResultJSONChan int // Default: 10000
|
||||
ErrChan int // Default: 100
|
||||
RecvChan int // Default: 2000
|
||||
}
|
||||
```
|
||||
|
||||
### Timeout Configuration
|
||||
|
||||
```go
|
||||
type Timeouts struct {
|
||||
KeepaliveTime time.Duration // Default: 300s
|
||||
KeepaliveTimeout time.Duration // Default: 20s
|
||||
CloseTimeout time.Duration // Default: 5s
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
1. **Pre-create Tokenizer**: Created at startup to avoid first request latency
|
||||
2. **Lock-Free Concurrency**: Tokenizer is thread-safe, no locks needed
|
||||
3. **Lazy Parsing**: JSON parsing deferred until needed
|
||||
4. **Direct JSON Passing**: `RecvJSON()` avoids parse/serialize overhead
|
||||
5. **Immediate Batching**: batchSize=1, no delay
|
||||
6. **Async Processing**: `readLoop` processes in background, doesn't block request handling
|
||||
7. **Configurable Buffers**: Adjust channel sizes based on concurrency needs
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
sgl-model-gateway/bindings/golang/
|
||||
├── client.go # High-level client API
|
||||
├── internal/
|
||||
│ ├── grpc/
|
||||
│ │ └── client_grpc.go # gRPC client implementation
|
||||
│ ├── ffi/ # FFI bindings (Rust)
|
||||
│ └── proto/ # Protobuf definitions
|
||||
└── examples/
|
||||
└── oai_server/
|
||||
├── handlers/
|
||||
│ └── chat.go # HTTP request handling
|
||||
├── models/
|
||||
│ └── chat.go # Request/response models
|
||||
└── service/
|
||||
└── sglang_service.go # Service layer
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Context Cancellation Mechanism
|
||||
1. **Client disconnects** → `SetBodyStreamWriter` detects flush error
|
||||
2. **Cancel streamCtx** → `readLoop` detects `ctx.Done()`
|
||||
3. **Call stream.CloseSend()** → `Recv()` goroutine returns error
|
||||
4. **readLoop defer executes**:
|
||||
- Set `closed` flag
|
||||
- Cancel context (if not already cancelled)
|
||||
- Wait for all `processAndSendResponse` goroutines to complete (`processWg.Wait()`)
|
||||
- Close all channels (`resultJSONChan`, `errChan`, `readLoopDone`)
|
||||
5. **Clean up resources and exit**
|
||||
|
||||
### Channel Blocking and Race Condition Prevention
|
||||
- **Context cancellation mechanism**: All channel sends use `select` statements with `case <-s.ctx.Done()`
|
||||
- **Graceful exit**: When context is cancelled, all blocking send operations can return immediately
|
||||
- **WaitGroup synchronization**: `readLoop`'s `defer` uses `processWg.Wait()` to ensure all goroutines complete before closing channels
|
||||
- **Avoid panic**: Through context cancellation and WaitGroup synchronization, avoids "send on closed channel" panic
|
||||
|
||||
## Key Functions
|
||||
|
||||
### CreateChatCompletionStream
|
||||
**Location**: `internal/grpc/client_grpc.go:108`
|
||||
- Preprocess request (FFI)
|
||||
- Build gRPC request
|
||||
- Create converter and batch processor
|
||||
- Start `readLoop`
|
||||
|
||||
### readLoop
|
||||
**Location**: `internal/grpc/client_grpc.go:290`
|
||||
- Start Recv() goroutine (continuously calls `stream.Recv()`)
|
||||
- Process proto responses
|
||||
- Asynchronously call `processAndSendResponse` (tracked with `processWg`)
|
||||
- **Graceful shutdown in defer**:
|
||||
- Set `closed` flag
|
||||
- Cancel context (if not already cancelled)
|
||||
- Wait for all `processAndSendResponse` goroutines to complete (`processWg.Wait()`)
|
||||
- Close all channels (`resultJSONChan`, `errChan`, `readLoopDone`)
|
||||
|
||||
### processAndSendResponse
|
||||
**Location**: `internal/grpc/client_grpc.go:379`
|
||||
- Check `ctx.Done()` at function start, return immediately if cancelled
|
||||
- Convert proto to JSON
|
||||
- Call FFI batch processor
|
||||
- All `select` statements include `case <-s.ctx.Done()` for graceful shutdown handling
|
||||
- Send JSON to channel
|
||||
|
||||
### RecvJSON
|
||||
**Location**:
|
||||
- `internal/grpc/client_grpc.go:412`: gRPC layer implementation
|
||||
- `client.go:410`: Client wrapper layer
|
||||
- Read from `resultJSONChan`
|
||||
- Directly return JSON string, no parsing needed
|
||||
55
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/config/config.go
vendored
Normal file
55
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/config/config.go
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// Config holds the application configuration
|
||||
type Config struct {
|
||||
Endpoint string
|
||||
TokenizerPath string
|
||||
Port string
|
||||
LogDir string
|
||||
LogLevel string
|
||||
}
|
||||
|
||||
// Load loads configuration from environment variables with defaults
|
||||
func Load() *Config {
|
||||
// Get tokenizer path from environment or use default
|
||||
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
|
||||
if tokenizerPath == "" {
|
||||
tokenizerPath = "../tokenizer"
|
||||
}
|
||||
|
||||
// Get endpoint from environment or use default
|
||||
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "grpc://localhost:20000"
|
||||
}
|
||||
|
||||
// Get port from environment or use default
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = "8080"
|
||||
}
|
||||
|
||||
// Get log directory from environment or use default
|
||||
logDir := os.Getenv("LOG_DIR")
|
||||
if logDir == "" {
|
||||
logDir = "./logs"
|
||||
}
|
||||
|
||||
// Get log level from environment or use default
|
||||
logLevel := os.Getenv("LOG_LEVEL")
|
||||
if logLevel == "" {
|
||||
logLevel = "info"
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Endpoint: endpoint,
|
||||
TokenizerPath: tokenizerPath,
|
||||
Port: port,
|
||||
LogDir: logDir,
|
||||
LogLevel: logLevel,
|
||||
}
|
||||
}
|
||||
121
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/docs/benchmark_result.md
vendored
Normal file
121
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/docs/benchmark_result.md
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
/tmp/ShareGPT_V3_unfiltered_cleaned_split.json: 100%|████████████████████| 642M/642M [10:02<00:00, 1.12MB/s]
|
||||
#Input tokens: 50561
|
||||
#Output tokens: 25883
|
||||
Starting warmup with 5 sequences...
|
||||
Warmup completed with 5 sequences. Starting main benchmark run...
|
||||
|
||||
============ Serving Benchmark Result ============
|
||||
Backend: sglang-oai-chat
|
||||
Traffic request rate: 20.0
|
||||
Max request concurrency: 20
|
||||
Successful requests: 100
|
||||
Benchmark duration (s): 107.24
|
||||
Total input tokens: 50561
|
||||
Total input text tokens: 50561
|
||||
Total input vision tokens: 0
|
||||
Total generated tokens: 25883
|
||||
Total generated tokens (retokenized): 129591
|
||||
Request throughput (req/s): 0.93
|
||||
Input token throughput (tok/s): 471.48
|
||||
Output token throughput (tok/s): 241.36
|
||||
Total token throughput (tok/s): 712.84
|
||||
Concurrency: 16.42
|
||||
----------------End-to-End Latency----------------
|
||||
Mean E2E Latency (ms): 17609.46
|
||||
Median E2E Latency (ms): 12343.82
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 190.71
|
||||
Median TTFT (ms): 164.86
|
||||
P99 TTFT (ms): 397.72
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 162.55
|
||||
Median TPOT (ms): 63.51
|
||||
P99 TPOT (ms): 1337.20
|
||||
---------------Inter-Token Latency----------------
|
||||
Mean ITL (ms): 25.85
|
||||
Median ITL (ms): 24.26
|
||||
P95 ITL (ms): 48.26
|
||||
P99 ITL (ms): 119.04
|
||||
Max ITL (ms): 194.58
|
||||
==================================================
|
||||
|
||||
✓ E2E test completed
|
||||
|
||||
|
||||
## Rust
|
||||
============ Serving Benchmark Result ============
|
||||
Backend: sglang-oai-chat
|
||||
Traffic request rate: 20.0
|
||||
Max request concurrency: 20
|
||||
Successful requests: 100
|
||||
Benchmark duration (s): 37.71
|
||||
Total input tokens: 50561
|
||||
Total input text tokens: 50561
|
||||
Total input vision tokens: 0
|
||||
Total generated tokens: 25883
|
||||
Total generated tokens (retokenized): 25599
|
||||
Request throughput (req/s): 2.65
|
||||
Input token throughput (tok/s): 1340.75
|
||||
Output token throughput (tok/s): 686.35
|
||||
Total token throughput (tok/s): 2027.10
|
||||
Concurrency: 18.58
|
||||
----------------End-to-End Latency----------------
|
||||
Mean E2E Latency (ms): 7008.05
|
||||
Median E2E Latency (ms): 7061.24
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 156.09
|
||||
Median TTFT (ms): 133.81
|
||||
P99 TTFT (ms): 318.53
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 26.59
|
||||
Median TPOT (ms): 26.75
|
||||
P99 TPOT (ms): 29.18
|
||||
---------------Inter-Token Latency----------------
|
||||
Mean ITL (ms): 26.71
|
||||
Median ITL (ms): 23.61
|
||||
P95 ITL (ms): 66.11
|
||||
P99 ITL (ms): 115.30
|
||||
Max ITL (ms): 201.08
|
||||
==================================================
|
||||
|
||||
|
||||
## golang
|
||||
#Input tokens: 50561
|
||||
#Output tokens: 25883
|
||||
Starting warmup with 5 sequences...
|
||||
Warmup completed with 5 sequences. Starting main benchmark run...
|
||||
|
||||
============ Serving Benchmark Result ============
|
||||
Backend: sglang-oai-chat
|
||||
Traffic request rate: 20.0
|
||||
Max request concurrency: 20
|
||||
Successful requests: 100
|
||||
Benchmark duration (s): 34.22
|
||||
Total input tokens: 50561
|
||||
Total input text tokens: 50561
|
||||
Total input vision tokens: 0
|
||||
Total generated tokens: 22970
|
||||
Total generated tokens (retokenized): 31740
|
||||
Request throughput (req/s): 2.92
|
||||
Input token throughput (tok/s): 1477.70
|
||||
Output token throughput (tok/s): 671.32
|
||||
Total token throughput (tok/s): 2149.03
|
||||
Concurrency: 18.42
|
||||
----------------End-to-End Latency----------------
|
||||
Mean E2E Latency (ms): 6303.33
|
||||
Median E2E Latency (ms): 6294.46
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 157.10
|
||||
Median TTFT (ms): 149.16
|
||||
P99 TTFT (ms): 251.98
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 26.49
|
||||
Median TPOT (ms): 27.15
|
||||
P99 TPOT (ms): 28.73
|
||||
---------------Inter-Token Latency----------------
|
||||
Mean ITL (ms): 26.97
|
||||
Median ITL (ms): 24.61
|
||||
P95 ITL (ms): 52.39
|
||||
P99 ITL (ms): 86.52
|
||||
Max ITL (ms): 194.55
|
||||
==================================================
|
||||
60
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/go.sum
vendored
Normal file
60
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/go.sum
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
|
||||
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0=
|
||||
github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo=
|
||||
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
|
||||
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
556
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/chat.go
vendored
Normal file
556
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/chat.go
vendored
Normal file
@@ -0,0 +1,556 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sglang "github.com/sglang/sglang-go-grpc-sdk"
|
||||
"github.com/valyala/fasthttp"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"oai_server/models"
|
||||
"oai_server/service"
|
||||
"oai_server/utils"
|
||||
)
|
||||
|
||||
// ChatHandler handles chat completion requests
|
||||
type ChatHandler struct {
|
||||
logger *zap.Logger
|
||||
service *service.SGLangService
|
||||
}
|
||||
|
||||
// NewChatHandler creates a new chat handler
|
||||
func NewChatHandler(logger *zap.Logger, svc *service.SGLangService) *ChatHandler {
|
||||
return &ChatHandler{
|
||||
logger: logger,
|
||||
service: svc,
|
||||
}
|
||||
}
|
||||
|
||||
// recvResult holds the result of a RecvJSON() call
|
||||
type recvResult struct {
|
||||
chunkJSON string
|
||||
err error
|
||||
}
|
||||
|
||||
// HandleChatCompletion handles POST /v1/chat/completions
|
||||
func (h *ChatHandler) HandleChatCompletion(ctx *fasthttp.RequestCtx) {
|
||||
var req models.ChatRequest
|
||||
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
h.logger.Warn("Invalid chat completion request", zap.Error(err))
|
||||
utils.RespondError(ctx, 400, fmt.Sprintf("Invalid request: %v", err), "invalid_request_error")
|
||||
return
|
||||
}
|
||||
|
||||
path := string(ctx.Path())
|
||||
|
||||
defer func() {
|
||||
statusCode := ctx.Response.StatusCode()
|
||||
if statusCode == 0 {
|
||||
statusCode = 200
|
||||
}
|
||||
h.logHTTPResponse(statusCode, path)
|
||||
}()
|
||||
|
||||
// Convert to SGLang format
|
||||
messages := make([]sglang.ChatMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
role, roleOk := msg["role"]
|
||||
content, contentOk := msg["content"]
|
||||
|
||||
// Validate role
|
||||
if !roleOk || role == "" {
|
||||
h.logger.Warn("Missing or empty role in message", zap.Int("message_index", i))
|
||||
utils.RespondError(ctx, 400, "Message role is required and cannot be empty", "invalid_request_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure content is always a string (not null)
|
||||
// Chat template requires content field to be present, even if empty
|
||||
// If content is missing or null, use empty string
|
||||
contentStr := ""
|
||||
if contentOk && content != "" {
|
||||
contentStr = content
|
||||
}
|
||||
|
||||
messages[i] = sglang.ChatMessage{
|
||||
Role: role,
|
||||
Content: contentStr,
|
||||
}
|
||||
}
|
||||
|
||||
sglReq := sglang.ChatCompletionRequest{
|
||||
Model: req.Model,
|
||||
Messages: messages,
|
||||
Stream: req.Stream,
|
||||
}
|
||||
|
||||
if req.Temperature != nil {
|
||||
temp := float32(*req.Temperature)
|
||||
sglReq.Temperature = &temp
|
||||
}
|
||||
if req.TopP != nil {
|
||||
topP := float32(*req.TopP)
|
||||
sglReq.TopP = &topP
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
sglReq.MaxCompletionTokens = req.MaxCompletionTokens
|
||||
} else if req.MaxTokens != nil {
|
||||
sglReq.MaxCompletionTokens = req.MaxTokens
|
||||
}
|
||||
|
||||
requestCtx := context.Background()
|
||||
|
||||
if req.Stream {
|
||||
h.handleStreamingCompletion(ctx, requestCtx, sglReq)
|
||||
} else {
|
||||
h.handleNonStreamingCompletion(ctx, requestCtx, sglReq)
|
||||
}
|
||||
}
|
||||
|
||||
// isBrokenPipeError checks if the error is a broken pipe error (client disconnected)
|
||||
func isBrokenPipeError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "connection closed") ||
|
||||
strings.Contains(errStr, "write: connection closed")
|
||||
}
|
||||
|
||||
// logHTTPResponse logs HTTP response with colored output
|
||||
func (h *ChatHandler) logHTTPResponse(statusCode int, path string) {
|
||||
var statusText string
|
||||
var colorCode string
|
||||
|
||||
switch {
|
||||
case statusCode >= 200 && statusCode < 300:
|
||||
colorCode = "\033[32m" // Green
|
||||
statusText = "OK"
|
||||
case statusCode >= 300 && statusCode < 400:
|
||||
colorCode = "\033[33m" // Yellow
|
||||
statusText = "Redirect"
|
||||
case statusCode >= 400 && statusCode < 500:
|
||||
colorCode = "\033[33m" // Yellow
|
||||
statusText = "Client Error"
|
||||
case statusCode >= 500:
|
||||
colorCode = "\033[31m" // Red
|
||||
statusText = "Server Error"
|
||||
default:
|
||||
colorCode = "\033[37m" // White
|
||||
statusText = "Unknown"
|
||||
}
|
||||
|
||||
resetCode := "\033[0m"
|
||||
msg := fmt.Sprintf("%s[%d %s]%s %s", colorCode, statusCode, statusText, resetCode, path)
|
||||
h.logger.Info(msg)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) handleStreamingCompletion(ctx *fasthttp.RequestCtx, requestCtx context.Context, req sglang.ChatCompletionRequest) {
|
||||
|
||||
ctx.SetContentType("text/event-stream")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
ctx.Response.Header.Set("Connection", "keep-alive")
|
||||
ctx.Response.Header.Set("X-Accel-Buffering", "no")
|
||||
ctx.SetStatusCode(200)
|
||||
|
||||
var clientDisconnected bool
|
||||
// Flush timeout: prevent deadlock if client is slow or disconnected
|
||||
// This timeout should be longer than typical network latency but shorter than client timeout
|
||||
const flushTimeout = 5 * time.Second
|
||||
|
||||
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
streamCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
stream, err := h.service.Client().CreateChatCompletionStream(streamCtx, req)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create chat completion stream",
|
||||
zap.Error(err),
|
||||
zap.String("model", req.Model),
|
||||
)
|
||||
// Use sendSSEError to send error in consistent format
|
||||
errInfo, sendErr := h.sendSSEError(w, err)
|
||||
if sendErr != nil {
|
||||
h.logger.Warn("Failed to send SSE error", zap.Error(sendErr))
|
||||
} else if errInfo.IsTimeout {
|
||||
h.logger.Error("Stream creation timeout", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := stream.Close(); closeErr != nil {
|
||||
h.logger.Warn("Failed to close stream", zap.Error(closeErr))
|
||||
}
|
||||
}()
|
||||
|
||||
// Use a single dedicated goroutine to continuously call RecvJSON() and send results via channel
|
||||
recvChan := make(chan recvResult, 20)
|
||||
recvGoroutineDone := make(chan struct{})
|
||||
go func() {
|
||||
defer func() {
|
||||
close(recvChan)
|
||||
close(recvGoroutineDone)
|
||||
}()
|
||||
for {
|
||||
// Check context before calling RecvJSON() to avoid blocking if context is cancelled
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Call RecvJSON() - this may block, but stream.Close() will unblock it
|
||||
// when context is cancelled (called from main loop)
|
||||
chunkJSON, err := stream.RecvJSON()
|
||||
|
||||
// Check context again after RecvJSON() returns
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Send to channel (may block if channel is full)
|
||||
// If channel is full, this will block until main loop reads from it
|
||||
// This is acceptable because main loop should be actively reading
|
||||
select {
|
||||
case recvChan <- recvResult{chunkJSON: chunkJSON, err: err}:
|
||||
if err != nil {
|
||||
// EOF or other error, stop the goroutine
|
||||
return
|
||||
}
|
||||
case <-streamCtx.Done():
|
||||
// Context cancelled while sending, stop the goroutine
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if clientDisconnected {
|
||||
cancel()
|
||||
// Close stream immediately to unblock RecvJSON() calls
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
// Close stream to ensure RecvJSON() goroutine can exit
|
||||
stream.Close()
|
||||
return
|
||||
case result, ok := <-recvChan:
|
||||
if !ok {
|
||||
// Channel closed, stream ended
|
||||
return
|
||||
}
|
||||
if result.err == io.EOF {
|
||||
if !clientDisconnected {
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
// Flush with timeout to prevent deadlock
|
||||
flushDone := make(chan error, 1)
|
||||
go func() {
|
||||
flushDone <- w.Flush()
|
||||
}()
|
||||
flushCtx, flushCancel := context.WithTimeout(streamCtx, flushTimeout)
|
||||
defer flushCancel()
|
||||
select {
|
||||
case flushErr := <-flushDone:
|
||||
if flushErr != nil && !isBrokenPipeError(flushErr) {
|
||||
h.logger.Warn("Final flush error", zap.Error(flushErr))
|
||||
}
|
||||
case <-flushCtx.Done():
|
||||
if flushCtx.Err() == context.DeadlineExceeded {
|
||||
h.logger.Warn("Final flush timeout", zap.Duration("timeout", flushTimeout))
|
||||
}
|
||||
case <-streamCtx.Done():
|
||||
// Context cancelled, skip flush
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if result.err != nil {
|
||||
if result.err == context.Canceled || result.err == context.DeadlineExceeded {
|
||||
return
|
||||
}
|
||||
// Send error to client before closing
|
||||
errInfo, sendErr := h.sendSSEError(w, result.err)
|
||||
if sendErr != nil {
|
||||
h.logger.Warn("Failed to send SSE error", zap.Error(sendErr))
|
||||
}
|
||||
if errInfo.IsTimeout {
|
||||
h.logger.Error("Stream timeout error", zap.Error(result.err))
|
||||
} else {
|
||||
h.logger.Error("Stream error", zap.Error(result.err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if result.chunkJSON == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
w.WriteString("data: ")
|
||||
w.WriteString(result.chunkJSON)
|
||||
w.WriteString("\n\n")
|
||||
|
||||
// Flush with timeout to prevent deadlock:
|
||||
// If Flush blocks indefinitely (slow client), RecvJSON goroutine may fill recvChan
|
||||
// and then block trying to send, causing deadlock
|
||||
// Note: bufio.Writer.Flush() doesn't have a timeout parameter, so we use
|
||||
// a goroutine + select pattern to implement timeout behavior
|
||||
flushDone := make(chan error, 1)
|
||||
go func() {
|
||||
flushDone <- w.Flush()
|
||||
}()
|
||||
|
||||
flushCtx, flushCancel := context.WithTimeout(streamCtx, flushTimeout)
|
||||
defer flushCancel()
|
||||
|
||||
select {
|
||||
case err := <-flushDone:
|
||||
if err != nil {
|
||||
if isBrokenPipeError(err) {
|
||||
clientDisconnected = true
|
||||
cancel()
|
||||
// Close stream immediately to unblock RecvJSON() calls
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
h.logger.Warn("Flush error", zap.Error(err))
|
||||
}
|
||||
case <-flushCtx.Done():
|
||||
// Flush timeout: client may be slow or disconnected
|
||||
// Continue processing to avoid deadlock, but mark as disconnected
|
||||
if flushCtx.Err() == context.DeadlineExceeded {
|
||||
h.logger.Warn("Flush timeout, client may be slow or disconnected", zap.Duration("timeout", flushTimeout))
|
||||
}
|
||||
clientDisconnected = true
|
||||
cancel()
|
||||
stream.Close()
|
||||
return
|
||||
case <-streamCtx.Done():
|
||||
// Context cancelled, stop flushing
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ChatHandler) handleNonStreamingCompletion(ctx *fasthttp.RequestCtx, requestCtx context.Context, req sglang.ChatCompletionRequest) {
|
||||
resp, err := h.service.Client().CreateChatCompletion(requestCtx, req)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create chat completion",
|
||||
zap.Error(err),
|
||||
zap.String("model", req.Model),
|
||||
)
|
||||
utils.RespondError(ctx, 500, fmt.Sprintf("Failed to create completion: %v", err), "server_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to OpenAI format
|
||||
response := utils.BuildResponseBase(resp.ID, resp.Created, resp.Model)
|
||||
response["object"] = "chat.completion"
|
||||
|
||||
choices := make([]map[string]interface{}, len(resp.Choices))
|
||||
for i, choice := range resp.Choices {
|
||||
choiceMap := map[string]interface{}{
|
||||
"index": choice.Index,
|
||||
"message": map[string]interface{}{
|
||||
"role": choice.Message.Role,
|
||||
"content": choice.Message.Content,
|
||||
},
|
||||
"finish_reason": choice.FinishReason,
|
||||
}
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
toolCalls := make([]map[string]interface{}, len(choice.Message.ToolCalls))
|
||||
for j, tc := range choice.Message.ToolCalls {
|
||||
toolCalls[j] = map[string]interface{}{
|
||||
"id": tc.ID,
|
||||
"type": tc.Type,
|
||||
"function": map[string]interface{}{"name": tc.Function.Name, "arguments": tc.Function.Arguments},
|
||||
}
|
||||
}
|
||||
choiceMap["message"].(map[string]interface{})["tool_calls"] = toolCalls
|
||||
}
|
||||
choices[i] = choiceMap
|
||||
}
|
||||
response["choices"] = choices
|
||||
|
||||
// Usage is always present (not a pointer)
|
||||
response["usage"] = map[string]interface{}{
|
||||
"prompt_tokens": resp.Usage.PromptTokens,
|
||||
"completion_tokens": resp.Usage.CompletionTokens,
|
||||
"total_tokens": resp.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetContentType("application/json")
|
||||
jsonData, _ := json.Marshal(response)
|
||||
ctx.Write(jsonData)
|
||||
}
|
||||
|
||||
// StreamErrorInfo holds parsed error information
|
||||
type StreamErrorInfo struct {
|
||||
Message string
|
||||
Type string
|
||||
Code int
|
||||
IsTimeout bool
|
||||
}
|
||||
|
||||
// parseStreamError parses error type and code
|
||||
func parseStreamError(err error) StreamErrorInfo {
|
||||
if err == nil {
|
||||
return StreamErrorInfo{}
|
||||
}
|
||||
|
||||
errorMsg := err.Error()
|
||||
// Check timeout error by message prefix
|
||||
isTimeout := strings.HasPrefix(errorMsg, "stream.Recv() timeout") || strings.Contains(errorMsg, "timeout after")
|
||||
|
||||
errorType := "server_error"
|
||||
errorCode := 500
|
||||
if isTimeout {
|
||||
errorType = "timeout_error"
|
||||
errorCode = 504
|
||||
}
|
||||
|
||||
return StreamErrorInfo{
|
||||
Message: errorMsg,
|
||||
Type: errorType,
|
||||
Code: errorCode,
|
||||
IsTimeout: isTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// formatErrorJSON formats error as OpenAI JSON
|
||||
func formatErrorJSON(errInfo StreamErrorInfo) string {
|
||||
errorObj := map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"message": errInfo.Message,
|
||||
"type": errInfo.Type,
|
||||
"code": errInfo.Code,
|
||||
},
|
||||
}
|
||||
jsonBytes, _ := json.Marshal(errorObj)
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
// sendSSEError sends SSE error response. Callers should log errors.
|
||||
func (h *ChatHandler) sendSSEError(w *bufio.Writer, err error) (StreamErrorInfo, error) {
|
||||
errInfo := parseStreamError(err)
|
||||
errorJSON := formatErrorJSON(errInfo)
|
||||
|
||||
w.WriteString("data: ")
|
||||
w.WriteString(errorJSON)
|
||||
w.WriteString("\n\n")
|
||||
|
||||
if flushErr := w.Flush(); flushErr != nil && !isBrokenPipeError(flushErr) {
|
||||
h.logger.Warn("Failed to flush error response", zap.Error(flushErr))
|
||||
return errInfo, flushErr
|
||||
}
|
||||
|
||||
return errInfo, nil
|
||||
}
|
||||
|
||||
// HandleGenerate handles POST /generate (SGLang native API)
|
||||
func (h *ChatHandler) HandleGenerate(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
|
||||
defer func() {
|
||||
statusCode := ctx.Response.StatusCode()
|
||||
if statusCode == 0 {
|
||||
statusCode = 200
|
||||
}
|
||||
h.logHTTPResponse(statusCode, path)
|
||||
}()
|
||||
|
||||
// Parse request body
|
||||
var req map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
h.logger.Warn("Invalid generate request", zap.Error(err))
|
||||
utils.RespondError(ctx, 400, fmt.Sprintf("Invalid request: %v", err), "invalid_request_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract text and sampling_params
|
||||
text, ok := req["text"].(string)
|
||||
if !ok || text == "" {
|
||||
utils.RespondError(ctx, 400, "Missing or invalid 'text' field", "invalid_request_error")
|
||||
return
|
||||
}
|
||||
|
||||
samplingParams, _ := req["sampling_params"].(map[string]interface{})
|
||||
if samplingParams == nil {
|
||||
samplingParams = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Convert to chat completion format for processing
|
||||
chatReq := sglang.ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []sglang.ChatMessage{{Role: "user", Content: text}},
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
// Copy sampling params
|
||||
if maxNewTokens, ok := samplingParams["max_new_tokens"].(float64); ok {
|
||||
tokens := int(maxNewTokens)
|
||||
chatReq.MaxCompletionTokens = &tokens
|
||||
}
|
||||
if temp, ok := samplingParams["temperature"].(float64); ok {
|
||||
temp32 := float32(temp)
|
||||
chatReq.Temperature = &temp32
|
||||
}
|
||||
if topP, ok := samplingParams["top_p"].(float64); ok {
|
||||
topP32 := float32(topP)
|
||||
chatReq.TopP = &topP32
|
||||
}
|
||||
if topK, ok := samplingParams["top_k"].(float64); ok {
|
||||
topKInt := int(topK)
|
||||
chatReq.TopK = &topKInt
|
||||
}
|
||||
|
||||
requestCtx := context.Background()
|
||||
|
||||
// Use non-streaming completion for /generate endpoint
|
||||
resp, err := h.service.Client().CreateChatCompletion(requestCtx, chatReq)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create completion",
|
||||
zap.Error(err),
|
||||
)
|
||||
utils.RespondError(ctx, 500, fmt.Sprintf("Failed to create completion: %v", err), "server_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to SGLang /generate response format
|
||||
// meta_info must match SGLang's expected format with completion_tokens at top level
|
||||
finishReason := resp.Choices[0].FinishReason
|
||||
if finishReason == "" {
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"text": resp.Choices[0].Message.Content,
|
||||
"meta_info": map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"finish_reason": finishReason,
|
||||
"prompt_tokens": resp.Usage.PromptTokens,
|
||||
"completion_tokens": resp.Usage.CompletionTokens,
|
||||
"cached_tokens": 0, // Not available from chat completion API
|
||||
"weight_version": "", // Not available from chat completion API
|
||||
},
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetContentType("application/json")
|
||||
jsonData, _ := json.Marshal(response)
|
||||
ctx.Write(jsonData)
|
||||
}
|
||||
33
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/health.go
vendored
Normal file
33
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/health.go
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HealthHandler handles health check requests
|
||||
type HealthHandler struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewHealthHandler creates a new health handler
|
||||
func NewHealthHandler(logger *zap.Logger) *HealthHandler {
|
||||
return &HealthHandler{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Check handles GET /health
|
||||
func (h *HealthHandler) Check(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
response := map[string]string{
|
||||
"status": "ok",
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(response)
|
||||
ctx.Write(jsonData)
|
||||
}
|
||||
67
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/models.go
vendored
Normal file
67
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/models.go
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ModelsHandler handles model list requests
|
||||
type ModelsHandler struct {
|
||||
logger *zap.Logger
|
||||
tokenizerPath string
|
||||
}
|
||||
|
||||
// NewModelsHandler creates a new models handler
|
||||
func NewModelsHandler(logger *zap.Logger, tokenizerPath string) *ModelsHandler {
|
||||
return &ModelsHandler{
|
||||
logger: logger,
|
||||
tokenizerPath: tokenizerPath,
|
||||
}
|
||||
}
|
||||
|
||||
// List handles GET /v1/models
|
||||
func (h *ModelsHandler) List(ctx *fasthttp.RequestCtx) {
|
||||
// Return a default model for OpenAI compatibility
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
response := map[string]interface{}{
|
||||
"object": "list",
|
||||
"data": []map[string]interface{}{
|
||||
{
|
||||
"id": "default",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "sglang",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(response)
|
||||
ctx.Write(jsonData)
|
||||
}
|
||||
|
||||
// GetModelInfo handles GET /get_model_info
|
||||
// Returns model information compatible with SGLang RuntimeEndpoint
|
||||
func (h *ModelsHandler) GetModelInfo(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
// Return model info compatible with SGLang RuntimeEndpoint expectations
|
||||
response := map[string]interface{}{
|
||||
"model_path": h.tokenizerPath, // Use tokenizer path as model path
|
||||
"tokenizer_path": h.tokenizerPath,
|
||||
"is_generation": true,
|
||||
"preferred_sampling_params": "",
|
||||
"weight_version": "",
|
||||
"has_image_understanding": false,
|
||||
"has_audio_understanding": false,
|
||||
"model_type": "",
|
||||
"architectures": nil,
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(response)
|
||||
ctx.Write(jsonData)
|
||||
}
|
||||
67
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/logger/logger.go
vendored
Normal file
67
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/logger/logger.go
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// Init initializes the logger with file and console output
|
||||
func Init(logDir, logLevel string) (*zap.Logger, error) {
|
||||
// Ensure log directory exists
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse log level
|
||||
var level zapcore.Level
|
||||
if err := level.UnmarshalText([]byte(logLevel)); err != nil {
|
||||
level = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
// Create log file path with date
|
||||
logFile := filepath.Join(logDir, "oai_server-"+time.Now().Format("2006-01-02")+".log")
|
||||
|
||||
// File writer with rotation
|
||||
fileWriter := zapcore.AddSync(&lumberjack.Logger{
|
||||
Filename: logFile,
|
||||
MaxSize: 100, // megabytes
|
||||
MaxBackups: 10,
|
||||
MaxAge: 30, // days
|
||||
Compress: true,
|
||||
})
|
||||
|
||||
// Console writer
|
||||
consoleWriter := zapcore.AddSync(os.Stdout)
|
||||
|
||||
// Encoder config
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = "timestamp"
|
||||
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
|
||||
|
||||
// Create cores
|
||||
fileCore := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
fileWriter,
|
||||
level,
|
||||
)
|
||||
|
||||
consoleCore := zapcore.NewCore(
|
||||
zapcore.NewConsoleEncoder(encoderConfig),
|
||||
consoleWriter,
|
||||
level,
|
||||
)
|
||||
|
||||
// Combine cores
|
||||
core := zapcore.NewTee(fileCore, consoleCore)
|
||||
|
||||
// Create logger
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
116
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/main.go
vendored
Normal file
116
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/main.go
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
// OpenAI-compatible chat server using SGLang Go SDK and fasthttp framework
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
_ "net/http/pprof" // Enable pprof endpoints
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"oai_server/config"
|
||||
"oai_server/handlers"
|
||||
"oai_server/logger"
|
||||
"oai_server/service"
|
||||
)
|
||||
|
||||
// Version information (set at build time via ldflags)
|
||||
var (
|
||||
Version = "dev"
|
||||
BuildTime = "unknown"
|
||||
GitCommit = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
cfg := config.Load()
|
||||
|
||||
// Initialize logger
|
||||
appLogger, err := logger.Init(cfg.LogDir, cfg.LogLevel)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
|
||||
}
|
||||
defer appLogger.Sync()
|
||||
|
||||
appLogger.Info("Starting OpenAI-compatible server",
|
||||
zap.String("endpoint", cfg.Endpoint),
|
||||
zap.String("tokenizer", cfg.TokenizerPath),
|
||||
zap.String("port", cfg.Port),
|
||||
)
|
||||
|
||||
// Initialize SGLang service
|
||||
sglangService, err := service.NewSGLangService(cfg.Endpoint, cfg.TokenizerPath)
|
||||
if err != nil {
|
||||
appLogger.Fatal("Failed to create SGLang client", zap.Error(err))
|
||||
}
|
||||
defer sglangService.Close()
|
||||
|
||||
appLogger.Info("SGLang client created successfully")
|
||||
|
||||
// Enable pprof if requested
|
||||
if os.Getenv("PPROF_ENABLED") == "true" {
|
||||
pprofPort := os.Getenv("PPROF_PORT")
|
||||
if pprofPort == "" {
|
||||
pprofPort = "6060"
|
||||
}
|
||||
go func() {
|
||||
pprofAddr := ":" + pprofPort
|
||||
appLogger.Info("Starting pprof server", zap.String("address", pprofAddr))
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
appLogger.Error("pprof server failed", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
appLogger.Info("pprof enabled", zap.String("port", pprofPort), zap.String("endpoint", fmt.Sprintf("http://localhost:%s/debug/pprof/", pprofPort)))
|
||||
}
|
||||
|
||||
// Initialize handlers
|
||||
healthHandler := handlers.NewHealthHandler(appLogger)
|
||||
modelsHandler := handlers.NewModelsHandler(appLogger, cfg.TokenizerPath)
|
||||
chatHandler := handlers.NewChatHandler(appLogger, sglangService)
|
||||
|
||||
// Setup fasthttp router
|
||||
router := func(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
method := string(ctx.Method())
|
||||
|
||||
switch {
|
||||
case method == "GET" && path == "/health":
|
||||
healthHandler.Check(ctx)
|
||||
case method == "GET" && path == "/v1/models":
|
||||
modelsHandler.List(ctx)
|
||||
case method == "GET" && path == "/get_model_info":
|
||||
modelsHandler.GetModelInfo(ctx)
|
||||
case method == "POST" && path == "/v1/chat/completions":
|
||||
chatHandler.HandleChatCompletion(ctx)
|
||||
case (method == "POST" || method == "PUT") && path == "/generate":
|
||||
chatHandler.HandleGenerate(ctx)
|
||||
default:
|
||||
ctx.Error("Not Found", fasthttp.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// Start server
|
||||
serverAddr := ":" + cfg.Port
|
||||
baseURL := fmt.Sprintf("http://localhost:%s", cfg.Port)
|
||||
|
||||
appLogger.Info("Server starting",
|
||||
zap.String("address", serverAddr),
|
||||
zap.String("base_url", baseURL),
|
||||
)
|
||||
|
||||
// Print available HTTP endpoints (similar to FastAPI startup)
|
||||
appLogger.Info("Available HTTP endpoints:")
|
||||
appLogger.Info(fmt.Sprintf(" GET %s/health", baseURL))
|
||||
appLogger.Info(fmt.Sprintf(" GET %s/v1/models", baseURL))
|
||||
appLogger.Info(fmt.Sprintf(" GET %s/get_model_info", baseURL))
|
||||
appLogger.Info(fmt.Sprintf(" POST %s/v1/chat/completions", baseURL))
|
||||
appLogger.Info(fmt.Sprintf(" POST %s/generate", baseURL))
|
||||
appLogger.Info(fmt.Sprintf("Application startup complete. Listening on %s", baseURL))
|
||||
|
||||
if err := fasthttp.ListenAndServe(serverAddr, router); err != nil {
|
||||
appLogger.Fatal("Server failed", zap.Error(err))
|
||||
}
|
||||
}
|
||||
14
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/models/chat.go
vendored
Normal file
14
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/models/chat.go
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
package models
|
||||
|
||||
// ChatRequest represents an OpenAI-compatible chat completion request
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Messages []map[string]string `json:"messages" binding:"required"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // OpenAI API standard field
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // SGLang-specific field (used by bench_serving.py)
|
||||
Tools []map[string]interface{} `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
}
|
||||
111
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/run.sh
vendored
Executable file
111
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/run.sh
vendored
Executable file
@@ -0,0 +1,111 @@
|
||||
#!/bin/bash
|
||||
|
||||
# OpenAI-compatible server runner
|
||||
# Usage: ./run.sh [tokenizer_path] [endpoint] [port] [--profile] [--pprof-port PORT]
|
||||
#
|
||||
# Options:
|
||||
# --profile Enable pprof profiling (default port: 6060)
|
||||
# --pprof-port PORT Set pprof port (default: 6060, requires --profile)
|
||||
|
||||
# Set library path for Rust FFI library
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
BINDINGS_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
LIB_DIR="${BINDINGS_DIR}/lib"
|
||||
|
||||
if [ ! -d "$LIB_DIR" ]; then
|
||||
echo "Error: Library directory not found at $LIB_DIR"
|
||||
echo "Please run 'make lib' first to build and export the library"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get Python LDFLAGS (needed for Rust FFI that depends on Python)
|
||||
PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
|
||||
|
||||
# Set CGO_LDFLAGS to link with the Rust library
|
||||
# Note: -lsgl_model_gateway_go and -ldl are already in the #cgo directive in internal/ffi/client.go
|
||||
# We only need to add the library path (-L) and Python flags
|
||||
export CGO_LDFLAGS="-L${LIB_DIR} ${PYTHON_LDFLAGS}"
|
||||
|
||||
# macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
export DYLD_LIBRARY_PATH="${LIB_DIR}:${DYLD_LIBRARY_PATH}"
|
||||
else
|
||||
export LD_LIBRARY_PATH="${LIB_DIR}:${LD_LIBRARY_PATH}"
|
||||
fi
|
||||
|
||||
# Parse arguments
|
||||
ENABLE_PROFILE=false
|
||||
PPROF_PORT="6060"
|
||||
TOKENIZER_PATH=""
|
||||
ENDPOINT=""
|
||||
PORT=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--profile)
|
||||
ENABLE_PROFILE=true
|
||||
shift
|
||||
;;
|
||||
--pprof-port)
|
||||
ENABLE_PROFILE=true
|
||||
PPROF_PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
if [[ -z "$TOKENIZER_PATH" ]]; then
|
||||
TOKENIZER_PATH="$1"
|
||||
elif [[ -z "$ENDPOINT" ]]; then
|
||||
ENDPOINT="$1"
|
||||
elif [[ -z "$PORT" ]]; then
|
||||
PORT="$1"
|
||||
fi
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_TOKENIZER_PATH="${SGL_TOKENIZER_PATH:-../tokenizer}"
|
||||
DEFAULT_ENDPOINT="${SGL_GRPC_ENDPOINT:-grpc://localhost:20000}"
|
||||
DEFAULT_PORT="${PORT:-8080}"
|
||||
|
||||
TOKENIZER_PATH="${TOKENIZER_PATH:-${DEFAULT_TOKENIZER_PATH}}"
|
||||
ENDPOINT="${ENDPOINT:-${DEFAULT_ENDPOINT}}"
|
||||
PORT="${PORT:-${DEFAULT_PORT}}"
|
||||
|
||||
echo "Running OpenAI-compatible server..."
|
||||
echo "Library path: ${LIB_DIR}"
|
||||
echo "Tokenizer: $TOKENIZER_PATH"
|
||||
echo "Endpoint: $ENDPOINT"
|
||||
echo "Port: $PORT"
|
||||
echo "Client Mode: gRPC (default)"
|
||||
echo "FFI Postprocessing: ENABLED (normal mode)"
|
||||
echo "FFI Preprocessing: ENABLED (normal mode)"
|
||||
if [[ "$ENABLE_PROFILE" == "true" ]]; then
|
||||
echo "Profiling: enabled (port: $PPROF_PORT)"
|
||||
echo " pprof endpoint: http://localhost:$PPROF_PORT/debug/pprof/"
|
||||
export PPROF_ENABLED=true
|
||||
export PPROF_PORT="$PPROF_PORT"
|
||||
else
|
||||
echo "Profiling: disabled"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Change to script directory
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Ensure Go module is properly initialized
|
||||
if [ ! -f "go.mod" ]; then
|
||||
echo "Error: go.mod not found in $(pwd)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure Go modules are enabled
|
||||
export GO111MODULE=on
|
||||
|
||||
# Sync Go module dependencies
|
||||
echo "Syncing Go module dependencies..."
|
||||
go mod tidy
|
||||
|
||||
# Run the server (use ./main.go to ensure module context is correct)
|
||||
SGL_TOKENIZER_PATH="$TOKENIZER_PATH" SGL_GRPC_ENDPOINT="$ENDPOINT" PORT="$PORT" go run ./main.go
|
||||
554
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/analyze_tpot.sh
vendored
Executable file
554
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/analyze_tpot.sh
vendored
Executable file
@@ -0,0 +1,554 @@
|
||||
#!/bin/bash
|
||||
|
||||
# TPOT performance bottleneck analysis script
|
||||
# Specifically designed to analyze why Go Router is twice as slow as Rust Router
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/analyze_tpot.sh [options]
|
||||
#
|
||||
# Options:
|
||||
# --duration SECONDS CPU profile duration (default: 60)
|
||||
# --requests NUM Number of requests (default: 100)
|
||||
# --concurrency NUM Concurrency level (default: 20)
|
||||
# --pprof-port PORT pprof port (default: 6060)
|
||||
# --server-url URL Server URL (default: http://localhost:8080)
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
PROFILE_DIR="${PROJECT_ROOT}/profiles"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUTPUT_DIR="${PROFILE_DIR}/tpot_analysis_${TIMESTAMP}"
|
||||
|
||||
# Colors
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m'
|
||||
|
||||
# Default values
|
||||
DURATION=${DURATION:-60}
|
||||
NUM_REQUESTS=${NUM_REQUESTS:-100}
|
||||
CONCURRENCY=${CONCURRENCY:-20}
|
||||
PPROF_PORT=${PPROF_PORT:-6060}
|
||||
SERVER_URL=${SERVER_URL:-http://localhost:8080}
|
||||
|
||||
# Parse arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--duration)
|
||||
DURATION="$2"
|
||||
shift 2
|
||||
;;
|
||||
--requests)
|
||||
NUM_REQUESTS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--concurrency)
|
||||
CONCURRENCY="$2"
|
||||
shift 2
|
||||
;;
|
||||
--pprof-port)
|
||||
PPROF_PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--server-url)
|
||||
SERVER_URL="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Check for graphviz (optional, needed for some pprof visualizations)
|
||||
HAS_GRAPHVIZ=false
|
||||
if command -v dot >/dev/null 2>&1; then
|
||||
HAS_GRAPHVIZ=true
|
||||
fi
|
||||
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE}TPOT Performance Bottleneck Analysis${NC}"
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo ""
|
||||
echo "Configuration:"
|
||||
echo " Duration: ${DURATION}s"
|
||||
echo " Requests: $NUM_REQUESTS"
|
||||
echo " Concurrency: $CONCURRENCY"
|
||||
echo " Server URL: $SERVER_URL"
|
||||
echo " pprof Port: $PPROF_PORT"
|
||||
echo " Output Dir: $OUTPUT_DIR"
|
||||
if [ "$HAS_GRAPHVIZ" = "false" ]; then
|
||||
echo ""
|
||||
echo -e "${YELLOW}Note: graphviz not found. Some pprof visualizations may not work.${NC}"
|
||||
echo -e "${YELLOW}To install graphviz:${NC}"
|
||||
echo -e "${YELLOW} macOS: brew install graphviz${NC}"
|
||||
echo -e "${YELLOW} Ubuntu: sudo apt-get install graphviz${NC}"
|
||||
echo -e "${YELLOW} CentOS: sudo yum install graphviz${NC}"
|
||||
echo -e "${YELLOW}Text reports will still be generated without graphviz.${NC}"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Check if server is running
|
||||
echo -e "${YELLOW}[Check] Verifying server is running...${NC}"
|
||||
if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then
|
||||
echo -e "${RED}Error: Server not responding at ${SERVER_URL}${NC}"
|
||||
echo ""
|
||||
echo "Please start the server first with profiling enabled:"
|
||||
echo " ./run.sh --profile --pprof-port $PPROF_PORT"
|
||||
echo " or"
|
||||
echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT make run"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${GREEN}✓ Server is running${NC}"
|
||||
echo ""
|
||||
|
||||
# Check if pprof is enabled
|
||||
echo -e "${YELLOW}[Check] Verifying pprof is enabled...${NC}"
|
||||
if ! curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then
|
||||
echo -e "${RED}Error: pprof not accessible at http://localhost:${PPROF_PORT}/debug/pprof/${NC}"
|
||||
echo ""
|
||||
echo "Please start the server with profiling enabled:"
|
||||
echo " ./run.sh --profile --pprof-port $PPROF_PORT"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${GREEN}✓ pprof is enabled${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 1: Collect baseline profiles
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 1/8] Collecting baseline profiles...${NC}"
|
||||
|
||||
# Baseline memory
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/heap_before.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
|
||||
|
||||
# Baseline goroutine
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/goroutine_before.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/goroutine" > /dev/null 2>&1 || true
|
||||
|
||||
echo -e "${GREEN}✓ Baseline profiles collected${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 2: Start CPU profile collection
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 2/8] Starting CPU profile collection (${DURATION}s)...${NC}"
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}" &
|
||||
CPU_PID=$!
|
||||
sleep 2
|
||||
echo -e "${GREEN}✓ CPU profile collection started${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 3: Run load test with streaming requests
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 3/8] Running load test ($NUM_REQUESTS streaming requests, concurrency=$CONCURRENCY)...${NC}"
|
||||
|
||||
# Function to run a single streaming request
|
||||
run_streaming_request() {
|
||||
local request_id=$1
|
||||
local start_time=$(date +%s)
|
||||
local start_nanos=$(date +%N 2>/dev/null || echo "000000000")
|
||||
|
||||
curl -N -s -X POST "${SERVER_URL}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"default\",
|
||||
\"messages\": [{\"role\": \"user\", \"content\": \"Write a 500-word story with character dialogue and scene descriptions\"}],
|
||||
\"stream\": true,
|
||||
\"max_tokens\": 300,
|
||||
\"temperature\": 0.7
|
||||
}" > /dev/null
|
||||
|
||||
local end_time=$(date +%s)
|
||||
local end_nanos=$(date +%N 2>/dev/null || echo "000000000")
|
||||
local duration=$((end_time - start_time))
|
||||
echo "$duration" >> "${OUTPUT_DIR}/request_times.txt"
|
||||
}
|
||||
|
||||
# Run requests with controlled concurrency
|
||||
# Use a temporary file to track job PIDs to avoid conflicts with CPU_PID
|
||||
JOB_PIDS_FILE="${OUTPUT_DIR}/.job_pids_$$"
|
||||
> "$JOB_PIDS_FILE"
|
||||
|
||||
for i in $(seq 1 $NUM_REQUESTS); do
|
||||
# Wait if we've reached concurrency limit
|
||||
while [ $(wc -l < "$JOB_PIDS_FILE" 2>/dev/null || echo 0) -ge $CONCURRENCY ]; do
|
||||
# Check and remove completed jobs
|
||||
while IFS= read -r pid; do
|
||||
if [ -n "$pid" ] && ! kill -0 "$pid" 2>/dev/null; then
|
||||
# Process completed, remove from file
|
||||
grep -v "^${pid}$" "$JOB_PIDS_FILE" > "${JOB_PIDS_FILE}.tmp" && \
|
||||
mv "${JOB_PIDS_FILE}.tmp" "$JOB_PIDS_FILE" || true
|
||||
fi
|
||||
done < "$JOB_PIDS_FILE"
|
||||
sleep 0.1
|
||||
done
|
||||
|
||||
# Start new request
|
||||
run_streaming_request $i &
|
||||
echo $! >> "$JOB_PIDS_FILE"
|
||||
|
||||
# Progress indicator
|
||||
if [ $((i % 10)) -eq 0 ]; then
|
||||
echo " Progress: $i/$NUM_REQUESTS requests sent..."
|
||||
fi
|
||||
done
|
||||
|
||||
# Wait for all remaining jobs (excluding CPU_PID)
|
||||
while IFS= read -r pid; do
|
||||
if [ -n "$pid" ] && [ "$pid" != "$CPU_PID" ]; then
|
||||
wait "$pid" 2>/dev/null || true
|
||||
fi
|
||||
done < "$JOB_PIDS_FILE"
|
||||
|
||||
# Clean up
|
||||
rm -f "$JOB_PIDS_FILE" "${JOB_PIDS_FILE}.tmp" 2>/dev/null || true
|
||||
|
||||
echo -e "${GREEN}✓ Load test completed${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 4: Wait for CPU profile to complete
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 4/8] Waiting for CPU profile to complete...${NC}"
|
||||
# Wait for the process, but handle the case where it might have already completed
|
||||
if kill -0 $CPU_PID 2>/dev/null; then
|
||||
wait $CPU_PID 2>/dev/null || true
|
||||
else
|
||||
# Process already completed, just wait a bit to ensure file is written
|
||||
sleep 1
|
||||
fi
|
||||
echo -e "${GREEN}✓ CPU profile collection completed${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 5: Collect final profiles
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 5/8] Collecting final profiles...${NC}"
|
||||
|
||||
# Final memory
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/heap_after.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
|
||||
|
||||
# Final goroutine
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/goroutine_after.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/goroutine" > /dev/null 2>&1 || true
|
||||
|
||||
# Mutex profile
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/mutex.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/mutex" > /dev/null 2>&1 || true
|
||||
|
||||
# Block profile
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/block.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/block" > /dev/null 2>&1 || true
|
||||
|
||||
echo -e "${GREEN}✓ Final profiles collected${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 6: Generate analysis reports
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 6/8] Generating analysis reports...${NC}"
|
||||
|
||||
# CPU analysis
|
||||
echo " Generating CPU reports..."
|
||||
go tool pprof -top -cum "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/01_cpu_top_cum.txt" 2>&1 || true
|
||||
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/02_cpu_top_flat.txt" 2>&1 || true
|
||||
|
||||
# Memory analysis
|
||||
echo " Generating memory reports..."
|
||||
if [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then
|
||||
go tool pprof -top -alloc_space "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/03_memory_alloc_space.txt" 2>&1 || true
|
||||
go tool pprof -top -alloc_objects "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/04_memory_alloc_objects.txt" 2>&1 || true
|
||||
go tool pprof -top -inuse_space "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/05_memory_inuse_space.txt" 2>&1 || true
|
||||
fi
|
||||
|
||||
# Memory growth
|
||||
if [ -f "${OUTPUT_DIR}/heap_before.pb.gz" ] && [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then
|
||||
go tool pprof -top -base="${OUTPUT_DIR}/heap_before.pb.gz" \
|
||||
"${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/06_memory_growth.txt" 2>&1 || true
|
||||
fi
|
||||
|
||||
# FFI/CGO analysis
|
||||
echo " Analyzing FFI/CGO calls..."
|
||||
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \
|
||||
grep -iE "(block_on|CGO|FFI|ffi|runtime\.cgo|_Cfunc)" > "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt" || \
|
||||
echo "No FFI/CGO related functions found" > "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt"
|
||||
|
||||
# JSON serialization analysis
|
||||
echo " Analyzing JSON serialization..."
|
||||
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \
|
||||
grep -iE "(json|Marshal|Unmarshal|Encode|Decode|sonic|jsoniter)" > "${OUTPUT_DIR}/08_json_analysis.txt" || \
|
||||
echo "No JSON related functions found" > "${OUTPUT_DIR}/08_json_analysis.txt"
|
||||
|
||||
# Goroutine analysis
|
||||
if [ -f "${OUTPUT_DIR}/goroutine_after.pb.gz" ]; then
|
||||
echo " Analyzing goroutines..."
|
||||
go tool pprof -top "${OUTPUT_DIR}/goroutine_after.pb.gz" > "${OUTPUT_DIR}/09_goroutine_analysis.txt" 2>&1 || true
|
||||
fi
|
||||
|
||||
# Mutex analysis
|
||||
if [ -f "${OUTPUT_DIR}/mutex.pb.gz" ]; then
|
||||
echo " Analyzing mutex contention..."
|
||||
go tool pprof -top "${OUTPUT_DIR}/mutex.pb.gz" > "${OUTPUT_DIR}/10_mutex_analysis.txt" 2>&1 || true
|
||||
fi
|
||||
|
||||
# Block analysis
|
||||
if [ -f "${OUTPUT_DIR}/block.pb.gz" ]; then
|
||||
echo " Analyzing blocking operations..."
|
||||
go tool pprof -top "${OUTPUT_DIR}/block.pb.gz" > "${OUTPUT_DIR}/11_block_analysis.txt" 2>&1 || true
|
||||
fi
|
||||
|
||||
# Request timing statistics
|
||||
if [ -f "${OUTPUT_DIR}/request_times.txt" ] && [ -s "${OUTPUT_DIR}/request_times.txt" ]; then
|
||||
echo " Calculating request timing statistics..."
|
||||
{
|
||||
echo "Request Timing Statistics"
|
||||
echo "========================"
|
||||
echo ""
|
||||
echo "Total requests: $(wc -l < "${OUTPUT_DIR}/request_times.txt" | tr -d ' ')"
|
||||
echo ""
|
||||
awk '{
|
||||
sum+=$1
|
||||
sumsq+=$1*$1
|
||||
if(NR==1 || $1<min) min=$1
|
||||
if(NR==1 || $1>max) max=$1
|
||||
} END {
|
||||
if(NR > 0) {
|
||||
mean=sum/NR
|
||||
variance=(sumsq/NR - mean*mean)
|
||||
stddev=sqrt(variance)
|
||||
print "Min: " min "s"
|
||||
print "Max: " max "s"
|
||||
print "Mean: " mean "s"
|
||||
print "StdDev: " stddev "s"
|
||||
}
|
||||
}' "${OUTPUT_DIR}/request_times.txt"
|
||||
} > "${OUTPUT_DIR}/12_request_timing.txt"
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✓ Analysis reports generated${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 7: Generate summary report
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 7/8] Generating summary report...${NC}"
|
||||
|
||||
SUMMARY_FILE="${OUTPUT_DIR}/00_SUMMARY.md"
|
||||
cat > "$SUMMARY_FILE" <<EOF
|
||||
# TPOT Performance Analysis Summary
|
||||
|
||||
**Analysis Date:** $(date)
|
||||
**Duration:** ${DURATION}s
|
||||
**Requests:** $NUM_REQUESTS
|
||||
**Concurrency:** $CONCURRENCY
|
||||
|
||||
## Key Findings
|
||||
|
||||
### 1. CPU Hotspots (Top 10 Cumulative Time)
|
||||
|
||||
\`\`\`
|
||||
$(head -15 "${OUTPUT_DIR}/01_cpu_top_cum.txt" | tail -10)
|
||||
\`\`\`
|
||||
|
||||
### 2. CPU Hotspots (Top 10 Flat Time)
|
||||
|
||||
\`\`\`
|
||||
$(head -15 "${OUTPUT_DIR}/02_cpu_top_flat.txt" | tail -10)
|
||||
\`\`\`
|
||||
|
||||
### 3. FFI/CGO Overhead
|
||||
|
||||
\`\`\`
|
||||
$(cat "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt")
|
||||
\`\`\`
|
||||
|
||||
### 4. JSON Serialization Overhead
|
||||
|
||||
\`\`\`
|
||||
$(cat "${OUTPUT_DIR}/08_json_analysis.txt")
|
||||
\`\`\`
|
||||
|
||||
### 5. Memory Allocation (Top 10 by Space)
|
||||
|
||||
\`\`\`
|
||||
$(head -15 "${OUTPUT_DIR}/03_memory_alloc_space.txt" | tail -10)
|
||||
\`\`\`
|
||||
|
||||
### 6. Memory Allocation (Top 10 by Objects)
|
||||
|
||||
\`\`\`
|
||||
$(head -15 "${OUTPUT_DIR}/04_memory_alloc_objects.txt" | tail -10)
|
||||
\`\`\`
|
||||
|
||||
### 7. Mutex Contention
|
||||
|
||||
\`\`\`
|
||||
$(head -15 "${OUTPUT_DIR}/10_mutex_analysis.txt" | tail -10 2>/dev/null || echo "No significant mutex contention detected")
|
||||
\`\`\`
|
||||
|
||||
### 8. Blocking Operations
|
||||
|
||||
\`\`\`
|
||||
$(head -15 "${OUTPUT_DIR}/11_block_analysis.txt" | tail -10 2>/dev/null || echo "No significant blocking detected")
|
||||
\`\`\`
|
||||
|
||||
## Performance Bottlenecks Identified
|
||||
|
||||
### High Priority Issues
|
||||
|
||||
1. **FFI/CGO Overhead**
|
||||
- Check: \`cat ${OUTPUT_DIR}/07_ffi_cgo_analysis.txt\`
|
||||
- Impact: FFI calls add overhead compared to native Rust code
|
||||
- Recommendation: Minimize FFI calls, batch operations
|
||||
|
||||
2. **JSON Serialization**
|
||||
- Check: \`cat ${OUTPUT_DIR}/08_json_analysis.txt\`
|
||||
- Impact: JSON marshaling/unmarshaling can be expensive
|
||||
- Recommendation: Use faster JSON library (jsoniter), reduce serialization frequency
|
||||
|
||||
3. **Memory Allocations**
|
||||
- Check: \`cat ${OUTPUT_DIR}/03_memory_alloc_space.txt\`
|
||||
- Impact: Frequent allocations cause GC pressure
|
||||
- Recommendation: Use object pools, pre-allocate buffers
|
||||
|
||||
### Medium Priority Issues
|
||||
|
||||
4. **Goroutine Overhead**
|
||||
- Check: \`cat ${OUTPUT_DIR}/09_goroutine_analysis.txt\`
|
||||
- Impact: Too many goroutines can cause scheduling overhead
|
||||
- Recommendation: Limit goroutine count, use worker pools
|
||||
|
||||
5. **Lock Contention**
|
||||
- Check: \`cat ${OUTPUT_DIR}/10_mutex_analysis.txt\`
|
||||
- Impact: Lock contention reduces parallelism
|
||||
- Recommendation: Reduce lock granularity, use lock-free structures
|
||||
|
||||
## Comparison with Rust Router
|
||||
|
||||
### Expected Differences
|
||||
|
||||
1. **FFI Overhead**: Go → Rust FFI calls add ~100-500ns per call
|
||||
2. **GC Overhead**: Go's GC can cause pauses (usually <1ms)
|
||||
3. **JSON Library**: Go's standard library is slower than Rust's serde
|
||||
4. **Memory Layout**: Go's GC affects cache locality
|
||||
|
||||
### Optimization Opportunities
|
||||
|
||||
1. **Reduce FFI Calls**
|
||||
- Batch token processing
|
||||
- Use async FFI (if possible)
|
||||
- Cache frequently used FFI results
|
||||
|
||||
2. **Optimize JSON**
|
||||
- Use jsoniter (already implemented)
|
||||
- Pre-allocate JSON buffers
|
||||
- Reduce serialization frequency
|
||||
|
||||
3. **Memory Management**
|
||||
- Use sync.Pool for frequently allocated objects
|
||||
- Pre-allocate slices with known capacity
|
||||
- Avoid unnecessary string copies
|
||||
|
||||
4. **Concurrency**
|
||||
- Use worker pools instead of spawning goroutines per request
|
||||
- Limit concurrent FFI calls
|
||||
- Use channels efficiently
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Review detailed reports in this directory
|
||||
2. Use interactive pprof: \`go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz\`
|
||||
3. Compare with Rust router profiles (if available)
|
||||
4. Implement optimizations based on findings
|
||||
5. Re-run analysis to measure improvements
|
||||
|
||||
## Files Generated
|
||||
|
||||
- \`00_SUMMARY.md\` - This summary
|
||||
- \`01_cpu_top_cum.txt\` - CPU top functions (cumulative)
|
||||
- \`02_cpu_top_flat.txt\` - CPU top functions (flat)
|
||||
- \`03_memory_alloc_space.txt\` - Memory allocation by space
|
||||
- \`04_memory_alloc_objects.txt\` - Memory allocation by objects
|
||||
- \`05_memory_inuse_space.txt\` - Memory in use by space
|
||||
- \`06_memory_growth.txt\` - Memory growth during test
|
||||
- \`07_ffi_cgo_analysis.txt\` - FFI/CGO overhead analysis
|
||||
- \`08_json_analysis.txt\` - JSON serialization analysis
|
||||
- \`09_goroutine_analysis.txt\` - Goroutine analysis
|
||||
- \`10_mutex_analysis.txt\` - Mutex contention analysis
|
||||
- \`11_block_analysis.txt\` - Blocking operations analysis
|
||||
- \`12_request_timing.txt\` - Request timing statistics
|
||||
- \`*.pb.gz\` - Raw profile files for interactive analysis
|
||||
|
||||
EOF
|
||||
|
||||
echo -e "${GREEN}✓ Summary report generated${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# Step 8: Display summary
|
||||
# ============================================
|
||||
echo -e "${GREEN}[Step 8/8] Analysis Complete!${NC}"
|
||||
echo ""
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE}Summary${NC}"
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Top CPU Hotspots (Cumulative):${NC}"
|
||||
head -12 "${OUTPUT_DIR}/01_cpu_top_cum.txt" | tail -10
|
||||
echo ""
|
||||
echo -e "${YELLOW}FFI/CGO Overhead:${NC}"
|
||||
cat "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt"
|
||||
echo ""
|
||||
echo -e "${YELLOW}JSON Serialization Overhead:${NC}"
|
||||
cat "${OUTPUT_DIR}/08_json_analysis.txt"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Top Memory Allocations:${NC}"
|
||||
head -12 "${OUTPUT_DIR}/03_memory_alloc_space.txt" | tail -10
|
||||
echo ""
|
||||
if [ -f "${OUTPUT_DIR}/12_request_timing.txt" ]; then
|
||||
echo -e "${YELLOW}Request Timing:${NC}"
|
||||
cat "${OUTPUT_DIR}/12_request_timing.txt"
|
||||
echo ""
|
||||
fi
|
||||
echo -e "${GREEN}========================================${NC}"
|
||||
echo ""
|
||||
echo -e "${BLUE}Detailed Reports:${NC}"
|
||||
echo " Summary: cat ${OUTPUT_DIR}/00_SUMMARY.md"
|
||||
echo " CPU (cum): cat ${OUTPUT_DIR}/01_cpu_top_cum.txt"
|
||||
echo " CPU (flat): cat ${OUTPUT_DIR}/02_cpu_top_flat.txt"
|
||||
echo " FFI/CGO: cat ${OUTPUT_DIR}/07_ffi_cgo_analysis.txt"
|
||||
echo " JSON: cat ${OUTPUT_DIR}/08_json_analysis.txt"
|
||||
echo " Memory: cat ${OUTPUT_DIR}/03_memory_alloc_space.txt"
|
||||
echo ""
|
||||
echo -e "${BLUE}Interactive Analysis:${NC}"
|
||||
echo " Run: go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz"
|
||||
echo " Then visit:"
|
||||
echo " - http://localhost:8081/ui/flamegraph (Flame Graph - no graphviz needed)"
|
||||
echo " - http://localhost:8081/ui/top (Top Functions - no graphviz needed)"
|
||||
if [ "$HAS_GRAPHVIZ" = "true" ]; then
|
||||
echo " - http://localhost:8081/ui/graph (Call Graph - requires graphviz)"
|
||||
else
|
||||
echo " - http://localhost:8081/ui/graph (Call Graph - requires graphviz, not available)"
|
||||
fi
|
||||
echo ""
|
||||
if [ "$HAS_GRAPHVIZ" = "false" ]; then
|
||||
echo -e "${YELLOW}Note: Install graphviz to enable call graph visualization:${NC}"
|
||||
echo -e "${YELLOW} macOS: brew install graphviz${NC}"
|
||||
echo -e "${YELLOW} Ubuntu: sudo apt-get install graphviz${NC}"
|
||||
echo -e "${YELLOW} CentOS: sudo yum install graphviz${NC}"
|
||||
echo ""
|
||||
fi
|
||||
echo -e "${GREEN}All files saved to: ${OUTPUT_DIR}${NC}"
|
||||
echo ""
|
||||
215
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_analysis.sh
vendored
Executable file
215
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_analysis.sh
vendored
Executable file
@@ -0,0 +1,215 @@
|
||||
#!/bin/bash
|
||||
|
||||
# pprof performance analysis script
|
||||
# Used to analyze performance bottlenecks of Go OpenAI server
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Configuration
|
||||
PPROF_PORT=${PPROF_PORT:-6060}
|
||||
SERVER_PORT=${SERVER_PORT:-8080}
|
||||
DURATION=${DURATION:-60} # Performance test duration (seconds)
|
||||
OUTPUT_DIR="./pprof_results"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
echo "=========================================="
|
||||
echo "pprof Performance Analysis Tool"
|
||||
echo "=========================================="
|
||||
echo "PPROF_PORT: $PPROF_PORT"
|
||||
echo "SERVER_PORT: $SERVER_PORT"
|
||||
echo "DURATION: ${DURATION}s"
|
||||
echo "OUTPUT_DIR: $OUTPUT_DIR"
|
||||
echo ""
|
||||
|
||||
# Check if go tool pprof is available
|
||||
if ! command -v go &> /dev/null; then
|
||||
echo "Error: go command not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if server is running
|
||||
check_server() {
|
||||
if curl -s "http://localhost:${SERVER_PORT}/health" > /dev/null 2>&1; then
|
||||
return 0
|
||||
else
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Check if pprof is available
|
||||
check_pprof() {
|
||||
if curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then
|
||||
return 0
|
||||
else
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Start server (if not running)
|
||||
if ! check_server; then
|
||||
echo "Server not running, please start the server first:"
|
||||
echo " export PPROF_ENABLED=true"
|
||||
echo " export PPROF_PORT=$PPROF_PORT"
|
||||
echo " ./oai_server"
|
||||
echo ""
|
||||
echo "Or use the following command to start:"
|
||||
echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT ./oai_server"
|
||||
echo ""
|
||||
read -p "Start server now? (y/n) " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "Starting server..."
|
||||
PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT ./oai_server &
|
||||
SERVER_PID=$!
|
||||
echo "Server PID: $SERVER_PID"
|
||||
|
||||
# Wait for server to start
|
||||
echo "Waiting for server to start..."
|
||||
for i in {1..30}; do
|
||||
if check_server; then
|
||||
echo "Server started"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
if ! check_server; then
|
||||
echo "Error: Server failed to start"
|
||||
kill $SERVER_PID 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check if pprof is available
|
||||
if ! check_pprof; then
|
||||
echo "Error: pprof not enabled. Please set environment variables:"
|
||||
echo " export PPROF_ENABLED=true"
|
||||
echo " export PPROF_PORT=$PPROF_PORT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Starting to collect performance data..."
|
||||
echo ""
|
||||
|
||||
# 1. CPU Profile (30 seconds)
|
||||
echo "[1/6] Collecting CPU Profile (30 seconds)..."
|
||||
go tool pprof -proto -output="$OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30" &
|
||||
CPU_PID=$!
|
||||
|
||||
# 2. Collect Heap Profile simultaneously
|
||||
echo "[2/6] Collecting Heap Profile..."
|
||||
go tool pprof -proto -output="$OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/heap" &
|
||||
HEAP_PID=$!
|
||||
|
||||
# 3. Collect Goroutine Profile
|
||||
echo "[3/6] Collecting Goroutine Profile..."
|
||||
go tool pprof -proto -output="$OUTPUT_DIR/goroutine_${TIMESTAMP}.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/goroutine" &
|
||||
GOROUTINE_PID=$!
|
||||
|
||||
# 4. Collect Mutex Profile
|
||||
echo "[4/6] Collecting Mutex Profile..."
|
||||
go tool pprof -proto -output="$OUTPUT_DIR/mutex_${TIMESTAMP}.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/mutex" &
|
||||
MUTEX_PID=$!
|
||||
|
||||
# 5. Collect Block Profile
|
||||
echo "[5/6] Collecting Block Profile..."
|
||||
go tool pprof -proto -output="$OUTPUT_DIR/block_${TIMESTAMP}.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/block" &
|
||||
BLOCK_PID=$!
|
||||
|
||||
# 6. Run performance test (during CPU profile collection)
|
||||
echo "[6/6] Running performance test..."
|
||||
echo "Tip: Please use your performance testing tool (curl, ab, wrk, etc.) to send requests to the server"
|
||||
echo " CPU profile will collect 30 seconds of performance data"
|
||||
echo ""
|
||||
|
||||
# Wait for CPU profile to complete
|
||||
wait $CPU_PID
|
||||
echo "CPU Profile collection completed"
|
||||
|
||||
# Wait for other profiles
|
||||
wait $HEAP_PID
|
||||
wait $GOROUTINE_PID
|
||||
wait $MUTEX_PID
|
||||
wait $BLOCK_PID
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Performance data collection completed!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "Generated analysis files:"
|
||||
ls -lh "$OUTPUT_DIR"/*_${TIMESTAMP}.* 2>/dev/null || true
|
||||
echo ""
|
||||
|
||||
# Generate analysis report
|
||||
echo "Generating analysis report..."
|
||||
echo ""
|
||||
|
||||
# CPU Top 20
|
||||
echo "=== CPU Top 20 (sorted by flat time) ===" > "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
go tool pprof -top -cum "$OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
|
||||
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
|
||||
# Heap Top 20
|
||||
echo "=== Heap Top 20 (sorted by allocation size) ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
go tool pprof -top "$OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
|
||||
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
|
||||
# Goroutine statistics
|
||||
echo "=== Goroutine Statistics ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
go tool pprof -top "$OUTPUT_DIR/goroutine_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
|
||||
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
|
||||
# Mutex statistics
|
||||
echo "=== Mutex Wait Time ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
go tool pprof -top "$OUTPUT_DIR/mutex_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
|
||||
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
|
||||
# Block statistics
|
||||
echo "=== Block Wait Time ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
go tool pprof -top "$OUTPUT_DIR/block_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
|
||||
|
||||
echo "Analysis report saved to: $OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
|
||||
# Display key information
|
||||
echo "=========================================="
|
||||
echo "Key Performance Metrics Summary"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "View detailed report:"
|
||||
echo " cat $OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
echo "Interactive CPU Profile view:"
|
||||
echo " go tool pprof $OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz"
|
||||
echo ""
|
||||
echo "Interactive Heap Profile view:"
|
||||
echo " go tool pprof $OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz"
|
||||
echo ""
|
||||
echo "Generate flame graph (requires go-torch or pprof):"
|
||||
echo " go tool pprof -http=:8080 $OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz"
|
||||
echo ""
|
||||
|
||||
# If server was started, ask if it should be closed
|
||||
if [ -n "$SERVER_PID" ]; then
|
||||
read -p "Close server? (y/n) " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
kill $SERVER_PID 2>/dev/null || true
|
||||
echo "Server closed"
|
||||
fi
|
||||
fi
|
||||
52
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_quick.sh
vendored
Executable file
52
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_quick.sh
vendored
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Quick pprof analysis script
|
||||
# Collects 30-second CPU profile and immediately displays top results
|
||||
|
||||
set -e
|
||||
|
||||
PPROF_PORT=${PPROF_PORT:-6060}
|
||||
DURATION=${DURATION:-30}
|
||||
|
||||
echo "=========================================="
|
||||
echo "Quick pprof Analysis"
|
||||
echo "=========================================="
|
||||
echo "PPROF_PORT: $PPROF_PORT"
|
||||
echo "DURATION: ${DURATION}s"
|
||||
echo ""
|
||||
echo "Tip: During data collection, please send requests to the server"
|
||||
echo " You can use: ./pprof_test.sh"
|
||||
echo ""
|
||||
|
||||
# Check if pprof is available
|
||||
if ! curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then
|
||||
echo "Error: pprof not enabled. Please set environment variables:"
|
||||
echo " export PPROF_ENABLED=true"
|
||||
echo " export PPROF_PORT=$PPROF_PORT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Starting to collect CPU Profile (${DURATION} seconds)..."
|
||||
echo ""
|
||||
|
||||
# Collect CPU profile and directly display top results
|
||||
go tool pprof -top -cum "http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}"
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Analysis Complete"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "More analysis options:"
|
||||
echo " # Interactive view"
|
||||
echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30"
|
||||
echo ""
|
||||
echo " # View heap memory"
|
||||
echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/heap"
|
||||
echo ""
|
||||
echo " # View goroutines"
|
||||
echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/goroutine"
|
||||
echo ""
|
||||
echo " # Generate Web UI"
|
||||
echo " go tool pprof -http=:8080 http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30"
|
||||
echo ""
|
||||
87
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_test.sh
vendored
Executable file
87
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_test.sh
vendored
Executable file
@@ -0,0 +1,87 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Simple performance test script for sending requests while collecting pprof data
|
||||
|
||||
set -e
|
||||
|
||||
SERVER_URL=${SERVER_URL:-"http://localhost:8080"}
|
||||
DURATION=${DURATION:-30} # Test duration (seconds)
|
||||
CONCURRENT=${CONCURRENT:-1} # Number of concurrent requests
|
||||
|
||||
echo "=========================================="
|
||||
echo "Performance Test Script"
|
||||
echo "=========================================="
|
||||
echo "SERVER_URL: $SERVER_URL"
|
||||
echo "DURATION: ${DURATION}s"
|
||||
echo "CONCURRENT: $CONCURRENT"
|
||||
echo ""
|
||||
|
||||
# Test request JSON
|
||||
TEST_REQUEST='{
|
||||
"model": "default",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
|
||||
# Check if server is available
|
||||
if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then
|
||||
echo "Error: Server not available (${SERVER_URL}/health)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Starting to send test requests..."
|
||||
echo ""
|
||||
|
||||
# Function to send streaming request
|
||||
send_stream_request() {
|
||||
local request_num=$1
|
||||
local start_time=$(date +%s.%N)
|
||||
|
||||
curl -s -N -X POST "${SERVER_URL}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$TEST_REQUEST" \
|
||||
> /dev/null 2>&1
|
||||
|
||||
local end_time=$(date +%s.%N)
|
||||
local duration=$(echo "$end_time - $start_time" | bc)
|
||||
echo "Request $request_num completed, duration: ${duration}s"
|
||||
}
|
||||
|
||||
# Send requests concurrently
|
||||
if [ "$CONCURRENT" -eq 1 ]; then
|
||||
# Single-threaded mode: continuously send requests
|
||||
end_time=$(($(date +%s) + DURATION))
|
||||
request_count=0
|
||||
|
||||
while [ $(date +%s) -lt $end_time ]; do
|
||||
request_count=$((request_count + 1))
|
||||
send_stream_request $request_count
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "Test completed, sent $request_count requests"
|
||||
else
|
||||
# Multi-threaded mode: send requests concurrently
|
||||
end_time=$(($(date +%s) + DURATION))
|
||||
request_count=0
|
||||
|
||||
while [ $(date +%s) -lt $end_time ]; do
|
||||
# Start concurrent requests
|
||||
for i in $(seq 1 $CONCURRENT); do
|
||||
request_count=$((request_count + 1))
|
||||
send_stream_request $request_count &
|
||||
done
|
||||
|
||||
# Wait for all requests to complete
|
||||
wait
|
||||
|
||||
# Brief rest to avoid overload
|
||||
sleep 0.1
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "Test completed, sent $request_count requests"
|
||||
fi
|
||||
140
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/profile_tpot.sh
vendored
Executable file
140
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/profile_tpot.sh
vendored
Executable file
@@ -0,0 +1,140 @@
|
||||
#!/bin/bash
|
||||
|
||||
# TPOT performance analysis script
|
||||
# Quickly collect and analyze TPOT-related performance data
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
PROFILE_DIR="${PROJECT_ROOT}/profiles"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUTPUT_DIR="${PROFILE_DIR}/${TIMESTAMP}"
|
||||
|
||||
# Colors
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
# Default values
|
||||
PPROF_PORT=${PPROF_PORT:-6060}
|
||||
SERVER_URL=${SERVER_URL:-http://localhost:8080}
|
||||
DURATION=${DURATION:-30}
|
||||
NUM_REQUESTS=${NUM_REQUESTS:-20}
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
echo -e "${GREEN}TPOT Performance Analysis${NC}"
|
||||
echo "=========================="
|
||||
echo "Profile directory: $OUTPUT_DIR"
|
||||
echo "Duration: ${DURATION}s"
|
||||
echo "Requests: $NUM_REQUESTS"
|
||||
echo ""
|
||||
|
||||
# Check if server is running
|
||||
if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then
|
||||
echo -e "${YELLOW}Warning: Server not responding at ${SERVER_URL}${NC}"
|
||||
echo "Please start the server first with profiling enabled:"
|
||||
echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT make run"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Collect baseline memory
|
||||
echo -e "${GREEN}[1/5] Collecting baseline memory profile...${NC}"
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/heap_before.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
|
||||
|
||||
# Start CPU profile collection in background
|
||||
echo -e "${GREEN}[2/5] Starting CPU profile collection (${DURATION}s)...${NC}"
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}" &
|
||||
CPU_PID=$!
|
||||
|
||||
# Wait a bit for profile to start
|
||||
sleep 2
|
||||
|
||||
# Run load test
|
||||
echo -e "${GREEN}[3/5] Running load test ($NUM_REQUESTS requests)...${NC}"
|
||||
for i in $(seq 1 $NUM_REQUESTS); do
|
||||
curl -N -s -X POST "${SERVER_URL}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"default\",
|
||||
\"messages\": [{\"role\": \"user\", \"content\": \"Write a story\"}],
|
||||
\"stream\": true,
|
||||
\"max_tokens\": 200
|
||||
}" > /dev/null &
|
||||
|
||||
# Limit concurrency
|
||||
if [ $((i % 5)) -eq 0 ]; then
|
||||
wait
|
||||
fi
|
||||
done
|
||||
wait
|
||||
|
||||
# Wait for CPU profile to complete
|
||||
echo -e "${GREEN}[4/5] Waiting for CPU profile to complete...${NC}"
|
||||
# Wait for the CPU profile process, but handle the case where it's not a child process
|
||||
if kill -0 $CPU_PID 2>/dev/null; then
|
||||
# Process is still running, wait for it
|
||||
while kill -0 $CPU_PID 2>/dev/null; do
|
||||
sleep 1
|
||||
done
|
||||
else
|
||||
# Process already completed or not found, just wait a bit
|
||||
sleep 2
|
||||
fi
|
||||
|
||||
# Collect final memory
|
||||
echo -e "${GREEN}[5/5] Collecting final memory profile...${NC}"
|
||||
go tool pprof -proto -output="${OUTPUT_DIR}/heap_after.pb.gz" \
|
||||
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
|
||||
|
||||
# Generate reports
|
||||
echo ""
|
||||
echo -e "${GREEN}Generating reports...${NC}"
|
||||
|
||||
# CPU top (cumulative)
|
||||
go tool pprof -top -cum "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/cpu_top_cum.txt" 2>&1 || true
|
||||
|
||||
# CPU top (flat)
|
||||
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/cpu_top_flat.txt" 2>&1 || true
|
||||
|
||||
# Memory growth
|
||||
if [ -f "${OUTPUT_DIR}/heap_before.pb.gz" ] && [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then
|
||||
go tool pprof -top -base="${OUTPUT_DIR}/heap_before.pb.gz" \
|
||||
"${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/heap_growth.txt" 2>&1 || true
|
||||
fi
|
||||
|
||||
# FFI/CGO related
|
||||
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \
|
||||
grep -E "(block_on|CGO|FFI|json|Marshal|Unmarshal)" > "${OUTPUT_DIR}/ffi_related.txt" || \
|
||||
echo "No FFI/CGO related functions found" > "${OUTPUT_DIR}/ffi_related.txt"
|
||||
|
||||
# Summary
|
||||
echo ""
|
||||
echo -e "${GREEN}=== Analysis Summary ===${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}CPU Top (Cumulative) - Top 10:${NC}"
|
||||
head -12 "${OUTPUT_DIR}/cpu_top_cum.txt" | tail -10 || true
|
||||
|
||||
echo ""
|
||||
echo -e "${YELLOW}CPU Top (Flat) - Top 10:${NC}"
|
||||
head -12 "${OUTPUT_DIR}/cpu_top_flat.txt" | tail -10 || true
|
||||
|
||||
echo ""
|
||||
echo -e "${YELLOW}FFI/CGO Related Functions:${NC}"
|
||||
cat "${OUTPUT_DIR}/ffi_related.txt" || true
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}=== Detailed Reports ===${NC}"
|
||||
echo "CPU (cumulative): cat ${OUTPUT_DIR}/cpu_top_cum.txt"
|
||||
echo "CPU (flat): cat ${OUTPUT_DIR}/cpu_top_flat.txt"
|
||||
echo "Memory growth: cat ${OUTPUT_DIR}/heap_growth.txt"
|
||||
echo "FFI related: cat ${OUTPUT_DIR}/ffi_related.txt"
|
||||
echo ""
|
||||
echo -e "${GREEN}=== Interactive Analysis ===${NC}"
|
||||
echo "Run: go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz"
|
||||
echo "Then visit: http://localhost:8081/ui/flamegraph"
|
||||
echo ""
|
||||
echo "Profile files saved to: ${OUTPUT_DIR}"
|
||||
37
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/service/sglang.go
vendored
Normal file
37
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/service/sglang.go
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
sglang "github.com/sglang/sglang-go-grpc-sdk"
|
||||
)
|
||||
|
||||
// SGLangService wraps SGLang client
|
||||
type SGLangService struct {
|
||||
client *sglang.Client
|
||||
}
|
||||
|
||||
func NewSGLangService(endpoint, tokenizerPath string) (*SGLangService, error) {
|
||||
client, err := sglang.NewClient(sglang.ClientConfig{
|
||||
Endpoint: endpoint,
|
||||
TokenizerPath: tokenizerPath,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SGLangService{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Client returns the underlying SGLang client
|
||||
func (s *SGLangService) Client() *sglang.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
// Close closes the SGLang client
|
||||
func (s *SGLangService) Close() error {
|
||||
if s.client != nil {
|
||||
return s.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
34
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/utils/utils.go
vendored
Normal file
34
third_party/sglang/sgl-model-gateway/bindings/golang/examples/oai_server/utils/utils.go
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// RespondError sends an error response in OpenAI format
|
||||
func RespondError(ctx *fasthttp.RequestCtx, statusCode int, message, errorType string) {
|
||||
ctx.SetStatusCode(statusCode)
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
response := map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"message": message,
|
||||
"type": errorType,
|
||||
"code": statusCode,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(response)
|
||||
ctx.Write(jsonData)
|
||||
}
|
||||
|
||||
// BuildResponseBase builds the base response structure for OpenAI-compatible responses
|
||||
func BuildResponseBase(id string, created int64, model string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
}
|
||||
}
|
||||
85
third_party/sglang/sgl-model-gateway/bindings/golang/examples/simple/main.go
vendored
Normal file
85
third_party/sglang/sgl-model-gateway/bindings/golang/examples/simple/main.go
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
// Simple example demonstrating basic usage of SGLang Go SDK
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/sglang/sglang-go-grpc-sdk"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Get configuration from environment or command line
|
||||
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "grpc://localhost:20000"
|
||||
}
|
||||
|
||||
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
|
||||
if tokenizerPath == "" {
|
||||
tokenizerPath = "./examples/tokenizer"
|
||||
}
|
||||
|
||||
// Create client
|
||||
client, err := sglang.NewClient(sglang.ClientConfig{
|
||||
Endpoint: endpoint,
|
||||
TokenizerPath: tokenizerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Create chat completion request
|
||||
req := sglang.ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []sglang.ChatMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "写一首歌关于夏天",
|
||||
},
|
||||
},
|
||||
Stream: false,
|
||||
Temperature: float32Ptr(0.7),
|
||||
MaxCompletionTokens: intPtr(200),
|
||||
SkipSpecialTokens: true,
|
||||
Tools: nil, // Use nil instead of empty slice to avoid template errors
|
||||
}
|
||||
|
||||
// Create completion
|
||||
ctx := context.Background()
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create completion: %v", err)
|
||||
}
|
||||
|
||||
// Print response
|
||||
fmt.Println("=== Response ===")
|
||||
fmt.Printf("ID: %s\n", resp.ID)
|
||||
fmt.Printf("Model: %s\n", resp.Model)
|
||||
fmt.Printf("Created: %d\n", resp.Created)
|
||||
fmt.Println("\nContent:")
|
||||
for _, choice := range resp.Choices {
|
||||
fmt.Println(choice.Message.Content)
|
||||
}
|
||||
fmt.Printf("\nFinish Reason: %s\n", resp.Choices[0].FinishReason)
|
||||
fmt.Printf("\nUsage: Prompt=%d, Completion=%d, Total=%d\n",
|
||||
resp.Usage.PromptTokens,
|
||||
resp.Usage.CompletionTokens,
|
||||
resp.Usage.TotalTokens,
|
||||
)
|
||||
}
|
||||
|
||||
func float32Ptr(f float32) *float32 {
|
||||
return &f
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
46
third_party/sglang/sgl-model-gateway/bindings/golang/examples/simple/run.sh
vendored
Executable file
46
third_party/sglang/sgl-model-gateway/bindings/golang/examples/simple/run.sh
vendored
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Simple example runner
|
||||
# Usage: ./run.sh [tokenizer_path] [endpoint]
|
||||
|
||||
# Set library path for Rust FFI library
|
||||
# The library should be in ./lib directory (created by 'make lib')
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
LIB_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)/lib"
|
||||
|
||||
# Check if lib directory exists
|
||||
if [ ! -d "$LIB_DIR" ]; then
|
||||
echo "Error: Library directory not found at $LIB_DIR"
|
||||
echo "Please run 'make lib' first to build and export the library"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get Python LDFLAGS (needed for Rust FFI that depends on Python)
|
||||
PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
|
||||
|
||||
# Set CGO_LDFLAGS to link with the Rust library
|
||||
export CGO_LDFLAGS="-L${LIB_DIR} -lsgl_model_gateway_go ${PYTHON_LDFLAGS} -ldl"
|
||||
|
||||
# macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
export DYLD_LIBRARY_PATH="${LIB_DIR}:${DYLD_LIBRARY_PATH}"
|
||||
else
|
||||
export LD_LIBRARY_PATH="${LIB_DIR}:${LD_LIBRARY_PATH}"
|
||||
fi
|
||||
|
||||
# Default configuration (can be overridden by environment variables or command line arguments)
|
||||
# Tokenizer path: ../tokenizer (relative to this script)
|
||||
DEFAULT_TOKENIZER_PATH="${SGL_TOKENIZER_PATH:-../tokenizer}"
|
||||
DEFAULT_ENDPOINT="${SGL_GRPC_ENDPOINT:-grpc://localhost:20000}"
|
||||
|
||||
TOKENIZER_PATH="${1:-${DEFAULT_TOKENIZER_PATH}}"
|
||||
ENDPOINT="${2:-${DEFAULT_ENDPOINT}}"
|
||||
|
||||
echo "Running simple example..."
|
||||
echo "Library path: ${LIB_DIR}"
|
||||
echo "Tokenizer: $TOKENIZER_PATH"
|
||||
echo "Endpoint: $ENDPOINT"
|
||||
echo ""
|
||||
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
SGL_TOKENIZER_PATH="$TOKENIZER_PATH" SGL_GRPC_ENDPOINT="$ENDPOINT" go run main.go
|
||||
125
third_party/sglang/sgl-model-gateway/bindings/golang/examples/streaming/main.go
vendored
Normal file
125
third_party/sglang/sgl-model-gateway/bindings/golang/examples/streaming/main.go
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
// Streaming example demonstrating real-time streaming with SGLang Go SDK
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sglang/sglang-go-grpc-sdk"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Get configuration from environment or command line
|
||||
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "grpc://localhost:20000"
|
||||
}
|
||||
|
||||
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
|
||||
if tokenizerPath == "" {
|
||||
tokenizerPath = "./examples/tokenizer"
|
||||
}
|
||||
|
||||
// Create client
|
||||
client, err := sglang.NewClient(sglang.ClientConfig{
|
||||
Endpoint: endpoint,
|
||||
TokenizerPath: tokenizerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Create streaming chat completion request
|
||||
req := sglang.ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []sglang.ChatMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "写一首春天的诗歌",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
Temperature: float32Ptr(0.7),
|
||||
MaxCompletionTokens: intPtr(500),
|
||||
SkipSpecialTokens: true,
|
||||
Tools: nil, // Use nil instead of empty slice to avoid template errors
|
||||
}
|
||||
|
||||
// Create streaming completion
|
||||
ctx := context.Background()
|
||||
stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create stream: %v", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
fmt.Println("=== Streaming Response ===")
|
||||
fmt.Println()
|
||||
|
||||
var fullContent strings.Builder
|
||||
chunkCount := 0
|
||||
startTime := time.Now()
|
||||
var firstTokenTime time.Time
|
||||
firstTokenReceived := false
|
||||
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("Stream error: %v", err)
|
||||
}
|
||||
|
||||
chunkCount++
|
||||
|
||||
// Extract content from delta
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
fmt.Print(choice.Delta.Content)
|
||||
fullContent.WriteString(choice.Delta.Content)
|
||||
|
||||
// Track first token time (TTFT)
|
||||
if !firstTokenReceived {
|
||||
firstTokenTime = time.Now()
|
||||
firstTokenReceived = true
|
||||
ttft := firstTokenTime.Sub(startTime)
|
||||
fmt.Printf("\n[TTFT: %v]\n", ttft)
|
||||
}
|
||||
}
|
||||
|
||||
if choice.FinishReason != "" {
|
||||
fmt.Printf("\n\n[Finished: %s]\n", choice.FinishReason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate metrics
|
||||
if firstTokenReceived {
|
||||
elapsed := time.Since(startTime)
|
||||
tokensPerSecond := float64(fullContent.Len()) / elapsed.Seconds()
|
||||
fmt.Printf("\n=== Metrics ===\n")
|
||||
fmt.Printf("Total chunks: %d\n", chunkCount)
|
||||
fmt.Printf("Total content length: %d characters\n", fullContent.Len())
|
||||
fmt.Printf("Time elapsed: %v\n", elapsed)
|
||||
fmt.Printf("Tokens per second: %.2f\n", tokensPerSecond)
|
||||
}
|
||||
}
|
||||
|
||||
func float32Ptr(f float32) *float32 {
|
||||
return &f
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
46
third_party/sglang/sgl-model-gateway/bindings/golang/examples/streaming/run.sh
vendored
Executable file
46
third_party/sglang/sgl-model-gateway/bindings/golang/examples/streaming/run.sh
vendored
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Streaming example runner
|
||||
# Usage: ./run.sh [tokenizer_path] [endpoint]
|
||||
|
||||
# Set library path for Rust FFI library
|
||||
# The library should be in ./lib directory (created by 'make lib')
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
LIB_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)/lib"
|
||||
|
||||
# Check if lib directory exists
|
||||
if [ ! -d "$LIB_DIR" ]; then
|
||||
echo "Error: Library directory not found at $LIB_DIR"
|
||||
echo "Please run 'make lib' first to build and export the library"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get Python LDFLAGS (needed for Rust FFI that depends on Python)
|
||||
PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
|
||||
|
||||
# Set CGO_LDFLAGS to link with the Rust library
|
||||
export CGO_LDFLAGS="-L${LIB_DIR} -lsgl_model_gateway_go ${PYTHON_LDFLAGS} -ldl"
|
||||
|
||||
# macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
export DYLD_LIBRARY_PATH="${LIB_DIR}:${DYLD_LIBRARY_PATH}"
|
||||
else
|
||||
export LD_LIBRARY_PATH="${LIB_DIR}:${LD_LIBRARY_PATH}"
|
||||
fi
|
||||
|
||||
# Default configuration (can be overridden by environment variables or command line arguments)
|
||||
# Tokenizer path: ../tokenizer (relative to this script)
|
||||
DEFAULT_TOKENIZER_PATH="${SGL_TOKENIZER_PATH:-../tokenizer}"
|
||||
DEFAULT_ENDPOINT="${SGL_GRPC_ENDPOINT:-grpc://localhost:20000}"
|
||||
|
||||
TOKENIZER_PATH="${1:-${DEFAULT_TOKENIZER_PATH}}"
|
||||
ENDPOINT="${2:-${DEFAULT_ENDPOINT}}"
|
||||
|
||||
echo "Running streaming example..."
|
||||
echo "Library path: ${LIB_DIR}"
|
||||
echo "Tokenizer: $TOKENIZER_PATH"
|
||||
echo "Endpoint: $ENDPOINT"
|
||||
echo ""
|
||||
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
SGL_TOKENIZER_PATH="$TOKENIZER_PATH" SGL_GRPC_ENDPOINT="$ENDPOINT" go run main.go
|
||||
36
third_party/sglang/sgl-model-gateway/bindings/golang/go.sum
vendored
Normal file
36
third_party/sglang/sgl-model-gateway/bindings/golang/go.sum
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo=
|
||||
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
|
||||
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
228
third_party/sglang/sgl-model-gateway/bindings/golang/integration_test.go
vendored
Normal file
228
third_party/sglang/sgl-model-gateway/bindings/golang/integration_test.go
vendored
Normal file
@@ -0,0 +1,228 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
// integration_test.go contains integration tests that require a running SGLang server
|
||||
//
|
||||
// To run these tests:
|
||||
// 1. Start an SGLang server: python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-hf
|
||||
// 2. Run: go test -tags=integration -run TestIntegration
|
||||
|
||||
package sglang
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// getTestConfig returns test configuration from environment or defaults
|
||||
func getTestConfig(t *testing.T) ClientConfig {
|
||||
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "grpc://localhost:20000"
|
||||
}
|
||||
|
||||
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
|
||||
if tokenizerPath == "" {
|
||||
t.Skip("SGL_TOKENIZER_PATH not set")
|
||||
}
|
||||
|
||||
return ClientConfig{
|
||||
Endpoint: endpoint,
|
||||
TokenizerPath: tokenizerPath,
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationNonStreamingCompletion tests non-streaming chat completion
|
||||
func TestIntegrationNonStreamingCompletion(t *testing.T) {
|
||||
config := getTestConfig(t)
|
||||
|
||||
client, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "Say 'Hello, World!' only"},
|
||||
},
|
||||
Stream: false,
|
||||
Temperature: float32Ptr(0.0),
|
||||
MaxCompletionTokens: intPtr(50),
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateChatCompletion failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.ID == "" {
|
||||
t.Error("Response ID is empty")
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
t.Error("Response has no choices")
|
||||
}
|
||||
|
||||
if resp.Choices[0].Message.Content == "" {
|
||||
t.Error("Response content is empty")
|
||||
}
|
||||
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
|
||||
t.Error("Usage information is missing or invalid")
|
||||
}
|
||||
|
||||
t.Logf("Response: %s", resp.Choices[0].Message.Content)
|
||||
t.Logf("Usage: %+v", resp.Usage)
|
||||
}
|
||||
|
||||
// TestIntegrationStreamingCompletion tests streaming chat completion
|
||||
func TestIntegrationStreamingCompletion(t *testing.T) {
|
||||
config := getTestConfig(t)
|
||||
|
||||
client, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "Count from 1 to 5"},
|
||||
},
|
||||
Stream: true,
|
||||
Temperature: float32Ptr(0.0),
|
||||
MaxCompletionTokens: intPtr(100),
|
||||
}
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateChatCompletionStream failed: %v", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
chunkCount := 0
|
||||
totalContent := ""
|
||||
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
// io.EOF is expected at end of stream
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Stream error: %v", err)
|
||||
}
|
||||
|
||||
chunkCount++
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
totalContent += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunkCount == 0 {
|
||||
t.Error("Received no chunks from stream")
|
||||
}
|
||||
|
||||
if totalContent == "" {
|
||||
t.Error("Received no content from stream")
|
||||
}
|
||||
|
||||
t.Logf("Received %d chunks with content: %s", chunkCount, totalContent)
|
||||
}
|
||||
|
||||
// TestIntegrationConcurrentRequests tests multiple concurrent requests
|
||||
func TestIntegrationConcurrentRequests(t *testing.T) {
|
||||
config := getTestConfig(t)
|
||||
|
||||
client, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
numRequests := 3
|
||||
done := make(chan error, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(idx int) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "Say 'test'"},
|
||||
},
|
||||
Stream: false,
|
||||
MaxCompletionTokens: intPtr(50),
|
||||
}
|
||||
|
||||
_, err := client.CreateChatCompletion(ctx, req)
|
||||
done <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Collect results
|
||||
for i := 0; i < numRequests; i++ {
|
||||
if err := <-done; err != nil {
|
||||
t.Errorf("Request %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("All %d concurrent requests completed successfully", numRequests)
|
||||
}
|
||||
|
||||
// TestIntegrationContextCancellation tests that context cancellation is handled
|
||||
func TestIntegrationContextCancellation(t *testing.T) {
|
||||
config := getTestConfig(t)
|
||||
|
||||
client, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Create a context that cancels immediately
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Model: "default",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "test"},
|
||||
},
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
// Should handle cancelled context gracefully
|
||||
_, err = client.CreateChatCompletion(ctx, req)
|
||||
if err == nil {
|
||||
t.Error("Expected error from cancelled context")
|
||||
}
|
||||
|
||||
t.Logf("Cancelled context handled: %v", err)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func float32Ptr(f float32) *float32 {
|
||||
return &f
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
126
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/batch_postprocessor.go
vendored
Normal file
126
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/batch_postprocessor.go
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
|
||||
package ffi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BatchPostprocessor handles batch postprocessing of stream chunks to reduce FFI overhead
|
||||
type BatchPostprocessor struct {
|
||||
converter *GrpcResponseConverterHandle
|
||||
buffer []string
|
||||
batchSize int
|
||||
flushInterval time.Duration
|
||||
lastFlush time.Time
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
// NewBatchPostprocessor creates a new batch postprocessor
|
||||
func NewBatchPostprocessor(converter *GrpcResponseConverterHandle, batchSize int, flushInterval time.Duration) *BatchPostprocessor {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
if flushInterval < 0 {
|
||||
flushInterval = 0
|
||||
}
|
||||
|
||||
return &BatchPostprocessor{
|
||||
converter: converter,
|
||||
buffer: make([]string, 0, batchSize),
|
||||
batchSize: batchSize,
|
||||
flushInterval: flushInterval,
|
||||
lastFlush: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddChunk adds a chunk to the buffer and processes if batch is full
|
||||
func (b *BatchPostprocessor) AddChunk(chunkJSON string) (results []string, shouldFlush bool, err error) {
|
||||
if b.batchSize == 1 {
|
||||
openaiJSON, _, err := PostprocessStreamChunk(b.converter, chunkJSON)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return []string{openaiJSON}, false, nil
|
||||
}
|
||||
|
||||
b.buffer = append(b.buffer, chunkJSON)
|
||||
shouldProcess := len(b.buffer) >= b.batchSize
|
||||
shouldFlushTimeout := b.flushInterval > 0 && time.Since(b.lastFlush) >= b.flushInterval
|
||||
|
||||
if shouldProcess || shouldFlushTimeout {
|
||||
return b.processBatch()
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// Flush processes any remaining chunks in the buffer
|
||||
func (b *BatchPostprocessor) Flush() (results []string, err error) {
|
||||
if len(b.buffer) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
res, _, err := b.processBatch()
|
||||
return res, err
|
||||
}
|
||||
|
||||
// processBatch processes the current buffer and returns results
|
||||
func (b *BatchPostprocessor) processBatch() (results []string, shouldFlush bool, err error) {
|
||||
if len(b.buffer) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(b.buffer) * 200)
|
||||
sb.WriteString(`[`)
|
||||
for i, chunkJSONStr := range b.buffer {
|
||||
if i > 0 {
|
||||
sb.WriteString(`,`)
|
||||
}
|
||||
sb.WriteString(chunkJSONStr)
|
||||
}
|
||||
sb.WriteString(`]`)
|
||||
bufferJSON := sb.String()
|
||||
|
||||
resultJSON, _, err := PostprocessStreamChunksBatch(
|
||||
b.converter,
|
||||
bufferJSON,
|
||||
b.batchSize*2,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("batch postprocessing failed: %w", err)
|
||||
}
|
||||
|
||||
var resultArray []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(resultJSON), &resultArray); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to unmarshal results array: %w", err)
|
||||
}
|
||||
|
||||
resultStrings := make([]string, 0, len(resultArray))
|
||||
for _, rawMsg := range resultArray {
|
||||
resultStrings = append(resultStrings, string(rawMsg))
|
||||
}
|
||||
|
||||
b.buffer = b.buffer[:0]
|
||||
b.lastFlush = time.Now()
|
||||
|
||||
if b.timer != nil {
|
||||
b.timer.Stop()
|
||||
b.timer = nil
|
||||
}
|
||||
|
||||
return resultStrings, false, nil
|
||||
}
|
||||
|
||||
// Reset clears the buffer and resets the postprocessor state
|
||||
func (b *BatchPostprocessor) Reset() {
|
||||
b.buffer = b.buffer[:0]
|
||||
b.lastFlush = time.Now()
|
||||
if b.timer != nil {
|
||||
b.timer.Stop()
|
||||
b.timer = nil
|
||||
}
|
||||
}
|
||||
228
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/client.go
vendored
Normal file
228
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/client.go
vendored
Normal file
@@ -0,0 +1,228 @@
|
||||
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
|
||||
//
|
||||
// This package wraps the Rust FFI layer of SGLang, providing low-level access to:
|
||||
// - Client creation and connection management
|
||||
// - Chat completion streaming
|
||||
// - Stream reading and response conversion
|
||||
// - Memory management for C strings
|
||||
//
|
||||
// Internal use only: This package is intended for internal use by the sglang package.
|
||||
// End users should use the public sglang package instead.
|
||||
package ffi
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// Error codes
|
||||
typedef enum {
|
||||
SGL_ERROR_SUCCESS = 0,
|
||||
SGL_ERROR_INVALID_ARGUMENT = 1,
|
||||
SGL_ERROR_TOKENIZATION_ERROR = 2,
|
||||
SGL_ERROR_PARSING_ERROR = 3,
|
||||
SGL_ERROR_MEMORY_ERROR = 4,
|
||||
SGL_ERROR_UNKNOWN = 99
|
||||
} SglErrorCode;
|
||||
|
||||
// Opaque handles
|
||||
typedef void* SglangClientHandle;
|
||||
typedef void* SglangStreamHandle;
|
||||
|
||||
// Client SDK functions
|
||||
SglangClientHandle* sgl_client_create(const char* endpoint, const char* tokenizer_path, char** error_out);
|
||||
void sgl_client_free(SglangClientHandle* handle);
|
||||
SglErrorCode sgl_client_chat_completion_stream(SglangClientHandle* client_handle, const char* request_json, SglangStreamHandle** stream_handle_out, char** error_out);
|
||||
SglErrorCode sgl_stream_read_next(SglangStreamHandle* stream_handle, char** response_json_out, int* is_done_out, char** error_out);
|
||||
void sgl_stream_free(SglangStreamHandle* handle);
|
||||
void sgl_free_string(char* s);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// ErrorCode represents FFI error codes returned by Rust functions.
|
||||
//
|
||||
// These codes indicate the result of FFI operations. Use Error() to get a human-readable
|
||||
// error message.
|
||||
type ErrorCode int
|
||||
|
||||
const (
|
||||
// ErrorSuccess indicates the operation completed successfully
|
||||
ErrorSuccess ErrorCode = 0
|
||||
// ErrorInvalidArgument indicates invalid arguments were passed to the FFI function
|
||||
ErrorInvalidArgument ErrorCode = 1
|
||||
// ErrorTokenizationError indicates an error during tokenization
|
||||
ErrorTokenizationError ErrorCode = 2
|
||||
// ErrorParsingError indicates an error parsing the response or request
|
||||
ErrorParsingError ErrorCode = 3
|
||||
// ErrorMemoryError indicates a memory allocation error
|
||||
ErrorMemoryError ErrorCode = 4
|
||||
// ErrorUnknown indicates an unclassified error
|
||||
ErrorUnknown ErrorCode = 99
|
||||
)
|
||||
|
||||
// Error implements the error interface for ErrorCode.
|
||||
func (e ErrorCode) Error() string {
|
||||
switch e {
|
||||
case ErrorSuccess:
|
||||
return "success"
|
||||
case ErrorInvalidArgument:
|
||||
return "invalid argument"
|
||||
case ErrorTokenizationError:
|
||||
return "tokenization error"
|
||||
case ErrorParsingError:
|
||||
return "parsing error"
|
||||
case ErrorMemoryError:
|
||||
return "memory error"
|
||||
case ErrorUnknown:
|
||||
return "unknown error"
|
||||
default:
|
||||
return fmt.Sprintf("unknown error code: %d", e)
|
||||
}
|
||||
}
|
||||
|
||||
// SglangClientHandle wraps the Rust client SDK FFI handle.
|
||||
//
|
||||
// This struct maintains a connection to the SGLang gRPC server and is used
|
||||
// to create streams and manage the underlying Rust client resources.
|
||||
type SglangClientHandle struct {
|
||||
handle *C.SglangClientHandle
|
||||
}
|
||||
|
||||
// NewClient creates a new SGLang client handle via FFI.
|
||||
//
|
||||
// This function initializes the Rust client with the given endpoint and tokenizer path.
|
||||
//
|
||||
// Parameters:
|
||||
// - endpoint: gRPC endpoint URL (e.g., "grpc://localhost:20000")
|
||||
// - tokenizerPath: Path to tokenizer directory
|
||||
//
|
||||
// Returns:
|
||||
// - *SglangClientHandle: A new client handle
|
||||
// - error: An error if client creation failed
|
||||
func NewClient(endpoint, tokenizerPath string) (*SglangClientHandle, error) {
|
||||
cEndpoint := C.CString(endpoint)
|
||||
defer C.free(unsafe.Pointer(cEndpoint))
|
||||
|
||||
cTokenizerPath := C.CString(tokenizerPath)
|
||||
defer C.free(unsafe.Pointer(cTokenizerPath))
|
||||
|
||||
var errorPtr *C.char
|
||||
handle := C.sgl_client_create(cEndpoint, cTokenizerPath, &errorPtr)
|
||||
|
||||
if handle == nil {
|
||||
errorMsg := ""
|
||||
if errorPtr != nil {
|
||||
errorMsg = C.GoString(errorPtr)
|
||||
C.sgl_free_string(errorPtr)
|
||||
}
|
||||
if errorMsg == "" {
|
||||
errorMsg = "failed to create client"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
return &SglangClientHandle{handle: handle}, nil
|
||||
}
|
||||
|
||||
// Free releases the client handle
|
||||
func (h *SglangClientHandle) Free() {
|
||||
if h.handle != nil {
|
||||
C.sgl_client_free(h.handle)
|
||||
h.handle = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletionStream creates a streaming chat completion request
|
||||
func (h *SglangClientHandle) ChatCompletionStream(requestJSON string) (*SglangStreamHandle, error) {
|
||||
if h.handle == nil {
|
||||
return nil, fmt.Errorf("client handle is nil")
|
||||
}
|
||||
|
||||
cRequestJSON := C.CString(requestJSON)
|
||||
defer C.free(unsafe.Pointer(cRequestJSON))
|
||||
|
||||
var streamHandle *C.SglangStreamHandle
|
||||
var errorPtr *C.char
|
||||
|
||||
result := C.sgl_client_chat_completion_stream(
|
||||
h.handle,
|
||||
cRequestJSON,
|
||||
&streamHandle,
|
||||
&errorPtr,
|
||||
)
|
||||
|
||||
if ErrorCode(result) != ErrorSuccess {
|
||||
errorMsg := ""
|
||||
if errorPtr != nil {
|
||||
errorMsg = C.GoString(errorPtr)
|
||||
C.sgl_free_string(errorPtr)
|
||||
}
|
||||
if errorMsg == "" {
|
||||
errorMsg = fmt.Sprintf("error code %d", result)
|
||||
}
|
||||
return nil, fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
if streamHandle == nil {
|
||||
return nil, fmt.Errorf("stream handle is nil")
|
||||
}
|
||||
|
||||
return &SglangStreamHandle{handle: streamHandle}, nil
|
||||
}
|
||||
|
||||
// SglangStreamHandle wraps the Rust stream FFI handle
|
||||
type SglangStreamHandle struct {
|
||||
handle *C.SglangStreamHandle
|
||||
}
|
||||
|
||||
// ReadNext reads the next chunk from the stream
|
||||
// Returns: (responseJSON, isDone, error)
|
||||
func (h *SglangStreamHandle) ReadNext() (string, bool, error) {
|
||||
if h.handle == nil {
|
||||
return "", true, fmt.Errorf("stream handle is nil")
|
||||
}
|
||||
|
||||
var responseJSON *C.char
|
||||
var isDone C.int
|
||||
var errorPtr *C.char
|
||||
|
||||
result := C.sgl_stream_read_next(
|
||||
h.handle,
|
||||
&responseJSON,
|
||||
&isDone,
|
||||
&errorPtr,
|
||||
)
|
||||
|
||||
if ErrorCode(result) != ErrorSuccess {
|
||||
errorMsg := ""
|
||||
if errorPtr != nil {
|
||||
errorMsg = C.GoString(errorPtr)
|
||||
C.sgl_free_string(errorPtr)
|
||||
}
|
||||
if errorMsg == "" {
|
||||
errorMsg = fmt.Sprintf("error code %d", result)
|
||||
}
|
||||
return "", isDone == 1, fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
responseStr := ""
|
||||
if responseJSON != nil {
|
||||
responseStr = C.GoString(responseJSON)
|
||||
C.sgl_free_string(responseJSON)
|
||||
}
|
||||
|
||||
return responseStr, isDone == 1, nil
|
||||
}
|
||||
|
||||
// Free releases the stream handle
|
||||
func (h *SglangStreamHandle) Free() {
|
||||
if h.handle != nil {
|
||||
C.sgl_stream_free(h.handle)
|
||||
h.handle = nil
|
||||
}
|
||||
}
|
||||
275
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/grpc_converter.go
vendored
Normal file
275
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/grpc_converter.go
vendored
Normal file
@@ -0,0 +1,275 @@
|
||||
package ffi
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// Error codes (must match client.go)
|
||||
typedef enum {
|
||||
SGL_ERROR_SUCCESS = 0,
|
||||
SGL_ERROR_INVALID_ARGUMENT = 1,
|
||||
SGL_ERROR_TOKENIZATION_ERROR = 2,
|
||||
SGL_ERROR_PARSING_ERROR = 3,
|
||||
SGL_ERROR_MEMORY_ERROR = 4,
|
||||
SGL_ERROR_UNKNOWN = 99
|
||||
} SglErrorCode;
|
||||
|
||||
// Opaque handles
|
||||
typedef void* TokenizerHandle;
|
||||
typedef void* GrpcResponseConverterHandle;
|
||||
|
||||
// Converter functions
|
||||
GrpcResponseConverterHandle* sgl_grpc_response_converter_create(
|
||||
TokenizerHandle* tokenizer_handle,
|
||||
const char* model,
|
||||
const char* request_id,
|
||||
const char* tools_json,
|
||||
const char* tool_choice_json,
|
||||
const char* stop,
|
||||
const char* stop_token_ids,
|
||||
int skip_special_tokens,
|
||||
int initial_prompt_tokens,
|
||||
char** error_out
|
||||
);
|
||||
|
||||
void sgl_grpc_response_converter_free(GrpcResponseConverterHandle* handle);
|
||||
|
||||
// Tokenizer functions
|
||||
TokenizerHandle* sgl_tokenizer_create_from_file(const char* tokenizer_path, char** error_out);
|
||||
void sgl_tokenizer_free(TokenizerHandle* handle);
|
||||
|
||||
// Memory management
|
||||
void sgl_free_string(char* s);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CreateGrpcResponseConverter creates a gRPC response converter handle
|
||||
// This function creates a new tokenizer handle each time (for backward compatibility)
|
||||
// For better performance, use CreateGrpcResponseConverterWithTokenizer with a cached tokenizer
|
||||
func CreateGrpcResponseConverter(
|
||||
tokenizerPath string,
|
||||
model string,
|
||||
requestID string,
|
||||
toolsJSON string,
|
||||
toolChoiceJSON string,
|
||||
stopJSON string,
|
||||
stopTokenIDs []uint32,
|
||||
skipSpecialTokens bool,
|
||||
initialPromptTokens int32,
|
||||
) (*GrpcResponseConverterHandle, error) {
|
||||
// Create tokenizer handle
|
||||
tokenizerHandle, err := createTokenizerHandle(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tokenizer handle: %w", err)
|
||||
}
|
||||
defer C.sgl_tokenizer_free(tokenizerHandle)
|
||||
|
||||
return createGrpcResponseConverterWithTokenizerHandle(
|
||||
tokenizerHandle,
|
||||
model,
|
||||
requestID,
|
||||
toolsJSON,
|
||||
toolChoiceJSON,
|
||||
stopJSON,
|
||||
stopTokenIDs,
|
||||
skipSpecialTokens,
|
||||
initialPromptTokens,
|
||||
)
|
||||
}
|
||||
|
||||
// CreateGrpcResponseConverterWithTokenizer creates a gRPC response converter handle using a cached tokenizer
|
||||
// This is more efficient as it reuses the tokenizer instead of creating a new one each time
|
||||
func CreateGrpcResponseConverterWithTokenizer(
|
||||
tokenizerHandle *TokenizerHandle,
|
||||
model string,
|
||||
requestID string,
|
||||
toolsJSON string,
|
||||
toolChoiceJSON string,
|
||||
stopJSON string,
|
||||
stopTokenIDs []uint32,
|
||||
skipSpecialTokens bool,
|
||||
initialPromptTokens int32,
|
||||
) (*GrpcResponseConverterHandle, error) {
|
||||
if tokenizerHandle == nil || tokenizerHandle.handle == nil {
|
||||
return nil, fmt.Errorf("invalid tokenizer handle")
|
||||
}
|
||||
|
||||
return createGrpcResponseConverterWithTokenizerHandle(
|
||||
tokenizerHandle.handle,
|
||||
model,
|
||||
requestID,
|
||||
toolsJSON,
|
||||
toolChoiceJSON,
|
||||
stopJSON,
|
||||
stopTokenIDs,
|
||||
skipSpecialTokens,
|
||||
initialPromptTokens,
|
||||
)
|
||||
}
|
||||
|
||||
// createGrpcResponseConverterWithTokenizerHandle is the internal implementation
|
||||
func createGrpcResponseConverterWithTokenizerHandle(
|
||||
tokenizerHandle *C.TokenizerHandle,
|
||||
model string,
|
||||
requestID string,
|
||||
toolsJSON string,
|
||||
toolChoiceJSON string,
|
||||
stopJSON string,
|
||||
stopTokenIDs []uint32,
|
||||
skipSpecialTokens bool,
|
||||
initialPromptTokens int32,
|
||||
) (*GrpcResponseConverterHandle, error) {
|
||||
|
||||
// Convert strings to C strings
|
||||
modelC := C.CString(model)
|
||||
defer C.free(unsafe.Pointer(modelC))
|
||||
|
||||
requestIDC := C.CString(requestID)
|
||||
defer C.free(unsafe.Pointer(requestIDC))
|
||||
|
||||
var toolsJSONC *C.char
|
||||
if toolsJSON != "" {
|
||||
toolsJSONC = C.CString(toolsJSON)
|
||||
defer C.free(unsafe.Pointer(toolsJSONC))
|
||||
}
|
||||
|
||||
var toolChoiceJSONC *C.char
|
||||
if toolChoiceJSON != "" {
|
||||
toolChoiceJSONC = C.CString(toolChoiceJSON)
|
||||
defer C.free(unsafe.Pointer(toolChoiceJSONC))
|
||||
}
|
||||
|
||||
var stopJSONC *C.char
|
||||
if stopJSON != "" {
|
||||
stopJSONC = C.CString(stopJSON)
|
||||
defer C.free(unsafe.Pointer(stopJSONC))
|
||||
}
|
||||
|
||||
// Convert stop_token_ids to JSON string
|
||||
stopTokenIDsJSON := ""
|
||||
if len(stopTokenIDs) > 0 {
|
||||
stopTokenIDsJSON = fmt.Sprintf("[%d", stopTokenIDs[0])
|
||||
for i := 1; i < len(stopTokenIDs); i++ {
|
||||
stopTokenIDsJSON += fmt.Sprintf(",%d", stopTokenIDs[i])
|
||||
}
|
||||
stopTokenIDsJSON += "]"
|
||||
}
|
||||
|
||||
var stopTokenIDsJSONC *C.char
|
||||
if stopTokenIDsJSON != "" {
|
||||
stopTokenIDsJSONC = C.CString(stopTokenIDsJSON)
|
||||
defer C.free(unsafe.Pointer(stopTokenIDsJSONC))
|
||||
}
|
||||
|
||||
var errorOut *C.char
|
||||
skipSpecialTokensC := C.int(0)
|
||||
if skipSpecialTokens {
|
||||
skipSpecialTokensC = C.int(1)
|
||||
}
|
||||
|
||||
initialPromptTokensC := C.int(initialPromptTokens)
|
||||
|
||||
converterHandle := C.sgl_grpc_response_converter_create(
|
||||
tokenizerHandle,
|
||||
modelC,
|
||||
requestIDC,
|
||||
toolsJSONC,
|
||||
toolChoiceJSONC,
|
||||
stopJSONC,
|
||||
stopTokenIDsJSONC,
|
||||
skipSpecialTokensC,
|
||||
initialPromptTokensC,
|
||||
&errorOut,
|
||||
)
|
||||
|
||||
if converterHandle == nil {
|
||||
errorMsg := ""
|
||||
if errorOut != nil {
|
||||
errorMsg = C.GoString(errorOut)
|
||||
C.sgl_free_string(errorOut)
|
||||
}
|
||||
if errorMsg == "" {
|
||||
errorMsg = "failed to create converter handle"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
return &GrpcResponseConverterHandle{
|
||||
handle: converterHandle,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FreeGrpcResponseConverter frees a gRPC response converter handle
|
||||
func FreeGrpcResponseConverter(handle *GrpcResponseConverterHandle) {
|
||||
if handle != nil && handle.handle != nil {
|
||||
C.sgl_grpc_response_converter_free(handle.handle)
|
||||
handle.handle = nil
|
||||
}
|
||||
}
|
||||
|
||||
// TokenizerHandle wraps the Rust tokenizer FFI handle
|
||||
type TokenizerHandle struct {
|
||||
handle *C.TokenizerHandle
|
||||
}
|
||||
|
||||
// CreateTokenizerHandle creates a tokenizer handle (exported for caching)
|
||||
func CreateTokenizerHandle(tokenizerPath string) (*TokenizerHandle, error) {
|
||||
tokenizerPathC := C.CString(tokenizerPath)
|
||||
defer C.free(unsafe.Pointer(tokenizerPathC))
|
||||
|
||||
var errorOut *C.char
|
||||
tokenizerHandle := C.sgl_tokenizer_create_from_file(tokenizerPathC, &errorOut)
|
||||
|
||||
if tokenizerHandle == nil {
|
||||
errorMsg := ""
|
||||
if errorOut != nil {
|
||||
errorMsg = C.GoString(errorOut)
|
||||
C.sgl_free_string(errorOut)
|
||||
}
|
||||
if errorMsg == "" {
|
||||
errorMsg = "failed to create tokenizer handle"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
return &TokenizerHandle{
|
||||
handle: tokenizerHandle,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FreeTokenizerHandle frees a tokenizer handle
|
||||
func FreeTokenizerHandle(handle *TokenizerHandle) {
|
||||
if handle != nil && handle.handle != nil {
|
||||
C.sgl_tokenizer_free(handle.handle)
|
||||
handle.handle = nil
|
||||
}
|
||||
}
|
||||
|
||||
// createTokenizerHandle creates a tokenizer handle (helper function, internal use)
|
||||
func createTokenizerHandle(tokenizerPath string) (*C.TokenizerHandle, error) {
|
||||
tokenizerPathC := C.CString(tokenizerPath)
|
||||
defer C.free(unsafe.Pointer(tokenizerPathC))
|
||||
|
||||
var errorOut *C.char
|
||||
tokenizerHandle := C.sgl_tokenizer_create_from_file(tokenizerPathC, &errorOut)
|
||||
|
||||
if tokenizerHandle == nil {
|
||||
errorMsg := ""
|
||||
if errorOut != nil {
|
||||
errorMsg = C.GoString(errorOut)
|
||||
C.sgl_free_string(errorOut)
|
||||
}
|
||||
if errorMsg == "" {
|
||||
errorMsg = "failed to create tokenizer handle"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
return tokenizerHandle, nil
|
||||
}
|
||||
156
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/postprocessor.go
vendored
Normal file
156
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/postprocessor.go
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
|
||||
package ffi
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// Error codes (must match client.go)
|
||||
typedef enum {
|
||||
SGL_ERROR_SUCCESS = 0,
|
||||
SGL_ERROR_INVALID_ARGUMENT = 1,
|
||||
SGL_ERROR_TOKENIZATION_ERROR = 2,
|
||||
SGL_ERROR_PARSING_ERROR = 3,
|
||||
SGL_ERROR_MEMORY_ERROR = 4,
|
||||
SGL_ERROR_UNKNOWN = 99
|
||||
} SglErrorCode;
|
||||
|
||||
// Opaque handle (must match grpc_converter.go)
|
||||
typedef void* GrpcResponseConverterHandle;
|
||||
|
||||
// Postprocessor functions
|
||||
SglErrorCode sgl_postprocess_stream_chunk(
|
||||
GrpcResponseConverterHandle* converter_handle,
|
||||
const char* proto_chunk_json,
|
||||
char** openai_json_out,
|
||||
int* is_done_out,
|
||||
char** error_out
|
||||
);
|
||||
|
||||
SglErrorCode sgl_postprocess_stream_chunks_batch(
|
||||
GrpcResponseConverterHandle* converter_handle,
|
||||
const char* proto_chunks_json_array,
|
||||
int max_chunks,
|
||||
char** openai_chunks_json_array_out,
|
||||
int* chunks_count_out,
|
||||
char** error_out
|
||||
);
|
||||
|
||||
// Memory management
|
||||
void sgl_free_string(char* s);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// GrpcResponseConverterHandle wraps the Rust gRPC response converter FFI handle
|
||||
type GrpcResponseConverterHandle struct {
|
||||
handle *C.GrpcResponseConverterHandle
|
||||
}
|
||||
|
||||
// PostprocessStreamChunk postprocesses a gRPC stream chunk to OpenAI format
|
||||
//
|
||||
// This function:
|
||||
// 1. Parses the proto chunk from JSON
|
||||
// 2. Converts it to OpenAI format using the converter handle
|
||||
// 3. Returns the OpenAI format JSON
|
||||
//
|
||||
// Returns the OpenAI format JSON, is_done flag, and any error.
|
||||
func PostprocessStreamChunk(converterHandle *GrpcResponseConverterHandle, protoChunkJSON string) (openaiJSON string, isDone bool, err error) {
|
||||
if converterHandle == nil || converterHandle.handle == nil {
|
||||
return "", false, fmt.Errorf("invalid converter handle")
|
||||
}
|
||||
|
||||
protoChunkJSONC := C.CString(protoChunkJSON)
|
||||
defer C.free(unsafe.Pointer(protoChunkJSONC))
|
||||
|
||||
var openaiJSONOut *C.char
|
||||
var isDoneOut C.int
|
||||
var errorOut *C.char
|
||||
|
||||
errorCode := C.sgl_postprocess_stream_chunk(
|
||||
converterHandle.handle,
|
||||
protoChunkJSONC,
|
||||
&openaiJSONOut,
|
||||
&isDoneOut,
|
||||
&errorOut,
|
||||
)
|
||||
|
||||
if errorCode != C.SGL_ERROR_SUCCESS {
|
||||
errorMsg := ""
|
||||
if errorOut != nil {
|
||||
errorMsg = C.GoString(errorOut)
|
||||
C.sgl_free_string(errorOut)
|
||||
}
|
||||
return "", false, fmt.Errorf("postprocessing failed: %s", errorMsg)
|
||||
}
|
||||
|
||||
openaiJSON = C.GoString(openaiJSONOut)
|
||||
isDone = isDoneOut != 0
|
||||
|
||||
// Free the C string allocated by Rust
|
||||
if openaiJSONOut != nil {
|
||||
C.sgl_free_string(openaiJSONOut)
|
||||
}
|
||||
|
||||
return openaiJSON, isDone, nil
|
||||
}
|
||||
|
||||
// PostprocessStreamChunksBatch postprocesses multiple gRPC stream chunks in batch
|
||||
//
|
||||
// This function processes multiple chunks in a single FFI call, significantly reducing
|
||||
// FFI overhead in streaming scenarios.
|
||||
//
|
||||
// Arguments:
|
||||
// - converterHandle: Converter handle
|
||||
// - protoChunksJSONArray: JSON array string of proto chunks
|
||||
// - maxChunks: Maximum number of chunks to process (for safety, typically 10-20)
|
||||
//
|
||||
// Returns:
|
||||
// - openaiChunksJSONArray: JSON array of OpenAI format chunks
|
||||
// - chunksCount: Number of processed chunks
|
||||
// - error: Any error that occurred
|
||||
func PostprocessStreamChunksBatch(converterHandle *GrpcResponseConverterHandle, protoChunksJSONArray string, maxChunks int) (openaiChunksJSONArray string, chunksCount int, err error) {
|
||||
if converterHandle == nil || converterHandle.handle == nil {
|
||||
return "", 0, fmt.Errorf("invalid converter handle")
|
||||
}
|
||||
|
||||
protoChunksJSONArrayC := C.CString(protoChunksJSONArray)
|
||||
defer C.free(unsafe.Pointer(protoChunksJSONArrayC))
|
||||
|
||||
var openaiChunksJSONArrayOut *C.char
|
||||
var chunksCountOut C.int
|
||||
var errorOut *C.char
|
||||
|
||||
errorCode := C.sgl_postprocess_stream_chunks_batch(
|
||||
converterHandle.handle,
|
||||
protoChunksJSONArrayC,
|
||||
C.int(maxChunks),
|
||||
&openaiChunksJSONArrayOut,
|
||||
&chunksCountOut,
|
||||
&errorOut,
|
||||
)
|
||||
|
||||
if errorCode != C.SGL_ERROR_SUCCESS {
|
||||
errorMsg := ""
|
||||
if errorOut != nil {
|
||||
errorMsg = C.GoString(errorOut)
|
||||
C.sgl_free_string(errorOut)
|
||||
}
|
||||
return "", 0, fmt.Errorf("batch postprocessing failed: %s", errorMsg)
|
||||
}
|
||||
|
||||
openaiChunksJSONArray = C.GoString(openaiChunksJSONArrayOut)
|
||||
chunksCount = int(chunksCountOut)
|
||||
|
||||
// Free the C string allocated by Rust
|
||||
if openaiChunksJSONArrayOut != nil {
|
||||
C.sgl_free_string(openaiChunksJSONArrayOut)
|
||||
}
|
||||
|
||||
return openaiChunksJSONArray, chunksCount, nil
|
||||
}
|
||||
246
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/preprocessor.go
vendored
Normal file
246
third_party/sglang/sgl-model-gateway/bindings/golang/internal/ffi/preprocessor.go
vendored
Normal file
@@ -0,0 +1,246 @@
|
||||
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
|
||||
package ffi
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// Error codes (must match client.go)
|
||||
typedef enum {
|
||||
SGL_ERROR_SUCCESS = 0,
|
||||
SGL_ERROR_INVALID_ARGUMENT = 1,
|
||||
SGL_ERROR_TOKENIZATION_ERROR = 2,
|
||||
SGL_ERROR_PARSING_ERROR = 3,
|
||||
SGL_ERROR_MEMORY_ERROR = 4,
|
||||
SGL_ERROR_UNKNOWN = 99
|
||||
} SglErrorCode;
|
||||
|
||||
// Preprocessor functions
|
||||
SglErrorCode sgl_preprocess_chat_request(
|
||||
const char* request_json,
|
||||
const char* tokenizer_path,
|
||||
char** prompt_text_out,
|
||||
uint32_t** token_ids_out,
|
||||
size_t* token_ids_len_out,
|
||||
char** tool_constraints_json_out,
|
||||
int32_t* prompt_tokens_out,
|
||||
char** error_out
|
||||
);
|
||||
|
||||
// Opaque handle (must match grpc_converter.go)
|
||||
typedef void* TokenizerHandle;
|
||||
|
||||
SglErrorCode sgl_preprocess_chat_request_with_tokenizer(
|
||||
const char* request_json,
|
||||
void* tokenizer_handle,
|
||||
char** prompt_text_out,
|
||||
uint32_t** token_ids_out,
|
||||
size_t* token_ids_len_out,
|
||||
char** tool_constraints_json_out,
|
||||
int32_t* prompt_tokens_out,
|
||||
char** error_out
|
||||
);
|
||||
|
||||
void sgl_preprocessed_request_free(
|
||||
char* prompt_text,
|
||||
uint32_t* token_ids,
|
||||
size_t token_ids_len,
|
||||
char* tool_constraints_json
|
||||
);
|
||||
|
||||
// Memory management
|
||||
void sgl_free_string(char* s);
|
||||
void sgl_free_token_ids(uint32_t* ptr, size_t count);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// PreprocessedRequest represents a preprocessed chat request
|
||||
type PreprocessedRequest struct {
|
||||
PromptText string
|
||||
TokenIDs []uint32
|
||||
ToolConstraintsJSON string
|
||||
PromptTokens int32
|
||||
// Internal pointers for memory management
|
||||
promptTextPtr *C.char
|
||||
tokenIDsPtr *C.uint32_t
|
||||
tokenIDsLen uintptr
|
||||
toolConstraintsJSONPtr *C.char
|
||||
}
|
||||
|
||||
// PreprocessChatRequest preprocesses a chat completion request
|
||||
//
|
||||
// This function:
|
||||
// 1. Applies chat_template to messages
|
||||
// 2. Tokenizes the processed text
|
||||
// 3. Generates tool constraints (if tools are present)
|
||||
//
|
||||
// Returns the preprocessed request data and any error.
|
||||
func PreprocessChatRequest(requestJSON, tokenizerPath string) (*PreprocessedRequest, error) {
|
||||
requestJSONC := C.CString(requestJSON)
|
||||
defer C.free(unsafe.Pointer(requestJSONC))
|
||||
|
||||
tokenizerPathC := C.CString(tokenizerPath)
|
||||
defer C.free(unsafe.Pointer(tokenizerPathC))
|
||||
|
||||
var promptTextOut *C.char
|
||||
var tokenIDsOut *C.uint32_t
|
||||
var tokenIDsLenOut C.size_t
|
||||
var toolConstraintsJSONOut *C.char
|
||||
var promptTokensOut C.int32_t
|
||||
var errorOut *C.char
|
||||
|
||||
errorCode := C.sgl_preprocess_chat_request(
|
||||
requestJSONC,
|
||||
tokenizerPathC,
|
||||
&promptTextOut,
|
||||
&tokenIDsOut,
|
||||
&tokenIDsLenOut,
|
||||
&toolConstraintsJSONOut,
|
||||
&promptTokensOut,
|
||||
&errorOut,
|
||||
)
|
||||
|
||||
if errorCode != C.SGL_ERROR_SUCCESS {
|
||||
errorMsg := ""
|
||||
if errorOut != nil {
|
||||
errorMsg = C.GoString(errorOut)
|
||||
C.sgl_free_string(errorOut)
|
||||
}
|
||||
return nil, fmt.Errorf("preprocessing failed: %s", errorMsg)
|
||||
}
|
||||
|
||||
result := &PreprocessedRequest{
|
||||
PromptText: C.GoString(promptTextOut),
|
||||
TokenIDs: make([]uint32, tokenIDsLenOut),
|
||||
ToolConstraintsJSON: "",
|
||||
PromptTokens: int32(promptTokensOut),
|
||||
}
|
||||
|
||||
// Copy token IDs
|
||||
if tokenIDsOut != nil && tokenIDsLenOut > 0 {
|
||||
tokenIDsSlice := (*[1 << 30]C.uint32_t)(unsafe.Pointer(tokenIDsOut))[:tokenIDsLenOut:tokenIDsLenOut]
|
||||
for i := range result.TokenIDs {
|
||||
result.TokenIDs[i] = uint32(tokenIDsSlice[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Copy tool constraints JSON if present
|
||||
if toolConstraintsJSONOut != nil {
|
||||
result.ToolConstraintsJSON = C.GoString(toolConstraintsJSONOut)
|
||||
}
|
||||
|
||||
// Store pointers for later cleanup
|
||||
result.promptTextPtr = promptTextOut
|
||||
result.tokenIDsPtr = tokenIDsOut
|
||||
result.tokenIDsLen = uintptr(tokenIDsLenOut)
|
||||
result.toolConstraintsJSONPtr = toolConstraintsJSONOut
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// PreprocessChatRequestWithTokenizer preprocesses a chat completion request using an existing tokenizer handle
|
||||
//
|
||||
// This function is similar to PreprocessChatRequest, but accepts a TokenizerHandle
|
||||
// instead of creating a new tokenizer. This allows reusing a cached tokenizer instance,
|
||||
// significantly reducing initialization overhead in concurrent scenarios.
|
||||
//
|
||||
// Returns the preprocessed request data and any error.
|
||||
func PreprocessChatRequestWithTokenizer(requestJSON string, tokenizerHandle *TokenizerHandle) (*PreprocessedRequest, error) {
|
||||
requestJSONC := C.CString(requestJSON)
|
||||
defer C.free(unsafe.Pointer(requestJSONC))
|
||||
|
||||
if tokenizerHandle == nil || tokenizerHandle.handle == nil {
|
||||
return nil, fmt.Errorf("invalid tokenizer handle")
|
||||
}
|
||||
|
||||
var promptTextOut *C.char
|
||||
var tokenIDsOut *C.uint32_t
|
||||
var tokenIDsLenOut C.size_t
|
||||
var toolConstraintsJSONOut *C.char
|
||||
var promptTokensOut C.int32_t
|
||||
var errorOut *C.char
|
||||
|
||||
errorCode := C.sgl_preprocess_chat_request_with_tokenizer(
|
||||
requestJSONC,
|
||||
unsafe.Pointer(tokenizerHandle.handle), // Convert *C.TokenizerHandle to void*
|
||||
&promptTextOut,
|
||||
&tokenIDsOut,
|
||||
&tokenIDsLenOut,
|
||||
&toolConstraintsJSONOut,
|
||||
&promptTokensOut,
|
||||
&errorOut,
|
||||
)
|
||||
|
||||
if errorCode != C.SGL_ERROR_SUCCESS {
|
||||
errorMsg := ""
|
||||
if errorOut != nil {
|
||||
errorMsg = C.GoString(errorOut)
|
||||
C.sgl_free_string(errorOut)
|
||||
}
|
||||
return nil, fmt.Errorf("preprocessing failed: %s", errorMsg)
|
||||
}
|
||||
|
||||
result := &PreprocessedRequest{
|
||||
PromptText: C.GoString(promptTextOut),
|
||||
TokenIDs: make([]uint32, tokenIDsLenOut),
|
||||
ToolConstraintsJSON: "",
|
||||
PromptTokens: int32(promptTokensOut),
|
||||
}
|
||||
|
||||
// Copy token IDs
|
||||
if tokenIDsOut != nil && tokenIDsLenOut > 0 {
|
||||
tokenIDsSlice := (*[1 << 30]C.uint32_t)(unsafe.Pointer(tokenIDsOut))[:tokenIDsLenOut:tokenIDsLenOut]
|
||||
for i := range result.TokenIDs {
|
||||
result.TokenIDs[i] = uint32(tokenIDsSlice[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Copy tool constraints JSON if present
|
||||
if toolConstraintsJSONOut != nil {
|
||||
result.ToolConstraintsJSON = C.GoString(toolConstraintsJSONOut)
|
||||
}
|
||||
|
||||
// Store pointers for later cleanup
|
||||
result.promptTextPtr = promptTextOut
|
||||
result.tokenIDsPtr = tokenIDsOut
|
||||
result.tokenIDsLen = uintptr(tokenIDsLenOut)
|
||||
result.toolConstraintsJSONPtr = toolConstraintsJSONOut
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Free frees the memory allocated for a preprocessed request
|
||||
func (p *PreprocessedRequest) Free() {
|
||||
if p.promptTextPtr != nil || p.tokenIDsPtr != nil || p.toolConstraintsJSONPtr != nil {
|
||||
C.sgl_preprocessed_request_free(
|
||||
p.promptTextPtr,
|
||||
p.tokenIDsPtr,
|
||||
C.size_t(p.tokenIDsLen),
|
||||
p.toolConstraintsJSONPtr,
|
||||
)
|
||||
// Clear pointers to prevent double-free
|
||||
p.promptTextPtr = nil
|
||||
p.tokenIDsPtr = nil
|
||||
p.tokenIDsLen = 0
|
||||
p.toolConstraintsJSONPtr = nil
|
||||
}
|
||||
}
|
||||
|
||||
// FreePreprocessedRequest frees the memory allocated for a preprocessed request
|
||||
// This is a convenience function for direct pointer management
|
||||
func FreePreprocessedRequest(promptTextPtr *C.char, tokenIDsPtr *C.uint32_t, tokenIDsLen uintptr, toolConstraintsJSONPtr *C.char) {
|
||||
if promptTextPtr != nil || tokenIDsPtr != nil || toolConstraintsJSONPtr != nil {
|
||||
C.sgl_preprocessed_request_free(
|
||||
promptTextPtr,
|
||||
tokenIDsPtr,
|
||||
C.size_t(tokenIDsLen),
|
||||
toolConstraintsJSONPtr,
|
||||
)
|
||||
}
|
||||
}
|
||||
684
third_party/sglang/sgl-model-gateway/bindings/golang/internal/grpc/client_grpc.go
vendored
Normal file
684
third_party/sglang/sgl-model-gateway/bindings/golang/internal/grpc/client_grpc.go
vendored
Normal file
@@ -0,0 +1,684 @@
|
||||
// Package grpc provides gRPC client implementation for SGLang
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/sglang/sglang-go-grpc-sdk/internal/ffi"
|
||||
"github.com/sglang/sglang-go-grpc-sdk/internal/proto"
|
||||
)
|
||||
|
||||
type grpcClientStream interface {
|
||||
Recv() (*proto.GenerateResponse, error)
|
||||
CloseSend() error
|
||||
}
|
||||
|
||||
// recvResult holds the result of a Recv() call
|
||||
type recvResult struct {
|
||||
resp *proto.GenerateResponse
|
||||
err error
|
||||
}
|
||||
|
||||
type GrpcClient struct {
|
||||
conn *grpc.ClientConn
|
||||
client proto.SglangSchedulerClient
|
||||
tokenizerPath string
|
||||
tokenizerHandle *ffi.TokenizerHandle
|
||||
bufferSizes ChannelBufferSizes
|
||||
timeouts Timeouts
|
||||
requestCounter uint64 // Atomic counter to ensure unique request IDs
|
||||
}
|
||||
|
||||
type ChannelBufferSizes struct {
|
||||
ResultJSONChan int
|
||||
ErrChan int
|
||||
RecvChan int
|
||||
}
|
||||
|
||||
type Timeouts struct {
|
||||
KeepaliveTime time.Duration
|
||||
KeepaliveTimeout time.Duration
|
||||
CloseTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewGrpcClient(endpoint, tokenizerPath string, bufferSizes ChannelBufferSizes, timeouts Timeouts) (*GrpcClient, error) {
|
||||
endpoint = strings.TrimPrefix(endpoint, "grpc://")
|
||||
if !strings.Contains(endpoint, ":") {
|
||||
return nil, fmt.Errorf("invalid endpoint format: %s (expected grpc://host:port)", endpoint)
|
||||
}
|
||||
|
||||
keepaliveParams := keepalive.ClientParameters{
|
||||
Time: timeouts.KeepaliveTime,
|
||||
Timeout: timeouts.KeepaliveTimeout,
|
||||
PermitWithoutStream: false,
|
||||
}
|
||||
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithKeepaliveParams(keepaliveParams),
|
||||
}
|
||||
|
||||
conn, err := grpc.NewClient(endpoint, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to gRPC server: %w", err)
|
||||
}
|
||||
|
||||
client := proto.NewSglangSchedulerClient(conn)
|
||||
|
||||
tokenizerHandle, err := ffi.CreateTokenizerHandle(tokenizerPath)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("failed to create tokenizer handle: %w", err)
|
||||
}
|
||||
|
||||
return &GrpcClient{
|
||||
conn: conn,
|
||||
client: client,
|
||||
tokenizerPath: tokenizerPath,
|
||||
tokenizerHandle: tokenizerHandle,
|
||||
bufferSizes: bufferSizes,
|
||||
timeouts: timeouts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Close() error {
|
||||
if c.tokenizerHandle != nil {
|
||||
ffi.FreeTokenizerHandle(c.tokenizerHandle)
|
||||
c.tokenizerHandle = nil
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) CreateChatCompletionStream(ctx context.Context, reqJSON string) (*GrpcChatCompletionStream, error) {
|
||||
if c.tokenizerHandle == nil {
|
||||
return nil, fmt.Errorf("tokenizer handle is nil (should be created at startup)")
|
||||
}
|
||||
|
||||
preprocessed, err := ffi.PreprocessChatRequestWithTokenizer(reqJSON, c.tokenizerHandle)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preprocessing failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if preprocessed != nil {
|
||||
preprocessed.Free()
|
||||
}
|
||||
}()
|
||||
|
||||
// Parse request JSON to get parameters
|
||||
var reqMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reqJSON), &reqMap); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse request JSON: %w", err)
|
||||
}
|
||||
|
||||
model, _ := reqMap["model"].(string)
|
||||
if model == "" {
|
||||
model = "default"
|
||||
}
|
||||
|
||||
// Build GenerateRequest
|
||||
// Generate unique request ID using timestamp + atomic counter to avoid collisions
|
||||
// This matches Rust version's UUID-based approach for uniqueness
|
||||
counter := atomic.AddUint64(&c.requestCounter, 1)
|
||||
requestID := fmt.Sprintf("chatcmpl-%d-%d", time.Now().UnixNano(), counter)
|
||||
generateReq := &proto.GenerateRequest{
|
||||
RequestId: requestID,
|
||||
Tokenized: &proto.TokenizedInput{
|
||||
OriginalText: preprocessed.PromptText,
|
||||
InputIds: preprocessed.TokenIDs,
|
||||
},
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
// Set sampling parameters
|
||||
samplingParams := &proto.SamplingParams{
|
||||
Temperature: 1.0,
|
||||
TopP: 1.0,
|
||||
TopK: -1,
|
||||
SkipSpecialTokens: true,
|
||||
}
|
||||
|
||||
if temp, ok := reqMap["temperature"].(float64); ok {
|
||||
samplingParams.Temperature = float32(temp)
|
||||
}
|
||||
if topP, ok := reqMap["top_p"].(float64); ok {
|
||||
samplingParams.TopP = float32(topP)
|
||||
}
|
||||
if topK, ok := reqMap["top_k"].(float64); ok {
|
||||
samplingParams.TopK = int32(topK)
|
||||
}
|
||||
var maxTokensInt *int32
|
||||
if maxCompletionTokens, ok := reqMap["max_completion_tokens"].(float64); ok {
|
||||
tokens := int32(maxCompletionTokens)
|
||||
maxTokensInt = &tokens
|
||||
} else if maxTokens, ok := reqMap["max_tokens"].(float64); ok {
|
||||
tokens := int32(maxTokens)
|
||||
maxTokensInt = &tokens
|
||||
}
|
||||
if maxTokensInt != nil {
|
||||
samplingParams.MaxNewTokens = maxTokensInt
|
||||
}
|
||||
|
||||
// Parse tool constraints if available
|
||||
if preprocessed.ToolConstraintsJSON != "" {
|
||||
var toolConstraints map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(preprocessed.ToolConstraintsJSON), &toolConstraints); err == nil {
|
||||
if regex, ok := toolConstraints["regex"].(string); ok {
|
||||
samplingParams.Constraint = &proto.SamplingParams_Regex{Regex: regex}
|
||||
} else if jsonSchema, ok := toolConstraints["json_schema"].(string); ok {
|
||||
samplingParams.Constraint = &proto.SamplingParams_JsonSchema{JsonSchema: jsonSchema}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
generateReq.SamplingParams = samplingParams
|
||||
generateReq.Timestamp = timestamppb.Now()
|
||||
|
||||
stream, err := c.client.Generate(ctx, generateReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gRPC stream: %w", err)
|
||||
}
|
||||
toolsJSON := ""
|
||||
if tools, ok := reqMap["tools"].([]interface{}); ok && len(tools) > 0 {
|
||||
toolsBytes, _ := json.Marshal(tools)
|
||||
toolsJSON = string(toolsBytes)
|
||||
}
|
||||
|
||||
toolChoiceJSON := ""
|
||||
if toolChoice, ok := reqMap["tool_choice"]; ok {
|
||||
toolChoiceBytes, _ := json.Marshal(toolChoice)
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
|
||||
stopJSON := ""
|
||||
if stop, ok := reqMap["stop"]; ok {
|
||||
stopBytes, _ := json.Marshal(stop)
|
||||
stopJSON = string(stopBytes)
|
||||
}
|
||||
|
||||
stopTokenIDs := []uint32{}
|
||||
if stopTokenIDsVal, ok := reqMap["stop_token_ids"].([]interface{}); ok {
|
||||
for _, id := range stopTokenIDsVal {
|
||||
if idFloat, ok := id.(float64); ok {
|
||||
stopTokenIDs = append(stopTokenIDs, uint32(idFloat))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
skipSpecialTokens := true
|
||||
if skipSpecialTokensVal, ok := reqMap["skip_special_tokens"].(bool); ok {
|
||||
skipSpecialTokens = skipSpecialTokensVal
|
||||
}
|
||||
|
||||
if c.tokenizerHandle == nil {
|
||||
stream.CloseSend()
|
||||
return nil, fmt.Errorf("tokenizer handle is nil (should be created at startup)")
|
||||
}
|
||||
|
||||
converterHandle, err := ffi.CreateGrpcResponseConverterWithTokenizer(
|
||||
c.tokenizerHandle,
|
||||
model,
|
||||
generateReq.RequestId,
|
||||
toolsJSON,
|
||||
toolChoiceJSON,
|
||||
stopJSON,
|
||||
stopTokenIDs,
|
||||
skipSpecialTokens,
|
||||
preprocessed.PromptTokens, // Pass initial prompt tokens from preprocessing
|
||||
)
|
||||
if err != nil {
|
||||
stream.CloseSend()
|
||||
return nil, fmt.Errorf("failed to create converter handle: %w", err)
|
||||
}
|
||||
|
||||
batchSize := 1
|
||||
batchPostprocessor := ffi.NewBatchPostprocessor(converterHandle, batchSize, 0)
|
||||
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
grpcStream := &GrpcChatCompletionStream{
|
||||
stream: stream,
|
||||
converterHandle: converterHandle,
|
||||
batchPostprocessor: batchPostprocessor,
|
||||
batchSize: batchSize,
|
||||
ctx: streamCtx,
|
||||
cancel: cancel,
|
||||
resultJSONChan: make(chan string, c.bufferSizes.ResultJSONChan),
|
||||
errChan: make(chan error, c.bufferSizes.ErrChan),
|
||||
readLoopDone: make(chan struct{}),
|
||||
requestID: generateReq.RequestId,
|
||||
model: model,
|
||||
processWg: sync.WaitGroup{},
|
||||
closeTimeout: c.timeouts.CloseTimeout,
|
||||
bufferSizes: c.bufferSizes,
|
||||
}
|
||||
|
||||
go grpcStream.readLoop()
|
||||
|
||||
return grpcStream, nil
|
||||
}
|
||||
|
||||
// GrpcChatCompletionStream represents a streaming chat completion via gRPC
|
||||
type GrpcChatCompletionStream struct {
|
||||
stream grpcClientStream
|
||||
converterHandle *ffi.GrpcResponseConverterHandle
|
||||
batchPostprocessor *ffi.BatchPostprocessor
|
||||
batchSize int
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closed int32
|
||||
resultJSONChan chan string
|
||||
errChan chan error
|
||||
readLoopDone chan struct{}
|
||||
requestID string
|
||||
model string
|
||||
processWg sync.WaitGroup
|
||||
closeTimeout time.Duration
|
||||
bufferSizes ChannelBufferSizes
|
||||
clientDisconnected int32 // Atomic flag: 1 if client disconnected, 0 otherwise
|
||||
}
|
||||
|
||||
func (s *GrpcChatCompletionStream) readLoop() {
|
||||
defer func() {
|
||||
atomic.StoreInt32(&s.closed, 1)
|
||||
s.processWg.Wait()
|
||||
close(s.resultJSONChan)
|
||||
close(s.errChan)
|
||||
close(s.readLoopDone)
|
||||
// Cancel context after channels are closed to ensure errors are read first
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
recvChan := make(chan recvResult, s.bufferSizes.RecvChan)
|
||||
const firstRecvTimeout = 60 * time.Second
|
||||
|
||||
go func() {
|
||||
defer close(recvChan)
|
||||
recvCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
// Skip CloseSend() if client disconnected
|
||||
if atomic.LoadInt32(&s.clientDisconnected) == 0 {
|
||||
_ = s.stream.CloseSend()
|
||||
}
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
recvCount++
|
||||
var protoResp *proto.GenerateResponse
|
||||
var err error
|
||||
|
||||
// First Recv() with timeout
|
||||
if recvCount == 1 {
|
||||
recvDone := make(chan recvResult, 1)
|
||||
go func() {
|
||||
resp, recvErr := s.stream.Recv()
|
||||
recvDone <- recvResult{resp: resp, err: recvErr}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-recvDone:
|
||||
protoResp = result.resp
|
||||
err = result.err
|
||||
case <-time.After(firstRecvTimeout):
|
||||
timeoutErr := fmt.Errorf("stream.Recv() timeout after %v: backend may not be responding (request_id=%s)", firstRecvTimeout, s.requestID)
|
||||
select {
|
||||
case recvChan <- recvResult{resp: nil, err: timeoutErr}:
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
return
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Normal Recv()
|
||||
protoResp, err = s.stream.Recv()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case recvChan <- recvResult{resp: nil, err: err}:
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
// Skip CloseSend() if client disconnected
|
||||
if atomic.LoadInt32(&s.clientDisconnected) == 0 {
|
||||
_ = s.stream.CloseSend()
|
||||
}
|
||||
return
|
||||
case recvChan <- recvResult{resp: protoResp, err: nil}:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
// Skip CloseSend() if client disconnected
|
||||
if atomic.LoadInt32(&s.clientDisconnected) == 0 {
|
||||
_ = s.stream.CloseSend()
|
||||
}
|
||||
return
|
||||
case result, ok := <-recvChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if result.err != nil {
|
||||
if result.err == io.EOF {
|
||||
results, flushErr := s.flushBatch()
|
||||
if flushErr != nil {
|
||||
select {
|
||||
case s.errChan <- fmt.Errorf("failed to flush batch: %w", flushErr):
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, resultJSON := range results {
|
||||
select {
|
||||
case s.resultJSONChan <- resultJSON:
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case s.errChan <- result.err:
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result.resp != nil {
|
||||
s.processWg.Add(1)
|
||||
go func(resp *proto.GenerateResponse) {
|
||||
defer s.processWg.Done()
|
||||
s.processAndSendResponse(resp)
|
||||
}(result.resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GrpcChatCompletionStream) processAndSendResponse(protoResp *proto.GenerateResponse) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if protoResp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
protoJSON, err := protoToJSON(protoResp)
|
||||
if err != nil {
|
||||
select {
|
||||
case s.errChan <- fmt.Errorf("failed to convert proto to JSON: %w", err):
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if s.batchPostprocessor == nil {
|
||||
select {
|
||||
case s.errChan <- fmt.Errorf("batch postprocessor is nil"):
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
results, _, err := s.batchPostprocessor.AddChunk(protoJSON)
|
||||
if err != nil {
|
||||
select {
|
||||
case s.errChan <- fmt.Errorf("batch postprocessing failed: %w", err):
|
||||
case <-s.ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, resultJSON := range results {
|
||||
select {
|
||||
case s.resultJSONChan <- resultJSON:
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GrpcChatCompletionStream) RecvJSON() (string, error) {
|
||||
// Use a loop instead of recursion to avoid stack overflow if there are many empty strings
|
||||
for {
|
||||
// Check errChan first to prioritize actual errors over context cancellation
|
||||
select {
|
||||
case err, ok := <-s.errChan:
|
||||
if !ok {
|
||||
return "", io.EOF
|
||||
}
|
||||
return "", err
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case resultJSON, ok := <-s.resultJSONChan:
|
||||
if !ok {
|
||||
return "", io.EOF
|
||||
}
|
||||
// Skip empty strings and continue loop instead of recursing
|
||||
if resultJSON != "" {
|
||||
return resultJSON, nil
|
||||
}
|
||||
// Empty string, continue loop to get next result
|
||||
continue
|
||||
case err, ok := <-s.errChan:
|
||||
if !ok {
|
||||
return "", io.EOF
|
||||
}
|
||||
return "", err
|
||||
case <-s.ctx.Done():
|
||||
return "", s.ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetClientDisconnected marks that the client has disconnected.
|
||||
// When Close() is called, it will not call CloseSend() to avoid aborting the request on server side.
|
||||
func (s *GrpcChatCompletionStream) SetClientDisconnected() {
|
||||
atomic.StoreInt32(&s.clientDisconnected, 1)
|
||||
}
|
||||
|
||||
func (s *GrpcChatCompletionStream) Close() error {
|
||||
if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
clientDisconnected := atomic.LoadInt32(&s.clientDisconnected) == 1
|
||||
|
||||
select {
|
||||
case <-s.readLoopDone:
|
||||
// readLoop completed
|
||||
default:
|
||||
if !clientDisconnected {
|
||||
// Call CloseSend() if client didn't disconnect
|
||||
_ = s.stream.CloseSend()
|
||||
}
|
||||
select {
|
||||
case <-s.readLoopDone:
|
||||
case <-time.After(s.closeTimeout):
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = s.flushBatch()
|
||||
|
||||
if s.converterHandle != nil {
|
||||
ffi.FreeGrpcResponseConverter(s.converterHandle)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GrpcChatCompletionStream) flushBatch() ([]string, error) {
|
||||
if s.batchPostprocessor != nil {
|
||||
results, err := s.batchPostprocessor.Flush()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch flush failed: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func protoToJSON(resp *proto.GenerateResponse) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.Grow(500)
|
||||
|
||||
sb.WriteString(`{"request_id":`)
|
||||
if resp.RequestId == "" {
|
||||
sb.WriteString(`""`)
|
||||
} else {
|
||||
requestIDJSON, err := json.Marshal(resp.RequestId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.Write(requestIDJSON)
|
||||
}
|
||||
|
||||
switch r := resp.Response.(type) {
|
||||
case *proto.GenerateResponse_Chunk:
|
||||
sb.WriteString(`,"chunk":{`)
|
||||
sb.WriteString(`"token_ids":`)
|
||||
tokenIDsJSON, err := json.Marshal(r.Chunk.TokenIds)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.Write(tokenIDsJSON)
|
||||
sb.WriteString(`,"prompt_tokens":`)
|
||||
sb.WriteString(strconv.FormatInt(int64(r.Chunk.PromptTokens), 10))
|
||||
sb.WriteString(`,"completion_tokens":`)
|
||||
sb.WriteString(strconv.FormatInt(int64(r.Chunk.CompletionTokens), 10))
|
||||
sb.WriteString(`,"cached_tokens":`)
|
||||
sb.WriteString(strconv.FormatInt(int64(r.Chunk.CachedTokens), 10))
|
||||
sb.WriteString(`,"index":`)
|
||||
sb.WriteString(strconv.FormatInt(int64(r.Chunk.Index), 10))
|
||||
sb.WriteString(`}`)
|
||||
case *proto.GenerateResponse_Complete:
|
||||
sb.WriteString(`,"complete":{`)
|
||||
sb.WriteString(`"output_ids":`)
|
||||
outputIDsJSON, err := json.Marshal(r.Complete.OutputIds)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.Write(outputIDsJSON)
|
||||
sb.WriteString(`,"finish_reason":`)
|
||||
finishReasonJSON, err := json.Marshal(r.Complete.FinishReason)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.Write(finishReasonJSON)
|
||||
sb.WriteString(`,"prompt_tokens":`)
|
||||
sb.WriteString(strconv.FormatInt(int64(r.Complete.PromptTokens), 10))
|
||||
sb.WriteString(`,"completion_tokens":`)
|
||||
sb.WriteString(strconv.FormatInt(int64(r.Complete.CompletionTokens), 10))
|
||||
sb.WriteString(`,"cached_tokens":`)
|
||||
sb.WriteString(strconv.FormatInt(int64(r.Complete.CachedTokens), 10))
|
||||
sb.WriteString(`}`)
|
||||
case *proto.GenerateResponse_Error:
|
||||
sb.WriteString(`,"error":{`)
|
||||
sb.WriteString(`"message":`)
|
||||
messageJSON, err := json.Marshal(r.Error.Message)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.Write(messageJSON)
|
||||
sb.WriteString(`,"http_status_code":`)
|
||||
httpStatusCodeJSON, err := json.Marshal(r.Error.HttpStatusCode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.Write(httpStatusCodeJSON)
|
||||
if r.Error.Details != "" {
|
||||
sb.WriteString(`,"details":`)
|
||||
detailsJSON, err := json.Marshal(r.Error.Details)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.Write(detailsJSON)
|
||||
}
|
||||
sb.WriteString(`}`)
|
||||
}
|
||||
|
||||
sb.WriteString(`}`)
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
type ChatCompletionStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
Choices []StreamChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// StreamChoice represents a choice in a streaming response
|
||||
type StreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta MessageDelta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// MessageDelta represents incremental message updates
|
||||
type MessageDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall represents a tool call in the response
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionCall represents a function call
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// Usage represents token usage information
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
3325
third_party/sglang/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler.pb.go
vendored
Normal file
3325
third_party/sglang/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler.pb.go
vendored
Normal file
File diff suppressed because it is too large
Load Diff
333
third_party/sglang/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler_grpc.pb.go
vendored
Normal file
333
third_party/sglang/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler_grpc.pb.go
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v3.21.12
|
||||
// source: sglang_scheduler.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
SglangScheduler_Generate_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Generate"
|
||||
SglangScheduler_Embed_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Embed"
|
||||
SglangScheduler_HealthCheck_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/HealthCheck"
|
||||
SglangScheduler_Abort_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Abort"
|
||||
SglangScheduler_GetModelInfo_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/GetModelInfo"
|
||||
SglangScheduler_GetServerInfo_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/GetServerInfo"
|
||||
)
|
||||
|
||||
// SglangSchedulerClient is the client API for SglangScheduler service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
//
|
||||
// Service definition for SGLang scheduler communication
|
||||
// This protocol bridges the Rust router and Python scheduler
|
||||
type SglangSchedulerClient interface {
|
||||
// Submit a generation request (supports streaming)
|
||||
Generate(ctx context.Context, in *GenerateRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GenerateResponse], error)
|
||||
// Submit an embedding request
|
||||
Embed(ctx context.Context, in *EmbedRequest, opts ...grpc.CallOption) (*EmbedResponse, error)
|
||||
// Health check and metrics
|
||||
HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error)
|
||||
// Abort a running request
|
||||
Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*AbortResponse, error)
|
||||
// Get model information
|
||||
GetModelInfo(ctx context.Context, in *GetModelInfoRequest, opts ...grpc.CallOption) (*GetModelInfoResponse, error)
|
||||
// Get server information
|
||||
GetServerInfo(ctx context.Context, in *GetServerInfoRequest, opts ...grpc.CallOption) (*GetServerInfoResponse, error)
|
||||
}
|
||||
|
||||
type sglangSchedulerClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewSglangSchedulerClient(cc grpc.ClientConnInterface) SglangSchedulerClient {
|
||||
return &sglangSchedulerClient{cc}
|
||||
}
|
||||
|
||||
func (c *sglangSchedulerClient) Generate(ctx context.Context, in *GenerateRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GenerateResponse], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &SglangScheduler_ServiceDesc.Streams[0], SglangScheduler_Generate_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[GenerateRequest, GenerateResponse]{ClientStream: stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type SglangScheduler_GenerateClient = grpc.ServerStreamingClient[GenerateResponse]
|
||||
|
||||
func (c *sglangSchedulerClient) Embed(ctx context.Context, in *EmbedRequest, opts ...grpc.CallOption) (*EmbedResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(EmbedResponse)
|
||||
err := c.cc.Invoke(ctx, SglangScheduler_Embed_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *sglangSchedulerClient) HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(HealthCheckResponse)
|
||||
err := c.cc.Invoke(ctx, SglangScheduler_HealthCheck_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *sglangSchedulerClient) Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*AbortResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(AbortResponse)
|
||||
err := c.cc.Invoke(ctx, SglangScheduler_Abort_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *sglangSchedulerClient) GetModelInfo(ctx context.Context, in *GetModelInfoRequest, opts ...grpc.CallOption) (*GetModelInfoResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetModelInfoResponse)
|
||||
err := c.cc.Invoke(ctx, SglangScheduler_GetModelInfo_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *sglangSchedulerClient) GetServerInfo(ctx context.Context, in *GetServerInfoRequest, opts ...grpc.CallOption) (*GetServerInfoResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetServerInfoResponse)
|
||||
err := c.cc.Invoke(ctx, SglangScheduler_GetServerInfo_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// SglangSchedulerServer is the server API for SglangScheduler service.
|
||||
// All implementations must embed UnimplementedSglangSchedulerServer
|
||||
// for forward compatibility.
|
||||
//
|
||||
// Service definition for SGLang scheduler communication
|
||||
// This protocol bridges the Rust router and Python scheduler
|
||||
type SglangSchedulerServer interface {
|
||||
// Submit a generation request (supports streaming)
|
||||
Generate(*GenerateRequest, grpc.ServerStreamingServer[GenerateResponse]) error
|
||||
// Submit an embedding request
|
||||
Embed(context.Context, *EmbedRequest) (*EmbedResponse, error)
|
||||
// Health check and metrics
|
||||
HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error)
|
||||
// Abort a running request
|
||||
Abort(context.Context, *AbortRequest) (*AbortResponse, error)
|
||||
// Get model information
|
||||
GetModelInfo(context.Context, *GetModelInfoRequest) (*GetModelInfoResponse, error)
|
||||
// Get server information
|
||||
GetServerInfo(context.Context, *GetServerInfoRequest) (*GetServerInfoResponse, error)
|
||||
mustEmbedUnimplementedSglangSchedulerServer()
|
||||
}
|
||||
|
||||
// UnimplementedSglangSchedulerServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedSglangSchedulerServer struct{}
|
||||
|
||||
func (UnimplementedSglangSchedulerServer) Generate(*GenerateRequest, grpc.ServerStreamingServer[GenerateResponse]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Generate not implemented")
|
||||
}
|
||||
func (UnimplementedSglangSchedulerServer) Embed(context.Context, *EmbedRequest) (*EmbedResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Embed not implemented")
|
||||
}
|
||||
func (UnimplementedSglangSchedulerServer) HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method HealthCheck not implemented")
|
||||
}
|
||||
func (UnimplementedSglangSchedulerServer) Abort(context.Context, *AbortRequest) (*AbortResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Abort not implemented")
|
||||
}
|
||||
func (UnimplementedSglangSchedulerServer) GetModelInfo(context.Context, *GetModelInfoRequest) (*GetModelInfoResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetModelInfo not implemented")
|
||||
}
|
||||
func (UnimplementedSglangSchedulerServer) GetServerInfo(context.Context, *GetServerInfoRequest) (*GetServerInfoResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetServerInfo not implemented")
|
||||
}
|
||||
func (UnimplementedSglangSchedulerServer) mustEmbedUnimplementedSglangSchedulerServer() {}
|
||||
func (UnimplementedSglangSchedulerServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeSglangSchedulerServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to SglangSchedulerServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeSglangSchedulerServer interface {
|
||||
mustEmbedUnimplementedSglangSchedulerServer()
|
||||
}
|
||||
|
||||
func RegisterSglangSchedulerServer(s grpc.ServiceRegistrar, srv SglangSchedulerServer) {
|
||||
// If the following call pancis, it indicates UnimplementedSglangSchedulerServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&SglangScheduler_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _SglangScheduler_Generate_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(GenerateRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(SglangSchedulerServer).Generate(m, &grpc.GenericServerStream[GenerateRequest, GenerateResponse]{ServerStream: stream})
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type SglangScheduler_GenerateServer = grpc.ServerStreamingServer[GenerateResponse]
|
||||
|
||||
func _SglangScheduler_Embed_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EmbedRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(SglangSchedulerServer).Embed(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: SglangScheduler_Embed_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(SglangSchedulerServer).Embed(ctx, req.(*EmbedRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _SglangScheduler_HealthCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(HealthCheckRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(SglangSchedulerServer).HealthCheck(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: SglangScheduler_HealthCheck_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(SglangSchedulerServer).HealthCheck(ctx, req.(*HealthCheckRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _SglangScheduler_Abort_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(AbortRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(SglangSchedulerServer).Abort(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: SglangScheduler_Abort_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(SglangSchedulerServer).Abort(ctx, req.(*AbortRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _SglangScheduler_GetModelInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetModelInfoRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(SglangSchedulerServer).GetModelInfo(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: SglangScheduler_GetModelInfo_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(SglangSchedulerServer).GetModelInfo(ctx, req.(*GetModelInfoRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _SglangScheduler_GetServerInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetServerInfoRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(SglangSchedulerServer).GetServerInfo(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: SglangScheduler_GetServerInfo_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(SglangSchedulerServer).GetServerInfo(ctx, req.(*GetServerInfoRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// SglangScheduler_ServiceDesc is the grpc.ServiceDesc for SglangScheduler service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var SglangScheduler_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "sglang.grpc.scheduler.SglangScheduler",
|
||||
HandlerType: (*SglangSchedulerServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Embed",
|
||||
Handler: _SglangScheduler_Embed_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "HealthCheck",
|
||||
Handler: _SglangScheduler_HealthCheck_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Abort",
|
||||
Handler: _SglangScheduler_Abort_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetModelInfo",
|
||||
Handler: _SglangScheduler_GetModelInfo_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetServerInfo",
|
||||
Handler: _SglangScheduler_GetServerInfo_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Generate",
|
||||
Handler: _SglangScheduler_Generate_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "sglang_scheduler.proto",
|
||||
}
|
||||
279
third_party/sglang/sgl-model-gateway/bindings/golang/src/client.rs
vendored
Normal file
279
third_party/sglang/sgl-model-gateway/bindings/golang/src/client.rs
vendored
Normal file
@@ -0,0 +1,279 @@
|
||||
//! Client SDK FFI functions
|
||||
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::os::raw::{c_char};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
use tokio::runtime::Runtime;
|
||||
use once_cell::sync::Lazy;
|
||||
use uuid::Uuid;
|
||||
|
||||
use smg::tokenizer::create_tokenizer_from_file;
|
||||
use smg::tokenizer::traits::Tokenizer;
|
||||
use smg_grpc_client::sglang_scheduler::SglangSchedulerClient;
|
||||
use smg::protocols::chat::ChatCompletionRequest;
|
||||
use smg::routers::grpc::utils::{process_chat_messages, generate_tool_constraints};
|
||||
|
||||
use super::error::{SglErrorCode, set_error_message};
|
||||
use super::grpc_converter::sgl_grpc_response_converter_create;
|
||||
use super::tokenizer::TokenizerHandle;
|
||||
use super::stream::SglangStreamHandle;
|
||||
|
||||
/// Global tokio runtime for async operations
|
||||
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
|
||||
Runtime::new().expect("Failed to create tokio runtime for client FFI")
|
||||
});
|
||||
|
||||
/// Handle for complete client SDK (gRPC client + tokenizer)
|
||||
/// This handle manages the connection to sglang and provides a complete SDK interface
|
||||
pub struct SglangClientHandle {
|
||||
pub(crate) client: Arc<SglangSchedulerClient>,
|
||||
pub(crate) tokenizer: Arc<dyn Tokenizer>,
|
||||
}
|
||||
|
||||
/// Handle for streaming request (includes prompt token count)
|
||||
#[allow(dead_code)]
|
||||
pub struct StreamRequestState {
|
||||
pub(crate) prompt_tokens: i32, // Number of prompt tokens for this request
|
||||
}
|
||||
|
||||
/// Create a new SGLang client handle
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `endpoint` - gRPC endpoint (e.g., "grpc://localhost:20000")
|
||||
/// * `tokenizer_path` - Path to tokenizer directory
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * Pointer to SglangClientHandle on success, null on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_client_create(
|
||||
endpoint: *const c_char,
|
||||
tokenizer_path: *const c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> *mut SglangClientHandle {
|
||||
if endpoint.is_null() || tokenizer_path.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
|
||||
let endpoint_str = match CStr::from_ptr(endpoint).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in endpoint");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
let tokenizer_path_str = match CStr::from_ptr(tokenizer_path).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in tokenizer_path");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
// Create tokenizer
|
||||
let tokenizer = match create_tokenizer_from_file(tokenizer_path_str) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create tokenizer: {}", e));
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
// Create gRPC client
|
||||
let client = match RUNTIME.block_on(async {
|
||||
SglangSchedulerClient::connect(endpoint_str).await
|
||||
}) {
|
||||
Ok(c) => Arc::new(c),
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to connect to endpoint: {}", e));
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
Box::into_raw(Box::new(SglangClientHandle {
|
||||
client,
|
||||
tokenizer,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Free a client handle
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_client_free(handle: *mut SglangClientHandle) {
|
||||
if !handle.is_null() {
|
||||
let _ = Box::from_raw(handle);
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a chat completion request and start streaming
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `client_handle` - Client handle
|
||||
/// * `request_json` - OpenAI ChatCompletionRequest as JSON string
|
||||
/// * `stream_handle_out` - Pointer to receive stream handle
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_client_chat_completion_stream(
|
||||
client_handle: *mut SglangClientHandle,
|
||||
request_json: *const c_char,
|
||||
stream_handle_out: *mut *mut SglangStreamHandle,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if client_handle.is_null() || request_json.is_null() || stream_handle_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let request_str = match CStr::from_ptr(request_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in request_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
let client_ref = &*client_handle;
|
||||
let client = Arc::clone(&client_ref.client);
|
||||
let tokenizer = Arc::clone(&client_ref.tokenizer);
|
||||
|
||||
// Parse OpenAI ChatCompletionRequest
|
||||
let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse request JSON: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Process messages and apply chat template
|
||||
let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to process messages: {}", e));
|
||||
return SglErrorCode::TokenizationError;
|
||||
}
|
||||
};
|
||||
|
||||
// Tokenize
|
||||
let token_ids = match tokenizer.encode(&processed_messages.text, false) {
|
||||
Ok(encoding) => encoding.token_ids().to_vec(),
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to tokenize: {}", e));
|
||||
return SglErrorCode::TokenizationError;
|
||||
}
|
||||
};
|
||||
let prompt_tokens = token_ids.len() as i32; // Save prompt token count
|
||||
|
||||
// Generate tool constraints if needed
|
||||
let tool_constraint = if let Some(tools) = chat_request.tools.as_ref() {
|
||||
match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) {
|
||||
Ok(Some((constraint_type, constraint_value))) => Some((constraint_type, constraint_value)),
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to generate tool constraints: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build GenerateRequest
|
||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let proto_request = match client.build_generate_request_from_chat(
|
||||
request_id.clone(),
|
||||
&chat_request,
|
||||
processed_messages.text,
|
||||
token_ids,
|
||||
processed_messages.multimodal_inputs,
|
||||
tool_constraint,
|
||||
) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to build generate request: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Send request and get stream
|
||||
let stream = match RUNTIME.block_on(async {
|
||||
client.generate(proto_request).await
|
||||
}) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to send request: {}", e));
|
||||
return SglErrorCode::UnknownError;
|
||||
}
|
||||
};
|
||||
|
||||
// Create response converter
|
||||
let tools_json = chat_request.tools.as_ref()
|
||||
.and_then(|t| serde_json::to_string(t).ok())
|
||||
.map(|s| CString::new(s).unwrap().into_raw());
|
||||
let tool_choice_json = chat_request.tool_choice.as_ref()
|
||||
.and_then(|tc| serde_json::to_string(tc).ok())
|
||||
.map(|s| CString::new(s).unwrap().into_raw());
|
||||
let stop_json = chat_request.stop.as_ref()
|
||||
.and_then(|s| serde_json::to_string(s).ok())
|
||||
.map(|s| CString::new(s).unwrap().into_raw());
|
||||
let stop_token_ids_json = chat_request.stop_token_ids.as_ref()
|
||||
.and_then(|ids| serde_json::to_string(ids).ok())
|
||||
.map(|s| CString::new(s).unwrap().into_raw());
|
||||
|
||||
// Create tokenizer handle for converter (we'll create a temporary one)
|
||||
let tokenizer_handle = Box::into_raw(Box::new(TokenizerHandle {
|
||||
tokenizer: Arc::clone(&tokenizer),
|
||||
}));
|
||||
|
||||
let converter = sgl_grpc_response_converter_create(
|
||||
tokenizer_handle,
|
||||
CString::new(chat_request.model.clone()).unwrap().as_ptr(),
|
||||
CString::new(request_id.clone()).unwrap().as_ptr(),
|
||||
tools_json.unwrap_or(ptr::null_mut()),
|
||||
tool_choice_json.unwrap_or(ptr::null_mut()),
|
||||
stop_json.unwrap_or(ptr::null_mut()),
|
||||
stop_token_ids_json.unwrap_or(ptr::null_mut()),
|
||||
if chat_request.skip_special_tokens { 1 } else { 0 },
|
||||
error_out,
|
||||
);
|
||||
|
||||
// Free temporary tokenizer handle (converter now owns the tokenizer)
|
||||
let _ = Box::from_raw(tokenizer_handle);
|
||||
|
||||
if converter.is_null() {
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
|
||||
// Clean up temporary CStrings
|
||||
if let Some(ptr) = tools_json {
|
||||
let _ = CString::from_raw(ptr);
|
||||
}
|
||||
if let Some(ptr) = tool_choice_json {
|
||||
let _ = CString::from_raw(ptr);
|
||||
}
|
||||
if let Some(ptr) = stop_json {
|
||||
let _ = CString::from_raw(ptr);
|
||||
}
|
||||
if let Some(ptr) = stop_token_ids_json {
|
||||
let _ = CString::from_raw(ptr);
|
||||
}
|
||||
|
||||
// Create converter handle and set initial_prompt_tokens immediately
|
||||
let mut converter_handle = *Box::from_raw(converter);
|
||||
converter_handle.initial_prompt_tokens = Some(prompt_tokens);
|
||||
|
||||
// Create stream handle with prompt_tokens
|
||||
*stream_handle_out = Box::into_raw(Box::new(SglangStreamHandle {
|
||||
stream: Arc::new(tokio::sync::Mutex::new(stream)),
|
||||
converter: Arc::new(tokio::sync::Mutex::new(converter_handle)),
|
||||
client: Arc::clone(&client),
|
||||
prompt_tokens,
|
||||
}));
|
||||
|
||||
SglErrorCode::Success
|
||||
}
|
||||
50
third_party/sglang/sgl-model-gateway/bindings/golang/src/error.rs
vendored
Normal file
50
third_party/sglang/sgl-model-gateway/bindings/golang/src/error.rs
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
//! Error handling for FFI functions
|
||||
|
||||
use std::ffi::CString;
|
||||
use std::os::raw::c_char;
|
||||
use std::ptr;
|
||||
|
||||
/// Error codes returned by FFI functions
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SglErrorCode {
|
||||
Success = 0,
|
||||
InvalidArgument = 1,
|
||||
TokenizationError = 2,
|
||||
ParsingError = 3,
|
||||
MemoryError = 4,
|
||||
UnknownError = 99,
|
||||
}
|
||||
|
||||
/// Helper to set error message in FFI output parameter
|
||||
pub fn set_error_message(error_out: *mut *mut c_char, message: &str) {
|
||||
unsafe {
|
||||
if !error_out.is_null() {
|
||||
if let Ok(cstr) = CString::new(message) {
|
||||
*error_out = cstr.into_raw();
|
||||
} else {
|
||||
*error_out = ptr::null_mut();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to set error message from format string
|
||||
pub fn set_error_message_fmt(error_out: *mut *mut c_char, fmt: std::fmt::Arguments) {
|
||||
if !error_out.is_null() {
|
||||
let msg = format!("{}", fmt);
|
||||
set_error_message(error_out, &msg);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to clear error message
|
||||
pub fn clear_error_message(error_out: *mut *mut c_char) {
|
||||
unsafe {
|
||||
if !error_out.is_null() {
|
||||
*error_out = ptr::null_mut();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for error handling
|
||||
// Note: Some helper functions are kept for potential future use
|
||||
758
third_party/sglang/sgl-model-gateway/bindings/golang/src/grpc_converter.rs
vendored
Normal file
758
third_party/sglang/sgl-model-gateway/bindings/golang/src/grpc_converter.rs
vendored
Normal file
@@ -0,0 +1,758 @@
|
||||
//! gRPC response converter FFI functions
|
||||
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use serde_json::Value;
|
||||
use tokio::runtime::Runtime;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use smg::tokenizer::traits::Tokenizer;
|
||||
use smg::tokenizer::stream::DecodeStream;
|
||||
use smg::tool_parser::ToolParser;
|
||||
use smg::protocols::common::{Tool, ToolChoice, ToolChoiceValue, ToolCallDelta, FunctionCallDelta, Usage, StringOrArray};
|
||||
use smg::tokenizer::stop::StopSequenceDecoder;
|
||||
use smg_grpc_client::sglang_proto as proto;
|
||||
|
||||
use super::error::{SglErrorCode, set_error_message, clear_error_message};
|
||||
use super::tokenizer::TokenizerHandle;
|
||||
use super::utils::generate_tool_call_id;
|
||||
|
||||
/// Global parser factory (initialized once)
|
||||
// Use the re-exported ParserFactory from tool_parser module
|
||||
static PARSER_FACTORY: Lazy<smg::tool_parser::ParserFactory> = Lazy::new(|| {
|
||||
// ParserFactory is re-exported from tool_parser::factory, so we can use it directly
|
||||
smg::tool_parser::ParserFactory::default()
|
||||
});
|
||||
|
||||
/// Global tokio runtime for async operations
|
||||
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
|
||||
Runtime::new().expect("Failed to create tokio runtime for gRPC converter FFI")
|
||||
});
|
||||
|
||||
/// Handle for gRPC response converter (maintains state for streaming)
|
||||
#[repr(C)]
|
||||
pub struct GrpcResponseConverterHandle {
|
||||
pub(crate) tokenizer: Arc<dyn Tokenizer>,
|
||||
pub(crate) tool_parser: Option<Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>>,
|
||||
pub(crate) stop_decoder: Option<Arc<tokio::sync::Mutex<StopSequenceDecoder>>>,
|
||||
pub(crate) model: String,
|
||||
pub(crate) request_id: String,
|
||||
pub(crate) created: u64,
|
||||
pub(crate) system_fingerprint: Option<String>,
|
||||
pub(crate) tools: Option<Vec<Tool>>,
|
||||
pub(crate) tool_choice: Option<ToolChoice>,
|
||||
pub(crate) history_tool_calls_count: usize,
|
||||
pub(crate) stream_buffers: HashMap<u32, String>, // Per-index text buffers
|
||||
pub(crate) decode_streams: HashMap<u32, DecodeStream>, // Per-index incremental decoders
|
||||
pub(crate) has_tool_calls: HashMap<u32, bool>, // Track if tool calls were emitted
|
||||
pub(crate) is_first_chunk: HashMap<u32, bool>, // Track first chunk per index
|
||||
pub(crate) prompt_tokens: HashMap<u32, i32>, // Track prompt tokens per index (from chunks)
|
||||
pub(crate) completion_tokens: HashMap<u32, i32>, // Track completion tokens per index (cumulative)
|
||||
pub(crate) initial_prompt_tokens: Option<i32>, // Initial prompt tokens from request (if available)
|
||||
pub(crate) skip_special_tokens: bool, // Whether to skip special tokens when decoding
|
||||
}
|
||||
|
||||
/// Create a gRPC response converter handle
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokenizer_handle` - Tokenizer handle (must be valid)
|
||||
/// * `model` - Model name
|
||||
/// * `request_id` - Request ID
|
||||
/// * `tools_json` - Optional JSON array of tools
|
||||
/// * `tool_choice_json` - Optional JSON object for tool_choice
|
||||
/// * `stop` - Optional stop sequences (JSON array)
|
||||
/// * `stop_token_ids` - Optional stop token IDs (JSON array)
|
||||
/// * `skip_special_tokens` - Whether to skip special tokens
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * Pointer to GrpcResponseConverterHandle on success, null on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_grpc_response_converter_create(
|
||||
tokenizer_handle: *mut TokenizerHandle,
|
||||
model: *const c_char,
|
||||
request_id: *const c_char,
|
||||
tools_json: *const c_char,
|
||||
tool_choice_json: *const c_char,
|
||||
stop: *const c_char,
|
||||
stop_token_ids: *const c_char,
|
||||
skip_special_tokens: c_int,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> *mut GrpcResponseConverterHandle {
|
||||
if tokenizer_handle.is_null() || model.is_null() || request_id.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
|
||||
let model_str = match CStr::from_ptr(model).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in model");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
let request_id_str = match CStr::from_ptr(request_id).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in request_id");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
let handle_ref = &*tokenizer_handle;
|
||||
let tokenizer = Arc::clone(&handle_ref.tokenizer);
|
||||
|
||||
// Parse tools if provided
|
||||
let tools: Option<Vec<Tool>> = if !tools_json.is_null() {
|
||||
match CStr::from_ptr(tools_json).to_str() {
|
||||
Ok(s) => serde_json::from_str::<Vec<Tool>>(s).ok(),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Parse tool_choice if provided
|
||||
let tool_choice: Option<ToolChoice> = if !tool_choice_json.is_null() {
|
||||
match CStr::from_ptr(tool_choice_json).to_str() {
|
||||
Ok(s) => serde_json::from_str::<ToolChoice>(s).ok(),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Parse stop sequences
|
||||
let stop: Option<StringOrArray> = if !stop.is_null() {
|
||||
let stop_str = match CStr::from_ptr(stop).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => return ptr::null_mut(),
|
||||
};
|
||||
serde_json::from_str::<StringOrArray>(stop_str).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Parse stop token IDs
|
||||
let stop_token_ids: Option<Vec<u32>> = if !stop_token_ids.is_null() {
|
||||
let ids_str = match CStr::from_ptr(stop_token_ids).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => return ptr::null_mut(),
|
||||
};
|
||||
serde_json::from_str::<Vec<u32>>(ids_str).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create stop decoder if needed
|
||||
let stop_decoder = if stop.is_some() || stop_token_ids.is_some() {
|
||||
Some(Arc::new(tokio::sync::Mutex::new(
|
||||
smg::routers::grpc::utils::create_stop_decoder(
|
||||
&tokenizer,
|
||||
stop.as_ref(),
|
||||
stop_token_ids.as_ref(),
|
||||
skip_special_tokens != 0,
|
||||
false, // no_stop_trim
|
||||
),
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create tool parser if tools are provided
|
||||
let tool_parser = if tools.is_some() {
|
||||
PARSER_FACTORY.registry().create_for_model(model_str)
|
||||
.map(|p| Arc::new(tokio::sync::Mutex::new(p)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Get system fingerprint from model (simplified)
|
||||
let system_fingerprint = Some("fp_placeholder".to_string()); // TODO: Get actual fingerprint
|
||||
|
||||
Box::into_raw(Box::new(GrpcResponseConverterHandle {
|
||||
tokenizer,
|
||||
tool_parser,
|
||||
stop_decoder,
|
||||
model: model_str.to_string(),
|
||||
request_id: request_id_str.to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
system_fingerprint,
|
||||
tools,
|
||||
tool_choice,
|
||||
history_tool_calls_count: 0,
|
||||
stream_buffers: HashMap::new(),
|
||||
decode_streams: HashMap::new(),
|
||||
has_tool_calls: HashMap::new(),
|
||||
is_first_chunk: HashMap::new(),
|
||||
prompt_tokens: HashMap::new(),
|
||||
completion_tokens: HashMap::new(),
|
||||
initial_prompt_tokens: None, // Will be set from stream handle
|
||||
skip_special_tokens: skip_special_tokens != 0,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert a gRPC GenerateResponse chunk to OpenAI format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handle` - Converter handle
|
||||
/// * `response_json` - JSON string of proto.GenerateResponse
|
||||
/// * `result_json_out` - Pointer to receive OpenAI format JSON (must be freed with sgl_free_string)
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_grpc_response_converter_convert_chunk(
|
||||
handle: *mut GrpcResponseConverterHandle,
|
||||
response_json: *const c_char,
|
||||
result_json_out: *mut *mut c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if handle.is_null() || response_json.is_null() || result_json_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let response_str = match CStr::from_ptr(response_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in response_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse proto.GenerateResponse from JSON
|
||||
let json_value: Value = match serde_json::from_str(response_str) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse response JSON: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Build proto::GenerateResponse from JSON value
|
||||
let mut proto_response = proto::GenerateResponse {
|
||||
request_id: json_value.get("request_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
response: None,
|
||||
};
|
||||
|
||||
// Parse the response oneof field
|
||||
if let Some(chunk_json) = json_value.get("chunk") {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
token_ids: chunk_json.get("token_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_u64().map(|n| n as u32)).collect())
|
||||
.unwrap_or_default(),
|
||||
prompt_tokens: chunk_json.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: chunk_json.get("completion_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
cached_tokens: chunk_json.get("cached_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
output_logprobs: None,
|
||||
hidden_states: vec![],
|
||||
input_logprobs: None,
|
||||
index: 0,
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Chunk(chunk));
|
||||
} else if let Some(complete_json) = json_value.get("complete") {
|
||||
let complete = proto::GenerateComplete {
|
||||
output_ids: complete_json.get("output_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_u64().map(|n| n as u32)).collect())
|
||||
.unwrap_or_default(),
|
||||
finish_reason: complete_json.get("finish_reason")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
prompt_tokens: complete_json.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: complete_json.get("completion_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
cached_tokens: complete_json.get("cached_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
output_logprobs: None,
|
||||
all_hidden_states: vec![],
|
||||
input_logprobs: None,
|
||||
matched_stop: None,
|
||||
index: 0,
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Complete(complete));
|
||||
} else if let Some(error_json) = json_value.get("error") {
|
||||
let error = proto::GenerateError {
|
||||
message: error_json.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
http_status_code: error_json.get("http_status_code")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("500")
|
||||
.to_string(),
|
||||
details: error_json.get("details")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Error(error));
|
||||
} else {
|
||||
set_error_message(error_out, "Response JSON must contain 'chunk', 'complete', or 'error' field");
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
|
||||
let handle_ref = &mut *handle;
|
||||
let tokenizer = Arc::clone(&handle_ref.tokenizer);
|
||||
let model = handle_ref.model.clone();
|
||||
let request_id = handle_ref.request_id.clone();
|
||||
let created = handle_ref.created;
|
||||
let system_fingerprint = handle_ref.system_fingerprint.clone();
|
||||
|
||||
// Use tokio runtime to run async code
|
||||
let result = RUNTIME.block_on(async {
|
||||
convert_proto_chunk_to_openai(
|
||||
proto_response,
|
||||
handle_ref,
|
||||
&tokenizer,
|
||||
&model,
|
||||
&request_id,
|
||||
created,
|
||||
system_fingerprint.as_deref(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
match result {
|
||||
Ok(Some(openai_response)) => {
|
||||
// Serialize to JSON
|
||||
let result_str = match serde_json::to_string(&openai_response) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to serialize response: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
let result_cstr = match CString::new(result_str) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create result string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
*result_json_out = result_cstr.into_raw();
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Ok(None) => {
|
||||
// No response to send (e.g., empty chunk)
|
||||
let empty = CString::new("").unwrap();
|
||||
*result_json_out = empty.into_raw();
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Conversion error: {}", e));
|
||||
SglErrorCode::ParsingError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to convert proto chunk to OpenAI format
|
||||
pub(crate) async fn convert_proto_chunk_to_openai(
|
||||
proto_response: proto::GenerateResponse,
|
||||
handle: &mut GrpcResponseConverterHandle,
|
||||
tokenizer: &Arc<dyn Tokenizer>,
|
||||
model: &str,
|
||||
request_id: &str,
|
||||
created: u64,
|
||||
system_fingerprint: Option<&str>,
|
||||
) -> Result<Option<smg::protocols::chat::ChatCompletionStreamResponse>, String> {
|
||||
use smg_grpc_client::sglang_proto::generate_response::Response::*;
|
||||
use smg::protocols::chat::{ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice};
|
||||
|
||||
match proto_response.response {
|
||||
Some(Chunk(chunk)) => {
|
||||
let index = chunk.index;
|
||||
|
||||
// Mark as not first chunk if we've seen this index before
|
||||
let is_first = handle.is_first_chunk.entry(index).or_insert(true);
|
||||
let first_chunk = *is_first;
|
||||
*is_first = false;
|
||||
|
||||
// Track token counts from chunks (cumulative values from proto)
|
||||
// These are cumulative values, so we always use the latest value
|
||||
// For prompt_tokens, if chunk value is 0, preserve existing value or use initial_prompt_tokens
|
||||
// This prevents overwriting valid prompt_tokens with 0
|
||||
if chunk.prompt_tokens > 0 {
|
||||
handle.prompt_tokens.insert(index, chunk.prompt_tokens);
|
||||
} else {
|
||||
// If chunk.prompt_tokens is 0, try to preserve existing value or use initial_prompt_tokens
|
||||
if !handle.prompt_tokens.contains_key(&index) {
|
||||
// No existing value, try to use initial_prompt_tokens
|
||||
if let Some(initial_prompt) = handle.initial_prompt_tokens {
|
||||
handle.prompt_tokens.insert(index, initial_prompt);
|
||||
}
|
||||
}
|
||||
// If existing value exists, keep it (don't overwrite with 0)
|
||||
}
|
||||
// For completion_tokens, always update (even if 0) as it's cumulative
|
||||
handle.completion_tokens.insert(index, chunk.completion_tokens);
|
||||
|
||||
// Process tokens through stop decoder if available, otherwise use incremental decoder
|
||||
let chunk_text = if let Some(ref stop_decoder) = handle.stop_decoder {
|
||||
let mut decoder_guard = stop_decoder.lock().await;
|
||||
let mut text = String::new();
|
||||
for &token_id in &chunk.token_ids {
|
||||
match decoder_guard.process_token(token_id).unwrap_or_else(|_| {
|
||||
smg::tokenizer::stop::SequenceDecoderOutput::Held
|
||||
}) {
|
||||
smg::tokenizer::stop::SequenceDecoderOutput::Text(t) => {
|
||||
text.push_str(&t);
|
||||
}
|
||||
smg::tokenizer::stop::SequenceDecoderOutput::StoppedWithText(t) => {
|
||||
text.push_str(&t);
|
||||
break;
|
||||
}
|
||||
smg::tokenizer::stop::SequenceDecoderOutput::Stopped => {
|
||||
break;
|
||||
}
|
||||
smg::tokenizer::stop::SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
text
|
||||
} else {
|
||||
// Use incremental decoder to handle multi-byte character boundaries
|
||||
let decode_stream = handle.decode_streams.entry(index).or_insert_with(|| {
|
||||
DecodeStream::new(
|
||||
Arc::clone(&tokenizer),
|
||||
&[], // No prompt tokens for completion
|
||||
handle.skip_special_tokens,
|
||||
)
|
||||
});
|
||||
|
||||
// Process tokens incrementally
|
||||
let mut text_parts = Vec::new();
|
||||
for &token_id in &chunk.token_ids {
|
||||
if let Ok(Some(text)) = decode_stream.step(token_id) {
|
||||
text_parts.push(text);
|
||||
}
|
||||
}
|
||||
text_parts.join("")
|
||||
};
|
||||
|
||||
if chunk_text.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Send first chunk with role
|
||||
if first_chunk {
|
||||
let first_response = ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
return Ok(Some(first_response));
|
||||
}
|
||||
|
||||
// Update stream buffer
|
||||
let stream_buffer = handle.stream_buffers.entry(index).or_default();
|
||||
stream_buffer.push_str(&chunk_text);
|
||||
|
||||
// Handle tool calls if tools are provided
|
||||
if let (Some(ref tools), Some(ref tool_parser)) = (handle.tools.as_ref(), handle.tool_parser.as_ref()) {
|
||||
let tool_choice_enabled = !matches!(
|
||||
handle.tool_choice,
|
||||
Some(ToolChoice::Value(ToolChoiceValue::None))
|
||||
);
|
||||
|
||||
if tool_choice_enabled {
|
||||
let mut parser_guard = tool_parser.lock().await;
|
||||
match parser_guard.parse_incremental(&chunk_text, tools).await {
|
||||
Ok(streaming_result) => {
|
||||
if !streaming_result.calls.is_empty() {
|
||||
handle.has_tool_calls.insert(index, true);
|
||||
// Convert tool call items to OpenAI format
|
||||
let tool_call_deltas: Vec<_> = streaming_result
|
||||
.calls
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
let id = if let Some(ref name) = item.name {
|
||||
generate_tool_call_id(
|
||||
model,
|
||||
name,
|
||||
item.tool_index,
|
||||
handle.history_tool_calls_count,
|
||||
)
|
||||
} else {
|
||||
format!("call_{}", item.tool_index)
|
||||
};
|
||||
|
||||
ToolCallDelta {
|
||||
index: item.tool_index as u32,
|
||||
id: Some(id),
|
||||
tool_type: if item.name.is_some() {
|
||||
Some("function".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
function: Some(FunctionCallDelta {
|
||||
name: item.name,
|
||||
arguments: if !item.parameters.is_empty() {
|
||||
Some(item.parameters)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tool_response = ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: Some(tool_call_deltas),
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
return Ok(Some(tool_response));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Log error but continue with regular content
|
||||
tracing::warn!("Tool parser error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Regular content emission
|
||||
let content_response = ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(chunk_text),
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: None,
|
||||
matched_stop: None,
|
||||
}],
|
||||
usage: None,
|
||||
};
|
||||
|
||||
Ok(Some(content_response))
|
||||
}
|
||||
Some(Complete(complete)) => {
|
||||
let index = complete.index;
|
||||
|
||||
// Flush any remaining text
|
||||
// Flush any remaining text from decode stream
|
||||
let mut final_text = handle.stream_buffers.remove(&index).unwrap_or_default();
|
||||
if let Some(ref mut decode_stream) = handle.decode_streams.get_mut(&index) {
|
||||
if let Ok(Some(remaining)) = decode_stream.flush() {
|
||||
final_text.push_str(&remaining);
|
||||
}
|
||||
}
|
||||
handle.decode_streams.remove(&index);
|
||||
|
||||
// Determine finish reason - ensure it's never empty
|
||||
// If finish_reason is empty, try to infer from other fields or use default
|
||||
let finish_reason = if handle.has_tool_calls.get(&index).copied().unwrap_or(false)
|
||||
&& (complete.finish_reason == "stop" || complete.finish_reason.is_empty())
|
||||
{
|
||||
"tool_calls".to_string()
|
||||
} else if complete.finish_reason.is_empty() || complete.finish_reason.trim().is_empty() {
|
||||
// If finish_reason is empty, try to infer from completion_tokens or use default
|
||||
if complete.completion_tokens > 0 {
|
||||
// If we have completion tokens, likely stopped normally
|
||||
"stop".to_string()
|
||||
} else if !complete.output_ids.is_empty() {
|
||||
// If we have output_ids, likely stopped normally
|
||||
"stop".to_string()
|
||||
} else {
|
||||
// Default fallback - always ensure we have a value
|
||||
"stop".to_string()
|
||||
}
|
||||
} else {
|
||||
complete.finish_reason.clone()
|
||||
};
|
||||
|
||||
// Ensure finish_reason is never empty (defensive check)
|
||||
let finish_reason = if finish_reason.is_empty() || finish_reason.trim().is_empty() {
|
||||
"stop".to_string()
|
||||
} else {
|
||||
finish_reason
|
||||
};
|
||||
|
||||
// Extract matched_stop
|
||||
let matched_stop = match &complete.matched_stop {
|
||||
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
|
||||
Some(Value::Number(serde_json::Number::from(*token_id)))
|
||||
}
|
||||
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
|
||||
Some(Value::String(stop_str.clone()))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Build usage - prefer values from complete message, but fallback to accumulated values from chunks
|
||||
// Complete message should have the final values, but sometimes they might be 0 or missing
|
||||
// Always use the latest cumulative value from chunks if available, otherwise use complete message value
|
||||
let mut prompt_tokens = handle.prompt_tokens.get(&index)
|
||||
.copied()
|
||||
.filter(|&v| v > 0)
|
||||
.unwrap_or(complete.prompt_tokens);
|
||||
let mut completion_tokens = handle.completion_tokens.get(&index)
|
||||
.copied()
|
||||
.filter(|&v| v > 0)
|
||||
.unwrap_or(complete.completion_tokens);
|
||||
|
||||
// Always try to use initial_prompt_tokens if prompt_tokens is 0 or missing
|
||||
// This is the most reliable source for prompt tokens since we calculate it from the request
|
||||
if prompt_tokens == 0 {
|
||||
if let Some(initial_prompt) = handle.initial_prompt_tokens {
|
||||
prompt_tokens = initial_prompt;
|
||||
}
|
||||
}
|
||||
|
||||
// If completion_tokens is 0, try to infer from output_ids or accumulated chunks
|
||||
if completion_tokens == 0 {
|
||||
// Try to use completion_tokens from complete message even if 0
|
||||
// Or calculate from output_ids
|
||||
if complete.completion_tokens > 0 {
|
||||
completion_tokens = complete.completion_tokens;
|
||||
} else if !complete.output_ids.is_empty() {
|
||||
completion_tokens = complete.output_ids.len() as i32;
|
||||
} else if let Some(&last_completion) = handle.completion_tokens.get(&index) {
|
||||
completion_tokens = last_completion;
|
||||
}
|
||||
}
|
||||
|
||||
// Final fallback: if both are still 0, try to use initial_prompt_tokens for prompt
|
||||
// and calculate completion from output_ids
|
||||
if prompt_tokens == 0 && completion_tokens == 0 {
|
||||
// Try to infer from output_ids if available
|
||||
let output_ids_len = complete.output_ids.len() as i32;
|
||||
if output_ids_len > 0 {
|
||||
completion_tokens = output_ids_len;
|
||||
// Always try to use initial_prompt_tokens for prompt
|
||||
if let Some(initial_prompt) = handle.initial_prompt_tokens {
|
||||
prompt_tokens = initial_prompt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final defensive check: ensure prompt_tokens is set if we have initial_prompt_tokens
|
||||
if prompt_tokens == 0 {
|
||||
if let Some(initial_prompt) = handle.initial_prompt_tokens {
|
||||
prompt_tokens = initial_prompt;
|
||||
}
|
||||
}
|
||||
|
||||
// Always create usage, even if values are 0 (defensive)
|
||||
let usage = Some(Usage {
|
||||
prompt_tokens: prompt_tokens.max(0) as u32,
|
||||
completion_tokens: completion_tokens.max(0) as u32,
|
||||
total_tokens: (prompt_tokens.max(0) + completion_tokens.max(0)) as u32,
|
||||
completion_tokens_details: None,
|
||||
});
|
||||
|
||||
let finish_response = ChatCompletionStreamResponse {
|
||||
id: request_id.to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model.to_string(),
|
||||
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
|
||||
choices: vec![ChatStreamChoice {
|
||||
index,
|
||||
delta: ChatMessageDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: if !final_text.is_empty() {
|
||||
Some(final_text)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
logprobs: None,
|
||||
finish_reason: Some(finish_reason),
|
||||
matched_stop,
|
||||
}],
|
||||
usage,
|
||||
};
|
||||
|
||||
Ok(Some(finish_response))
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
Err(format!("Server error: {} (status: {})", error.message, error.http_status_code))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Free a gRPC response converter handle
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_grpc_response_converter_free(handle: *mut GrpcResponseConverterHandle) {
|
||||
if !handle.is_null() {
|
||||
let _ = Box::from_raw(handle);
|
||||
}
|
||||
}
|
||||
103
third_party/sglang/sgl-model-gateway/bindings/golang/src/lib.rs
vendored
Normal file
103
third_party/sglang/sgl-model-gateway/bindings/golang/src/lib.rs
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
//! FFI module for exposing sgl-model-gateway preprocessing and postprocessing functions
|
||||
//! to C-compatible languages (e.g., Golang via cgo)
|
||||
//!
|
||||
//! This module provides C-compatible function signatures for:
|
||||
//! - Tokenizer operations (encode, decode, chat template)
|
||||
//! - Tool parser operations (parse tool calls)
|
||||
//! - Tool constraint generation
|
||||
//! - gRPC client SDK (complete request-response flow)
|
||||
//!
|
||||
//! # Safety
|
||||
//! All functions marked with `#[no_mangle]` and `extern "C"` must be called
|
||||
//! with valid pointers and follow the documented memory management rules.
|
||||
|
||||
// Re-export error types
|
||||
pub use error::{SglErrorCode, set_error_message, set_error_message_fmt, clear_error_message};
|
||||
|
||||
// Re-export memory management functions
|
||||
pub use memory::{sgl_free_string, sgl_free_token_ids};
|
||||
|
||||
// Re-export tokenizer functions
|
||||
pub use tokenizer::{
|
||||
TokenizerHandle,
|
||||
sgl_tokenizer_create_from_file,
|
||||
sgl_tokenizer_encode,
|
||||
sgl_tokenizer_apply_chat_template,
|
||||
sgl_tokenizer_apply_chat_template_with_tools,
|
||||
sgl_tokenizer_decode,
|
||||
sgl_tokenizer_free,
|
||||
};
|
||||
|
||||
// Re-export tool parser functions
|
||||
pub use tool_parser::{
|
||||
ToolParserHandle,
|
||||
sgl_tool_parser_create,
|
||||
sgl_tool_parser_parse_complete,
|
||||
sgl_tool_parser_parse_incremental,
|
||||
sgl_tool_parser_reset,
|
||||
sgl_tool_parser_free,
|
||||
};
|
||||
|
||||
// Re-export gRPC converter functions
|
||||
pub use grpc_converter::{
|
||||
GrpcResponseConverterHandle,
|
||||
sgl_grpc_response_converter_create,
|
||||
sgl_grpc_response_converter_convert_chunk,
|
||||
sgl_grpc_response_converter_free,
|
||||
};
|
||||
|
||||
// Re-export client SDK functions
|
||||
pub use client::{
|
||||
SglangClientHandle,
|
||||
sgl_client_create,
|
||||
sgl_client_free,
|
||||
};
|
||||
|
||||
// Re-export stream functions
|
||||
pub use stream::{
|
||||
SglangStreamHandle,
|
||||
sgl_stream_read_next,
|
||||
sgl_stream_free,
|
||||
};
|
||||
|
||||
// Re-export client stream function (defined in client.rs but used by stream)
|
||||
pub use client::sgl_client_chat_completion_stream;
|
||||
|
||||
// Re-export preprocessor functions
|
||||
pub use preprocessor::{
|
||||
sgl_preprocess_chat_request,
|
||||
sgl_preprocess_chat_request_with_tokenizer,
|
||||
sgl_preprocessed_request_free,
|
||||
};
|
||||
|
||||
// Re-export postprocessor functions
|
||||
pub use postprocessor::{
|
||||
sgl_postprocess_stream_chunk,
|
||||
sgl_postprocess_stream_chunks_batch,
|
||||
};
|
||||
|
||||
// Re-export utility functions
|
||||
pub use utils::sgl_generate_tool_constraints;
|
||||
|
||||
// Sub-modules
|
||||
mod error;
|
||||
mod memory;
|
||||
mod tokenizer;
|
||||
mod tool_parser;
|
||||
mod grpc_converter;
|
||||
mod client;
|
||||
mod stream;
|
||||
mod utils;
|
||||
mod preprocessor;
|
||||
mod postprocessor;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_codes() {
|
||||
assert_eq!(SglErrorCode::Success as i32, 0);
|
||||
assert_eq!(SglErrorCode::InvalidArgument as i32, 1);
|
||||
}
|
||||
}
|
||||
28
third_party/sglang/sgl-model-gateway/bindings/golang/src/memory.rs
vendored
Normal file
28
third_party/sglang/sgl-model-gateway/bindings/golang/src/memory.rs
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
//! Memory management for FFI functions
|
||||
|
||||
use std::ffi::CString;
|
||||
use std::os::raw::c_char;
|
||||
|
||||
/// Free a C string allocated by Rust
|
||||
///
|
||||
/// # Safety
|
||||
/// This function must only be called with pointers returned by other FFI functions.
|
||||
/// Calling with arbitrary pointers or multiple times on the same pointer is undefined behavior.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_free_string(s: *mut c_char) {
|
||||
if !s.is_null() {
|
||||
let _ = CString::from_raw(s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Free token IDs array allocated by Rust
|
||||
///
|
||||
/// # Safety
|
||||
/// This function must only be called with pointers returned by `sgl_tokenizer_encode`.
|
||||
/// The `count` parameter must match the length of the array.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_free_token_ids(ptr: *mut u32, count: usize) {
|
||||
if !ptr.is_null() && count > 0 {
|
||||
let _ = Vec::from_raw_parts(ptr, count, count);
|
||||
}
|
||||
}
|
||||
465
third_party/sglang/sgl-model-gateway/bindings/golang/src/postprocessor.rs
vendored
Normal file
465
third_party/sglang/sgl-model-gateway/bindings/golang/src/postprocessor.rs
vendored
Normal file
@@ -0,0 +1,465 @@
|
||||
//! Postprocessing FFI functions for gRPC stream chunks
|
||||
//!
|
||||
//! This module provides C-compatible functions for postprocessing gRPC stream chunks:
|
||||
//! - Parse tool calls from model output
|
||||
//! - Convert proto format to OpenAI format
|
||||
//! - Handle reasoning content parsing
|
||||
//!
|
||||
//! These functions are designed to be called for each stream chunk, but can be optimized
|
||||
//! with batching in the future.
|
||||
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
use serde_json::Value;
|
||||
|
||||
use smg_grpc_client::sglang_proto as proto;
|
||||
|
||||
use super::error::{SglErrorCode, set_error_message};
|
||||
use super::grpc_converter::GrpcResponseConverterHandle;
|
||||
|
||||
use tokio::runtime::Runtime;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
/// Global tokio runtime for async operations
|
||||
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
|
||||
Runtime::new().expect("Failed to create tokio runtime for postprocessor FFI")
|
||||
});
|
||||
|
||||
/// Postprocess a gRPC stream chunk to OpenAI format
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Parses the proto chunk from JSON
|
||||
/// 2. Converts it to OpenAI format using the converter handle
|
||||
/// 3. Returns the OpenAI format JSON
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `converter_handle` - Converter handle (created with sgl_grpc_response_converter_create)
|
||||
/// * `proto_chunk_json` - JSON string of proto.GenerateResponse
|
||||
/// * `openai_json_out` - Pointer to receive OpenAI format JSON (must be freed with sgl_free_string)
|
||||
/// * `is_done_out` - Pointer to receive is_done flag (1 if stream is complete, 0 otherwise)
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_postprocess_stream_chunk(
|
||||
converter_handle: *mut GrpcResponseConverterHandle,
|
||||
proto_chunk_json: *const c_char,
|
||||
openai_json_out: *mut *mut c_char,
|
||||
is_done_out: *mut c_int,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if converter_handle.is_null()
|
||||
|| proto_chunk_json.is_null()
|
||||
|| openai_json_out.is_null()
|
||||
|| is_done_out.is_null()
|
||||
{
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let proto_chunk_str = match CStr::from_ptr(proto_chunk_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in proto_chunk_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse proto.GenerateResponse from JSON
|
||||
let json_value: Value = match serde_json::from_str(proto_chunk_str) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse proto chunk JSON: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Build proto::GenerateResponse from JSON value
|
||||
let mut proto_response = proto::GenerateResponse {
|
||||
request_id: json_value
|
||||
.get("request_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
response: None,
|
||||
};
|
||||
|
||||
// Parse the response oneof field
|
||||
let is_done = if let Some(chunk_json) = json_value.get("chunk") {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
token_ids: chunk_json
|
||||
.get("token_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_u64().map(|n| n as u32))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
prompt_tokens: chunk_json
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: chunk_json
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
cached_tokens: chunk_json
|
||||
.get("cached_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
output_logprobs: None,
|
||||
hidden_states: vec![],
|
||||
input_logprobs: None,
|
||||
index: 0,
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Chunk(chunk));
|
||||
false
|
||||
} else if let Some(complete_json) = json_value.get("complete") {
|
||||
let complete = proto::GenerateComplete {
|
||||
output_ids: complete_json
|
||||
.get("output_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_u64().map(|n| n as u32))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
finish_reason: complete_json
|
||||
.get("finish_reason")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
prompt_tokens: complete_json
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: complete_json
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
cached_tokens: complete_json
|
||||
.get("cached_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
output_logprobs: None,
|
||||
all_hidden_states: vec![],
|
||||
input_logprobs: None,
|
||||
matched_stop: None,
|
||||
index: 0,
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Complete(complete));
|
||||
true
|
||||
} else if let Some(error_json) = json_value.get("error") {
|
||||
let error = proto::GenerateError {
|
||||
message: error_json
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
http_status_code: error_json
|
||||
.get("http_status_code")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("500")
|
||||
.to_string(),
|
||||
details: error_json
|
||||
.get("details")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Error(error));
|
||||
true
|
||||
} else {
|
||||
set_error_message(
|
||||
error_out,
|
||||
"Proto chunk JSON must contain 'chunk', 'complete', or 'error' field",
|
||||
);
|
||||
return SglErrorCode::ParsingError;
|
||||
};
|
||||
|
||||
// Convert proto chunk to OpenAI format using the converter's convert_chunk function
|
||||
// We'll use the existing converter API instead of calling the internal function directly
|
||||
let proto_chunk_json_cstr = match CString::new(proto_chunk_str) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create C string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
// Use the existing converter API
|
||||
let mut openai_json_ptr: *mut c_char = ptr::null_mut();
|
||||
let result = super::grpc_converter::sgl_grpc_response_converter_convert_chunk(
|
||||
converter_handle,
|
||||
proto_chunk_json_cstr.as_ptr(),
|
||||
&mut openai_json_ptr,
|
||||
error_out,
|
||||
);
|
||||
|
||||
if result == SglErrorCode::Success {
|
||||
*openai_json_out = openai_json_ptr;
|
||||
*is_done_out = if is_done { 1 } else { 0 };
|
||||
SglErrorCode::Success
|
||||
} else {
|
||||
*openai_json_out = ptr::null_mut();
|
||||
*is_done_out = if is_done { 1 } else { 0 };
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Postprocess multiple gRPC stream chunks in batch (reduces FFI overhead)
|
||||
///
|
||||
/// This function processes multiple chunks in a single FFI call, significantly reducing
|
||||
/// FFI overhead in streaming scenarios.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `converter_handle` - Converter handle (created with sgl_grpc_response_converter_create)
|
||||
/// * `proto_chunks_json_array` - JSON array string of proto.GenerateResponse chunks
|
||||
/// * `max_chunks` - Maximum number of chunks to process (for safety)
|
||||
/// * `openai_chunks_json_array_out` - Pointer to receive JSON array of OpenAI format chunks (must be freed with sgl_free_string)
|
||||
/// * `chunks_count_out` - Pointer to receive number of processed chunks
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_postprocess_stream_chunks_batch(
|
||||
converter_handle: *mut GrpcResponseConverterHandle,
|
||||
proto_chunks_json_array: *const c_char,
|
||||
max_chunks: c_int,
|
||||
openai_chunks_json_array_out: *mut *mut c_char,
|
||||
chunks_count_out: *mut c_int,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if converter_handle.is_null()
|
||||
|| proto_chunks_json_array.is_null()
|
||||
|| openai_chunks_json_array_out.is_null()
|
||||
|| chunks_count_out.is_null()
|
||||
{
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let chunks_array_str = match CStr::from_ptr(proto_chunks_json_array).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in proto_chunks_json_array");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse JSON array of chunks
|
||||
let chunks_array: Vec<Value> = match serde_json::from_str(chunks_array_str) {
|
||||
Ok(arr) => arr,
|
||||
Err(e) => {
|
||||
set_error_message(
|
||||
error_out,
|
||||
&format!("Failed to parse chunks JSON array: {}", e),
|
||||
);
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Limit batch size for safety
|
||||
let max_chunks_usize = max_chunks as usize;
|
||||
let chunks_to_process = if chunks_array.len() > max_chunks_usize {
|
||||
&chunks_array[..max_chunks_usize]
|
||||
} else {
|
||||
&chunks_array
|
||||
};
|
||||
|
||||
let handle_ref = &mut *converter_handle;
|
||||
let tokenizer = Arc::clone(&handle_ref.tokenizer);
|
||||
let model = handle_ref.model.clone();
|
||||
let request_id = handle_ref.request_id.clone();
|
||||
let created = handle_ref.created;
|
||||
let system_fingerprint = handle_ref.system_fingerprint.clone();
|
||||
|
||||
// Process chunks in batch
|
||||
let mut results = Vec::new();
|
||||
let mut has_error = false;
|
||||
let mut error_msg = String::new();
|
||||
|
||||
for chunk_json in chunks_to_process {
|
||||
// Parse proto.GenerateResponse from JSON
|
||||
let mut proto_response = proto::GenerateResponse {
|
||||
request_id: chunk_json
|
||||
.get("request_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
response: None,
|
||||
};
|
||||
|
||||
// Parse the response oneof field (same logic as single chunk processing)
|
||||
let _is_done = if let Some(chunk_json) = chunk_json.get("chunk") {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
token_ids: chunk_json
|
||||
.get("token_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_u64().map(|n| n as u32))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
prompt_tokens: chunk_json
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: chunk_json
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
cached_tokens: chunk_json
|
||||
.get("cached_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
output_logprobs: None,
|
||||
hidden_states: vec![],
|
||||
input_logprobs: None,
|
||||
index: 0,
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Chunk(chunk));
|
||||
false
|
||||
} else if let Some(complete_json) = chunk_json.get("complete") {
|
||||
let complete = proto::GenerateComplete {
|
||||
output_ids: complete_json
|
||||
.get("output_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_u64().map(|n| n as u32))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
finish_reason: complete_json
|
||||
.get("finish_reason")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
prompt_tokens: complete_json
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
completion_tokens: complete_json
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
cached_tokens: complete_json
|
||||
.get("cached_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|n| n as i32)
|
||||
.unwrap_or(0),
|
||||
output_logprobs: None,
|
||||
all_hidden_states: vec![],
|
||||
input_logprobs: None,
|
||||
matched_stop: None,
|
||||
index: 0,
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Complete(complete));
|
||||
true
|
||||
} else if let Some(error_json) = chunk_json.get("error") {
|
||||
let error = proto::GenerateError {
|
||||
message: error_json
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
http_status_code: error_json
|
||||
.get("http_status_code")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("500")
|
||||
.to_string(),
|
||||
details: error_json
|
||||
.get("details")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
};
|
||||
proto_response.response = Some(proto::generate_response::Response::Error(error));
|
||||
true
|
||||
} else {
|
||||
error_msg = format!(
|
||||
"Chunk JSON must contain 'chunk', 'complete', or 'error' field: {}",
|
||||
chunk_json
|
||||
);
|
||||
has_error = true;
|
||||
break;
|
||||
};
|
||||
|
||||
// Convert proto chunk to OpenAI format
|
||||
let result = RUNTIME.block_on(async {
|
||||
super::grpc_converter::convert_proto_chunk_to_openai(
|
||||
proto_response,
|
||||
handle_ref,
|
||||
&tokenizer,
|
||||
&model,
|
||||
&request_id,
|
||||
created,
|
||||
system_fingerprint.as_deref(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
match result {
|
||||
Ok(Some(openai_response)) => {
|
||||
results.push(openai_response);
|
||||
}
|
||||
Ok(None) => {
|
||||
// Empty response, skip
|
||||
}
|
||||
Err(e) => {
|
||||
error_msg = format!("Postprocessing failed for chunk: {}", e);
|
||||
has_error = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if has_error {
|
||||
set_error_message(error_out, &error_msg);
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
|
||||
// Serialize results to JSON array
|
||||
let results_json = match serde_json::to_string(&results) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(
|
||||
error_out,
|
||||
&format!("Failed to serialize results JSON array: {}", e),
|
||||
);
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
let results_cstr = match CString::new(results_json) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create C string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
*openai_chunks_json_array_out = results_cstr.into_raw();
|
||||
*chunks_count_out = results.len() as c_int;
|
||||
|
||||
SglErrorCode::Success
|
||||
}
|
||||
372
third_party/sglang/sgl-model-gateway/bindings/golang/src/preprocessor.rs
vendored
Normal file
372
third_party/sglang/sgl-model-gateway/bindings/golang/src/preprocessor.rs
vendored
Normal file
@@ -0,0 +1,372 @@
|
||||
//! Preprocessing FFI functions for chat requests
|
||||
//!
|
||||
//! This module provides C-compatible functions for preprocessing chat completion requests:
|
||||
//! - Apply chat_template to messages
|
||||
//! - Tokenize the processed text
|
||||
//! - Generate tool constraints
|
||||
//!
|
||||
//! These functions are designed to be called once per request, reducing FFI overhead.
|
||||
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::ptr;
|
||||
use std::os::raw::c_uint;
|
||||
|
||||
use smg::tokenizer::create_tokenizer_from_file;
|
||||
use smg::protocols::chat::ChatCompletionRequest;
|
||||
use smg::routers::grpc::utils::{process_chat_messages, generate_tool_constraints};
|
||||
|
||||
use super::error::{SglErrorCode, set_error_message};
|
||||
use super::memory::{sgl_free_string, sgl_free_token_ids};
|
||||
use super::tokenizer::TokenizerHandle;
|
||||
|
||||
/// Handle for preprocessed request
|
||||
#[repr(C)]
|
||||
pub struct PreprocessedRequestHandle {
|
||||
pub(crate) prompt_text: CString,
|
||||
pub(crate) token_ids: Vec<i32>,
|
||||
pub(crate) tool_constraints_json: Option<CString>,
|
||||
pub(crate) prompt_tokens: i32,
|
||||
}
|
||||
|
||||
/// Preprocess a chat completion request
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Applies chat_template to messages
|
||||
/// 2. Tokenizes the processed text
|
||||
/// 3. Generates tool constraints (if tools are present)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `request_json` - OpenAI ChatCompletionRequest as JSON string
|
||||
/// * `tokenizer_path` - Path to tokenizer directory
|
||||
/// * `prompt_text_out` - Pointer to receive prompt text (C string, must be freed with sgl_free_string)
|
||||
/// * `token_ids_out` - Pointer to receive token IDs array (must be freed with sgl_free_token_ids)
|
||||
/// * `token_ids_len_out` - Pointer to receive token IDs array length
|
||||
/// * `tool_constraints_json_out` - Optional pointer to receive tool constraints JSON (must be freed with sgl_free_string)
|
||||
/// * `prompt_tokens_out` - Pointer to receive prompt token count
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_preprocess_chat_request(
|
||||
request_json: *const c_char,
|
||||
tokenizer_path: *const c_char,
|
||||
prompt_text_out: *mut *mut c_char,
|
||||
token_ids_out: *mut *mut c_uint,
|
||||
token_ids_len_out: *mut usize,
|
||||
tool_constraints_json_out: *mut *mut c_char,
|
||||
prompt_tokens_out: *mut c_int,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if request_json.is_null()
|
||||
|| tokenizer_path.is_null()
|
||||
|| prompt_text_out.is_null()
|
||||
|| token_ids_out.is_null()
|
||||
|| token_ids_len_out.is_null()
|
||||
|| prompt_tokens_out.is_null()
|
||||
{
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
// Parse input strings
|
||||
let request_str = match CStr::from_ptr(request_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in request_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
let tokenizer_path_str = match CStr::from_ptr(tokenizer_path).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in tokenizer_path");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse ChatCompletionRequest
|
||||
let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse request JSON: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Create tokenizer
|
||||
let tokenizer = match create_tokenizer_from_file(tokenizer_path_str) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create tokenizer: {}", e));
|
||||
return SglErrorCode::TokenizationError;
|
||||
}
|
||||
};
|
||||
|
||||
// Process chat messages (apply chat_template)
|
||||
let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to process chat messages: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Tokenize the processed text
|
||||
let encoding = match tokenizer.encode(&processed_messages.text, false) {
|
||||
Ok(enc) => enc,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Tokenization failed: {}", e));
|
||||
return SglErrorCode::TokenizationError;
|
||||
}
|
||||
};
|
||||
|
||||
let token_ids_vec: Vec<i32> = encoding
|
||||
.token_ids()
|
||||
.iter()
|
||||
.map(|&id| id as i32)
|
||||
.collect();
|
||||
|
||||
let prompt_tokens = token_ids_vec.len() as i32;
|
||||
|
||||
// Generate tool constraints if tools are present
|
||||
let tool_constraints_json = if let Some(tools) = chat_request.tools.as_ref() {
|
||||
match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) {
|
||||
Ok(Some(constraints)) => {
|
||||
match serde_json::to_string(&constraints) {
|
||||
Ok(json_str) => Some(CString::new(json_str).unwrap()),
|
||||
Err(e) => {
|
||||
set_error_message(
|
||||
error_out,
|
||||
&format!("Failed to serialize tool constraints: {}", e),
|
||||
);
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
set_error_message(
|
||||
error_out,
|
||||
&format!("Failed to generate tool constraints: {}", e),
|
||||
);
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Allocate memory for outputs
|
||||
let prompt_text_cstr = match CString::new(processed_messages.text) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create C string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
let token_ids_len = token_ids_vec.len();
|
||||
// Convert i32 to u32 for token IDs (as expected by the memory management functions)
|
||||
let token_ids_u32: Vec<u32> = token_ids_vec.iter().map(|&id| id as u32).collect();
|
||||
let token_ids_ptr = if token_ids_u32.is_empty() {
|
||||
ptr::null_mut()
|
||||
} else {
|
||||
let boxed = token_ids_u32.into_boxed_slice();
|
||||
Box::into_raw(boxed) as *mut c_uint
|
||||
};
|
||||
|
||||
// Set output values
|
||||
*prompt_text_out = prompt_text_cstr.into_raw();
|
||||
*token_ids_out = token_ids_ptr;
|
||||
*token_ids_len_out = token_ids_len;
|
||||
*prompt_tokens_out = prompt_tokens;
|
||||
|
||||
if !tool_constraints_json_out.is_null() {
|
||||
if let Some(constraints) = tool_constraints_json {
|
||||
*tool_constraints_json_out = constraints.into_raw();
|
||||
} else {
|
||||
*tool_constraints_json_out = ptr::null_mut();
|
||||
}
|
||||
}
|
||||
|
||||
SglErrorCode::Success
|
||||
}
|
||||
|
||||
/// Preprocess a chat completion request using an existing tokenizer handle
|
||||
///
|
||||
/// This function is similar to sgl_preprocess_chat_request, but accepts a TokenizerHandle
|
||||
/// instead of creating a new tokenizer. This allows reusing a cached tokenizer instance,
|
||||
/// significantly reducing initialization overhead in concurrent scenarios.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `request_json` - OpenAI ChatCompletionRequest as JSON string
|
||||
/// * `tokenizer_handle` - Existing tokenizer handle (must be valid)
|
||||
/// * `prompt_text_out` - Pointer to receive prompt text (C string, must be freed with sgl_free_string)
|
||||
/// * `token_ids_out` - Pointer to receive token IDs array (must be freed with sgl_free_token_ids)
|
||||
/// * `token_ids_len_out` - Pointer to receive token IDs array length
|
||||
/// * `tool_constraints_json_out` - Optional pointer to receive tool constraints JSON (must be freed with sgl_free_string)
|
||||
/// * `prompt_tokens_out` - Pointer to receive prompt token count
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_preprocess_chat_request_with_tokenizer(
|
||||
request_json: *const c_char,
|
||||
tokenizer_handle: *mut TokenizerHandle,
|
||||
prompt_text_out: *mut *mut c_char,
|
||||
token_ids_out: *mut *mut c_uint,
|
||||
token_ids_len_out: *mut usize,
|
||||
tool_constraints_json_out: *mut *mut c_char,
|
||||
prompt_tokens_out: *mut c_int,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if request_json.is_null()
|
||||
|| tokenizer_handle.is_null()
|
||||
|| prompt_text_out.is_null()
|
||||
|| token_ids_out.is_null()
|
||||
|| token_ids_len_out.is_null()
|
||||
|| prompt_tokens_out.is_null()
|
||||
{
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
// Parse input string
|
||||
let request_str = match CStr::from_ptr(request_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in request_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse ChatCompletionRequest
|
||||
let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse request JSON: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Use existing tokenizer from handle (no need to create new one!)
|
||||
let handle_ref = &*tokenizer_handle;
|
||||
let tokenizer = &handle_ref.tokenizer;
|
||||
|
||||
// Process chat messages (apply chat_template)
|
||||
let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to process chat messages: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
// Tokenize the processed text
|
||||
let encoding = match tokenizer.encode(&processed_messages.text, false) {
|
||||
Ok(enc) => enc,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Tokenization failed: {}", e));
|
||||
return SglErrorCode::TokenizationError;
|
||||
}
|
||||
};
|
||||
|
||||
let token_ids_vec: Vec<i32> = encoding
|
||||
.token_ids()
|
||||
.iter()
|
||||
.map(|&id| id as i32)
|
||||
.collect();
|
||||
|
||||
let prompt_tokens = token_ids_vec.len() as i32;
|
||||
|
||||
// Generate tool constraints if tools are present
|
||||
let tool_constraints_json = if let Some(tools) = chat_request.tools.as_ref() {
|
||||
match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) {
|
||||
Ok(Some(constraints)) => {
|
||||
match serde_json::to_string(&constraints) {
|
||||
Ok(json_str) => Some(CString::new(json_str).unwrap()),
|
||||
Err(e) => {
|
||||
set_error_message(
|
||||
error_out,
|
||||
&format!("Failed to serialize tool constraints: {}", e),
|
||||
);
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
set_error_message(
|
||||
error_out,
|
||||
&format!("Failed to generate tool constraints: {}", e),
|
||||
);
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Allocate memory for outputs
|
||||
let prompt_text_cstr = match CString::new(processed_messages.text) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create C string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
let token_ids_len = token_ids_vec.len();
|
||||
// Convert i32 to u32 for token IDs (as expected by the memory management functions)
|
||||
let token_ids_u32: Vec<u32> = token_ids_vec.iter().map(|&id| id as u32).collect();
|
||||
let token_ids_ptr = if token_ids_u32.is_empty() {
|
||||
ptr::null_mut()
|
||||
} else {
|
||||
let boxed = token_ids_u32.into_boxed_slice();
|
||||
Box::into_raw(boxed) as *mut c_uint
|
||||
};
|
||||
|
||||
// Set output values
|
||||
*prompt_text_out = prompt_text_cstr.into_raw();
|
||||
*token_ids_out = token_ids_ptr;
|
||||
*token_ids_len_out = token_ids_len;
|
||||
*prompt_tokens_out = prompt_tokens;
|
||||
|
||||
if !tool_constraints_json_out.is_null() {
|
||||
if let Some(constraints) = tool_constraints_json {
|
||||
*tool_constraints_json_out = constraints.into_raw();
|
||||
} else {
|
||||
*tool_constraints_json_out = ptr::null_mut();
|
||||
}
|
||||
}
|
||||
|
||||
SglErrorCode::Success
|
||||
}
|
||||
|
||||
/// Free a preprocessed request handle (cleanup function)
|
||||
///
|
||||
/// This function frees the memory allocated by sgl_preprocess_chat_request.
|
||||
/// It should be called after the preprocessed data is no longer needed.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_preprocessed_request_free(
|
||||
prompt_text: *mut c_char,
|
||||
token_ids: *mut c_uint,
|
||||
token_ids_len: usize,
|
||||
tool_constraints_json: *mut c_char,
|
||||
) {
|
||||
if !prompt_text.is_null() {
|
||||
sgl_free_string(prompt_text);
|
||||
}
|
||||
|
||||
if !token_ids.is_null() && token_ids_len > 0 {
|
||||
sgl_free_token_ids(token_ids, token_ids_len);
|
||||
}
|
||||
|
||||
if !tool_constraints_json.is_null() {
|
||||
sgl_free_string(tool_constraints_json);
|
||||
}
|
||||
}
|
||||
288
third_party/sglang/sgl-model-gateway/bindings/golang/src/stream.rs
vendored
Normal file
288
third_party/sglang/sgl-model-gateway/bindings/golang/src/stream.rs
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
//! Stream handling FFI functions
|
||||
//!
|
||||
//! This module provides FFI (Foreign Function Interface) functions for managing
|
||||
//! streaming responses from the SGLang gRPC API. It handles:
|
||||
//!
|
||||
//! - Creating and managing stream handles
|
||||
//! - Reading chunks from streams and converting them to OpenAI format
|
||||
//! - Managing automatic abort on stream drop (via AbortOnDropStream)
|
||||
//! - Thread-safe access to streams and response converters
|
||||
//!
|
||||
//! # Safety
|
||||
//!
|
||||
//! All FFI functions are marked `unsafe` as per Rust FFI conventions. Callers must:
|
||||
//! - Pass valid pointers
|
||||
//! - Ensure proper pointer lifetime management
|
||||
//! - Call corresponding free functions for cleanup
|
||||
|
||||
use std::ffi::CString;
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
use tokio::runtime::Runtime;
|
||||
use once_cell::sync::Lazy;
|
||||
use futures_util::StreamExt;
|
||||
|
||||
use smg_grpc_client::{sglang_proto as proto, sglang_scheduler::{SglangSchedulerClient, AbortOnDropStream}};
|
||||
|
||||
use super::error::{SglErrorCode, set_error_message};
|
||||
use super::grpc_converter::{GrpcResponseConverterHandle, convert_proto_chunk_to_openai};
|
||||
|
||||
/// Global tokio runtime for async operations
|
||||
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
|
||||
Runtime::new().expect("Failed to create tokio runtime for stream FFI")
|
||||
});
|
||||
|
||||
/// Handle for an active streaming request.
|
||||
///
|
||||
/// This struct manages the stream and response converter for a single request.
|
||||
/// It is wrapped in Arc and Mutex for thread-safe concurrent access.
|
||||
///
|
||||
/// # Fields
|
||||
///
|
||||
/// * `stream` - The gRPC stream wrapped in AbortOnDropStream for automatic cleanup
|
||||
/// * `converter` - Response converter that transforms proto messages to OpenAI format
|
||||
/// * `client` - The underlying gRPC client connection
|
||||
/// * `prompt_tokens` - Number of prompt tokens from the original request
|
||||
pub struct SglangStreamHandle {
|
||||
pub(crate) stream: Arc<tokio::sync::Mutex<AbortOnDropStream>>,
|
||||
pub(crate) converter: Arc<tokio::sync::Mutex<GrpcResponseConverterHandle>>,
|
||||
#[allow(dead_code)]
|
||||
pub(crate) client: Arc<SglangSchedulerClient>,
|
||||
#[allow(dead_code)]
|
||||
pub(crate) prompt_tokens: i32, // Number of prompt tokens for this request
|
||||
}
|
||||
|
||||
/// Read next chunk from stream and convert to OpenAI format.
|
||||
///
|
||||
/// This function reads the next chunk from the gRPC stream, converts it from the
|
||||
/// internal protocol format to OpenAI-compatible JSON format, and returns it via
|
||||
/// the output parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `stream_handle` - Mutable pointer to the stream handle
|
||||
/// * `response_json_out` - Pointer to receive OpenAI format JSON string
|
||||
/// - Caller must free this with `sgl_free_string`
|
||||
/// - May be NULL if no data available
|
||||
/// * `is_done_out` - Pointer to receive completion status
|
||||
/// - 0 = stream has more data
|
||||
/// - 1 = stream is complete
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
/// - Only set if function returns an error code
|
||||
/// - Must be freed with `sgl_free_string` if not NULL
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `SglErrorCode::Success` - Successfully read a chunk or reached end of stream
|
||||
/// * Other error codes - See `SglErrorCode` for details
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// - All pointers must be valid and properly aligned
|
||||
/// - `stream_handle` must point to a valid `SglangStreamHandle`
|
||||
/// - Output pointers must be writable
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// - Complete messages are identified by the presence of `proto::GenerateResponse::Complete`
|
||||
/// - When is_done=1, this may be the last readable chunk or the stream may be ending
|
||||
/// - Subsequent calls after is_done=1 will mark the stream as complete internally
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_stream_read_next(
|
||||
stream_handle: *mut SglangStreamHandle,
|
||||
response_json_out: *mut *mut c_char,
|
||||
is_done_out: *mut c_int,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if stream_handle.is_null() || response_json_out.is_null() || is_done_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let handle_ref = &*stream_handle;
|
||||
let stream = Arc::clone(&handle_ref.stream);
|
||||
let converter = Arc::clone(&handle_ref.converter);
|
||||
|
||||
// Read next chunk from stream
|
||||
let chunk_result = RUNTIME.block_on(async {
|
||||
let mut stream_guard = stream.lock().await;
|
||||
stream_guard.next().await
|
||||
});
|
||||
|
||||
match chunk_result {
|
||||
Some(Ok(proto_response)) => {
|
||||
// Convert proto response to OpenAI format
|
||||
// We need to get the converter lock first
|
||||
let conversion_result = RUNTIME.block_on(async {
|
||||
let mut converter_guard = converter.lock().await;
|
||||
|
||||
// Clone necessary fields for conversion
|
||||
let tokenizer = Arc::clone(&converter_guard.tokenizer);
|
||||
let model = converter_guard.model.clone();
|
||||
let request_id = converter_guard.request_id.clone();
|
||||
let created = converter_guard.created;
|
||||
let system_fingerprint = converter_guard.system_fingerprint.clone();
|
||||
|
||||
// Call the conversion function
|
||||
convert_proto_chunk_to_openai(
|
||||
proto_response.clone(),
|
||||
&mut *converter_guard,
|
||||
&tokenizer,
|
||||
&model,
|
||||
&request_id,
|
||||
created,
|
||||
system_fingerprint.as_deref(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
match conversion_result {
|
||||
Ok(Some(openai_response)) => {
|
||||
// Serialize to JSON
|
||||
let result_str = match serde_json::to_string(&openai_response) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to serialize response: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
let result_cstr = match CString::new(result_str) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create result string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if this is a complete response (stream done)
|
||||
let is_complete = matches!(proto_response.response, Some(proto::generate_response::Response::Complete(_)) | Some(proto::generate_response::Response::Error(_)));
|
||||
|
||||
*response_json_out = result_cstr.into_raw();
|
||||
*is_done_out = if is_complete { 1 } else { 0 };
|
||||
|
||||
if is_complete {
|
||||
// Mark stream as completed
|
||||
// Ensure mark_completed() completes and is visible before returning
|
||||
// Use yield_now to ensure Release ordering is fully propagated
|
||||
RUNTIME.block_on(async {
|
||||
let stream_guard = stream.lock().await;
|
||||
stream_guard.mark_completed();
|
||||
// Keep the guard until mark_completed() is fully executed
|
||||
drop(stream_guard);
|
||||
// Yield to ensure Release ordering is propagated before returning
|
||||
// This prevents race condition where Free() is called immediately
|
||||
// and Drop might not see the mark_completed() write
|
||||
tokio::task::yield_now().await;
|
||||
});
|
||||
}
|
||||
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Ok(None) => {
|
||||
// No response to send (e.g., empty chunk)
|
||||
// Don't mark as completed - stream might continue
|
||||
// Just return null and let caller read more
|
||||
*response_json_out = ptr::null_mut();
|
||||
*is_done_out = 0; // Keep stream open, not done yet
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
// Conversion error - don't mark as completed
|
||||
// Let the stream end naturally or return error without stopping stream
|
||||
set_error_message(error_out, &format!("Conversion error: {}", e));
|
||||
*response_json_out = ptr::null_mut();
|
||||
*is_done_out = 0; // Don't mark as done - let caller decide
|
||||
SglErrorCode::ParsingError
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
// Stream error - mark as completed to prevent abort
|
||||
RUNTIME.block_on(async {
|
||||
let stream_guard = stream.lock().await;
|
||||
stream_guard.mark_completed();
|
||||
drop(stream_guard);
|
||||
// Yield to ensure Release ordering is propagated
|
||||
tokio::task::yield_now().await;
|
||||
});
|
||||
|
||||
set_error_message(error_out, &format!("Stream error: {}", e));
|
||||
*is_done_out = 1;
|
||||
SglErrorCode::UnknownError
|
||||
}
|
||||
None => {
|
||||
// Stream ended naturally (no more chunks)
|
||||
// Mark stream as completed before returning to prevent abort
|
||||
RUNTIME.block_on(async {
|
||||
let stream_guard = stream.lock().await;
|
||||
stream_guard.mark_completed();
|
||||
drop(stream_guard);
|
||||
// Yield to ensure Release ordering is propagated
|
||||
tokio::task::yield_now().await;
|
||||
});
|
||||
|
||||
*response_json_out = ptr::null_mut();
|
||||
*is_done_out = 1;
|
||||
SglErrorCode::Success
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Free a stream handle and release all associated resources.
|
||||
///
|
||||
/// This function must be called exactly once for each stream handle returned by
|
||||
/// `sgl_client_chat_completion_stream`. It marks the stream as completed internally
|
||||
/// to prevent abort signals from being sent when resources are cleaned up.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `handle` - Mutable pointer to the stream handle to free
|
||||
/// - If NULL, this function does nothing
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// - Must be called only once per handle
|
||||
/// - Handle must not be used after calling this function
|
||||
/// - After this call, the stream is no longer valid
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// - This function internally calls `mark_completed()` before freeing to ensure
|
||||
/// the stream cleanup doesn't trigger an abort RPC to the server
|
||||
/// - Memory fences are used to ensure visibility across threads
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_stream_free(handle: *mut SglangStreamHandle) {
|
||||
if !handle.is_null() {
|
||||
let handle_ref = Box::from_raw(handle);
|
||||
|
||||
// Mark stream as completed to prevent abort on drop
|
||||
// By this point, the stream should already be completed by ReadNext()
|
||||
// but we call it again to be safe
|
||||
RUNTIME.block_on(async {
|
||||
let stream_guard = handle_ref.stream.lock().await;
|
||||
stream_guard.mark_completed();
|
||||
// Keep guard alive to ensure mark_completed() write completes
|
||||
drop(stream_guard);
|
||||
// Yield to ensure the atomic write is visible
|
||||
tokio::task::yield_now().await;
|
||||
});
|
||||
|
||||
// Use a strong memory fence to ensure mark_completed()'s Release write
|
||||
// is visible before we drop the last Arc reference
|
||||
std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
|
||||
|
||||
// Now drop all references - if mark_completed() was called successfully,
|
||||
// the drop won't send an abort
|
||||
drop(handle_ref.stream);
|
||||
|
||||
// Free converter
|
||||
let converter = Arc::try_unwrap(handle_ref.converter)
|
||||
.ok()
|
||||
.map(|m| m.into_inner());
|
||||
if let Some(conv) = converter {
|
||||
super::grpc_converter::sgl_grpc_response_converter_free(Box::into_raw(Box::new(conv)));
|
||||
}
|
||||
}
|
||||
}
|
||||
388
third_party/sglang/sgl-model-gateway/bindings/golang/src/tokenizer.rs
vendored
Normal file
388
third_party/sglang/sgl-model-gateway/bindings/golang/src/tokenizer.rs
vendored
Normal file
@@ -0,0 +1,388 @@
|
||||
//! Tokenizer FFI functions
|
||||
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::os::raw::{c_char, c_int};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
use serde_json::Value;
|
||||
|
||||
use smg::tokenizer::{
|
||||
create_tokenizer_from_file,
|
||||
traits::Tokenizer as TokenizerTrait,
|
||||
chat_template::ChatTemplateParams,
|
||||
huggingface::HuggingFaceTokenizer,
|
||||
};
|
||||
|
||||
use super::error::{SglErrorCode, set_error_message, clear_error_message};
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
type BooleanT = libc::boolean_t;
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
type BooleanT = libc::c_int;
|
||||
|
||||
/// Opaque handle for a tokenizer instance
|
||||
#[repr(C)]
|
||||
pub struct TokenizerHandle {
|
||||
pub(crate) tokenizer: Arc<dyn TokenizerTrait>,
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to tokenizer.json file (null-terminated C string)
|
||||
/// * `error_out` - Optional pointer to receive error message (must be freed with sgl_free_string)
|
||||
///
|
||||
/// # Returns
|
||||
/// * Pointer to TokenizerHandle on success, null on failure
|
||||
///
|
||||
/// # Safety
|
||||
/// The returned handle must be freed with `sgl_tokenizer_free`.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tokenizer_create_from_file(
|
||||
path: *const c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> *mut TokenizerHandle {
|
||||
if path.is_null() {
|
||||
set_error_message(error_out, "path cannot be null");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
|
||||
let path_str = match CStr::from_ptr(path).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Invalid UTF-8 in path: {}", e));
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
match create_tokenizer_from_file(path_str) {
|
||||
Ok(tokenizer) => {
|
||||
clear_error_message(error_out);
|
||||
Box::into_raw(Box::new(TokenizerHandle {
|
||||
tokenizer,
|
||||
}))
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &e.to_string());
|
||||
ptr::null_mut()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode text to token IDs
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handle` - Tokenizer handle (must not be null)
|
||||
/// * `text` - Input text (null-terminated C string)
|
||||
/// * `add_special_tokens` - Whether to add special tokens
|
||||
/// * `token_ids_out` - Pointer to receive array of token IDs (must be freed with sgl_free_token_ids)
|
||||
/// * `token_count_out` - Pointer to receive token count
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
///
|
||||
/// # Safety
|
||||
/// The token_ids_out array must be freed with sgl_free_token_ids() after use.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tokenizer_encode(
|
||||
handle: *mut TokenizerHandle,
|
||||
text: *const c_char,
|
||||
add_special_tokens: BooleanT,
|
||||
token_ids_out: *mut *mut u32,
|
||||
token_count_out: *mut usize,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if handle.is_null() || text.is_null() || token_ids_out.is_null() || token_count_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let text_str = match CStr::from_ptr(text).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in text");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
let add_special_tokens_bool = add_special_tokens != 0;
|
||||
|
||||
let tokenizer = &(*handle).tokenizer;
|
||||
match tokenizer.encode(text_str, add_special_tokens_bool) {
|
||||
Ok(encoding) => {
|
||||
let token_ids = encoding.token_ids();
|
||||
let count = token_ids.len();
|
||||
|
||||
// Allocate memory for token IDs using Vec, then leak to give ownership to C
|
||||
let vec = token_ids.to_vec();
|
||||
let ptr = vec.as_ptr() as *mut u32;
|
||||
let _ = std::mem::ManuallyDrop::new(vec);
|
||||
|
||||
*token_ids_out = ptr;
|
||||
*token_count_out = count;
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &e.to_string());
|
||||
SglErrorCode::TokenizationError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply chat template to messages with tools support
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handle` - Tokenizer handle
|
||||
/// * `messages_json` - JSON string of messages array
|
||||
/// * `tools_json` - Optional JSON string of tools array (null or empty string for no tools)
|
||||
/// * `result_out` - Pointer to receive result string (must be freed with sgl_free_string)
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tokenizer_apply_chat_template_with_tools(
|
||||
handle: *mut TokenizerHandle,
|
||||
messages_json: *const c_char,
|
||||
tools_json: *const c_char,
|
||||
result_out: *mut *mut c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if handle.is_null() || messages_json.is_null() || result_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let messages_str = match CStr::from_ptr(messages_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in messages_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse JSON messages
|
||||
let messages: Vec<Value> = match serde_json::from_str(messages_str) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse messages JSON: {}", e));
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse tools JSON if provided
|
||||
let tools: Option<Vec<Value>> = if tools_json.is_null() {
|
||||
None
|
||||
} else {
|
||||
let tools_str = match CStr::from_ptr(tools_json).to_str() {
|
||||
Ok(s) => {
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match serde_json::from_str::<Vec<Value>>(s) {
|
||||
Ok(t) => Some(t),
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse tools JSON: {}", e));
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in tools_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
tools_str
|
||||
};
|
||||
|
||||
// Get the tokenizer from handle
|
||||
let handle_ref = &*handle;
|
||||
let tokenizer = &handle_ref.tokenizer;
|
||||
|
||||
// Try to downcast to HuggingFaceTokenizer
|
||||
if let Some(hf_tokenizer) = tokenizer.as_any().downcast_ref::<HuggingFaceTokenizer>() {
|
||||
// Apply chat template with tools
|
||||
let empty_docs: [Value; 0] = [];
|
||||
let tools_slice = tools.as_ref().map(|t| t.as_slice());
|
||||
let params = ChatTemplateParams {
|
||||
add_generation_prompt: true,
|
||||
tools: tools_slice,
|
||||
documents: Some(&empty_docs),
|
||||
template_kwargs: None,
|
||||
};
|
||||
|
||||
match hf_tokenizer.apply_chat_template(&messages, params) {
|
||||
Ok(result) => {
|
||||
let result_cstr = match CString::new(result) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create result string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
*result_out = result_cstr.into_raw();
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to apply chat template: {}", e));
|
||||
SglErrorCode::TokenizationError
|
||||
}
|
||||
}
|
||||
} else {
|
||||
set_error_message(error_out, "Chat template is only supported for HuggingFace tokenizers");
|
||||
SglErrorCode::TokenizationError
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply chat template to messages
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handle` - Tokenizer handle
|
||||
/// * `messages_json` - JSON string of messages array
|
||||
/// * `result_out` - Pointer to receive result string (must be freed with sgl_free_string)
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tokenizer_apply_chat_template(
|
||||
handle: *mut TokenizerHandle,
|
||||
messages_json: *const c_char,
|
||||
result_out: *mut *mut c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if handle.is_null() || messages_json.is_null() || result_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let messages_str = match CStr::from_ptr(messages_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in messages_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse JSON messages
|
||||
let messages: Vec<Value> = match serde_json::from_str(messages_str) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to parse messages JSON: {}", e));
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Get the tokenizer from handle
|
||||
let handle_ref = &*handle;
|
||||
let tokenizer = &handle_ref.tokenizer;
|
||||
|
||||
// Try to downcast to HuggingFaceTokenizer
|
||||
if let Some(hf_tokenizer) = tokenizer.as_any().downcast_ref::<HuggingFaceTokenizer>() {
|
||||
// Apply chat template with default parameters
|
||||
// Use empty arrays instead of None to avoid template errors
|
||||
// Set add_generation_prompt to true so the model knows to start generating
|
||||
let empty_tools: [Value; 0] = [];
|
||||
let empty_docs: [Value; 0] = [];
|
||||
let params = ChatTemplateParams {
|
||||
add_generation_prompt: true, // Important: tells the model to start generating
|
||||
tools: Some(&empty_tools),
|
||||
documents: Some(&empty_docs),
|
||||
template_kwargs: None,
|
||||
};
|
||||
|
||||
match hf_tokenizer.apply_chat_template(&messages, params) {
|
||||
Ok(result) => {
|
||||
let result_cstr = match CString::new(result) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create result string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
*result_out = result_cstr.into_raw();
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to apply chat template: {}", e));
|
||||
SglErrorCode::TokenizationError
|
||||
}
|
||||
}
|
||||
} else {
|
||||
set_error_message(error_out, "Chat template is only supported for HuggingFace tokenizers");
|
||||
SglErrorCode::TokenizationError
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode token IDs to text
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handle` - Tokenizer handle
|
||||
/// * `token_ids` - Array of token IDs
|
||||
/// * `token_count` - Number of tokens
|
||||
/// * `skip_special_tokens` - Whether to skip special tokens
|
||||
/// * `result_out` - Pointer to receive result string (must be freed with sgl_free_string)
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tokenizer_decode(
|
||||
handle: *mut TokenizerHandle,
|
||||
token_ids: *const u32,
|
||||
token_count: usize,
|
||||
skip_special_tokens: c_int,
|
||||
result_out: *mut *mut c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if handle.is_null() || token_ids.is_null() || result_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
if token_count == 0 {
|
||||
let empty = CString::new("").unwrap();
|
||||
*result_out = empty.into_raw();
|
||||
clear_error_message(error_out);
|
||||
return SglErrorCode::Success;
|
||||
}
|
||||
|
||||
// Convert C array to Rust slice
|
||||
let token_slice = std::slice::from_raw_parts(token_ids, token_count);
|
||||
|
||||
let tokenizer = &(*handle).tokenizer;
|
||||
match tokenizer.decode(token_slice, skip_special_tokens != 0) {
|
||||
Ok(text) => {
|
||||
let result_cstr = match CString::new(text) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create result string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
*result_out = result_cstr.into_raw();
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &e.to_string());
|
||||
SglErrorCode::TokenizationError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Free a tokenizer handle
|
||||
///
|
||||
/// # Safety
|
||||
/// This function must only be called once per handle, and the handle must not be used after calling.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tokenizer_free(handle: *mut TokenizerHandle) {
|
||||
if !handle.is_null() {
|
||||
let _ = Box::from_raw(handle);
|
||||
}
|
||||
}
|
||||
329
third_party/sglang/sgl-model-gateway/bindings/golang/src/tool_parser.rs
vendored
Normal file
329
third_party/sglang/sgl-model-gateway/bindings/golang/src/tool_parser.rs
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
//! Tool parser FFI functions
|
||||
|
||||
use std::ffi::{CStr, CString};
|
||||
use std::os::raw::{c_char};
|
||||
use std::ptr;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use serde_json::{json, Value};
|
||||
use tokio::runtime::Runtime;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use smg::tool_parser::{ParserFactory, ToolParser};
|
||||
use smg::protocols::common::Tool;
|
||||
|
||||
use super::error::{SglErrorCode, set_error_message, clear_error_message};
|
||||
use super::utils::generate_tool_call_id;
|
||||
|
||||
/// Global parser factory (initialized once)
|
||||
static PARSER_FACTORY: Lazy<ParserFactory> = Lazy::new(|| ParserFactory::new());
|
||||
|
||||
/// Global tokio runtime for async operations
|
||||
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
|
||||
Runtime::new().expect("Failed to create tokio runtime for tool parser FFI")
|
||||
});
|
||||
|
||||
/// Opaque handle for a tool parser instance
|
||||
/// Note: For streaming, we need mutable access, so we use Arc<Mutex<>> internally
|
||||
/// Note: This is an opaque handle, C code doesn't access fields directly
|
||||
pub struct ToolParserHandle {
|
||||
parser: Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>,
|
||||
model: String, // Store model name for ID generation
|
||||
history_tool_calls_count: usize, // Track tool call count for ID generation
|
||||
tool_index_to_id: HashMap<usize, String>, // Map tool_index to ID for incremental updates
|
||||
}
|
||||
|
||||
/// Create a tool parser
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `parser_type` - Parser type name (e.g., "json", "llama", "mistral") or model name (e.g., "gpt-4")
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * Pointer to ToolParserHandle on success, null on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tool_parser_create(
|
||||
parser_type: *const c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> *mut ToolParserHandle {
|
||||
if parser_type.is_null() {
|
||||
set_error_message(error_out, "parser_type cannot be null");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
|
||||
let type_str = match CStr::from_ptr(parser_type).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in parser_type");
|
||||
return ptr::null_mut();
|
||||
}
|
||||
};
|
||||
|
||||
// Create parser using factory
|
||||
// The factory will determine the parser type based on model name or use the provided type
|
||||
let parser = if let Some(parser_box) = PARSER_FACTORY.registry().create_for_model(type_str) {
|
||||
parser_box
|
||||
} else if let Some(parser_box) = PARSER_FACTORY.registry().create_parser(type_str) {
|
||||
parser_box
|
||||
} else {
|
||||
set_error_message(error_out, &format!("Unknown parser type: {}", type_str));
|
||||
return ptr::null_mut();
|
||||
};
|
||||
|
||||
Box::into_raw(Box::new(ToolParserHandle {
|
||||
parser: Arc::new(tokio::sync::Mutex::new(parser)),
|
||||
model: type_str.to_string(),
|
||||
history_tool_calls_count: 0,
|
||||
tool_index_to_id: HashMap::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Parse complete tool calls from text
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handle` - Tool parser handle
|
||||
/// * `text` - Input text to parse
|
||||
/// * `result_json_out` - Pointer to receive JSON result (must be freed with sgl_free_string)
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tool_parser_parse_complete(
|
||||
handle: *mut ToolParserHandle,
|
||||
text: *const c_char,
|
||||
result_json_out: *mut *mut c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if handle.is_null() || text.is_null() || result_json_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let text_str = match CStr::from_ptr(text).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in text");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
let handle_ref = &*handle;
|
||||
let parser = Arc::clone(&handle_ref.parser);
|
||||
let model = handle_ref.model.clone();
|
||||
let history_count = handle_ref.history_tool_calls_count;
|
||||
|
||||
// Use tokio runtime to run async code
|
||||
let result = RUNTIME.block_on(async {
|
||||
let parser_guard = parser.lock().await;
|
||||
parser_guard.parse_complete(text_str).await
|
||||
});
|
||||
|
||||
match result {
|
||||
Ok((normal_text, tool_calls)) => {
|
||||
// Convert Rust ToolCall to OpenAI format
|
||||
let openai_tool_calls: Vec<Value> = tool_calls
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, tc)| {
|
||||
// Generate ID for this tool call
|
||||
let id = generate_tool_call_id(&model, &tc.function.name, index, history_count);
|
||||
json!({
|
||||
"id": id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build result JSON
|
||||
let result_json = json!({
|
||||
"normal_text": normal_text,
|
||||
"tool_calls": openai_tool_calls
|
||||
});
|
||||
|
||||
let result_str = match serde_json::to_string(&result_json) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to serialize JSON: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
let result_cstr = match CString::new(result_str) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create result string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
*result_json_out = result_cstr.into_raw();
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Parse error: {}", e));
|
||||
SglErrorCode::ParsingError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse tool calls incrementally from streaming chunks
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `handle` - Tool parser handle
|
||||
/// * `chunk` - New text chunk from stream
|
||||
/// * `tools_json` - JSON array of available tools (for validation, can be null/empty)
|
||||
/// * `result_json_out` - Pointer to receive JSON result (must be freed with sgl_free_string)
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tool_parser_parse_incremental(
|
||||
handle: *mut ToolParserHandle,
|
||||
chunk: *const c_char,
|
||||
tools_json: *const c_char,
|
||||
result_json_out: *mut *mut c_char,
|
||||
error_out: *mut *mut c_char,
|
||||
) -> SglErrorCode {
|
||||
if handle.is_null() || chunk.is_null() || result_json_out.is_null() {
|
||||
set_error_message(error_out, "Invalid arguments: null pointer");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
|
||||
let chunk_str = match CStr::from_ptr(chunk).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in chunk");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse tools JSON if provided
|
||||
let tools: Vec<Tool> = if !tools_json.is_null() {
|
||||
let tools_str = match CStr::from_ptr(tools_json).to_str() {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
set_error_message(error_out, "Invalid UTF-8 in tools_json");
|
||||
return SglErrorCode::InvalidArgument;
|
||||
}
|
||||
};
|
||||
match serde_json::from_str::<Vec<Tool>>(tools_str) {
|
||||
Ok(t) => t,
|
||||
Err(_) => vec![], // If parsing fails, use empty tools
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let handle_ref = &*handle;
|
||||
let parser = Arc::clone(&handle_ref.parser);
|
||||
let model = handle_ref.model.clone();
|
||||
let history_count = handle_ref.history_tool_calls_count;
|
||||
|
||||
// Use tokio runtime to run async code
|
||||
let result = RUNTIME.block_on(async {
|
||||
let mut parser_guard = parser.lock().await;
|
||||
parser_guard.parse_incremental(chunk_str, &tools).await
|
||||
});
|
||||
|
||||
match result {
|
||||
Ok(streaming_result) => {
|
||||
// Convert StreamingParseResult to OpenAI format
|
||||
let handle_mut = &mut *handle;
|
||||
let openai_tool_calls: Vec<Value> = streaming_result
|
||||
.calls
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
// For incremental parsing, we may not have complete tool calls yet
|
||||
// Generate or reuse ID based on tool_index
|
||||
let id = if let Some(ref name) = item.name {
|
||||
// New tool call with name - generate ID and store it
|
||||
let id = generate_tool_call_id(&model, name, item.tool_index, history_count);
|
||||
handle_mut.tool_index_to_id.insert(item.tool_index, id.clone());
|
||||
id
|
||||
} else {
|
||||
// Parameter update - reuse existing ID for this tool_index
|
||||
handle_mut.tool_index_to_id
|
||||
.get(&item.tool_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| format!("call_{}", item.tool_index))
|
||||
};
|
||||
|
||||
json!({
|
||||
"id": id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item.name.unwrap_or_default(),
|
||||
"arguments": item.parameters
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build result JSON
|
||||
let result_json = json!({
|
||||
"normal_text": streaming_result.normal_text,
|
||||
"tool_calls": openai_tool_calls
|
||||
});
|
||||
|
||||
let result_str = match serde_json::to_string(&result_json) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to serialize JSON: {}", e));
|
||||
return SglErrorCode::ParsingError;
|
||||
}
|
||||
};
|
||||
|
||||
let result_cstr = match CString::new(result_str) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Failed to create result string: {}", e));
|
||||
return SglErrorCode::MemoryError;
|
||||
}
|
||||
};
|
||||
|
||||
*result_json_out = result_cstr.into_raw();
|
||||
clear_error_message(error_out);
|
||||
SglErrorCode::Success
|
||||
}
|
||||
Err(e) => {
|
||||
set_error_message(error_out, &format!("Parse incremental error: {}", e));
|
||||
SglErrorCode::ParsingError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the parser state for reuse
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tool_parser_reset(handle: *mut ToolParserHandle) {
|
||||
if handle.is_null() {
|
||||
return;
|
||||
}
|
||||
|
||||
let handle_ref = &mut *handle;
|
||||
let parser = Arc::clone(&handle_ref.parser);
|
||||
|
||||
// Reset parser state
|
||||
RUNTIME.block_on(async {
|
||||
let mut parser_guard = parser.lock().await;
|
||||
parser_guard.reset();
|
||||
});
|
||||
|
||||
// Reset history count and tool index mapping
|
||||
handle_ref.history_tool_calls_count = 0;
|
||||
handle_ref.tool_index_to_id.clear();
|
||||
}
|
||||
|
||||
/// Free a tool parser handle
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_tool_parser_free(handle: *mut ToolParserHandle) {
|
||||
if !handle.is_null() {
|
||||
let _ = Box::from_raw(handle);
|
||||
}
|
||||
}
|
||||
44
third_party/sglang/sgl-model-gateway/bindings/golang/src/utils.rs
vendored
Normal file
44
third_party/sglang/sgl-model-gateway/bindings/golang/src/utils.rs
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
//! Utility functions for FFI
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Helper function to generate tool call ID (matches router implementation)
|
||||
pub fn generate_tool_call_id(
|
||||
model: &str,
|
||||
function_name: &str,
|
||||
index: usize,
|
||||
history_tool_calls_count: usize,
|
||||
) -> String {
|
||||
if model.to_lowercase().contains("kimi") {
|
||||
// KimiK2 format: functions.{name}:{global_index}
|
||||
format!("functions.{}:{}", function_name, history_tool_calls_count + index)
|
||||
} else {
|
||||
// Standard OpenAI format: call_{24-char-uuid}
|
||||
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tool constraints (placeholder implementation)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tools_json` - JSON array of tools
|
||||
/// * `tool_choice_json` - JSON object representing tool_choice
|
||||
/// * `constraint_type_out` - Pointer to receive constraint type (e.g., "json_schema")
|
||||
/// * `constraint_schema_out` - Pointer to receive constraint schema JSON
|
||||
/// * `error_out` - Optional pointer to receive error message
|
||||
///
|
||||
/// # Returns
|
||||
/// * SglErrorCode::Success on success, error code on failure
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn sgl_generate_tool_constraints(
|
||||
_tools_json: *const std::os::raw::c_char,
|
||||
_tool_choice_json: *const std::os::raw::c_char,
|
||||
_constraint_type_out: *mut *mut std::os::raw::c_char,
|
||||
_constraint_schema_out: *mut *mut std::os::raw::c_char,
|
||||
error_out: *mut *mut std::os::raw::c_char,
|
||||
) -> super::error::SglErrorCode {
|
||||
// Implementation would parse JSON and call generate_tool_constraints
|
||||
// This is a placeholder
|
||||
super::error::set_error_message(error_out, "Tool constraint generation not yet implemented in FFI");
|
||||
super::error::SglErrorCode::UnknownError
|
||||
}
|
||||
9
third_party/sglang/sgl-model-gateway/bindings/python/.coveragerc
vendored
Normal file
9
third_party/sglang/sgl-model-gateway/bindings/python/.coveragerc
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
[run]
|
||||
source = sglang_router
|
||||
omit =
|
||||
*/mini_lb.py
|
||||
*/cli.py
|
||||
*/__main__.py
|
||||
|
||||
[report]
|
||||
fail_under = 80
|
||||
28
third_party/sglang/sgl-model-gateway/bindings/python/Cargo.toml
vendored
Normal file
28
third_party/sglang/sgl-model-gateway/bindings/python/Cargo.toml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "sgl-model-gateway-python"
|
||||
version = "0.3.2"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "sglang_router_rs"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
pyo3 = { version = "0.27.1", features = ["extension-module", "abi3-py38"] }
|
||||
tokio = { version = "1.42.0", features = ["full"] }
|
||||
once_cell = "1.19"
|
||||
|
||||
[dependencies.sgl-model-gateway]
|
||||
path = "../.."
|
||||
default-features = true
|
||||
|
||||
[features]
|
||||
default = ["pyo3/extension-module"]
|
||||
vendored-openssl = ["sgl-model-gateway/vendored-openssl"]
|
||||
|
||||
[profile.ci]
|
||||
inherits = "release"
|
||||
opt-level = 2 # Lighter optimization (still fast runtime, much faster compile)
|
||||
lto = "thin" # Thin LTO - good balance
|
||||
codegen-units = 16 # More parallelization for faster builds
|
||||
strip = true
|
||||
9
third_party/sglang/sgl-model-gateway/bindings/python/MANIFEST.in
vendored
Normal file
9
third_party/sglang/sgl-model-gateway/bindings/python/MANIFEST.in
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# Must include:
|
||||
include Cargo.toml # Python bindings Cargo configuration
|
||||
include ../../Cargo.toml # Main Rust project configuration
|
||||
include ../../build.rs # Build script for protobuf generation
|
||||
include ../../LICENSE
|
||||
recursive-include src *.rs # Python bindings wrapper
|
||||
recursive-include ../../src *.rs # Main Rust source files
|
||||
recursive-include ../../src/proto *.proto # Protobuf definitions
|
||||
recursive-include sglang_router *.py # Python source files
|
||||
77
third_party/sglang/sgl-model-gateway/bindings/python/README.md
vendored
Normal file
77
third_party/sglang/sgl-model-gateway/bindings/python/README.md
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
# SGLang Model Gateway Python Bindings
|
||||
|
||||
This directory contains the Python bindings for the SGLang Router, built using [maturin](https://github.com/PyO3/maturin) and [PyO3](https://github.com/PyO3/pyo3).
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
bindings/python/
|
||||
├── src/ # Source code (src layout)
|
||||
│ ├── lib.rs # Rust/PyO3 bindings implementation
|
||||
│ └── sglang_router/ # Python source code
|
||||
│ ├── __init__.py
|
||||
│ ├── version.py
|
||||
│ ├── launch_server.py
|
||||
│ ├── launch_router.py
|
||||
│ ├── router.py
|
||||
│ ├── router_args.py
|
||||
│ └── mini_lb.py
|
||||
├── tests/ # Python unit tests
|
||||
│ ├── conftest.py
|
||||
│ ├── test_validation.py
|
||||
│ ├── test_arg_parser.py
|
||||
│ ├── test_router_config.py
|
||||
│ └── test_startup_sequence.py
|
||||
├── Cargo.toml # Rust package configuration for bindings
|
||||
├── pyproject.toml # Python package configuration
|
||||
├── setup.py # Setup configuration
|
||||
├── MANIFEST.in # Package manifest
|
||||
├── .coveragerc # Test coverage configuration
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
### Development Build
|
||||
|
||||
```bash
|
||||
# Install maturin
|
||||
pip install maturin
|
||||
|
||||
# Build and install in development mode
|
||||
cd sgl-model-gateway/bindings/python
|
||||
maturin develop --features vendored-openssl
|
||||
```
|
||||
|
||||
### Production Build
|
||||
|
||||
```bash
|
||||
# Build wheel
|
||||
cd sgl-model-gateway/bindings/python
|
||||
maturin build --release --out dist --features vendored-openssl
|
||||
|
||||
# Install the built wheel
|
||||
pip install dist/sglang_router-*.whl
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run Python unit tests (after maturin develop)
|
||||
cd sgl-model-gateway/bindings/python
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
- **pyproject.toml**: Defines package metadata, dependencies, and build configuration
|
||||
- **python-source**: Set to `"src"` indicating Python source uses the src layout
|
||||
- **module-name**: `sglang_router.sglang_router_rs` - the Rust extension module name
|
||||
|
||||
## Notes
|
||||
|
||||
- The Rust bindings source code is located in `src/lib.rs`
|
||||
- The bindings have their own `Cargo.toml` in this directory
|
||||
- The main sglang-router library is located in `../../` and is used as a dependency
|
||||
- The package includes both Python code and Rust extensions built with PyO3
|
||||
- PyO3 types are prefixed with `Py` in Rust but exposed to Python without the prefix using the `name` attribute
|
||||
64
third_party/sglang/sgl-model-gateway/bindings/python/pyproject.toml
vendored
Normal file
64
third_party/sglang/sgl-model-gateway/bindings/python/pyproject.toml
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "sglang-router"
|
||||
version = "0.3.2"
|
||||
description = "High-performance Rust-based load balancer for SGLang with multiple routing algorithms and prefill-decode disaggregation support"
|
||||
authors = [
|
||||
{name = "Simo Lin", email = "linsimo.mark@gmail.com"},
|
||||
{name = "Chang Su", email = "mckvtl@gmail.com"},
|
||||
{name = "Keyang Ru", email = "rukeyang@gmail.com"},
|
||||
{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
readme = "../../README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Programming Language :: Rust",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"setproctitle",
|
||||
"aiohttp",
|
||||
"orjson",
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"requests>=2.25.0",
|
||||
"pytest>=7.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
smg = "sglang_router.cli:main"
|
||||
amg = "sglang_router.cli:main"
|
||||
sglang-router = "sglang_router.cli:main"
|
||||
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "src"
|
||||
module-name = "sglang_router.sglang_router_rs"
|
||||
# Exclude bindings/python/README.md to use root README only
|
||||
exclude = ["README.md"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
markers = [
|
||||
"unit: mark test as a unit test (no GPU required)",
|
||||
]
|
||||
28
third_party/sglang/sgl-model-gateway/bindings/python/setup.py
vendored
Normal file
28
third_party/sglang/sgl-model-gateway/bindings/python/setup.py
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
with_rust = os.environ.get("SGLANG_ROUTER_BUILD_WITH_RUST", None)
|
||||
with_rust = with_rust is None or (not with_rust.lower() in ["0", "false", "no"])
|
||||
|
||||
rust_extensions = []
|
||||
if with_rust:
|
||||
from setuptools_rust import Binding, RustExtension
|
||||
|
||||
rust_extensions.append(
|
||||
RustExtension(
|
||||
target="sglang_router_rs",
|
||||
path="Cargo.toml",
|
||||
binding=Binding.PyO3,
|
||||
)
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Building 'sglang-router' without Rust support. Performance may be degraded."
|
||||
)
|
||||
|
||||
setup(
|
||||
rust_extensions=rust_extensions,
|
||||
zip_safe=False,
|
||||
)
|
||||
1029
third_party/sglang/sgl-model-gateway/bindings/python/src/lib.rs
vendored
Normal file
1029
third_party/sglang/sgl-model-gateway/bindings/python/src/lib.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
3
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/__init__.py
vendored
Normal file
3
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/__init__.py
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
from sglang_router.version import __version__
|
||||
|
||||
__all__ = ["__version__"]
|
||||
8
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/__main__.py
vendored
Normal file
8
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/__main__.py
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Allow running the CLI via: python -m sglang_router
|
||||
"""
|
||||
|
||||
from sglang_router.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
107
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/cli.py
vendored
Executable file
107
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/cli.py
vendored
Executable file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SGLang Model Gateway CLI
|
||||
|
||||
Provides convenient command-line interface for launching the router and server.
|
||||
|
||||
Usage:
|
||||
smg launch [args] # Launch router only
|
||||
smg server [args] # Launch router + server
|
||||
smg --help # Show help
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang_router.sglang_router_rs import (
|
||||
get_verbose_version_string,
|
||||
get_version_string,
|
||||
)
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
"""Create the main CLI parser with subcommands."""
|
||||
prog_name = os.path.basename(sys.argv[0]) if sys.argv else "smg"
|
||||
parser = argparse.ArgumentParser(
|
||||
prog=prog_name,
|
||||
description="SGLang Model Gateway - High-performance inference router",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Launch router subcommand
|
||||
launch_parser = subparsers.add_parser(
|
||||
"launch",
|
||||
help="Launch router only (requires existing worker URLs)",
|
||||
description="Launch the SGLang router with existing worker instances",
|
||||
add_help=False, # Let router handle --help
|
||||
)
|
||||
|
||||
# Launch server + router subcommand
|
||||
server_parser = subparsers.add_parser(
|
||||
"server",
|
||||
help="Launch router and server processes together",
|
||||
description="Launch both SGLang router and server processes",
|
||||
add_help=False, # Let server handle --help
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> None:
|
||||
"""Main CLI entry point."""
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
# Handle version flags before parsing
|
||||
if argv and argv[0] in ["--version", "-V", "--version-verbose"]:
|
||||
if argv[0] == "--version-verbose":
|
||||
print(get_verbose_version_string())
|
||||
else:
|
||||
print(get_version_string())
|
||||
sys.exit(0)
|
||||
|
||||
# Handle empty command - show help
|
||||
if not argv or argv[0] not in ["launch", "server", "-h", "--help"]:
|
||||
parser = create_parser()
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
parser = create_parser()
|
||||
args, unknown = parser.parse_known_args(argv)
|
||||
|
||||
if args.command == "launch":
|
||||
# Import and call launch_router functions directly
|
||||
from sglang_router.launch_router import launch_router, parse_router_args
|
||||
|
||||
# All router args are in unknown
|
||||
router_args = parse_router_args(unknown)
|
||||
launch_router(router_args)
|
||||
|
||||
elif args.command == "server":
|
||||
# Import and call launch_server main with proper argv
|
||||
# Note: launch_server.main() uses argparse internally which reads sys.argv
|
||||
# We need to temporarily set sys.argv for compatibility
|
||||
import sglang_router.launch_server as launch_server_module
|
||||
|
||||
# Preserve original sys.argv
|
||||
original_argv = sys.argv
|
||||
try:
|
||||
# All server args are in unknown
|
||||
prog_name = os.path.basename(sys.argv[0]) if sys.argv else "smg"
|
||||
sys.argv = [f"{prog_name} server"] + unknown
|
||||
launch_server_module.main()
|
||||
finally:
|
||||
# Restore original sys.argv
|
||||
sys.argv = original_argv
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/launch_router.py
vendored
Normal file
109
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/launch_router.py
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import setproctitle
|
||||
from sglang_router.mini_lb import MiniLoadBalancer
|
||||
from sglang_router.router_args import RouterArgs
|
||||
|
||||
logger = logging.getLogger("router")
|
||||
|
||||
try:
|
||||
from sglang_router.router import Router
|
||||
except ImportError:
|
||||
Router = None
|
||||
logger.warning(
|
||||
"Rust Router is not installed, only python MiniLB (debugging only) is available"
|
||||
)
|
||||
|
||||
|
||||
def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
"""
|
||||
Launch the SGLang router with the configuration from parsed arguments.
|
||||
|
||||
Args:
|
||||
args: Namespace object containing router configuration
|
||||
Can be either raw argparse.Namespace or converted RouterArgs
|
||||
|
||||
Returns:
|
||||
Router instance if successful, None if failed
|
||||
"""
|
||||
setproctitle.setproctitle("sglang::router")
|
||||
try:
|
||||
# Convert to RouterArgs if needed
|
||||
if not isinstance(args, RouterArgs):
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
if router_args.mini_lb:
|
||||
mini_lb = MiniLoadBalancer(router_args)
|
||||
mini_lb.start()
|
||||
else:
|
||||
if Router is None:
|
||||
raise RuntimeError("Rust Router is not installed")
|
||||
router_args._validate_router_args()
|
||||
router = Router.from_args(router_args)
|
||||
router.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting router: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
class CustomHelpFormatter(
|
||||
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
|
||||
):
|
||||
"""Custom formatter that preserves both description formatting and shows defaults"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def parse_router_args(args: List[str]) -> RouterArgs:
|
||||
"""Parse command line arguments and return RouterArgs instance."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""SGLang Router - High-performance request distribution across worker nodes
|
||||
|
||||
Usage:
|
||||
This launcher enables starting a router with individual worker instances. It is useful for
|
||||
multi-node setups or when you want to start workers and router separately.
|
||||
|
||||
Examples:
|
||||
# Regular mode
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
|
||||
# PD disaggregated mode with same policy for both
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--policy cache_aware
|
||||
|
||||
# PD mode with optional bootstrap ports
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 9000 \\ # With bootstrap port
|
||||
--prefill http://prefill2:8000 none \\ # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8000 \\ # Defaults to no bootstrap port
|
||||
--decode http://decode1:8001 --decode http://decode2:8001
|
||||
|
||||
# PD mode with different policies for prefill and decode
|
||||
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||
--prefill http://prefill1:8000 --prefill http://prefill2:8000 \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--prefill-policy cache_aware --decode-policy power_of_two
|
||||
|
||||
""",
|
||||
formatter_class=CustomHelpFormatter,
|
||||
)
|
||||
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=False)
|
||||
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
router_args = parse_router_args(sys.argv[1:])
|
||||
launch_router(router_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
213
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/launch_server.py
vendored
Normal file
213
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/launch_server.py
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from setproctitle import setproctitle
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils.network import is_port_available
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logger = logging.getLogger("router")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Create new process group
|
||||
def run_server(server_args, dp_rank):
|
||||
"""
|
||||
Note:
|
||||
|
||||
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
|
||||
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
|
||||
|
||||
Terminal (PGID=100)
|
||||
└── Main Python Process (PGID=100)
|
||||
└── Server Process 1 (PGID=100)
|
||||
└── Scheduler 1
|
||||
└── Detokenizer 1
|
||||
└── Server Process 2 (PGID=100)
|
||||
└── Scheduler 2
|
||||
└── Detokenizer 2
|
||||
|
||||
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
|
||||
|
||||
Terminal (PGID=100)
|
||||
└── Main Python Process (PGID=200)
|
||||
└── Server Process 1 (PGID=300)
|
||||
└── Scheduler 1
|
||||
└── Detokenizer 1
|
||||
└── Server Process 2 (PGID=400)
|
||||
└── Scheduler 2
|
||||
└── Detokenizer 2
|
||||
"""
|
||||
# create new process group
|
||||
os.setpgrp()
|
||||
|
||||
setproctitle("sglang::server")
|
||||
# Set SGLANG_DP_RANK environment variable
|
||||
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
|
||||
|
||||
# Launch server in appropriate mode (HTTP or gRPC)
|
||||
if server_args.grpc_mode:
|
||||
from sglang.srt.entrypoints.grpc_server import serve_grpc
|
||||
|
||||
asyncio.run(serve_grpc(server_args))
|
||||
else:
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
|
||||
launch_server(server_args)
|
||||
|
||||
|
||||
def launch_server_process(
|
||||
server_args: ServerArgs, worker_port: int, dp_id: int
|
||||
) -> mp.Process:
|
||||
"""Launch a single server process with the given args and port."""
|
||||
server_args = copy.deepcopy(server_args)
|
||||
server_args.port = worker_port
|
||||
server_args.base_gpu_id = dp_id * server_args.tp_size
|
||||
server_args.dp_size = 1
|
||||
|
||||
proc = mp.Process(target=run_server, args=(server_args, dp_id))
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
|
||||
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
|
||||
"""Wait for server to be healthy by checking /health endpoint."""
|
||||
start_time = time.perf_counter()
|
||||
url = f"http://{host}:{port}/health"
|
||||
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
def find_available_ports(base_port: int, count: int) -> List[int]:
|
||||
"""Find consecutive available ports starting from base_port."""
|
||||
available_ports = []
|
||||
current_port = base_port
|
||||
|
||||
while len(available_ports) < count:
|
||||
if is_port_available(current_port):
|
||||
available_ports.append(current_port)
|
||||
current_port += random.randint(100, 1000)
|
||||
|
||||
return available_ports
|
||||
|
||||
|
||||
def cleanup_processes(processes: List[mp.Process]):
|
||||
for process in processes:
|
||||
logger.info(f"Terminating process group {process.pid}")
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
# Process group may already be terminated
|
||||
pass
|
||||
|
||||
# Wait for processes to terminate
|
||||
for process in processes:
|
||||
process.join(timeout=5)
|
||||
if process.is_alive():
|
||||
logger.warning(
|
||||
f"Process {process.pid} did not terminate gracefully, forcing kill"
|
||||
)
|
||||
try:
|
||||
os.killpg(process.pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
logger.info("All process groups terminated")
|
||||
|
||||
|
||||
def main():
|
||||
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch SGLang router and server processes"
|
||||
)
|
||||
|
||||
ServerArgs.add_cli_args(parser)
|
||||
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
|
||||
parser.add_argument(
|
||||
"--router-dp-worker-base-port",
|
||||
type=int,
|
||||
default=31000,
|
||||
help="Base port number for data parallel workers",
|
||||
)
|
||||
# No extra retry/CB flags here; RouterArgs.add_cli_args already defines them with router- prefix
|
||||
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Find available ports for workers
|
||||
worker_ports = find_available_ports(
|
||||
args.router_dp_worker_base_port, server_args.dp_size
|
||||
)
|
||||
|
||||
# Start server processes
|
||||
server_processes = []
|
||||
|
||||
for i, worker_port in enumerate(worker_ports):
|
||||
logger.info(f"Launching DP server process {i} on port {worker_port}")
|
||||
proc = launch_server_process(server_args, worker_port, i)
|
||||
server_processes.append(proc)
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
|
||||
signal.signal(
|
||||
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
|
||||
)
|
||||
signal.signal(
|
||||
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
|
||||
)
|
||||
|
||||
# Update router args with worker URLs
|
||||
# Use grpc:// protocol if server is in gRPC mode, otherwise http://
|
||||
protocol = "grpc" if server_args.grpc_mode else "http"
|
||||
router_args.worker_urls = [
|
||||
f"{protocol}://{server_args.host}:{port}" for port in worker_ports
|
||||
]
|
||||
|
||||
# Start the router
|
||||
try:
|
||||
launch_router(router_args)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start router: {e}")
|
||||
cleanup_processes(server_processes)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
462
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/mini_lb.py
vendored
Normal file
462
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/mini_lb.py
vendored
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import random
|
||||
import urllib
|
||||
import warnings
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
from sglang_router.router_args import RouterArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
||||
1024 * 64
|
||||
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
||||
|
||||
|
||||
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return f"[{address}]"
|
||||
except ValueError:
|
||||
return address
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
def __init__(
|
||||
self,
|
||||
router_args: RouterArgs,
|
||||
):
|
||||
self._validate_router_args(router_args)
|
||||
|
||||
self.host = router_args.host
|
||||
self.port = router_args.port
|
||||
self.timeout = router_args.request_timeout_secs
|
||||
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
|
||||
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
|
||||
self.decode_urls = router_args.decode_urls
|
||||
self.test_external_dp_routing = router_args.test_external_dp_routing
|
||||
self.prefill_dp_size = None
|
||||
self.decode_dp_size = None
|
||||
|
||||
def _validate_router_args(self, router_args: RouterArgs):
|
||||
logger.warning(
|
||||
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
|
||||
)
|
||||
|
||||
# NOTE: too many arguments unsupported, just validate some important ones
|
||||
if router_args.policy != "random":
|
||||
logger.warning("[MiniLB] Overriding policy to random")
|
||||
router_args.policy = "random"
|
||||
|
||||
if not router_args.pd_disaggregation:
|
||||
raise ValueError("MiniLB only supports PD disaggregation mode")
|
||||
|
||||
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
|
||||
raise ValueError(
|
||||
"MiniLB requires at least one prefill and one decode server"
|
||||
)
|
||||
|
||||
def start(self):
|
||||
global lb
|
||||
lb = self
|
||||
uvicorn.run(app, host=self.host, port=self.port)
|
||||
|
||||
async def _ensure_dp_sizes(self):
|
||||
if self.prefill_dp_size is not None:
|
||||
return
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.prefill_urls[0]}/server_info") as resp:
|
||||
info = await resp.json()
|
||||
self.prefill_dp_size = len(info.get("internal_states", [1]))
|
||||
async with session.get(f"{self.decode_urls[0]}/server_info") as resp:
|
||||
info = await resp.json()
|
||||
self.decode_dp_size = len(info.get("internal_states", [1]))
|
||||
logger.info(
|
||||
f"[MiniLB] DP sizes: prefill={self.prefill_dp_size}, decode={self.decode_dp_size}"
|
||||
)
|
||||
|
||||
def _fork_dp_requests(self, request):
|
||||
p_rank = random.randint(0, self.prefill_dp_size - 1)
|
||||
d_rank = random.randint(0, self.decode_dp_size - 1)
|
||||
|
||||
prefill_req = request.copy()
|
||||
decode_req = request.copy()
|
||||
prefill_req["routed_dp_rank"] = p_rank
|
||||
decode_req["routed_dp_rank"] = d_rank
|
||||
decode_req["disagg_prefill_dp_rank"] = p_rank
|
||||
|
||||
return prefill_req, decode_req, d_rank
|
||||
|
||||
def select_pair(self):
|
||||
assert len(self.prefill_urls) > 0, "No prefill servers available"
|
||||
assert len(self.decode_urls) > 0, "No decode servers available"
|
||||
pidx = random.randint(0, len(self.prefill_urls) - 1)
|
||||
didx = random.randint(0, len(self.decode_urls) - 1)
|
||||
return (
|
||||
self.prefill_urls[pidx],
|
||||
self.prefill_bootstrap_ports[pidx],
|
||||
self.decode_urls[didx],
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self, modified_request, prefill_server, decode_server, endpoint
|
||||
) -> ORJSONResponse:
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
expected_decode_dp_rank = None
|
||||
if self.test_external_dp_routing:
|
||||
await self._ensure_dp_sizes()
|
||||
prefill_req, decode_req, expected_decode_dp_rank = self._fork_dp_requests(
|
||||
modified_request
|
||||
)
|
||||
else:
|
||||
prefill_req = modified_request
|
||||
decode_req = modified_request
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=prefill_req),
|
||||
session.post(f"{decode_server}/{endpoint}", json=decode_req),
|
||||
]
|
||||
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if "return_logprob" in modified_request:
|
||||
|
||||
prefill_json = await prefill_response.json()
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
# merge `meta_info.input_token_logprobs` from prefill to decode
|
||||
if "meta_info" in ret_json:
|
||||
if "input_token_logprobs" in ret_json["meta_info"]:
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
prefill_json["meta_info"]["input_token_logprobs"]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
if expected_decode_dp_rank is not None:
|
||||
actual = ret_json.get("meta_info", {}).get("dp_rank")
|
||||
if actual != expected_decode_dp_rank:
|
||||
return ORJSONResponse(
|
||||
content={
|
||||
"error": f"DP rank mismatch: expected {expected_decode_dp_rank}, got {actual}"
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return ORJSONResponse(
|
||||
content=ret_json,
|
||||
status_code=decode_response.status,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
||||
):
|
||||
|
||||
if self.test_external_dp_routing:
|
||||
warnings.warn("--test-external-dp-routing is not supported with streaming")
|
||||
|
||||
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.timeout
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if modified_request.get("return_logprob", False):
|
||||
prefill_chunks = []
|
||||
async for chunk in prefill_response.content:
|
||||
prefill_chunks.append(chunk)
|
||||
|
||||
first_prefill_chunk = (
|
||||
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
||||
)
|
||||
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
||||
|
||||
async for chunk in decode_response.content:
|
||||
# Note: This is inefficient
|
||||
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
||||
decoded_chunk = chunk.decode("utf-8")
|
||||
if (
|
||||
decoded_chunk
|
||||
and decoded_chunk.startswith("data:")
|
||||
and "[DONE]" not in decoded_chunk
|
||||
):
|
||||
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
first_prefill_chunk_json["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
|
||||
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in decode_response.content.iter_chunked(
|
||||
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
lb: Optional[MiniLoadBalancer] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||
tasks.append(session.get(f"{server}/health_generate"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(lb.prefill_urls, lb.decode_urls):
|
||||
tasks.append(session.post(f"{server}/flush_cache"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
# TODO: Remove `/get_server_info` alias after one release-cycle deprecation window.
|
||||
@app.get("/server_info")
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
prefill_infos = []
|
||||
decode_infos = []
|
||||
all_internal_states = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for server in lb.prefill_urls:
|
||||
server_info = await session.get(f"{server}/server_info")
|
||||
prefill_infos.append(await server_info.json())
|
||||
for server in lb.decode_urls:
|
||||
server_info = await session.get(f"{server}/server_info")
|
||||
info_json = await server_info.json()
|
||||
decode_infos.append(info_json)
|
||||
# Extract internal_states from decode servers
|
||||
if "internal_states" in info_json:
|
||||
all_internal_states.extend(info_json["internal_states"])
|
||||
|
||||
# Return format expected by bench_one_batch_server.py
|
||||
if all_internal_states:
|
||||
return {
|
||||
"internal_states": all_internal_states,
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
else:
|
||||
# Fallback with dummy data if no internal states found
|
||||
return {
|
||||
"internal_states": [
|
||||
{
|
||||
"last_gen_throughput": 0.0,
|
||||
"avg_spec_accept_length": None,
|
||||
}
|
||||
],
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
|
||||
|
||||
async def _get_model_info_impl():
|
||||
if not lb or not lb.prefill_urls:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail="There is no server registered",
|
||||
)
|
||||
|
||||
target_server_url = lb.prefill_urls[0]
|
||||
endpoint_url = f"{target_server_url}/model_info"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(endpoint_url) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_GATEWAY,
|
||||
detail=(
|
||||
f"Failed to get model info from {target_server_url}"
|
||||
f"Status: {response.status}, Response: {error_text}"
|
||||
),
|
||||
)
|
||||
|
||||
model_info_json = await response.json()
|
||||
return ORJSONResponse(content=model_info_json)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
detail=f"Failed to get model info from backend",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/model_info")
|
||||
async def model_info():
|
||||
return await _get_model_info_impl()
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
return await _get_model_info_impl()
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def handle_generate_request(request_data: dict):
|
||||
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
|
||||
batch_size = _get_request_batch_size(modified_request)
|
||||
if batch_size is not None:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": [hostname] * batch_size,
|
||||
"bootstrap_port": [bootstrap_port] * batch_size,
|
||||
"bootstrap_room": [
|
||||
_generate_bootstrap_room() for _ in range(batch_size)
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await lb.generate_stream(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
else:
|
||||
return await lb.generate(
|
||||
modified_request, prefill_server, decode_server, "generate"
|
||||
)
|
||||
|
||||
|
||||
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||
prefill_server, bootstrap_port, decode_server = lb.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await lb.generate_stream(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
else:
|
||||
return await lb.generate(
|
||||
modified_request,
|
||||
prefill_server,
|
||||
decode_server,
|
||||
endpoint=endpoint_name,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/chat/completions")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completion_request(request_data: dict):
|
||||
return await _forward_to_backend(request_data, "v1/completions")
|
||||
|
||||
|
||||
def _generate_bootstrap_room():
|
||||
return random.randint(0, 2**63 - 1)
|
||||
|
||||
|
||||
# We may utilize `GenerateReqInput`'s logic later
|
||||
def _get_request_batch_size(request):
|
||||
if (text := request.get("text")) is not None:
|
||||
return None if isinstance(text, str) else len(text)
|
||||
if (input_ids := request.get("input_ids")) is not None:
|
||||
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_models():
|
||||
prefill_server = lb.prefill_urls[0] # Get the first prefill server
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
response = await session.get(f"{prefill_server}/v1/models")
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status}",
|
||||
)
|
||||
return ORJSONResponse(content=await response.json())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
320
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/router.py
vendored
Normal file
320
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/router.py
vendored
Normal file
@@ -0,0 +1,320 @@
|
||||
from typing import Optional
|
||||
|
||||
from sglang_router.router_args import RouterArgs
|
||||
from sglang_router.sglang_router_rs import (
|
||||
BackendType,
|
||||
HistoryBackendType,
|
||||
PolicyType,
|
||||
PyApiKeyEntry,
|
||||
PyControlPlaneAuthConfig,
|
||||
PyJwtConfig,
|
||||
PyOracleConfig,
|
||||
PyPostgresConfig,
|
||||
PyRedisConfig,
|
||||
PyRole,
|
||||
)
|
||||
from sglang_router.sglang_router_rs import Router as _Router
|
||||
|
||||
|
||||
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
if policy_str is None:
|
||||
return None
|
||||
policy_map = {
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
"bucket": PolicyType.Bucket,
|
||||
"manual": PolicyType.Manual,
|
||||
"consistent_hashing": PolicyType.ConsistentHashing,
|
||||
"prefix_hash": PolicyType.PrefixHash,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
|
||||
def backend_from_str(backend_str: Optional[str]) -> BackendType:
|
||||
"""Convert backend string to BackendType enum."""
|
||||
if isinstance(backend_str, BackendType):
|
||||
return backend_str
|
||||
if backend_str is None:
|
||||
return BackendType.Sglang
|
||||
backend_map = {"sglang": BackendType.Sglang, "openai": BackendType.Openai}
|
||||
backend_lower = backend_str.lower()
|
||||
if backend_lower not in backend_map:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend_str}. Valid options: {', '.join(backend_map.keys())}"
|
||||
)
|
||||
return backend_map[backend_lower]
|
||||
|
||||
|
||||
def history_backend_from_str(backend_str: Optional[str]) -> HistoryBackendType:
|
||||
"""Convert history backend string to HistoryBackendType enum."""
|
||||
if isinstance(backend_str, HistoryBackendType):
|
||||
return backend_str
|
||||
if backend_str is None:
|
||||
return HistoryBackendType.Memory
|
||||
backend_lower = backend_str.lower()
|
||||
if backend_lower == "memory":
|
||||
return HistoryBackendType.Memory
|
||||
elif backend_lower == "none":
|
||||
# Use getattr to access 'None' which is a Python keyword
|
||||
return getattr(HistoryBackendType, "None")
|
||||
elif backend_lower == "oracle":
|
||||
return HistoryBackendType.Oracle
|
||||
elif backend_lower == "postgres":
|
||||
return HistoryBackendType.Postgres
|
||||
elif backend_lower == "redis":
|
||||
return HistoryBackendType.Redis
|
||||
else:
|
||||
raise ValueError(f"Unknown history backend: {backend_str}")
|
||||
|
||||
|
||||
def role_from_str(role_str: str) -> PyRole:
|
||||
"""Convert role string to PyRole enum."""
|
||||
if role_str.lower() == "admin":
|
||||
return PyRole.Admin
|
||||
return PyRole.User
|
||||
|
||||
|
||||
def build_control_plane_auth_config(
|
||||
args_dict: dict,
|
||||
) -> Optional[PyControlPlaneAuthConfig]:
|
||||
"""Build control plane auth config from args dict."""
|
||||
api_keys = args_dict.get("control_plane_api_keys", [])
|
||||
jwt_issuer = args_dict.get("jwt_issuer")
|
||||
jwt_audience = args_dict.get("jwt_audience")
|
||||
audit_enabled = args_dict.get("control_plane_audit_enabled", False)
|
||||
|
||||
# Check if any auth is configured
|
||||
has_api_keys = bool(api_keys)
|
||||
has_jwt = jwt_issuer is not None and jwt_audience is not None
|
||||
|
||||
if not has_api_keys and not has_jwt:
|
||||
return None
|
||||
|
||||
# Build API key entries
|
||||
py_api_keys = []
|
||||
for key_tuple in api_keys:
|
||||
# Tuple format: (id, name, key, role)
|
||||
key_id, name, key, role = key_tuple
|
||||
py_api_keys.append(
|
||||
PyApiKeyEntry(
|
||||
id=key_id,
|
||||
name=name,
|
||||
key=key,
|
||||
role=role_from_str(role),
|
||||
)
|
||||
)
|
||||
|
||||
# Build JWT config if present
|
||||
jwt_config = None
|
||||
if has_jwt:
|
||||
jwt_config = PyJwtConfig(
|
||||
issuer=jwt_issuer,
|
||||
audience=jwt_audience,
|
||||
jwks_uri=args_dict.get("jwt_jwks_uri"),
|
||||
role_mapping=args_dict.get("jwt_role_mapping", {}),
|
||||
)
|
||||
|
||||
return PyControlPlaneAuthConfig(
|
||||
jwt=jwt_config,
|
||||
api_keys=py_api_keys,
|
||||
audit_enabled=audit_enabled,
|
||||
)
|
||||
|
||||
|
||||
class Router:
|
||||
"""
|
||||
A high-performance router for distributing requests across worker nodes.
|
||||
|
||||
Args:
|
||||
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
|
||||
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
|
||||
policy: Load balancing policy to use. Options:
|
||||
- PolicyType.Random: Randomly select workers
|
||||
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
||||
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
||||
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
||||
host: Host address to bind the router server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces. Default: '0.0.0.0'
|
||||
port: Port number to bind the router server. Default: 3001
|
||||
worker_startup_timeout_secs: Timeout in seconds for worker startup and registration. Large models can take significant time to load into GPU memory. Default: 1800 (30 minutes)
|
||||
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
|
||||
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
|
||||
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
|
||||
tree. Default: 0.5
|
||||
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
|
||||
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
|
||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
||||
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
||||
routing. Default: 60
|
||||
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||
dp_aware: Enable data parallelism aware schedule. Default: False
|
||||
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled,
|
||||
the router can manage multiple models simultaneously with per-model load balancing
|
||||
policies. Default: False
|
||||
api_key: The api key used for the authorization with the worker.
|
||||
Useful when the dp aware scheduling strategy is enabled.
|
||||
Default: None
|
||||
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
||||
log_level: Logging level. Options: 'debug', 'info', 'warn', 'error'.
|
||||
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
|
||||
automatically discover worker pods based on the selector. Default: False
|
||||
selector: Dictionary mapping of label keys to values for Kubernetes pod selection.
|
||||
Example: {"app": "sglang-worker"}. Default: {}
|
||||
service_discovery_port: Port to use for service discovery. The router will generate
|
||||
worker URLs using this port. Default: 80
|
||||
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
|
||||
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
||||
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
||||
for prefill servers (PD mode only). Default: {}
|
||||
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
||||
for decode servers (PD mode only). Default: {}
|
||||
prometheus_port: Port to expose Prometheus metrics. Default: None
|
||||
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
||||
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
||||
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
||||
decode_urls: List of URLs for decode servers (PD mode only)
|
||||
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
|
||||
If not specified, uses the main policy. Default: None
|
||||
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
||||
If not specified, uses the main policy. Default: None
|
||||
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
||||
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
||||
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
||||
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
|
||||
Default: 'sglang.ai/bootstrap-port'
|
||||
request_timeout_secs: Request timeout in seconds. Default: 600
|
||||
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
|
||||
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
|
||||
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
|
||||
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
|
||||
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
|
||||
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
|
||||
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
|
||||
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
|
||||
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
|
||||
health_check_endpoint: Health check endpoint path. Default: '/health'
|
||||
model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
|
||||
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
|
||||
server_cert_path: Path to server TLS certificate (PEM format). Default: None
|
||||
server_key_path: Path to server TLS private key (PEM format). Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, router: _Router):
|
||||
self._router = router
|
||||
|
||||
@staticmethod
|
||||
def from_args(args: RouterArgs) -> "Router":
|
||||
"""Create a router from a RouterArgs instance."""
|
||||
|
||||
args_dict = vars(args)
|
||||
# Convert RouterArgs to _Router parameters
|
||||
args_dict["worker_urls"] = (
|
||||
[]
|
||||
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
|
||||
else args_dict["worker_urls"]
|
||||
)
|
||||
args_dict["policy"] = policy_from_str(args_dict["policy"])
|
||||
args_dict["prefill_urls"] = (
|
||||
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["decode_urls"] = (
|
||||
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
|
||||
)
|
||||
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
|
||||
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
|
||||
|
||||
# Convert backend
|
||||
args_dict["backend"] = backend_from_str(args_dict.get("backend"))
|
||||
|
||||
# Convert history_backend to enum first
|
||||
history_backend_raw = args_dict.get("history_backend", "memory")
|
||||
history_backend = history_backend_from_str(history_backend_raw)
|
||||
|
||||
# Convert Oracle config if needed
|
||||
oracle_config = None
|
||||
if history_backend == HistoryBackendType.Oracle:
|
||||
# Prioritize TNS alias over connect descriptor
|
||||
tns_alias = args_dict.get("oracle_tns_alias")
|
||||
connect_descriptor = args_dict.get("oracle_connect_descriptor")
|
||||
|
||||
# Use TNS alias if provided, otherwise use connect descriptor
|
||||
final_descriptor = tns_alias if tns_alias else connect_descriptor
|
||||
|
||||
oracle_config = PyOracleConfig(
|
||||
password=args_dict.get("oracle_password"),
|
||||
username=args_dict.get("oracle_username"),
|
||||
connect_descriptor=final_descriptor,
|
||||
wallet_path=args_dict.get("oracle_wallet_path"),
|
||||
pool_min=args_dict.get("oracle_pool_min", 1),
|
||||
pool_max=args_dict.get("oracle_pool_max", 16),
|
||||
pool_timeout_secs=args_dict.get("oracle_pool_timeout_secs", 30),
|
||||
)
|
||||
args_dict["oracle_config"] = oracle_config
|
||||
args_dict["history_backend"] = history_backend
|
||||
|
||||
# Convert Postgres config if needed
|
||||
postgres_config = None
|
||||
if history_backend == HistoryBackendType.Postgres:
|
||||
postgres_config = PyPostgresConfig(
|
||||
db_url=args_dict.get("postgres_db_url"),
|
||||
pool_max=args_dict.get("postgres_pool_max", 16),
|
||||
)
|
||||
args_dict["postgres_config"] = postgres_config
|
||||
|
||||
# Convert Redis config if needed
|
||||
redis_config = None
|
||||
if history_backend == HistoryBackendType.Redis:
|
||||
retention_days = args_dict.get("redis_retention_days", 30)
|
||||
# If retention_days is negative, it means persistent storage (None in Rust)
|
||||
retention_arg = None if retention_days < 0 else retention_days
|
||||
|
||||
redis_config = PyRedisConfig(
|
||||
url=args_dict.get("redis_url"),
|
||||
pool_max=args_dict.get("redis_pool_max", 16),
|
||||
retention_days=retention_arg,
|
||||
)
|
||||
args_dict["redis_config"] = redis_config
|
||||
|
||||
# Build control plane auth config
|
||||
args_dict["control_plane_auth"] = build_control_plane_auth_config(args_dict)
|
||||
|
||||
# Remove fields that shouldn't be passed to Rust Router constructor
|
||||
fields_to_remove = [
|
||||
"mini_lb",
|
||||
"test_external_dp_routing",
|
||||
"oracle_wallet_path",
|
||||
"oracle_tns_alias",
|
||||
"oracle_connect_descriptor",
|
||||
"oracle_username",
|
||||
"oracle_password",
|
||||
"oracle_pool_min",
|
||||
"oracle_pool_max",
|
||||
"oracle_pool_timeout_secs",
|
||||
"postgres_db_url",
|
||||
"postgres_pool_max",
|
||||
"redis_url",
|
||||
"redis_pool_max",
|
||||
"redis_retention_days",
|
||||
# Control plane auth fields (converted to control_plane_auth)
|
||||
"control_plane_api_keys",
|
||||
"control_plane_audit_enabled",
|
||||
"jwt_issuer",
|
||||
"jwt_audience",
|
||||
"jwt_jwks_uri",
|
||||
"jwt_role_mapping",
|
||||
]
|
||||
for field in fields_to_remove:
|
||||
args_dict.pop(field, None)
|
||||
|
||||
return Router(_Router(**args_dict))
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the router server.
|
||||
|
||||
This method blocks until the server is shut down.
|
||||
"""
|
||||
self._router.start()
|
||||
1104
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/router_args.py
vendored
Normal file
1104
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/router_args.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/version.py
vendored
Normal file
1
third_party/sglang/sgl-model-gateway/bindings/python/src/sglang_router/version.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.3.2"
|
||||
14
third_party/sglang/sgl-model-gateway/bindings/python/tests/conftest.py
vendored
Normal file
14
third_party/sglang/sgl-model-gateway/bindings/python/tests/conftest.py
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Pytest configuration for sglang_router Python binding tests.
|
||||
|
||||
These are unit tests that run without GPU resources or external dependencies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest markers."""
|
||||
config.addinivalue_line(
|
||||
"markers", "unit: mark test as a unit test (no GPU required)"
|
||||
)
|
||||
637
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_arg_parser.py
vendored
Normal file
637
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_arg_parser.py
vendored
Normal file
@@ -0,0 +1,637 @@
|
||||
"""
|
||||
Unit tests for argument parsing functionality in sglang_router.
|
||||
|
||||
These tests focus on testing the argument parsing logic in isolation,
|
||||
without starting actual router instances.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, parse_router_args
|
||||
from sglang_router.router import policy_from_str
|
||||
|
||||
|
||||
class TestRouterArgs:
|
||||
"""Test RouterArgs dataclass and its methods."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test that RouterArgs has correct default values."""
|
||||
args = RouterArgs()
|
||||
|
||||
# Test basic defaults
|
||||
assert args.host == "0.0.0.0"
|
||||
assert args.port == 30000
|
||||
assert args.policy == "cache_aware"
|
||||
assert args.worker_urls == []
|
||||
assert args.pd_disaggregation is False
|
||||
assert args.prefill_urls == []
|
||||
assert args.decode_urls == []
|
||||
|
||||
# Test PD-specific defaults
|
||||
assert args.prefill_policy is None
|
||||
assert args.decode_policy is None
|
||||
|
||||
# Test service discovery defaults
|
||||
assert args.service_discovery is False
|
||||
assert args.selector == {}
|
||||
assert args.service_discovery_port == 80
|
||||
assert args.service_discovery_namespace is None
|
||||
|
||||
# Test retry and circuit breaker defaults
|
||||
assert args.retry_max_retries == 5
|
||||
assert args.cb_failure_threshold == 10
|
||||
assert args.disable_retries is False
|
||||
assert args.disable_circuit_breaker is False
|
||||
|
||||
def test_parse_selector_valid(self):
|
||||
"""Test parsing valid selector arguments."""
|
||||
# Test single key-value pair
|
||||
result = RouterArgs._parse_selector(["app=worker"])
|
||||
assert result == {"app": "worker"}
|
||||
|
||||
# Test multiple key-value pairs
|
||||
result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"])
|
||||
assert result == {"app": "worker", "env": "prod", "version": "v1"}
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_selector([])
|
||||
assert result == {}
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_selector(None)
|
||||
assert result == {}
|
||||
|
||||
def test_parse_selector_invalid(self):
|
||||
"""Test parsing invalid selector arguments."""
|
||||
# Test malformed selector (no equals sign)
|
||||
result = RouterArgs._parse_selector(["app"])
|
||||
assert result == {}
|
||||
|
||||
# Test multiple equals signs (should use first one)
|
||||
result = RouterArgs._parse_selector(["app=worker=extra"])
|
||||
assert result == {"app": "worker=extra"}
|
||||
|
||||
def test_parse_prefill_urls_valid(self):
|
||||
"""Test parsing valid prefill URL arguments."""
|
||||
# Test with bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]])
|
||||
assert result == [("http://prefill1:8000", 9000)]
|
||||
|
||||
# Test with 'none' bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]])
|
||||
assert result == [("http://prefill1:8000", None)]
|
||||
|
||||
# Test without bootstrap port
|
||||
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]])
|
||||
assert result == [("http://prefill1:8000", None)]
|
||||
|
||||
# Test multiple prefill URLs
|
||||
result = RouterArgs._parse_prefill_urls(
|
||||
[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
["http://prefill3:8000"],
|
||||
]
|
||||
)
|
||||
expected = [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
("http://prefill3:8000", None),
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_prefill_urls([])
|
||||
assert result == []
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_prefill_urls(None)
|
||||
assert result == []
|
||||
|
||||
def test_parse_prefill_urls_invalid(self):
|
||||
"""Test parsing invalid prefill URL arguments."""
|
||||
# Test invalid bootstrap port
|
||||
with pytest.raises(ValueError, match="Invalid bootstrap port"):
|
||||
RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]])
|
||||
|
||||
def test_parse_decode_urls_valid(self):
|
||||
"""Test parsing valid decode URL arguments."""
|
||||
# Test single decode URL
|
||||
result = RouterArgs._parse_decode_urls([["http://decode1:8001"]])
|
||||
assert result == ["http://decode1:8001"]
|
||||
|
||||
# Test multiple decode URLs
|
||||
result = RouterArgs._parse_decode_urls(
|
||||
[["http://decode1:8001"], ["http://decode2:8001"]]
|
||||
)
|
||||
assert result == ["http://decode1:8001", "http://decode2:8001"]
|
||||
|
||||
# Test empty list
|
||||
result = RouterArgs._parse_decode_urls([])
|
||||
assert result == []
|
||||
|
||||
# Test None
|
||||
result = RouterArgs._parse_decode_urls(None)
|
||||
assert result == []
|
||||
|
||||
def test_from_cli_args_basic(self):
|
||||
"""Test creating RouterArgs from basic CLI arguments."""
|
||||
args = SimpleNamespace(
|
||||
host="0.0.0.0",
|
||||
port=30001,
|
||||
worker_urls=["http://worker1:8000", "http://worker2:8000"],
|
||||
policy="round_robin",
|
||||
prefill=None,
|
||||
decode=None,
|
||||
router_policy="round_robin",
|
||||
router_pd_disaggregation=False,
|
||||
router_prefill_policy=None,
|
||||
router_decode_policy=None,
|
||||
router_worker_startup_timeout_secs=300,
|
||||
router_worker_startup_check_interval=15,
|
||||
router_cache_threshold=0.7,
|
||||
router_balance_abs_threshold=128,
|
||||
router_balance_rel_threshold=2.0,
|
||||
router_eviction_interval=180,
|
||||
router_max_tree_size=2**28,
|
||||
router_max_payload_size=1024 * 1024 * 1024, # 1GB
|
||||
router_dp_aware=True,
|
||||
router_api_key="test-key",
|
||||
router_log_dir="/tmp/logs",
|
||||
router_log_level="debug",
|
||||
router_service_discovery=True,
|
||||
router_selector=["app=worker", "env=test"],
|
||||
router_service_discovery_port=8080,
|
||||
router_service_discovery_namespace="default",
|
||||
router_prefill_selector=["app=prefill"],
|
||||
router_decode_selector=["app=decode"],
|
||||
router_prometheus_port=29000,
|
||||
router_prometheus_host="0.0.0.0",
|
||||
router_request_id_headers=["x-request-id", "x-trace-id"],
|
||||
router_request_timeout_secs=1200,
|
||||
router_max_concurrent_requests=512,
|
||||
router_queue_size=200,
|
||||
router_queue_timeout_secs=120,
|
||||
router_rate_limit_tokens_per_second=100,
|
||||
router_cors_allowed_origins=["http://localhost:3000"],
|
||||
router_retry_max_retries=3,
|
||||
router_retry_initial_backoff_ms=100,
|
||||
router_retry_max_backoff_ms=10000,
|
||||
router_retry_backoff_multiplier=2.0,
|
||||
router_retry_jitter_factor=0.1,
|
||||
router_cb_failure_threshold=5,
|
||||
router_cb_success_threshold=2,
|
||||
router_cb_timeout_duration_secs=30,
|
||||
router_cb_window_duration_secs=60,
|
||||
router_disable_retries=False,
|
||||
router_disable_circuit_breaker=False,
|
||||
router_health_failure_threshold=2,
|
||||
router_health_success_threshold=1,
|
||||
router_health_check_timeout_secs=3,
|
||||
router_health_check_interval_secs=30,
|
||||
router_health_check_endpoint="/healthz",
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Test basic configuration
|
||||
assert router_args.host == "0.0.0.0"
|
||||
assert router_args.port == 30001
|
||||
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert router_args.policy == "round_robin"
|
||||
|
||||
# Test PD configuration
|
||||
assert router_args.pd_disaggregation is False
|
||||
assert router_args.prefill_urls == []
|
||||
assert router_args.decode_urls == []
|
||||
|
||||
# Test service discovery
|
||||
assert router_args.service_discovery is True
|
||||
assert router_args.selector == {"app": "worker", "env": "test"}
|
||||
assert router_args.service_discovery_port == 8080
|
||||
assert router_args.service_discovery_namespace == "default"
|
||||
assert router_args.prefill_selector == {"app": "prefill"}
|
||||
assert router_args.decode_selector == {"app": "decode"}
|
||||
|
||||
# Test other configurations
|
||||
assert router_args.dp_aware is True
|
||||
assert router_args.api_key == "test-key"
|
||||
assert router_args.log_dir == "/tmp/logs"
|
||||
assert router_args.log_level == "debug"
|
||||
assert router_args.prometheus_port == 29000
|
||||
assert router_args.prometheus_host == "0.0.0.0"
|
||||
assert router_args.request_id_headers == ["x-request-id", "x-trace-id"]
|
||||
assert router_args.request_timeout_secs == 1200
|
||||
assert router_args.max_concurrent_requests == 512
|
||||
assert router_args.queue_size == 200
|
||||
assert router_args.queue_timeout_secs == 120
|
||||
assert router_args.rate_limit_tokens_per_second == 100
|
||||
assert router_args.cors_allowed_origins == ["http://localhost:3000"]
|
||||
|
||||
# Test retry configuration
|
||||
assert router_args.retry_max_retries == 3
|
||||
assert router_args.retry_initial_backoff_ms == 100
|
||||
assert router_args.retry_max_backoff_ms == 10000
|
||||
assert router_args.retry_backoff_multiplier == 2.0
|
||||
assert router_args.retry_jitter_factor == 0.1
|
||||
|
||||
# Test circuit breaker configuration
|
||||
assert router_args.cb_failure_threshold == 5
|
||||
assert router_args.cb_success_threshold == 2
|
||||
assert router_args.cb_timeout_duration_secs == 30
|
||||
assert router_args.cb_window_duration_secs == 60
|
||||
assert router_args.disable_retries is False
|
||||
assert router_args.disable_circuit_breaker is False
|
||||
|
||||
# Test health check configuration
|
||||
assert router_args.health_failure_threshold == 2
|
||||
assert router_args.health_success_threshold == 1
|
||||
assert router_args.health_check_timeout_secs == 3
|
||||
assert router_args.health_check_interval_secs == 30
|
||||
assert router_args.health_check_endpoint == "/healthz"
|
||||
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
|
||||
def test_from_cli_args_pd_mode(self):
|
||||
"""Test creating RouterArgs from CLI arguments in PD mode."""
|
||||
args = SimpleNamespace(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=[],
|
||||
policy="cache_aware",
|
||||
prefill=[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
],
|
||||
decode=[["http://decode1:8001"], ["http://decode2:8001"]],
|
||||
router_prefill=[
|
||||
["http://prefill1:8000", "9000"],
|
||||
["http://prefill2:8000", "none"],
|
||||
],
|
||||
router_decode=[["http://decode1:8001"], ["http://decode2:8001"]],
|
||||
router_policy="cache_aware",
|
||||
router_pd_disaggregation=True,
|
||||
router_prefill_policy="power_of_two",
|
||||
router_decode_policy="round_robin",
|
||||
# Include all required fields with defaults
|
||||
router_worker_startup_timeout_secs=600,
|
||||
router_worker_startup_check_interval=30,
|
||||
router_cache_threshold=0.3,
|
||||
router_balance_abs_threshold=64,
|
||||
router_balance_rel_threshold=1.5,
|
||||
router_eviction_interval=120,
|
||||
router_max_tree_size=2**26,
|
||||
router_max_payload_size=512 * 1024 * 1024,
|
||||
router_dp_aware=False,
|
||||
router_api_key=None,
|
||||
router_log_dir=None,
|
||||
router_log_level=None,
|
||||
router_service_discovery=False,
|
||||
router_selector=None,
|
||||
router_service_discovery_port=80,
|
||||
router_service_discovery_namespace=None,
|
||||
router_prefill_selector=None,
|
||||
router_decode_selector=None,
|
||||
router_prometheus_port=None,
|
||||
router_prometheus_host=None,
|
||||
router_request_id_headers=None,
|
||||
router_request_timeout_secs=1800,
|
||||
router_max_concurrent_requests=256,
|
||||
router_queue_size=100,
|
||||
router_queue_timeout_secs=60,
|
||||
router_rate_limit_tokens_per_second=None,
|
||||
router_cors_allowed_origins=[],
|
||||
router_retry_max_retries=5,
|
||||
router_retry_initial_backoff_ms=50,
|
||||
router_retry_max_backoff_ms=30000,
|
||||
router_retry_backoff_multiplier=1.5,
|
||||
router_retry_jitter_factor=0.2,
|
||||
router_cb_failure_threshold=10,
|
||||
router_cb_success_threshold=3,
|
||||
router_cb_timeout_duration_secs=60,
|
||||
router_cb_window_duration_secs=120,
|
||||
router_disable_retries=False,
|
||||
router_disable_circuit_breaker=False,
|
||||
router_health_failure_threshold=3,
|
||||
router_health_success_threshold=2,
|
||||
router_health_check_timeout_secs=5,
|
||||
router_health_check_interval_secs=60,
|
||||
router_health_check_endpoint="/health",
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
|
||||
|
||||
# Test PD configuration
|
||||
assert router_args.pd_disaggregation is True
|
||||
assert router_args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert router_args.prefill_policy == "power_of_two"
|
||||
assert router_args.decode_policy == "round_robin"
|
||||
assert router_args.policy == "cache_aware" # Main policy still set
|
||||
|
||||
def test_from_cli_args_without_prefix(self):
|
||||
"""Test creating RouterArgs from CLI arguments without router prefix."""
|
||||
args = SimpleNamespace(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="random",
|
||||
prefill=None,
|
||||
decode=None,
|
||||
pd_disaggregation=False,
|
||||
prefill_policy=None,
|
||||
decode_policy=None,
|
||||
worker_startup_timeout_secs=600,
|
||||
worker_startup_check_interval=30,
|
||||
cache_threshold=0.3,
|
||||
balance_abs_threshold=64,
|
||||
balance_rel_threshold=1.5,
|
||||
eviction_interval=120,
|
||||
max_tree_size=2**26,
|
||||
max_payload_size=512 * 1024 * 1024,
|
||||
dp_aware=False,
|
||||
api_key=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
service_discovery=False,
|
||||
selector=None,
|
||||
service_discovery_port=80,
|
||||
service_discovery_namespace=None,
|
||||
prefill_selector=None,
|
||||
decode_selector=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
request_id_headers=None,
|
||||
request_timeout_secs=1800,
|
||||
max_concurrent_requests=256,
|
||||
queue_size=100,
|
||||
queue_timeout_secs=60,
|
||||
rate_limit_tokens_per_second=None,
|
||||
cors_allowed_origins=[],
|
||||
retry_max_retries=5,
|
||||
retry_initial_backoff_ms=50,
|
||||
retry_max_backoff_ms=30000,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_jitter_factor=0.2,
|
||||
cb_failure_threshold=10,
|
||||
cb_success_threshold=3,
|
||||
cb_timeout_duration_secs=60,
|
||||
cb_window_duration_secs=120,
|
||||
disable_retries=False,
|
||||
disable_circuit_breaker=False,
|
||||
health_failure_threshold=3,
|
||||
health_success_threshold=2,
|
||||
health_check_timeout_secs=5,
|
||||
health_check_interval_secs=60,
|
||||
health_check_endpoint="/health",
|
||||
model_path=None,
|
||||
tokenizer_path=None,
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args, use_router_prefix=False)
|
||||
|
||||
assert router_args.host == "127.0.0.1"
|
||||
assert router_args.port == 30000
|
||||
assert router_args.worker_urls == ["http://worker1:8000"]
|
||||
assert router_args.policy == "random"
|
||||
assert router_args.pd_disaggregation is False
|
||||
|
||||
|
||||
class TestPolicyFromStr:
|
||||
"""Test policy string to enum conversion."""
|
||||
|
||||
def test_valid_policies(self):
|
||||
"""Test conversion of valid policy strings."""
|
||||
from sglang_router.sglang_router_rs import PolicyType
|
||||
|
||||
assert policy_from_str("random") == PolicyType.Random
|
||||
assert policy_from_str("round_robin") == PolicyType.RoundRobin
|
||||
assert policy_from_str("cache_aware") == PolicyType.CacheAware
|
||||
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
|
||||
|
||||
def test_invalid_policy(self):
|
||||
"""Test conversion of invalid policy string."""
|
||||
with pytest.raises(KeyError):
|
||||
policy_from_str("invalid_policy")
|
||||
|
||||
|
||||
class TestParseRouterArgs:
|
||||
"""Test the parse_router_args function."""
|
||||
|
||||
def test_parse_basic_args(self):
|
||||
"""Test parsing basic router arguments."""
|
||||
args = [
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"30001",
|
||||
"--worker-urls",
|
||||
"http://worker1:8000",
|
||||
"http://worker2:8000",
|
||||
"--policy",
|
||||
"round_robin",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.host == "0.0.0.0"
|
||||
assert router_args.port == 30001
|
||||
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert router_args.policy == "round_robin"
|
||||
|
||||
def test_parse_pd_args(self):
|
||||
"""Test parsing PD disaggregated mode arguments."""
|
||||
args = [
|
||||
"--pd-disaggregation",
|
||||
"--prefill",
|
||||
"http://prefill1:8000",
|
||||
"9000",
|
||||
"--prefill",
|
||||
"http://prefill2:8000",
|
||||
"none",
|
||||
"--decode",
|
||||
"http://decode1:8001",
|
||||
"--decode",
|
||||
"http://decode2:8001",
|
||||
"--prefill-policy",
|
||||
"power_of_two",
|
||||
"--decode-policy",
|
||||
"round_robin",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.pd_disaggregation is True
|
||||
assert router_args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert router_args.prefill_policy == "power_of_two"
|
||||
assert router_args.decode_policy == "round_robin"
|
||||
|
||||
def test_parse_service_discovery_args(self):
|
||||
"""Test parsing service discovery arguments."""
|
||||
args_a = [
|
||||
"--service-discovery",
|
||||
"--selector",
|
||||
"app=worker",
|
||||
"env=prod",
|
||||
"--service-discovery-port",
|
||||
"8080",
|
||||
"--service-discovery-namespace",
|
||||
"default",
|
||||
]
|
||||
args_b = [
|
||||
"--service-discovery",
|
||||
"--selector",
|
||||
# OME has this style
|
||||
"app=worker env=prod",
|
||||
"--service-discovery-port",
|
||||
"8080",
|
||||
"--service-discovery-namespace",
|
||||
"default",
|
||||
]
|
||||
|
||||
for args in [args_a, args_b]:
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.service_discovery is True
|
||||
assert router_args.selector == {"app": "worker", "env": "prod"}
|
||||
assert router_args.service_discovery_port == 8080
|
||||
assert router_args.service_discovery_namespace == "default"
|
||||
|
||||
def test_parse_retry_and_circuit_breaker_args(self):
|
||||
"""Test parsing retry and circuit breaker arguments."""
|
||||
args = [
|
||||
"--retry-max-retries",
|
||||
"3",
|
||||
"--retry-initial-backoff-ms",
|
||||
"100",
|
||||
"--retry-max-backoff-ms",
|
||||
"10000",
|
||||
"--retry-backoff-multiplier",
|
||||
"2.0",
|
||||
"--retry-jitter-factor",
|
||||
"0.1",
|
||||
"--disable-retries",
|
||||
"--cb-failure-threshold",
|
||||
"5",
|
||||
"--cb-success-threshold",
|
||||
"2",
|
||||
"--cb-timeout-duration-secs",
|
||||
"30",
|
||||
"--cb-window-duration-secs",
|
||||
"60",
|
||||
"--disable-circuit-breaker",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
# Test retry configuration
|
||||
assert router_args.retry_max_retries == 3
|
||||
assert router_args.retry_initial_backoff_ms == 100
|
||||
assert router_args.retry_max_backoff_ms == 10000
|
||||
assert router_args.retry_backoff_multiplier == 2.0
|
||||
assert router_args.retry_jitter_factor == 0.1
|
||||
assert router_args.disable_retries is True
|
||||
|
||||
# Test circuit breaker configuration
|
||||
assert router_args.cb_failure_threshold == 5
|
||||
assert router_args.cb_success_threshold == 2
|
||||
assert router_args.cb_timeout_duration_secs == 30
|
||||
assert router_args.cb_window_duration_secs == 60
|
||||
assert router_args.disable_circuit_breaker is True
|
||||
|
||||
def test_parse_rate_limiting_args(self):
|
||||
"""Test parsing rate limiting arguments."""
|
||||
args = [
|
||||
"--max-concurrent-requests",
|
||||
"512",
|
||||
"--queue-size",
|
||||
"200",
|
||||
"--queue-timeout-secs",
|
||||
"120",
|
||||
"--rate-limit-tokens-per-second",
|
||||
"100",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.max_concurrent_requests == 512
|
||||
assert router_args.queue_size == 200
|
||||
assert router_args.queue_timeout_secs == 120
|
||||
assert router_args.rate_limit_tokens_per_second == 100
|
||||
|
||||
def test_parse_health_check_args(self):
|
||||
"""Test parsing health check arguments."""
|
||||
args = [
|
||||
"--health-failure-threshold",
|
||||
"2",
|
||||
"--health-success-threshold",
|
||||
"1",
|
||||
"--health-check-timeout-secs",
|
||||
"3",
|
||||
"--health-check-interval-secs",
|
||||
"30",
|
||||
"--health-check-endpoint",
|
||||
"/healthz",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.health_failure_threshold == 2
|
||||
assert router_args.health_success_threshold == 1
|
||||
assert router_args.health_check_timeout_secs == 3
|
||||
assert router_args.health_check_interval_secs == 30
|
||||
assert router_args.health_check_endpoint == "/healthz"
|
||||
|
||||
def test_parse_cors_args(self):
|
||||
"""Test parsing CORS arguments."""
|
||||
args = [
|
||||
"--cors-allowed-origins",
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
router_args = parse_router_args(args)
|
||||
|
||||
assert router_args.cors_allowed_origins == [
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
def test_parse_tokenizer_args(self):
|
||||
"""Test parsing tokenizer arguments."""
|
||||
# Note: model-path and tokenizer-path arguments are not available in current implementation
|
||||
# This test is skipped until those arguments are added
|
||||
pytest.skip("Tokenizer arguments not available in current implementation")
|
||||
|
||||
def test_parse_invalid_args(self):
|
||||
"""Test parsing invalid arguments."""
|
||||
# Test invalid policy
|
||||
with pytest.raises(SystemExit):
|
||||
parse_router_args(["--policy", "invalid_policy"])
|
||||
|
||||
# Test invalid bootstrap port
|
||||
with pytest.raises(ValueError, match="Invalid bootstrap port"):
|
||||
parse_router_args(
|
||||
[
|
||||
"--pd-disaggregation",
|
||||
"--prefill",
|
||||
"http://prefill1:8000",
|
||||
"invalid_port",
|
||||
]
|
||||
)
|
||||
|
||||
def test_help_output(self):
|
||||
"""Test that help output is generated correctly."""
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
parse_router_args(["--help"])
|
||||
|
||||
# SystemExit with code 0 indicates help was displayed
|
||||
assert exc_info.value.code == 0
|
||||
423
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_router_config.py
vendored
Normal file
423
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_router_config.py
vendored
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Unit tests for router configuration validation and setup.
|
||||
|
||||
These tests focus on testing the router configuration logic in isolation,
|
||||
including validation of configuration parameters and their interactions.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
from sglang_router.router import policy_from_str
|
||||
from sglang_router.sglang_router_rs import PolicyType
|
||||
|
||||
|
||||
class TestRouterConfigValidation:
|
||||
"""Test router configuration validation logic."""
|
||||
|
||||
def test_valid_basic_config(self):
|
||||
"""Test that a valid basic configuration passes validation."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000", "http://worker2:8000"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise any exceptions
|
||||
assert args.host == "127.0.0.1"
|
||||
assert args.port == 30000
|
||||
assert args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
|
||||
assert args.policy == "cache_aware"
|
||||
|
||||
def test_valid_pd_config(self):
|
||||
"""Test that a valid PD configuration passes validation."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
],
|
||||
decode_urls=["http://decode1:8001", "http://decode2:8001"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.prefill_urls == [
|
||||
("http://prefill1:8000", 9000),
|
||||
("http://prefill2:8000", None),
|
||||
]
|
||||
assert args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
|
||||
assert args.policy == "cache_aware"
|
||||
|
||||
def test_pd_config_without_urls_allowed(self):
|
||||
"""Test that PD mode without URLs is now allowed (URLs are optional)."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=False,
|
||||
)
|
||||
|
||||
# Should not raise validation error - URLs are now optional
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
# This should succeed without raising an error
|
||||
launch_router(args)
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_pd_config_with_service_discovery_allows_empty_urls(self):
|
||||
"""Test that PD mode with service discovery allows empty URLs."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=True,
|
||||
)
|
||||
|
||||
# Should not raise validation error when service discovery is enabled
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_regular_mode_without_workers_allows_empty_urls(self):
|
||||
"""Test that regular mode allows empty worker URLs."""
|
||||
args = RouterArgs(worker_urls=[], service_discovery=False)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_cache_threshold_validation(self):
|
||||
"""Test cache threshold validation."""
|
||||
# Valid cache threshold
|
||||
args = RouterArgs(cache_threshold=0.5)
|
||||
assert args.cache_threshold == 0.5
|
||||
|
||||
# Edge cases
|
||||
args = RouterArgs(cache_threshold=0.0)
|
||||
assert args.cache_threshold == 0.0
|
||||
|
||||
args = RouterArgs(cache_threshold=1.0)
|
||||
assert args.cache_threshold == 1.0
|
||||
|
||||
def test_balance_threshold_validation(self):
|
||||
"""Test load balancing threshold validation."""
|
||||
# Valid thresholds
|
||||
args = RouterArgs(balance_abs_threshold=64, balance_rel_threshold=1.5)
|
||||
assert args.balance_abs_threshold == 64
|
||||
assert args.balance_rel_threshold == 1.5
|
||||
|
||||
# Edge cases
|
||||
args = RouterArgs(balance_abs_threshold=0, balance_rel_threshold=1.0)
|
||||
assert args.balance_abs_threshold == 0
|
||||
assert args.balance_rel_threshold == 1.0
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout parameter validation."""
|
||||
# Valid timeouts
|
||||
args = RouterArgs(
|
||||
worker_startup_timeout_secs=600,
|
||||
worker_startup_check_interval=30,
|
||||
request_timeout_secs=1800,
|
||||
queue_timeout_secs=60,
|
||||
)
|
||||
assert args.worker_startup_timeout_secs == 600
|
||||
assert args.worker_startup_check_interval == 30
|
||||
assert args.request_timeout_secs == 1800
|
||||
assert args.queue_timeout_secs == 60
|
||||
|
||||
def test_retry_config_validation(self):
|
||||
"""Test retry configuration validation."""
|
||||
# Valid retry config
|
||||
args = RouterArgs(
|
||||
retry_max_retries=5,
|
||||
retry_initial_backoff_ms=50,
|
||||
retry_max_backoff_ms=30000,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_jitter_factor=0.2,
|
||||
disable_retries=False,
|
||||
)
|
||||
assert args.retry_max_retries == 5
|
||||
assert args.retry_initial_backoff_ms == 50
|
||||
assert args.retry_max_backoff_ms == 30000
|
||||
assert args.retry_backoff_multiplier == 1.5
|
||||
assert args.retry_jitter_factor == 0.2
|
||||
assert args.disable_retries is False
|
||||
|
||||
def test_circuit_breaker_config_validation(self):
|
||||
"""Test circuit breaker configuration validation."""
|
||||
# Valid circuit breaker config
|
||||
args = RouterArgs(
|
||||
cb_failure_threshold=10,
|
||||
cb_success_threshold=3,
|
||||
cb_timeout_duration_secs=60,
|
||||
cb_window_duration_secs=120,
|
||||
disable_circuit_breaker=False,
|
||||
)
|
||||
assert args.cb_failure_threshold == 10
|
||||
assert args.cb_success_threshold == 3
|
||||
assert args.cb_timeout_duration_secs == 60
|
||||
assert args.cb_window_duration_secs == 120
|
||||
assert args.disable_circuit_breaker is False
|
||||
|
||||
def test_health_check_config_validation(self):
|
||||
"""Test health check configuration validation."""
|
||||
# Valid health check config
|
||||
args = RouterArgs(
|
||||
health_failure_threshold=3,
|
||||
health_success_threshold=2,
|
||||
health_check_timeout_secs=5,
|
||||
health_check_interval_secs=60,
|
||||
health_check_endpoint="/health",
|
||||
)
|
||||
assert args.health_failure_threshold == 3
|
||||
assert args.health_success_threshold == 2
|
||||
assert args.health_check_timeout_secs == 5
|
||||
assert args.health_check_interval_secs == 60
|
||||
assert args.health_check_endpoint == "/health"
|
||||
|
||||
def test_rate_limiting_config_validation(self):
|
||||
"""Test rate limiting configuration validation."""
|
||||
# Valid rate limiting config
|
||||
args = RouterArgs(
|
||||
max_concurrent_requests=256,
|
||||
queue_size=100,
|
||||
queue_timeout_secs=60,
|
||||
rate_limit_tokens_per_second=100,
|
||||
)
|
||||
assert args.max_concurrent_requests == 256
|
||||
assert args.queue_size == 100
|
||||
assert args.queue_timeout_secs == 60
|
||||
assert args.rate_limit_tokens_per_second == 100
|
||||
|
||||
def test_service_discovery_config_validation(self):
|
||||
"""Test service discovery configuration validation."""
|
||||
# Valid service discovery config
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker", "env": "prod"},
|
||||
service_discovery_port=8080,
|
||||
service_discovery_namespace="default",
|
||||
)
|
||||
assert args.service_discovery is True
|
||||
assert args.selector == {"app": "worker", "env": "prod"}
|
||||
assert args.service_discovery_port == 8080
|
||||
assert args.service_discovery_namespace == "default"
|
||||
|
||||
def test_pd_service_discovery_config_validation(self):
|
||||
"""Test PD service discovery configuration validation."""
|
||||
# Valid PD service discovery config
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
service_discovery=True,
|
||||
prefill_selector={"app": "prefill"},
|
||||
decode_selector={"app": "decode"},
|
||||
bootstrap_port_annotation="sglang.ai/bootstrap-port",
|
||||
)
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.service_discovery is True
|
||||
assert args.prefill_selector == {"app": "prefill"}
|
||||
assert args.decode_selector == {"app": "decode"}
|
||||
assert args.bootstrap_port_annotation == "sglang.ai/bootstrap-port"
|
||||
|
||||
def test_prometheus_config_validation(self):
|
||||
"""Test Prometheus configuration validation."""
|
||||
# Valid Prometheus config
|
||||
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
||||
assert args.prometheus_port == 29000
|
||||
assert args.prometheus_host == "127.0.0.1"
|
||||
|
||||
def test_cors_config_validation(self):
|
||||
"""Test CORS configuration validation."""
|
||||
# Valid CORS config
|
||||
args = RouterArgs(
|
||||
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
|
||||
)
|
||||
assert args.cors_allowed_origins == [
|
||||
"http://localhost:3000",
|
||||
"https://example.com",
|
||||
]
|
||||
|
||||
def test_tokenizer_config_validation(self):
|
||||
"""Test tokenizer configuration validation."""
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
pytest.skip("Tokenizer configuration not available in current implementation")
|
||||
|
||||
def test_dp_aware_config_validation(self):
|
||||
"""Test data parallelism aware configuration validation."""
|
||||
# Valid DP aware config
|
||||
args = RouterArgs(dp_aware=True, api_key="test-api-key")
|
||||
assert args.dp_aware is True
|
||||
assert args.api_key == "test-api-key"
|
||||
|
||||
def test_request_id_headers_validation(self):
|
||||
"""Test request ID headers configuration validation."""
|
||||
# Valid request ID headers config
|
||||
args = RouterArgs(
|
||||
request_id_headers=["x-request-id", "x-trace-id", "x-correlation-id"]
|
||||
)
|
||||
assert args.request_id_headers == [
|
||||
"x-request-id",
|
||||
"x-trace-id",
|
||||
"x-correlation-id",
|
||||
]
|
||||
|
||||
def test_policy_consistency_validation(self):
|
||||
"""Test policy consistency validation in PD mode."""
|
||||
# Test with both prefill and decode policies specified
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
prefill_policy="power_of_two",
|
||||
decode_policy="round_robin",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_policy_fallback_validation(self):
|
||||
"""Test policy fallback validation in PD mode."""
|
||||
# Test with only prefill policy specified
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
prefill_policy="power_of_two",
|
||||
decode_policy=None,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_policy_enum_conversion(self):
|
||||
"""Test policy string to enum conversion."""
|
||||
# Test all valid policy conversions
|
||||
assert policy_from_str("random") == PolicyType.Random
|
||||
assert policy_from_str("round_robin") == PolicyType.RoundRobin
|
||||
assert policy_from_str("cache_aware") == PolicyType.CacheAware
|
||||
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
|
||||
|
||||
def test_invalid_policy_enum_conversion(self):
|
||||
"""Test invalid policy string to enum conversion."""
|
||||
with pytest.raises(KeyError):
|
||||
policy_from_str("invalid_policy")
|
||||
|
||||
def test_config_immutability(self):
|
||||
"""Test that configuration objects are properly immutable."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
|
||||
)
|
||||
|
||||
# Test that we can't modify the configuration after creation
|
||||
# (This is more of a design test - dataclasses are mutable by default)
|
||||
original_host = args.host
|
||||
args.host = "0.0.0.0"
|
||||
assert args.host == "0.0.0.0" # Dataclasses are mutable
|
||||
assert args.host != original_host
|
||||
|
||||
def test_config_defaults_consistency(self):
|
||||
"""Test that configuration defaults are consistent."""
|
||||
args1 = RouterArgs()
|
||||
args2 = RouterArgs()
|
||||
|
||||
# Both instances should have the same defaults
|
||||
assert args1.host == args2.host
|
||||
assert args1.port == args2.port
|
||||
assert args1.policy == args2.policy
|
||||
assert args1.worker_urls == args2.worker_urls
|
||||
assert args1.pd_disaggregation == args2.pd_disaggregation
|
||||
|
||||
def test_config_serialization(self):
|
||||
"""Test that configuration can be serialized/deserialized."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="cache_aware",
|
||||
cache_threshold=0.5,
|
||||
)
|
||||
|
||||
# Test that we can access all attributes
|
||||
assert hasattr(args, "host")
|
||||
assert hasattr(args, "port")
|
||||
assert hasattr(args, "worker_urls")
|
||||
assert hasattr(args, "policy")
|
||||
assert hasattr(args, "cache_threshold")
|
||||
|
||||
def test_config_with_none_values(self):
|
||||
"""Test configuration with None values."""
|
||||
args = RouterArgs(
|
||||
api_key=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
request_id_headers=None,
|
||||
rate_limit_tokens_per_second=None,
|
||||
service_discovery_namespace=None,
|
||||
)
|
||||
|
||||
# All None values should be preserved
|
||||
assert args.api_key is None
|
||||
assert args.log_dir is None
|
||||
assert args.log_level is None
|
||||
assert args.prometheus_port is None
|
||||
assert args.prometheus_host is None
|
||||
assert args.request_id_headers is None
|
||||
assert args.rate_limit_tokens_per_second is None
|
||||
assert args.service_discovery_namespace is None
|
||||
|
||||
def test_config_with_empty_lists(self):
|
||||
"""Test configuration with empty lists."""
|
||||
args = RouterArgs(
|
||||
worker_urls=[], prefill_urls=[], decode_urls=[], cors_allowed_origins=[]
|
||||
)
|
||||
|
||||
# All empty lists should be preserved
|
||||
assert args.worker_urls == []
|
||||
assert args.prefill_urls == []
|
||||
assert args.decode_urls == []
|
||||
assert args.cors_allowed_origins == []
|
||||
|
||||
def test_config_with_empty_dicts(self):
|
||||
"""Test configuration with empty dictionaries."""
|
||||
args = RouterArgs(selector={}, prefill_selector={}, decode_selector={})
|
||||
|
||||
# All empty dictionaries should be preserved
|
||||
assert args.selector == {}
|
||||
assert args.prefill_selector == {}
|
||||
assert args.decode_selector == {}
|
||||
1056
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_startup_sequence.py
vendored
Normal file
1056
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_startup_sequence.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
509
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_validation.py
vendored
Normal file
509
third_party/sglang/sgl-model-gateway/bindings/python/tests/test_validation.py
vendored
Normal file
@@ -0,0 +1,509 @@
|
||||
"""
|
||||
Unit tests for validation logic in sglang_router.
|
||||
|
||||
These tests focus on testing the validation logic in isolation,
|
||||
including parameter validation, URL validation, and configuration validation.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
|
||||
class TestURLValidation:
|
||||
"""Test URL validation logic."""
|
||||
|
||||
def test_valid_worker_urls(self):
|
||||
"""Test validation of valid worker URLs."""
|
||||
valid_urls = [
|
||||
"http://worker1:8000",
|
||||
"https://worker2:8000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://worker.example.com:8000",
|
||||
]
|
||||
|
||||
for url in valid_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.worker_urls
|
||||
|
||||
def test_valid_prefill_urls(self):
|
||||
"""Test validation of valid prefill URLs."""
|
||||
valid_prefill_urls = [
|
||||
("http://prefill1:8000", 9000),
|
||||
("https://prefill2:8000", None),
|
||||
("http://localhost:8000", 9000),
|
||||
("http://127.0.0.1:8000", None),
|
||||
]
|
||||
|
||||
for url, bootstrap_port in valid_prefill_urls:
|
||||
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
|
||||
# Should not raise any validation errors
|
||||
assert (url, bootstrap_port) in args.prefill_urls
|
||||
|
||||
def test_valid_decode_urls(self):
|
||||
"""Test validation of valid decode URLs."""
|
||||
valid_decode_urls = [
|
||||
"http://decode1:8001",
|
||||
"https://decode2:8001",
|
||||
"http://localhost:8001",
|
||||
"http://127.0.0.1:8001",
|
||||
]
|
||||
|
||||
for url in valid_decode_urls:
|
||||
args = RouterArgs(decode_urls=[url])
|
||||
# Should not raise any validation errors
|
||||
assert url in args.decode_urls
|
||||
|
||||
def test_malformed_urls(self):
|
||||
"""Test handling of malformed URLs."""
|
||||
# Note: The current implementation doesn't validate URL format
|
||||
# This test documents the current behavior
|
||||
malformed_urls = [
|
||||
"not-a-url",
|
||||
"ftp://worker1:8000", # Wrong protocol
|
||||
"http://", # Missing host
|
||||
":8000", # Missing protocol and host
|
||||
"http://worker1", # Missing port
|
||||
]
|
||||
|
||||
for url in malformed_urls:
|
||||
args = RouterArgs(worker_urls=[url])
|
||||
# Currently, malformed URLs are accepted
|
||||
# This might be something to improve in the future
|
||||
assert url in args.worker_urls
|
||||
|
||||
|
||||
class TestPortValidation:
|
||||
"""Test port validation logic."""
|
||||
|
||||
def test_valid_ports(self):
|
||||
"""Test validation of valid port numbers."""
|
||||
valid_ports = [1, 80, 8000, 30000, 65535]
|
||||
|
||||
for port in valid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
assert args.port == port
|
||||
|
||||
def test_invalid_ports(self):
|
||||
"""Test handling of invalid port numbers."""
|
||||
# Note: The current implementation doesn't validate port ranges
|
||||
# This test documents the current behavior
|
||||
invalid_ports = [0, -1, 65536, 70000]
|
||||
|
||||
for port in invalid_ports:
|
||||
args = RouterArgs(port=port)
|
||||
# Currently, invalid ports are accepted
|
||||
# This might be something to improve in the future
|
||||
assert args.port == port
|
||||
|
||||
def test_bootstrap_port_validation(self):
|
||||
"""Test validation of bootstrap ports in PD mode."""
|
||||
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
|
||||
|
||||
for bootstrap_port in valid_bootstrap_ports:
|
||||
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
|
||||
assert args.prefill_urls[0][1] == bootstrap_port
|
||||
|
||||
|
||||
class TestParameterValidation:
|
||||
"""Test parameter validation logic."""
|
||||
|
||||
def test_cache_threshold_validation(self):
|
||||
"""Test cache threshold parameter validation."""
|
||||
# Valid cache thresholds
|
||||
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
|
||||
|
||||
for threshold in valid_thresholds:
|
||||
args = RouterArgs(cache_threshold=threshold)
|
||||
assert args.cache_threshold == threshold
|
||||
|
||||
def test_balance_threshold_validation(self):
|
||||
"""Test load balancing threshold parameter validation."""
|
||||
# Valid absolute thresholds
|
||||
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
|
||||
for threshold in valid_abs_thresholds:
|
||||
args = RouterArgs(balance_abs_threshold=threshold)
|
||||
assert args.balance_abs_threshold == threshold
|
||||
|
||||
# Valid relative thresholds
|
||||
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
|
||||
for threshold in valid_rel_thresholds:
|
||||
args = RouterArgs(balance_rel_threshold=threshold)
|
||||
assert args.balance_rel_threshold == threshold
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout parameter validation."""
|
||||
# Valid timeouts
|
||||
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
|
||||
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
worker_startup_timeout_secs=timeout,
|
||||
worker_startup_check_interval=timeout,
|
||||
request_timeout_secs=timeout,
|
||||
queue_timeout_secs=timeout,
|
||||
)
|
||||
assert args.worker_startup_timeout_secs == timeout
|
||||
assert args.worker_startup_check_interval == timeout
|
||||
assert args.request_timeout_secs == timeout
|
||||
assert args.queue_timeout_secs == timeout
|
||||
|
||||
def test_retry_parameter_validation(self):
|
||||
"""Test retry parameter validation."""
|
||||
# Valid retry parameters
|
||||
valid_retry_counts = [0, 1, 3, 5, 10]
|
||||
for count in valid_retry_counts:
|
||||
args = RouterArgs(retry_max_retries=count)
|
||||
assert args.retry_max_retries == count
|
||||
|
||||
# Valid backoff parameters
|
||||
valid_backoff_ms = [1, 50, 100, 1000, 30000]
|
||||
for backoff in valid_backoff_ms:
|
||||
args = RouterArgs(
|
||||
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
|
||||
)
|
||||
assert args.retry_initial_backoff_ms == backoff
|
||||
assert args.retry_max_backoff_ms == backoff
|
||||
|
||||
# Valid multiplier parameters
|
||||
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
|
||||
for multiplier in valid_multipliers:
|
||||
args = RouterArgs(retry_backoff_multiplier=multiplier)
|
||||
assert args.retry_backoff_multiplier == multiplier
|
||||
|
||||
# Valid jitter parameters
|
||||
valid_jitter = [0.0, 0.1, 0.2, 0.5]
|
||||
for jitter in valid_jitter:
|
||||
args = RouterArgs(retry_jitter_factor=jitter)
|
||||
assert args.retry_jitter_factor == jitter
|
||||
|
||||
def test_circuit_breaker_parameter_validation(self):
|
||||
"""Test circuit breaker parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 3, 5, 10, 20]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(cb_failure_threshold=threshold)
|
||||
assert args.cb_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(cb_success_threshold=threshold)
|
||||
assert args.cb_success_threshold == threshold
|
||||
|
||||
# Valid timeout durations
|
||||
valid_timeouts = [10, 30, 60, 120, 300]
|
||||
for timeout in valid_timeouts:
|
||||
args = RouterArgs(
|
||||
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
|
||||
)
|
||||
assert args.cb_timeout_duration_secs == timeout
|
||||
assert args.cb_window_duration_secs == timeout
|
||||
|
||||
def test_health_check_parameter_validation(self):
|
||||
"""Test health check parameter validation."""
|
||||
# Valid failure thresholds
|
||||
valid_failure_thresholds = [1, 2, 3, 5, 10]
|
||||
for threshold in valid_failure_thresholds:
|
||||
args = RouterArgs(health_failure_threshold=threshold)
|
||||
assert args.health_failure_threshold == threshold
|
||||
|
||||
# Valid success thresholds
|
||||
valid_success_thresholds = [1, 2, 3, 5]
|
||||
for threshold in valid_success_thresholds:
|
||||
args = RouterArgs(health_success_threshold=threshold)
|
||||
assert args.health_success_threshold == threshold
|
||||
|
||||
# Valid timeouts and intervals
|
||||
valid_times = [1, 5, 10, 30, 60, 120]
|
||||
for time_val in valid_times:
|
||||
args = RouterArgs(
|
||||
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
|
||||
)
|
||||
assert args.health_check_timeout_secs == time_val
|
||||
assert args.health_check_interval_secs == time_val
|
||||
|
||||
def test_rate_limiting_parameter_validation(self):
|
||||
"""Test rate limiting parameter validation."""
|
||||
# Valid concurrent request limits
|
||||
valid_limits = [1, 10, 64, 256, 512, 1000]
|
||||
for limit in valid_limits:
|
||||
args = RouterArgs(max_concurrent_requests=limit)
|
||||
assert args.max_concurrent_requests == limit
|
||||
|
||||
# Valid queue sizes
|
||||
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
|
||||
for size in valid_queue_sizes:
|
||||
args = RouterArgs(queue_size=size)
|
||||
assert args.queue_size == size
|
||||
|
||||
# Valid token rates
|
||||
valid_rates = [1, 10, 50, 100, 500, 1000]
|
||||
for rate in valid_rates:
|
||||
args = RouterArgs(rate_limit_tokens_per_second=rate)
|
||||
assert args.rate_limit_tokens_per_second == rate
|
||||
|
||||
def test_tree_size_validation(self):
|
||||
"""Test tree size parameter validation."""
|
||||
# Valid tree sizes (powers of 2)
|
||||
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_tree_size=size)
|
||||
assert args.max_tree_size == size
|
||||
|
||||
def test_payload_size_validation(self):
|
||||
"""Test payload size parameter validation."""
|
||||
# Valid payload sizes
|
||||
valid_sizes = [
|
||||
1024, # 1KB
|
||||
1024 * 1024, # 1MB
|
||||
10 * 1024 * 1024, # 10MB
|
||||
100 * 1024 * 1024, # 100MB
|
||||
512 * 1024 * 1024, # 512MB
|
||||
1024 * 1024 * 1024, # 1GB
|
||||
]
|
||||
|
||||
for size in valid_sizes:
|
||||
args = RouterArgs(max_payload_size=size)
|
||||
assert args.max_payload_size == size
|
||||
|
||||
|
||||
class TestConfigurationValidation:
|
||||
"""Test configuration validation logic."""
|
||||
|
||||
def test_pd_mode_validation(self):
|
||||
"""Test PD mode configuration validation."""
|
||||
# Valid PD configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert len(args.prefill_urls) > 0
|
||||
assert len(args.decode_urls) > 0
|
||||
|
||||
def test_service_discovery_validation(self):
|
||||
"""Test service discovery configuration validation."""
|
||||
# Valid service discovery configuration
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker", "env": "prod"},
|
||||
service_discovery_port=8080,
|
||||
service_discovery_namespace="default",
|
||||
)
|
||||
|
||||
assert args.service_discovery is True
|
||||
assert args.selector == {"app": "worker", "env": "prod"}
|
||||
assert args.service_discovery_port == 8080
|
||||
assert args.service_discovery_namespace == "default"
|
||||
|
||||
def test_pd_service_discovery_validation(self):
|
||||
"""Test PD service discovery configuration validation."""
|
||||
# Valid PD service discovery configuration
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
service_discovery=True,
|
||||
prefill_selector={"app": "prefill"},
|
||||
decode_selector={"app": "decode"},
|
||||
)
|
||||
|
||||
assert args.pd_disaggregation is True
|
||||
assert args.service_discovery is True
|
||||
assert args.prefill_selector == {"app": "prefill"}
|
||||
assert args.decode_selector == {"app": "decode"}
|
||||
|
||||
def test_policy_validation(self):
|
||||
"""Test policy configuration validation."""
|
||||
# Valid policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for policy in valid_policies:
|
||||
args = RouterArgs(policy=policy)
|
||||
assert args.policy == policy
|
||||
|
||||
def test_pd_policy_validation(self):
|
||||
"""Test PD policy configuration validation."""
|
||||
# Valid PD policies
|
||||
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
||||
|
||||
for prefill_policy in valid_policies:
|
||||
for decode_policy in valid_policies:
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", None)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
prefill_policy=prefill_policy,
|
||||
decode_policy=decode_policy,
|
||||
)
|
||||
assert args.prefill_policy == prefill_policy
|
||||
assert args.decode_policy == decode_policy
|
||||
|
||||
def test_cors_validation(self):
|
||||
"""Test CORS configuration validation."""
|
||||
# Valid CORS origins
|
||||
valid_origins = [
|
||||
[],
|
||||
["http://localhost:3000"],
|
||||
["https://example.com"],
|
||||
["http://localhost:3000", "https://example.com"],
|
||||
["*"], # Wildcard (if supported)
|
||||
]
|
||||
|
||||
for origins in valid_origins:
|
||||
args = RouterArgs(cors_allowed_origins=origins)
|
||||
assert args.cors_allowed_origins == origins
|
||||
|
||||
def test_logging_validation(self):
|
||||
"""Test logging configuration validation."""
|
||||
# Valid log levels
|
||||
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
|
||||
|
||||
for level in valid_log_levels:
|
||||
args = RouterArgs(log_level=level)
|
||||
assert args.log_level == level
|
||||
|
||||
def test_prometheus_validation(self):
|
||||
"""Test Prometheus configuration validation."""
|
||||
# Valid Prometheus configuration
|
||||
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
||||
|
||||
assert args.prometheus_port == 29000
|
||||
assert args.prometheus_host == "127.0.0.1"
|
||||
|
||||
def test_tokenizer_validation(self):
|
||||
"""Test tokenizer configuration validation."""
|
||||
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
||||
pytest.skip("Tokenizer configuration not available in current implementation")
|
||||
|
||||
def test_request_id_headers_validation(self):
|
||||
"""Test request ID headers configuration validation."""
|
||||
# Valid request ID headers
|
||||
valid_headers = [
|
||||
["x-request-id"],
|
||||
["x-request-id", "x-trace-id"],
|
||||
["x-request-id", "x-trace-id", "x-correlation-id"],
|
||||
["custom-header"],
|
||||
]
|
||||
|
||||
for headers in valid_headers:
|
||||
args = RouterArgs(request_id_headers=headers)
|
||||
assert args.request_id_headers == headers
|
||||
|
||||
|
||||
class TestLaunchValidation:
|
||||
"""Test launch-time validation logic."""
|
||||
|
||||
def test_pd_mode_allows_empty_urls(self):
|
||||
"""Test that PD mode now allows empty URLs (URLs are optional)."""
|
||||
# PD mode without URLs is now allowed
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=False,
|
||||
)
|
||||
|
||||
# Should not raise validation error - URLs are now optional
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
# This should succeed without raising an error
|
||||
launch_router(args)
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
|
||||
"""Test that PD mode with service discovery allows empty URLs."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[],
|
||||
decode_urls=[],
|
||||
service_discovery=True,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_regular_mode_allows_empty_worker_urls(self):
|
||||
"""Test that regular mode allows empty worker URLs."""
|
||||
args = RouterArgs(worker_urls=[], service_discovery=False)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_valid_config(self):
|
||||
"""Test launching with valid configuration."""
|
||||
args = RouterArgs(
|
||||
host="127.0.0.1",
|
||||
port=30000,
|
||||
worker_urls=["http://worker1:8000"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_pd_config(self):
|
||||
"""Test launching with valid PD configuration."""
|
||||
args = RouterArgs(
|
||||
pd_disaggregation=True,
|
||||
prefill_urls=[("http://prefill1:8000", 9000)],
|
||||
decode_urls=["http://decode1:8001"],
|
||||
policy="cache_aware",
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
|
||||
def test_launch_with_service_discovery_config(self):
|
||||
"""Test launching with valid service discovery configuration."""
|
||||
args = RouterArgs(
|
||||
service_discovery=True,
|
||||
selector={"app": "worker"},
|
||||
service_discovery_port=8080,
|
||||
)
|
||||
|
||||
# Should not raise validation error
|
||||
with patch("sglang_router.launch_router.Router") as router_mod:
|
||||
mock_router_instance = MagicMock()
|
||||
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
||||
|
||||
launch_router(args)
|
||||
|
||||
# Should create router instance via from_args
|
||||
router_mod.from_args.assert_called_once()
|
||||
108
third_party/sglang/sgl-model-gateway/build.rs
vendored
Normal file
108
third_party/sglang/sgl-model-gateway/build.rs
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
use std::process::Command;
|
||||
|
||||
const DEFAULT_VERSION: &str = "0.0.0";
|
||||
const DEFAULT_PROJECT_NAME: &str = "sgl-model-gateway";
|
||||
|
||||
/// Set a compile-time environment variable with the SGL_MODEL_GATEWAY_ prefix
|
||||
macro_rules! set_env {
|
||||
($name:expr, $value:expr) => {
|
||||
println!("cargo:rustc-env=SGL_MODEL_GATEWAY_{}={}", $name, $value);
|
||||
};
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Rebuild triggers
|
||||
println!("cargo:rerun-if-changed=Cargo.toml");
|
||||
|
||||
// Set version info environment variables
|
||||
let version = read_cargo_version().unwrap_or_else(|_| DEFAULT_VERSION.to_string());
|
||||
let target = std::env::var("TARGET").unwrap_or_else(|_| get_rustc_host().unwrap_or_default());
|
||||
let profile = std::env::var("PROFILE").unwrap_or_default();
|
||||
|
||||
set_env!("PROJECT_NAME", DEFAULT_PROJECT_NAME);
|
||||
set_env!("VERSION", version);
|
||||
set_env!(
|
||||
"BUILD_TIME",
|
||||
chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC")
|
||||
);
|
||||
set_env!(
|
||||
"BUILD_MODE",
|
||||
if profile == "release" {
|
||||
"release"
|
||||
} else {
|
||||
"debug"
|
||||
}
|
||||
);
|
||||
set_env!("TARGET_TRIPLE", target);
|
||||
set_env!(
|
||||
"GIT_BRANCH",
|
||||
git_branch().unwrap_or_else(|| "unknown".into())
|
||||
);
|
||||
set_env!(
|
||||
"GIT_COMMIT",
|
||||
git_commit().unwrap_or_else(|| "unknown".into())
|
||||
);
|
||||
set_env!(
|
||||
"GIT_STATUS",
|
||||
git_status().unwrap_or_else(|| "unknown".into())
|
||||
);
|
||||
set_env!(
|
||||
"RUSTC_VERSION",
|
||||
rustc_version().unwrap_or_else(|| "unknown".into())
|
||||
);
|
||||
set_env!(
|
||||
"CARGO_VERSION",
|
||||
cargo_version().unwrap_or_else(|| "unknown".into())
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_cargo_version() -> Result<String, Box<dyn std::error::Error>> {
|
||||
let content = std::fs::read_to_string("Cargo.toml")?;
|
||||
let toml: toml::Value = toml::from_str(&content)?;
|
||||
toml.get("package")
|
||||
.and_then(|p| p.get("version"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.ok_or_else(|| "Missing version in Cargo.toml".into())
|
||||
}
|
||||
|
||||
fn run_cmd(cmd: &str, args: &[&str]) -> Option<String> {
|
||||
Command::new(cmd)
|
||||
.args(args)
|
||||
.output()
|
||||
.ok()
|
||||
.filter(|o| o.status.success())
|
||||
.and_then(|o| String::from_utf8(o.stdout).ok())
|
||||
.map(|s| s.trim().to_string())
|
||||
}
|
||||
|
||||
fn git_branch() -> Option<String> {
|
||||
run_cmd("git", &["rev-parse", "--abbrev-ref", "HEAD"])
|
||||
}
|
||||
|
||||
fn git_commit() -> Option<String> {
|
||||
run_cmd("git", &["rev-parse", "--short", "HEAD"])
|
||||
}
|
||||
|
||||
fn git_status() -> Option<String> {
|
||||
run_cmd("git", &["status", "--porcelain"])
|
||||
.map(|s| if s.is_empty() { "clean" } else { "dirty" }.into())
|
||||
}
|
||||
|
||||
fn rustc_version() -> Option<String> {
|
||||
run_cmd("rustc", &["--version"])
|
||||
}
|
||||
|
||||
fn cargo_version() -> Option<String> {
|
||||
run_cmd("cargo", &["--version"])
|
||||
}
|
||||
|
||||
fn get_rustc_host() -> Option<String> {
|
||||
run_cmd("rustc", &["-vV"])?
|
||||
.lines()
|
||||
.find(|l| l.starts_with("host: "))
|
||||
.and_then(|l| l.strip_prefix("host: "))
|
||||
.map(|s| s.trim().to_string())
|
||||
}
|
||||
1
third_party/sglang/sgl-model-gateway/e2e_test/__init__.py
vendored
Normal file
1
third_party/sglang/sgl-model-gateway/e2e_test/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""Test package root for router Python tests."""
|
||||
0
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/__init__.py
vendored
Normal file
0
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/__init__.py
vendored
Normal file
222
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/conftest.py
vendored
Normal file
222
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/conftest.py
vendored
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Benchmark-specific fixtures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from infra import GPUMonitor, should_monitor_gpu, terminate_process
|
||||
|
||||
from .results import BenchmarkResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_command(
|
||||
cli: str,
|
||||
router_url: str,
|
||||
model_path: str,
|
||||
experiment_folder: str,
|
||||
num_concurrency: int,
|
||||
traffic_scenario: str,
|
||||
max_requests: int,
|
||||
) -> list[str]:
|
||||
"""Build genai-bench command."""
|
||||
return [
|
||||
cli,
|
||||
"benchmark",
|
||||
"--api-backend",
|
||||
"openai",
|
||||
"--api-base",
|
||||
router_url,
|
||||
"--api-key",
|
||||
"dummy-token",
|
||||
"--api-model-name",
|
||||
model_path,
|
||||
"--model-tokenizer",
|
||||
model_path,
|
||||
"--task",
|
||||
"text-to-text",
|
||||
"--num-concurrency",
|
||||
str(num_concurrency),
|
||||
"--traffic-scenario",
|
||||
traffic_scenario,
|
||||
"--max-requests-per-run",
|
||||
str(max_requests),
|
||||
"--max-time-per-run",
|
||||
"3",
|
||||
"--experiment-folder-name",
|
||||
experiment_folder,
|
||||
"--experiment-base-dir",
|
||||
str(Path.cwd()),
|
||||
]
|
||||
|
||||
|
||||
def _find_results(experiment_folder: str, timeout: int = 10) -> list[Path]:
|
||||
"""Find benchmark result JSON files."""
|
||||
base = Path.cwd()
|
||||
folder = base / experiment_folder
|
||||
|
||||
if not folder.is_dir():
|
||||
# Search for folder
|
||||
for p in base.rglob(experiment_folder):
|
||||
if p.is_dir() and p.name == experiment_folder:
|
||||
folder = p
|
||||
break
|
||||
|
||||
if not folder.is_dir():
|
||||
raise AssertionError(f"Experiment folder not found: {experiment_folder}")
|
||||
|
||||
# Wait for JSON results
|
||||
for _ in range(timeout):
|
||||
files = [
|
||||
p
|
||||
for p in folder.rglob("*.json")
|
||||
if "experiment_metadata" not in p.name and "gpu_utilization" not in p.name
|
||||
]
|
||||
if files:
|
||||
return files
|
||||
time.sleep(1)
|
||||
|
||||
raise AssertionError(f"No JSON results found in {folder}")
|
||||
|
||||
|
||||
def _cleanup_procs(procs: list, drain_delay: int) -> None:
|
||||
"""Terminate processes gracefully."""
|
||||
if not procs:
|
||||
return
|
||||
if drain_delay > 0:
|
||||
time.sleep(drain_delay)
|
||||
for p in procs:
|
||||
try:
|
||||
proc = getattr(p, "proc", p) if hasattr(p, "proc") else p
|
||||
if isinstance(proc, subprocess.Popen):
|
||||
terminate_process(proc)
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def genai_bench_runner():
|
||||
"""Run genai-bench and validate metrics.
|
||||
|
||||
Usage:
|
||||
def test_perf(setup_backend, genai_bench_runner):
|
||||
backend, model_path, client, gateway = setup_backend
|
||||
genai_bench_runner(
|
||||
router_url=gateway.base_url,
|
||||
model_path=model_path,
|
||||
experiment_folder="benchmark_results",
|
||||
thresholds={"ttft_mean_max": 5, "gpu_util_p50_min": 99},
|
||||
)
|
||||
"""
|
||||
|
||||
def _run(
|
||||
*,
|
||||
router_url: str,
|
||||
model_path: str,
|
||||
experiment_folder: str,
|
||||
thresholds: dict | None = None,
|
||||
timeout_sec: int | None = None,
|
||||
num_concurrency: int = 32,
|
||||
traffic_scenario: str = "D(4000,100)",
|
||||
max_requests_per_run: int | None = None,
|
||||
kill_procs: list | None = None,
|
||||
drain_delay_sec: int = 6,
|
||||
) -> None:
|
||||
cli = shutil.which("genai-bench")
|
||||
if not cli:
|
||||
pytest.fail("genai-bench CLI not found")
|
||||
|
||||
# Clean previous results
|
||||
exp_dir = Path.cwd() / experiment_folder
|
||||
if exp_dir.exists():
|
||||
shutil.rmtree(exp_dir, ignore_errors=True)
|
||||
|
||||
# Build and run command
|
||||
max_requests = max_requests_per_run or num_concurrency * 5
|
||||
cmd = _build_command(
|
||||
cli,
|
||||
router_url,
|
||||
model_path,
|
||||
experiment_folder,
|
||||
num_concurrency,
|
||||
traffic_scenario,
|
||||
max_requests,
|
||||
)
|
||||
timeout = timeout_sec or int(os.environ.get("GENAI_BENCH_TEST_TIMEOUT", "120"))
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
env=os.environ.copy(),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pytest.fail(f"genai-bench executable not found at {cli}")
|
||||
except PermissionError:
|
||||
pytest.fail(f"Permission denied executing {cli}")
|
||||
except OSError as e:
|
||||
pytest.fail(f"Failed to start genai-bench: {e}")
|
||||
|
||||
# Start GPU monitor if needed
|
||||
gpu_monitor: GPUMonitor | None = None
|
||||
if should_monitor_gpu(thresholds):
|
||||
interval = float(os.environ.get("GPU_UTIL_SAMPLE_INTERVAL", "2.0"))
|
||||
gpu_monitor = GPUMonitor(output_dir=exp_dir, interval=interval)
|
||||
gpu_monitor.start(target_pid=proc.pid)
|
||||
|
||||
try:
|
||||
stdout, stderr = proc.communicate(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
stdout, stderr = proc.communicate()
|
||||
logger.error("genai-bench timed out after %ds", timeout)
|
||||
|
||||
# Log output if process failed or for debugging
|
||||
if proc.returncode != 0:
|
||||
logger.error(
|
||||
"genai-bench exited with code %d\nstdout:\n%s\nstderr:\n%s",
|
||||
proc.returncode,
|
||||
stdout or "(empty)",
|
||||
stderr or "(empty)",
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse and validate results
|
||||
for path in _find_results(experiment_folder):
|
||||
result = BenchmarkResult.from_json(path)
|
||||
result.log(experiment_folder, logger)
|
||||
if thresholds:
|
||||
result.validate(thresholds)
|
||||
|
||||
# Validate GPU utilization
|
||||
if gpu_monitor:
|
||||
gpu_monitor.stop()
|
||||
gpu_monitor.log_summary()
|
||||
gpu_monitor.assert_thresholds(thresholds)
|
||||
|
||||
except AssertionError:
|
||||
# Log genai-bench output when results not found
|
||||
logger.error(
|
||||
"genai-bench output (returncode=%d):\nstdout:\n%s\nstderr:\n%s",
|
||||
proc.returncode,
|
||||
stdout or "(empty)",
|
||||
stderr or "(empty)",
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
_cleanup_procs(kill_procs, drain_delay_sec)
|
||||
if gpu_monitor:
|
||||
gpu_monitor.stop(timeout=2)
|
||||
|
||||
return _run
|
||||
98
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/results.py
vendored
Normal file
98
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/results.py
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Benchmark result dataclasses for parsing genai-bench and GPU monitor output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Parsed benchmark metrics from genai-bench output."""
|
||||
|
||||
ttft_mean: float
|
||||
e2e_latency_mean: float
|
||||
input_throughput_mean: float
|
||||
output_throughput_mean: float
|
||||
file_name: str
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: Path) -> "BenchmarkResult":
|
||||
"""Parse benchmark results from JSON file."""
|
||||
with path.open() as f:
|
||||
data = json.load(f)
|
||||
stats = data.get("aggregated_metrics", {}).get("stats", {})
|
||||
return cls(
|
||||
ttft_mean=float(stats.get("ttft", {}).get("mean", float("inf"))),
|
||||
e2e_latency_mean=float(
|
||||
stats.get("e2e_latency", {}).get("mean", float("inf"))
|
||||
),
|
||||
input_throughput_mean=float(
|
||||
stats.get("input_throughput", {}).get("mean", 0.0)
|
||||
),
|
||||
output_throughput_mean=float(
|
||||
stats.get("output_throughput", {}).get("mean", 0.0)
|
||||
),
|
||||
file_name=path.name,
|
||||
)
|
||||
|
||||
def log(self, experiment: str, logger) -> None:
|
||||
"""Log benchmark results."""
|
||||
logger.info(
|
||||
"genai-bench[%s] %s ttft=%.3fs e2e=%.3fs input=%.1f tok/s output=%.1f tok/s",
|
||||
experiment,
|
||||
self.file_name,
|
||||
self.ttft_mean,
|
||||
self.e2e_latency_mean,
|
||||
self.input_throughput_mean,
|
||||
self.output_throughput_mean,
|
||||
)
|
||||
|
||||
def validate(self, thresholds: dict) -> None:
|
||||
"""Validate metrics against thresholds."""
|
||||
checks = [
|
||||
("ttft_mean_max", self.ttft_mean, "<=", "TTFT"),
|
||||
("e2e_latency_mean_max", self.e2e_latency_mean, "<=", "E2E latency"),
|
||||
(
|
||||
"input_throughput_mean_min",
|
||||
self.input_throughput_mean,
|
||||
">=",
|
||||
"Input throughput",
|
||||
),
|
||||
(
|
||||
"output_throughput_mean_min",
|
||||
self.output_throughput_mean,
|
||||
">=",
|
||||
"Output throughput",
|
||||
),
|
||||
]
|
||||
for key, value, op, name in checks:
|
||||
if key not in thresholds:
|
||||
continue
|
||||
threshold = thresholds[key]
|
||||
if op == "<=" and value > threshold:
|
||||
raise AssertionError(f"{name}: {value:.2f} > {threshold}")
|
||||
if op == ">=" and value < threshold:
|
||||
raise AssertionError(f"{name}: {value:.2f} < {threshold}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUUtilization:
|
||||
"""Parsed GPU utilization metrics from gpu_monitor output."""
|
||||
|
||||
overall_mean: float
|
||||
per_gpu: dict[str, dict[str, float]]
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: Path) -> "GPUUtilization | None":
|
||||
"""Parse GPU utilization from JSON file."""
|
||||
try:
|
||||
with path.open() as f:
|
||||
data = json.load(f)
|
||||
return cls(
|
||||
overall_mean=float(data.get("overall", {}).get("mean", 0)),
|
||||
per_gpu=data.get("per_gpu", {}),
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
119
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/summarize.py
vendored
Normal file
119
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/summarize.py
vendored
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Generate benchmark summary for GitHub Actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from results import BenchmarkResult, GPUUtilization
|
||||
|
||||
|
||||
def discover_benchmarks(base_dir: Path) -> list[tuple[Path, str]]:
|
||||
"""Auto-discover benchmark folders and their result JSON files.
|
||||
|
||||
Returns list of (json_path, label) tuples sorted by folder name.
|
||||
"""
|
||||
results = []
|
||||
for folder in base_dir.rglob("benchmark_*"):
|
||||
if not folder.is_dir():
|
||||
continue
|
||||
# Find result JSON (exclude metadata and gpu files)
|
||||
for json_file in folder.glob("*.json"):
|
||||
if (
|
||||
"experiment_metadata" not in json_file.name
|
||||
and "gpu_utilization" not in json_file.name
|
||||
):
|
||||
# Generate label from folder name: benchmark_cache_aware_pd_grpc -> cache_aware pd grpc
|
||||
label = folder.name.replace("benchmark_", "").replace("_", " ")
|
||||
results.append((json_file, label))
|
||||
break # One JSON per folder
|
||||
return sorted(results, key=lambda x: x[0].parent.name)
|
||||
|
||||
|
||||
def find_gpu_utilization(result_path: Path) -> Path | None:
|
||||
"""Find GPU utilization JSON in same folder as result."""
|
||||
gpu_json = result_path.parent / "gpu_utilization.json"
|
||||
return gpu_json if gpu_json.exists() else None
|
||||
|
||||
|
||||
def generate_summary(base_dir: Path) -> str:
|
||||
"""Generate markdown summary."""
|
||||
benchmarks = discover_benchmarks(base_dir)
|
||||
|
||||
if not benchmarks:
|
||||
return (
|
||||
"## Gateway E2E Genai-Bench Results Summary\n\nNo benchmark results found."
|
||||
)
|
||||
|
||||
lines = [
|
||||
"## Gateway E2E Genai-Bench Results Summary",
|
||||
"",
|
||||
"| Scenario | Status | TTFT (s) | E2E Latency (s) | Input Throughput (tok/s) | Output Throughput (tok/s) |",
|
||||
"|----------|--------|----------|-----------------|--------------------------|---------------------------|",
|
||||
]
|
||||
|
||||
gpu_sections = []
|
||||
|
||||
for result_path, label in benchmarks:
|
||||
try:
|
||||
result = BenchmarkResult.from_json(result_path)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse {result_path}: {e}", file=sys.stderr)
|
||||
lines.append(f"| {label} | ❌ Failed | - | - | - | - |")
|
||||
continue
|
||||
|
||||
lines.append(
|
||||
f"| {label} | ✅ Success | "
|
||||
f"{result.ttft_mean:.2f} | "
|
||||
f"{result.e2e_latency_mean:.2f} | "
|
||||
f"{result.input_throughput_mean:.0f} | "
|
||||
f"{result.output_throughput_mean:.0f} |"
|
||||
)
|
||||
|
||||
# GPU utilization
|
||||
gpu_path = find_gpu_utilization(result_path)
|
||||
if gpu_path:
|
||||
gpu = GPUUtilization.from_json(gpu_path)
|
||||
if gpu and gpu.per_gpu:
|
||||
gpu_lines = [
|
||||
f"### GPU Utilization — {label}",
|
||||
"",
|
||||
f"Overall mean: {gpu.overall_mean:.2f}%",
|
||||
"",
|
||||
"| GPU | Mean (%) | p5 | p10 | p25 | p50 | p75 | p90 | p95 |",
|
||||
"|-----|----------|----|-----|-----|-----|-----|-----|-----|",
|
||||
]
|
||||
for gpu_id, stats in sorted(
|
||||
gpu.per_gpu.items(), key=lambda x: int(x[0])
|
||||
):
|
||||
gpu_lines.append(
|
||||
f"| {gpu_id} | {stats.get('mean', 0):.2f} | "
|
||||
f"{stats.get('p5', 0):.2f} | {stats.get('p10', 0):.2f} | "
|
||||
f"{stats.get('p25', 0):.2f} | {stats.get('p50', 0):.2f} | "
|
||||
f"{stats.get('p75', 0):.2f} | {stats.get('p90', 0):.2f} | "
|
||||
f"{stats.get('p95', 0):.2f} |"
|
||||
)
|
||||
gpu_sections.append("\n".join(gpu_lines))
|
||||
|
||||
return "\n".join(lines) + "\n" + "\n\n".join(gpu_sections)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
base_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path.cwd()
|
||||
summary = generate_summary(base_dir)
|
||||
|
||||
# Write to GITHUB_STEP_SUMMARY if available
|
||||
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
|
||||
if summary_file:
|
||||
with open(summary_file, "a") as f:
|
||||
f.write(summary)
|
||||
f.write("\n")
|
||||
print(f"Summary written to {summary_file}")
|
||||
else:
|
||||
print(summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
26
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/test_pd_perf.py
vendored
Normal file
26
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/test_pd_perf.py
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
"""PD (prefill/decode disaggregation) router performance benchmark test."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.workers(prefill=2, decode=2)
|
||||
@pytest.mark.parametrize("setup_backend", ["pd"], indirect=True)
|
||||
class TestPDPerf:
|
||||
"""Performance benchmark for PD disaggregation router."""
|
||||
|
||||
def test_pd_perf(self, setup_backend, genai_bench_runner):
|
||||
"""Run genai-bench against PD router and validate metrics."""
|
||||
backend, model_path, client, gateway = setup_backend
|
||||
genai_bench_runner(
|
||||
router_url=gateway.base_url,
|
||||
model_path=model_path,
|
||||
experiment_folder="benchmark_round_robin_pd",
|
||||
thresholds={
|
||||
"ttft_mean_max": 13,
|
||||
"e2e_latency_mean_max": 16,
|
||||
"input_throughput_mean_min": 350,
|
||||
"output_throughput_mean_min": 18,
|
||||
"gpu_util_p50_min": 99,
|
||||
},
|
||||
)
|
||||
27
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/test_regular_perf.py
vendored
Normal file
27
third_party/sglang/sgl-model-gateway/e2e_test/benchmarks/test_regular_perf.py
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Regular router performance benchmark test."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.workers(count=4)
|
||||
@pytest.mark.gateway(policy="cache_aware")
|
||||
@pytest.mark.parametrize("setup_backend", ["http", "grpc"], indirect=True)
|
||||
class TestRegularPerf:
|
||||
"""Performance benchmark for regular (non-PD) router."""
|
||||
|
||||
def test_regular_perf(self, setup_backend, genai_bench_runner):
|
||||
"""Run genai-bench against regular router and validate metrics."""
|
||||
backend, model_path, client, gateway = setup_backend
|
||||
genai_bench_runner(
|
||||
router_url=gateway.base_url,
|
||||
model_path=model_path,
|
||||
experiment_folder=f"benchmark_cache_aware_regular_{backend}",
|
||||
thresholds={
|
||||
"ttft_mean_max": 6,
|
||||
"e2e_latency_mean_max": 14,
|
||||
"input_throughput_mean_min": 800,
|
||||
"output_throughput_mean_min": 12,
|
||||
"gpu_util_p50_min": 99,
|
||||
},
|
||||
)
|
||||
0
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/__init__.py
vendored
Normal file
0
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/__init__.py
vendored
Normal file
168
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_enable_thinking.py
vendored
Normal file
168
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_enable_thinking.py
vendored
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Enable Thinking E2E Tests.
|
||||
|
||||
Tests for chat completions with enable_thinking feature (Qwen3 reasoning).
|
||||
|
||||
Source: Migrated from e2e_grpc/features/test_enable_thinking.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# API key is not validated by the gateway, but required for OpenAI-compatible headers
|
||||
API_KEY = "not-used"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Enable Thinking Tests (Qwen 30B)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.model("qwen-30b")
|
||||
@pytest.mark.gateway(
|
||||
extra_args=["--reasoning-parser", "qwen3", "--history-backend", "memory"]
|
||||
)
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
|
||||
class TestEnableThinking:
|
||||
"""Tests for enable_thinking feature with Qwen3 reasoning parser."""
|
||||
|
||||
def test_chat_completion_with_reasoning(self, setup_backend):
|
||||
"""Test non-streaming with enable_thinking=True, reasoning_content should not be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = requests.post(
|
||||
f"{gateway.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {API_KEY}"},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed with: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "choices" in data
|
||||
assert len(data["choices"]) > 0
|
||||
assert "message" in data["choices"][0]
|
||||
assert "reasoning_content" in data["choices"][0]["message"]
|
||||
assert data["choices"][0]["message"]["reasoning_content"] is not None
|
||||
|
||||
def test_chat_completion_without_reasoning(self, setup_backend):
|
||||
"""Test non-streaming with enable_thinking=False, reasoning_content should be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = requests.post(
|
||||
f"{gateway.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {API_KEY}"},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed with: {response.text}"
|
||||
data = response.json()
|
||||
|
||||
assert "choices" in data
|
||||
assert len(data["choices"]) > 0
|
||||
assert "message" in data["choices"][0]
|
||||
|
||||
if "reasoning_content" in data["choices"][0]["message"]:
|
||||
assert data["choices"][0]["message"]["reasoning_content"] is None
|
||||
|
||||
def test_stream_chat_completion_with_reasoning(self, setup_backend):
|
||||
"""Test streaming with enable_thinking=True, reasoning_content should not be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = requests.post(
|
||||
f"{gateway.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {API_KEY}"},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed with: {response.text}"
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
assert (
|
||||
has_reasoning
|
||||
), "The reasoning content is not included in the stream response"
|
||||
assert has_content, "The stream response does not contain normal content"
|
||||
|
||||
def test_stream_chat_completion_without_reasoning(self, setup_backend):
|
||||
"""Test streaming with enable_thinking=False, reasoning_content should be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = requests.post(
|
||||
f"{gateway.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {API_KEY}"},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Failed with: {response.text}"
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
assert (
|
||||
not has_reasoning
|
||||
), "The reasoning content should not be included in the stream response"
|
||||
assert has_content, "The stream response does not contain normal content"
|
||||
1529
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_function_calling.py
vendored
Normal file
1529
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_function_calling.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
316
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_openai_server.py
vendored
Normal file
316
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_openai_server.py
vendored
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Chat Completions API E2E Tests - OpenAI Server Compatibility.
|
||||
|
||||
Tests for OpenAI-compatible chat completions API through the gateway.
|
||||
|
||||
Source: Migrated from e2e_grpc/basic/test_openai_server.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Completion Tests (Llama 8B)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.model("llama-8b")
|
||||
@pytest.mark.gateway(extra_args=["--history-backend", "memory"])
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
|
||||
class TestChatCompletion:
|
||||
"""Tests for OpenAI-compatible chat completions API."""
|
||||
|
||||
@pytest.mark.parametrize("logprobs", [None, 5])
|
||||
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
|
||||
def test_chat_completion(self, setup_backend, logprobs, parallel_sample_num):
|
||||
"""Test non-streaming chat completion with logprobs and parallel sampling."""
|
||||
_, model, client, gateway = setup_backend
|
||||
self._run_chat_completion(client, model, logprobs, parallel_sample_num)
|
||||
|
||||
@pytest.mark.parametrize("logprobs", [None, 5])
|
||||
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
|
||||
def test_chat_completion_stream(self, setup_backend, logprobs, parallel_sample_num):
|
||||
"""Test streaming chat completion with logprobs and parallel sampling."""
|
||||
_, model, client, gateway = setup_backend
|
||||
self._run_chat_completion_stream(client, model, logprobs, parallel_sample_num)
|
||||
|
||||
def test_regex(self, setup_backend):
|
||||
"""Test structured output with regex constraint."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": [\d]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
raise
|
||||
assert isinstance(js_obj["name"], str)
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
def test_penalty(self, setup_backend):
|
||||
"""Test frequency penalty parameter."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
frequency_penalty=1.0,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
|
||||
def test_response_prefill(self, setup_backend):
|
||||
"""Test assistant message prefill with continue_final_message."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
Extract the name, size, price, and color from this product description as a JSON object:
|
||||
|
||||
<description>
|
||||
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
|
||||
</description>
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\n",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
extra_body={"continue_final_message": True},
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0]
|
||||
.message.content.strip()
|
||||
.startswith('"name": "SmartHome Mini",')
|
||||
)
|
||||
|
||||
def test_model_list(self, setup_backend):
|
||||
"""Test listing available models."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
models = list(client.models.list().data)
|
||||
assert len(models) == 1
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Skipping retrieve model test as it is not supported by the router"
|
||||
)
|
||||
def test_retrieve_model(self, setup_backend):
|
||||
"""Test retrieving a specific model."""
|
||||
import openai
|
||||
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
retrieved_model = client.models.retrieve(model)
|
||||
assert retrieved_model.id == model
|
||||
assert retrieved_model.root == model
|
||||
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
client.models.retrieve("non-existent-model")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helper methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _run_chat_completion(self, client, model, logprobs, parallel_sample_num):
|
||||
"""Run a non-streaming chat completion and verify response."""
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
if logprobs:
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
)
|
||||
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert len(response.choices) == parallel_sample_num
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def _run_chat_completion_stream(
|
||||
self, client, model, logprobs, parallel_sample_num=1
|
||||
):
|
||||
"""Run a streaming chat completion and verify response chunks."""
|
||||
generator = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
is_finished = {}
|
||||
finish_reason_counts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0, "usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, "usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, "usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason is not None:
|
||||
is_finished[index] = True
|
||||
finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1
|
||||
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
assert (
|
||||
data.role == "assistant"
|
||||
), "data.role was not 'assistant' for first chunk"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs and not is_finished.get(index, False):
|
||||
assert response.choices[0].logprobs, "logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
), "top_logprobs token was not a string"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||
), "top_logprobs was not a list"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert (
|
||||
isinstance(data.content, str)
|
||||
or isinstance(data.reasoning_content, str)
|
||||
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
|
||||
or response.choices[0].finish_reason
|
||||
)
|
||||
assert response.id
|
||||
assert response.created
|
||||
|
||||
for index in range(parallel_sample_num):
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
for index in range(parallel_sample_num):
|
||||
assert (
|
||||
index in finish_reason_counts
|
||||
), f"No finish_reason found for index {index}"
|
||||
assert finish_reason_counts[index] == 1, (
|
||||
f"Expected 1 finish_reason chunk for index {index}, "
|
||||
f"got {finish_reason_counts[index]}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Completion Tests (GPT-OSS)
|
||||
#
|
||||
# NOTE: Some tests are skipped because they don't work with OSS models:
|
||||
# - test_regex: OSS models don't support regex constraints
|
||||
# - test_penalty: OSS models don't support frequency_penalty
|
||||
# - test_response_prefill: OSS models don't support continue_final_message
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.model("gpt-oss")
|
||||
@pytest.mark.gateway(
|
||||
extra_args=["--reasoning-parser=gpt-oss", "--history-backend", "memory"]
|
||||
)
|
||||
class TestChatCompletionGptOss(TestChatCompletion):
|
||||
"""Tests for chat completions API with GPT-OSS model.
|
||||
|
||||
Inherits from TestChatCompletion and overrides tests that don't work
|
||||
with OSS models.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("logprobs", [None]) # No logprobs for OSS
|
||||
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
|
||||
def test_chat_completion(self, setup_backend, logprobs, parallel_sample_num):
|
||||
"""Test non-streaming chat completion with parallel sampling (no logprobs)."""
|
||||
super().test_chat_completion(setup_backend, logprobs, parallel_sample_num)
|
||||
|
||||
@pytest.mark.parametrize("logprobs", [None]) # No logprobs for OSS
|
||||
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
|
||||
def test_chat_completion_stream(self, setup_backend, logprobs, parallel_sample_num):
|
||||
"""Test streaming chat completion with parallel sampling (no logprobs)."""
|
||||
super().test_chat_completion_stream(
|
||||
setup_backend, logprobs, parallel_sample_num
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="OSS models don't support regex constraints")
|
||||
def test_regex(self, setup_backend):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="OSS models don't support frequency_penalty")
|
||||
def test_penalty(self, setup_backend):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="OSS models don't support continue_final_message")
|
||||
def test_response_prefill(self, setup_backend):
|
||||
pass
|
||||
165
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_reasoning_content.py
vendored
Normal file
165
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_reasoning_content.py
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Reasoning Content E2E Tests.
|
||||
|
||||
Tests for chat completions with reasoning content (DeepSeek R1 reasoning parser).
|
||||
|
||||
Source: Migrated from e2e_grpc/features/test_reasoning_content.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Reasoning Content API Tests (DeepSeek 7B)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.model("deepseek-7b")
|
||||
@pytest.mark.gateway(
|
||||
extra_args=["--reasoning-parser", "deepseek_r1", "--history-backend", "memory"]
|
||||
)
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
|
||||
class TestReasoningContentAPI:
|
||||
"""Tests for reasoning content API with DeepSeek R1 reasoning parser."""
|
||||
|
||||
def test_streaming_separate_reasoning_false(self, setup_backend):
|
||||
"""Test streaming with separate_reasoning=False, reasoning_content should be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
extra_body={"separate_reasoning": False},
|
||||
)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) == 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true(self, setup_backend):
|
||||
"""Test streaming with separate_reasoning=True, reasoning_content should not be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
extra_body={"separate_reasoning": True},
|
||||
)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
elif chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content += chunk.choices[0].delta.reasoning_content
|
||||
|
||||
assert len(reasoning_content) > 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_streaming_separate_reasoning_true_stream_reasoning_false(
|
||||
self, setup_backend
|
||||
):
|
||||
"""Test streaming with separate_reasoning=True and stream_reasoning=False."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
extra_body={"separate_reasoning": True, "stream_reasoning": False},
|
||||
)
|
||||
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
first_chunk = False
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.reasoning_content:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
if not first_chunk:
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
first_chunk = True
|
||||
if not first_chunk:
|
||||
assert (
|
||||
not chunk.choices[0].delta.reasoning_content
|
||||
or len(chunk.choices[0].delta.reasoning_content) == 0
|
||||
)
|
||||
|
||||
assert len(reasoning_content) > 0
|
||||
assert len(content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_false(self, setup_backend):
|
||||
"""Test non-streaming with separate_reasoning=False, reasoning_content should be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
max_tokens=100,
|
||||
extra_body={"separate_reasoning": False},
|
||||
)
|
||||
|
||||
assert (
|
||||
not response.choices[0].message.reasoning_content
|
||||
or len(response.choices[0].message.reasoning_content) == 0
|
||||
)
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
def test_nonstreaming_separate_reasoning_true(self, setup_backend):
|
||||
"""Test non-streaming with separate_reasoning=True, reasoning_content should not be empty."""
|
||||
_, model, client, gateway = setup_backend
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 1+3?",
|
||||
}
|
||||
],
|
||||
max_tokens=100,
|
||||
extra_body={"separate_reasoning": True},
|
||||
)
|
||||
|
||||
assert len(response.choices[0].message.reasoning_content) > 0
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
167
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_validation.py
vendored
Normal file
167
third_party/sglang/sgl-model-gateway/e2e_test/chat_completions/test_validation.py
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Validation E2E Tests.
|
||||
|
||||
Tests for validation features like ignore_eos and large token handling.
|
||||
|
||||
Source: Migrated from e2e_grpc/validation/test_openai_server_ignore_eos.py
|
||||
and e2e_grpc/validation/test_large_max_new_tokens.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy load tokenizer to avoid import errors if transformers not installed
|
||||
_tokenizer_cache: dict = {}
|
||||
_tokenizer_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_tokenizer(model_path: str):
|
||||
"""Get tokenizer for a model, with caching."""
|
||||
if model_path not in _tokenizer_cache:
|
||||
with _tokenizer_lock:
|
||||
# Re-check after acquiring the lock to handle race conditions
|
||||
if model_path not in _tokenizer_cache:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
_tokenizer_cache[model_path] = AutoTokenizer.from_pretrained(model_path)
|
||||
return _tokenizer_cache[model_path]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Ignore EOS Tests (Llama 8B)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.model("llama-8b")
|
||||
@pytest.mark.gateway(extra_args=["--history-backend", "memory"])
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
|
||||
class TestIgnoreEOS:
|
||||
"""Tests for ignore_eos feature."""
|
||||
|
||||
def test_ignore_eos(self, setup_backend):
|
||||
"""Test that ignore_eos=True allows generation to continue beyond EOS token.
|
||||
|
||||
When ignore_eos=True, the model should generate until max_tokens is reached,
|
||||
even if it encounters an EOS token.
|
||||
"""
|
||||
_, model, client, _ = setup_backend
|
||||
|
||||
tokenizer = get_tokenizer(model)
|
||||
max_tokens = 200
|
||||
|
||||
# Request without ignore_eos (default behavior - stops at EOS)
|
||||
response_default = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count from 1 to 20."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={"ignore_eos": False},
|
||||
)
|
||||
|
||||
# Request with ignore_eos=True (continues past EOS until max_tokens)
|
||||
response_ignore_eos = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count from 1 to 20."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={"ignore_eos": True},
|
||||
)
|
||||
|
||||
default_tokens = len(
|
||||
tokenizer.encode(response_default.choices[0].message.content)
|
||||
)
|
||||
ignore_eos_tokens = len(
|
||||
tokenizer.encode(response_ignore_eos.choices[0].message.content)
|
||||
)
|
||||
|
||||
# Check if ignore_eos resulted in more tokens or exactly max_tokens
|
||||
# The ignore_eos response should either:
|
||||
# 1. Have more tokens than the default response (if default stopped at EOS before max_tokens)
|
||||
# 2. Have exactly max_tokens (if it reached the max_tokens limit)
|
||||
assert (
|
||||
ignore_eos_tokens > default_tokens or ignore_eos_tokens >= max_tokens
|
||||
), f"ignore_eos did not generate more tokens: {ignore_eos_tokens} vs {default_tokens}"
|
||||
|
||||
assert response_ignore_eos.choices[0].finish_reason == "length", (
|
||||
f"Expected finish_reason='length' for ignore_eos=True, "
|
||||
f"got {response_ignore_eos.choices[0].finish_reason}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Large Max New Tokens Tests (Llama 8B)
|
||||
#
|
||||
# NOTE: This test verifies concurrent request handling with large token limits.
|
||||
# The original test monitored server logs to verify concurrency, which is not
|
||||
# possible with the pool-based infrastructure. This simplified version verifies
|
||||
# that concurrent requests complete successfully.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.model("llama-8b")
|
||||
@pytest.mark.gateway(extra_args=["--history-backend", "memory"])
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
|
||||
class TestLargeMaxNewTokens:
|
||||
"""Tests for handling large max_new_tokens with concurrent requests."""
|
||||
|
||||
def test_concurrent_chat_completions(self, setup_backend):
|
||||
"""Test that multiple concurrent requests with large token generation complete.
|
||||
|
||||
This test sends multiple requests that ask for long outputs concurrently
|
||||
to verify the server can handle concurrent long-running requests.
|
||||
"""
|
||||
_, model, client, _ = setup_backend
|
||||
|
||||
num_requests = 4
|
||||
|
||||
def run_chat_completion():
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please repeat the word 'hello' for 100 times.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=256, # Reasonable limit for concurrent test
|
||||
)
|
||||
return response
|
||||
|
||||
# Send concurrent requests
|
||||
start_time = time.time()
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=num_requests) as executor:
|
||||
for _ in range(num_requests):
|
||||
futures.append(executor.submit(run_chat_completion))
|
||||
|
||||
# Wait for all to complete and collect results
|
||||
responses = [f.result() for f in futures]
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info("Completed %d concurrent requests in %.2fs", num_requests, elapsed)
|
||||
|
||||
# Verify all requests completed successfully
|
||||
assert len(responses) == num_requests
|
||||
for i, response in enumerate(responses):
|
||||
assert response.choices[
|
||||
0
|
||||
].message.content, f"Request {i} returned empty content"
|
||||
assert response.choices[0].finish_reason in ("stop", "length"), (
|
||||
f"Request {i} had unexpected finish_reason: "
|
||||
f"{response.choices[0].finish_reason}"
|
||||
)
|
||||
225
third_party/sglang/sgl-model-gateway/e2e_test/conftest.py
vendored
Normal file
225
third_party/sglang/sgl-model-gateway/e2e_test/conftest.py
vendored
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Pytest configuration for E2E tests.
|
||||
|
||||
Parallel Execution
|
||||
------------------
|
||||
Tests can run in parallel using pytest-parallel with shared worker processes.
|
||||
Use --workers 1 --tests-per-worker N for N concurrent test threads:
|
||||
|
||||
pytest --workers 1 --tests-per-worker 4 e2e_test/router/
|
||||
|
||||
This leverages the thread-safe ModelPool and GPUAllocator classes to enable
|
||||
true shared-worker parallelism where all threads share the same session-scoped
|
||||
model_pool fixture. Tests marked with @pytest.mark.thread_unsafe will be
|
||||
automatically skipped in parallel mode.
|
||||
|
||||
Markers
|
||||
-------
|
||||
This module defines several pytest markers for configuring E2E tests:
|
||||
|
||||
@pytest.mark.model(name)
|
||||
Specify which model to use for the test.
|
||||
|
||||
Args:
|
||||
name: Model ID from MODEL_SPECS (e.g., "llama-8b", "qwen-7b")
|
||||
|
||||
GPU Resource Management:
|
||||
When GPUs are limited (e.g., 4 GPUs, 6 models), the model pool uses
|
||||
MRU (Most Recently Used) eviction:
|
||||
1. Models are pre-launched until GPUs are full
|
||||
2. When a test needs a model that isn't running, MRU model is evicted
|
||||
(models just used are likely done, models not yet used are waiting)
|
||||
3. The needed model is then launched on-demand
|
||||
|
||||
Examples:
|
||||
@pytest.mark.model("llama-8b")
|
||||
@pytest.mark.model("qwen-72b")
|
||||
|
||||
@pytest.mark.workers(count=1, prefill=None, decode=None)
|
||||
Configure worker topology for the test.
|
||||
|
||||
Args:
|
||||
count: Number of regular workers (default: 1)
|
||||
prefill: Number of prefill workers for PD disaggregation
|
||||
decode: Number of decode workers for PD disaggregation
|
||||
|
||||
Examples:
|
||||
@pytest.mark.workers(count=3) # 3 regular workers
|
||||
@pytest.mark.workers(prefill=2, decode=2) # PD mode
|
||||
|
||||
@pytest.mark.gateway(policy="round_robin", timeout=None, extra_args=None)
|
||||
Configure the gateway/router.
|
||||
|
||||
Args:
|
||||
policy: Routing policy ("round_robin", "random", etc.)
|
||||
timeout: Startup timeout in seconds
|
||||
extra_args: Additional CLI arguments for the router
|
||||
|
||||
Examples:
|
||||
@pytest.mark.gateway(policy="random")
|
||||
@pytest.mark.gateway(extra_args=["--cache-routing"])
|
||||
|
||||
@pytest.mark.e2e
|
||||
Mark test as an end-to-end test requiring GPU workers.
|
||||
|
||||
@pytest.mark.slow
|
||||
Mark test as slow-running.
|
||||
|
||||
@pytest.mark.thread_unsafe(reason=None)
|
||||
Mark test as incompatible with parallel thread execution.
|
||||
Tests with this marker are automatically skipped when running
|
||||
with --tests-per-worker > 1.
|
||||
|
||||
Args:
|
||||
reason: Optional explanation of why the test is thread-unsafe.
|
||||
|
||||
Examples:
|
||||
@pytest.mark.thread_unsafe
|
||||
@pytest.mark.thread_unsafe(reason="Modifies global state")
|
||||
|
||||
Fixtures
|
||||
--------
|
||||
model_pool: Session-scoped fixture managing SGLang worker processes.
|
||||
setup_backend: Class-scoped fixture that launches gateway + provides client.
|
||||
|
||||
Usage Examples
|
||||
--------------
|
||||
Basic test with default model:
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.parametrize("setup_backend", ["http"], indirect=True)
|
||||
class TestBasic:
|
||||
def test_chat(self, setup_backend):
|
||||
backend, model, client, gateway = setup_backend
|
||||
response = client.chat.completions.create(...)
|
||||
|
||||
Test with specific model and multiple backends:
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.model("qwen-7b")
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc", "http"], indirect=True)
|
||||
class TestQwen:
|
||||
def test_generate(self, setup_backend):
|
||||
...
|
||||
|
||||
PD disaggregation mode:
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.workers(prefill=1, decode=1)
|
||||
@pytest.mark.parametrize("setup_backend", ["pd"], indirect=True)
|
||||
class TestPD:
|
||||
def test_pd_inference(self, setup_backend):
|
||||
...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path setup (must happen before other imports)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ROOT = Path(__file__).resolve().parents[1] # sgl-model-gateway/
|
||||
_E2E_TEST = Path(__file__).resolve().parent # e2e_test/
|
||||
_SRC = _ROOT / "bindings" / "python"
|
||||
|
||||
# Add e2e_test to path so "from infra import ..." works
|
||||
if str(_E2E_TEST) not in sys.path:
|
||||
sys.path.insert(0, str(_E2E_TEST))
|
||||
|
||||
# Add bindings/python to path if the wheel is not installed (for local development)
|
||||
_wheel_installed = find_spec("sglang_router.sglang_router_rs") is not None
|
||||
|
||||
if not _wheel_installed and str(_SRC) not in sys.path:
|
||||
sys.path.insert(0, str(_SRC))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging setup (clean output without pytest's "---- live log ----" dividers)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
"""Configure clean logging to stdout with timestamps and thread info.
|
||||
|
||||
In parallel mode (--tests-per-worker > 1), logs from different threads
|
||||
would be interleaved. Including thread name helps identify which test
|
||||
produced each log line.
|
||||
"""
|
||||
# Include thread name for parallel execution readability
|
||||
# MainThread for sequential, Thread-N for parallel workers
|
||||
fmt = "%(asctime)s.%(msecs)03d [%(threadName)s] [%(name)s] %(message)s"
|
||||
datefmt = "%H:%M:%S"
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(logging.Formatter(fmt, datefmt))
|
||||
|
||||
for logger_name in ("e2e_test", "infra", "fixtures"):
|
||||
log = logging.getLogger(logger_name)
|
||||
log.setLevel(logging.INFO)
|
||||
log.addHandler(handler)
|
||||
log.propagate = False
|
||||
|
||||
for logger_name in ("openai", "httpx", "httpcore", "numexpr"):
|
||||
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
||||
|
||||
|
||||
_setup_logging()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test visibility hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def pytest_runtest_logstart(nodeid: str, location: tuple) -> None:
|
||||
"""Print clear test header at start of each test."""
|
||||
import threading
|
||||
|
||||
from infra import LOG_SEPARATOR_WIDTH
|
||||
|
||||
test_name = nodeid.split("::")[-1] if "::" in nodeid else nodeid
|
||||
thread_name = threading.current_thread().name
|
||||
print(f"\n{'=' * LOG_SEPARATOR_WIDTH}")
|
||||
print(f"[{thread_name}] TEST: {test_name}")
|
||||
print(f"{'=' * LOG_SEPARATOR_WIDTH}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import pytest hooks and fixtures from fixtures/ package
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Import fixtures - pytest discovers these by name
|
||||
# Import hooks - pytest discovers these by name
|
||||
from fixtures import (
|
||||
backend_router,
|
||||
model_base_url,
|
||||
model_client,
|
||||
model_pool,
|
||||
pytest_collection_finish,
|
||||
pytest_collection_modifyitems,
|
||||
pytest_configure,
|
||||
pytest_runtest_setup,
|
||||
setup_backend,
|
||||
)
|
||||
|
||||
# Re-export for pytest discovery
|
||||
__all__ = [
|
||||
# Hooks
|
||||
"pytest_runtest_logstart",
|
||||
"pytest_collection_modifyitems",
|
||||
"pytest_collection_finish",
|
||||
"pytest_configure",
|
||||
"pytest_runtest_setup",
|
||||
# Fixtures
|
||||
"model_pool",
|
||||
"model_client",
|
||||
"model_base_url",
|
||||
"setup_backend",
|
||||
"backend_router",
|
||||
]
|
||||
0
third_party/sglang/sgl-model-gateway/e2e_test/embeddings/__init__.py
vendored
Normal file
0
third_party/sglang/sgl-model-gateway/e2e_test/embeddings/__init__.py
vendored
Normal file
143
third_party/sglang/sgl-model-gateway/e2e_test/embeddings/test_basic.py
vendored
Normal file
143
third_party/sglang/sgl-model-gateway/e2e_test/embeddings/test_basic.py
vendored
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Basic embedding API tests.
|
||||
|
||||
Tests the embedding functionality through the router with both gRPC and HTTP backends.
|
||||
|
||||
Source: Migrated from e2e_grpc/basic/test_embedding_server.py
|
||||
|
||||
Usage:
|
||||
pytest e2e_test/embeddings/test_basic.py -v
|
||||
pytest e2e_test/embeddings/test_basic.py -v -k "grpc"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.model("embedding")
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc", "http"], indirect=True)
|
||||
class TestEmbeddingBasic:
|
||||
"""Basic embedding API tests using local workers (gRPC and HTTP)."""
|
||||
|
||||
def test_embedding_single(self, setup_backend):
|
||||
"""Test single text embedding.
|
||||
|
||||
Verifies that:
|
||||
- Response object structure is correct
|
||||
- Embedding is a non-empty list of floats
|
||||
- Usage statistics are present
|
||||
"""
|
||||
backend, model, client, gateway = setup_backend
|
||||
|
||||
input_text = "Hello world"
|
||||
response = client.embeddings.create(
|
||||
model=model,
|
||||
input=input_text,
|
||||
)
|
||||
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
|
||||
embedding = response.data[0]
|
||||
assert embedding.object == "embedding"
|
||||
assert embedding.index == 0
|
||||
assert len(embedding.embedding) > 0
|
||||
assert isinstance(embedding.embedding[0], float)
|
||||
|
||||
# Verify usage statistics
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.total_tokens == response.usage.prompt_tokens
|
||||
|
||||
logger.info(
|
||||
"Single embedding: %d dimensions, %d tokens",
|
||||
len(embedding.embedding),
|
||||
response.usage.prompt_tokens,
|
||||
)
|
||||
|
||||
def test_embedding_batch(self, setup_backend):
|
||||
"""Test batch embedding with multiple texts.
|
||||
|
||||
Note: The original test expected len(response.data) == 1 for batch,
|
||||
which seems incorrect. This might be model-specific behavior.
|
||||
"""
|
||||
backend, model, client, gateway = setup_backend
|
||||
|
||||
input_texts = ["Hello world", "SGLang is fast"]
|
||||
response = client.embeddings.create(
|
||||
model=model,
|
||||
input=input_texts,
|
||||
)
|
||||
|
||||
# Note: Original test had len(response.data) == 1, which seems like
|
||||
# a bug or model-specific behavior. Standard behavior should return
|
||||
# one embedding per input text.
|
||||
assert len(response.data) >= 1
|
||||
assert response.data[0].index == 0
|
||||
assert len(response.data[0].embedding) > 0
|
||||
|
||||
logger.info("Batch embedding: %d results", len(response.data))
|
||||
|
||||
def test_embedding_dimensions_consistent(self, setup_backend):
|
||||
"""Test that embedding dimensions are consistent across different inputs.
|
||||
|
||||
Verifies that different length inputs produce embeddings with
|
||||
the same dimensionality.
|
||||
"""
|
||||
backend, model, client, gateway = setup_backend
|
||||
|
||||
response1 = client.embeddings.create(
|
||||
model=model,
|
||||
input="A short text",
|
||||
)
|
||||
dim1 = len(response1.data[0].embedding)
|
||||
|
||||
response2 = client.embeddings.create(
|
||||
model=model,
|
||||
input="A much longer text to ensure dimensions match regardless of input length",
|
||||
)
|
||||
dim2 = len(response2.data[0].embedding)
|
||||
|
||||
assert dim1 == dim2, f"Dimensions differ: {dim1} vs {dim2}"
|
||||
logger.info("Embedding dimensions: %d (consistent)", dim1)
|
||||
|
||||
def test_embedding_empty_string(self, setup_backend):
|
||||
"""Test embedding with empty string input.
|
||||
|
||||
Some models may handle empty strings differently.
|
||||
This test verifies the API doesn't crash on empty input.
|
||||
"""
|
||||
backend, model, client, gateway = setup_backend
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
model=model,
|
||||
input="",
|
||||
)
|
||||
# If it succeeds, verify structure
|
||||
assert len(response.data) >= 1
|
||||
logger.info("Empty string embedding succeeded")
|
||||
except Exception as e:
|
||||
# Some models may reject empty strings - that's acceptable
|
||||
logger.info("Empty string embedding rejected: %s", e)
|
||||
|
||||
def test_embedding_unicode(self, setup_backend):
|
||||
"""Test embedding with unicode characters.
|
||||
|
||||
Verifies that the API handles non-ASCII characters correctly.
|
||||
"""
|
||||
backend, model, client, gateway = setup_backend
|
||||
|
||||
input_text = "Hello 世界! 🚀 Привет мир"
|
||||
response = client.embeddings.create(
|
||||
model=model,
|
||||
input=input_text,
|
||||
)
|
||||
|
||||
assert len(response.data) == 1
|
||||
assert len(response.data[0].embedding) > 0
|
||||
logger.info("Unicode embedding: %d dimensions", len(response.data[0].embedding))
|
||||
262
third_party/sglang/sgl-model-gateway/e2e_test/embeddings/test_correctness.py
vendored
Normal file
262
third_party/sglang/sgl-model-gateway/e2e_test/embeddings/test_correctness.py
vendored
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Embedding correctness tests.
|
||||
|
||||
Tests that embeddings from the router match HuggingFace reference embeddings.
|
||||
Validates numerical correctness including tokenization and inference.
|
||||
|
||||
Source: Migrated from e2e_grpc/basic/test_embedding_correctness.py
|
||||
|
||||
Usage:
|
||||
pytest e2e_test/embeddings/test_correctness.py -v
|
||||
pytest e2e_test/embeddings/test_correctness.py -v -k "grpc"
|
||||
|
||||
Requirements:
|
||||
- sentence-transformers (for reference embeddings)
|
||||
- torch
|
||||
- numpy
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread-safe storage for HF reference embeddings
|
||||
_hf_embeddings_cache: dict[str, Any] | None = None
|
||||
_hf_embeddings_lock = threading.Lock()
|
||||
|
||||
|
||||
# Test data for semantic similarity checks
|
||||
SEMANTIC_TEST_SETS: list[list[str]] = [
|
||||
[
|
||||
"The cat sat on the mat.",
|
||||
"A feline was resting on a rug.",
|
||||
"Bright stars illuminate the night sky.", # Unrelated sentence
|
||||
],
|
||||
[
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"A fast, dark-colored fox leaps above a sluggish canine.",
|
||||
"Ocean waves gently lap against the shore.", # Unrelated sentence
|
||||
],
|
||||
[
|
||||
"An apple a day keeps the doctor away.",
|
||||
"Eating a daily apple can prevent medical visits.",
|
||||
"Mountains are vast and often snow-capped.", # Unrelated sentence
|
||||
],
|
||||
]
|
||||
|
||||
# Test data for relevance scoring
|
||||
RELEVANCE_TEST_DATA: dict[str, Any] = {
|
||||
"sample_query": "Why is Oracle launching Cloud Lift Services?",
|
||||
"sample_reference": [
|
||||
{
|
||||
"docid": 466,
|
||||
"body": "What are some extended benefits of using Oracle Cloud Infrastructure? \nWhen customers migrate their on-premises Oracle applications to Oracle Cloud Infrastructure, they realize the benefits \nof the cloud without needing to rearchitect those applications. Customers can lower total cost of ownership, improve \nagility and increase workload performance. Additional benefits include: \nConsistently low global pricing and lack of hidden charges \nAutomated migration support, leveraging cloud managers and tools for key applications \nFlexible universal credits applied towards any IaaS or PaaS service \nBring Your Own License (BYOL) capabilities \nIs Oracle Cloud Lift available for PAYGO customers? \nOracle Cloud Lift Services are designed for customers who use the UCM credits (Monthly Flex). PAYGO customers can \ncontact their sales representative or cloud engineer to evaluate their eligibility. \nAre any countries excluded from Oracle Cloud Lift Services? \nAmong the countries that Oracle operates in, only China is excluded from the Oracle Cloud Lift Services program.",
|
||||
},
|
||||
{
|
||||
"docid": 636,
|
||||
"body": "Cloud Lift Services as needed to make our joint customers more successful. Public Sector accounts and partner \nengagements are not currently eligible to participate in this program. \n How can I get started with Oracle Cloud? \nYou can use the Oracle Cloud Free Tier for a free trial and Contact Us for more information.",
|
||||
},
|
||||
{
|
||||
"docid": 545,
|
||||
"body": "Frequently Asked Questions (FAQs) for \nOracle Cloud Lift Services \n \nWhy is Oracle launching Cloud Lift Services? \n \n \n \nThis program underscores Oracle's intent to better serve its customer base. Cloud Lift Services provide new and \nexisting customers expanded access to cloud engineering tools and resources to quickly migrate workloads at no \nadditional cost.",
|
||||
},
|
||||
{
|
||||
"docid": 716,
|
||||
"body": "as part of their existing contract. \nWhat happens if I already have a paid services engagement? \nPlease keep proceeding with your existing engagement. Oracle will work with you to identify expansion opportunities \nto leverage Cloud Lift Services for other projects.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_openai_embeddings(
|
||||
texts: str | list[str],
|
||||
client,
|
||||
model: str,
|
||||
) -> list[list[float]]:
|
||||
"""Get embeddings from the gateway via OpenAI-compatible API."""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = client.embeddings.create(
|
||||
model=model,
|
||||
input=text,
|
||||
)
|
||||
embeddings.append(response.data[0].embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def get_hf_st_embeddings(texts: str | list[str], model_path: str) -> np.ndarray:
|
||||
"""Get embeddings using sentence-transformers library.
|
||||
|
||||
This handles the correct pooling strategy for each model automatically.
|
||||
For e5-mistral, it uses last-token pooling (not mean pooling).
|
||||
|
||||
Uses CPU to compute reference embeddings to avoid GPU memory conflicts
|
||||
with the worker being tested.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Force CPU to avoid GPU memory conflicts in CI
|
||||
model = SentenceTransformer(model_path, trust_remote_code=True, device="cpu")
|
||||
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compare_embeddings(
|
||||
embeddings1: list[list[float]], embeddings2: list[list[float]]
|
||||
) -> list[float]:
|
||||
"""Compare two sets of embeddings using cosine similarity."""
|
||||
similarities = [
|
||||
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0).item()
|
||||
for e1, e2 in zip(embeddings1, embeddings2)
|
||||
]
|
||||
return similarities
|
||||
|
||||
|
||||
def get_input_texts(test_json: dict) -> list[str]:
|
||||
"""Extract document bodies from test JSON."""
|
||||
return [doc["body"] for doc in test_json["sample_reference"]]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_reference_embeddings(request):
|
||||
"""Pre-compute HuggingFace reference embeddings on CPU.
|
||||
|
||||
This is done once per session with thread-safe initialization to support
|
||||
pytest-parallel execution. Uses CPU to avoid GPU memory conflicts.
|
||||
"""
|
||||
global _hf_embeddings_cache
|
||||
|
||||
# Thread-safe initialization - only one thread computes embeddings
|
||||
with _hf_embeddings_lock:
|
||||
if _hf_embeddings_cache is not None:
|
||||
return _hf_embeddings_cache
|
||||
|
||||
from infra.model_specs import MODEL_SPECS
|
||||
|
||||
# Get model path from MODEL_SPECS for the embedding model
|
||||
model_path = MODEL_SPECS.get("embedding", {}).get("model")
|
||||
if model_path is None:
|
||||
pytest.skip("Embedding model not found in MODEL_SPECS")
|
||||
|
||||
logger.info(
|
||||
"Pre-computing HuggingFace reference embeddings (CPU) for %s", model_path
|
||||
)
|
||||
|
||||
# Flatten all test texts for semantic similarity
|
||||
all_semantic_texts = []
|
||||
for text_set in SEMANTIC_TEST_SETS:
|
||||
all_semantic_texts.extend(text_set)
|
||||
|
||||
# Get relevance test texts
|
||||
query = f"Instruct: Given a search query, retrieve relevant passages that answer the query\nQuery: {RELEVANCE_TEST_DATA['sample_query']}"
|
||||
docs = get_input_texts(RELEVANCE_TEST_DATA)
|
||||
|
||||
# Compute all reference embeddings at once
|
||||
hf_semantic = get_hf_st_embeddings(all_semantic_texts, model_path)
|
||||
hf_query = get_hf_st_embeddings(query, model_path)
|
||||
hf_docs = get_hf_st_embeddings(docs, model_path)
|
||||
|
||||
logger.info("Reference embeddings computed on CPU")
|
||||
|
||||
_hf_embeddings_cache = {
|
||||
"semantic": hf_semantic,
|
||||
"query": hf_query,
|
||||
"docs": hf_docs,
|
||||
}
|
||||
|
||||
return _hf_embeddings_cache
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.model("embedding")
|
||||
@pytest.mark.parametrize("setup_backend", ["grpc", "http"], indirect=True)
|
||||
class TestEmbeddingCorrectness:
|
||||
"""Test embedding correctness by comparing gateway output against HuggingFace reference.
|
||||
|
||||
Strategy: Pre-compute HuggingFace reference embeddings on CPU, then launch the
|
||||
worker on GPU and compare. Using CPU for reference avoids GPU memory conflicts.
|
||||
"""
|
||||
|
||||
def test_semantic_similarity(self, setup_backend, hf_reference_embeddings):
|
||||
"""Check if gateway and HF embeddings give similar results.
|
||||
|
||||
For each text in the semantic test sets, the gateway embedding should
|
||||
have >0.98 cosine similarity with the HuggingFace reference embedding.
|
||||
"""
|
||||
backend, model_path, client, gateway = setup_backend
|
||||
tolerance = 1e-2
|
||||
|
||||
# Track position in pre-computed embeddings
|
||||
embed_idx = 0
|
||||
|
||||
for i, input_texts in enumerate(SEMANTIC_TEST_SETS):
|
||||
logger.info("Processing semantic similarity test set %d", i + 1)
|
||||
|
||||
embedding_gateway = get_openai_embeddings(input_texts, client, model_path)
|
||||
|
||||
# Get pre-computed HF embeddings for this set
|
||||
num_texts = len(input_texts)
|
||||
embedding_hf = hf_reference_embeddings["semantic"][
|
||||
embed_idx : embed_idx + num_texts
|
||||
].tolist()
|
||||
embed_idx += num_texts
|
||||
|
||||
similarities = compare_embeddings(embedding_gateway, embedding_hf)
|
||||
logger.info("Similarities: %s", similarities)
|
||||
|
||||
# Verify all similarities are close to 1.0
|
||||
for j, sim in enumerate(similarities):
|
||||
assert (
|
||||
abs(sim - 1.0) < tolerance
|
||||
), f"Set {i+1}, text {j+1}: similarity {sim:.4f} not close to 1.0"
|
||||
|
||||
logger.info("Semantic similarity test set %d passed", i + 1)
|
||||
|
||||
def test_relevance_scores(self, setup_backend, hf_reference_embeddings):
|
||||
"""Compare relevance scores between gateway and HF implementations.
|
||||
|
||||
The relevance scores (query @ docs) should match between the gateway
|
||||
and HuggingFace implementations within tolerance.
|
||||
"""
|
||||
backend, model_path, client, gateway = setup_backend
|
||||
tolerance = 0.05
|
||||
|
||||
# Format query with instruction (for e5-mistral)
|
||||
query = f"Instruct: Given a search query, retrieve relevant passages that answer the query\nQuery: {RELEVANCE_TEST_DATA['sample_query']}"
|
||||
docs = get_input_texts(RELEVANCE_TEST_DATA)
|
||||
|
||||
# Get gateway scores
|
||||
query_embeddings_gateway = get_openai_embeddings(query, client, model_path)
|
||||
docs_embeddings_gateway = get_openai_embeddings(docs, client, model_path)
|
||||
scores_gateway = (
|
||||
np.array(query_embeddings_gateway) @ np.array(docs_embeddings_gateway).T
|
||||
) * 100
|
||||
|
||||
# Use pre-computed HF scores
|
||||
scores_hf = (
|
||||
hf_reference_embeddings["query"] @ hf_reference_embeddings["docs"].T
|
||||
) * 100
|
||||
|
||||
logger.info("Gateway relevance scores: %s", scores_gateway)
|
||||
logger.info("HF relevance scores: %s", scores_hf)
|
||||
|
||||
assert np.allclose(
|
||||
scores_gateway, scores_hf, atol=tolerance
|
||||
), f"Scores differ beyond tolerance:\nGateway: {scores_gateway}\nHF: {scores_hf}"
|
||||
|
||||
logger.info("Relevance scores comparison passed")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user