chore: vendor sglang v0.5.10 snapshot

This commit is contained in:
2026-04-24 12:29:36 +00:00
parent 78f0d15221
commit bded08301f
4308 changed files with 1200894 additions and 2 deletions

View File

@@ -0,0 +1,15 @@
[build]
rustflags = []
incremental = true
[target.aarch64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]
[target.x86_64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]

View File

@@ -0,0 +1,198 @@
[package]
name = "sgl-model-gateway"
version = "0.3.2"
edition = "2021"
[features]
default = ["grpc-client"]
grpc-client = []
grpc-server = []
vendored-openssl = ["openssl/vendored"]
[lints.rust]
unused_qualifications = "warn"
[lib]
name = "smg"
crate-type = ["rlib"]
[[bin]]
name = "sgl-model-gateway"
path = "src/main.rs"
[[bin]]
name = "smg"
path = "src/main.rs"
[[bin]]
name = "amg"
path = "src/main.rs"
[dependencies]
clap = { version = "4", features = ["derive", "env"] }
axum = { version = "0.8.6", features = ["macros", "ws", "tracing"] }
axum-server = { version = "0.8.0", default-features = false, features = ["tls-rustls"] }
tower = { version = "0.5", features = ["full"] }
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0", default-features = false, features = [
"std",
"preserve_order",
] }
bytes = "1.8.0"
http-body = "1.0"
rand = "0.9.2"
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json", "rustls-tls"], default-features = false }
futures-util = "0.3"
futures = "0.3"
dashmap = "6.1.0"
blake3 = "1.5"
xxhash-rust = { version = "0.8", features = ["xxh3"] }
bytemuck = { version = "1.21", features = ["derive"] }
http = "1.1.0"
tokio = { version = "1.42.0", features = ["full"] }
async-trait = "0.1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "chrono"] }
tracing-log = "0.2"
tracing-appender = "0.2.3"
opentelemetry = "0.27"
opentelemetry_sdk = { version = "0.27", features = ["trace", "rt-tokio"] }
opentelemetry-otlp = { version = "0.27", features = ["trace", "grpc-tonic"] }
tracing-opentelemetry = "0.28"
chrono = "0.4"
kube = { version = "1.1.0", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.25.0", features = ["v1_33"] }
metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0"
uuid = { version = "1.10", features = ["v4", "serde"] }
parking_lot = "0.12.4"
thiserror = "2.0.12"
regex = "1.10"
memchr = "2.7" # SIMD-optimized byte pattern searching
url = "2.5.4"
validator = { version = "0.20.0", features = ["derive"] }
tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0"
reasoning-parser = "=1.0.0"
openai-protocol = { version = "=1.0.0", features = ["axum"] }
tool-parser = "=1.0.0"
llm-tokenizer = "=1.0.0"
smg-auth = "=1.0.0"
wfaas = "=1.0.0"
data-connector = "=1.0.0"
smg-mcp = "=1.0.0"
smg-wasm = "=1.0.0"
smg-mesh = "=1.0.0"
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
rustls-pemfile = "2.2"
openssl = "0.10.73"
rmcp = { version = "0.8.3", features = ["client", "server",
"transport-child-process",
"transport-sse-client-reqwest",
"transport-streamable-http-client-reqwest",
"transport-streamable-http-server",
"transport-streamable-http-server-session",
"reqwest",
"auth"] }
serde_yaml = "0.9"
subtle = "2.6"
jsonwebtoken = { version = "9.3", default-features = false, features = ["use_pem"] }
num-traits = "0.2"
num-bigint = "0.4"
base64 = "0.22"
openai-harmony = { git = "https://github.com/openai/harmony", tag = "v0.0.4" }
openmetrics-parser = "0.4.4"
arc-swap = "1.7.1"
# gRPC and Protobuf dependencies
smg-grpc-client = "=1.0.0"
tonic = { version = "0.14.2", features = ["gzip", "transport"] }
prost = "0.14.1"
prost-types = "0.14.1"
tonic-prost = "0.14.2"
bitflags = "2.10.0"
once_cell = "1.21.3"
# CRDT for Mesh state synchronization
crdts = "7.3"
redis = { version = "0.27.6", features = ["tokio-comp", "json", "connection-manager"] }
# wasm dependencies
sha2 = "0.10"
wasmtime = { version = "38.0", features = ["component-model", "async"] }
[build-dependencies]
chrono = { version = "0.4", features = ["clock"] }
toml = "0.9"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
tower = { version = "0.5", features = ["util"] }
http-body-util = "0.1"
portpicker = "0.1"
tempfile = "3.8"
lazy_static = "1.4"
wasm-encoder = "0.242"
npyz = { version = "0.8", features = ["npz"] } # For reading numpy .npz files in golden tests
opentelemetry-proto = { version = "0.27", features = ["gen-tonic"] }
tonic-v12 = { version = "0.12.3", package = "tonic" }
serial_test = "3.0"
rsa = { version = "0.9", features = ["sha2"] }
[[bench]]
name = "consistent_hash_bench"
harness = false
path = "benches/consistent_hash_bench.rs"
[[bench]]
name = "wasm_middleware_latency"
harness = false
path = "benches/wasm_middleware_latency.rs"
[[bench]]
name = "request_processing"
harness = false
path = "benches/request_processing.rs"
[[bench]]
name = "router_registry_bench"
harness = false
[[bench]]
name = "manual_policy_benchmark"
harness = false
path = "benches/manual_policy_benchmark.rs"
[profile.release]
opt-level = "z" # Optimize for size
lto = "fat" # Full LTO for smaller binaries
codegen-units = 1 # Better optimization, slower compile
strip = true # Strip debug symbols
[profile.ci]
inherits = "release"
opt-level = 2 # Lighter optimization (still fast runtime, much faster compile)
lto = "thin" # Thin LTO - good balance
codegen-units = 16 # More parallelization for faster builds
strip = true
[profile.dev]
opt-level = 0
debug = 1
split-debuginfo = "unpacked"
incremental = true
codegen-units = 256
[profile.dev.package."*"]
opt-level = 2
debug = false
[profile.dev.build-override]
opt-level = 3
codegen-units = 1
[profile.dev-opt]
inherits = "dev"
opt-level = 1

View File

@@ -0,0 +1 @@
../LICENSE

View File

@@ -0,0 +1,202 @@
# Model Gateway Makefile
# Provides convenient shortcuts for common development tasks
# Python bindings directory
PYTHON_DIR := bindings/python
# Auto-detect CPU cores and cap at reasonable limit to avoid thread exhaustion
# Can be overridden: make python-dev JOBS=4
NPROC := $(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8)
JOBS ?= $(shell echo $$(($(NPROC) > 16 ? 16 : $(NPROC))))
# Check if sccache is available and set RUSTC_WRAPPER accordingly
SCCACHE := $(shell which sccache 2>/dev/null)
ifdef SCCACHE
export RUSTC_WRAPPER := $(SCCACHE)
$(info Using sccache for compilation caching)
else
$(info sccache not found. Install it for faster builds: cargo install sccache)
endif
.PHONY: help build test clean docs check fmt dev-setup pre-commit setup-sccache sccache-stats sccache-clean sccache-stop \
python-dev python-build python-build-release python-install python-clean python-test python-check \
show-version bump-version release-notes
help: ## Show this help message
@echo "Model Gateway Development Commands"
@echo "=================================="
@echo ""
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}'
@echo ""
build: ## Build the project in release mode
@echo "Building SGLang Model Gateway..."
@cargo build --release
test: ## Run all tests
@echo "Running tests..."
@cargo test
clean: ## Clean build artifacts
@echo "Cleaning build artifacts..."
@cargo clean
docs: ## Generate and open documentation
@echo "Generating documentation..."
@cargo doc --open
check: ## Run cargo check and clippy
@echo "Running cargo check..."
@cargo check
@echo "Running clippy..."
@cargo clippy --all-targets --all-features -- -D warnings
fmt: ## Format code with rustfmt
@echo "Formatting code..."
@rustup run nightly cargo fmt
# Development workflow shortcuts
dev-setup: build test ## Set up development environment
@echo "Development environment ready!"
pre-commit: fmt check test ## Run pre-commit checks
@echo "Pre-commit checks passed!"
# sccache management targets
setup-sccache: ## Install and configure sccache
@echo "Setting up sccache..."
@./scripts/setup-sccache.sh
sccache-stats: ## Show sccache statistics
@if [ -n "$(SCCACHE)" ]; then \
echo "sccache statistics:"; \
sccache -s; \
else \
echo "sccache not installed. Run 'make setup-sccache' to install it."; \
fi
sccache-clean: ## Clear sccache cache
@if [ -n "$(SCCACHE)" ]; then \
echo "Clearing sccache cache..."; \
sccache -C; \
echo "sccache cache cleared"; \
else \
echo "sccache not installed"; \
fi
sccache-stop: ## Stop the sccache server
@if [ -n "$(SCCACHE)" ]; then \
echo "Stopping sccache server..."; \
sccache --stop-server || true; \
else \
echo "sccache not installed"; \
fi
# Python bindings (maturin) targets
python-dev: ## Build Python bindings in development mode (fast, debug build)
@echo "Building Python bindings in development mode (using $(JOBS) parallel jobs with sccache)..."
@cd $(PYTHON_DIR) && CARGO_BUILD_JOBS=$(JOBS) maturin develop
python-build: ## Build Python wheel (release mode with vendored OpenSSL)
@echo "Building Python wheel (release, vendored OpenSSL, using $(JOBS) parallel jobs with sccache)..."
@cd $(PYTHON_DIR) && CARGO_BUILD_JOBS=$(JOBS) maturin build --release --out dist --features vendored-openssl
python-build-release: python-build ## Alias for python-build
python-install: python-build ## Build and install Python wheel
@echo "Installing Python wheel..."
@pip install --force-reinstall $(PYTHON_DIR)/dist/*.whl
@echo "Python package installed!"
python-clean: ## Clean Python build artifacts
@echo "Cleaning Python build artifacts..."
@rm -rf $(PYTHON_DIR)/dist/
@rm -rf $(PYTHON_DIR)/target/
@rm -rf $(PYTHON_DIR)/sglang_router.egg-info/
@rm -rf $(PYTHON_DIR)/sglang_router/__pycache__/
@find $(PYTHON_DIR) -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find $(PYTHON_DIR) -name "*.pyc" -delete 2>/dev/null || true
@echo "Python build artifacts cleaned!"
python-test: ## Run Python tests
@echo "Running Python tests..."
@pytest e2e_test/ -v
python-check: ## Check Python package with twine
@echo "Checking Python package..."
@cd $(PYTHON_DIR) && CARGO_BUILD_JOBS=$(JOBS) maturin build --release --out dist --features vendored-openssl
@pip install twine 2>/dev/null || true
@twine check $(PYTHON_DIR)/dist/*
@echo "Python package check passed!"
# Combined shortcuts
dev: python-dev ## Quick development setup (build Python bindings in dev mode)
install: python-install ## Build and install everything
# Release management
VERSION_FILES := Cargo.toml \
bindings/golang/Cargo.toml \
bindings/python/Cargo.toml \
bindings/python/pyproject.toml \
bindings/python/src/sglang_router/version.py
show-version: ## Show current version across all files
@echo "Current versions:"
@echo " Cargo.toml: $$(grep -m1 '^version = ' Cargo.toml | sed 's/version = "\(.*\)"/\1/')"
@echo " bindings/golang/Cargo.toml: $$(grep -m1 '^version = ' bindings/golang/Cargo.toml | sed 's/version = "\(.*\)"/\1/')"
@echo " bindings/python/Cargo.toml: $$(grep -m1 '^version = ' bindings/python/Cargo.toml | sed 's/version = "\(.*\)"/\1/')"
@echo " bindings/python/pyproject.toml: $$(grep -m1 '^version = ' bindings/python/pyproject.toml | sed 's/version = "\(.*\)"/\1/')"
@echo " bindings/python/.../version.py: $$(grep '__version__' bindings/python/src/sglang_router/version.py | sed 's/__version__ = "\(.*\)"/\1/')"
bump-version: ## Bump version across all files (usage: make bump-version VERSION=0.3.3)
@if [ -z "$(VERSION)" ]; then \
echo "Usage: make bump-version VERSION=<new-version>"; \
echo "Example: make bump-version VERSION=0.3.3"; \
echo ""; \
echo "Current version:"; \
grep -m1 '^version = ' Cargo.toml | sed 's/version = "\(.*\)"/ \1/'; \
exit 1; \
fi
@echo "Bumping version to $(VERSION)..."
@# Update main Cargo.toml (line 3)
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' Cargo.toml && rm -f Cargo.toml.bak
@# Update golang binding Cargo.toml
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' bindings/golang/Cargo.toml && rm -f bindings/golang/Cargo.toml.bak
@# Update python binding Cargo.toml
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' bindings/python/Cargo.toml && rm -f bindings/python/Cargo.toml.bak
@# Update pyproject.toml
@sed -i.bak 's/^version = ".*"/version = "$(VERSION)"/' bindings/python/pyproject.toml && rm -f bindings/python/pyproject.toml.bak
@# Update version.py
@sed -i.bak 's/__version__ = ".*"/__version__ = "$(VERSION)"/' bindings/python/src/sglang_router/version.py && rm -f bindings/python/src/sglang_router/version.py.bak
@echo "Version updated to $(VERSION) in all files:"
@echo " - Cargo.toml"
@echo " - bindings/golang/Cargo.toml"
@echo " - bindings/python/Cargo.toml"
@echo " - bindings/python/pyproject.toml"
@echo " - bindings/python/src/sglang_router/version.py"
@echo ""
@echo "Verify with: make show-version"
release-notes: ## Generate release notes for gateway (usage: make release-notes PREV=gateway-v0.2.2 CURR=gateway-v1.0.0)
@if [ -z "$(PREV)" ] || [ -z "$(CURR)" ]; then \
echo "Usage: make release-notes PREV=<previous-tag> CURR=<current-tag>"; \
echo "Example: make release-notes PREV=gateway-v0.2.2 CURR=gateway-v1.0.0"; \
echo ""; \
echo "Options:"; \
echo " OUTPUT=<file> Save to file (default: stdout)"; \
echo " CREATE_RELEASE=1 Create GitHub draft release via gh CLI (default: draft)"; \
echo " DRAFT=0 Publish release immediately (skip draft)"; \
exit 1; \
fi
@ARGS="$(PREV) $(CURR)"; \
if [ -n "$(OUTPUT)" ]; then \
ARGS="$$ARGS --output $(OUTPUT)"; \
fi; \
if [ "$(CREATE_RELEASE)" = "1" ]; then \
ARGS="$$ARGS --create-release"; \
if [ "$(DRAFT)" = "0" ]; then \
ARGS="$$ARGS --no-draft"; \
fi; \
fi; \
./scripts/generate_gateway_release_notes.sh $$ARGS

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use smg_mesh::consistent_hash::ConsistentHashRing;
fn setup_ring(node_count: usize) -> ConsistentHashRing {
let mut ring = ConsistentHashRing::new();
for i in 0..node_count {
ring.add_node(&format!("node-{}", i));
}
ring
}
fn bench_consistent_hash(c: &mut Criterion) {
let mut group = c.benchmark_group("ConsistentHashRing");
for size in [10, 100, 500].iter() {
let ring = setup_ring(*size);
let key = "test-request-key-for-rate-limiting";
let node_name = "node-5";
group.bench_with_input(BenchmarkId::new("get_owners", size), size, |b, _| {
b.iter(|| {
black_box(ring.get_owners(key));
});
});
group.bench_with_input(BenchmarkId::new("is_owner", size), size, |b, _| {
b.iter(|| {
black_box(ring.is_owner(key, node_name));
});
});
}
group.finish();
}
criterion_group!(benches, bench_consistent_hash);
criterion_main!(benches);

View File

@@ -0,0 +1,260 @@
use std::sync::Arc;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use smg::{
core::{BasicWorkerBuilder, Worker, WorkerType},
policies::{LoadBalancingPolicy, ManualPolicy, SelectWorkerInfo},
};
use tokio::runtime::Runtime;
// ============================================================================
// Test Helpers
// ============================================================================
fn create_workers(count: usize) -> Vec<Arc<dyn Worker>> {
(0..count)
.map(|i| {
Arc::new(
BasicWorkerBuilder::new(format!("http://worker-{}:8000", i))
.worker_type(WorkerType::Regular)
.build(),
) as Arc<dyn Worker>
})
.collect()
}
fn select_with_key(
rt: &Runtime,
policy: &ManualPolicy,
workers: &[Arc<dyn Worker>],
key: &str,
) -> Option<usize> {
let mut headers = http::HeaderMap::new();
headers.insert("x-smg-routing-key", key.parse().unwrap());
let info = SelectWorkerInfo {
headers: Some(&headers),
..Default::default()
};
rt.block_on(policy.select_worker(workers, &info))
}
fn warmup_keys(rt: &Runtime, policy: &ManualPolicy, workers: &[Arc<dyn Worker>], keys: &[String]) {
for key in keys {
select_with_key(rt, policy, workers, key);
}
}
fn gen_keys(count: usize, prefix: &str) -> Vec<String> {
(0..count).map(|i| format!("{}{}", prefix, i)).collect()
}
// ============================================================================
// Benchmarks
// ============================================================================
fn bench_fast_path_hit(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("manual_policy/fast_path");
for worker_count in [4, 16, 64, 256] {
let policy = ManualPolicy::new();
let workers = create_workers(worker_count);
let keys = gen_keys(1000, "user-");
warmup_keys(&rt, &policy, &workers, &keys);
group.throughput(Throughput::Elements(1));
group.bench_with_input(
BenchmarkId::new("workers", worker_count),
&worker_count,
|b, _| {
let mut idx = 0;
b.iter(|| {
let result = select_with_key(&rt, &policy, &workers, &keys[idx % keys.len()]);
idx += 1;
black_box(result)
});
},
);
}
group.finish();
}
fn bench_slow_path_vacant(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("manual_policy/slow_path_vacant");
for worker_count in [4, 16, 64, 256] {
let workers = create_workers(worker_count);
group.throughput(Throughput::Elements(1));
group.bench_with_input(
BenchmarkId::new("workers", worker_count),
&worker_count,
|b, _| {
let policy = ManualPolicy::new();
let mut idx = 0;
b.iter(|| {
let key = format!("new-user-{}", idx);
let result = select_with_key(&rt, &policy, &workers, &key);
idx += 1;
black_box(result)
});
},
);
}
group.finish();
}
fn bench_no_routing_key(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("manual_policy/no_routing_key");
for worker_count in [4, 16, 64, 256] {
let policy = ManualPolicy::new();
let workers = create_workers(worker_count);
group.throughput(Throughput::Elements(1));
group.bench_with_input(
BenchmarkId::new("workers", worker_count),
&worker_count,
|b, _| {
let info = SelectWorkerInfo::default();
b.iter(|| black_box(rt.block_on(policy.select_worker(&workers, &info))));
},
);
}
group.finish();
}
fn bench_failover(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("manual_policy/failover");
group.sample_size(50);
for worker_count in [4, 16, 64] {
group.bench_with_input(
BenchmarkId::new("workers", worker_count),
&worker_count,
|b, &count| {
b.iter_with_setup(
|| {
let policy = ManualPolicy::new();
let workers = create_workers(count);
let idx = select_with_key(&rt, &policy, &workers, "failover-test").unwrap();
workers[idx].set_healthy(false);
(policy, workers)
},
|(policy, workers)| {
black_box(select_with_key(&rt, &policy, &workers, "failover-test"))
},
);
},
);
}
group.finish();
}
fn bench_concurrent(c: &mut Criterion) {
let rt = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.build()
.unwrap(),
);
let mut group = c.benchmark_group("manual_policy/concurrent");
group.sample_size(50);
for num_threads in [2, 4, 8, 16] {
group.bench_with_input(
BenchmarkId::new("threads", num_threads),
&num_threads,
|b, &threads| {
b.iter(|| {
let policy = Arc::new(ManualPolicy::new());
let workers: Arc<Vec<Arc<dyn Worker>>> = Arc::new(create_workers(16));
rt.block_on(async {
let handles: Vec<_> = (0..threads)
.map(|t| {
let policy = Arc::clone(&policy);
let workers = Arc::clone(&workers);
tokio::spawn(async move {
for i in 0..500 {
let key = if i % 5 == 0 {
format!("thread{}_user{}", t, i)
} else {
format!("shared_user{}", i % 50)
};
let mut headers = http::HeaderMap::new();
headers.insert("x-smg-routing-key", key.parse().unwrap());
let info = SelectWorkerInfo {
headers: Some(&headers),
..Default::default()
};
let _ =
black_box(policy.select_worker(&workers, &info).await);
}
})
})
.collect();
for h in handles {
h.await.unwrap();
}
});
});
},
);
}
group.finish();
}
fn bench_cache_size_impact(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("manual_policy/cache_size");
for cache_size in [100, 1000, 10000, 100000] {
let policy = ManualPolicy::new();
let workers = create_workers(16);
let keys = gen_keys(cache_size, "user-");
warmup_keys(&rt, &policy, &workers, &keys);
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::new("keys", cache_size), &cache_size, |b, _| {
let mut idx = 0;
b.iter(|| {
let result = select_with_key(&rt, &policy, &workers, &keys[idx % keys.len()]);
idx += 1;
black_box(result)
});
});
}
group.finish();
}
fn bench_comparison_baseline(c: &mut Criterion) {
use rand::Rng;
let mut group = c.benchmark_group("manual_policy/vs_baseline");
let workers = create_workers(16);
// Baseline: raw random selection without any policy overhead
group.bench_function("raw_random", |b| {
let mut rng = rand::rng();
b.iter(|| black_box(rng.random_range(0..workers.len())));
});
group.finish();
}
criterion_group!(
benches,
bench_fast_path_hit,
bench_slow_path_vacant,
bench_no_routing_key,
bench_failover,
bench_concurrent,
bench_cache_size_impact,
bench_comparison_baseline,
);
criterion_main!(benches);

View File

@@ -0,0 +1,670 @@
use std::time::Instant;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use serde_json::{from_str, to_string, to_value, to_vec};
use smg::{
core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType},
protocols::{
chat::{ChatCompletionRequest, ChatMessage, MessageContent},
common::StringOrArray,
completion::CompletionRequest,
generate::GenerateRequest,
sampling_params::SamplingParams,
},
routers::http::pd_types::{generate_room_id, RequestWithBootstrap},
};
fn create_test_worker() -> BasicWorker {
BasicWorkerBuilder::new("http://test-server:8000")
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(5678),
})
.build()
}
// Helper function to get bootstrap info from worker
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
let hostname = worker.bootstrap_host().to_string();
let bootstrap_port = worker.bootstrap_port();
(hostname, bootstrap_port)
}
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn default_generate_request() -> GenerateRequest {
GenerateRequest {
text: None,
model: None,
input_ids: None,
input_embeds: None,
image_data: None,
video_data: None,
audio_data: None,
sampling_params: None,
return_logprob: None,
logprob_start_len: None,
top_logprobs_num: None,
token_ids_logprob: None,
return_text_in_logprobs: false,
stream: false,
log_metrics: true,
return_hidden_states: false,
modalities: None,
session_params: None,
lora_path: None,
lora_id: None,
custom_logit_processor: None,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
bootstrap_pair_key: None,
data_parallel_rank: None,
background: false,
conversation_id: None,
priority: None,
extra_key: None,
no_logs: false,
custom_labels: None,
return_bytes: false,
return_entropy: false,
rid: None,
}
}
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
#[allow(deprecated)]
fn default_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest {
// Required fields in OpenAI order
messages: vec![],
model: String::new(),
// Use default for all other fields
..Default::default()
}
}
/// Create a default CompletionRequest for benchmarks with minimal fields set
fn default_completion_request() -> CompletionRequest {
CompletionRequest {
model: String::new(),
prompt: StringOrArray::String(String::new()),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
sampling_seed: None,
other: serde_json::Map::new(),
}
}
// Sample request data for benchmarks
fn create_sample_generate_request() -> GenerateRequest {
GenerateRequest {
text: Some("Write a story about artificial intelligence".to_string()),
sampling_params: Some(SamplingParams {
max_new_tokens: Some(100),
temperature: Some(0.8),
top_p: Some(0.9),
top_k: Some(50),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
repetition_penalty: Some(1.0),
..Default::default()
}),
..default_generate_request()
}
}
#[allow(deprecated)]
fn create_sample_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![
ChatMessage::System {
content: MessageContent::Text("You are a helpful assistant".to_string()),
name: None,
},
ChatMessage::User {
content: MessageContent::Text(
"Explain quantum computing in simple terms".to_string(),
),
name: None,
},
],
max_tokens: Some(150),
max_completion_tokens: Some(150),
temperature: Some(0.7),
top_p: Some(1.0),
n: Some(1),
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
parallel_tool_calls: Some(true),
..default_chat_completion_request()
}
}
fn create_sample_completion_request() -> CompletionRequest {
CompletionRequest {
model: "text-davinci-003".to_string(),
prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()),
max_tokens: Some(50),
temperature: Some(0.8),
top_p: Some(1.0),
n: Some(1),
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
best_of: Some(1),
..default_completion_request()
}
}
#[allow(deprecated)]
fn create_large_chat_completion_request() -> ChatCompletionRequest {
let mut messages = vec![ChatMessage::System {
content: MessageContent::Text(
"You are a helpful assistant with extensive knowledge.".to_string(),
),
name: None,
}];
// Add many user/assistant pairs to simulate a long conversation
for i in 0..50 {
messages.push(ChatMessage::User {
content: MessageContent::Text(format!("Question {}: What do you think about topic number {} which involves complex reasoning about multiple interconnected systems and their relationships?", i, i)),
name: None,
});
messages.push(ChatMessage::Assistant {
content: Some(MessageContent::Text(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i))),
name: None,
tool_calls: None,
reasoning_content: None,
});
}
ChatCompletionRequest {
model: "gpt-4".to_string(),
messages,
max_tokens: Some(1000),
max_completion_tokens: Some(1000),
temperature: Some(0.7),
top_p: Some(0.95),
n: Some(1),
presence_penalty: Some(0.1),
frequency_penalty: Some(0.1),
top_logprobs: Some(5),
seed: Some(42),
parallel_tool_calls: Some(true),
..default_chat_completion_request()
}
}
// Benchmark JSON serialization
fn bench_json_serialization(c: &mut Criterion) {
let mut group = c.benchmark_group("json_serialization");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
let completion_req = create_sample_completion_request();
let large_chat_req = create_large_chat_completion_request();
group.bench_function("generate_request", |b| {
b.iter(|| {
let json = to_string(black_box(&generate_req)).unwrap();
black_box(json);
});
});
group.bench_function("chat_completion_request", |b| {
b.iter(|| {
let json = to_string(black_box(&chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("completion_request", |b| {
b.iter(|| {
let json = to_string(black_box(&completion_req)).unwrap();
black_box(json);
});
});
group.bench_function("large_chat_completion_request", |b| {
b.iter(|| {
let json = to_string(black_box(&large_chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("generate_request_to_bytes", |b| {
b.iter(|| {
let bytes = to_vec(black_box(&generate_req)).unwrap();
black_box(bytes);
});
});
group.finish();
}
// Benchmark JSON deserialization
fn bench_json_deserialization(c: &mut Criterion) {
let mut group = c.benchmark_group("json_deserialization");
let generate_json = to_string(&create_sample_generate_request()).unwrap();
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
let completion_json = to_string(&create_sample_completion_request()).unwrap();
let large_chat_json = to_string(&create_large_chat_completion_request()).unwrap();
group.bench_function("generate_request", |b| {
b.iter(|| {
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
black_box(req);
});
});
group.bench_function("chat_completion_request", |b| {
b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
black_box(req);
});
});
group.bench_function("completion_request", |b| {
b.iter(|| {
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
black_box(req);
});
});
group.bench_function("large_chat_completion_request", |b| {
b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&large_chat_json)).unwrap();
black_box(req);
});
});
group.finish();
}
// Benchmark bootstrap injection (replaces request adaptation)
fn bench_bootstrap_injection(c: &mut Criterion) {
let mut group = c.benchmark_group("bootstrap_injection");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
let completion_req = create_sample_completion_request();
let large_chat_req = create_large_chat_completion_request();
let worker = create_test_worker();
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
group.bench_function("generate_bootstrap_injection", |b| {
b.iter(|| {
let request_with_bootstrap = RequestWithBootstrap {
original: &generate_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json);
});
});
group.bench_function("chat_completion_bootstrap_injection", |b| {
b.iter(|| {
let request_with_bootstrap = RequestWithBootstrap {
original: &chat_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json);
});
});
group.bench_function("completion_bootstrap_injection", |b| {
b.iter(|| {
let request_with_bootstrap = RequestWithBootstrap {
original: &completion_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json);
});
});
group.bench_function("large_chat_completion_bootstrap_injection", |b| {
b.iter(|| {
let request_with_bootstrap = RequestWithBootstrap {
original: &large_chat_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json);
});
});
group.finish();
}
// Benchmark direct JSON routing (replaces regular routing)
fn bench_direct_json_routing(c: &mut Criterion) {
let mut group = c.benchmark_group("direct_json_routing");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
let completion_req = create_sample_completion_request();
group.bench_function("generate_to_json", |b| {
b.iter(|| {
let json = to_value(black_box(&generate_req)).unwrap();
black_box(json);
});
});
group.bench_function("generate_to_json_string", |b| {
b.iter(|| {
let json = to_string(black_box(&generate_req)).unwrap();
black_box(json);
});
});
group.bench_function("generate_to_bytes", |b| {
b.iter(|| {
let bytes = to_vec(black_box(&generate_req)).unwrap();
black_box(bytes);
});
});
group.bench_function("chat_completion_to_json", |b| {
b.iter(|| {
let json = to_value(black_box(&chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("chat_completion_to_json_string", |b| {
b.iter(|| {
let json = to_string(black_box(&chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("completion_to_json", |b| {
b.iter(|| {
let json = to_value(black_box(&completion_req)).unwrap();
black_box(json);
});
});
group.finish();
}
// Benchmark throughput with different request sizes
fn bench_throughput_by_size(c: &mut Criterion) {
let mut group = c.benchmark_group("throughput_by_size");
// Create requests of different sizes
let small_generate = GenerateRequest {
text: Some("Hi".to_string()),
..default_generate_request()
};
let medium_generate = GenerateRequest {
text: Some("Write a medium length story about AI".repeat(10)),
..default_generate_request()
};
let large_generate = GenerateRequest {
text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)),
..default_generate_request()
};
let worker = create_test_worker();
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
for (name, req) in [
("small", &small_generate),
("medium", &medium_generate),
("large", &large_generate),
] {
let json = to_string(req).unwrap();
let size_bytes = json.len();
let hostname_clone = hostname.clone();
group.throughput(Throughput::Bytes(size_bytes as u64));
group.bench_with_input(BenchmarkId::new("serialize", name), &req, |b, req| {
b.iter(|| {
let json = to_string(black_box(req)).unwrap();
black_box(json);
});
});
group.bench_with_input(
BenchmarkId::new("deserialize", name),
&json,
|b, json_str| {
b.iter(|| {
let req: GenerateRequest = black_box(from_str(json_str)).unwrap();
black_box(req);
});
},
);
group.bench_with_input(
BenchmarkId::new("bootstrap_inject", name),
&req,
move |b, req| {
let hostname = hostname_clone.clone();
b.iter(|| {
let request_with_bootstrap = RequestWithBootstrap {
original: req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(&request_with_bootstrap).unwrap();
black_box(json);
});
},
);
}
group.finish();
}
// Benchmark full round-trip: deserialize -> inject bootstrap -> serialize
fn bench_full_round_trip(c: &mut Criterion) {
let mut group = c.benchmark_group("full_round_trip");
let generate_json = to_string(&create_sample_generate_request()).unwrap();
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
let completion_json = to_string(&create_sample_completion_request()).unwrap();
let worker = create_test_worker();
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
group.bench_function("generate_openai_to_pd_pipeline", |b| {
b.iter(|| {
// Deserialize OpenAI request
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
// Create wrapper with bootstrap fields
let request_with_bootstrap = RequestWithBootstrap {
original: &req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
// Serialize final request
let pd_json = to_string(&request_with_bootstrap).unwrap();
black_box(pd_json);
});
});
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
let request_with_bootstrap = RequestWithBootstrap {
original: &req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let pd_json = to_string(&request_with_bootstrap).unwrap();
black_box(pd_json);
});
});
group.bench_function("completion_openai_to_pd_pipeline", |b| {
b.iter(|| {
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
let request_with_bootstrap = RequestWithBootstrap {
original: &req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let pd_json = to_string(&request_with_bootstrap).unwrap();
black_box(pd_json);
});
});
group.bench_function("generate_direct_json_pipeline", |b| {
b.iter(|| {
// Deserialize OpenAI request
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
// Convert to JSON for direct routing (no bootstrap injection)
let routing_json = to_value(&req).unwrap();
let json_string = to_string(&routing_json).unwrap();
black_box(json_string);
});
});
group.finish();
}
fn benchmark_summary(c: &mut Criterion) {
let group = c.benchmark_group("benchmark_summary");
println!("\nSGLang Model Gateway Performance Benchmark Suite");
println!("=================================================");
// Quick performance overview
let generate_req = create_sample_generate_request();
let worker = create_test_worker();
println!("\nQuick Performance Overview:");
// Measure serialization
let start = Instant::now();
for _ in 0..1000 {
let _ = black_box(to_string(&generate_req).unwrap());
}
let serialize_time = start.elapsed().as_nanos() / 1000;
println!(" * Serialization (avg): {:>8} ns/req", serialize_time);
// Measure deserialization
let json = to_string(&generate_req).unwrap();
let start = Instant::now();
for _ in 0..1000 {
let _: GenerateRequest = black_box(from_str(&json).unwrap());
}
let deserialize_time = start.elapsed().as_nanos() / 1000;
println!(
" * Deserialization (avg): {:>8} ns/req",
deserialize_time
);
// Measure bootstrap injection (replaces adaptation)
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
let start = Instant::now();
for _ in 0..1000 {
let request_with_bootstrap = RequestWithBootstrap {
original: &generate_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let _ = black_box(to_value(&request_with_bootstrap).unwrap());
}
let inject_time = start.elapsed().as_nanos() / 1000;
println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time);
// Calculate ratios
let total_pipeline = serialize_time + deserialize_time + inject_time;
println!(" * Total Pipeline (avg): {:>8} ns/req", total_pipeline);
println!("\nPerformance Insights:");
if deserialize_time > serialize_time * 2 {
println!(" • Deserialization is significantly faster than serialization");
}
if inject_time < serialize_time / 10 {
println!(
" • Bootstrap injection overhead is negligible ({:.1}% of serialization)",
(inject_time as f64 / serialize_time as f64) * 100.0
);
}
if total_pipeline < 100_000 {
println!(" • Total pipeline latency is excellent (< 100μs)");
}
println!("\nSimplification Benefits:");
println!(" • Eliminated complex type conversion layer");
println!(" • Reduced memory allocations");
println!(" • Automatic field preservation (no manual mapping)");
println!(" • Direct JSON manipulation improves performance");
println!("\nRecommendations:");
if serialize_time > deserialize_time {
println!(" • Focus optimization efforts on serialization rather than deserialization");
}
println!(" • PD mode overhead is minimal - safe to use for latency-sensitive workloads");
println!(" • Consider batching small requests to improve overall throughput");
println!("\n{}", "=".repeat(50));
group.finish();
}
criterion_group!(
benches,
benchmark_summary,
bench_json_serialization,
bench_json_deserialization,
bench_bootstrap_injection,
bench_direct_json_routing,
bench_throughput_by_size,
bench_full_round_trip
);
criterion_main!(benches);

View File

@@ -0,0 +1,59 @@
use std::{collections::HashMap, sync::Arc};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use smg::core::{BasicWorkerBuilder, CircuitBreakerConfig, WorkerRegistry, WorkerType};
// Helper to populate registry
fn setup_registry(count: usize) -> Arc<WorkerRegistry> {
let registry = Arc::new(WorkerRegistry::new());
for i in 0..count {
let mut labels = HashMap::new();
labels.insert("model_id".to_string(), "benchmark-model".to_string());
let worker_type = if i % 2 == 0 {
WorkerType::Regular
} else {
WorkerType::Decode
};
let worker = BasicWorkerBuilder::new(format!("http://worker-{}:8000", i))
.worker_type(worker_type)
.labels(labels)
.circuit_breaker_config(CircuitBreakerConfig::default())
.build();
registry.register(Arc::from(worker));
}
registry
}
fn bench_optimizations(c: &mut Criterion) {
let mut group = c.benchmark_group("Registry Optimizations");
// We test with 5000 workers to simulate high load
let size = 5000;
let registry = setup_registry(size);
// The OLD method (Slow: Allocates vector + Clones ARCs)
group.bench_function(BenchmarkId::new("Old: get_all()", size), |b| {
b.iter(|| {
black_box(registry.get_all());
});
});
// The NEW method (Fast: O(1) Lookup, Zero Allocation)
group.bench_function(
BenchmarkId::new("New: get_worker_distribution()", size),
|b| {
b.iter(|| {
black_box(registry.get_worker_distribution());
});
},
);
group.finish();
}
criterion_group!(benches, bench_optimizations);
criterion_main!(benches);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,110 @@
use std::sync::Arc;
use axum::{
body::Body,
http::{HeaderMap, Request, Response, StatusCode},
middleware,
response::IntoResponse,
};
use criterion::{criterion_group, criterion_main, Criterion};
use http_body_util::BodyExt;
use smg::{
app_context::AppContext, config::RouterConfig, middleware::wasm_middleware,
protocols::chat::ChatCompletionRequest, routers::RouterTrait, server::AppState,
};
use tokio::runtime::Runtime;
use tower::{Layer, Service};
#[derive(Debug)]
struct MockRouter;
#[async_trait::async_trait]
impl RouterTrait for MockRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response<Body> {
StatusCode::OK.into_response()
}
fn router_type(&self) -> &'static str {
"mock"
}
}
/// Mock service that simulates a streaming response with a 500ms delay.
async fn mock_next_streaming(_req: Request<Body>) -> Response<Body> {
let (tx, rx) = tokio::sync::mpsc::channel(16);
tokio::spawn(async move {
// Send first chunk immediately
let _ = tx
.send(Ok::<_, std::io::Error>(bytes::Bytes::from("chunk 1 ")))
.await;
// Simulate generation delay
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Send final chunk
let _ = tx
.send(Ok::<_, std::io::Error>(bytes::Bytes::from("chunk 2")))
.await;
});
Response::new(Body::from_stream(
tokio_stream::wrappers::ReceiverStream::new(rx),
))
}
fn bench_wasm_middleware_buffering(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
// Setup AppContext with WASM enabled
let config = RouterConfig::builder().enable_wasm(true).build_unchecked();
let context = rt.block_on(AppContext::from_config(config, 30)).unwrap();
let app_state = Arc::new(AppState {
router: Arc::new(MockRouter),
context: Arc::new(context),
concurrency_queue_tx: None,
router_manager: None,
mesh_handler: None,
mesh_sync_manager: None,
});
c.bench_function("wasm_middleware_pre_fix_latency", |b| {
b.iter(|| {
rt.block_on(async {
let req = Request::builder()
.uri("/v1/chat/completions")
.body(Body::empty())
.unwrap();
// Create the service by applying the middleware layer to the mock streamer
let mut service =
middleware::from_fn_with_state(app_state.clone(), wasm_middleware).layer(
tower::service_fn(|req: Request<Body>| async move {
Ok::<_, std::convert::Infallible>(mock_next_streaming(req).await)
}),
);
// Explicitly poll the service
let response: Response<Body> =
service.call(req).await.expect("Middleware service failed");
// Measure how long it takes to receive the FIRST frame
let mut body = response.into_body();
let _first_frame = body.frame().await;
});
});
});
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = bench_wasm_middleware_buffering
}
criterion_main!(benches);

View File

@@ -0,0 +1,24 @@
# Build artifacts
target/
lib/
# Compiled binaries
examples/simple/simple
examples/streaming/streaming
# Go build artifacts
*.o
*.a
*.so
*.dylib
# IDE and editor files
.vscode/
.idea/
*.swp
*.swo
*~
# Environment files
.env
.env.local

View File

@@ -0,0 +1,48 @@
[package]
name = "sgl-model-gateway-golang"
version = "0.3.2"
edition = "2021"
[lib]
name = "sgl_model_gateway_go"
crate-type = ["cdylib"]
[dependencies]
tokio = { version = "1.42.0", features = ["full"] }
serde_json = { version = "1.0", default-features = false, features = [
"std",
"preserve_order",
] }
uuid = { version = "1.10", features = ["v4", "serde"] }
once_cell = "1.21.3"
futures-util = "0.3"
tracing = "0.1"
libc = "0.2.179"
[dependencies.sgl-model-gateway]
path = "../.."
default-features = true
[features]
default = []
vendored-openssl = ["sgl-model-gateway/vendored-openssl"]
[profile.release]
opt-level = "z" # Optimize for size
lto = "fat" # Full LTO for smaller binaries
codegen-units = 1 # Better optimization, slower compile
strip = true # Strip debug symbols
[profile.ci]
inherits = "release"
opt-level = 2 # Lighter optimization (still fast runtime, much faster compile)
lto = "thin" # Thin LTO - good balance
codegen-units = 16 # More parallelization for faster builds
strip = true
[profile.dev]
opt-level = 0
debug = 1
split-debuginfo = "unpacked"
incremental = true
codegen-units = 256

View File

@@ -0,0 +1,103 @@
# Makefile for sgl-model-gateway golang bindings
# This builds the Rust FFI library and provides convenience targets for Go development
# Configuration
CARGO_BUILD_DIR ?= $(shell pwd)/target
BUILD_MODE ?= release
LIB_NAME = libsgl_model_gateway_go
# Detect OS
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux)
LIB_EXT = .so
LD_LIBRARY_PATH_VAR = LD_LIBRARY_PATH
endif
ifeq ($(UNAME_S),Darwin)
LIB_EXT = .dylib
LD_LIBRARY_PATH_VAR = DYLD_LIBRARY_PATH
endif
# Paths
ROOT_DIR := $(shell pwd)
RUST_SRC_DIR := $(ROOT_DIR)/src
LIB_BUILD_DIR := $(CARGO_BUILD_DIR)/$(BUILD_MODE)
LIB_BUILD_PATH := $(LIB_BUILD_DIR)/$(LIB_NAME)$(LIB_EXT)
LIB_EXPORT_DIR := $(ROOT_DIR)/lib
LIB_EXPORT_PATH := $(LIB_EXPORT_DIR)/$(LIB_NAME)$(LIB_EXT)
# Python LDFLAGS (needed for Rust FFI that depends on Python)
PYTHON_LDFLAGS := $(shell python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
# CGO flags - use exported lib directory if available, otherwise build directory
LIB_DIR := $(if $(wildcard $(LIB_EXPORT_PATH)),$(LIB_EXPORT_DIR),$(LIB_BUILD_DIR))
export CGO_LDFLAGS = -L$(LIB_DIR) -lsgl_model_gateway_go $(PYTHON_LDFLAGS) -ldl
export $(LD_LIBRARY_PATH_VAR) := $(LIB_DIR):$($(LD_LIBRARY_PATH_VAR))
.PHONY: all build build-dev lib lib-clean clean test examples help run-simple run-streaming check-lib
help:
@echo "Available targets:"
@echo " build - Build release version of Rust FFI library"
@echo " build-dev - Build debug version of Rust FFI library"
@echo " lib - Copy built library to ./lib directory"
@echo " lib-clean - Clean ./lib directory"
@echo " clean - Clean build artifacts"
@echo " test - Run Go tests"
@echo " examples - Build example programs"
@echo " run-simple - Run simple example"
@echo " run-streaming - Run streaming example"
all: build
build:
@echo "Building Rust FFI library (release mode)..."
@CARGO_TARGET_DIR=$(CARGO_BUILD_DIR) cargo build --release --manifest-path Cargo.toml
@echo "Library built at: $(LIB_BUILD_PATH)"
build-dev:
@echo "Building Rust FFI library (debug mode)..."
@CARGO_TARGET_DIR=$(CARGO_BUILD_DIR) cargo build --manifest-path Cargo.toml
@echo "Library built at: $(LIB_BUILD_DIR)/debug/$(LIB_NAME)$(LIB_EXT)"
lib: build
@echo "Copying library to ./lib directory..."
@mkdir -p $(LIB_EXPORT_DIR)
@cp $(LIB_BUILD_PATH) $(LIB_EXPORT_PATH)
@echo "Library exported at: $(LIB_EXPORT_PATH)"
lib-clean:
@echo "Cleaning ./lib directory..."
@rm -rf $(LIB_EXPORT_DIR)
@echo "Lib directory cleaned"
clean: lib-clean
@echo "Cleaning build artifacts..."
@CARGO_TARGET_DIR=$(CARGO_BUILD_DIR) cargo clean --manifest-path Cargo.toml
@echo "Clean complete"
test: build
@echo "Running Go tests..."
@go test ./...
examples: build
@echo "Building example programs..."
@cd examples/simple && go build -o simple main.go
@cd examples/streaming && go build -o streaming main.go
@echo "Examples built"
run-simple: build
@echo "Running simple example..."
@cd examples/simple && bash run.sh
run-streaming: build
@echo "Running streaming example..."
@cd examples/streaming && bash run.sh
# Check if library exists (either in lib dir or build dir)
check-lib:
@if [ ! -f "$(LIB_EXPORT_PATH)" ] && [ ! -f "$(LIB_BUILD_PATH)" ]; then \
echo "Error: Library not found at $(LIB_EXPORT_PATH) or $(LIB_BUILD_PATH)"; \
echo "Run 'make build' or 'make lib' first"; \
exit 1; \
fi
@echo "Library found at: $(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)"

View File

@@ -0,0 +1,482 @@
# SGLang Go gRPC SDK
A high-level Go SDK for interacting with SGLang gRPC API, designed with an OpenAI-style API for familiarity and ease of use.
**Location**: `sgl-model-gateway/bindings/golang/`
## Table of Contents
- [Features](#features)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Basic Usage](#basic-usage)
- [Streaming Usage](#streaming-usage)
- [Examples](#examples)
- [Configuration](#configuration)
- [API Reference](#api-reference)
- [Testing](#testing)
- [Unit Tests](#unit-tests)
- [Integration Tests](#integration-tests)
- [Benchmarks](#benchmarks)
- [Documentation](#documentation)
- [Development](#development)
- [Troubleshooting](#troubleshooting)
- [License](#license)
## Features
- **OpenAI-style API**: Familiar interface similar to OpenAI Go SDK
- **Streaming Support**: Real-time streaming chat completions
- **Non-streaming Support**: Simple request/response API
- **Tool Calling**: Support for function calling and tool use
- **Type-safe**: Full Go type definitions for requests and responses
- **Comprehensive Testing**: 18+ unit and integration tests
- **Thread-safe**: All public methods are safe for concurrent use
- **Well-documented**: Full API documentation with examples
## Installation
```bash
go get github.com/sglang/sglang-go-grpc-sdk
```
### Sync Dependencies
```bash
cd sgl-model-gateway/bindings/golang
go mod tidy
```
### Build Requirements
- Go 1.21+, Rust toolchain, Python 3.x
## Quick Start
### Benchmark
Run the OpenAI-compatible server and benchmark:
```bash
# Set environment variables
export SGL_TOKENIZER_PATH="/Users/yangyanbo/tokenizer"
export SGL_GRPC_ENDPOINT="grpc://10.109.185.20:8001"
# Run server
cd examples/oai_server
bash run.sh
# Run E2E benchmark
cd ../..
make e2e E2E_MODEL=/work/models/qwencoder-3b E2E_TOKENIZER=/Users/yangyanbo/tokenizer E2E_INPUT_LEN=1024 E2E_OUTPUT_LEN=512
```
## Examples
The SDK includes several examples in the `examples/` directory:
- **simple**: Basic non-streaming chat completion example
- **streaming**: Real-time streaming with performance metrics
### Running Examples
```bash
# Run simple example
cd bindings/golang/examples/simple
bash run.sh
# Run streaming example
cd bindings/golang/examples/streaming
bash run.sh
# Or use Makefile from bindings/golang directory
cd bindings/golang
make run-simple
make run-streaming
```
### Basic Usage (Non-streaming)
```go
package main
import (
"context"
"fmt"
"log"
"github.com/sglang/sglang-go-grpc-sdk"
)
func main() {
// Create client
client, err := sglang.NewClient(sglang.ClientConfig{
Endpoint: "grpc://localhost:20000",
TokenizerPath: "/path/to/tokenizer",
})
if err != nil {
log.Fatal(err)
}
defer client.Close()
// Create completion
resp, err := client.CreateChatCompletion(context.Background(), sglang.ChatCompletionRequest{
Model: "default",
Messages: []sglang.ChatMessage{
{Role: "user", Content: "Hello!"},
},
Stream: false,
})
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.Choices[0].Message.Content)
fmt.Printf("Usage: Prompt=%d, Completion=%d, Total=%d\n",
resp.Usage.PromptTokens,
resp.Usage.CompletionTokens,
resp.Usage.TotalTokens)
}
```
### Streaming Usage
```go
package main
import (
"context"
"fmt"
"io"
"log"
"github.com/sglang/sglang-go-grpc-sdk"
)
func main() {
// Create client
client, err := sglang.NewClient(sglang.ClientConfig{
Endpoint: "grpc://localhost:20000",
TokenizerPath: "/path/to/tokenizer",
})
if err != nil {
log.Fatal(err)
}
defer client.Close()
// Create streaming completion
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, sglang.ChatCompletionRequest{
Model: "default",
Messages: []sglang.ChatMessage{
{Role: "user", Content: "Tell me a story"},
},
Stream: true,
MaxCompletionTokens: intPtr(500),
})
if err != nil {
log.Fatal(err)
}
defer stream.Close()
// Read streaming response
for {
chunk, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Fatal(err)
}
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
fmt.Print(choice.Delta.Content)
}
}
}
fmt.Println() // newline
}
// Helper functions for optional pointer fields
func intPtr(i int) *int {
return &i
}
func float32Ptr(f float32) *float32 {
return &f
}
```
Examples automatically detect the server endpoint and tokenizer path via environment variables or defaults.
## Configuration
### Environment Variables
- `SGL_GRPC_ENDPOINT`: gRPC server endpoint (default: `grpc://localhost:20000`)
- `SGL_TOKENIZER_PATH`: Path to tokenizer directory (required)
- `CARGO_BUILD_DIR`: Rust build output directory (auto-detected if not set)
### ClientConfig
```go
type ClientConfig struct {
// Endpoint is the gRPC endpoint URL (e.g., "grpc://localhost:20000")
// Required field. Must include the scheme (grpc://) and port number.
Endpoint string
// TokenizerPath is the path to the tokenizer directory containing
// tokenizer configuration files (e.g., tokenizer.json, vocab.json)
// Required field.
TokenizerPath string
}
```
## API Reference
### Client Methods
```go
type Client struct {
// Thread-safe client for SGLang gRPC API
}
// Creates a new client with the given configuration
func NewClient(config ClientConfig) (*Client, error)
// Closes the client and releases all resources
func (c *Client) Close() error
// Creates a non-streaming chat completion
func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error)
// Creates a streaming chat completion
func (c *Client) CreateChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionStream, error)
```
### Request Types
- `ChatCompletionRequest`: Main request type for chat completions
- Model, Messages, Stream, Temperature, TopP, MaxCompletionTokens, Tools, etc.
- `ChatMessage`: Individual message in a conversation
- Role, Content
- `Tool`: Tool/function definition for function calling
- Type, Function (name, description, parameters)
### Response Types
- `ChatCompletionResponse`: Non-streaming response
- ID, Model, Created, Choices, Usage
- `ChatCompletionStreamResponse`: Streaming response chunk
- Same structure as above but for incremental updates
- `Message`: Complete message with content and tool calls
- `ToolCall`: Tool call information with function and arguments
- `Usage`: Token usage statistics
- PromptTokens, CompletionTokens, TotalTokens
## Testing
The SDK includes comprehensive testing infrastructure with both unit and integration tests.
### Unit Tests
Unit tests are located in `client_test.go` and test individual components without requiring a server.
#### Running Unit Tests
```bash
# Run all unit tests
go test ./...
# Run with verbose output
go test -v ./...
# Run specific test
go test -run TestClientConfig
# Run tests with race detector (detects concurrency issues)
go test -race ./...
# Run with coverage analysis
go test -cover ./...
# Generate detailed coverage report
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out -o coverage.html
```
#### Unit Test Coverage
- Configuration validation, type structures, response handling, concurrent operations, and benchmarks
- `client_test.go` - 10 unit tests covering core functionality
### Integration Tests
Integration tests require a running SGLang server and test the full client-server interaction.
#### Prerequisites
1. Start SGLang server: `python -m sglang.launch_server --model-path <model_path>`
2. Set environment variables:
```bash
export SGL_GRPC_ENDPOINT=grpc://localhost:20000
export SGL_TOKENIZER_PATH=/path/to/tokenizer
```
#### Running Integration Tests
```bash
# Run all integration tests
go test -tags=integration ./...
# Run specific integration test
go test -tags=integration -run TestIntegrationNonStreamingCompletion
# Run with verbose output
go test -tags=integration -v ./...
# Run with race detector
go test -tags=integration -race ./...
```
#### Integration Test Coverage
**Test File**: `integration_test.go` - 4 integration tests
- `TestIntegrationNonStreamingCompletion` - Basic non-streaming request/response
- `TestIntegrationStreamingCompletion` - Streaming response handling
- `TestIntegrationConcurrentRequests` - Multiple simultaneous requests
- `TestIntegrationContextCancellation` - Context timeout and cancellation
### Benchmarks
```bash
go test -bench=. -benchmem ./...
```
## Documentation
All public types and functions include comprehensive documentation with usage examples.
### Key Documented Components
- `Client` - Main client with thread-safety notes
- `ClientConfig` - Configuration requirements and validation rules
- `ChatCompletionRequest` - Request structure with field descriptions
- `ChatCompletionResponse` - Response structure and usage
- `ChatCompletionStreamResponse` - Streaming response format
- `Usage` - Token usage information structure
- `Tool`, `Function`, `ToolCall` - Tool call structures
### Viewing Documentation
```bash
godoc -http=:6060
# Visit: http://localhost:6060/pkg/github.com/sglang/sglang-go-grpc-sdk/
```
## Development
```bash
cd bindings/golang
make build # Build Go bindings
go vet ./... # Check code quality
go fmt ./... # Format code
go test -race ./... # Run tests
```
### Project Structure
```
bindings/golang/
├── client.go # Main client implementation
├── client_test.go # Unit tests
├── integration_test.go # Integration tests
├── README.md # This file
├── Makefile # Build automation
├── Cargo.toml # Rust FFI dependencies
├── examples/ # Example programs
│ ├── simple/ # Non-streaming example
│ └── streaming/ # Streaming example
├── src/ # Rust FFI source
│ ├── client.rs # Client FFI
│ ├── stream.rs # Stream handling
│ ├── grpc_converter.rs # Response conversion
│ └── ...
└── internal/ # Internal packages
└── ffi/ # FFI bindings
```
## Troubleshooting
### Missing Dependencies
Run `go mod tidy` to sync dependencies.
### Connection Errors
Ensure SGLang server is running and check `SGL_GRPC_ENDPOINT`.
### Tokenizer Not Found
Set `SGL_TOKENIZER_PATH` environment variable.
2. Verify path contains required files: `ls $SGL_TOKENIZER_PATH`
3. Files should include: `tokenizer.json`, `vocab.json`, `config.json`
### Build Failures
**Error**: `library 'sgl_model_gateway_go' not found`
**Solution**:
1. Rebuild Rust library: `cd sgl-model-gateway/bindings/golang && make build`
2. Or manually with cargo: `cd sgl-model-gateway/bindings/golang && cargo build --release`
3. Set `CARGO_BUILD_DIR` if using non-standard build location
4. Ensure Rust toolchain is installed: `rustup toolchain list`
### Tests Hanging
**Error**: Tests seem to hang indefinitely
**Solution**:
1. Use timeout for hanging tests: `timeout 30s go test ./...`
2. Run with verbose output to see which test hangs: `go test -v ./...`
3. Ensure server is responsive: `grpcurl -plaintext localhost:20000 list`
### Memory Issues
**Error**: Out of memory during tests
**Solution**:
```bash
# Run with memory limit for long-running tests
GODEBUG=madvdontneed=1 go test -timeout 5m ./...
# Monitor memory during tests
watch -n1 'ps aux | grep test'
```
## Contributing
When adding new features:
1. Add comprehensive documentation to public types/functions
2. Include usage examples for complex APIs
3. Add unit tests covering happy path and error cases
4. Add integration tests if server interaction required
5. Ensure code passes `go vet` and `go test -race`
6. Update this README if adding new features
## License
See LICENSE file for details.
---
**Need Help?**
- Check examples in `examples/` directory
- Run tests to see working code: `go test -v ./...`
- Review function documentation: `godoc` or inline comments
- Check troubleshooting section above

View File

@@ -0,0 +1,483 @@
// Package sglang provides a Go SDK for SGLang gRPC API.
//
// SGLang is a fast language model serving framework. This package provides a Go client
// library for interacting with SGLang's gRPC API, following the style of OpenAI's Go SDK.
//
// Basic usage:
//
// client, err := sglang.NewClient(sglang.ClientConfig{
// Endpoint: "grpc://localhost:20000",
// TokenizerPath: "/path/to/tokenizer",
// })
// if err != nil {
// log.Fatal(err)
// }
// defer client.Close()
//
// resp, err := client.CreateChatCompletion(ctx, sglang.ChatCompletionRequest{
// Model: "default",
// Messages: []sglang.ChatMessage{
// {Role: "user", Content: "Hello"},
// },
// })
//
// For streaming responses, use CreateChatCompletionStream instead.
package sglang
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
grpcclient "github.com/sglang/sglang-go-grpc-sdk/internal/grpc"
)
// Client is the main client for interacting with SGLang gRPC API.
// It manages the connection to the SGLang server and handles both streaming
// and non-streaming chat completions.
//
// Thread-safe: All public methods are safe for concurrent use.
type Client struct {
endpoint string
tokenizerPath string
grpcClient *grpcclient.GrpcClient // gRPC-based client
mu sync.RWMutex
}
// ClientConfig holds configuration for creating a new client.
type ClientConfig struct {
// Endpoint is the gRPC endpoint URL (e.g., "grpc://localhost:20000").
// Required field. Must include the scheme (grpc://) and port number.
Endpoint string
// TokenizerPath is the path to the tokenizer directory containing
// tokenizer configuration files (e.g., tokenizer.json, vocab.json).
// Required field.
TokenizerPath string
// ChannelBufferSizes configures buffer sizes for internal channels.
// If nil, default values will be used (optimized for high concurrency).
ChannelBufferSizes *ChannelBufferSizes
// Timeouts configures timeout values for various operations.
// If nil, default values will be used.
Timeouts *Timeouts
}
// ChannelBufferSizes configures buffer sizes for internal channels.
// These affect concurrency and memory usage. Larger buffers allow more
// concurrent operations but use more memory.
type ChannelBufferSizes = grpcclient.ChannelBufferSizes
// Timeouts configures timeout values for various operations.
type Timeouts = grpcclient.Timeouts
// defaultChannelBufferSizes returns default channel buffer sizes optimized for high concurrency (10k+).
// These values are designed to handle thousands of concurrent requests without blocking.
func defaultChannelBufferSizes() ChannelBufferSizes {
return ChannelBufferSizes{
ResultJSONChan: 10000, // Increased for high concurrency: each request may produce 200-500 chunks
ErrChan: 100, // Errors are rare, 100 is sufficient
RecvChan: 2000, // Increased for high concurrency: more gRPC responses to buffer
}
}
// defaultTimeouts returns default timeout values.
func defaultTimeouts() Timeouts {
return Timeouts{
KeepaliveTime: 300 * time.Second, // Increased to reduce ping frequency and avoid "too many pings" errors
KeepaliveTimeout: 20 * time.Second,
CloseTimeout: 5 * time.Second,
}
}
// NewClient creates a new SGLang client with the given configuration.
//
// The client maintains a long-lived connection to the SGLang server and should
// be reused for multiple requests. Call Close() to release resources.
//
// Returns an error if:
// - Endpoint is empty
// - TokenizerPath is empty
// - Connection to the server fails
func NewClient(config ClientConfig) (*Client, error) {
if config.Endpoint == "" {
return nil, errors.New("endpoint is required")
}
if config.TokenizerPath == "" {
return nil, errors.New("tokenizer path is required")
}
bufferSizes := defaultChannelBufferSizes()
if config.ChannelBufferSizes != nil {
if config.ChannelBufferSizes.ResultJSONChan > 0 {
bufferSizes.ResultJSONChan = config.ChannelBufferSizes.ResultJSONChan
}
if config.ChannelBufferSizes.ErrChan > 0 {
bufferSizes.ErrChan = config.ChannelBufferSizes.ErrChan
}
if config.ChannelBufferSizes.RecvChan > 0 {
bufferSizes.RecvChan = config.ChannelBufferSizes.RecvChan
}
}
timeouts := defaultTimeouts()
if config.Timeouts != nil {
if config.Timeouts.KeepaliveTime > 0 {
timeouts.KeepaliveTime = config.Timeouts.KeepaliveTime
}
if config.Timeouts.KeepaliveTimeout > 0 {
timeouts.KeepaliveTimeout = config.Timeouts.KeepaliveTimeout
}
if config.Timeouts.CloseTimeout > 0 {
timeouts.CloseTimeout = config.Timeouts.CloseTimeout
}
}
grpcClient, err := grpcclient.NewGrpcClient(config.Endpoint, config.TokenizerPath, bufferSizes, timeouts)
if err != nil {
return nil, fmt.Errorf("failed to create gRPC client: %w", err)
}
return &Client{
endpoint: config.Endpoint,
tokenizerPath: config.TokenizerPath,
grpcClient: grpcClient,
}, nil
}
// Close closes the client and releases all resources.
//
// After Close() is called, the client cannot be used for further requests.
// Calling Close() multiple times is safe and idempotent.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.grpcClient != nil {
if err := c.grpcClient.Close(); err != nil {
return err
}
c.grpcClient = nil
}
return nil
}
// ChatCompletionRequest represents a request for chat completion.
// It follows the OpenAI API style for familiar usage.
type ChatCompletionRequest struct {
// Model specifies the model to use for completion (e.g., "default")
Model string `json:"model"`
// Messages is the list of messages in the conversation
Messages []ChatMessage `json:"messages"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Stream bool `json:"stream"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
Stop interface{} `json:"stop,omitempty"`
StopTokenIDs []int `json:"stop_token_ids,omitempty"`
SkipSpecialTokens bool `json:"skip_special_tokens,omitempty"`
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed *int `json:"seed,omitempty"`
Logprobs bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
User string `json:"user,omitempty"`
}
// ChatMessage represents a single message in a chat conversation
type ChatMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
Name string `json:"name,omitempty"`
}
// Tool represents a tool/function that can be called
type Tool struct {
Type string `json:"type"`
Function Function `json:"function"`
}
// Function represents a function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]interface{} `json:"parameters"`
}
// ResponseFormat represents the response format
type ResponseFormat struct {
Type string `json:"type"`
}
// ChatCompletionResponse represents a non-streaming chat completion response
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
// Choice represents a choice in the completion response
type Choice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
// Message represents a message in the response
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// ToolCall represents a tool call in the response
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
}
// FunctionCall represents a function call
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// Usage represents token usage information
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// ChatCompletionStreamResponse represents a streaming chat completion response
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Choices []StreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
// StreamChoice represents a choice in a streaming response
type StreamChoice struct {
Index int `json:"index"`
Delta MessageDelta `json:"delta"`
FinishReason string `json:"finish_reason,omitempty"`
}
// MessageDelta represents incremental message updates
type MessageDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// CreateChatCompletion creates a non-streaming chat completion with context support.
//
// Context Support:
// The ctx parameter is fully supported for cancellation and timeouts:
// - If ctx is cancelled, the request will be interrupted on the next stream.RecvJSON() call
// - If ctx times out, the request will return context.DeadlineExceeded
//
// Example with timeout:
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// resp, err := client.CreateChatCompletion(ctx, req)
//
// Note: Internally, this creates a stream and collects all chunks,
// so context monitoring happens at the chunk level.
func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error) {
// For non-streaming, we'll collect all chunks and return the final response
req.Stream = true // We still use streaming internally, but collect all chunks
if len(req.Tools) == 0 {
req.Tools = nil
}
stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
return nil, err
}
defer stream.Close()
var fullContent strings.Builder
var fullToolCalls []ToolCall
var finishReason string
var usage Usage
var responseID string
var created int64
var model string
var systemFingerprint string
for {
chunkJSON, err := stream.RecvJSON()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
var chunk ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(chunkJSON), &chunk); err != nil {
return nil, fmt.Errorf("failed to parse chunk: %w", err)
}
if chunk.ID != "" {
responseID = chunk.ID
}
if chunk.Created > 0 {
created = chunk.Created
}
if chunk.Model != "" {
model = chunk.Model
}
if chunk.SystemFingerprint != "" {
systemFingerprint = chunk.SystemFingerprint
}
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
fullContent.WriteString(choice.Delta.Content)
}
if len(choice.Delta.ToolCalls) > 0 {
fullToolCalls = append(fullToolCalls, choice.Delta.ToolCalls...)
}
if choice.FinishReason != "" {
finishReason = choice.FinishReason
}
}
if chunk.Usage != nil {
usage = *chunk.Usage
}
}
message := Message{
Role: "assistant",
Content: fullContent.String(),
}
if len(fullToolCalls) > 0 {
message.ToolCalls = fullToolCalls
}
if finishReason == "" {
finishReason = "stop"
}
return &ChatCompletionResponse{
ID: responseID,
Object: "chat.completion",
Created: created,
Model: model,
SystemFingerprint: systemFingerprint,
Choices: []Choice{
{
Index: 0,
Message: message,
FinishReason: finishReason,
},
},
Usage: usage,
}, nil
}
// ChatCompletionStream represents a streaming chat completion
type ChatCompletionStream struct {
grpcStream *grpcclient.GrpcChatCompletionStream
ctx context.Context
cancel context.CancelFunc
}
func (s *ChatCompletionStream) RecvJSON() (string, error) {
return s.grpcStream.RecvJSON()
}
// Close closes the stream and cancels any pending operations.
func (s *ChatCompletionStream) Close() error {
if s.cancel != nil {
s.cancel()
}
if s.grpcStream != nil {
return s.grpcStream.Close()
}
return nil
}
// CreateChatCompletionStream creates a streaming chat completion with context cancellation support.
//
// Context Support:
// The ctx parameter is now fully supported for cancellation and timeouts:
// - If ctx is cancelled, stream.RecvJSON() will return context.Canceled on the next call
// - If ctx times out (WithTimeout), stream.RecvJSON() will return context.DeadlineExceeded
// - Calling stream.Close() also cancels the context
//
// Example with timeout:
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// stream, err := client.CreateChatCompletionStream(ctx, req)
// // Stream will auto-close if 30 seconds elapse
//
// Example with cancellation:
//
// ctx, cancel := context.WithCancel(context.Background())
// stream, err := client.CreateChatCompletionStream(ctx, req)
// go func() {
// time.Sleep(5*time.Second)
// cancel() // Cancel after 5 seconds
// }()
func (c *Client) CreateChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionStream, error) {
reqJSON, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
var reqMap map[string]interface{}
if err := json.Unmarshal(reqJSON, &reqMap); err != nil {
return nil, fmt.Errorf("failed to unmarshal request to map: %w", err)
}
if _, exists := reqMap["tools"]; !exists {
reqMap["tools"] = []interface{}{}
}
reqJSON, err = json.Marshal(reqMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request map to JSON: %w", err)
}
if c.grpcClient == nil {
return nil, errors.New("gRPC client is closed")
}
grpcStream, err := c.grpcClient.CreateChatCompletionStream(ctx, string(reqJSON))
if err != nil {
return nil, fmt.Errorf("failed to create gRPC stream: %w", err)
}
streamCtx, cancel := context.WithCancel(ctx)
return &ChatCompletionStream{
grpcStream: grpcStream,
ctx: streamCtx,
cancel: cancel,
}, nil
}

View File

@@ -0,0 +1,325 @@
package sglang
import (
"context"
"testing"
)
// TestClientConfig tests ClientConfig validation
func TestClientConfig(t *testing.T) {
tests := []struct {
name string
config ClientConfig
wantErr bool
}{
{
name: "valid config",
config: ClientConfig{
Endpoint: "grpc://localhost:20000",
TokenizerPath: "/path/to/tokenizer",
},
wantErr: false,
},
{
name: "missing endpoint",
config: ClientConfig{
Endpoint: "",
TokenizerPath: "/path/to/tokenizer",
},
wantErr: true,
},
{
name: "missing tokenizer path",
config: ClientConfig{
Endpoint: "grpc://localhost:20000",
TokenizerPath: "",
},
wantErr: true,
},
{
name: "both missing",
config: ClientConfig{
Endpoint: "",
TokenizerPath: "",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewClient(tt.config)
if (err != nil) != tt.wantErr {
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
// TestChatMessageTypes tests ChatMessage struct and its variants
func TestChatMessageTypes(t *testing.T) {
msg := ChatMessage{
Role: "user",
Content: "Hello",
}
if msg.Role != "user" {
t.Errorf("Expected role 'user', got '%s'", msg.Role)
}
if msg.Content != "Hello" {
t.Errorf("Expected content 'Hello', got '%s'", msg.Content)
}
}
// TestChatCompletionRequestValidation tests ChatCompletionRequest validation
func TestChatCompletionRequestValidation(t *testing.T) {
// Test valid request
req := ChatCompletionRequest{
Model: "default",
Messages: []ChatMessage{
{Role: "user", Content: "test"},
},
Stream: false,
}
if req.Model == "" {
t.Error("Expected model to be set")
}
if len(req.Messages) == 0 {
t.Error("Expected messages to be non-empty")
}
if req.Messages[0].Role != "user" {
t.Errorf("Expected first message role 'user', got '%s'", req.Messages[0].Role)
}
}
// TestClientClose tests that Close can be called multiple times safely
func TestClientClose(t *testing.T) {
// Create a mock client (note: in real tests, you might want to skip this
// if it requires actual server connection)
config := ClientConfig{
Endpoint: "grpc://localhost:20000",
TokenizerPath: "/path/to/tokenizer",
}
// Skip if connection fails (expected in unit test environment)
client, err := NewClient(config)
if err != nil {
t.Skip("Skipping client close test: server not available")
}
// First close should succeed
if err := client.Close(); err != nil {
t.Errorf("First Close() failed: %v", err)
}
// Second close should also succeed (idempotent)
if err := client.Close(); err != nil {
t.Errorf("Second Close() failed: %v", err)
}
}
// TestChatCompletionResponseTypes tests response type structures
func TestChatCompletionResponseTypes(t *testing.T) {
resp := ChatCompletionResponse{
ID: "test-id",
Model: "default",
Created: 1234567890,
Choices: []Choice{
{
Message: Message{
Role: "assistant",
Content: "Hello",
},
FinishReason: "stop",
},
},
Usage: Usage{
PromptTokens: 10,
CompletionTokens: 20,
TotalTokens: 30,
},
}
if resp.ID != "test-id" {
t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
}
if len(resp.Choices) != 1 {
t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
}
if resp.Choices[0].Message.Content != "Hello" {
t.Errorf("Expected content 'Hello', got '%s'", resp.Choices[0].Message.Content)
}
if resp.Usage.TotalTokens != 30 {
t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
}
}
// TestStreamingResponseTypes tests streaming response structures
func TestStreamingResponseTypes(t *testing.T) {
chunk := ChatCompletionStreamResponse{
ID: "stream-id",
Created: 1234567890,
Choices: []StreamChoice{
{
Index: 0,
Delta: MessageDelta{
Content: "Hello",
},
FinishReason: "",
},
},
}
if chunk.ID != "stream-id" {
t.Errorf("Expected ID 'stream-id', got '%s'", chunk.ID)
}
if len(chunk.Choices) == 0 {
t.Error("Expected at least one choice")
}
if chunk.Choices[0].Delta.Content != "Hello" {
t.Errorf("Expected delta content 'Hello', got '%s'", chunk.Choices[0].Delta.Content)
}
}
// TestToolCallStructure tests Tool and ToolCall structures
func TestToolCallStructure(t *testing.T) {
tool := Tool{
Type: "function",
Function: Function{
Name: "get_weather",
Description: "Get the weather",
Parameters: map[string]interface{}{
"location": "string",
},
},
}
if tool.Type != "function" {
t.Errorf("Expected tool type 'function', got '%s'", tool.Type)
}
if tool.Function.Name != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", tool.Function.Name)
}
toolCall := ToolCall{
ID: "call-123",
Type: "function",
Function: FunctionCall{
Name: "get_weather",
Arguments: `{"location": "San Francisco"}`,
},
}
if toolCall.ID != "call-123" {
t.Errorf("Expected tool call ID 'call-123', got '%s'", toolCall.ID)
}
}
// TestConcurrentClientOperations tests thread safety
// This is a basic test that just verifies concurrent calls don't panic
func TestConcurrentClientOperations(t *testing.T) {
config := ClientConfig{
Endpoint: "grpc://localhost:20000",
TokenizerPath: "/path/to/tokenizer",
}
client, err := NewClient(config)
if err != nil {
t.Skip("Skipping concurrent operations test: server not available")
}
defer client.Close()
// Try concurrent Close calls (should not panic or race)
done := make(chan bool, 2)
go func() {
client.Close()
done <- true
}()
go func() {
client.Close()
done <- true
}()
<-done
<-done
}
// BenchmarkChatCompletionRequest benchmarks request creation
func BenchmarkChatCompletionRequest(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ChatCompletionRequest{
Model: "default",
Messages: []ChatMessage{
{Role: "user", Content: "test message"},
},
Stream: false,
Temperature: floatPtr(0.7),
MaxCompletionTokens: intPtr(100),
}
}
}
// Helper functions for benchmarks
func floatPtr(f float32) *float32 {
return &f
}
func intPtr(i int) *int {
return &i
}
// TestContextCancellation tests that cancelled context is handled gracefully.
//
// NOTE: Currently, the FFI layer is blocking and doesn't actively monitor context cancellation.
// This test verifies that the client at least returns an error rather than panicking or
// hanging indefinitely when a pre-cancelled context is passed.
//
// Future: When FFI supports context cancellation (via signals or async operations),
// this test should be updated to assert that the error is context.Canceled or wrapped
// context cancellation error.
func TestContextCancellation(t *testing.T) {
config := ClientConfig{
Endpoint: "grpc://localhost:20000",
TokenizerPath: "/path/to/tokenizer",
}
client, err := NewClient(config)
if err != nil {
t.Skip("Skipping context cancellation test: server not available")
}
defer client.Close()
// Create a pre-cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel()
req := ChatCompletionRequest{
Model: "default",
Messages: []ChatMessage{
{Role: "user", Content: "test"},
},
}
// Attempt request with cancelled context
// Since FFI is blocking, we expect either:
// 1. An error from the server/network
// 2. The call to complete normally (FFI doesn't check context)
// What we DON'T expect is a panic or indefinite hang
_, err = client.CreateChatCompletion(ctx, req)
if err != nil {
t.Logf("Request with cancelled context returned error: %v", err)
} else {
t.Logf("Request with cancelled context completed (FFI may not support context cancellation)")
}
}

View File

@@ -0,0 +1,239 @@
# Makefile for OAI Server
# Builds binary, runs tests, and provides basic targets
# Configuration
APP_NAME = oai_server
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S')
GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
# Paths
ROOT_DIR := $(shell pwd)
BINDINGS_DIR := $(shell cd $(ROOT_DIR)/../.. && pwd)
BUILD_DIR := $(ROOT_DIR)/build
BINARY := $(BUILD_DIR)/$(APP_NAME)
# Rust FFI library paths
LIB_DIR := $(BINDINGS_DIR)/lib
LIB_NAME = libsgl_model_gateway_go
# Detect OS
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux)
LIB_EXT = .so
LD_LIBRARY_PATH_VAR = LD_LIBRARY_PATH
ARCH := $(shell uname -m)
ifeq ($(ARCH),x86_64)
GOARCH = amd64
else ifeq ($(ARCH),aarch64)
GOARCH = arm64
endif
endif
ifeq ($(UNAME_S),Darwin)
LIB_EXT = .dylib
LD_LIBRARY_PATH_VAR = DYLD_LIBRARY_PATH
ARCH := $(shell uname -m)
ifeq ($(ARCH),x86_64)
GOARCH = amd64
else ifeq ($(ARCH),arm64)
GOARCH = arm64
endif
endif
# Build flags
LDFLAGS = -X main.Version=$(VERSION) -X main.BuildTime=$(BUILD_TIME) -X main.GitCommit=$(GIT_COMMIT)
GO_BUILD_FLAGS = -ldflags "$(LDFLAGS)"
# Python LDFLAGS (needed for Rust FFI that depends on Python)
PYTHON_LDFLAGS := $(shell python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || python-config --ldflags --embed 2>/dev/null || python-config --ldflags 2>/dev/null || echo "")
# CGO flags
CGO_LDFLAGS = -L$(LIB_DIR) $(PYTHON_LDFLAGS)
.PHONY: all build build-dev test e2e clean help lib run stream check-rust-lib check-server
# E2E test configuration
E2E_HOST ?= localhost
E2E_PORT ?= 8080
E2E_MODEL ?= default
E2E_TOKENIZER ?= $(shell echo $$SGL_TOKENIZER_PATH || echo "./examples/tokenizer")
E2E_NUM_PROMPTS ?= 100
E2E_INPUT_LEN ?= 1024
E2E_OUTPUT_LEN ?= 512
E2E_REQUEST_RATE ?= 20
E2E_MAX_CONCURRENCY ?= 20
E2E_BASE_URL ?= http://$(E2E_HOST):$(E2E_PORT)
help:
@echo "OAI Server Makefile"
@echo ""
@echo "Available targets:"
@echo " lib - Build Rust FFI library"
@echo " build - Build binary (release mode)"
@echo " build-dev - Build binary (debug mode)"
@echo " test - Run tests"
@echo " e2e - Run end-to-end test with bench_serving.py"
@echo " run - Run the server (development)"
@echo " stream - Run streaming example"
@echo " clean - Clean build artifacts"
@echo ""
@echo "E2E test variables:"
@echo " E2E_HOST - OAI Server host (default: localhost)"
@echo " E2E_PORT - OAI Server port (default: 8080)"
@echo " E2E_MODEL - Model name (default: default)"
@echo " E2E_TOKENIZER - Tokenizer path"
@echo " E2E_NUM_PROMPTS - Number of prompts (default: 100)"
@echo " E2E_INPUT_LEN - Input token length (default: 1024)"
@echo " E2E_OUTPUT_LEN - Output token length (default: 512)"
@echo " E2E_REQUEST_RATE - Request rate per second (default: 20)"
@echo " E2E_MAX_CONCURRENCY - Max concurrent requests (default: 20)"
all: build
# Build Rust FFI library
lib:
@echo "Building Rust FFI library..."
@cd $(BINDINGS_DIR) && $(MAKE) lib
@echo "✓ Rust FFI library built"
# Check if Rust FFI library exists
check-rust-lib:
@if [ ! -f "$(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)" ]; then \
echo "Error: Rust FFI library not found at $(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)"; \
echo "Building Rust library..."; \
cd $(BINDINGS_DIR) && $(MAKE) lib; \
fi
@echo "✓ Rust FFI library found"
# Build binary (release)
build: check-rust-lib
@echo "Building $(APP_NAME) (release mode)..."
@mkdir -p $(BUILD_DIR)
@CGO_ENABLED=1 \
CGO_LDFLAGS="$(CGO_LDFLAGS)" \
GOOS=$(shell go env GOOS) \
GOARCH=$(GOARCH) \
go build $(GO_BUILD_FLAGS) -o $(BINARY) .
@echo "✓ Binary built: $(BINARY)"
# Build binary (debug)
build-dev: check-rust-lib
@echo "Building $(APP_NAME) (debug mode)..."
@mkdir -p $(BUILD_DIR)
@CGO_ENABLED=1 \
CGO_LDFLAGS="$(CGO_LDFLAGS)" \
go build -o $(BINARY) .
@echo "✓ Binary built (debug): $(BINARY)"
# Run tests
test: check-rust-lib
@echo "Running tests..."
@CGO_ENABLED=1 \
CGO_LDFLAGS="$(CGO_LDFLAGS)" \
export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \
go test -v ./...
@echo "✓ Tests completed"
# Check if OAI Server is running
check-server:
@echo "Checking if OAI Server is running at $(E2E_BASE_URL)..."
@if curl -s -f $(E2E_BASE_URL)/health > /dev/null 2>&1; then \
echo "✓ OAI Server is running"; \
exit 0; \
else \
echo "✗ OAI Server is not running at $(E2E_BASE_URL)"; \
echo " Start it with: make run"; \
exit 1; \
fi
# Find sglang project root (4 levels up from oai_server)
SGLANG_ROOT := $(shell cd $(ROOT_DIR)/../../../../.. && pwd)
# Run end-to-end test with bench_serving.py
e2e: check-server
@echo "Checking if bench_serving.py is available..."
@if python -m sglang.bench_serving --help > /dev/null 2>&1; then \
echo "✓ Using installed bench_serving.py module"; \
USE_SGLANG_ROOT=false; \
elif [ -f "$(SGLANG_ROOT)/python/sglang/bench_serving.py" ]; then \
echo "✓ Using bench_serving.py from $(SGLANG_ROOT)"; \
USE_SGLANG_ROOT=true; \
else \
echo "✗ bench_serving.py is not available"; \
echo " Install dependencies: pip install aiohttp numpy datasets transformers tqdm pillow pybase64"; \
exit 1; \
fi
@echo "Running end-to-end test with bench_serving.py..."
@echo "Configuration:"
@echo " Server: $(E2E_BASE_URL)"
@if [ "$(E2E_MODEL)" != "default" ]; then \
echo " Model: $(E2E_MODEL)"; \
fi
@if [ -n "$(E2E_TOKENIZER)" ]; then \
echo " Tokenizer: $(E2E_TOKENIZER)"; \
fi
@echo " Prompts: $(E2E_NUM_PROMPTS)"
@echo " Input/Output: $(E2E_INPUT_LEN)/$(E2E_OUTPUT_LEN) tokens"
@echo " Request rate: $(E2E_REQUEST_RATE) req/s"
@echo " Max concurrency: $(E2E_MAX_CONCURRENCY)"
@echo ""
@TOKENIZER_ABS=$$(cd $(ROOT_DIR) && python3 -c "import os; path='$(E2E_TOKENIZER)'; print(os.path.abspath(path) if not os.path.isabs(path) else path)" 2>/dev/null || echo "$(E2E_TOKENIZER)"); \
if [ -n "$(E2E_TOKENIZER)" ]; then \
if [ -n "$$TOKENIZER_ABS" ] && ([ -d "$$TOKENIZER_ABS" ] || [ -f "$$TOKENIZER_ABS" ]); then \
TOKENIZER_ARG="--tokenizer $$TOKENIZER_ABS"; \
else \
TOKENIZER_ARG="--tokenizer $(E2E_TOKENIZER)"; \
fi; \
else \
TOKENIZER_ARG=""; \
fi; \
if [ "$$USE_SGLANG_ROOT" = "true" ]; then \
cd $(SGLANG_ROOT) && PYTHONPATH=$(SGLANG_ROOT)/python:$$PYTHONPATH python python/sglang/bench_serving.py \
--backend sglang-oai-chat \
--base-url $(E2E_BASE_URL) \
$$([ "$(E2E_MODEL)" != "default" ] && echo "--model $(E2E_MODEL)") \
$$TOKENIZER_ARG \
--dataset-name random \
--num-prompts $(E2E_NUM_PROMPTS) \
--random-input-len $(E2E_INPUT_LEN) \
--random-output-len $(E2E_OUTPUT_LEN) \
--request-rate $(E2E_REQUEST_RATE) \
--max-concurrency $(E2E_MAX_CONCURRENCY) \
--warmup-requests 5 \
--disable-tqdm || (echo "✗ E2E test failed"; exit 1); \
else \
python -m sglang.bench_serving \
--backend sglang-oai-chat \
--base-url $(E2E_BASE_URL) \
$$([ "$(E2E_MODEL)" != "default" ] && echo "--model $(E2E_MODEL)") \
$$TOKENIZER_ARG \
--dataset-name random \
--num-prompts $(E2E_NUM_PROMPTS) \
--random-input-len $(E2E_INPUT_LEN) \
--random-output-len $(E2E_OUTPUT_LEN) \
--request-rate $(E2E_REQUEST_RATE) \
--max-concurrency $(E2E_MAX_CONCURRENCY) \
--warmup-requests 5 \
--disable-tqdm || (echo "✗ E2E test failed"; exit 1); \
fi
@echo ""
@echo "✓ E2E test completed"
# Run the server (development)
run: build-dev
@echo "Running server..."
@export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \
$(BINARY)
# Run streaming example
stream: check-rust-lib
@echo "Running streaming example..."
@cd $(BINDINGS_DIR)/examples/streaming && \
export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \
bash run.sh
# Clean build artifacts
clean:
@echo "Cleaning build artifacts..."
@rm -rf $(BUILD_DIR)
@echo "✓ Clean complete"

View File

@@ -0,0 +1,305 @@
# Go SGLang Router - OpenAI Compatible API Server
Go SGLang Router is a high-performance OpenAI-compatible API server that communicates with the SGLang backend via gRPC and performs efficient preprocessing and postprocessing through Rust FFI.
## Features
-**OpenAI API Compatible**: Fully compatible with OpenAI Chat Completions API
-**High Performance**: Low latency and high throughput using gRPC and Rust FFI
-**Streaming Support**: Server-Sent Events (SSE) streaming responses
-**Thread-Safe**: Pre-created tokenizer handle, lock-free concurrency
-**Graceful Shutdown**: Context cancellation mechanism to avoid resource leaks and panics
-**Configurable**: Supports configuring channel buffer sizes and timeout durations
## Architecture Overview
**Important Note**: gRPC mode **still calls FFI**, which is used for:
- **Preprocessing**: chat_template and tokenization (request phase)
- **Postprocessing**: token decoding and tool parsing (response phase)
gRPC is only used for communication with the SGLang backend, while input/output processing completely relies on Rust FFI.
```
┌─────────────────────────────────────────────────────────────────┐
│ HTTP Client │
│ (OpenAI API Format) │
└────────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ FastHTTP Server │
│ handlers/chat.go:HandleChatCompletion │
│ - Parse request JSON │
│ - SetBodyStreamWriter (SSE) │
└────────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ SGLang Client (client.go) │
│ CreateChatCompletionStream(ctx, req) │
│ - Wraps gRPC client │
└────────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ gRPC Client (internal/grpc/client_grpc.go) │
│ CreateChatCompletionStream(ctx, reqJSON) │
│ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Step 1: FFI Preprocess (Rust FFI) │ │
│ │ - ffi.PreprocessChatRequestWithTokenizer() │ │
│ │ - chat_template application │ │
│ │ - tokenization │ │
│ │ - tool constraints generation │ │
│ │ Returns: PromptText, TokenIDs, ToolConstraintsJSON, │ │
│ │ PromptTokens │ │
│ └────────────────────┬─────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Step 2: Build gRPC Request │ │
│ │ - Parse request JSON (model, temperature, etc.) │ │
│ │ - Create proto.GenerateRequest │ │
│ │ - Set TokenizedInput (PromptText, TokenIDs) │ │
│ │ - Set SamplingParams (temperature, top_p, top_k, etc.) │ │
│ │ - Set Constraints (from ToolConstraintsJSON) │ │
│ └────────────────────┬─────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Step 3: Create gRPC Stream │ │
│ │ - client.Generate(generateReq) → gRPC stream │ │
│ │ - Connects to SGLang Backend (Rust) │ │
│ └────────────────────┬─────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Step 4: Create Converter & BatchPostprocessor │ │
│ │ - ffi.CreateGrpcResponseConverterWithTokenizer() │ │
│ │ - Uses preprocessed.PromptTokens for initial count │ │
│ │ - ffi.NewBatchPostprocessor(batchSize=1, immediate) │ │
│ └────────────────────┬─────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Step 5: Start readLoop (Background Goroutine) │ │
│ │ - go grpcStream.readLoop() │ │
│ │ - Returns GrpcChatCompletionStream immediately │ │
│ └────────────────────┬─────────────────────────────────────┘ │
└───────────────────────┼────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ GrpcChatCompletionStream.readLoop() │
│ (Background Goroutine) │
│ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Recv() Goroutine (Dedicated) │ │
│ │ - Continuously calls stream.Recv() │ │
│ │ - Sends results to recvChan (buffered, 2000) │ │
│ │ - Exits on ctx.Done() or error │ │
│ │ - Calls stream.CloseSend() on ctx.Done() │ │
│ └────────────────────┬─────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ Main Loop │ │
│ │ - Reads from recvChan │ │
│ │ - For each proto.GenerateResponse: │ │
│ │ → go processAndSendResponse() (async) │ │
│ │ - protoToJSON() converts proto to JSON string │ │
│ │ - batchPostprocessor.AddChunk(protoJSON) │ │
│ │ → FFI postprocessing (token decoding, tool parsing)│ │
│ │ → Returns OpenAI-format JSON strings │ │
│ │ - Sends JSON to resultJSONChan (buffered, 10000) │ │
│ │ - All operations check ctx.Done() for cancellation │ │
│ │ - On EOF: flush batch, send remaining results, return │ │
│ │ - On error: send to errChan (buffered, 100) │ │
│ │ - defer: cancel ctx, wait goroutines, close channels │ │
│ └────────────────────┬─────────────────────────────────────┘ │
└───────────────────────┼────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ resultJSONChan (Buffered Channel, 10000) │
│ - Contains OpenAI-format JSON strings │
│ - Ready for consumption │
└────────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ ChatCompletionStream.RecvJSON() │
│ (client.go:410) │
│ - Direct wrapper: return grpcStream.RecvJSON() │
│ - No intermediate processing │
└────────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ FastHTTP SetBodyStreamWriter │
│ (handlers/chat.go:159) │
│ - Loop: stream.RecvJSON() → format SSE → flush │
│ - Format: "data: {json}\n\n" │
│ - Final: "data: [DONE]\n\n" │
│ - Immediate flush after each chunk │
└────────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ HTTP Client │
│ (SSE Stream) │
│ Receives: data: {...}\n\n │
└─────────────────────────────────────────────────────────────────┘
```
## Quick Start
### Start Server
```bash
./run.sh
```
The server will start on port `:8080`.
### Usage Example
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/path/to/model",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": true
}'
```
## Key Design
### 1. Thread-Safe Tokenizer
- Pre-create `TokenizerHandle` at startup
- Rust side uses `Arc<dyn TokenizerTrait>`, thread-safe
- Lock-free concurrency, eliminating lock contention
### 2. Context Cancellation Mechanism (Graceful Shutdown)
- Use `context.Context` cancellation mechanism
- In `readLoop`'s `defer`: cancel context first, then wait for all goroutines to complete, finally close channels
- `processAndSendResponse` checks `ctx.Done()` at function start, all `select` statements include `case <-s.ctx.Done()`
- Avoids "send on closed channel" panic
### 3. Cancellable Recv()
- Use dedicated goroutine to execute `Recv()`
- Pass results through `recvChan`
- Call `CloseSend()` when context is cancelled to make `Recv()` return error
### 4. Simplified Channel Design
- `resultJSONChan`: Main data channel (gRPC layer)
- `errChan`: Error channel (gRPC layer)
- `recvChan`: Internal communication channel (gRPC layer)
- Removed redundant channels and duplicate reads
## Configuration
### Channel Buffer Sizes
```go
type ChannelBufferSizes struct {
ResultJSONChan int // Default: 10000
ErrChan int // Default: 100
RecvChan int // Default: 2000
}
```
### Timeout Configuration
```go
type Timeouts struct {
KeepaliveTime time.Duration // Default: 300s
KeepaliveTimeout time.Duration // Default: 20s
CloseTimeout time.Duration // Default: 5s
}
```
## Performance Optimizations
1. **Pre-create Tokenizer**: Created at startup to avoid first request latency
2. **Lock-Free Concurrency**: Tokenizer is thread-safe, no locks needed
3. **Lazy Parsing**: JSON parsing deferred until needed
4. **Direct JSON Passing**: `RecvJSON()` avoids parse/serialize overhead
5. **Immediate Batching**: batchSize=1, no delay
6. **Async Processing**: `readLoop` processes in background, doesn't block request handling
7. **Configurable Buffers**: Adjust channel sizes based on concurrency needs
## File Structure
```
sgl-model-gateway/bindings/golang/
├── client.go # High-level client API
├── internal/
│ ├── grpc/
│ │ └── client_grpc.go # gRPC client implementation
│ ├── ffi/ # FFI bindings (Rust)
│ └── proto/ # Protobuf definitions
└── examples/
└── oai_server/
├── handlers/
│ └── chat.go # HTTP request handling
├── models/
│ └── chat.go # Request/response models
└── service/
└── sglang_service.go # Service layer
```
## Error Handling
### Context Cancellation Mechanism
1. **Client disconnects**`SetBodyStreamWriter` detects flush error
2. **Cancel streamCtx**`readLoop` detects `ctx.Done()`
3. **Call stream.CloseSend()**`Recv()` goroutine returns error
4. **readLoop defer executes**:
- Set `closed` flag
- Cancel context (if not already cancelled)
- Wait for all `processAndSendResponse` goroutines to complete (`processWg.Wait()`)
- Close all channels (`resultJSONChan`, `errChan`, `readLoopDone`)
5. **Clean up resources and exit**
### Channel Blocking and Race Condition Prevention
- **Context cancellation mechanism**: All channel sends use `select` statements with `case <-s.ctx.Done()`
- **Graceful exit**: When context is cancelled, all blocking send operations can return immediately
- **WaitGroup synchronization**: `readLoop`'s `defer` uses `processWg.Wait()` to ensure all goroutines complete before closing channels
- **Avoid panic**: Through context cancellation and WaitGroup synchronization, avoids "send on closed channel" panic
## Key Functions
### CreateChatCompletionStream
**Location**: `internal/grpc/client_grpc.go:108`
- Preprocess request (FFI)
- Build gRPC request
- Create converter and batch processor
- Start `readLoop`
### readLoop
**Location**: `internal/grpc/client_grpc.go:290`
- Start Recv() goroutine (continuously calls `stream.Recv()`)
- Process proto responses
- Asynchronously call `processAndSendResponse` (tracked with `processWg`)
- **Graceful shutdown in defer**:
- Set `closed` flag
- Cancel context (if not already cancelled)
- Wait for all `processAndSendResponse` goroutines to complete (`processWg.Wait()`)
- Close all channels (`resultJSONChan`, `errChan`, `readLoopDone`)
### processAndSendResponse
**Location**: `internal/grpc/client_grpc.go:379`
- Check `ctx.Done()` at function start, return immediately if cancelled
- Convert proto to JSON
- Call FFI batch processor
- All `select` statements include `case <-s.ctx.Done()` for graceful shutdown handling
- Send JSON to channel
### RecvJSON
**Location**:
- `internal/grpc/client_grpc.go:412`: gRPC layer implementation
- `client.go:410`: Client wrapper layer
- Read from `resultJSONChan`
- Directly return JSON string, no parsing needed

View File

@@ -0,0 +1,55 @@
package config
import (
"os"
)
// Config holds the application configuration
type Config struct {
Endpoint string
TokenizerPath string
Port string
LogDir string
LogLevel string
}
// Load loads configuration from environment variables with defaults
func Load() *Config {
// Get tokenizer path from environment or use default
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
if tokenizerPath == "" {
tokenizerPath = "../tokenizer"
}
// Get endpoint from environment or use default
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
if endpoint == "" {
endpoint = "grpc://localhost:20000"
}
// Get port from environment or use default
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
// Get log directory from environment or use default
logDir := os.Getenv("LOG_DIR")
if logDir == "" {
logDir = "./logs"
}
// Get log level from environment or use default
logLevel := os.Getenv("LOG_LEVEL")
if logLevel == "" {
logLevel = "info"
}
return &Config{
Endpoint: endpoint,
TokenizerPath: tokenizerPath,
Port: port,
LogDir: logDir,
LogLevel: logLevel,
}
}

View File

@@ -0,0 +1,121 @@
/tmp/ShareGPT_V3_unfiltered_cleaned_split.json: 100%|████████████████████| 642M/642M [10:02<00:00, 1.12MB/s]
#Input tokens: 50561
#Output tokens: 25883
Starting warmup with 5 sequences...
Warmup completed with 5 sequences. Starting main benchmark run...
============ Serving Benchmark Result ============
Backend: sglang-oai-chat
Traffic request rate: 20.0
Max request concurrency: 20
Successful requests: 100
Benchmark duration (s): 107.24
Total input tokens: 50561
Total input text tokens: 50561
Total input vision tokens: 0
Total generated tokens: 25883
Total generated tokens (retokenized): 129591
Request throughput (req/s): 0.93
Input token throughput (tok/s): 471.48
Output token throughput (tok/s): 241.36
Total token throughput (tok/s): 712.84
Concurrency: 16.42
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 17609.46
Median E2E Latency (ms): 12343.82
---------------Time to First Token----------------
Mean TTFT (ms): 190.71
Median TTFT (ms): 164.86
P99 TTFT (ms): 397.72
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 162.55
Median TPOT (ms): 63.51
P99 TPOT (ms): 1337.20
---------------Inter-Token Latency----------------
Mean ITL (ms): 25.85
Median ITL (ms): 24.26
P95 ITL (ms): 48.26
P99 ITL (ms): 119.04
Max ITL (ms): 194.58
==================================================
E2E test completed
## Rust
============ Serving Benchmark Result ============
Backend: sglang-oai-chat
Traffic request rate: 20.0
Max request concurrency: 20
Successful requests: 100
Benchmark duration (s): 37.71
Total input tokens: 50561
Total input text tokens: 50561
Total input vision tokens: 0
Total generated tokens: 25883
Total generated tokens (retokenized): 25599
Request throughput (req/s): 2.65
Input token throughput (tok/s): 1340.75
Output token throughput (tok/s): 686.35
Total token throughput (tok/s): 2027.10
Concurrency: 18.58
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 7008.05
Median E2E Latency (ms): 7061.24
---------------Time to First Token----------------
Mean TTFT (ms): 156.09
Median TTFT (ms): 133.81
P99 TTFT (ms): 318.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 26.59
Median TPOT (ms): 26.75
P99 TPOT (ms): 29.18
---------------Inter-Token Latency----------------
Mean ITL (ms): 26.71
Median ITL (ms): 23.61
P95 ITL (ms): 66.11
P99 ITL (ms): 115.30
Max ITL (ms): 201.08
==================================================
## golang
#Input tokens: 50561
#Output tokens: 25883
Starting warmup with 5 sequences...
Warmup completed with 5 sequences. Starting main benchmark run...
============ Serving Benchmark Result ============
Backend: sglang-oai-chat
Traffic request rate: 20.0
Max request concurrency: 20
Successful requests: 100
Benchmark duration (s): 34.22
Total input tokens: 50561
Total input text tokens: 50561
Total input vision tokens: 0
Total generated tokens: 22970
Total generated tokens (retokenized): 31740
Request throughput (req/s): 2.92
Input token throughput (tok/s): 1477.70
Output token throughput (tok/s): 671.32
Total token throughput (tok/s): 2149.03
Concurrency: 18.42
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 6303.33
Median E2E Latency (ms): 6294.46
---------------Time to First Token----------------
Mean TTFT (ms): 157.10
Median TTFT (ms): 149.16
P99 TTFT (ms): 251.98
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 26.49
Median TPOT (ms): 27.15
P99 TPOT (ms): 28.73
---------------Inter-Token Latency----------------
Mean ITL (ms): 26.97
Median ITL (ms): 24.61
P95 ITL (ms): 52.39
P99 ITL (ms): 86.52
Max ITL (ms): 194.55
==================================================

View File

@@ -0,0 +1,60 @@
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0=
github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo=
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,556 @@
package handlers
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
sglang "github.com/sglang/sglang-go-grpc-sdk"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
"oai_server/models"
"oai_server/service"
"oai_server/utils"
)
// ChatHandler handles chat completion requests
type ChatHandler struct {
logger *zap.Logger
service *service.SGLangService
}
// NewChatHandler creates a new chat handler
func NewChatHandler(logger *zap.Logger, svc *service.SGLangService) *ChatHandler {
return &ChatHandler{
logger: logger,
service: svc,
}
}
// recvResult holds the result of a RecvJSON() call
type recvResult struct {
chunkJSON string
err error
}
// HandleChatCompletion handles POST /v1/chat/completions
func (h *ChatHandler) HandleChatCompletion(ctx *fasthttp.RequestCtx) {
var req models.ChatRequest
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
h.logger.Warn("Invalid chat completion request", zap.Error(err))
utils.RespondError(ctx, 400, fmt.Sprintf("Invalid request: %v", err), "invalid_request_error")
return
}
path := string(ctx.Path())
defer func() {
statusCode := ctx.Response.StatusCode()
if statusCode == 0 {
statusCode = 200
}
h.logHTTPResponse(statusCode, path)
}()
// Convert to SGLang format
messages := make([]sglang.ChatMessage, len(req.Messages))
for i, msg := range req.Messages {
role, roleOk := msg["role"]
content, contentOk := msg["content"]
// Validate role
if !roleOk || role == "" {
h.logger.Warn("Missing or empty role in message", zap.Int("message_index", i))
utils.RespondError(ctx, 400, "Message role is required and cannot be empty", "invalid_request_error")
return
}
// Ensure content is always a string (not null)
// Chat template requires content field to be present, even if empty
// If content is missing or null, use empty string
contentStr := ""
if contentOk && content != "" {
contentStr = content
}
messages[i] = sglang.ChatMessage{
Role: role,
Content: contentStr,
}
}
sglReq := sglang.ChatCompletionRequest{
Model: req.Model,
Messages: messages,
Stream: req.Stream,
}
if req.Temperature != nil {
temp := float32(*req.Temperature)
sglReq.Temperature = &temp
}
if req.TopP != nil {
topP := float32(*req.TopP)
sglReq.TopP = &topP
}
if req.MaxCompletionTokens != nil {
sglReq.MaxCompletionTokens = req.MaxCompletionTokens
} else if req.MaxTokens != nil {
sglReq.MaxCompletionTokens = req.MaxTokens
}
requestCtx := context.Background()
if req.Stream {
h.handleStreamingCompletion(ctx, requestCtx, sglReq)
} else {
h.handleNonStreamingCompletion(ctx, requestCtx, sglReq)
}
}
// isBrokenPipeError checks if the error is a broken pipe error (client disconnected)
func isBrokenPipeError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "connection closed") ||
strings.Contains(errStr, "write: connection closed")
}
// logHTTPResponse logs HTTP response with colored output
func (h *ChatHandler) logHTTPResponse(statusCode int, path string) {
var statusText string
var colorCode string
switch {
case statusCode >= 200 && statusCode < 300:
colorCode = "\033[32m" // Green
statusText = "OK"
case statusCode >= 300 && statusCode < 400:
colorCode = "\033[33m" // Yellow
statusText = "Redirect"
case statusCode >= 400 && statusCode < 500:
colorCode = "\033[33m" // Yellow
statusText = "Client Error"
case statusCode >= 500:
colorCode = "\033[31m" // Red
statusText = "Server Error"
default:
colorCode = "\033[37m" // White
statusText = "Unknown"
}
resetCode := "\033[0m"
msg := fmt.Sprintf("%s[%d %s]%s %s", colorCode, statusCode, statusText, resetCode, path)
h.logger.Info(msg)
}
func (h *ChatHandler) handleStreamingCompletion(ctx *fasthttp.RequestCtx, requestCtx context.Context, req sglang.ChatCompletionRequest) {
ctx.SetContentType("text/event-stream")
ctx.Response.Header.Set("Cache-Control", "no-cache")
ctx.Response.Header.Set("Connection", "keep-alive")
ctx.Response.Header.Set("X-Accel-Buffering", "no")
ctx.SetStatusCode(200)
var clientDisconnected bool
// Flush timeout: prevent deadlock if client is slow or disconnected
// This timeout should be longer than typical network latency but shorter than client timeout
const flushTimeout = 5 * time.Second
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
streamCtx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := h.service.Client().CreateChatCompletionStream(streamCtx, req)
if err != nil {
h.logger.Error("Failed to create chat completion stream",
zap.Error(err),
zap.String("model", req.Model),
)
// Use sendSSEError to send error in consistent format
errInfo, sendErr := h.sendSSEError(w, err)
if sendErr != nil {
h.logger.Warn("Failed to send SSE error", zap.Error(sendErr))
} else if errInfo.IsTimeout {
h.logger.Error("Stream creation timeout", zap.Error(err))
}
return
}
defer func() {
if closeErr := stream.Close(); closeErr != nil {
h.logger.Warn("Failed to close stream", zap.Error(closeErr))
}
}()
// Use a single dedicated goroutine to continuously call RecvJSON() and send results via channel
recvChan := make(chan recvResult, 20)
recvGoroutineDone := make(chan struct{})
go func() {
defer func() {
close(recvChan)
close(recvGoroutineDone)
}()
for {
// Check context before calling RecvJSON() to avoid blocking if context is cancelled
select {
case <-streamCtx.Done():
return
default:
}
// Call RecvJSON() - this may block, but stream.Close() will unblock it
// when context is cancelled (called from main loop)
chunkJSON, err := stream.RecvJSON()
// Check context again after RecvJSON() returns
select {
case <-streamCtx.Done():
return
default:
}
// Send to channel (may block if channel is full)
// If channel is full, this will block until main loop reads from it
// This is acceptable because main loop should be actively reading
select {
case recvChan <- recvResult{chunkJSON: chunkJSON, err: err}:
if err != nil {
// EOF or other error, stop the goroutine
return
}
case <-streamCtx.Done():
// Context cancelled while sending, stop the goroutine
return
}
}
}()
for {
if clientDisconnected {
cancel()
// Close stream immediately to unblock RecvJSON() calls
stream.Close()
return
}
select {
case <-streamCtx.Done():
// Close stream to ensure RecvJSON() goroutine can exit
stream.Close()
return
case result, ok := <-recvChan:
if !ok {
// Channel closed, stream ended
return
}
if result.err == io.EOF {
if !clientDisconnected {
w.WriteString("data: [DONE]\n\n")
// Flush with timeout to prevent deadlock
flushDone := make(chan error, 1)
go func() {
flushDone <- w.Flush()
}()
flushCtx, flushCancel := context.WithTimeout(streamCtx, flushTimeout)
defer flushCancel()
select {
case flushErr := <-flushDone:
if flushErr != nil && !isBrokenPipeError(flushErr) {
h.logger.Warn("Final flush error", zap.Error(flushErr))
}
case <-flushCtx.Done():
if flushCtx.Err() == context.DeadlineExceeded {
h.logger.Warn("Final flush timeout", zap.Duration("timeout", flushTimeout))
}
case <-streamCtx.Done():
// Context cancelled, skip flush
}
}
return
}
if result.err != nil {
if result.err == context.Canceled || result.err == context.DeadlineExceeded {
return
}
// Send error to client before closing
errInfo, sendErr := h.sendSSEError(w, result.err)
if sendErr != nil {
h.logger.Warn("Failed to send SSE error", zap.Error(sendErr))
}
if errInfo.IsTimeout {
h.logger.Error("Stream timeout error", zap.Error(result.err))
} else {
h.logger.Error("Stream error", zap.Error(result.err))
}
return
}
if result.chunkJSON == "" {
continue
}
w.WriteString("data: ")
w.WriteString(result.chunkJSON)
w.WriteString("\n\n")
// Flush with timeout to prevent deadlock:
// If Flush blocks indefinitely (slow client), RecvJSON goroutine may fill recvChan
// and then block trying to send, causing deadlock
// Note: bufio.Writer.Flush() doesn't have a timeout parameter, so we use
// a goroutine + select pattern to implement timeout behavior
flushDone := make(chan error, 1)
go func() {
flushDone <- w.Flush()
}()
flushCtx, flushCancel := context.WithTimeout(streamCtx, flushTimeout)
defer flushCancel()
select {
case err := <-flushDone:
if err != nil {
if isBrokenPipeError(err) {
clientDisconnected = true
cancel()
// Close stream immediately to unblock RecvJSON() calls
stream.Close()
return
}
h.logger.Warn("Flush error", zap.Error(err))
}
case <-flushCtx.Done():
// Flush timeout: client may be slow or disconnected
// Continue processing to avoid deadlock, but mark as disconnected
if flushCtx.Err() == context.DeadlineExceeded {
h.logger.Warn("Flush timeout, client may be slow or disconnected", zap.Duration("timeout", flushTimeout))
}
clientDisconnected = true
cancel()
stream.Close()
return
case <-streamCtx.Done():
// Context cancelled, stop flushing
return
}
}
}
})
}
func (h *ChatHandler) handleNonStreamingCompletion(ctx *fasthttp.RequestCtx, requestCtx context.Context, req sglang.ChatCompletionRequest) {
resp, err := h.service.Client().CreateChatCompletion(requestCtx, req)
if err != nil {
h.logger.Error("Failed to create chat completion",
zap.Error(err),
zap.String("model", req.Model),
)
utils.RespondError(ctx, 500, fmt.Sprintf("Failed to create completion: %v", err), "server_error")
return
}
// Convert to OpenAI format
response := utils.BuildResponseBase(resp.ID, resp.Created, resp.Model)
response["object"] = "chat.completion"
choices := make([]map[string]interface{}, len(resp.Choices))
for i, choice := range resp.Choices {
choiceMap := map[string]interface{}{
"index": choice.Index,
"message": map[string]interface{}{
"role": choice.Message.Role,
"content": choice.Message.Content,
},
"finish_reason": choice.FinishReason,
}
if len(choice.Message.ToolCalls) > 0 {
toolCalls := make([]map[string]interface{}, len(choice.Message.ToolCalls))
for j, tc := range choice.Message.ToolCalls {
toolCalls[j] = map[string]interface{}{
"id": tc.ID,
"type": tc.Type,
"function": map[string]interface{}{"name": tc.Function.Name, "arguments": tc.Function.Arguments},
}
}
choiceMap["message"].(map[string]interface{})["tool_calls"] = toolCalls
}
choices[i] = choiceMap
}
response["choices"] = choices
// Usage is always present (not a pointer)
response["usage"] = map[string]interface{}{
"prompt_tokens": resp.Usage.PromptTokens,
"completion_tokens": resp.Usage.CompletionTokens,
"total_tokens": resp.Usage.TotalTokens,
}
ctx.SetStatusCode(200)
ctx.SetContentType("application/json")
jsonData, _ := json.Marshal(response)
ctx.Write(jsonData)
}
// StreamErrorInfo holds parsed error information
type StreamErrorInfo struct {
Message string
Type string
Code int
IsTimeout bool
}
// parseStreamError parses error type and code
func parseStreamError(err error) StreamErrorInfo {
if err == nil {
return StreamErrorInfo{}
}
errorMsg := err.Error()
// Check timeout error by message prefix
isTimeout := strings.HasPrefix(errorMsg, "stream.Recv() timeout") || strings.Contains(errorMsg, "timeout after")
errorType := "server_error"
errorCode := 500
if isTimeout {
errorType = "timeout_error"
errorCode = 504
}
return StreamErrorInfo{
Message: errorMsg,
Type: errorType,
Code: errorCode,
IsTimeout: isTimeout,
}
}
// formatErrorJSON formats error as OpenAI JSON
func formatErrorJSON(errInfo StreamErrorInfo) string {
errorObj := map[string]interface{}{
"error": map[string]interface{}{
"message": errInfo.Message,
"type": errInfo.Type,
"code": errInfo.Code,
},
}
jsonBytes, _ := json.Marshal(errorObj)
return string(jsonBytes)
}
// sendSSEError sends SSE error response. Callers should log errors.
func (h *ChatHandler) sendSSEError(w *bufio.Writer, err error) (StreamErrorInfo, error) {
errInfo := parseStreamError(err)
errorJSON := formatErrorJSON(errInfo)
w.WriteString("data: ")
w.WriteString(errorJSON)
w.WriteString("\n\n")
if flushErr := w.Flush(); flushErr != nil && !isBrokenPipeError(flushErr) {
h.logger.Warn("Failed to flush error response", zap.Error(flushErr))
return errInfo, flushErr
}
return errInfo, nil
}
// HandleGenerate handles POST /generate (SGLang native API)
func (h *ChatHandler) HandleGenerate(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
defer func() {
statusCode := ctx.Response.StatusCode()
if statusCode == 0 {
statusCode = 200
}
h.logHTTPResponse(statusCode, path)
}()
// Parse request body
var req map[string]interface{}
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
h.logger.Warn("Invalid generate request", zap.Error(err))
utils.RespondError(ctx, 400, fmt.Sprintf("Invalid request: %v", err), "invalid_request_error")
return
}
// Extract text and sampling_params
text, ok := req["text"].(string)
if !ok || text == "" {
utils.RespondError(ctx, 400, "Missing or invalid 'text' field", "invalid_request_error")
return
}
samplingParams, _ := req["sampling_params"].(map[string]interface{})
if samplingParams == nil {
samplingParams = make(map[string]interface{})
}
// Convert to chat completion format for processing
chatReq := sglang.ChatCompletionRequest{
Model: "default",
Messages: []sglang.ChatMessage{{Role: "user", Content: text}},
Stream: false,
}
// Copy sampling params
if maxNewTokens, ok := samplingParams["max_new_tokens"].(float64); ok {
tokens := int(maxNewTokens)
chatReq.MaxCompletionTokens = &tokens
}
if temp, ok := samplingParams["temperature"].(float64); ok {
temp32 := float32(temp)
chatReq.Temperature = &temp32
}
if topP, ok := samplingParams["top_p"].(float64); ok {
topP32 := float32(topP)
chatReq.TopP = &topP32
}
if topK, ok := samplingParams["top_k"].(float64); ok {
topKInt := int(topK)
chatReq.TopK = &topKInt
}
requestCtx := context.Background()
// Use non-streaming completion for /generate endpoint
resp, err := h.service.Client().CreateChatCompletion(requestCtx, chatReq)
if err != nil {
h.logger.Error("Failed to create completion",
zap.Error(err),
)
utils.RespondError(ctx, 500, fmt.Sprintf("Failed to create completion: %v", err), "server_error")
return
}
// Convert to SGLang /generate response format
// meta_info must match SGLang's expected format with completion_tokens at top level
finishReason := resp.Choices[0].FinishReason
if finishReason == "" {
finishReason = "stop"
}
response := map[string]interface{}{
"text": resp.Choices[0].Message.Content,
"meta_info": map[string]interface{}{
"id": resp.ID,
"finish_reason": finishReason,
"prompt_tokens": resp.Usage.PromptTokens,
"completion_tokens": resp.Usage.CompletionTokens,
"cached_tokens": 0, // Not available from chat completion API
"weight_version": "", // Not available from chat completion API
},
}
ctx.SetStatusCode(200)
ctx.SetContentType("application/json")
jsonData, _ := json.Marshal(response)
ctx.Write(jsonData)
}

View File

@@ -0,0 +1,33 @@
package handlers
import (
"encoding/json"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
)
// HealthHandler handles health check requests
type HealthHandler struct {
logger *zap.Logger
}
// NewHealthHandler creates a new health handler
func NewHealthHandler(logger *zap.Logger) *HealthHandler {
return &HealthHandler{
logger: logger,
}
}
// Check handles GET /health
func (h *HealthHandler) Check(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetContentType("application/json")
response := map[string]string{
"status": "ok",
}
jsonData, _ := json.Marshal(response)
ctx.Write(jsonData)
}

View File

@@ -0,0 +1,67 @@
package handlers
import (
"encoding/json"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
)
// ModelsHandler handles model list requests
type ModelsHandler struct {
logger *zap.Logger
tokenizerPath string
}
// NewModelsHandler creates a new models handler
func NewModelsHandler(logger *zap.Logger, tokenizerPath string) *ModelsHandler {
return &ModelsHandler{
logger: logger,
tokenizerPath: tokenizerPath,
}
}
// List handles GET /v1/models
func (h *ModelsHandler) List(ctx *fasthttp.RequestCtx) {
// Return a default model for OpenAI compatibility
ctx.SetStatusCode(200)
ctx.SetContentType("application/json")
response := map[string]interface{}{
"object": "list",
"data": []map[string]interface{}{
{
"id": "default",
"object": "model",
"created": 1677610602,
"owned_by": "sglang",
},
},
}
jsonData, _ := json.Marshal(response)
ctx.Write(jsonData)
}
// GetModelInfo handles GET /get_model_info
// Returns model information compatible with SGLang RuntimeEndpoint
func (h *ModelsHandler) GetModelInfo(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetContentType("application/json")
// Return model info compatible with SGLang RuntimeEndpoint expectations
response := map[string]interface{}{
"model_path": h.tokenizerPath, // Use tokenizer path as model path
"tokenizer_path": h.tokenizerPath,
"is_generation": true,
"preferred_sampling_params": "",
"weight_version": "",
"has_image_understanding": false,
"has_audio_understanding": false,
"model_type": "",
"architectures": nil,
}
jsonData, _ := json.Marshal(response)
ctx.Write(jsonData)
}

View File

@@ -0,0 +1,67 @@
package logger
import (
"os"
"path/filepath"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
// Init initializes the logger with file and console output
func Init(logDir, logLevel string) (*zap.Logger, error) {
// Ensure log directory exists
if err := os.MkdirAll(logDir, 0755); err != nil {
return nil, err
}
// Parse log level
var level zapcore.Level
if err := level.UnmarshalText([]byte(logLevel)); err != nil {
level = zapcore.InfoLevel
}
// Create log file path with date
logFile := filepath.Join(logDir, "oai_server-"+time.Now().Format("2006-01-02")+".log")
// File writer with rotation
fileWriter := zapcore.AddSync(&lumberjack.Logger{
Filename: logFile,
MaxSize: 100, // megabytes
MaxBackups: 10,
MaxAge: 30, // days
Compress: true,
})
// Console writer
consoleWriter := zapcore.AddSync(os.Stdout)
// Encoder config
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = "timestamp"
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
// Create cores
fileCore := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
fileWriter,
level,
)
consoleCore := zapcore.NewCore(
zapcore.NewConsoleEncoder(encoderConfig),
consoleWriter,
level,
)
// Combine cores
core := zapcore.NewTee(fileCore, consoleCore)
// Create logger
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
return logger, nil
}

View File

@@ -0,0 +1,116 @@
// OpenAI-compatible chat server using SGLang Go SDK and fasthttp framework
package main
import (
"fmt"
"net/http"
"os"
_ "net/http/pprof" // Enable pprof endpoints
"github.com/valyala/fasthttp"
"go.uber.org/zap"
"oai_server/config"
"oai_server/handlers"
"oai_server/logger"
"oai_server/service"
)
// Version information (set at build time via ldflags)
var (
Version = "dev"
BuildTime = "unknown"
GitCommit = "unknown"
)
func main() {
// Load configuration
cfg := config.Load()
// Initialize logger
appLogger, err := logger.Init(cfg.LogDir, cfg.LogLevel)
if err != nil {
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
}
defer appLogger.Sync()
appLogger.Info("Starting OpenAI-compatible server",
zap.String("endpoint", cfg.Endpoint),
zap.String("tokenizer", cfg.TokenizerPath),
zap.String("port", cfg.Port),
)
// Initialize SGLang service
sglangService, err := service.NewSGLangService(cfg.Endpoint, cfg.TokenizerPath)
if err != nil {
appLogger.Fatal("Failed to create SGLang client", zap.Error(err))
}
defer sglangService.Close()
appLogger.Info("SGLang client created successfully")
// Enable pprof if requested
if os.Getenv("PPROF_ENABLED") == "true" {
pprofPort := os.Getenv("PPROF_PORT")
if pprofPort == "" {
pprofPort = "6060"
}
go func() {
pprofAddr := ":" + pprofPort
appLogger.Info("Starting pprof server", zap.String("address", pprofAddr))
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
appLogger.Error("pprof server failed", zap.Error(err))
}
}()
appLogger.Info("pprof enabled", zap.String("port", pprofPort), zap.String("endpoint", fmt.Sprintf("http://localhost:%s/debug/pprof/", pprofPort)))
}
// Initialize handlers
healthHandler := handlers.NewHealthHandler(appLogger)
modelsHandler := handlers.NewModelsHandler(appLogger, cfg.TokenizerPath)
chatHandler := handlers.NewChatHandler(appLogger, sglangService)
// Setup fasthttp router
router := func(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
method := string(ctx.Method())
switch {
case method == "GET" && path == "/health":
healthHandler.Check(ctx)
case method == "GET" && path == "/v1/models":
modelsHandler.List(ctx)
case method == "GET" && path == "/get_model_info":
modelsHandler.GetModelInfo(ctx)
case method == "POST" && path == "/v1/chat/completions":
chatHandler.HandleChatCompletion(ctx)
case (method == "POST" || method == "PUT") && path == "/generate":
chatHandler.HandleGenerate(ctx)
default:
ctx.Error("Not Found", fasthttp.StatusNotFound)
}
}
// Start server
serverAddr := ":" + cfg.Port
baseURL := fmt.Sprintf("http://localhost:%s", cfg.Port)
appLogger.Info("Server starting",
zap.String("address", serverAddr),
zap.String("base_url", baseURL),
)
// Print available HTTP endpoints (similar to FastAPI startup)
appLogger.Info("Available HTTP endpoints:")
appLogger.Info(fmt.Sprintf(" GET %s/health", baseURL))
appLogger.Info(fmt.Sprintf(" GET %s/v1/models", baseURL))
appLogger.Info(fmt.Sprintf(" GET %s/get_model_info", baseURL))
appLogger.Info(fmt.Sprintf(" POST %s/v1/chat/completions", baseURL))
appLogger.Info(fmt.Sprintf(" POST %s/generate", baseURL))
appLogger.Info(fmt.Sprintf("Application startup complete. Listening on %s", baseURL))
if err := fasthttp.ListenAndServe(serverAddr, router); err != nil {
appLogger.Fatal("Server failed", zap.Error(err))
}
}

View File

@@ -0,0 +1,14 @@
package models
// ChatRequest represents an OpenAI-compatible chat completion request
type ChatRequest struct {
Model string `json:"model" binding:"required"`
Messages []map[string]string `json:"messages" binding:"required"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"` // OpenAI API standard field
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // SGLang-specific field (used by bench_serving.py)
Tools []map[string]interface{} `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
}

View File

@@ -0,0 +1,111 @@
#!/bin/bash
# OpenAI-compatible server runner
# Usage: ./run.sh [tokenizer_path] [endpoint] [port] [--profile] [--pprof-port PORT]
#
# Options:
# --profile Enable pprof profiling (default port: 6060)
# --pprof-port PORT Set pprof port (default: 6060, requires --profile)
# Set library path for Rust FFI library
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BINDINGS_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)"
LIB_DIR="${BINDINGS_DIR}/lib"
if [ ! -d "$LIB_DIR" ]; then
echo "Error: Library directory not found at $LIB_DIR"
echo "Please run 'make lib' first to build and export the library"
exit 1
fi
# Get Python LDFLAGS (needed for Rust FFI that depends on Python)
PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
# Set CGO_LDFLAGS to link with the Rust library
# Note: -lsgl_model_gateway_go and -ldl are already in the #cgo directive in internal/ffi/client.go
# We only need to add the library path (-L) and Python flags
export CGO_LDFLAGS="-L${LIB_DIR} ${PYTHON_LDFLAGS}"
# macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH
if [[ "$OSTYPE" == "darwin"* ]]; then
export DYLD_LIBRARY_PATH="${LIB_DIR}:${DYLD_LIBRARY_PATH}"
else
export LD_LIBRARY_PATH="${LIB_DIR}:${LD_LIBRARY_PATH}"
fi
# Parse arguments
ENABLE_PROFILE=false
PPROF_PORT="6060"
TOKENIZER_PATH=""
ENDPOINT=""
PORT=""
while [[ $# -gt 0 ]]; do
case $1 in
--profile)
ENABLE_PROFILE=true
shift
;;
--pprof-port)
ENABLE_PROFILE=true
PPROF_PORT="$2"
shift 2
;;
*)
if [[ -z "$TOKENIZER_PATH" ]]; then
TOKENIZER_PATH="$1"
elif [[ -z "$ENDPOINT" ]]; then
ENDPOINT="$1"
elif [[ -z "$PORT" ]]; then
PORT="$1"
fi
shift
;;
esac
done
# Default configuration
DEFAULT_TOKENIZER_PATH="${SGL_TOKENIZER_PATH:-../tokenizer}"
DEFAULT_ENDPOINT="${SGL_GRPC_ENDPOINT:-grpc://localhost:20000}"
DEFAULT_PORT="${PORT:-8080}"
TOKENIZER_PATH="${TOKENIZER_PATH:-${DEFAULT_TOKENIZER_PATH}}"
ENDPOINT="${ENDPOINT:-${DEFAULT_ENDPOINT}}"
PORT="${PORT:-${DEFAULT_PORT}}"
echo "Running OpenAI-compatible server..."
echo "Library path: ${LIB_DIR}"
echo "Tokenizer: $TOKENIZER_PATH"
echo "Endpoint: $ENDPOINT"
echo "Port: $PORT"
echo "Client Mode: gRPC (default)"
echo "FFI Postprocessing: ENABLED (normal mode)"
echo "FFI Preprocessing: ENABLED (normal mode)"
if [[ "$ENABLE_PROFILE" == "true" ]]; then
echo "Profiling: enabled (port: $PPROF_PORT)"
echo " pprof endpoint: http://localhost:$PPROF_PORT/debug/pprof/"
export PPROF_ENABLED=true
export PPROF_PORT="$PPROF_PORT"
else
echo "Profiling: disabled"
fi
echo ""
# Change to script directory
cd "$(dirname "${BASH_SOURCE[0]}")"
# Ensure Go module is properly initialized
if [ ! -f "go.mod" ]; then
echo "Error: go.mod not found in $(pwd)"
exit 1
fi
# Ensure Go modules are enabled
export GO111MODULE=on
# Sync Go module dependencies
echo "Syncing Go module dependencies..."
go mod tidy
# Run the server (use ./main.go to ensure module context is correct)
SGL_TOKENIZER_PATH="$TOKENIZER_PATH" SGL_GRPC_ENDPOINT="$ENDPOINT" PORT="$PORT" go run ./main.go

View File

@@ -0,0 +1,554 @@
#!/bin/bash
# TPOT performance bottleneck analysis script
# Specifically designed to analyze why Go Router is twice as slow as Rust Router
#
# Usage:
# ./scripts/analyze_tpot.sh [options]
#
# Options:
# --duration SECONDS CPU profile duration (default: 60)
# --requests NUM Number of requests (default: 100)
# --concurrency NUM Concurrency level (default: 20)
# --pprof-port PORT pprof port (default: 6060)
# --server-url URL Server URL (default: http://localhost:8080)
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
PROFILE_DIR="${PROJECT_ROOT}/profiles"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
OUTPUT_DIR="${PROFILE_DIR}/tpot_analysis_${TIMESTAMP}"
# Colors
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
BLUE='\033[0;34m'
NC='\033[0m'
# Default values
DURATION=${DURATION:-60}
NUM_REQUESTS=${NUM_REQUESTS:-100}
CONCURRENCY=${CONCURRENCY:-20}
PPROF_PORT=${PPROF_PORT:-6060}
SERVER_URL=${SERVER_URL:-http://localhost:8080}
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--duration)
DURATION="$2"
shift 2
;;
--requests)
NUM_REQUESTS="$2"
shift 2
;;
--concurrency)
CONCURRENCY="$2"
shift 2
;;
--pprof-port)
PPROF_PORT="$2"
shift 2
;;
--server-url)
SERVER_URL="$2"
shift 2
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
mkdir -p "$OUTPUT_DIR"
# Check for graphviz (optional, needed for some pprof visualizations)
HAS_GRAPHVIZ=false
if command -v dot >/dev/null 2>&1; then
HAS_GRAPHVIZ=true
fi
echo -e "${BLUE}========================================${NC}"
echo -e "${BLUE}TPOT Performance Bottleneck Analysis${NC}"
echo -e "${BLUE}========================================${NC}"
echo ""
echo "Configuration:"
echo " Duration: ${DURATION}s"
echo " Requests: $NUM_REQUESTS"
echo " Concurrency: $CONCURRENCY"
echo " Server URL: $SERVER_URL"
echo " pprof Port: $PPROF_PORT"
echo " Output Dir: $OUTPUT_DIR"
if [ "$HAS_GRAPHVIZ" = "false" ]; then
echo ""
echo -e "${YELLOW}Note: graphviz not found. Some pprof visualizations may not work.${NC}"
echo -e "${YELLOW}To install graphviz:${NC}"
echo -e "${YELLOW} macOS: brew install graphviz${NC}"
echo -e "${YELLOW} Ubuntu: sudo apt-get install graphviz${NC}"
echo -e "${YELLOW} CentOS: sudo yum install graphviz${NC}"
echo -e "${YELLOW}Text reports will still be generated without graphviz.${NC}"
fi
echo ""
# Check if server is running
echo -e "${YELLOW}[Check] Verifying server is running...${NC}"
if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then
echo -e "${RED}Error: Server not responding at ${SERVER_URL}${NC}"
echo ""
echo "Please start the server first with profiling enabled:"
echo " ./run.sh --profile --pprof-port $PPROF_PORT"
echo " or"
echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT make run"
exit 1
fi
echo -e "${GREEN}✓ Server is running${NC}"
echo ""
# Check if pprof is enabled
echo -e "${YELLOW}[Check] Verifying pprof is enabled...${NC}"
if ! curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then
echo -e "${RED}Error: pprof not accessible at http://localhost:${PPROF_PORT}/debug/pprof/${NC}"
echo ""
echo "Please start the server with profiling enabled:"
echo " ./run.sh --profile --pprof-port $PPROF_PORT"
exit 1
fi
echo -e "${GREEN}✓ pprof is enabled${NC}"
echo ""
# ============================================
# Step 1: Collect baseline profiles
# ============================================
echo -e "${GREEN}[Step 1/8] Collecting baseline profiles...${NC}"
# Baseline memory
go tool pprof -proto -output="${OUTPUT_DIR}/heap_before.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
# Baseline goroutine
go tool pprof -proto -output="${OUTPUT_DIR}/goroutine_before.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/goroutine" > /dev/null 2>&1 || true
echo -e "${GREEN}✓ Baseline profiles collected${NC}"
echo ""
# ============================================
# Step 2: Start CPU profile collection
# ============================================
echo -e "${GREEN}[Step 2/8] Starting CPU profile collection (${DURATION}s)...${NC}"
go tool pprof -proto -output="${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}" &
CPU_PID=$!
sleep 2
echo -e "${GREEN}✓ CPU profile collection started${NC}"
echo ""
# ============================================
# Step 3: Run load test with streaming requests
# ============================================
echo -e "${GREEN}[Step 3/8] Running load test ($NUM_REQUESTS streaming requests, concurrency=$CONCURRENCY)...${NC}"
# Function to run a single streaming request
run_streaming_request() {
local request_id=$1
local start_time=$(date +%s)
local start_nanos=$(date +%N 2>/dev/null || echo "000000000")
curl -N -s -X POST "${SERVER_URL}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{
\"model\": \"default\",
\"messages\": [{\"role\": \"user\", \"content\": \"Write a 500-word story with character dialogue and scene descriptions\"}],
\"stream\": true,
\"max_tokens\": 300,
\"temperature\": 0.7
}" > /dev/null
local end_time=$(date +%s)
local end_nanos=$(date +%N 2>/dev/null || echo "000000000")
local duration=$((end_time - start_time))
echo "$duration" >> "${OUTPUT_DIR}/request_times.txt"
}
# Run requests with controlled concurrency
# Use a temporary file to track job PIDs to avoid conflicts with CPU_PID
JOB_PIDS_FILE="${OUTPUT_DIR}/.job_pids_$$"
> "$JOB_PIDS_FILE"
for i in $(seq 1 $NUM_REQUESTS); do
# Wait if we've reached concurrency limit
while [ $(wc -l < "$JOB_PIDS_FILE" 2>/dev/null || echo 0) -ge $CONCURRENCY ]; do
# Check and remove completed jobs
while IFS= read -r pid; do
if [ -n "$pid" ] && ! kill -0 "$pid" 2>/dev/null; then
# Process completed, remove from file
grep -v "^${pid}$" "$JOB_PIDS_FILE" > "${JOB_PIDS_FILE}.tmp" && \
mv "${JOB_PIDS_FILE}.tmp" "$JOB_PIDS_FILE" || true
fi
done < "$JOB_PIDS_FILE"
sleep 0.1
done
# Start new request
run_streaming_request $i &
echo $! >> "$JOB_PIDS_FILE"
# Progress indicator
if [ $((i % 10)) -eq 0 ]; then
echo " Progress: $i/$NUM_REQUESTS requests sent..."
fi
done
# Wait for all remaining jobs (excluding CPU_PID)
while IFS= read -r pid; do
if [ -n "$pid" ] && [ "$pid" != "$CPU_PID" ]; then
wait "$pid" 2>/dev/null || true
fi
done < "$JOB_PIDS_FILE"
# Clean up
rm -f "$JOB_PIDS_FILE" "${JOB_PIDS_FILE}.tmp" 2>/dev/null || true
echo -e "${GREEN}✓ Load test completed${NC}"
echo ""
# ============================================
# Step 4: Wait for CPU profile to complete
# ============================================
echo -e "${GREEN}[Step 4/8] Waiting for CPU profile to complete...${NC}"
# Wait for the process, but handle the case where it might have already completed
if kill -0 $CPU_PID 2>/dev/null; then
wait $CPU_PID 2>/dev/null || true
else
# Process already completed, just wait a bit to ensure file is written
sleep 1
fi
echo -e "${GREEN}✓ CPU profile collection completed${NC}"
echo ""
# ============================================
# Step 5: Collect final profiles
# ============================================
echo -e "${GREEN}[Step 5/8] Collecting final profiles...${NC}"
# Final memory
go tool pprof -proto -output="${OUTPUT_DIR}/heap_after.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
# Final goroutine
go tool pprof -proto -output="${OUTPUT_DIR}/goroutine_after.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/goroutine" > /dev/null 2>&1 || true
# Mutex profile
go tool pprof -proto -output="${OUTPUT_DIR}/mutex.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/mutex" > /dev/null 2>&1 || true
# Block profile
go tool pprof -proto -output="${OUTPUT_DIR}/block.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/block" > /dev/null 2>&1 || true
echo -e "${GREEN}✓ Final profiles collected${NC}"
echo ""
# ============================================
# Step 6: Generate analysis reports
# ============================================
echo -e "${GREEN}[Step 6/8] Generating analysis reports...${NC}"
# CPU analysis
echo " Generating CPU reports..."
go tool pprof -top -cum "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/01_cpu_top_cum.txt" 2>&1 || true
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/02_cpu_top_flat.txt" 2>&1 || true
# Memory analysis
echo " Generating memory reports..."
if [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then
go tool pprof -top -alloc_space "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/03_memory_alloc_space.txt" 2>&1 || true
go tool pprof -top -alloc_objects "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/04_memory_alloc_objects.txt" 2>&1 || true
go tool pprof -top -inuse_space "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/05_memory_inuse_space.txt" 2>&1 || true
fi
# Memory growth
if [ -f "${OUTPUT_DIR}/heap_before.pb.gz" ] && [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then
go tool pprof -top -base="${OUTPUT_DIR}/heap_before.pb.gz" \
"${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/06_memory_growth.txt" 2>&1 || true
fi
# FFI/CGO analysis
echo " Analyzing FFI/CGO calls..."
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \
grep -iE "(block_on|CGO|FFI|ffi|runtime\.cgo|_Cfunc)" > "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt" || \
echo "No FFI/CGO related functions found" > "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt"
# JSON serialization analysis
echo " Analyzing JSON serialization..."
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \
grep -iE "(json|Marshal|Unmarshal|Encode|Decode|sonic|jsoniter)" > "${OUTPUT_DIR}/08_json_analysis.txt" || \
echo "No JSON related functions found" > "${OUTPUT_DIR}/08_json_analysis.txt"
# Goroutine analysis
if [ -f "${OUTPUT_DIR}/goroutine_after.pb.gz" ]; then
echo " Analyzing goroutines..."
go tool pprof -top "${OUTPUT_DIR}/goroutine_after.pb.gz" > "${OUTPUT_DIR}/09_goroutine_analysis.txt" 2>&1 || true
fi
# Mutex analysis
if [ -f "${OUTPUT_DIR}/mutex.pb.gz" ]; then
echo " Analyzing mutex contention..."
go tool pprof -top "${OUTPUT_DIR}/mutex.pb.gz" > "${OUTPUT_DIR}/10_mutex_analysis.txt" 2>&1 || true
fi
# Block analysis
if [ -f "${OUTPUT_DIR}/block.pb.gz" ]; then
echo " Analyzing blocking operations..."
go tool pprof -top "${OUTPUT_DIR}/block.pb.gz" > "${OUTPUT_DIR}/11_block_analysis.txt" 2>&1 || true
fi
# Request timing statistics
if [ -f "${OUTPUT_DIR}/request_times.txt" ] && [ -s "${OUTPUT_DIR}/request_times.txt" ]; then
echo " Calculating request timing statistics..."
{
echo "Request Timing Statistics"
echo "========================"
echo ""
echo "Total requests: $(wc -l < "${OUTPUT_DIR}/request_times.txt" | tr -d ' ')"
echo ""
awk '{
sum+=$1
sumsq+=$1*$1
if(NR==1 || $1<min) min=$1
if(NR==1 || $1>max) max=$1
} END {
if(NR > 0) {
mean=sum/NR
variance=(sumsq/NR - mean*mean)
stddev=sqrt(variance)
print "Min: " min "s"
print "Max: " max "s"
print "Mean: " mean "s"
print "StdDev: " stddev "s"
}
}' "${OUTPUT_DIR}/request_times.txt"
} > "${OUTPUT_DIR}/12_request_timing.txt"
fi
echo -e "${GREEN}✓ Analysis reports generated${NC}"
echo ""
# ============================================
# Step 7: Generate summary report
# ============================================
echo -e "${GREEN}[Step 7/8] Generating summary report...${NC}"
SUMMARY_FILE="${OUTPUT_DIR}/00_SUMMARY.md"
cat > "$SUMMARY_FILE" <<EOF
# TPOT Performance Analysis Summary
**Analysis Date:** $(date)
**Duration:** ${DURATION}s
**Requests:** $NUM_REQUESTS
**Concurrency:** $CONCURRENCY
## Key Findings
### 1. CPU Hotspots (Top 10 Cumulative Time)
\`\`\`
$(head -15 "${OUTPUT_DIR}/01_cpu_top_cum.txt" | tail -10)
\`\`\`
### 2. CPU Hotspots (Top 10 Flat Time)
\`\`\`
$(head -15 "${OUTPUT_DIR}/02_cpu_top_flat.txt" | tail -10)
\`\`\`
### 3. FFI/CGO Overhead
\`\`\`
$(cat "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt")
\`\`\`
### 4. JSON Serialization Overhead
\`\`\`
$(cat "${OUTPUT_DIR}/08_json_analysis.txt")
\`\`\`
### 5. Memory Allocation (Top 10 by Space)
\`\`\`
$(head -15 "${OUTPUT_DIR}/03_memory_alloc_space.txt" | tail -10)
\`\`\`
### 6. Memory Allocation (Top 10 by Objects)
\`\`\`
$(head -15 "${OUTPUT_DIR}/04_memory_alloc_objects.txt" | tail -10)
\`\`\`
### 7. Mutex Contention
\`\`\`
$(head -15 "${OUTPUT_DIR}/10_mutex_analysis.txt" | tail -10 2>/dev/null || echo "No significant mutex contention detected")
\`\`\`
### 8. Blocking Operations
\`\`\`
$(head -15 "${OUTPUT_DIR}/11_block_analysis.txt" | tail -10 2>/dev/null || echo "No significant blocking detected")
\`\`\`
## Performance Bottlenecks Identified
### High Priority Issues
1. **FFI/CGO Overhead**
- Check: \`cat ${OUTPUT_DIR}/07_ffi_cgo_analysis.txt\`
- Impact: FFI calls add overhead compared to native Rust code
- Recommendation: Minimize FFI calls, batch operations
2. **JSON Serialization**
- Check: \`cat ${OUTPUT_DIR}/08_json_analysis.txt\`
- Impact: JSON marshaling/unmarshaling can be expensive
- Recommendation: Use faster JSON library (jsoniter), reduce serialization frequency
3. **Memory Allocations**
- Check: \`cat ${OUTPUT_DIR}/03_memory_alloc_space.txt\`
- Impact: Frequent allocations cause GC pressure
- Recommendation: Use object pools, pre-allocate buffers
### Medium Priority Issues
4. **Goroutine Overhead**
- Check: \`cat ${OUTPUT_DIR}/09_goroutine_analysis.txt\`
- Impact: Too many goroutines can cause scheduling overhead
- Recommendation: Limit goroutine count, use worker pools
5. **Lock Contention**
- Check: \`cat ${OUTPUT_DIR}/10_mutex_analysis.txt\`
- Impact: Lock contention reduces parallelism
- Recommendation: Reduce lock granularity, use lock-free structures
## Comparison with Rust Router
### Expected Differences
1. **FFI Overhead**: Go → Rust FFI calls add ~100-500ns per call
2. **GC Overhead**: Go's GC can cause pauses (usually <1ms)
3. **JSON Library**: Go's standard library is slower than Rust's serde
4. **Memory Layout**: Go's GC affects cache locality
### Optimization Opportunities
1. **Reduce FFI Calls**
- Batch token processing
- Use async FFI (if possible)
- Cache frequently used FFI results
2. **Optimize JSON**
- Use jsoniter (already implemented)
- Pre-allocate JSON buffers
- Reduce serialization frequency
3. **Memory Management**
- Use sync.Pool for frequently allocated objects
- Pre-allocate slices with known capacity
- Avoid unnecessary string copies
4. **Concurrency**
- Use worker pools instead of spawning goroutines per request
- Limit concurrent FFI calls
- Use channels efficiently
## Next Steps
1. Review detailed reports in this directory
2. Use interactive pprof: \`go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz\`
3. Compare with Rust router profiles (if available)
4. Implement optimizations based on findings
5. Re-run analysis to measure improvements
## Files Generated
- \`00_SUMMARY.md\` - This summary
- \`01_cpu_top_cum.txt\` - CPU top functions (cumulative)
- \`02_cpu_top_flat.txt\` - CPU top functions (flat)
- \`03_memory_alloc_space.txt\` - Memory allocation by space
- \`04_memory_alloc_objects.txt\` - Memory allocation by objects
- \`05_memory_inuse_space.txt\` - Memory in use by space
- \`06_memory_growth.txt\` - Memory growth during test
- \`07_ffi_cgo_analysis.txt\` - FFI/CGO overhead analysis
- \`08_json_analysis.txt\` - JSON serialization analysis
- \`09_goroutine_analysis.txt\` - Goroutine analysis
- \`10_mutex_analysis.txt\` - Mutex contention analysis
- \`11_block_analysis.txt\` - Blocking operations analysis
- \`12_request_timing.txt\` - Request timing statistics
- \`*.pb.gz\` - Raw profile files for interactive analysis
EOF
echo -e "${GREEN}✓ Summary report generated${NC}"
echo ""
# ============================================
# Step 8: Display summary
# ============================================
echo -e "${GREEN}[Step 8/8] Analysis Complete!${NC}"
echo ""
echo -e "${BLUE}========================================${NC}"
echo -e "${BLUE}Summary${NC}"
echo -e "${BLUE}========================================${NC}"
echo ""
echo -e "${YELLOW}Top CPU Hotspots (Cumulative):${NC}"
head -12 "${OUTPUT_DIR}/01_cpu_top_cum.txt" | tail -10
echo ""
echo -e "${YELLOW}FFI/CGO Overhead:${NC}"
cat "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt"
echo ""
echo -e "${YELLOW}JSON Serialization Overhead:${NC}"
cat "${OUTPUT_DIR}/08_json_analysis.txt"
echo ""
echo -e "${YELLOW}Top Memory Allocations:${NC}"
head -12 "${OUTPUT_DIR}/03_memory_alloc_space.txt" | tail -10
echo ""
if [ -f "${OUTPUT_DIR}/12_request_timing.txt" ]; then
echo -e "${YELLOW}Request Timing:${NC}"
cat "${OUTPUT_DIR}/12_request_timing.txt"
echo ""
fi
echo -e "${GREEN}========================================${NC}"
echo ""
echo -e "${BLUE}Detailed Reports:${NC}"
echo " Summary: cat ${OUTPUT_DIR}/00_SUMMARY.md"
echo " CPU (cum): cat ${OUTPUT_DIR}/01_cpu_top_cum.txt"
echo " CPU (flat): cat ${OUTPUT_DIR}/02_cpu_top_flat.txt"
echo " FFI/CGO: cat ${OUTPUT_DIR}/07_ffi_cgo_analysis.txt"
echo " JSON: cat ${OUTPUT_DIR}/08_json_analysis.txt"
echo " Memory: cat ${OUTPUT_DIR}/03_memory_alloc_space.txt"
echo ""
echo -e "${BLUE}Interactive Analysis:${NC}"
echo " Run: go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz"
echo " Then visit:"
echo " - http://localhost:8081/ui/flamegraph (Flame Graph - no graphviz needed)"
echo " - http://localhost:8081/ui/top (Top Functions - no graphviz needed)"
if [ "$HAS_GRAPHVIZ" = "true" ]; then
echo " - http://localhost:8081/ui/graph (Call Graph - requires graphviz)"
else
echo " - http://localhost:8081/ui/graph (Call Graph - requires graphviz, not available)"
fi
echo ""
if [ "$HAS_GRAPHVIZ" = "false" ]; then
echo -e "${YELLOW}Note: Install graphviz to enable call graph visualization:${NC}"
echo -e "${YELLOW} macOS: brew install graphviz${NC}"
echo -e "${YELLOW} Ubuntu: sudo apt-get install graphviz${NC}"
echo -e "${YELLOW} CentOS: sudo yum install graphviz${NC}"
echo ""
fi
echo -e "${GREEN}All files saved to: ${OUTPUT_DIR}${NC}"
echo ""

View File

@@ -0,0 +1,215 @@
#!/bin/bash
# pprof performance analysis script
# Used to analyze performance bottlenecks of Go OpenAI server
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
# Configuration
PPROF_PORT=${PPROF_PORT:-6060}
SERVER_PORT=${SERVER_PORT:-8080}
DURATION=${DURATION:-60} # Performance test duration (seconds)
OUTPUT_DIR="./pprof_results"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
# Create output directory
mkdir -p "$OUTPUT_DIR"
echo "=========================================="
echo "pprof Performance Analysis Tool"
echo "=========================================="
echo "PPROF_PORT: $PPROF_PORT"
echo "SERVER_PORT: $SERVER_PORT"
echo "DURATION: ${DURATION}s"
echo "OUTPUT_DIR: $OUTPUT_DIR"
echo ""
# Check if go tool pprof is available
if ! command -v go &> /dev/null; then
echo "Error: go command not found"
exit 1
fi
# Check if server is running
check_server() {
if curl -s "http://localhost:${SERVER_PORT}/health" > /dev/null 2>&1; then
return 0
else
return 1
fi
}
# Check if pprof is available
check_pprof() {
if curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then
return 0
else
return 1
fi
}
# Start server (if not running)
if ! check_server; then
echo "Server not running, please start the server first:"
echo " export PPROF_ENABLED=true"
echo " export PPROF_PORT=$PPROF_PORT"
echo " ./oai_server"
echo ""
echo "Or use the following command to start:"
echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT ./oai_server"
echo ""
read -p "Start server now? (y/n) " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
echo "Starting server..."
PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT ./oai_server &
SERVER_PID=$!
echo "Server PID: $SERVER_PID"
# Wait for server to start
echo "Waiting for server to start..."
for i in {1..30}; do
if check_server; then
echo "Server started"
break
fi
sleep 1
done
if ! check_server; then
echo "Error: Server failed to start"
kill $SERVER_PID 2>/dev/null || true
exit 1
fi
else
exit 1
fi
fi
# Check if pprof is available
if ! check_pprof; then
echo "Error: pprof not enabled. Please set environment variables:"
echo " export PPROF_ENABLED=true"
echo " export PPROF_PORT=$PPROF_PORT"
exit 1
fi
echo "Starting to collect performance data..."
echo ""
# 1. CPU Profile (30 seconds)
echo "[1/6] Collecting CPU Profile (30 seconds)..."
go tool pprof -proto -output="$OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30" &
CPU_PID=$!
# 2. Collect Heap Profile simultaneously
echo "[2/6] Collecting Heap Profile..."
go tool pprof -proto -output="$OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/heap" &
HEAP_PID=$!
# 3. Collect Goroutine Profile
echo "[3/6] Collecting Goroutine Profile..."
go tool pprof -proto -output="$OUTPUT_DIR/goroutine_${TIMESTAMP}.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/goroutine" &
GOROUTINE_PID=$!
# 4. Collect Mutex Profile
echo "[4/6] Collecting Mutex Profile..."
go tool pprof -proto -output="$OUTPUT_DIR/mutex_${TIMESTAMP}.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/mutex" &
MUTEX_PID=$!
# 5. Collect Block Profile
echo "[5/6] Collecting Block Profile..."
go tool pprof -proto -output="$OUTPUT_DIR/block_${TIMESTAMP}.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/block" &
BLOCK_PID=$!
# 6. Run performance test (during CPU profile collection)
echo "[6/6] Running performance test..."
echo "Tip: Please use your performance testing tool (curl, ab, wrk, etc.) to send requests to the server"
echo " CPU profile will collect 30 seconds of performance data"
echo ""
# Wait for CPU profile to complete
wait $CPU_PID
echo "CPU Profile collection completed"
# Wait for other profiles
wait $HEAP_PID
wait $GOROUTINE_PID
wait $MUTEX_PID
wait $BLOCK_PID
echo ""
echo "=========================================="
echo "Performance data collection completed!"
echo "=========================================="
echo ""
echo "Generated analysis files:"
ls -lh "$OUTPUT_DIR"/*_${TIMESTAMP}.* 2>/dev/null || true
echo ""
# Generate analysis report
echo "Generating analysis report..."
echo ""
# CPU Top 20
echo "=== CPU Top 20 (sorted by flat time) ===" > "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
go tool pprof -top -cum "$OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
# Heap Top 20
echo "=== Heap Top 20 (sorted by allocation size) ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
go tool pprof -top "$OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
# Goroutine statistics
echo "=== Goroutine Statistics ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
go tool pprof -top "$OUTPUT_DIR/goroutine_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
# Mutex statistics
echo "=== Mutex Wait Time ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
go tool pprof -top "$OUTPUT_DIR/mutex_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
# Block statistics
echo "=== Block Wait Time ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
go tool pprof -top "$OUTPUT_DIR/block_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true
echo "Analysis report saved to: $OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
echo ""
# Display key information
echo "=========================================="
echo "Key Performance Metrics Summary"
echo "=========================================="
echo ""
echo "View detailed report:"
echo " cat $OUTPUT_DIR/analysis_${TIMESTAMP}.txt"
echo ""
echo "Interactive CPU Profile view:"
echo " go tool pprof $OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz"
echo ""
echo "Interactive Heap Profile view:"
echo " go tool pprof $OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz"
echo ""
echo "Generate flame graph (requires go-torch or pprof):"
echo " go tool pprof -http=:8080 $OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz"
echo ""
# If server was started, ask if it should be closed
if [ -n "$SERVER_PID" ]; then
read -p "Close server? (y/n) " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
kill $SERVER_PID 2>/dev/null || true
echo "Server closed"
fi
fi

View File

@@ -0,0 +1,52 @@
#!/bin/bash
# Quick pprof analysis script
# Collects 30-second CPU profile and immediately displays top results
set -e
PPROF_PORT=${PPROF_PORT:-6060}
DURATION=${DURATION:-30}
echo "=========================================="
echo "Quick pprof Analysis"
echo "=========================================="
echo "PPROF_PORT: $PPROF_PORT"
echo "DURATION: ${DURATION}s"
echo ""
echo "Tip: During data collection, please send requests to the server"
echo " You can use: ./pprof_test.sh"
echo ""
# Check if pprof is available
if ! curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then
echo "Error: pprof not enabled. Please set environment variables:"
echo " export PPROF_ENABLED=true"
echo " export PPROF_PORT=$PPROF_PORT"
exit 1
fi
echo "Starting to collect CPU Profile (${DURATION} seconds)..."
echo ""
# Collect CPU profile and directly display top results
go tool pprof -top -cum "http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}"
echo ""
echo "=========================================="
echo "Analysis Complete"
echo "=========================================="
echo ""
echo "More analysis options:"
echo " # Interactive view"
echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30"
echo ""
echo " # View heap memory"
echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/heap"
echo ""
echo " # View goroutines"
echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/goroutine"
echo ""
echo " # Generate Web UI"
echo " go tool pprof -http=:8080 http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30"
echo ""

View File

@@ -0,0 +1,87 @@
#!/bin/bash
# Simple performance test script for sending requests while collecting pprof data
set -e
SERVER_URL=${SERVER_URL:-"http://localhost:8080"}
DURATION=${DURATION:-30} # Test duration (seconds)
CONCURRENT=${CONCURRENT:-1} # Number of concurrent requests
echo "=========================================="
echo "Performance Test Script"
echo "=========================================="
echo "SERVER_URL: $SERVER_URL"
echo "DURATION: ${DURATION}s"
echo "CONCURRENT: $CONCURRENT"
echo ""
# Test request JSON
TEST_REQUEST='{
"model": "default",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
],
"stream": true,
"max_tokens": 100
}'
# Check if server is available
if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then
echo "Error: Server not available (${SERVER_URL}/health)"
exit 1
fi
echo "Starting to send test requests..."
echo ""
# Function to send streaming request
send_stream_request() {
local request_num=$1
local start_time=$(date +%s.%N)
curl -s -N -X POST "${SERVER_URL}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "$TEST_REQUEST" \
> /dev/null 2>&1
local end_time=$(date +%s.%N)
local duration=$(echo "$end_time - $start_time" | bc)
echo "Request $request_num completed, duration: ${duration}s"
}
# Send requests concurrently
if [ "$CONCURRENT" -eq 1 ]; then
# Single-threaded mode: continuously send requests
end_time=$(($(date +%s) + DURATION))
request_count=0
while [ $(date +%s) -lt $end_time ]; do
request_count=$((request_count + 1))
send_stream_request $request_count
done
echo ""
echo "Test completed, sent $request_count requests"
else
# Multi-threaded mode: send requests concurrently
end_time=$(($(date +%s) + DURATION))
request_count=0
while [ $(date +%s) -lt $end_time ]; do
# Start concurrent requests
for i in $(seq 1 $CONCURRENT); do
request_count=$((request_count + 1))
send_stream_request $request_count &
done
# Wait for all requests to complete
wait
# Brief rest to avoid overload
sleep 0.1
done
echo ""
echo "Test completed, sent $request_count requests"
fi

View File

@@ -0,0 +1,140 @@
#!/bin/bash
# TPOT performance analysis script
# Quickly collect and analyze TPOT-related performance data
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
PROFILE_DIR="${PROJECT_ROOT}/profiles"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
OUTPUT_DIR="${PROFILE_DIR}/${TIMESTAMP}"
# Colors
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
# Default values
PPROF_PORT=${PPROF_PORT:-6060}
SERVER_URL=${SERVER_URL:-http://localhost:8080}
DURATION=${DURATION:-30}
NUM_REQUESTS=${NUM_REQUESTS:-20}
mkdir -p "$OUTPUT_DIR"
echo -e "${GREEN}TPOT Performance Analysis${NC}"
echo "=========================="
echo "Profile directory: $OUTPUT_DIR"
echo "Duration: ${DURATION}s"
echo "Requests: $NUM_REQUESTS"
echo ""
# Check if server is running
if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then
echo -e "${YELLOW}Warning: Server not responding at ${SERVER_URL}${NC}"
echo "Please start the server first with profiling enabled:"
echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT make run"
exit 1
fi
# Collect baseline memory
echo -e "${GREEN}[1/5] Collecting baseline memory profile...${NC}"
go tool pprof -proto -output="${OUTPUT_DIR}/heap_before.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
# Start CPU profile collection in background
echo -e "${GREEN}[2/5] Starting CPU profile collection (${DURATION}s)...${NC}"
go tool pprof -proto -output="${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}" &
CPU_PID=$!
# Wait a bit for profile to start
sleep 2
# Run load test
echo -e "${GREEN}[3/5] Running load test ($NUM_REQUESTS requests)...${NC}"
for i in $(seq 1 $NUM_REQUESTS); do
curl -N -s -X POST "${SERVER_URL}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{
\"model\": \"default\",
\"messages\": [{\"role\": \"user\", \"content\": \"Write a story\"}],
\"stream\": true,
\"max_tokens\": 200
}" > /dev/null &
# Limit concurrency
if [ $((i % 5)) -eq 0 ]; then
wait
fi
done
wait
# Wait for CPU profile to complete
echo -e "${GREEN}[4/5] Waiting for CPU profile to complete...${NC}"
# Wait for the CPU profile process, but handle the case where it's not a child process
if kill -0 $CPU_PID 2>/dev/null; then
# Process is still running, wait for it
while kill -0 $CPU_PID 2>/dev/null; do
sleep 1
done
else
# Process already completed or not found, just wait a bit
sleep 2
fi
# Collect final memory
echo -e "${GREEN}[5/5] Collecting final memory profile...${NC}"
go tool pprof -proto -output="${OUTPUT_DIR}/heap_after.pb.gz" \
"http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true
# Generate reports
echo ""
echo -e "${GREEN}Generating reports...${NC}"
# CPU top (cumulative)
go tool pprof -top -cum "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/cpu_top_cum.txt" 2>&1 || true
# CPU top (flat)
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/cpu_top_flat.txt" 2>&1 || true
# Memory growth
if [ -f "${OUTPUT_DIR}/heap_before.pb.gz" ] && [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then
go tool pprof -top -base="${OUTPUT_DIR}/heap_before.pb.gz" \
"${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/heap_growth.txt" 2>&1 || true
fi
# FFI/CGO related
go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \
grep -E "(block_on|CGO|FFI|json|Marshal|Unmarshal)" > "${OUTPUT_DIR}/ffi_related.txt" || \
echo "No FFI/CGO related functions found" > "${OUTPUT_DIR}/ffi_related.txt"
# Summary
echo ""
echo -e "${GREEN}=== Analysis Summary ===${NC}"
echo ""
echo -e "${YELLOW}CPU Top (Cumulative) - Top 10:${NC}"
head -12 "${OUTPUT_DIR}/cpu_top_cum.txt" | tail -10 || true
echo ""
echo -e "${YELLOW}CPU Top (Flat) - Top 10:${NC}"
head -12 "${OUTPUT_DIR}/cpu_top_flat.txt" | tail -10 || true
echo ""
echo -e "${YELLOW}FFI/CGO Related Functions:${NC}"
cat "${OUTPUT_DIR}/ffi_related.txt" || true
echo ""
echo -e "${GREEN}=== Detailed Reports ===${NC}"
echo "CPU (cumulative): cat ${OUTPUT_DIR}/cpu_top_cum.txt"
echo "CPU (flat): cat ${OUTPUT_DIR}/cpu_top_flat.txt"
echo "Memory growth: cat ${OUTPUT_DIR}/heap_growth.txt"
echo "FFI related: cat ${OUTPUT_DIR}/ffi_related.txt"
echo ""
echo -e "${GREEN}=== Interactive Analysis ===${NC}"
echo "Run: go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz"
echo "Then visit: http://localhost:8081/ui/flamegraph"
echo ""
echo "Profile files saved to: ${OUTPUT_DIR}"

View File

@@ -0,0 +1,37 @@
package service
import (
sglang "github.com/sglang/sglang-go-grpc-sdk"
)
// SGLangService wraps SGLang client
type SGLangService struct {
client *sglang.Client
}
func NewSGLangService(endpoint, tokenizerPath string) (*SGLangService, error) {
client, err := sglang.NewClient(sglang.ClientConfig{
Endpoint: endpoint,
TokenizerPath: tokenizerPath,
})
if err != nil {
return nil, err
}
return &SGLangService{
client: client,
}, nil
}
// Client returns the underlying SGLang client
func (s *SGLangService) Client() *sglang.Client {
return s.client
}
// Close closes the SGLang client
func (s *SGLangService) Close() error {
if s.client != nil {
return s.client.Close()
}
return nil
}

View File

@@ -0,0 +1,34 @@
package utils
import (
"encoding/json"
"github.com/valyala/fasthttp"
)
// RespondError sends an error response in OpenAI format
func RespondError(ctx *fasthttp.RequestCtx, statusCode int, message, errorType string) {
ctx.SetStatusCode(statusCode)
ctx.SetContentType("application/json")
response := map[string]interface{}{
"error": map[string]interface{}{
"message": message,
"type": errorType,
"code": statusCode,
},
}
jsonData, _ := json.Marshal(response)
ctx.Write(jsonData)
}
// BuildResponseBase builds the base response structure for OpenAI-compatible responses
func BuildResponseBase(id string, created int64, model string) map[string]interface{} {
return map[string]interface{}{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
}
}

View File

@@ -0,0 +1,85 @@
// Simple example demonstrating basic usage of SGLang Go SDK
package main
import (
"context"
"fmt"
"log"
"os"
"github.com/sglang/sglang-go-grpc-sdk"
)
func main() {
// Get configuration from environment or command line
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
if endpoint == "" {
endpoint = "grpc://localhost:20000"
}
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
if tokenizerPath == "" {
tokenizerPath = "./examples/tokenizer"
}
// Create client
client, err := sglang.NewClient(sglang.ClientConfig{
Endpoint: endpoint,
TokenizerPath: tokenizerPath,
})
if err != nil {
log.Fatalf("Failed to create client: %v", err)
}
defer client.Close()
// Create chat completion request
req := sglang.ChatCompletionRequest{
Model: "default",
Messages: []sglang.ChatMessage{
{
Role: "system",
Content: "You are a helpful assistant.",
},
{
Role: "user",
Content: "写一首歌关于夏天",
},
},
Stream: false,
Temperature: float32Ptr(0.7),
MaxCompletionTokens: intPtr(200),
SkipSpecialTokens: true,
Tools: nil, // Use nil instead of empty slice to avoid template errors
}
// Create completion
ctx := context.Background()
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
log.Fatalf("Failed to create completion: %v", err)
}
// Print response
fmt.Println("=== Response ===")
fmt.Printf("ID: %s\n", resp.ID)
fmt.Printf("Model: %s\n", resp.Model)
fmt.Printf("Created: %d\n", resp.Created)
fmt.Println("\nContent:")
for _, choice := range resp.Choices {
fmt.Println(choice.Message.Content)
}
fmt.Printf("\nFinish Reason: %s\n", resp.Choices[0].FinishReason)
fmt.Printf("\nUsage: Prompt=%d, Completion=%d, Total=%d\n",
resp.Usage.PromptTokens,
resp.Usage.CompletionTokens,
resp.Usage.TotalTokens,
)
}
func float32Ptr(f float32) *float32 {
return &f
}
func intPtr(i int) *int {
return &i
}

View File

@@ -0,0 +1,46 @@
#!/bin/bash
# Simple example runner
# Usage: ./run.sh [tokenizer_path] [endpoint]
# Set library path for Rust FFI library
# The library should be in ./lib directory (created by 'make lib')
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
LIB_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)/lib"
# Check if lib directory exists
if [ ! -d "$LIB_DIR" ]; then
echo "Error: Library directory not found at $LIB_DIR"
echo "Please run 'make lib' first to build and export the library"
exit 1
fi
# Get Python LDFLAGS (needed for Rust FFI that depends on Python)
PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
# Set CGO_LDFLAGS to link with the Rust library
export CGO_LDFLAGS="-L${LIB_DIR} -lsgl_model_gateway_go ${PYTHON_LDFLAGS} -ldl"
# macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH
if [[ "$OSTYPE" == "darwin"* ]]; then
export DYLD_LIBRARY_PATH="${LIB_DIR}:${DYLD_LIBRARY_PATH}"
else
export LD_LIBRARY_PATH="${LIB_DIR}:${LD_LIBRARY_PATH}"
fi
# Default configuration (can be overridden by environment variables or command line arguments)
# Tokenizer path: ../tokenizer (relative to this script)
DEFAULT_TOKENIZER_PATH="${SGL_TOKENIZER_PATH:-../tokenizer}"
DEFAULT_ENDPOINT="${SGL_GRPC_ENDPOINT:-grpc://localhost:20000}"
TOKENIZER_PATH="${1:-${DEFAULT_TOKENIZER_PATH}}"
ENDPOINT="${2:-${DEFAULT_ENDPOINT}}"
echo "Running simple example..."
echo "Library path: ${LIB_DIR}"
echo "Tokenizer: $TOKENIZER_PATH"
echo "Endpoint: $ENDPOINT"
echo ""
cd "$(dirname "${BASH_SOURCE[0]}")"
SGL_TOKENIZER_PATH="$TOKENIZER_PATH" SGL_GRPC_ENDPOINT="$ENDPOINT" go run main.go

View File

@@ -0,0 +1,125 @@
// Streaming example demonstrating real-time streaming with SGLang Go SDK
package main
import (
"context"
"fmt"
"io"
"log"
"os"
"strings"
"time"
"github.com/sglang/sglang-go-grpc-sdk"
)
func main() {
// Get configuration from environment or command line
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
if endpoint == "" {
endpoint = "grpc://localhost:20000"
}
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
if tokenizerPath == "" {
tokenizerPath = "./examples/tokenizer"
}
// Create client
client, err := sglang.NewClient(sglang.ClientConfig{
Endpoint: endpoint,
TokenizerPath: tokenizerPath,
})
if err != nil {
log.Fatalf("Failed to create client: %v", err)
}
defer client.Close()
// Create streaming chat completion request
req := sglang.ChatCompletionRequest{
Model: "default",
Messages: []sglang.ChatMessage{
{
Role: "system",
Content: "You are a helpful assistant.",
},
{
Role: "user",
Content: "写一首春天的诗歌",
},
},
Stream: true,
Temperature: float32Ptr(0.7),
MaxCompletionTokens: intPtr(500),
SkipSpecialTokens: true,
Tools: nil, // Use nil instead of empty slice to avoid template errors
}
// Create streaming completion
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
log.Fatalf("Failed to create stream: %v", err)
}
defer stream.Close()
fmt.Println("=== Streaming Response ===")
fmt.Println()
var fullContent strings.Builder
chunkCount := 0
startTime := time.Now()
var firstTokenTime time.Time
firstTokenReceived := false
for {
chunk, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Fatalf("Stream error: %v", err)
}
chunkCount++
// Extract content from delta
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
fmt.Print(choice.Delta.Content)
fullContent.WriteString(choice.Delta.Content)
// Track first token time (TTFT)
if !firstTokenReceived {
firstTokenTime = time.Now()
firstTokenReceived = true
ttft := firstTokenTime.Sub(startTime)
fmt.Printf("\n[TTFT: %v]\n", ttft)
}
}
if choice.FinishReason != "" {
fmt.Printf("\n\n[Finished: %s]\n", choice.FinishReason)
}
}
}
// Calculate metrics
if firstTokenReceived {
elapsed := time.Since(startTime)
tokensPerSecond := float64(fullContent.Len()) / elapsed.Seconds()
fmt.Printf("\n=== Metrics ===\n")
fmt.Printf("Total chunks: %d\n", chunkCount)
fmt.Printf("Total content length: %d characters\n", fullContent.Len())
fmt.Printf("Time elapsed: %v\n", elapsed)
fmt.Printf("Tokens per second: %.2f\n", tokensPerSecond)
}
}
func float32Ptr(f float32) *float32 {
return &f
}
func intPtr(i int) *int {
return &i
}

View File

@@ -0,0 +1,46 @@
#!/bin/bash
# Streaming example runner
# Usage: ./run.sh [tokenizer_path] [endpoint]
# Set library path for Rust FFI library
# The library should be in ./lib directory (created by 'make lib')
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
LIB_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)/lib"
# Check if lib directory exists
if [ ! -d "$LIB_DIR" ]; then
echo "Error: Library directory not found at $LIB_DIR"
echo "Please run 'make lib' first to build and export the library"
exit 1
fi
# Get Python LDFLAGS (needed for Rust FFI that depends on Python)
PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "")
# Set CGO_LDFLAGS to link with the Rust library
export CGO_LDFLAGS="-L${LIB_DIR} -lsgl_model_gateway_go ${PYTHON_LDFLAGS} -ldl"
# macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH
if [[ "$OSTYPE" == "darwin"* ]]; then
export DYLD_LIBRARY_PATH="${LIB_DIR}:${DYLD_LIBRARY_PATH}"
else
export LD_LIBRARY_PATH="${LIB_DIR}:${LD_LIBRARY_PATH}"
fi
# Default configuration (can be overridden by environment variables or command line arguments)
# Tokenizer path: ../tokenizer (relative to this script)
DEFAULT_TOKENIZER_PATH="${SGL_TOKENIZER_PATH:-../tokenizer}"
DEFAULT_ENDPOINT="${SGL_GRPC_ENDPOINT:-grpc://localhost:20000}"
TOKENIZER_PATH="${1:-${DEFAULT_TOKENIZER_PATH}}"
ENDPOINT="${2:-${DEFAULT_ENDPOINT}}"
echo "Running streaming example..."
echo "Library path: ${LIB_DIR}"
echo "Tokenizer: $TOKENIZER_PATH"
echo "Endpoint: $ENDPOINT"
echo ""
cd "$(dirname "${BASH_SOURCE[0]}")"
SGL_TOKENIZER_PATH="$TOKENIZER_PATH" SGL_GRPC_ENDPOINT="$ENDPOINT" go run main.go

View File

@@ -0,0 +1,36 @@
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo=
golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=

View File

@@ -0,0 +1,228 @@
//go:build integration
// +build integration
// integration_test.go contains integration tests that require a running SGLang server
//
// To run these tests:
// 1. Start an SGLang server: python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-hf
// 2. Run: go test -tags=integration -run TestIntegration
package sglang
import (
"context"
"io"
"os"
"testing"
"time"
)
// getTestConfig returns test configuration from environment or defaults
func getTestConfig(t *testing.T) ClientConfig {
endpoint := os.Getenv("SGL_GRPC_ENDPOINT")
if endpoint == "" {
endpoint = "grpc://localhost:20000"
}
tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH")
if tokenizerPath == "" {
t.Skip("SGL_TOKENIZER_PATH not set")
}
return ClientConfig{
Endpoint: endpoint,
TokenizerPath: tokenizerPath,
}
}
// TestIntegrationNonStreamingCompletion tests non-streaming chat completion
func TestIntegrationNonStreamingCompletion(t *testing.T) {
config := getTestConfig(t)
client, err := NewClient(config)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer client.Close()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req := ChatCompletionRequest{
Model: "default",
Messages: []ChatMessage{
{Role: "user", Content: "Say 'Hello, World!' only"},
},
Stream: false,
Temperature: float32Ptr(0.0),
MaxCompletionTokens: intPtr(50),
}
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateChatCompletion failed: %v", err)
}
if resp.ID == "" {
t.Error("Response ID is empty")
}
if len(resp.Choices) == 0 {
t.Error("Response has no choices")
}
if resp.Choices[0].Message.Content == "" {
t.Error("Response content is empty")
}
if resp.Usage == nil || resp.Usage.TotalTokens == 0 {
t.Error("Usage information is missing or invalid")
}
t.Logf("Response: %s", resp.Choices[0].Message.Content)
t.Logf("Usage: %+v", resp.Usage)
}
// TestIntegrationStreamingCompletion tests streaming chat completion
func TestIntegrationStreamingCompletion(t *testing.T) {
config := getTestConfig(t)
client, err := NewClient(config)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer client.Close()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req := ChatCompletionRequest{
Model: "default",
Messages: []ChatMessage{
{Role: "user", Content: "Count from 1 to 5"},
},
Stream: true,
Temperature: float32Ptr(0.0),
MaxCompletionTokens: intPtr(100),
}
stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
t.Fatalf("CreateChatCompletionStream failed: %v", err)
}
defer stream.Close()
chunkCount := 0
totalContent := ""
for {
chunk, err := stream.Recv()
if err == io.EOF {
// io.EOF is expected at end of stream
break
}
if err != nil {
t.Fatalf("Stream error: %v", err)
}
chunkCount++
for _, choice := range chunk.Choices {
if choice.Delta.Content != "" {
totalContent += choice.Delta.Content
}
}
}
if chunkCount == 0 {
t.Error("Received no chunks from stream")
}
if totalContent == "" {
t.Error("Received no content from stream")
}
t.Logf("Received %d chunks with content: %s", chunkCount, totalContent)
}
// TestIntegrationConcurrentRequests tests multiple concurrent requests
func TestIntegrationConcurrentRequests(t *testing.T) {
config := getTestConfig(t)
client, err := NewClient(config)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer client.Close()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
numRequests := 3
done := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func(idx int) {
req := ChatCompletionRequest{
Model: "default",
Messages: []ChatMessage{
{Role: "user", Content: "Say 'test'"},
},
Stream: false,
MaxCompletionTokens: intPtr(50),
}
_, err := client.CreateChatCompletion(ctx, req)
done <- err
}(i)
}
// Collect results
for i := 0; i < numRequests; i++ {
if err := <-done; err != nil {
t.Errorf("Request %d failed: %v", i, err)
}
}
t.Logf("All %d concurrent requests completed successfully", numRequests)
}
// TestIntegrationContextCancellation tests that context cancellation is handled
func TestIntegrationContextCancellation(t *testing.T) {
config := getTestConfig(t)
client, err := NewClient(config)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer client.Close()
// Create a context that cancels immediately
ctx, cancel := context.WithCancel(context.Background())
cancel()
req := ChatCompletionRequest{
Model: "default",
Messages: []ChatMessage{
{Role: "user", Content: "test"},
},
Stream: false,
}
// Should handle cancelled context gracefully
_, err = client.CreateChatCompletion(ctx, req)
if err == nil {
t.Error("Expected error from cancelled context")
}
t.Logf("Cancelled context handled: %v", err)
}
// Helper functions
func float32Ptr(f float32) *float32 {
return &f
}
func intPtr(i int) *int {
return &i
}

View File

@@ -0,0 +1,126 @@
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
package ffi
import (
"encoding/json"
"fmt"
"strings"
"time"
)
// BatchPostprocessor handles batch postprocessing of stream chunks to reduce FFI overhead
type BatchPostprocessor struct {
converter *GrpcResponseConverterHandle
buffer []string
batchSize int
flushInterval time.Duration
lastFlush time.Time
timer *time.Timer
}
// NewBatchPostprocessor creates a new batch postprocessor
func NewBatchPostprocessor(converter *GrpcResponseConverterHandle, batchSize int, flushInterval time.Duration) *BatchPostprocessor {
if batchSize <= 0 {
batchSize = 1
}
if flushInterval < 0 {
flushInterval = 0
}
return &BatchPostprocessor{
converter: converter,
buffer: make([]string, 0, batchSize),
batchSize: batchSize,
flushInterval: flushInterval,
lastFlush: time.Now(),
}
}
// AddChunk adds a chunk to the buffer and processes if batch is full
func (b *BatchPostprocessor) AddChunk(chunkJSON string) (results []string, shouldFlush bool, err error) {
if b.batchSize == 1 {
openaiJSON, _, err := PostprocessStreamChunk(b.converter, chunkJSON)
if err != nil {
return nil, false, err
}
return []string{openaiJSON}, false, nil
}
b.buffer = append(b.buffer, chunkJSON)
shouldProcess := len(b.buffer) >= b.batchSize
shouldFlushTimeout := b.flushInterval > 0 && time.Since(b.lastFlush) >= b.flushInterval
if shouldProcess || shouldFlushTimeout {
return b.processBatch()
}
return nil, false, nil
}
// Flush processes any remaining chunks in the buffer
func (b *BatchPostprocessor) Flush() (results []string, err error) {
if len(b.buffer) == 0 {
return nil, nil
}
res, _, err := b.processBatch()
return res, err
}
// processBatch processes the current buffer and returns results
func (b *BatchPostprocessor) processBatch() (results []string, shouldFlush bool, err error) {
if len(b.buffer) == 0 {
return nil, false, nil
}
var sb strings.Builder
sb.Grow(len(b.buffer) * 200)
sb.WriteString(`[`)
for i, chunkJSONStr := range b.buffer {
if i > 0 {
sb.WriteString(`,`)
}
sb.WriteString(chunkJSONStr)
}
sb.WriteString(`]`)
bufferJSON := sb.String()
resultJSON, _, err := PostprocessStreamChunksBatch(
b.converter,
bufferJSON,
b.batchSize*2,
)
if err != nil {
return nil, false, fmt.Errorf("batch postprocessing failed: %w", err)
}
var resultArray []json.RawMessage
if err := json.Unmarshal([]byte(resultJSON), &resultArray); err != nil {
return nil, false, fmt.Errorf("failed to unmarshal results array: %w", err)
}
resultStrings := make([]string, 0, len(resultArray))
for _, rawMsg := range resultArray {
resultStrings = append(resultStrings, string(rawMsg))
}
b.buffer = b.buffer[:0]
b.lastFlush = time.Now()
if b.timer != nil {
b.timer.Stop()
b.timer = nil
}
return resultStrings, false, nil
}
// Reset clears the buffer and resets the postprocessor state
func (b *BatchPostprocessor) Reset() {
b.buffer = b.buffer[:0]
b.lastFlush = time.Now()
if b.timer != nil {
b.timer.Stop()
b.timer = nil
}
}

View File

@@ -0,0 +1,228 @@
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
//
// This package wraps the Rust FFI layer of SGLang, providing low-level access to:
// - Client creation and connection management
// - Chat completion streaming
// - Stream reading and response conversion
// - Memory management for C strings
//
// Internal use only: This package is intended for internal use by the sglang package.
// End users should use the public sglang package instead.
package ffi
/*
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
#include <stdlib.h>
#include <stdint.h>
// Error codes
typedef enum {
SGL_ERROR_SUCCESS = 0,
SGL_ERROR_INVALID_ARGUMENT = 1,
SGL_ERROR_TOKENIZATION_ERROR = 2,
SGL_ERROR_PARSING_ERROR = 3,
SGL_ERROR_MEMORY_ERROR = 4,
SGL_ERROR_UNKNOWN = 99
} SglErrorCode;
// Opaque handles
typedef void* SglangClientHandle;
typedef void* SglangStreamHandle;
// Client SDK functions
SglangClientHandle* sgl_client_create(const char* endpoint, const char* tokenizer_path, char** error_out);
void sgl_client_free(SglangClientHandle* handle);
SglErrorCode sgl_client_chat_completion_stream(SglangClientHandle* client_handle, const char* request_json, SglangStreamHandle** stream_handle_out, char** error_out);
SglErrorCode sgl_stream_read_next(SglangStreamHandle* stream_handle, char** response_json_out, int* is_done_out, char** error_out);
void sgl_stream_free(SglangStreamHandle* handle);
void sgl_free_string(char* s);
*/
import "C"
import (
"fmt"
"unsafe"
)
// ErrorCode represents FFI error codes returned by Rust functions.
//
// These codes indicate the result of FFI operations. Use Error() to get a human-readable
// error message.
type ErrorCode int
const (
// ErrorSuccess indicates the operation completed successfully
ErrorSuccess ErrorCode = 0
// ErrorInvalidArgument indicates invalid arguments were passed to the FFI function
ErrorInvalidArgument ErrorCode = 1
// ErrorTokenizationError indicates an error during tokenization
ErrorTokenizationError ErrorCode = 2
// ErrorParsingError indicates an error parsing the response or request
ErrorParsingError ErrorCode = 3
// ErrorMemoryError indicates a memory allocation error
ErrorMemoryError ErrorCode = 4
// ErrorUnknown indicates an unclassified error
ErrorUnknown ErrorCode = 99
)
// Error implements the error interface for ErrorCode.
func (e ErrorCode) Error() string {
switch e {
case ErrorSuccess:
return "success"
case ErrorInvalidArgument:
return "invalid argument"
case ErrorTokenizationError:
return "tokenization error"
case ErrorParsingError:
return "parsing error"
case ErrorMemoryError:
return "memory error"
case ErrorUnknown:
return "unknown error"
default:
return fmt.Sprintf("unknown error code: %d", e)
}
}
// SglangClientHandle wraps the Rust client SDK FFI handle.
//
// This struct maintains a connection to the SGLang gRPC server and is used
// to create streams and manage the underlying Rust client resources.
type SglangClientHandle struct {
handle *C.SglangClientHandle
}
// NewClient creates a new SGLang client handle via FFI.
//
// This function initializes the Rust client with the given endpoint and tokenizer path.
//
// Parameters:
// - endpoint: gRPC endpoint URL (e.g., "grpc://localhost:20000")
// - tokenizerPath: Path to tokenizer directory
//
// Returns:
// - *SglangClientHandle: A new client handle
// - error: An error if client creation failed
func NewClient(endpoint, tokenizerPath string) (*SglangClientHandle, error) {
cEndpoint := C.CString(endpoint)
defer C.free(unsafe.Pointer(cEndpoint))
cTokenizerPath := C.CString(tokenizerPath)
defer C.free(unsafe.Pointer(cTokenizerPath))
var errorPtr *C.char
handle := C.sgl_client_create(cEndpoint, cTokenizerPath, &errorPtr)
if handle == nil {
errorMsg := ""
if errorPtr != nil {
errorMsg = C.GoString(errorPtr)
C.sgl_free_string(errorPtr)
}
if errorMsg == "" {
errorMsg = "failed to create client"
}
return nil, fmt.Errorf("%s", errorMsg)
}
return &SglangClientHandle{handle: handle}, nil
}
// Free releases the client handle
func (h *SglangClientHandle) Free() {
if h.handle != nil {
C.sgl_client_free(h.handle)
h.handle = nil
}
}
// ChatCompletionStream creates a streaming chat completion request
func (h *SglangClientHandle) ChatCompletionStream(requestJSON string) (*SglangStreamHandle, error) {
if h.handle == nil {
return nil, fmt.Errorf("client handle is nil")
}
cRequestJSON := C.CString(requestJSON)
defer C.free(unsafe.Pointer(cRequestJSON))
var streamHandle *C.SglangStreamHandle
var errorPtr *C.char
result := C.sgl_client_chat_completion_stream(
h.handle,
cRequestJSON,
&streamHandle,
&errorPtr,
)
if ErrorCode(result) != ErrorSuccess {
errorMsg := ""
if errorPtr != nil {
errorMsg = C.GoString(errorPtr)
C.sgl_free_string(errorPtr)
}
if errorMsg == "" {
errorMsg = fmt.Sprintf("error code %d", result)
}
return nil, fmt.Errorf("%s", errorMsg)
}
if streamHandle == nil {
return nil, fmt.Errorf("stream handle is nil")
}
return &SglangStreamHandle{handle: streamHandle}, nil
}
// SglangStreamHandle wraps the Rust stream FFI handle
type SglangStreamHandle struct {
handle *C.SglangStreamHandle
}
// ReadNext reads the next chunk from the stream
// Returns: (responseJSON, isDone, error)
func (h *SglangStreamHandle) ReadNext() (string, bool, error) {
if h.handle == nil {
return "", true, fmt.Errorf("stream handle is nil")
}
var responseJSON *C.char
var isDone C.int
var errorPtr *C.char
result := C.sgl_stream_read_next(
h.handle,
&responseJSON,
&isDone,
&errorPtr,
)
if ErrorCode(result) != ErrorSuccess {
errorMsg := ""
if errorPtr != nil {
errorMsg = C.GoString(errorPtr)
C.sgl_free_string(errorPtr)
}
if errorMsg == "" {
errorMsg = fmt.Sprintf("error code %d", result)
}
return "", isDone == 1, fmt.Errorf("%s", errorMsg)
}
responseStr := ""
if responseJSON != nil {
responseStr = C.GoString(responseJSON)
C.sgl_free_string(responseJSON)
}
return responseStr, isDone == 1, nil
}
// Free releases the stream handle
func (h *SglangStreamHandle) Free() {
if h.handle != nil {
C.sgl_stream_free(h.handle)
h.handle = nil
}
}

View File

@@ -0,0 +1,275 @@
package ffi
/*
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
#include <stdlib.h>
#include <stdint.h>
// Error codes (must match client.go)
typedef enum {
SGL_ERROR_SUCCESS = 0,
SGL_ERROR_INVALID_ARGUMENT = 1,
SGL_ERROR_TOKENIZATION_ERROR = 2,
SGL_ERROR_PARSING_ERROR = 3,
SGL_ERROR_MEMORY_ERROR = 4,
SGL_ERROR_UNKNOWN = 99
} SglErrorCode;
// Opaque handles
typedef void* TokenizerHandle;
typedef void* GrpcResponseConverterHandle;
// Converter functions
GrpcResponseConverterHandle* sgl_grpc_response_converter_create(
TokenizerHandle* tokenizer_handle,
const char* model,
const char* request_id,
const char* tools_json,
const char* tool_choice_json,
const char* stop,
const char* stop_token_ids,
int skip_special_tokens,
int initial_prompt_tokens,
char** error_out
);
void sgl_grpc_response_converter_free(GrpcResponseConverterHandle* handle);
// Tokenizer functions
TokenizerHandle* sgl_tokenizer_create_from_file(const char* tokenizer_path, char** error_out);
void sgl_tokenizer_free(TokenizerHandle* handle);
// Memory management
void sgl_free_string(char* s);
*/
import "C"
import (
"fmt"
"unsafe"
)
// CreateGrpcResponseConverter creates a gRPC response converter handle
// This function creates a new tokenizer handle each time (for backward compatibility)
// For better performance, use CreateGrpcResponseConverterWithTokenizer with a cached tokenizer
func CreateGrpcResponseConverter(
tokenizerPath string,
model string,
requestID string,
toolsJSON string,
toolChoiceJSON string,
stopJSON string,
stopTokenIDs []uint32,
skipSpecialTokens bool,
initialPromptTokens int32,
) (*GrpcResponseConverterHandle, error) {
// Create tokenizer handle
tokenizerHandle, err := createTokenizerHandle(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to create tokenizer handle: %w", err)
}
defer C.sgl_tokenizer_free(tokenizerHandle)
return createGrpcResponseConverterWithTokenizerHandle(
tokenizerHandle,
model,
requestID,
toolsJSON,
toolChoiceJSON,
stopJSON,
stopTokenIDs,
skipSpecialTokens,
initialPromptTokens,
)
}
// CreateGrpcResponseConverterWithTokenizer creates a gRPC response converter handle using a cached tokenizer
// This is more efficient as it reuses the tokenizer instead of creating a new one each time
func CreateGrpcResponseConverterWithTokenizer(
tokenizerHandle *TokenizerHandle,
model string,
requestID string,
toolsJSON string,
toolChoiceJSON string,
stopJSON string,
stopTokenIDs []uint32,
skipSpecialTokens bool,
initialPromptTokens int32,
) (*GrpcResponseConverterHandle, error) {
if tokenizerHandle == nil || tokenizerHandle.handle == nil {
return nil, fmt.Errorf("invalid tokenizer handle")
}
return createGrpcResponseConverterWithTokenizerHandle(
tokenizerHandle.handle,
model,
requestID,
toolsJSON,
toolChoiceJSON,
stopJSON,
stopTokenIDs,
skipSpecialTokens,
initialPromptTokens,
)
}
// createGrpcResponseConverterWithTokenizerHandle is the internal implementation
func createGrpcResponseConverterWithTokenizerHandle(
tokenizerHandle *C.TokenizerHandle,
model string,
requestID string,
toolsJSON string,
toolChoiceJSON string,
stopJSON string,
stopTokenIDs []uint32,
skipSpecialTokens bool,
initialPromptTokens int32,
) (*GrpcResponseConverterHandle, error) {
// Convert strings to C strings
modelC := C.CString(model)
defer C.free(unsafe.Pointer(modelC))
requestIDC := C.CString(requestID)
defer C.free(unsafe.Pointer(requestIDC))
var toolsJSONC *C.char
if toolsJSON != "" {
toolsJSONC = C.CString(toolsJSON)
defer C.free(unsafe.Pointer(toolsJSONC))
}
var toolChoiceJSONC *C.char
if toolChoiceJSON != "" {
toolChoiceJSONC = C.CString(toolChoiceJSON)
defer C.free(unsafe.Pointer(toolChoiceJSONC))
}
var stopJSONC *C.char
if stopJSON != "" {
stopJSONC = C.CString(stopJSON)
defer C.free(unsafe.Pointer(stopJSONC))
}
// Convert stop_token_ids to JSON string
stopTokenIDsJSON := ""
if len(stopTokenIDs) > 0 {
stopTokenIDsJSON = fmt.Sprintf("[%d", stopTokenIDs[0])
for i := 1; i < len(stopTokenIDs); i++ {
stopTokenIDsJSON += fmt.Sprintf(",%d", stopTokenIDs[i])
}
stopTokenIDsJSON += "]"
}
var stopTokenIDsJSONC *C.char
if stopTokenIDsJSON != "" {
stopTokenIDsJSONC = C.CString(stopTokenIDsJSON)
defer C.free(unsafe.Pointer(stopTokenIDsJSONC))
}
var errorOut *C.char
skipSpecialTokensC := C.int(0)
if skipSpecialTokens {
skipSpecialTokensC = C.int(1)
}
initialPromptTokensC := C.int(initialPromptTokens)
converterHandle := C.sgl_grpc_response_converter_create(
tokenizerHandle,
modelC,
requestIDC,
toolsJSONC,
toolChoiceJSONC,
stopJSONC,
stopTokenIDsJSONC,
skipSpecialTokensC,
initialPromptTokensC,
&errorOut,
)
if converterHandle == nil {
errorMsg := ""
if errorOut != nil {
errorMsg = C.GoString(errorOut)
C.sgl_free_string(errorOut)
}
if errorMsg == "" {
errorMsg = "failed to create converter handle"
}
return nil, fmt.Errorf("%s", errorMsg)
}
return &GrpcResponseConverterHandle{
handle: converterHandle,
}, nil
}
// FreeGrpcResponseConverter frees a gRPC response converter handle
func FreeGrpcResponseConverter(handle *GrpcResponseConverterHandle) {
if handle != nil && handle.handle != nil {
C.sgl_grpc_response_converter_free(handle.handle)
handle.handle = nil
}
}
// TokenizerHandle wraps the Rust tokenizer FFI handle
type TokenizerHandle struct {
handle *C.TokenizerHandle
}
// CreateTokenizerHandle creates a tokenizer handle (exported for caching)
func CreateTokenizerHandle(tokenizerPath string) (*TokenizerHandle, error) {
tokenizerPathC := C.CString(tokenizerPath)
defer C.free(unsafe.Pointer(tokenizerPathC))
var errorOut *C.char
tokenizerHandle := C.sgl_tokenizer_create_from_file(tokenizerPathC, &errorOut)
if tokenizerHandle == nil {
errorMsg := ""
if errorOut != nil {
errorMsg = C.GoString(errorOut)
C.sgl_free_string(errorOut)
}
if errorMsg == "" {
errorMsg = "failed to create tokenizer handle"
}
return nil, fmt.Errorf("%s", errorMsg)
}
return &TokenizerHandle{
handle: tokenizerHandle,
}, nil
}
// FreeTokenizerHandle frees a tokenizer handle
func FreeTokenizerHandle(handle *TokenizerHandle) {
if handle != nil && handle.handle != nil {
C.sgl_tokenizer_free(handle.handle)
handle.handle = nil
}
}
// createTokenizerHandle creates a tokenizer handle (helper function, internal use)
func createTokenizerHandle(tokenizerPath string) (*C.TokenizerHandle, error) {
tokenizerPathC := C.CString(tokenizerPath)
defer C.free(unsafe.Pointer(tokenizerPathC))
var errorOut *C.char
tokenizerHandle := C.sgl_tokenizer_create_from_file(tokenizerPathC, &errorOut)
if tokenizerHandle == nil {
errorMsg := ""
if errorOut != nil {
errorMsg = C.GoString(errorOut)
C.sgl_free_string(errorOut)
}
if errorMsg == "" {
errorMsg = "failed to create tokenizer handle"
}
return nil, fmt.Errorf("%s", errorMsg)
}
return tokenizerHandle, nil
}

View File

@@ -0,0 +1,156 @@
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
package ffi
/*
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
#include <stdlib.h>
#include <stdint.h>
// Error codes (must match client.go)
typedef enum {
SGL_ERROR_SUCCESS = 0,
SGL_ERROR_INVALID_ARGUMENT = 1,
SGL_ERROR_TOKENIZATION_ERROR = 2,
SGL_ERROR_PARSING_ERROR = 3,
SGL_ERROR_MEMORY_ERROR = 4,
SGL_ERROR_UNKNOWN = 99
} SglErrorCode;
// Opaque handle (must match grpc_converter.go)
typedef void* GrpcResponseConverterHandle;
// Postprocessor functions
SglErrorCode sgl_postprocess_stream_chunk(
GrpcResponseConverterHandle* converter_handle,
const char* proto_chunk_json,
char** openai_json_out,
int* is_done_out,
char** error_out
);
SglErrorCode sgl_postprocess_stream_chunks_batch(
GrpcResponseConverterHandle* converter_handle,
const char* proto_chunks_json_array,
int max_chunks,
char** openai_chunks_json_array_out,
int* chunks_count_out,
char** error_out
);
// Memory management
void sgl_free_string(char* s);
*/
import "C"
import (
"fmt"
"unsafe"
)
// GrpcResponseConverterHandle wraps the Rust gRPC response converter FFI handle
type GrpcResponseConverterHandle struct {
handle *C.GrpcResponseConverterHandle
}
// PostprocessStreamChunk postprocesses a gRPC stream chunk to OpenAI format
//
// This function:
// 1. Parses the proto chunk from JSON
// 2. Converts it to OpenAI format using the converter handle
// 3. Returns the OpenAI format JSON
//
// Returns the OpenAI format JSON, is_done flag, and any error.
func PostprocessStreamChunk(converterHandle *GrpcResponseConverterHandle, protoChunkJSON string) (openaiJSON string, isDone bool, err error) {
if converterHandle == nil || converterHandle.handle == nil {
return "", false, fmt.Errorf("invalid converter handle")
}
protoChunkJSONC := C.CString(protoChunkJSON)
defer C.free(unsafe.Pointer(protoChunkJSONC))
var openaiJSONOut *C.char
var isDoneOut C.int
var errorOut *C.char
errorCode := C.sgl_postprocess_stream_chunk(
converterHandle.handle,
protoChunkJSONC,
&openaiJSONOut,
&isDoneOut,
&errorOut,
)
if errorCode != C.SGL_ERROR_SUCCESS {
errorMsg := ""
if errorOut != nil {
errorMsg = C.GoString(errorOut)
C.sgl_free_string(errorOut)
}
return "", false, fmt.Errorf("postprocessing failed: %s", errorMsg)
}
openaiJSON = C.GoString(openaiJSONOut)
isDone = isDoneOut != 0
// Free the C string allocated by Rust
if openaiJSONOut != nil {
C.sgl_free_string(openaiJSONOut)
}
return openaiJSON, isDone, nil
}
// PostprocessStreamChunksBatch postprocesses multiple gRPC stream chunks in batch
//
// This function processes multiple chunks in a single FFI call, significantly reducing
// FFI overhead in streaming scenarios.
//
// Arguments:
// - converterHandle: Converter handle
// - protoChunksJSONArray: JSON array string of proto chunks
// - maxChunks: Maximum number of chunks to process (for safety, typically 10-20)
//
// Returns:
// - openaiChunksJSONArray: JSON array of OpenAI format chunks
// - chunksCount: Number of processed chunks
// - error: Any error that occurred
func PostprocessStreamChunksBatch(converterHandle *GrpcResponseConverterHandle, protoChunksJSONArray string, maxChunks int) (openaiChunksJSONArray string, chunksCount int, err error) {
if converterHandle == nil || converterHandle.handle == nil {
return "", 0, fmt.Errorf("invalid converter handle")
}
protoChunksJSONArrayC := C.CString(protoChunksJSONArray)
defer C.free(unsafe.Pointer(protoChunksJSONArrayC))
var openaiChunksJSONArrayOut *C.char
var chunksCountOut C.int
var errorOut *C.char
errorCode := C.sgl_postprocess_stream_chunks_batch(
converterHandle.handle,
protoChunksJSONArrayC,
C.int(maxChunks),
&openaiChunksJSONArrayOut,
&chunksCountOut,
&errorOut,
)
if errorCode != C.SGL_ERROR_SUCCESS {
errorMsg := ""
if errorOut != nil {
errorMsg = C.GoString(errorOut)
C.sgl_free_string(errorOut)
}
return "", 0, fmt.Errorf("batch postprocessing failed: %s", errorMsg)
}
openaiChunksJSONArray = C.GoString(openaiChunksJSONArrayOut)
chunksCount = int(chunksCountOut)
// Free the C string allocated by Rust
if openaiChunksJSONArrayOut != nil {
C.sgl_free_string(openaiChunksJSONArrayOut)
}
return openaiChunksJSONArray, chunksCount, nil
}

View File

@@ -0,0 +1,246 @@
// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface).
package ffi
/*
#cgo LDFLAGS: -lsgl_model_gateway_go -ldl
#include <stdlib.h>
#include <stdint.h>
// Error codes (must match client.go)
typedef enum {
SGL_ERROR_SUCCESS = 0,
SGL_ERROR_INVALID_ARGUMENT = 1,
SGL_ERROR_TOKENIZATION_ERROR = 2,
SGL_ERROR_PARSING_ERROR = 3,
SGL_ERROR_MEMORY_ERROR = 4,
SGL_ERROR_UNKNOWN = 99
} SglErrorCode;
// Preprocessor functions
SglErrorCode sgl_preprocess_chat_request(
const char* request_json,
const char* tokenizer_path,
char** prompt_text_out,
uint32_t** token_ids_out,
size_t* token_ids_len_out,
char** tool_constraints_json_out,
int32_t* prompt_tokens_out,
char** error_out
);
// Opaque handle (must match grpc_converter.go)
typedef void* TokenizerHandle;
SglErrorCode sgl_preprocess_chat_request_with_tokenizer(
const char* request_json,
void* tokenizer_handle,
char** prompt_text_out,
uint32_t** token_ids_out,
size_t* token_ids_len_out,
char** tool_constraints_json_out,
int32_t* prompt_tokens_out,
char** error_out
);
void sgl_preprocessed_request_free(
char* prompt_text,
uint32_t* token_ids,
size_t token_ids_len,
char* tool_constraints_json
);
// Memory management
void sgl_free_string(char* s);
void sgl_free_token_ids(uint32_t* ptr, size_t count);
*/
import "C"
import (
"fmt"
"unsafe"
)
// PreprocessedRequest represents a preprocessed chat request
type PreprocessedRequest struct {
PromptText string
TokenIDs []uint32
ToolConstraintsJSON string
PromptTokens int32
// Internal pointers for memory management
promptTextPtr *C.char
tokenIDsPtr *C.uint32_t
tokenIDsLen uintptr
toolConstraintsJSONPtr *C.char
}
// PreprocessChatRequest preprocesses a chat completion request
//
// This function:
// 1. Applies chat_template to messages
// 2. Tokenizes the processed text
// 3. Generates tool constraints (if tools are present)
//
// Returns the preprocessed request data and any error.
func PreprocessChatRequest(requestJSON, tokenizerPath string) (*PreprocessedRequest, error) {
requestJSONC := C.CString(requestJSON)
defer C.free(unsafe.Pointer(requestJSONC))
tokenizerPathC := C.CString(tokenizerPath)
defer C.free(unsafe.Pointer(tokenizerPathC))
var promptTextOut *C.char
var tokenIDsOut *C.uint32_t
var tokenIDsLenOut C.size_t
var toolConstraintsJSONOut *C.char
var promptTokensOut C.int32_t
var errorOut *C.char
errorCode := C.sgl_preprocess_chat_request(
requestJSONC,
tokenizerPathC,
&promptTextOut,
&tokenIDsOut,
&tokenIDsLenOut,
&toolConstraintsJSONOut,
&promptTokensOut,
&errorOut,
)
if errorCode != C.SGL_ERROR_SUCCESS {
errorMsg := ""
if errorOut != nil {
errorMsg = C.GoString(errorOut)
C.sgl_free_string(errorOut)
}
return nil, fmt.Errorf("preprocessing failed: %s", errorMsg)
}
result := &PreprocessedRequest{
PromptText: C.GoString(promptTextOut),
TokenIDs: make([]uint32, tokenIDsLenOut),
ToolConstraintsJSON: "",
PromptTokens: int32(promptTokensOut),
}
// Copy token IDs
if tokenIDsOut != nil && tokenIDsLenOut > 0 {
tokenIDsSlice := (*[1 << 30]C.uint32_t)(unsafe.Pointer(tokenIDsOut))[:tokenIDsLenOut:tokenIDsLenOut]
for i := range result.TokenIDs {
result.TokenIDs[i] = uint32(tokenIDsSlice[i])
}
}
// Copy tool constraints JSON if present
if toolConstraintsJSONOut != nil {
result.ToolConstraintsJSON = C.GoString(toolConstraintsJSONOut)
}
// Store pointers for later cleanup
result.promptTextPtr = promptTextOut
result.tokenIDsPtr = tokenIDsOut
result.tokenIDsLen = uintptr(tokenIDsLenOut)
result.toolConstraintsJSONPtr = toolConstraintsJSONOut
return result, nil
}
// PreprocessChatRequestWithTokenizer preprocesses a chat completion request using an existing tokenizer handle
//
// This function is similar to PreprocessChatRequest, but accepts a TokenizerHandle
// instead of creating a new tokenizer. This allows reusing a cached tokenizer instance,
// significantly reducing initialization overhead in concurrent scenarios.
//
// Returns the preprocessed request data and any error.
func PreprocessChatRequestWithTokenizer(requestJSON string, tokenizerHandle *TokenizerHandle) (*PreprocessedRequest, error) {
requestJSONC := C.CString(requestJSON)
defer C.free(unsafe.Pointer(requestJSONC))
if tokenizerHandle == nil || tokenizerHandle.handle == nil {
return nil, fmt.Errorf("invalid tokenizer handle")
}
var promptTextOut *C.char
var tokenIDsOut *C.uint32_t
var tokenIDsLenOut C.size_t
var toolConstraintsJSONOut *C.char
var promptTokensOut C.int32_t
var errorOut *C.char
errorCode := C.sgl_preprocess_chat_request_with_tokenizer(
requestJSONC,
unsafe.Pointer(tokenizerHandle.handle), // Convert *C.TokenizerHandle to void*
&promptTextOut,
&tokenIDsOut,
&tokenIDsLenOut,
&toolConstraintsJSONOut,
&promptTokensOut,
&errorOut,
)
if errorCode != C.SGL_ERROR_SUCCESS {
errorMsg := ""
if errorOut != nil {
errorMsg = C.GoString(errorOut)
C.sgl_free_string(errorOut)
}
return nil, fmt.Errorf("preprocessing failed: %s", errorMsg)
}
result := &PreprocessedRequest{
PromptText: C.GoString(promptTextOut),
TokenIDs: make([]uint32, tokenIDsLenOut),
ToolConstraintsJSON: "",
PromptTokens: int32(promptTokensOut),
}
// Copy token IDs
if tokenIDsOut != nil && tokenIDsLenOut > 0 {
tokenIDsSlice := (*[1 << 30]C.uint32_t)(unsafe.Pointer(tokenIDsOut))[:tokenIDsLenOut:tokenIDsLenOut]
for i := range result.TokenIDs {
result.TokenIDs[i] = uint32(tokenIDsSlice[i])
}
}
// Copy tool constraints JSON if present
if toolConstraintsJSONOut != nil {
result.ToolConstraintsJSON = C.GoString(toolConstraintsJSONOut)
}
// Store pointers for later cleanup
result.promptTextPtr = promptTextOut
result.tokenIDsPtr = tokenIDsOut
result.tokenIDsLen = uintptr(tokenIDsLenOut)
result.toolConstraintsJSONPtr = toolConstraintsJSONOut
return result, nil
}
// Free frees the memory allocated for a preprocessed request
func (p *PreprocessedRequest) Free() {
if p.promptTextPtr != nil || p.tokenIDsPtr != nil || p.toolConstraintsJSONPtr != nil {
C.sgl_preprocessed_request_free(
p.promptTextPtr,
p.tokenIDsPtr,
C.size_t(p.tokenIDsLen),
p.toolConstraintsJSONPtr,
)
// Clear pointers to prevent double-free
p.promptTextPtr = nil
p.tokenIDsPtr = nil
p.tokenIDsLen = 0
p.toolConstraintsJSONPtr = nil
}
}
// FreePreprocessedRequest frees the memory allocated for a preprocessed request
// This is a convenience function for direct pointer management
func FreePreprocessedRequest(promptTextPtr *C.char, tokenIDsPtr *C.uint32_t, tokenIDsLen uintptr, toolConstraintsJSONPtr *C.char) {
if promptTextPtr != nil || tokenIDsPtr != nil || toolConstraintsJSONPtr != nil {
C.sgl_preprocessed_request_free(
promptTextPtr,
tokenIDsPtr,
C.size_t(tokenIDsLen),
toolConstraintsJSONPtr,
)
}
}

View File

@@ -0,0 +1,684 @@
// Package grpc provides gRPC client implementation for SGLang
package grpc
import (
"context"
"encoding/json"
"fmt"
"io"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/sglang/sglang-go-grpc-sdk/internal/ffi"
"github.com/sglang/sglang-go-grpc-sdk/internal/proto"
)
type grpcClientStream interface {
Recv() (*proto.GenerateResponse, error)
CloseSend() error
}
// recvResult holds the result of a Recv() call
type recvResult struct {
resp *proto.GenerateResponse
err error
}
type GrpcClient struct {
conn *grpc.ClientConn
client proto.SglangSchedulerClient
tokenizerPath string
tokenizerHandle *ffi.TokenizerHandle
bufferSizes ChannelBufferSizes
timeouts Timeouts
requestCounter uint64 // Atomic counter to ensure unique request IDs
}
type ChannelBufferSizes struct {
ResultJSONChan int
ErrChan int
RecvChan int
}
type Timeouts struct {
KeepaliveTime time.Duration
KeepaliveTimeout time.Duration
CloseTimeout time.Duration
}
func NewGrpcClient(endpoint, tokenizerPath string, bufferSizes ChannelBufferSizes, timeouts Timeouts) (*GrpcClient, error) {
endpoint = strings.TrimPrefix(endpoint, "grpc://")
if !strings.Contains(endpoint, ":") {
return nil, fmt.Errorf("invalid endpoint format: %s (expected grpc://host:port)", endpoint)
}
keepaliveParams := keepalive.ClientParameters{
Time: timeouts.KeepaliveTime,
Timeout: timeouts.KeepaliveTimeout,
PermitWithoutStream: false,
}
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithKeepaliveParams(keepaliveParams),
}
conn, err := grpc.NewClient(endpoint, opts...)
if err != nil {
return nil, fmt.Errorf("failed to connect to gRPC server: %w", err)
}
client := proto.NewSglangSchedulerClient(conn)
tokenizerHandle, err := ffi.CreateTokenizerHandle(tokenizerPath)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to create tokenizer handle: %w", err)
}
return &GrpcClient{
conn: conn,
client: client,
tokenizerPath: tokenizerPath,
tokenizerHandle: tokenizerHandle,
bufferSizes: bufferSizes,
timeouts: timeouts,
}, nil
}
func (c *GrpcClient) Close() error {
if c.tokenizerHandle != nil {
ffi.FreeTokenizerHandle(c.tokenizerHandle)
c.tokenizerHandle = nil
}
if c.conn != nil {
return c.conn.Close()
}
return nil
}
func (c *GrpcClient) CreateChatCompletionStream(ctx context.Context, reqJSON string) (*GrpcChatCompletionStream, error) {
if c.tokenizerHandle == nil {
return nil, fmt.Errorf("tokenizer handle is nil (should be created at startup)")
}
preprocessed, err := ffi.PreprocessChatRequestWithTokenizer(reqJSON, c.tokenizerHandle)
if err != nil {
return nil, fmt.Errorf("preprocessing failed: %w", err)
}
defer func() {
if preprocessed != nil {
preprocessed.Free()
}
}()
// Parse request JSON to get parameters
var reqMap map[string]interface{}
if err := json.Unmarshal([]byte(reqJSON), &reqMap); err != nil {
return nil, fmt.Errorf("failed to parse request JSON: %w", err)
}
model, _ := reqMap["model"].(string)
if model == "" {
model = "default"
}
// Build GenerateRequest
// Generate unique request ID using timestamp + atomic counter to avoid collisions
// This matches Rust version's UUID-based approach for uniqueness
counter := atomic.AddUint64(&c.requestCounter, 1)
requestID := fmt.Sprintf("chatcmpl-%d-%d", time.Now().UnixNano(), counter)
generateReq := &proto.GenerateRequest{
RequestId: requestID,
Tokenized: &proto.TokenizedInput{
OriginalText: preprocessed.PromptText,
InputIds: preprocessed.TokenIDs,
},
Stream: true,
}
// Set sampling parameters
samplingParams := &proto.SamplingParams{
Temperature: 1.0,
TopP: 1.0,
TopK: -1,
SkipSpecialTokens: true,
}
if temp, ok := reqMap["temperature"].(float64); ok {
samplingParams.Temperature = float32(temp)
}
if topP, ok := reqMap["top_p"].(float64); ok {
samplingParams.TopP = float32(topP)
}
if topK, ok := reqMap["top_k"].(float64); ok {
samplingParams.TopK = int32(topK)
}
var maxTokensInt *int32
if maxCompletionTokens, ok := reqMap["max_completion_tokens"].(float64); ok {
tokens := int32(maxCompletionTokens)
maxTokensInt = &tokens
} else if maxTokens, ok := reqMap["max_tokens"].(float64); ok {
tokens := int32(maxTokens)
maxTokensInt = &tokens
}
if maxTokensInt != nil {
samplingParams.MaxNewTokens = maxTokensInt
}
// Parse tool constraints if available
if preprocessed.ToolConstraintsJSON != "" {
var toolConstraints map[string]interface{}
if err := json.Unmarshal([]byte(preprocessed.ToolConstraintsJSON), &toolConstraints); err == nil {
if regex, ok := toolConstraints["regex"].(string); ok {
samplingParams.Constraint = &proto.SamplingParams_Regex{Regex: regex}
} else if jsonSchema, ok := toolConstraints["json_schema"].(string); ok {
samplingParams.Constraint = &proto.SamplingParams_JsonSchema{JsonSchema: jsonSchema}
}
}
}
generateReq.SamplingParams = samplingParams
generateReq.Timestamp = timestamppb.Now()
stream, err := c.client.Generate(ctx, generateReq)
if err != nil {
return nil, fmt.Errorf("failed to create gRPC stream: %w", err)
}
toolsJSON := ""
if tools, ok := reqMap["tools"].([]interface{}); ok && len(tools) > 0 {
toolsBytes, _ := json.Marshal(tools)
toolsJSON = string(toolsBytes)
}
toolChoiceJSON := ""
if toolChoice, ok := reqMap["tool_choice"]; ok {
toolChoiceBytes, _ := json.Marshal(toolChoice)
toolChoiceJSON = string(toolChoiceBytes)
}
stopJSON := ""
if stop, ok := reqMap["stop"]; ok {
stopBytes, _ := json.Marshal(stop)
stopJSON = string(stopBytes)
}
stopTokenIDs := []uint32{}
if stopTokenIDsVal, ok := reqMap["stop_token_ids"].([]interface{}); ok {
for _, id := range stopTokenIDsVal {
if idFloat, ok := id.(float64); ok {
stopTokenIDs = append(stopTokenIDs, uint32(idFloat))
}
}
}
skipSpecialTokens := true
if skipSpecialTokensVal, ok := reqMap["skip_special_tokens"].(bool); ok {
skipSpecialTokens = skipSpecialTokensVal
}
if c.tokenizerHandle == nil {
stream.CloseSend()
return nil, fmt.Errorf("tokenizer handle is nil (should be created at startup)")
}
converterHandle, err := ffi.CreateGrpcResponseConverterWithTokenizer(
c.tokenizerHandle,
model,
generateReq.RequestId,
toolsJSON,
toolChoiceJSON,
stopJSON,
stopTokenIDs,
skipSpecialTokens,
preprocessed.PromptTokens, // Pass initial prompt tokens from preprocessing
)
if err != nil {
stream.CloseSend()
return nil, fmt.Errorf("failed to create converter handle: %w", err)
}
batchSize := 1
batchPostprocessor := ffi.NewBatchPostprocessor(converterHandle, batchSize, 0)
streamCtx, cancel := context.WithCancel(ctx)
grpcStream := &GrpcChatCompletionStream{
stream: stream,
converterHandle: converterHandle,
batchPostprocessor: batchPostprocessor,
batchSize: batchSize,
ctx: streamCtx,
cancel: cancel,
resultJSONChan: make(chan string, c.bufferSizes.ResultJSONChan),
errChan: make(chan error, c.bufferSizes.ErrChan),
readLoopDone: make(chan struct{}),
requestID: generateReq.RequestId,
model: model,
processWg: sync.WaitGroup{},
closeTimeout: c.timeouts.CloseTimeout,
bufferSizes: c.bufferSizes,
}
go grpcStream.readLoop()
return grpcStream, nil
}
// GrpcChatCompletionStream represents a streaming chat completion via gRPC
type GrpcChatCompletionStream struct {
stream grpcClientStream
converterHandle *ffi.GrpcResponseConverterHandle
batchPostprocessor *ffi.BatchPostprocessor
batchSize int
ctx context.Context
cancel context.CancelFunc
closed int32
resultJSONChan chan string
errChan chan error
readLoopDone chan struct{}
requestID string
model string
processWg sync.WaitGroup
closeTimeout time.Duration
bufferSizes ChannelBufferSizes
clientDisconnected int32 // Atomic flag: 1 if client disconnected, 0 otherwise
}
func (s *GrpcChatCompletionStream) readLoop() {
defer func() {
atomic.StoreInt32(&s.closed, 1)
s.processWg.Wait()
close(s.resultJSONChan)
close(s.errChan)
close(s.readLoopDone)
// Cancel context after channels are closed to ensure errors are read first
if s.cancel != nil {
s.cancel()
}
}()
recvChan := make(chan recvResult, s.bufferSizes.RecvChan)
const firstRecvTimeout = 60 * time.Second
go func() {
defer close(recvChan)
recvCount := 0
for {
select {
case <-s.ctx.Done():
// Skip CloseSend() if client disconnected
if atomic.LoadInt32(&s.clientDisconnected) == 0 {
_ = s.stream.CloseSend()
}
return
default:
}
recvCount++
var protoResp *proto.GenerateResponse
var err error
// First Recv() with timeout
if recvCount == 1 {
recvDone := make(chan recvResult, 1)
go func() {
resp, recvErr := s.stream.Recv()
recvDone <- recvResult{resp: resp, err: recvErr}
}()
select {
case result := <-recvDone:
protoResp = result.resp
err = result.err
case <-time.After(firstRecvTimeout):
timeoutErr := fmt.Errorf("stream.Recv() timeout after %v: backend may not be responding (request_id=%s)", firstRecvTimeout, s.requestID)
select {
case recvChan <- recvResult{resp: nil, err: timeoutErr}:
case <-s.ctx.Done():
}
return
case <-s.ctx.Done():
return
}
} else {
// Normal Recv()
protoResp, err = s.stream.Recv()
}
if err != nil {
select {
case recvChan <- recvResult{resp: nil, err: err}:
case <-s.ctx.Done():
return
}
return
}
select {
case <-s.ctx.Done():
// Skip CloseSend() if client disconnected
if atomic.LoadInt32(&s.clientDisconnected) == 0 {
_ = s.stream.CloseSend()
}
return
case recvChan <- recvResult{resp: protoResp, err: nil}:
}
}
}()
for {
select {
case <-s.ctx.Done():
// Skip CloseSend() if client disconnected
if atomic.LoadInt32(&s.clientDisconnected) == 0 {
_ = s.stream.CloseSend()
}
return
case result, ok := <-recvChan:
if !ok {
return
}
if result.err != nil {
if result.err == io.EOF {
results, flushErr := s.flushBatch()
if flushErr != nil {
select {
case s.errChan <- fmt.Errorf("failed to flush batch: %w", flushErr):
case <-s.ctx.Done():
}
return
}
for _, resultJSON := range results {
select {
case s.resultJSONChan <- resultJSON:
case <-s.ctx.Done():
return
}
}
return
}
select {
case s.errChan <- result.err:
case <-s.ctx.Done():
}
return
}
if result.resp != nil {
s.processWg.Add(1)
go func(resp *proto.GenerateResponse) {
defer s.processWg.Done()
s.processAndSendResponse(resp)
}(result.resp)
}
}
}
}
func (s *GrpcChatCompletionStream) processAndSendResponse(protoResp *proto.GenerateResponse) {
select {
case <-s.ctx.Done():
return
default:
}
if protoResp == nil {
return
}
protoJSON, err := protoToJSON(protoResp)
if err != nil {
select {
case s.errChan <- fmt.Errorf("failed to convert proto to JSON: %w", err):
case <-s.ctx.Done():
}
return
}
if s.batchPostprocessor == nil {
select {
case s.errChan <- fmt.Errorf("batch postprocessor is nil"):
case <-s.ctx.Done():
}
return
}
results, _, err := s.batchPostprocessor.AddChunk(protoJSON)
if err != nil {
select {
case s.errChan <- fmt.Errorf("batch postprocessing failed: %w", err):
case <-s.ctx.Done():
}
return
}
for _, resultJSON := range results {
select {
case s.resultJSONChan <- resultJSON:
case <-s.ctx.Done():
return
}
}
}
func (s *GrpcChatCompletionStream) RecvJSON() (string, error) {
// Use a loop instead of recursion to avoid stack overflow if there are many empty strings
for {
// Check errChan first to prioritize actual errors over context cancellation
select {
case err, ok := <-s.errChan:
if !ok {
return "", io.EOF
}
return "", err
default:
}
select {
case resultJSON, ok := <-s.resultJSONChan:
if !ok {
return "", io.EOF
}
// Skip empty strings and continue loop instead of recursing
if resultJSON != "" {
return resultJSON, nil
}
// Empty string, continue loop to get next result
continue
case err, ok := <-s.errChan:
if !ok {
return "", io.EOF
}
return "", err
case <-s.ctx.Done():
return "", s.ctx.Err()
}
}
}
// SetClientDisconnected marks that the client has disconnected.
// When Close() is called, it will not call CloseSend() to avoid aborting the request on server side.
func (s *GrpcChatCompletionStream) SetClientDisconnected() {
atomic.StoreInt32(&s.clientDisconnected, 1)
}
func (s *GrpcChatCompletionStream) Close() error {
if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
return nil
}
if s.cancel != nil {
s.cancel()
}
clientDisconnected := atomic.LoadInt32(&s.clientDisconnected) == 1
select {
case <-s.readLoopDone:
// readLoop completed
default:
if !clientDisconnected {
// Call CloseSend() if client didn't disconnect
_ = s.stream.CloseSend()
}
select {
case <-s.readLoopDone:
case <-time.After(s.closeTimeout):
}
}
_, _ = s.flushBatch()
if s.converterHandle != nil {
ffi.FreeGrpcResponseConverter(s.converterHandle)
}
return nil
}
func (s *GrpcChatCompletionStream) flushBatch() ([]string, error) {
if s.batchPostprocessor != nil {
results, err := s.batchPostprocessor.Flush()
if err != nil {
return nil, fmt.Errorf("batch flush failed: %w", err)
}
return results, nil
}
return nil, nil
}
func protoToJSON(resp *proto.GenerateResponse) (string, error) {
var sb strings.Builder
sb.Grow(500)
sb.WriteString(`{"request_id":`)
if resp.RequestId == "" {
sb.WriteString(`""`)
} else {
requestIDJSON, err := json.Marshal(resp.RequestId)
if err != nil {
return "", err
}
sb.Write(requestIDJSON)
}
switch r := resp.Response.(type) {
case *proto.GenerateResponse_Chunk:
sb.WriteString(`,"chunk":{`)
sb.WriteString(`"token_ids":`)
tokenIDsJSON, err := json.Marshal(r.Chunk.TokenIds)
if err != nil {
return "", err
}
sb.Write(tokenIDsJSON)
sb.WriteString(`,"prompt_tokens":`)
sb.WriteString(strconv.FormatInt(int64(r.Chunk.PromptTokens), 10))
sb.WriteString(`,"completion_tokens":`)
sb.WriteString(strconv.FormatInt(int64(r.Chunk.CompletionTokens), 10))
sb.WriteString(`,"cached_tokens":`)
sb.WriteString(strconv.FormatInt(int64(r.Chunk.CachedTokens), 10))
sb.WriteString(`,"index":`)
sb.WriteString(strconv.FormatInt(int64(r.Chunk.Index), 10))
sb.WriteString(`}`)
case *proto.GenerateResponse_Complete:
sb.WriteString(`,"complete":{`)
sb.WriteString(`"output_ids":`)
outputIDsJSON, err := json.Marshal(r.Complete.OutputIds)
if err != nil {
return "", err
}
sb.Write(outputIDsJSON)
sb.WriteString(`,"finish_reason":`)
finishReasonJSON, err := json.Marshal(r.Complete.FinishReason)
if err != nil {
return "", err
}
sb.Write(finishReasonJSON)
sb.WriteString(`,"prompt_tokens":`)
sb.WriteString(strconv.FormatInt(int64(r.Complete.PromptTokens), 10))
sb.WriteString(`,"completion_tokens":`)
sb.WriteString(strconv.FormatInt(int64(r.Complete.CompletionTokens), 10))
sb.WriteString(`,"cached_tokens":`)
sb.WriteString(strconv.FormatInt(int64(r.Complete.CachedTokens), 10))
sb.WriteString(`}`)
case *proto.GenerateResponse_Error:
sb.WriteString(`,"error":{`)
sb.WriteString(`"message":`)
messageJSON, err := json.Marshal(r.Error.Message)
if err != nil {
return "", err
}
sb.Write(messageJSON)
sb.WriteString(`,"http_status_code":`)
httpStatusCodeJSON, err := json.Marshal(r.Error.HttpStatusCode)
if err != nil {
return "", err
}
sb.Write(httpStatusCodeJSON)
if r.Error.Details != "" {
sb.WriteString(`,"details":`)
detailsJSON, err := json.Marshal(r.Error.Details)
if err != nil {
return "", err
}
sb.Write(detailsJSON)
}
sb.WriteString(`}`)
}
sb.WriteString(`}`)
return sb.String(), nil
}
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Choices []StreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
// StreamChoice represents a choice in a streaming response
type StreamChoice struct {
Index int `json:"index"`
Delta MessageDelta `json:"delta"`
FinishReason string `json:"finish_reason,omitempty"`
}
// MessageDelta represents incremental message updates
type MessageDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// ToolCall represents a tool call in the response
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
}
// FunctionCall represents a function call
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// Usage represents token usage information
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v3.21.12
// source: sglang_scheduler.proto
package proto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
SglangScheduler_Generate_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Generate"
SglangScheduler_Embed_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Embed"
SglangScheduler_HealthCheck_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/HealthCheck"
SglangScheduler_Abort_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Abort"
SglangScheduler_GetModelInfo_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/GetModelInfo"
SglangScheduler_GetServerInfo_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/GetServerInfo"
)
// SglangSchedulerClient is the client API for SglangScheduler service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler
type SglangSchedulerClient interface {
// Submit a generation request (supports streaming)
Generate(ctx context.Context, in *GenerateRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GenerateResponse], error)
// Submit an embedding request
Embed(ctx context.Context, in *EmbedRequest, opts ...grpc.CallOption) (*EmbedResponse, error)
// Health check and metrics
HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error)
// Abort a running request
Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*AbortResponse, error)
// Get model information
GetModelInfo(ctx context.Context, in *GetModelInfoRequest, opts ...grpc.CallOption) (*GetModelInfoResponse, error)
// Get server information
GetServerInfo(ctx context.Context, in *GetServerInfoRequest, opts ...grpc.CallOption) (*GetServerInfoResponse, error)
}
type sglangSchedulerClient struct {
cc grpc.ClientConnInterface
}
func NewSglangSchedulerClient(cc grpc.ClientConnInterface) SglangSchedulerClient {
return &sglangSchedulerClient{cc}
}
func (c *sglangSchedulerClient) Generate(ctx context.Context, in *GenerateRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GenerateResponse], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &SglangScheduler_ServiceDesc.Streams[0], SglangScheduler_Generate_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[GenerateRequest, GenerateResponse]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type SglangScheduler_GenerateClient = grpc.ServerStreamingClient[GenerateResponse]
func (c *sglangSchedulerClient) Embed(ctx context.Context, in *EmbedRequest, opts ...grpc.CallOption) (*EmbedResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(EmbedResponse)
err := c.cc.Invoke(ctx, SglangScheduler_Embed_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *sglangSchedulerClient) HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(HealthCheckResponse)
err := c.cc.Invoke(ctx, SglangScheduler_HealthCheck_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *sglangSchedulerClient) Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*AbortResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AbortResponse)
err := c.cc.Invoke(ctx, SglangScheduler_Abort_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *sglangSchedulerClient) GetModelInfo(ctx context.Context, in *GetModelInfoRequest, opts ...grpc.CallOption) (*GetModelInfoResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetModelInfoResponse)
err := c.cc.Invoke(ctx, SglangScheduler_GetModelInfo_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *sglangSchedulerClient) GetServerInfo(ctx context.Context, in *GetServerInfoRequest, opts ...grpc.CallOption) (*GetServerInfoResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetServerInfoResponse)
err := c.cc.Invoke(ctx, SglangScheduler_GetServerInfo_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// SglangSchedulerServer is the server API for SglangScheduler service.
// All implementations must embed UnimplementedSglangSchedulerServer
// for forward compatibility.
//
// Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler
type SglangSchedulerServer interface {
// Submit a generation request (supports streaming)
Generate(*GenerateRequest, grpc.ServerStreamingServer[GenerateResponse]) error
// Submit an embedding request
Embed(context.Context, *EmbedRequest) (*EmbedResponse, error)
// Health check and metrics
HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error)
// Abort a running request
Abort(context.Context, *AbortRequest) (*AbortResponse, error)
// Get model information
GetModelInfo(context.Context, *GetModelInfoRequest) (*GetModelInfoResponse, error)
// Get server information
GetServerInfo(context.Context, *GetServerInfoRequest) (*GetServerInfoResponse, error)
mustEmbedUnimplementedSglangSchedulerServer()
}
// UnimplementedSglangSchedulerServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedSglangSchedulerServer struct{}
func (UnimplementedSglangSchedulerServer) Generate(*GenerateRequest, grpc.ServerStreamingServer[GenerateResponse]) error {
return status.Errorf(codes.Unimplemented, "method Generate not implemented")
}
func (UnimplementedSglangSchedulerServer) Embed(context.Context, *EmbedRequest) (*EmbedResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Embed not implemented")
}
func (UnimplementedSglangSchedulerServer) HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method HealthCheck not implemented")
}
func (UnimplementedSglangSchedulerServer) Abort(context.Context, *AbortRequest) (*AbortResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Abort not implemented")
}
func (UnimplementedSglangSchedulerServer) GetModelInfo(context.Context, *GetModelInfoRequest) (*GetModelInfoResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetModelInfo not implemented")
}
func (UnimplementedSglangSchedulerServer) GetServerInfo(context.Context, *GetServerInfoRequest) (*GetServerInfoResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetServerInfo not implemented")
}
func (UnimplementedSglangSchedulerServer) mustEmbedUnimplementedSglangSchedulerServer() {}
func (UnimplementedSglangSchedulerServer) testEmbeddedByValue() {}
// UnsafeSglangSchedulerServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to SglangSchedulerServer will
// result in compilation errors.
type UnsafeSglangSchedulerServer interface {
mustEmbedUnimplementedSglangSchedulerServer()
}
func RegisterSglangSchedulerServer(s grpc.ServiceRegistrar, srv SglangSchedulerServer) {
// If the following call pancis, it indicates UnimplementedSglangSchedulerServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&SglangScheduler_ServiceDesc, srv)
}
func _SglangScheduler_Generate_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(GenerateRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(SglangSchedulerServer).Generate(m, &grpc.GenericServerStream[GenerateRequest, GenerateResponse]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type SglangScheduler_GenerateServer = grpc.ServerStreamingServer[GenerateResponse]
func _SglangScheduler_Embed_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EmbedRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SglangSchedulerServer).Embed(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: SglangScheduler_Embed_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SglangSchedulerServer).Embed(ctx, req.(*EmbedRequest))
}
return interceptor(ctx, in, info, handler)
}
func _SglangScheduler_HealthCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthCheckRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SglangSchedulerServer).HealthCheck(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: SglangScheduler_HealthCheck_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SglangSchedulerServer).HealthCheck(ctx, req.(*HealthCheckRequest))
}
return interceptor(ctx, in, info, handler)
}
func _SglangScheduler_Abort_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AbortRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SglangSchedulerServer).Abort(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: SglangScheduler_Abort_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SglangSchedulerServer).Abort(ctx, req.(*AbortRequest))
}
return interceptor(ctx, in, info, handler)
}
func _SglangScheduler_GetModelInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetModelInfoRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SglangSchedulerServer).GetModelInfo(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: SglangScheduler_GetModelInfo_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SglangSchedulerServer).GetModelInfo(ctx, req.(*GetModelInfoRequest))
}
return interceptor(ctx, in, info, handler)
}
func _SglangScheduler_GetServerInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetServerInfoRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SglangSchedulerServer).GetServerInfo(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: SglangScheduler_GetServerInfo_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SglangSchedulerServer).GetServerInfo(ctx, req.(*GetServerInfoRequest))
}
return interceptor(ctx, in, info, handler)
}
// SglangScheduler_ServiceDesc is the grpc.ServiceDesc for SglangScheduler service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var SglangScheduler_ServiceDesc = grpc.ServiceDesc{
ServiceName: "sglang.grpc.scheduler.SglangScheduler",
HandlerType: (*SglangSchedulerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Embed",
Handler: _SglangScheduler_Embed_Handler,
},
{
MethodName: "HealthCheck",
Handler: _SglangScheduler_HealthCheck_Handler,
},
{
MethodName: "Abort",
Handler: _SglangScheduler_Abort_Handler,
},
{
MethodName: "GetModelInfo",
Handler: _SglangScheduler_GetModelInfo_Handler,
},
{
MethodName: "GetServerInfo",
Handler: _SglangScheduler_GetServerInfo_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "Generate",
Handler: _SglangScheduler_Generate_Handler,
ServerStreams: true,
},
},
Metadata: "sglang_scheduler.proto",
}

View File

@@ -0,0 +1,279 @@
//! Client SDK FFI functions
use std::ffi::{CStr, CString};
use std::os::raw::{c_char};
use std::ptr;
use std::sync::Arc;
use tokio::runtime::Runtime;
use once_cell::sync::Lazy;
use uuid::Uuid;
use smg::tokenizer::create_tokenizer_from_file;
use smg::tokenizer::traits::Tokenizer;
use smg_grpc_client::sglang_scheduler::SglangSchedulerClient;
use smg::protocols::chat::ChatCompletionRequest;
use smg::routers::grpc::utils::{process_chat_messages, generate_tool_constraints};
use super::error::{SglErrorCode, set_error_message};
use super::grpc_converter::sgl_grpc_response_converter_create;
use super::tokenizer::TokenizerHandle;
use super::stream::SglangStreamHandle;
/// Global tokio runtime for async operations
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
Runtime::new().expect("Failed to create tokio runtime for client FFI")
});
/// Handle for complete client SDK (gRPC client + tokenizer)
/// This handle manages the connection to sglang and provides a complete SDK interface
pub struct SglangClientHandle {
pub(crate) client: Arc<SglangSchedulerClient>,
pub(crate) tokenizer: Arc<dyn Tokenizer>,
}
/// Handle for streaming request (includes prompt token count)
#[allow(dead_code)]
pub struct StreamRequestState {
pub(crate) prompt_tokens: i32, // Number of prompt tokens for this request
}
/// Create a new SGLang client handle
///
/// # Arguments
/// * `endpoint` - gRPC endpoint (e.g., "grpc://localhost:20000")
/// * `tokenizer_path` - Path to tokenizer directory
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * Pointer to SglangClientHandle on success, null on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_client_create(
endpoint: *const c_char,
tokenizer_path: *const c_char,
error_out: *mut *mut c_char,
) -> *mut SglangClientHandle {
if endpoint.is_null() || tokenizer_path.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return ptr::null_mut();
}
let endpoint_str = match CStr::from_ptr(endpoint).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in endpoint");
return ptr::null_mut();
}
};
let tokenizer_path_str = match CStr::from_ptr(tokenizer_path).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in tokenizer_path");
return ptr::null_mut();
}
};
// Create tokenizer
let tokenizer = match create_tokenizer_from_file(tokenizer_path_str) {
Ok(t) => t,
Err(e) => {
set_error_message(error_out, &format!("Failed to create tokenizer: {}", e));
return ptr::null_mut();
}
};
// Create gRPC client
let client = match RUNTIME.block_on(async {
SglangSchedulerClient::connect(endpoint_str).await
}) {
Ok(c) => Arc::new(c),
Err(e) => {
set_error_message(error_out, &format!("Failed to connect to endpoint: {}", e));
return ptr::null_mut();
}
};
Box::into_raw(Box::new(SglangClientHandle {
client,
tokenizer,
}))
}
/// Free a client handle
#[no_mangle]
pub unsafe extern "C" fn sgl_client_free(handle: *mut SglangClientHandle) {
if !handle.is_null() {
let _ = Box::from_raw(handle);
}
}
/// Send a chat completion request and start streaming
///
/// # Arguments
/// * `client_handle` - Client handle
/// * `request_json` - OpenAI ChatCompletionRequest as JSON string
/// * `stream_handle_out` - Pointer to receive stream handle
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_client_chat_completion_stream(
client_handle: *mut SglangClientHandle,
request_json: *const c_char,
stream_handle_out: *mut *mut SglangStreamHandle,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if client_handle.is_null() || request_json.is_null() || stream_handle_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let request_str = match CStr::from_ptr(request_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in request_json");
return SglErrorCode::InvalidArgument;
}
};
let client_ref = &*client_handle;
let client = Arc::clone(&client_ref.client);
let tokenizer = Arc::clone(&client_ref.tokenizer);
// Parse OpenAI ChatCompletionRequest
let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) {
Ok(req) => req,
Err(e) => {
set_error_message(error_out, &format!("Failed to parse request JSON: {}", e));
return SglErrorCode::ParsingError;
}
};
// Process messages and apply chat template
let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) {
Ok(msgs) => msgs,
Err(e) => {
set_error_message(error_out, &format!("Failed to process messages: {}", e));
return SglErrorCode::TokenizationError;
}
};
// Tokenize
let token_ids = match tokenizer.encode(&processed_messages.text, false) {
Ok(encoding) => encoding.token_ids().to_vec(),
Err(e) => {
set_error_message(error_out, &format!("Failed to tokenize: {}", e));
return SglErrorCode::TokenizationError;
}
};
let prompt_tokens = token_ids.len() as i32; // Save prompt token count
// Generate tool constraints if needed
let tool_constraint = if let Some(tools) = chat_request.tools.as_ref() {
match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) {
Ok(Some((constraint_type, constraint_value))) => Some((constraint_type, constraint_value)),
Ok(None) => None,
Err(e) => {
set_error_message(error_out, &format!("Failed to generate tool constraints: {}", e));
return SglErrorCode::ParsingError;
}
}
} else {
None
};
// Build GenerateRequest
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let proto_request = match client.build_generate_request_from_chat(
request_id.clone(),
&chat_request,
processed_messages.text,
token_ids,
processed_messages.multimodal_inputs,
tool_constraint,
) {
Ok(req) => req,
Err(e) => {
set_error_message(error_out, &format!("Failed to build generate request: {}", e));
return SglErrorCode::ParsingError;
}
};
// Send request and get stream
let stream = match RUNTIME.block_on(async {
client.generate(proto_request).await
}) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to send request: {}", e));
return SglErrorCode::UnknownError;
}
};
// Create response converter
let tools_json = chat_request.tools.as_ref()
.and_then(|t| serde_json::to_string(t).ok())
.map(|s| CString::new(s).unwrap().into_raw());
let tool_choice_json = chat_request.tool_choice.as_ref()
.and_then(|tc| serde_json::to_string(tc).ok())
.map(|s| CString::new(s).unwrap().into_raw());
let stop_json = chat_request.stop.as_ref()
.and_then(|s| serde_json::to_string(s).ok())
.map(|s| CString::new(s).unwrap().into_raw());
let stop_token_ids_json = chat_request.stop_token_ids.as_ref()
.and_then(|ids| serde_json::to_string(ids).ok())
.map(|s| CString::new(s).unwrap().into_raw());
// Create tokenizer handle for converter (we'll create a temporary one)
let tokenizer_handle = Box::into_raw(Box::new(TokenizerHandle {
tokenizer: Arc::clone(&tokenizer),
}));
let converter = sgl_grpc_response_converter_create(
tokenizer_handle,
CString::new(chat_request.model.clone()).unwrap().as_ptr(),
CString::new(request_id.clone()).unwrap().as_ptr(),
tools_json.unwrap_or(ptr::null_mut()),
tool_choice_json.unwrap_or(ptr::null_mut()),
stop_json.unwrap_or(ptr::null_mut()),
stop_token_ids_json.unwrap_or(ptr::null_mut()),
if chat_request.skip_special_tokens { 1 } else { 0 },
error_out,
);
// Free temporary tokenizer handle (converter now owns the tokenizer)
let _ = Box::from_raw(tokenizer_handle);
if converter.is_null() {
return SglErrorCode::MemoryError;
}
// Clean up temporary CStrings
if let Some(ptr) = tools_json {
let _ = CString::from_raw(ptr);
}
if let Some(ptr) = tool_choice_json {
let _ = CString::from_raw(ptr);
}
if let Some(ptr) = stop_json {
let _ = CString::from_raw(ptr);
}
if let Some(ptr) = stop_token_ids_json {
let _ = CString::from_raw(ptr);
}
// Create converter handle and set initial_prompt_tokens immediately
let mut converter_handle = *Box::from_raw(converter);
converter_handle.initial_prompt_tokens = Some(prompt_tokens);
// Create stream handle with prompt_tokens
*stream_handle_out = Box::into_raw(Box::new(SglangStreamHandle {
stream: Arc::new(tokio::sync::Mutex::new(stream)),
converter: Arc::new(tokio::sync::Mutex::new(converter_handle)),
client: Arc::clone(&client),
prompt_tokens,
}));
SglErrorCode::Success
}

View File

@@ -0,0 +1,50 @@
//! Error handling for FFI functions
use std::ffi::CString;
use std::os::raw::c_char;
use std::ptr;
/// Error codes returned by FFI functions
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SglErrorCode {
Success = 0,
InvalidArgument = 1,
TokenizationError = 2,
ParsingError = 3,
MemoryError = 4,
UnknownError = 99,
}
/// Helper to set error message in FFI output parameter
pub fn set_error_message(error_out: *mut *mut c_char, message: &str) {
unsafe {
if !error_out.is_null() {
if let Ok(cstr) = CString::new(message) {
*error_out = cstr.into_raw();
} else {
*error_out = ptr::null_mut();
}
}
}
}
/// Helper to set error message from format string
pub fn set_error_message_fmt(error_out: *mut *mut c_char, fmt: std::fmt::Arguments) {
if !error_out.is_null() {
let msg = format!("{}", fmt);
set_error_message(error_out, &msg);
}
}
/// Helper to clear error message
pub fn clear_error_message(error_out: *mut *mut c_char) {
unsafe {
if !error_out.is_null() {
*error_out = ptr::null_mut();
}
}
}
// Helper functions for error handling
// Note: Some helper functions are kept for potential future use

View File

@@ -0,0 +1,758 @@
//! gRPC response converter FFI functions
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int};
use std::ptr;
use std::sync::Arc;
use std::collections::HashMap;
use serde_json::Value;
use tokio::runtime::Runtime;
use once_cell::sync::Lazy;
use smg::tokenizer::traits::Tokenizer;
use smg::tokenizer::stream::DecodeStream;
use smg::tool_parser::ToolParser;
use smg::protocols::common::{Tool, ToolChoice, ToolChoiceValue, ToolCallDelta, FunctionCallDelta, Usage, StringOrArray};
use smg::tokenizer::stop::StopSequenceDecoder;
use smg_grpc_client::sglang_proto as proto;
use super::error::{SglErrorCode, set_error_message, clear_error_message};
use super::tokenizer::TokenizerHandle;
use super::utils::generate_tool_call_id;
/// Global parser factory (initialized once)
// Use the re-exported ParserFactory from tool_parser module
static PARSER_FACTORY: Lazy<smg::tool_parser::ParserFactory> = Lazy::new(|| {
// ParserFactory is re-exported from tool_parser::factory, so we can use it directly
smg::tool_parser::ParserFactory::default()
});
/// Global tokio runtime for async operations
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
Runtime::new().expect("Failed to create tokio runtime for gRPC converter FFI")
});
/// Handle for gRPC response converter (maintains state for streaming)
#[repr(C)]
pub struct GrpcResponseConverterHandle {
pub(crate) tokenizer: Arc<dyn Tokenizer>,
pub(crate) tool_parser: Option<Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>>,
pub(crate) stop_decoder: Option<Arc<tokio::sync::Mutex<StopSequenceDecoder>>>,
pub(crate) model: String,
pub(crate) request_id: String,
pub(crate) created: u64,
pub(crate) system_fingerprint: Option<String>,
pub(crate) tools: Option<Vec<Tool>>,
pub(crate) tool_choice: Option<ToolChoice>,
pub(crate) history_tool_calls_count: usize,
pub(crate) stream_buffers: HashMap<u32, String>, // Per-index text buffers
pub(crate) decode_streams: HashMap<u32, DecodeStream>, // Per-index incremental decoders
pub(crate) has_tool_calls: HashMap<u32, bool>, // Track if tool calls were emitted
pub(crate) is_first_chunk: HashMap<u32, bool>, // Track first chunk per index
pub(crate) prompt_tokens: HashMap<u32, i32>, // Track prompt tokens per index (from chunks)
pub(crate) completion_tokens: HashMap<u32, i32>, // Track completion tokens per index (cumulative)
pub(crate) initial_prompt_tokens: Option<i32>, // Initial prompt tokens from request (if available)
pub(crate) skip_special_tokens: bool, // Whether to skip special tokens when decoding
}
/// Create a gRPC response converter handle
///
/// # Arguments
/// * `tokenizer_handle` - Tokenizer handle (must be valid)
/// * `model` - Model name
/// * `request_id` - Request ID
/// * `tools_json` - Optional JSON array of tools
/// * `tool_choice_json` - Optional JSON object for tool_choice
/// * `stop` - Optional stop sequences (JSON array)
/// * `stop_token_ids` - Optional stop token IDs (JSON array)
/// * `skip_special_tokens` - Whether to skip special tokens
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * Pointer to GrpcResponseConverterHandle on success, null on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_grpc_response_converter_create(
tokenizer_handle: *mut TokenizerHandle,
model: *const c_char,
request_id: *const c_char,
tools_json: *const c_char,
tool_choice_json: *const c_char,
stop: *const c_char,
stop_token_ids: *const c_char,
skip_special_tokens: c_int,
error_out: *mut *mut c_char,
) -> *mut GrpcResponseConverterHandle {
if tokenizer_handle.is_null() || model.is_null() || request_id.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return ptr::null_mut();
}
let model_str = match CStr::from_ptr(model).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in model");
return ptr::null_mut();
}
};
let request_id_str = match CStr::from_ptr(request_id).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in request_id");
return ptr::null_mut();
}
};
let handle_ref = &*tokenizer_handle;
let tokenizer = Arc::clone(&handle_ref.tokenizer);
// Parse tools if provided
let tools: Option<Vec<Tool>> = if !tools_json.is_null() {
match CStr::from_ptr(tools_json).to_str() {
Ok(s) => serde_json::from_str::<Vec<Tool>>(s).ok(),
Err(_) => None,
}
} else {
None
};
// Parse tool_choice if provided
let tool_choice: Option<ToolChoice> = if !tool_choice_json.is_null() {
match CStr::from_ptr(tool_choice_json).to_str() {
Ok(s) => serde_json::from_str::<ToolChoice>(s).ok(),
Err(_) => None,
}
} else {
None
};
// Parse stop sequences
let stop: Option<StringOrArray> = if !stop.is_null() {
let stop_str = match CStr::from_ptr(stop).to_str() {
Ok(s) => s,
Err(_) => return ptr::null_mut(),
};
serde_json::from_str::<StringOrArray>(stop_str).ok()
} else {
None
};
// Parse stop token IDs
let stop_token_ids: Option<Vec<u32>> = if !stop_token_ids.is_null() {
let ids_str = match CStr::from_ptr(stop_token_ids).to_str() {
Ok(s) => s,
Err(_) => return ptr::null_mut(),
};
serde_json::from_str::<Vec<u32>>(ids_str).ok()
} else {
None
};
// Create stop decoder if needed
let stop_decoder = if stop.is_some() || stop_token_ids.is_some() {
Some(Arc::new(tokio::sync::Mutex::new(
smg::routers::grpc::utils::create_stop_decoder(
&tokenizer,
stop.as_ref(),
stop_token_ids.as_ref(),
skip_special_tokens != 0,
false, // no_stop_trim
),
)))
} else {
None
};
// Create tool parser if tools are provided
let tool_parser = if tools.is_some() {
PARSER_FACTORY.registry().create_for_model(model_str)
.map(|p| Arc::new(tokio::sync::Mutex::new(p)))
} else {
None
};
// Get system fingerprint from model (simplified)
let system_fingerprint = Some("fp_placeholder".to_string()); // TODO: Get actual fingerprint
Box::into_raw(Box::new(GrpcResponseConverterHandle {
tokenizer,
tool_parser,
stop_decoder,
model: model_str.to_string(),
request_id: request_id_str.to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
system_fingerprint,
tools,
tool_choice,
history_tool_calls_count: 0,
stream_buffers: HashMap::new(),
decode_streams: HashMap::new(),
has_tool_calls: HashMap::new(),
is_first_chunk: HashMap::new(),
prompt_tokens: HashMap::new(),
completion_tokens: HashMap::new(),
initial_prompt_tokens: None, // Will be set from stream handle
skip_special_tokens: skip_special_tokens != 0,
}))
}
/// Convert a gRPC GenerateResponse chunk to OpenAI format
///
/// # Arguments
/// * `handle` - Converter handle
/// * `response_json` - JSON string of proto.GenerateResponse
/// * `result_json_out` - Pointer to receive OpenAI format JSON (must be freed with sgl_free_string)
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_grpc_response_converter_convert_chunk(
handle: *mut GrpcResponseConverterHandle,
response_json: *const c_char,
result_json_out: *mut *mut c_char,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if handle.is_null() || response_json.is_null() || result_json_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let response_str = match CStr::from_ptr(response_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in response_json");
return SglErrorCode::InvalidArgument;
}
};
// Parse proto.GenerateResponse from JSON
let json_value: Value = match serde_json::from_str(response_str) {
Ok(v) => v,
Err(e) => {
set_error_message(error_out, &format!("Failed to parse response JSON: {}", e));
return SglErrorCode::ParsingError;
}
};
// Build proto::GenerateResponse from JSON value
let mut proto_response = proto::GenerateResponse {
request_id: json_value.get("request_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
response: None,
};
// Parse the response oneof field
if let Some(chunk_json) = json_value.get("chunk") {
let chunk = proto::GenerateStreamChunk {
token_ids: chunk_json.get("token_ids")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_u64().map(|n| n as u32)).collect())
.unwrap_or_default(),
prompt_tokens: chunk_json.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
completion_tokens: chunk_json.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
cached_tokens: chunk_json.get("cached_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
output_logprobs: None,
hidden_states: vec![],
input_logprobs: None,
index: 0,
};
proto_response.response = Some(proto::generate_response::Response::Chunk(chunk));
} else if let Some(complete_json) = json_value.get("complete") {
let complete = proto::GenerateComplete {
output_ids: complete_json.get("output_ids")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_u64().map(|n| n as u32)).collect())
.unwrap_or_default(),
finish_reason: complete_json.get("finish_reason")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
prompt_tokens: complete_json.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
completion_tokens: complete_json.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
cached_tokens: complete_json.get("cached_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
output_logprobs: None,
all_hidden_states: vec![],
input_logprobs: None,
matched_stop: None,
index: 0,
};
proto_response.response = Some(proto::generate_response::Response::Complete(complete));
} else if let Some(error_json) = json_value.get("error") {
let error = proto::GenerateError {
message: error_json.get("message")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
http_status_code: error_json.get("http_status_code")
.and_then(|v| v.as_str())
.unwrap_or("500")
.to_string(),
details: error_json.get("details")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
};
proto_response.response = Some(proto::generate_response::Response::Error(error));
} else {
set_error_message(error_out, "Response JSON must contain 'chunk', 'complete', or 'error' field");
return SglErrorCode::ParsingError;
}
let handle_ref = &mut *handle;
let tokenizer = Arc::clone(&handle_ref.tokenizer);
let model = handle_ref.model.clone();
let request_id = handle_ref.request_id.clone();
let created = handle_ref.created;
let system_fingerprint = handle_ref.system_fingerprint.clone();
// Use tokio runtime to run async code
let result = RUNTIME.block_on(async {
convert_proto_chunk_to_openai(
proto_response,
handle_ref,
&tokenizer,
&model,
&request_id,
created,
system_fingerprint.as_deref(),
)
.await
});
match result {
Ok(Some(openai_response)) => {
// Serialize to JSON
let result_str = match serde_json::to_string(&openai_response) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to serialize response: {}", e));
return SglErrorCode::ParsingError;
}
};
let result_cstr = match CString::new(result_str) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create result string: {}", e));
return SglErrorCode::MemoryError;
}
};
*result_json_out = result_cstr.into_raw();
clear_error_message(error_out);
SglErrorCode::Success
}
Ok(None) => {
// No response to send (e.g., empty chunk)
let empty = CString::new("").unwrap();
*result_json_out = empty.into_raw();
clear_error_message(error_out);
SglErrorCode::Success
}
Err(e) => {
set_error_message(error_out, &format!("Conversion error: {}", e));
SglErrorCode::ParsingError
}
}
}
/// Helper function to convert proto chunk to OpenAI format
pub(crate) async fn convert_proto_chunk_to_openai(
proto_response: proto::GenerateResponse,
handle: &mut GrpcResponseConverterHandle,
tokenizer: &Arc<dyn Tokenizer>,
model: &str,
request_id: &str,
created: u64,
system_fingerprint: Option<&str>,
) -> Result<Option<smg::protocols::chat::ChatCompletionStreamResponse>, String> {
use smg_grpc_client::sglang_proto::generate_response::Response::*;
use smg::protocols::chat::{ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice};
match proto_response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Mark as not first chunk if we've seen this index before
let is_first = handle.is_first_chunk.entry(index).or_insert(true);
let first_chunk = *is_first;
*is_first = false;
// Track token counts from chunks (cumulative values from proto)
// These are cumulative values, so we always use the latest value
// For prompt_tokens, if chunk value is 0, preserve existing value or use initial_prompt_tokens
// This prevents overwriting valid prompt_tokens with 0
if chunk.prompt_tokens > 0 {
handle.prompt_tokens.insert(index, chunk.prompt_tokens);
} else {
// If chunk.prompt_tokens is 0, try to preserve existing value or use initial_prompt_tokens
if !handle.prompt_tokens.contains_key(&index) {
// No existing value, try to use initial_prompt_tokens
if let Some(initial_prompt) = handle.initial_prompt_tokens {
handle.prompt_tokens.insert(index, initial_prompt);
}
}
// If existing value exists, keep it (don't overwrite with 0)
}
// For completion_tokens, always update (even if 0) as it's cumulative
handle.completion_tokens.insert(index, chunk.completion_tokens);
// Process tokens through stop decoder if available, otherwise use incremental decoder
let chunk_text = if let Some(ref stop_decoder) = handle.stop_decoder {
let mut decoder_guard = stop_decoder.lock().await;
let mut text = String::new();
for &token_id in &chunk.token_ids {
match decoder_guard.process_token(token_id).unwrap_or_else(|_| {
smg::tokenizer::stop::SequenceDecoderOutput::Held
}) {
smg::tokenizer::stop::SequenceDecoderOutput::Text(t) => {
text.push_str(&t);
}
smg::tokenizer::stop::SequenceDecoderOutput::StoppedWithText(t) => {
text.push_str(&t);
break;
}
smg::tokenizer::stop::SequenceDecoderOutput::Stopped => {
break;
}
smg::tokenizer::stop::SequenceDecoderOutput::Held => {}
}
}
text
} else {
// Use incremental decoder to handle multi-byte character boundaries
let decode_stream = handle.decode_streams.entry(index).or_insert_with(|| {
DecodeStream::new(
Arc::clone(&tokenizer),
&[], // No prompt tokens for completion
handle.skip_special_tokens,
)
});
// Process tokens incrementally
let mut text_parts = Vec::new();
for &token_id in &chunk.token_ids {
if let Ok(Some(text)) = decode_stream.step(token_id) {
text_parts.push(text);
}
}
text_parts.join("")
};
if chunk_text.is_empty() {
return Ok(None);
}
// Send first chunk with role
if first_chunk {
let first_response = ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
return Ok(Some(first_response));
}
// Update stream buffer
let stream_buffer = handle.stream_buffers.entry(index).or_default();
stream_buffer.push_str(&chunk_text);
// Handle tool calls if tools are provided
if let (Some(ref tools), Some(ref tool_parser)) = (handle.tools.as_ref(), handle.tool_parser.as_ref()) {
let tool_choice_enabled = !matches!(
handle.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
);
if tool_choice_enabled {
let mut parser_guard = tool_parser.lock().await;
match parser_guard.parse_incremental(&chunk_text, tools).await {
Ok(streaming_result) => {
if !streaming_result.calls.is_empty() {
handle.has_tool_calls.insert(index, true);
// Convert tool call items to OpenAI format
let tool_call_deltas: Vec<_> = streaming_result
.calls
.into_iter()
.map(|item| {
let id = if let Some(ref name) = item.name {
generate_tool_call_id(
model,
name,
item.tool_index,
handle.history_tool_calls_count,
)
} else {
format!("call_{}", item.tool_index)
};
ToolCallDelta {
index: item.tool_index as u32,
id: Some(id),
tool_type: if item.name.is_some() {
Some("function".to_string())
} else {
None
},
function: Some(FunctionCallDelta {
name: item.name,
arguments: if !item.parameters.is_empty() {
Some(item.parameters)
} else {
None
},
}),
}
})
.collect();
let tool_response = ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(tool_call_deltas),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
return Ok(Some(tool_response));
}
}
Err(e) => {
// Log error but continue with regular content
tracing::warn!("Tool parser error: {}", e);
}
}
}
}
// Regular content emission
let content_response = ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(chunk_text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
Ok(Some(content_response))
}
Some(Complete(complete)) => {
let index = complete.index;
// Flush any remaining text
// Flush any remaining text from decode stream
let mut final_text = handle.stream_buffers.remove(&index).unwrap_or_default();
if let Some(ref mut decode_stream) = handle.decode_streams.get_mut(&index) {
if let Ok(Some(remaining)) = decode_stream.flush() {
final_text.push_str(&remaining);
}
}
handle.decode_streams.remove(&index);
// Determine finish reason - ensure it's never empty
// If finish_reason is empty, try to infer from other fields or use default
let finish_reason = if handle.has_tool_calls.get(&index).copied().unwrap_or(false)
&& (complete.finish_reason == "stop" || complete.finish_reason.is_empty())
{
"tool_calls".to_string()
} else if complete.finish_reason.is_empty() || complete.finish_reason.trim().is_empty() {
// If finish_reason is empty, try to infer from completion_tokens or use default
if complete.completion_tokens > 0 {
// If we have completion tokens, likely stopped normally
"stop".to_string()
} else if !complete.output_ids.is_empty() {
// If we have output_ids, likely stopped normally
"stop".to_string()
} else {
// Default fallback - always ensure we have a value
"stop".to_string()
}
} else {
complete.finish_reason.clone()
};
// Ensure finish_reason is never empty (defensive check)
let finish_reason = if finish_reason.is_empty() || finish_reason.trim().is_empty() {
"stop".to_string()
} else {
finish_reason
};
// Extract matched_stop
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
// Build usage - prefer values from complete message, but fallback to accumulated values from chunks
// Complete message should have the final values, but sometimes they might be 0 or missing
// Always use the latest cumulative value from chunks if available, otherwise use complete message value
let mut prompt_tokens = handle.prompt_tokens.get(&index)
.copied()
.filter(|&v| v > 0)
.unwrap_or(complete.prompt_tokens);
let mut completion_tokens = handle.completion_tokens.get(&index)
.copied()
.filter(|&v| v > 0)
.unwrap_or(complete.completion_tokens);
// Always try to use initial_prompt_tokens if prompt_tokens is 0 or missing
// This is the most reliable source for prompt tokens since we calculate it from the request
if prompt_tokens == 0 {
if let Some(initial_prompt) = handle.initial_prompt_tokens {
prompt_tokens = initial_prompt;
}
}
// If completion_tokens is 0, try to infer from output_ids or accumulated chunks
if completion_tokens == 0 {
// Try to use completion_tokens from complete message even if 0
// Or calculate from output_ids
if complete.completion_tokens > 0 {
completion_tokens = complete.completion_tokens;
} else if !complete.output_ids.is_empty() {
completion_tokens = complete.output_ids.len() as i32;
} else if let Some(&last_completion) = handle.completion_tokens.get(&index) {
completion_tokens = last_completion;
}
}
// Final fallback: if both are still 0, try to use initial_prompt_tokens for prompt
// and calculate completion from output_ids
if prompt_tokens == 0 && completion_tokens == 0 {
// Try to infer from output_ids if available
let output_ids_len = complete.output_ids.len() as i32;
if output_ids_len > 0 {
completion_tokens = output_ids_len;
// Always try to use initial_prompt_tokens for prompt
if let Some(initial_prompt) = handle.initial_prompt_tokens {
prompt_tokens = initial_prompt;
}
}
}
// Final defensive check: ensure prompt_tokens is set if we have initial_prompt_tokens
if prompt_tokens == 0 {
if let Some(initial_prompt) = handle.initial_prompt_tokens {
prompt_tokens = initial_prompt;
}
}
// Always create usage, even if values are 0 (defensive)
let usage = Some(Usage {
prompt_tokens: prompt_tokens.max(0) as u32,
completion_tokens: completion_tokens.max(0) as u32,
total_tokens: (prompt_tokens.max(0) + completion_tokens.max(0)) as u32,
completion_tokens_details: None,
});
let finish_response = ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: system_fingerprint.map(|s| s.to_string()),
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: if !final_text.is_empty() {
Some(final_text)
} else {
None
},
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(finish_reason),
matched_stop,
}],
usage,
};
Ok(Some(finish_response))
}
Some(Error(error)) => {
Err(format!("Server error: {} (status: {})", error.message, error.http_status_code))
}
None => Ok(None),
}
}
/// Free a gRPC response converter handle
#[no_mangle]
pub unsafe extern "C" fn sgl_grpc_response_converter_free(handle: *mut GrpcResponseConverterHandle) {
if !handle.is_null() {
let _ = Box::from_raw(handle);
}
}

View File

@@ -0,0 +1,103 @@
//! FFI module for exposing sgl-model-gateway preprocessing and postprocessing functions
//! to C-compatible languages (e.g., Golang via cgo)
//!
//! This module provides C-compatible function signatures for:
//! - Tokenizer operations (encode, decode, chat template)
//! - Tool parser operations (parse tool calls)
//! - Tool constraint generation
//! - gRPC client SDK (complete request-response flow)
//!
//! # Safety
//! All functions marked with `#[no_mangle]` and `extern "C"` must be called
//! with valid pointers and follow the documented memory management rules.
// Re-export error types
pub use error::{SglErrorCode, set_error_message, set_error_message_fmt, clear_error_message};
// Re-export memory management functions
pub use memory::{sgl_free_string, sgl_free_token_ids};
// Re-export tokenizer functions
pub use tokenizer::{
TokenizerHandle,
sgl_tokenizer_create_from_file,
sgl_tokenizer_encode,
sgl_tokenizer_apply_chat_template,
sgl_tokenizer_apply_chat_template_with_tools,
sgl_tokenizer_decode,
sgl_tokenizer_free,
};
// Re-export tool parser functions
pub use tool_parser::{
ToolParserHandle,
sgl_tool_parser_create,
sgl_tool_parser_parse_complete,
sgl_tool_parser_parse_incremental,
sgl_tool_parser_reset,
sgl_tool_parser_free,
};
// Re-export gRPC converter functions
pub use grpc_converter::{
GrpcResponseConverterHandle,
sgl_grpc_response_converter_create,
sgl_grpc_response_converter_convert_chunk,
sgl_grpc_response_converter_free,
};
// Re-export client SDK functions
pub use client::{
SglangClientHandle,
sgl_client_create,
sgl_client_free,
};
// Re-export stream functions
pub use stream::{
SglangStreamHandle,
sgl_stream_read_next,
sgl_stream_free,
};
// Re-export client stream function (defined in client.rs but used by stream)
pub use client::sgl_client_chat_completion_stream;
// Re-export preprocessor functions
pub use preprocessor::{
sgl_preprocess_chat_request,
sgl_preprocess_chat_request_with_tokenizer,
sgl_preprocessed_request_free,
};
// Re-export postprocessor functions
pub use postprocessor::{
sgl_postprocess_stream_chunk,
sgl_postprocess_stream_chunks_batch,
};
// Re-export utility functions
pub use utils::sgl_generate_tool_constraints;
// Sub-modules
mod error;
mod memory;
mod tokenizer;
mod tool_parser;
mod grpc_converter;
mod client;
mod stream;
mod utils;
mod preprocessor;
mod postprocessor;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_codes() {
assert_eq!(SglErrorCode::Success as i32, 0);
assert_eq!(SglErrorCode::InvalidArgument as i32, 1);
}
}

View File

@@ -0,0 +1,28 @@
//! Memory management for FFI functions
use std::ffi::CString;
use std::os::raw::c_char;
/// Free a C string allocated by Rust
///
/// # Safety
/// This function must only be called with pointers returned by other FFI functions.
/// Calling with arbitrary pointers or multiple times on the same pointer is undefined behavior.
#[no_mangle]
pub unsafe extern "C" fn sgl_free_string(s: *mut c_char) {
if !s.is_null() {
let _ = CString::from_raw(s);
}
}
/// Free token IDs array allocated by Rust
///
/// # Safety
/// This function must only be called with pointers returned by `sgl_tokenizer_encode`.
/// The `count` parameter must match the length of the array.
#[no_mangle]
pub unsafe extern "C" fn sgl_free_token_ids(ptr: *mut u32, count: usize) {
if !ptr.is_null() && count > 0 {
let _ = Vec::from_raw_parts(ptr, count, count);
}
}

View File

@@ -0,0 +1,465 @@
//! Postprocessing FFI functions for gRPC stream chunks
//!
//! This module provides C-compatible functions for postprocessing gRPC stream chunks:
//! - Parse tool calls from model output
//! - Convert proto format to OpenAI format
//! - Handle reasoning content parsing
//!
//! These functions are designed to be called for each stream chunk, but can be optimized
//! with batching in the future.
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int};
use std::ptr;
use std::sync::Arc;
use serde_json::Value;
use smg_grpc_client::sglang_proto as proto;
use super::error::{SglErrorCode, set_error_message};
use super::grpc_converter::GrpcResponseConverterHandle;
use tokio::runtime::Runtime;
use once_cell::sync::Lazy;
/// Global tokio runtime for async operations
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
Runtime::new().expect("Failed to create tokio runtime for postprocessor FFI")
});
/// Postprocess a gRPC stream chunk to OpenAI format
///
/// This function:
/// 1. Parses the proto chunk from JSON
/// 2. Converts it to OpenAI format using the converter handle
/// 3. Returns the OpenAI format JSON
///
/// # Arguments
/// * `converter_handle` - Converter handle (created with sgl_grpc_response_converter_create)
/// * `proto_chunk_json` - JSON string of proto.GenerateResponse
/// * `openai_json_out` - Pointer to receive OpenAI format JSON (must be freed with sgl_free_string)
/// * `is_done_out` - Pointer to receive is_done flag (1 if stream is complete, 0 otherwise)
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_postprocess_stream_chunk(
converter_handle: *mut GrpcResponseConverterHandle,
proto_chunk_json: *const c_char,
openai_json_out: *mut *mut c_char,
is_done_out: *mut c_int,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if converter_handle.is_null()
|| proto_chunk_json.is_null()
|| openai_json_out.is_null()
|| is_done_out.is_null()
{
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let proto_chunk_str = match CStr::from_ptr(proto_chunk_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in proto_chunk_json");
return SglErrorCode::InvalidArgument;
}
};
// Parse proto.GenerateResponse from JSON
let json_value: Value = match serde_json::from_str(proto_chunk_str) {
Ok(v) => v,
Err(e) => {
set_error_message(error_out, &format!("Failed to parse proto chunk JSON: {}", e));
return SglErrorCode::ParsingError;
}
};
// Build proto::GenerateResponse from JSON value
let mut proto_response = proto::GenerateResponse {
request_id: json_value
.get("request_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
response: None,
};
// Parse the response oneof field
let is_done = if let Some(chunk_json) = json_value.get("chunk") {
let chunk = proto::GenerateStreamChunk {
token_ids: chunk_json
.get("token_ids")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_u64().map(|n| n as u32))
.collect()
})
.unwrap_or_default(),
prompt_tokens: chunk_json
.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
completion_tokens: chunk_json
.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
cached_tokens: chunk_json
.get("cached_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
output_logprobs: None,
hidden_states: vec![],
input_logprobs: None,
index: 0,
};
proto_response.response = Some(proto::generate_response::Response::Chunk(chunk));
false
} else if let Some(complete_json) = json_value.get("complete") {
let complete = proto::GenerateComplete {
output_ids: complete_json
.get("output_ids")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_u64().map(|n| n as u32))
.collect()
})
.unwrap_or_default(),
finish_reason: complete_json
.get("finish_reason")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
prompt_tokens: complete_json
.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
completion_tokens: complete_json
.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
cached_tokens: complete_json
.get("cached_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
output_logprobs: None,
all_hidden_states: vec![],
input_logprobs: None,
matched_stop: None,
index: 0,
};
proto_response.response = Some(proto::generate_response::Response::Complete(complete));
true
} else if let Some(error_json) = json_value.get("error") {
let error = proto::GenerateError {
message: error_json
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
http_status_code: error_json
.get("http_status_code")
.and_then(|v| v.as_str())
.unwrap_or("500")
.to_string(),
details: error_json
.get("details")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
};
proto_response.response = Some(proto::generate_response::Response::Error(error));
true
} else {
set_error_message(
error_out,
"Proto chunk JSON must contain 'chunk', 'complete', or 'error' field",
);
return SglErrorCode::ParsingError;
};
// Convert proto chunk to OpenAI format using the converter's convert_chunk function
// We'll use the existing converter API instead of calling the internal function directly
let proto_chunk_json_cstr = match CString::new(proto_chunk_str) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create C string: {}", e));
return SglErrorCode::MemoryError;
}
};
// Use the existing converter API
let mut openai_json_ptr: *mut c_char = ptr::null_mut();
let result = super::grpc_converter::sgl_grpc_response_converter_convert_chunk(
converter_handle,
proto_chunk_json_cstr.as_ptr(),
&mut openai_json_ptr,
error_out,
);
if result == SglErrorCode::Success {
*openai_json_out = openai_json_ptr;
*is_done_out = if is_done { 1 } else { 0 };
SglErrorCode::Success
} else {
*openai_json_out = ptr::null_mut();
*is_done_out = if is_done { 1 } else { 0 };
result
}
}
/// Postprocess multiple gRPC stream chunks in batch (reduces FFI overhead)
///
/// This function processes multiple chunks in a single FFI call, significantly reducing
/// FFI overhead in streaming scenarios.
///
/// # Arguments
/// * `converter_handle` - Converter handle (created with sgl_grpc_response_converter_create)
/// * `proto_chunks_json_array` - JSON array string of proto.GenerateResponse chunks
/// * `max_chunks` - Maximum number of chunks to process (for safety)
/// * `openai_chunks_json_array_out` - Pointer to receive JSON array of OpenAI format chunks (must be freed with sgl_free_string)
/// * `chunks_count_out` - Pointer to receive number of processed chunks
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_postprocess_stream_chunks_batch(
converter_handle: *mut GrpcResponseConverterHandle,
proto_chunks_json_array: *const c_char,
max_chunks: c_int,
openai_chunks_json_array_out: *mut *mut c_char,
chunks_count_out: *mut c_int,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if converter_handle.is_null()
|| proto_chunks_json_array.is_null()
|| openai_chunks_json_array_out.is_null()
|| chunks_count_out.is_null()
{
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let chunks_array_str = match CStr::from_ptr(proto_chunks_json_array).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in proto_chunks_json_array");
return SglErrorCode::InvalidArgument;
}
};
// Parse JSON array of chunks
let chunks_array: Vec<Value> = match serde_json::from_str(chunks_array_str) {
Ok(arr) => arr,
Err(e) => {
set_error_message(
error_out,
&format!("Failed to parse chunks JSON array: {}", e),
);
return SglErrorCode::ParsingError;
}
};
// Limit batch size for safety
let max_chunks_usize = max_chunks as usize;
let chunks_to_process = if chunks_array.len() > max_chunks_usize {
&chunks_array[..max_chunks_usize]
} else {
&chunks_array
};
let handle_ref = &mut *converter_handle;
let tokenizer = Arc::clone(&handle_ref.tokenizer);
let model = handle_ref.model.clone();
let request_id = handle_ref.request_id.clone();
let created = handle_ref.created;
let system_fingerprint = handle_ref.system_fingerprint.clone();
// Process chunks in batch
let mut results = Vec::new();
let mut has_error = false;
let mut error_msg = String::new();
for chunk_json in chunks_to_process {
// Parse proto.GenerateResponse from JSON
let mut proto_response = proto::GenerateResponse {
request_id: chunk_json
.get("request_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
response: None,
};
// Parse the response oneof field (same logic as single chunk processing)
let _is_done = if let Some(chunk_json) = chunk_json.get("chunk") {
let chunk = proto::GenerateStreamChunk {
token_ids: chunk_json
.get("token_ids")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_u64().map(|n| n as u32))
.collect()
})
.unwrap_or_default(),
prompt_tokens: chunk_json
.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
completion_tokens: chunk_json
.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
cached_tokens: chunk_json
.get("cached_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
output_logprobs: None,
hidden_states: vec![],
input_logprobs: None,
index: 0,
};
proto_response.response = Some(proto::generate_response::Response::Chunk(chunk));
false
} else if let Some(complete_json) = chunk_json.get("complete") {
let complete = proto::GenerateComplete {
output_ids: complete_json
.get("output_ids")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_u64().map(|n| n as u32))
.collect()
})
.unwrap_or_default(),
finish_reason: complete_json
.get("finish_reason")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
prompt_tokens: complete_json
.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
completion_tokens: complete_json
.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
cached_tokens: complete_json
.get("cached_tokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32)
.unwrap_or(0),
output_logprobs: None,
all_hidden_states: vec![],
input_logprobs: None,
matched_stop: None,
index: 0,
};
proto_response.response = Some(proto::generate_response::Response::Complete(complete));
true
} else if let Some(error_json) = chunk_json.get("error") {
let error = proto::GenerateError {
message: error_json
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
http_status_code: error_json
.get("http_status_code")
.and_then(|v| v.as_str())
.unwrap_or("500")
.to_string(),
details: error_json
.get("details")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
};
proto_response.response = Some(proto::generate_response::Response::Error(error));
true
} else {
error_msg = format!(
"Chunk JSON must contain 'chunk', 'complete', or 'error' field: {}",
chunk_json
);
has_error = true;
break;
};
// Convert proto chunk to OpenAI format
let result = RUNTIME.block_on(async {
super::grpc_converter::convert_proto_chunk_to_openai(
proto_response,
handle_ref,
&tokenizer,
&model,
&request_id,
created,
system_fingerprint.as_deref(),
)
.await
});
match result {
Ok(Some(openai_response)) => {
results.push(openai_response);
}
Ok(None) => {
// Empty response, skip
}
Err(e) => {
error_msg = format!("Postprocessing failed for chunk: {}", e);
has_error = true;
break;
}
}
}
if has_error {
set_error_message(error_out, &error_msg);
return SglErrorCode::ParsingError;
}
// Serialize results to JSON array
let results_json = match serde_json::to_string(&results) {
Ok(s) => s,
Err(e) => {
set_error_message(
error_out,
&format!("Failed to serialize results JSON array: {}", e),
);
return SglErrorCode::ParsingError;
}
};
let results_cstr = match CString::new(results_json) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create C string: {}", e));
return SglErrorCode::MemoryError;
}
};
*openai_chunks_json_array_out = results_cstr.into_raw();
*chunks_count_out = results.len() as c_int;
SglErrorCode::Success
}

View File

@@ -0,0 +1,372 @@
//! Preprocessing FFI functions for chat requests
//!
//! This module provides C-compatible functions for preprocessing chat completion requests:
//! - Apply chat_template to messages
//! - Tokenize the processed text
//! - Generate tool constraints
//!
//! These functions are designed to be called once per request, reducing FFI overhead.
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int};
use std::ptr;
use std::os::raw::c_uint;
use smg::tokenizer::create_tokenizer_from_file;
use smg::protocols::chat::ChatCompletionRequest;
use smg::routers::grpc::utils::{process_chat_messages, generate_tool_constraints};
use super::error::{SglErrorCode, set_error_message};
use super::memory::{sgl_free_string, sgl_free_token_ids};
use super::tokenizer::TokenizerHandle;
/// Handle for preprocessed request
#[repr(C)]
pub struct PreprocessedRequestHandle {
pub(crate) prompt_text: CString,
pub(crate) token_ids: Vec<i32>,
pub(crate) tool_constraints_json: Option<CString>,
pub(crate) prompt_tokens: i32,
}
/// Preprocess a chat completion request
///
/// This function:
/// 1. Applies chat_template to messages
/// 2. Tokenizes the processed text
/// 3. Generates tool constraints (if tools are present)
///
/// # Arguments
/// * `request_json` - OpenAI ChatCompletionRequest as JSON string
/// * `tokenizer_path` - Path to tokenizer directory
/// * `prompt_text_out` - Pointer to receive prompt text (C string, must be freed with sgl_free_string)
/// * `token_ids_out` - Pointer to receive token IDs array (must be freed with sgl_free_token_ids)
/// * `token_ids_len_out` - Pointer to receive token IDs array length
/// * `tool_constraints_json_out` - Optional pointer to receive tool constraints JSON (must be freed with sgl_free_string)
/// * `prompt_tokens_out` - Pointer to receive prompt token count
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_preprocess_chat_request(
request_json: *const c_char,
tokenizer_path: *const c_char,
prompt_text_out: *mut *mut c_char,
token_ids_out: *mut *mut c_uint,
token_ids_len_out: *mut usize,
tool_constraints_json_out: *mut *mut c_char,
prompt_tokens_out: *mut c_int,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if request_json.is_null()
|| tokenizer_path.is_null()
|| prompt_text_out.is_null()
|| token_ids_out.is_null()
|| token_ids_len_out.is_null()
|| prompt_tokens_out.is_null()
{
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
// Parse input strings
let request_str = match CStr::from_ptr(request_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in request_json");
return SglErrorCode::InvalidArgument;
}
};
let tokenizer_path_str = match CStr::from_ptr(tokenizer_path).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in tokenizer_path");
return SglErrorCode::InvalidArgument;
}
};
// Parse ChatCompletionRequest
let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) {
Ok(req) => req,
Err(e) => {
set_error_message(error_out, &format!("Failed to parse request JSON: {}", e));
return SglErrorCode::ParsingError;
}
};
// Create tokenizer
let tokenizer = match create_tokenizer_from_file(tokenizer_path_str) {
Ok(t) => t,
Err(e) => {
set_error_message(error_out, &format!("Failed to create tokenizer: {}", e));
return SglErrorCode::TokenizationError;
}
};
// Process chat messages (apply chat_template)
let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) {
Ok(msgs) => msgs,
Err(e) => {
set_error_message(error_out, &format!("Failed to process chat messages: {}", e));
return SglErrorCode::ParsingError;
}
};
// Tokenize the processed text
let encoding = match tokenizer.encode(&processed_messages.text, false) {
Ok(enc) => enc,
Err(e) => {
set_error_message(error_out, &format!("Tokenization failed: {}", e));
return SglErrorCode::TokenizationError;
}
};
let token_ids_vec: Vec<i32> = encoding
.token_ids()
.iter()
.map(|&id| id as i32)
.collect();
let prompt_tokens = token_ids_vec.len() as i32;
// Generate tool constraints if tools are present
let tool_constraints_json = if let Some(tools) = chat_request.tools.as_ref() {
match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) {
Ok(Some(constraints)) => {
match serde_json::to_string(&constraints) {
Ok(json_str) => Some(CString::new(json_str).unwrap()),
Err(e) => {
set_error_message(
error_out,
&format!("Failed to serialize tool constraints: {}", e),
);
return SglErrorCode::ParsingError;
}
}
}
Ok(None) => None,
Err(e) => {
set_error_message(
error_out,
&format!("Failed to generate tool constraints: {}", e),
);
return SglErrorCode::ParsingError;
}
}
} else {
None
};
// Allocate memory for outputs
let prompt_text_cstr = match CString::new(processed_messages.text) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create C string: {}", e));
return SglErrorCode::MemoryError;
}
};
let token_ids_len = token_ids_vec.len();
// Convert i32 to u32 for token IDs (as expected by the memory management functions)
let token_ids_u32: Vec<u32> = token_ids_vec.iter().map(|&id| id as u32).collect();
let token_ids_ptr = if token_ids_u32.is_empty() {
ptr::null_mut()
} else {
let boxed = token_ids_u32.into_boxed_slice();
Box::into_raw(boxed) as *mut c_uint
};
// Set output values
*prompt_text_out = prompt_text_cstr.into_raw();
*token_ids_out = token_ids_ptr;
*token_ids_len_out = token_ids_len;
*prompt_tokens_out = prompt_tokens;
if !tool_constraints_json_out.is_null() {
if let Some(constraints) = tool_constraints_json {
*tool_constraints_json_out = constraints.into_raw();
} else {
*tool_constraints_json_out = ptr::null_mut();
}
}
SglErrorCode::Success
}
/// Preprocess a chat completion request using an existing tokenizer handle
///
/// This function is similar to sgl_preprocess_chat_request, but accepts a TokenizerHandle
/// instead of creating a new tokenizer. This allows reusing a cached tokenizer instance,
/// significantly reducing initialization overhead in concurrent scenarios.
///
/// # Arguments
/// * `request_json` - OpenAI ChatCompletionRequest as JSON string
/// * `tokenizer_handle` - Existing tokenizer handle (must be valid)
/// * `prompt_text_out` - Pointer to receive prompt text (C string, must be freed with sgl_free_string)
/// * `token_ids_out` - Pointer to receive token IDs array (must be freed with sgl_free_token_ids)
/// * `token_ids_len_out` - Pointer to receive token IDs array length
/// * `tool_constraints_json_out` - Optional pointer to receive tool constraints JSON (must be freed with sgl_free_string)
/// * `prompt_tokens_out` - Pointer to receive prompt token count
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_preprocess_chat_request_with_tokenizer(
request_json: *const c_char,
tokenizer_handle: *mut TokenizerHandle,
prompt_text_out: *mut *mut c_char,
token_ids_out: *mut *mut c_uint,
token_ids_len_out: *mut usize,
tool_constraints_json_out: *mut *mut c_char,
prompt_tokens_out: *mut c_int,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if request_json.is_null()
|| tokenizer_handle.is_null()
|| prompt_text_out.is_null()
|| token_ids_out.is_null()
|| token_ids_len_out.is_null()
|| prompt_tokens_out.is_null()
{
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
// Parse input string
let request_str = match CStr::from_ptr(request_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in request_json");
return SglErrorCode::InvalidArgument;
}
};
// Parse ChatCompletionRequest
let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) {
Ok(req) => req,
Err(e) => {
set_error_message(error_out, &format!("Failed to parse request JSON: {}", e));
return SglErrorCode::ParsingError;
}
};
// Use existing tokenizer from handle (no need to create new one!)
let handle_ref = &*tokenizer_handle;
let tokenizer = &handle_ref.tokenizer;
// Process chat messages (apply chat_template)
let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) {
Ok(msgs) => msgs,
Err(e) => {
set_error_message(error_out, &format!("Failed to process chat messages: {}", e));
return SglErrorCode::ParsingError;
}
};
// Tokenize the processed text
let encoding = match tokenizer.encode(&processed_messages.text, false) {
Ok(enc) => enc,
Err(e) => {
set_error_message(error_out, &format!("Tokenization failed: {}", e));
return SglErrorCode::TokenizationError;
}
};
let token_ids_vec: Vec<i32> = encoding
.token_ids()
.iter()
.map(|&id| id as i32)
.collect();
let prompt_tokens = token_ids_vec.len() as i32;
// Generate tool constraints if tools are present
let tool_constraints_json = if let Some(tools) = chat_request.tools.as_ref() {
match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) {
Ok(Some(constraints)) => {
match serde_json::to_string(&constraints) {
Ok(json_str) => Some(CString::new(json_str).unwrap()),
Err(e) => {
set_error_message(
error_out,
&format!("Failed to serialize tool constraints: {}", e),
);
return SglErrorCode::ParsingError;
}
}
}
Ok(None) => None,
Err(e) => {
set_error_message(
error_out,
&format!("Failed to generate tool constraints: {}", e),
);
return SglErrorCode::ParsingError;
}
}
} else {
None
};
// Allocate memory for outputs
let prompt_text_cstr = match CString::new(processed_messages.text) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create C string: {}", e));
return SglErrorCode::MemoryError;
}
};
let token_ids_len = token_ids_vec.len();
// Convert i32 to u32 for token IDs (as expected by the memory management functions)
let token_ids_u32: Vec<u32> = token_ids_vec.iter().map(|&id| id as u32).collect();
let token_ids_ptr = if token_ids_u32.is_empty() {
ptr::null_mut()
} else {
let boxed = token_ids_u32.into_boxed_slice();
Box::into_raw(boxed) as *mut c_uint
};
// Set output values
*prompt_text_out = prompt_text_cstr.into_raw();
*token_ids_out = token_ids_ptr;
*token_ids_len_out = token_ids_len;
*prompt_tokens_out = prompt_tokens;
if !tool_constraints_json_out.is_null() {
if let Some(constraints) = tool_constraints_json {
*tool_constraints_json_out = constraints.into_raw();
} else {
*tool_constraints_json_out = ptr::null_mut();
}
}
SglErrorCode::Success
}
/// Free a preprocessed request handle (cleanup function)
///
/// This function frees the memory allocated by sgl_preprocess_chat_request.
/// It should be called after the preprocessed data is no longer needed.
#[no_mangle]
pub unsafe extern "C" fn sgl_preprocessed_request_free(
prompt_text: *mut c_char,
token_ids: *mut c_uint,
token_ids_len: usize,
tool_constraints_json: *mut c_char,
) {
if !prompt_text.is_null() {
sgl_free_string(prompt_text);
}
if !token_ids.is_null() && token_ids_len > 0 {
sgl_free_token_ids(token_ids, token_ids_len);
}
if !tool_constraints_json.is_null() {
sgl_free_string(tool_constraints_json);
}
}

View File

@@ -0,0 +1,288 @@
//! Stream handling FFI functions
//!
//! This module provides FFI (Foreign Function Interface) functions for managing
//! streaming responses from the SGLang gRPC API. It handles:
//!
//! - Creating and managing stream handles
//! - Reading chunks from streams and converting them to OpenAI format
//! - Managing automatic abort on stream drop (via AbortOnDropStream)
//! - Thread-safe access to streams and response converters
//!
//! # Safety
//!
//! All FFI functions are marked `unsafe` as per Rust FFI conventions. Callers must:
//! - Pass valid pointers
//! - Ensure proper pointer lifetime management
//! - Call corresponding free functions for cleanup
use std::ffi::CString;
use std::os::raw::{c_char, c_int};
use std::ptr;
use std::sync::Arc;
use tokio::runtime::Runtime;
use once_cell::sync::Lazy;
use futures_util::StreamExt;
use smg_grpc_client::{sglang_proto as proto, sglang_scheduler::{SglangSchedulerClient, AbortOnDropStream}};
use super::error::{SglErrorCode, set_error_message};
use super::grpc_converter::{GrpcResponseConverterHandle, convert_proto_chunk_to_openai};
/// Global tokio runtime for async operations
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
Runtime::new().expect("Failed to create tokio runtime for stream FFI")
});
/// Handle for an active streaming request.
///
/// This struct manages the stream and response converter for a single request.
/// It is wrapped in Arc and Mutex for thread-safe concurrent access.
///
/// # Fields
///
/// * `stream` - The gRPC stream wrapped in AbortOnDropStream for automatic cleanup
/// * `converter` - Response converter that transforms proto messages to OpenAI format
/// * `client` - The underlying gRPC client connection
/// * `prompt_tokens` - Number of prompt tokens from the original request
pub struct SglangStreamHandle {
pub(crate) stream: Arc<tokio::sync::Mutex<AbortOnDropStream>>,
pub(crate) converter: Arc<tokio::sync::Mutex<GrpcResponseConverterHandle>>,
#[allow(dead_code)]
pub(crate) client: Arc<SglangSchedulerClient>,
#[allow(dead_code)]
pub(crate) prompt_tokens: i32, // Number of prompt tokens for this request
}
/// Read next chunk from stream and convert to OpenAI format.
///
/// This function reads the next chunk from the gRPC stream, converts it from the
/// internal protocol format to OpenAI-compatible JSON format, and returns it via
/// the output parameters.
///
/// # Arguments
///
/// * `stream_handle` - Mutable pointer to the stream handle
/// * `response_json_out` - Pointer to receive OpenAI format JSON string
/// - Caller must free this with `sgl_free_string`
/// - May be NULL if no data available
/// * `is_done_out` - Pointer to receive completion status
/// - 0 = stream has more data
/// - 1 = stream is complete
/// * `error_out` - Optional pointer to receive error message
/// - Only set if function returns an error code
/// - Must be freed with `sgl_free_string` if not NULL
///
/// # Returns
///
/// * `SglErrorCode::Success` - Successfully read a chunk or reached end of stream
/// * Other error codes - See `SglErrorCode` for details
///
/// # Safety
///
/// - All pointers must be valid and properly aligned
/// - `stream_handle` must point to a valid `SglangStreamHandle`
/// - Output pointers must be writable
///
/// # Notes
///
/// - Complete messages are identified by the presence of `proto::GenerateResponse::Complete`
/// - When is_done=1, this may be the last readable chunk or the stream may be ending
/// - Subsequent calls after is_done=1 will mark the stream as complete internally
#[no_mangle]
pub unsafe extern "C" fn sgl_stream_read_next(
stream_handle: *mut SglangStreamHandle,
response_json_out: *mut *mut c_char,
is_done_out: *mut c_int,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if stream_handle.is_null() || response_json_out.is_null() || is_done_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let handle_ref = &*stream_handle;
let stream = Arc::clone(&handle_ref.stream);
let converter = Arc::clone(&handle_ref.converter);
// Read next chunk from stream
let chunk_result = RUNTIME.block_on(async {
let mut stream_guard = stream.lock().await;
stream_guard.next().await
});
match chunk_result {
Some(Ok(proto_response)) => {
// Convert proto response to OpenAI format
// We need to get the converter lock first
let conversion_result = RUNTIME.block_on(async {
let mut converter_guard = converter.lock().await;
// Clone necessary fields for conversion
let tokenizer = Arc::clone(&converter_guard.tokenizer);
let model = converter_guard.model.clone();
let request_id = converter_guard.request_id.clone();
let created = converter_guard.created;
let system_fingerprint = converter_guard.system_fingerprint.clone();
// Call the conversion function
convert_proto_chunk_to_openai(
proto_response.clone(),
&mut *converter_guard,
&tokenizer,
&model,
&request_id,
created,
system_fingerprint.as_deref(),
)
.await
});
match conversion_result {
Ok(Some(openai_response)) => {
// Serialize to JSON
let result_str = match serde_json::to_string(&openai_response) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to serialize response: {}", e));
return SglErrorCode::ParsingError;
}
};
let result_cstr = match CString::new(result_str) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create result string: {}", e));
return SglErrorCode::MemoryError;
}
};
// Check if this is a complete response (stream done)
let is_complete = matches!(proto_response.response, Some(proto::generate_response::Response::Complete(_)) | Some(proto::generate_response::Response::Error(_)));
*response_json_out = result_cstr.into_raw();
*is_done_out = if is_complete { 1 } else { 0 };
if is_complete {
// Mark stream as completed
// Ensure mark_completed() completes and is visible before returning
// Use yield_now to ensure Release ordering is fully propagated
RUNTIME.block_on(async {
let stream_guard = stream.lock().await;
stream_guard.mark_completed();
// Keep the guard until mark_completed() is fully executed
drop(stream_guard);
// Yield to ensure Release ordering is propagated before returning
// This prevents race condition where Free() is called immediately
// and Drop might not see the mark_completed() write
tokio::task::yield_now().await;
});
}
SglErrorCode::Success
}
Ok(None) => {
// No response to send (e.g., empty chunk)
// Don't mark as completed - stream might continue
// Just return null and let caller read more
*response_json_out = ptr::null_mut();
*is_done_out = 0; // Keep stream open, not done yet
SglErrorCode::Success
}
Err(e) => {
// Conversion error - don't mark as completed
// Let the stream end naturally or return error without stopping stream
set_error_message(error_out, &format!("Conversion error: {}", e));
*response_json_out = ptr::null_mut();
*is_done_out = 0; // Don't mark as done - let caller decide
SglErrorCode::ParsingError
}
}
}
Some(Err(e)) => {
// Stream error - mark as completed to prevent abort
RUNTIME.block_on(async {
let stream_guard = stream.lock().await;
stream_guard.mark_completed();
drop(stream_guard);
// Yield to ensure Release ordering is propagated
tokio::task::yield_now().await;
});
set_error_message(error_out, &format!("Stream error: {}", e));
*is_done_out = 1;
SglErrorCode::UnknownError
}
None => {
// Stream ended naturally (no more chunks)
// Mark stream as completed before returning to prevent abort
RUNTIME.block_on(async {
let stream_guard = stream.lock().await;
stream_guard.mark_completed();
drop(stream_guard);
// Yield to ensure Release ordering is propagated
tokio::task::yield_now().await;
});
*response_json_out = ptr::null_mut();
*is_done_out = 1;
SglErrorCode::Success
}
}
}
/// Free a stream handle and release all associated resources.
///
/// This function must be called exactly once for each stream handle returned by
/// `sgl_client_chat_completion_stream`. It marks the stream as completed internally
/// to prevent abort signals from being sent when resources are cleaned up.
///
/// # Arguments
///
/// * `handle` - Mutable pointer to the stream handle to free
/// - If NULL, this function does nothing
///
/// # Safety
///
/// - Must be called only once per handle
/// - Handle must not be used after calling this function
/// - After this call, the stream is no longer valid
///
/// # Notes
///
/// - This function internally calls `mark_completed()` before freeing to ensure
/// the stream cleanup doesn't trigger an abort RPC to the server
/// - Memory fences are used to ensure visibility across threads
#[no_mangle]
pub unsafe extern "C" fn sgl_stream_free(handle: *mut SglangStreamHandle) {
if !handle.is_null() {
let handle_ref = Box::from_raw(handle);
// Mark stream as completed to prevent abort on drop
// By this point, the stream should already be completed by ReadNext()
// but we call it again to be safe
RUNTIME.block_on(async {
let stream_guard = handle_ref.stream.lock().await;
stream_guard.mark_completed();
// Keep guard alive to ensure mark_completed() write completes
drop(stream_guard);
// Yield to ensure the atomic write is visible
tokio::task::yield_now().await;
});
// Use a strong memory fence to ensure mark_completed()'s Release write
// is visible before we drop the last Arc reference
std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
// Now drop all references - if mark_completed() was called successfully,
// the drop won't send an abort
drop(handle_ref.stream);
// Free converter
let converter = Arc::try_unwrap(handle_ref.converter)
.ok()
.map(|m| m.into_inner());
if let Some(conv) = converter {
super::grpc_converter::sgl_grpc_response_converter_free(Box::into_raw(Box::new(conv)));
}
}
}

View File

@@ -0,0 +1,388 @@
//! Tokenizer FFI functions
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int};
use std::ptr;
use std::sync::Arc;
use serde_json::Value;
use smg::tokenizer::{
create_tokenizer_from_file,
traits::Tokenizer as TokenizerTrait,
chat_template::ChatTemplateParams,
huggingface::HuggingFaceTokenizer,
};
use super::error::{SglErrorCode, set_error_message, clear_error_message};
#[cfg(target_os = "macos")]
type BooleanT = libc::boolean_t;
#[cfg(not(target_os = "macos"))]
type BooleanT = libc::c_int;
/// Opaque handle for a tokenizer instance
#[repr(C)]
pub struct TokenizerHandle {
pub(crate) tokenizer: Arc<dyn TokenizerTrait>,
}
/// Create a tokenizer from a file path
///
/// # Arguments
/// * `path` - Path to tokenizer.json file (null-terminated C string)
/// * `error_out` - Optional pointer to receive error message (must be freed with sgl_free_string)
///
/// # Returns
/// * Pointer to TokenizerHandle on success, null on failure
///
/// # Safety
/// The returned handle must be freed with `sgl_tokenizer_free`.
#[no_mangle]
pub unsafe extern "C" fn sgl_tokenizer_create_from_file(
path: *const c_char,
error_out: *mut *mut c_char,
) -> *mut TokenizerHandle {
if path.is_null() {
set_error_message(error_out, "path cannot be null");
return ptr::null_mut();
}
let path_str = match CStr::from_ptr(path).to_str() {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Invalid UTF-8 in path: {}", e));
return ptr::null_mut();
}
};
match create_tokenizer_from_file(path_str) {
Ok(tokenizer) => {
clear_error_message(error_out);
Box::into_raw(Box::new(TokenizerHandle {
tokenizer,
}))
}
Err(e) => {
set_error_message(error_out, &e.to_string());
ptr::null_mut()
}
}
}
/// Encode text to token IDs
///
/// # Arguments
/// * `handle` - Tokenizer handle (must not be null)
/// * `text` - Input text (null-terminated C string)
/// * `add_special_tokens` - Whether to add special tokens
/// * `token_ids_out` - Pointer to receive array of token IDs (must be freed with sgl_free_token_ids)
/// * `token_count_out` - Pointer to receive token count
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
///
/// # Safety
/// The token_ids_out array must be freed with sgl_free_token_ids() after use.
#[no_mangle]
pub unsafe extern "C" fn sgl_tokenizer_encode(
handle: *mut TokenizerHandle,
text: *const c_char,
add_special_tokens: BooleanT,
token_ids_out: *mut *mut u32,
token_count_out: *mut usize,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if handle.is_null() || text.is_null() || token_ids_out.is_null() || token_count_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let text_str = match CStr::from_ptr(text).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in text");
return SglErrorCode::InvalidArgument;
}
};
let add_special_tokens_bool = add_special_tokens != 0;
let tokenizer = &(*handle).tokenizer;
match tokenizer.encode(text_str, add_special_tokens_bool) {
Ok(encoding) => {
let token_ids = encoding.token_ids();
let count = token_ids.len();
// Allocate memory for token IDs using Vec, then leak to give ownership to C
let vec = token_ids.to_vec();
let ptr = vec.as_ptr() as *mut u32;
let _ = std::mem::ManuallyDrop::new(vec);
*token_ids_out = ptr;
*token_count_out = count;
clear_error_message(error_out);
SglErrorCode::Success
}
Err(e) => {
set_error_message(error_out, &e.to_string());
SglErrorCode::TokenizationError
}
}
}
/// Apply chat template to messages with tools support
///
/// # Arguments
/// * `handle` - Tokenizer handle
/// * `messages_json` - JSON string of messages array
/// * `tools_json` - Optional JSON string of tools array (null or empty string for no tools)
/// * `result_out` - Pointer to receive result string (must be freed with sgl_free_string)
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_tokenizer_apply_chat_template_with_tools(
handle: *mut TokenizerHandle,
messages_json: *const c_char,
tools_json: *const c_char,
result_out: *mut *mut c_char,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if handle.is_null() || messages_json.is_null() || result_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let messages_str = match CStr::from_ptr(messages_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in messages_json");
return SglErrorCode::InvalidArgument;
}
};
// Parse JSON messages
let messages: Vec<Value> = match serde_json::from_str(messages_str) {
Ok(msgs) => msgs,
Err(e) => {
set_error_message(error_out, &format!("Failed to parse messages JSON: {}", e));
return SglErrorCode::InvalidArgument;
}
};
// Parse tools JSON if provided
let tools: Option<Vec<Value>> = if tools_json.is_null() {
None
} else {
let tools_str = match CStr::from_ptr(tools_json).to_str() {
Ok(s) => {
if s.is_empty() {
None
} else {
match serde_json::from_str::<Vec<Value>>(s) {
Ok(t) => Some(t),
Err(e) => {
set_error_message(error_out, &format!("Failed to parse tools JSON: {}", e));
return SglErrorCode::InvalidArgument;
}
}
}
}
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in tools_json");
return SglErrorCode::InvalidArgument;
}
};
tools_str
};
// Get the tokenizer from handle
let handle_ref = &*handle;
let tokenizer = &handle_ref.tokenizer;
// Try to downcast to HuggingFaceTokenizer
if let Some(hf_tokenizer) = tokenizer.as_any().downcast_ref::<HuggingFaceTokenizer>() {
// Apply chat template with tools
let empty_docs: [Value; 0] = [];
let tools_slice = tools.as_ref().map(|t| t.as_slice());
let params = ChatTemplateParams {
add_generation_prompt: true,
tools: tools_slice,
documents: Some(&empty_docs),
template_kwargs: None,
};
match hf_tokenizer.apply_chat_template(&messages, params) {
Ok(result) => {
let result_cstr = match CString::new(result) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create result string: {}", e));
return SglErrorCode::MemoryError;
}
};
*result_out = result_cstr.into_raw();
clear_error_message(error_out);
SglErrorCode::Success
}
Err(e) => {
set_error_message(error_out, &format!("Failed to apply chat template: {}", e));
SglErrorCode::TokenizationError
}
}
} else {
set_error_message(error_out, "Chat template is only supported for HuggingFace tokenizers");
SglErrorCode::TokenizationError
}
}
/// Apply chat template to messages
///
/// # Arguments
/// * `handle` - Tokenizer handle
/// * `messages_json` - JSON string of messages array
/// * `result_out` - Pointer to receive result string (must be freed with sgl_free_string)
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_tokenizer_apply_chat_template(
handle: *mut TokenizerHandle,
messages_json: *const c_char,
result_out: *mut *mut c_char,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if handle.is_null() || messages_json.is_null() || result_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let messages_str = match CStr::from_ptr(messages_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in messages_json");
return SglErrorCode::InvalidArgument;
}
};
// Parse JSON messages
let messages: Vec<Value> = match serde_json::from_str(messages_str) {
Ok(msgs) => msgs,
Err(e) => {
set_error_message(error_out, &format!("Failed to parse messages JSON: {}", e));
return SglErrorCode::InvalidArgument;
}
};
// Get the tokenizer from handle
let handle_ref = &*handle;
let tokenizer = &handle_ref.tokenizer;
// Try to downcast to HuggingFaceTokenizer
if let Some(hf_tokenizer) = tokenizer.as_any().downcast_ref::<HuggingFaceTokenizer>() {
// Apply chat template with default parameters
// Use empty arrays instead of None to avoid template errors
// Set add_generation_prompt to true so the model knows to start generating
let empty_tools: [Value; 0] = [];
let empty_docs: [Value; 0] = [];
let params = ChatTemplateParams {
add_generation_prompt: true, // Important: tells the model to start generating
tools: Some(&empty_tools),
documents: Some(&empty_docs),
template_kwargs: None,
};
match hf_tokenizer.apply_chat_template(&messages, params) {
Ok(result) => {
let result_cstr = match CString::new(result) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create result string: {}", e));
return SglErrorCode::MemoryError;
}
};
*result_out = result_cstr.into_raw();
clear_error_message(error_out);
SglErrorCode::Success
}
Err(e) => {
set_error_message(error_out, &format!("Failed to apply chat template: {}", e));
SglErrorCode::TokenizationError
}
}
} else {
set_error_message(error_out, "Chat template is only supported for HuggingFace tokenizers");
SglErrorCode::TokenizationError
}
}
/// Decode token IDs to text
///
/// # Arguments
/// * `handle` - Tokenizer handle
/// * `token_ids` - Array of token IDs
/// * `token_count` - Number of tokens
/// * `skip_special_tokens` - Whether to skip special tokens
/// * `result_out` - Pointer to receive result string (must be freed with sgl_free_string)
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_tokenizer_decode(
handle: *mut TokenizerHandle,
token_ids: *const u32,
token_count: usize,
skip_special_tokens: c_int,
result_out: *mut *mut c_char,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if handle.is_null() || token_ids.is_null() || result_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
if token_count == 0 {
let empty = CString::new("").unwrap();
*result_out = empty.into_raw();
clear_error_message(error_out);
return SglErrorCode::Success;
}
// Convert C array to Rust slice
let token_slice = std::slice::from_raw_parts(token_ids, token_count);
let tokenizer = &(*handle).tokenizer;
match tokenizer.decode(token_slice, skip_special_tokens != 0) {
Ok(text) => {
let result_cstr = match CString::new(text) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create result string: {}", e));
return SglErrorCode::MemoryError;
}
};
*result_out = result_cstr.into_raw();
clear_error_message(error_out);
SglErrorCode::Success
}
Err(e) => {
set_error_message(error_out, &e.to_string());
SglErrorCode::TokenizationError
}
}
}
/// Free a tokenizer handle
///
/// # Safety
/// This function must only be called once per handle, and the handle must not be used after calling.
#[no_mangle]
pub unsafe extern "C" fn sgl_tokenizer_free(handle: *mut TokenizerHandle) {
if !handle.is_null() {
let _ = Box::from_raw(handle);
}
}

View File

@@ -0,0 +1,329 @@
//! Tool parser FFI functions
use std::ffi::{CStr, CString};
use std::os::raw::{c_char};
use std::ptr;
use std::sync::Arc;
use std::collections::HashMap;
use serde_json::{json, Value};
use tokio::runtime::Runtime;
use once_cell::sync::Lazy;
use smg::tool_parser::{ParserFactory, ToolParser};
use smg::protocols::common::Tool;
use super::error::{SglErrorCode, set_error_message, clear_error_message};
use super::utils::generate_tool_call_id;
/// Global parser factory (initialized once)
static PARSER_FACTORY: Lazy<ParserFactory> = Lazy::new(|| ParserFactory::new());
/// Global tokio runtime for async operations
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
Runtime::new().expect("Failed to create tokio runtime for tool parser FFI")
});
/// Opaque handle for a tool parser instance
/// Note: For streaming, we need mutable access, so we use Arc<Mutex<>> internally
/// Note: This is an opaque handle, C code doesn't access fields directly
pub struct ToolParserHandle {
parser: Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>,
model: String, // Store model name for ID generation
history_tool_calls_count: usize, // Track tool call count for ID generation
tool_index_to_id: HashMap<usize, String>, // Map tool_index to ID for incremental updates
}
/// Create a tool parser
///
/// # Arguments
/// * `parser_type` - Parser type name (e.g., "json", "llama", "mistral") or model name (e.g., "gpt-4")
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * Pointer to ToolParserHandle on success, null on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_tool_parser_create(
parser_type: *const c_char,
error_out: *mut *mut c_char,
) -> *mut ToolParserHandle {
if parser_type.is_null() {
set_error_message(error_out, "parser_type cannot be null");
return ptr::null_mut();
}
let type_str = match CStr::from_ptr(parser_type).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in parser_type");
return ptr::null_mut();
}
};
// Create parser using factory
// The factory will determine the parser type based on model name or use the provided type
let parser = if let Some(parser_box) = PARSER_FACTORY.registry().create_for_model(type_str) {
parser_box
} else if let Some(parser_box) = PARSER_FACTORY.registry().create_parser(type_str) {
parser_box
} else {
set_error_message(error_out, &format!("Unknown parser type: {}", type_str));
return ptr::null_mut();
};
Box::into_raw(Box::new(ToolParserHandle {
parser: Arc::new(tokio::sync::Mutex::new(parser)),
model: type_str.to_string(),
history_tool_calls_count: 0,
tool_index_to_id: HashMap::new(),
}))
}
/// Parse complete tool calls from text
///
/// # Arguments
/// * `handle` - Tool parser handle
/// * `text` - Input text to parse
/// * `result_json_out` - Pointer to receive JSON result (must be freed with sgl_free_string)
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_tool_parser_parse_complete(
handle: *mut ToolParserHandle,
text: *const c_char,
result_json_out: *mut *mut c_char,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if handle.is_null() || text.is_null() || result_json_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let text_str = match CStr::from_ptr(text).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in text");
return SglErrorCode::InvalidArgument;
}
};
let handle_ref = &*handle;
let parser = Arc::clone(&handle_ref.parser);
let model = handle_ref.model.clone();
let history_count = handle_ref.history_tool_calls_count;
// Use tokio runtime to run async code
let result = RUNTIME.block_on(async {
let parser_guard = parser.lock().await;
parser_guard.parse_complete(text_str).await
});
match result {
Ok((normal_text, tool_calls)) => {
// Convert Rust ToolCall to OpenAI format
let openai_tool_calls: Vec<Value> = tool_calls
.into_iter()
.enumerate()
.map(|(index, tc)| {
// Generate ID for this tool call
let id = generate_tool_call_id(&model, &tc.function.name, index, history_count);
json!({
"id": id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
})
})
.collect();
// Build result JSON
let result_json = json!({
"normal_text": normal_text,
"tool_calls": openai_tool_calls
});
let result_str = match serde_json::to_string(&result_json) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to serialize JSON: {}", e));
return SglErrorCode::ParsingError;
}
};
let result_cstr = match CString::new(result_str) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create result string: {}", e));
return SglErrorCode::MemoryError;
}
};
*result_json_out = result_cstr.into_raw();
clear_error_message(error_out);
SglErrorCode::Success
}
Err(e) => {
set_error_message(error_out, &format!("Parse error: {}", e));
SglErrorCode::ParsingError
}
}
}
/// Parse tool calls incrementally from streaming chunks
///
/// # Arguments
/// * `handle` - Tool parser handle
/// * `chunk` - New text chunk from stream
/// * `tools_json` - JSON array of available tools (for validation, can be null/empty)
/// * `result_json_out` - Pointer to receive JSON result (must be freed with sgl_free_string)
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_tool_parser_parse_incremental(
handle: *mut ToolParserHandle,
chunk: *const c_char,
tools_json: *const c_char,
result_json_out: *mut *mut c_char,
error_out: *mut *mut c_char,
) -> SglErrorCode {
if handle.is_null() || chunk.is_null() || result_json_out.is_null() {
set_error_message(error_out, "Invalid arguments: null pointer");
return SglErrorCode::InvalidArgument;
}
let chunk_str = match CStr::from_ptr(chunk).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in chunk");
return SglErrorCode::InvalidArgument;
}
};
// Parse tools JSON if provided
let tools: Vec<Tool> = if !tools_json.is_null() {
let tools_str = match CStr::from_ptr(tools_json).to_str() {
Ok(s) => s,
Err(_) => {
set_error_message(error_out, "Invalid UTF-8 in tools_json");
return SglErrorCode::InvalidArgument;
}
};
match serde_json::from_str::<Vec<Tool>>(tools_str) {
Ok(t) => t,
Err(_) => vec![], // If parsing fails, use empty tools
}
} else {
vec![]
};
let handle_ref = &*handle;
let parser = Arc::clone(&handle_ref.parser);
let model = handle_ref.model.clone();
let history_count = handle_ref.history_tool_calls_count;
// Use tokio runtime to run async code
let result = RUNTIME.block_on(async {
let mut parser_guard = parser.lock().await;
parser_guard.parse_incremental(chunk_str, &tools).await
});
match result {
Ok(streaming_result) => {
// Convert StreamingParseResult to OpenAI format
let handle_mut = &mut *handle;
let openai_tool_calls: Vec<Value> = streaming_result
.calls
.into_iter()
.map(|item| {
// For incremental parsing, we may not have complete tool calls yet
// Generate or reuse ID based on tool_index
let id = if let Some(ref name) = item.name {
// New tool call with name - generate ID and store it
let id = generate_tool_call_id(&model, name, item.tool_index, history_count);
handle_mut.tool_index_to_id.insert(item.tool_index, id.clone());
id
} else {
// Parameter update - reuse existing ID for this tool_index
handle_mut.tool_index_to_id
.get(&item.tool_index)
.cloned()
.unwrap_or_else(|| format!("call_{}", item.tool_index))
};
json!({
"id": id,
"type": "function",
"function": {
"name": item.name.unwrap_or_default(),
"arguments": item.parameters
}
})
})
.collect();
// Build result JSON
let result_json = json!({
"normal_text": streaming_result.normal_text,
"tool_calls": openai_tool_calls
});
let result_str = match serde_json::to_string(&result_json) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to serialize JSON: {}", e));
return SglErrorCode::ParsingError;
}
};
let result_cstr = match CString::new(result_str) {
Ok(s) => s,
Err(e) => {
set_error_message(error_out, &format!("Failed to create result string: {}", e));
return SglErrorCode::MemoryError;
}
};
*result_json_out = result_cstr.into_raw();
clear_error_message(error_out);
SglErrorCode::Success
}
Err(e) => {
set_error_message(error_out, &format!("Parse incremental error: {}", e));
SglErrorCode::ParsingError
}
}
}
/// Reset the parser state for reuse
#[no_mangle]
pub unsafe extern "C" fn sgl_tool_parser_reset(handle: *mut ToolParserHandle) {
if handle.is_null() {
return;
}
let handle_ref = &mut *handle;
let parser = Arc::clone(&handle_ref.parser);
// Reset parser state
RUNTIME.block_on(async {
let mut parser_guard = parser.lock().await;
parser_guard.reset();
});
// Reset history count and tool index mapping
handle_ref.history_tool_calls_count = 0;
handle_ref.tool_index_to_id.clear();
}
/// Free a tool parser handle
#[no_mangle]
pub unsafe extern "C" fn sgl_tool_parser_free(handle: *mut ToolParserHandle) {
if !handle.is_null() {
let _ = Box::from_raw(handle);
}
}

View File

@@ -0,0 +1,44 @@
//! Utility functions for FFI
use uuid::Uuid;
/// Helper function to generate tool call ID (matches router implementation)
pub fn generate_tool_call_id(
model: &str,
function_name: &str,
index: usize,
history_tool_calls_count: usize,
) -> String {
if model.to_lowercase().contains("kimi") {
// KimiK2 format: functions.{name}:{global_index}
format!("functions.{}:{}", function_name, history_tool_calls_count + index)
} else {
// Standard OpenAI format: call_{24-char-uuid}
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
}
}
/// Generate tool constraints (placeholder implementation)
///
/// # Arguments
/// * `tools_json` - JSON array of tools
/// * `tool_choice_json` - JSON object representing tool_choice
/// * `constraint_type_out` - Pointer to receive constraint type (e.g., "json_schema")
/// * `constraint_schema_out` - Pointer to receive constraint schema JSON
/// * `error_out` - Optional pointer to receive error message
///
/// # Returns
/// * SglErrorCode::Success on success, error code on failure
#[no_mangle]
pub unsafe extern "C" fn sgl_generate_tool_constraints(
_tools_json: *const std::os::raw::c_char,
_tool_choice_json: *const std::os::raw::c_char,
_constraint_type_out: *mut *mut std::os::raw::c_char,
_constraint_schema_out: *mut *mut std::os::raw::c_char,
error_out: *mut *mut std::os::raw::c_char,
) -> super::error::SglErrorCode {
// Implementation would parse JSON and call generate_tool_constraints
// This is a placeholder
super::error::set_error_message(error_out, "Tool constraint generation not yet implemented in FFI");
super::error::SglErrorCode::UnknownError
}

View File

@@ -0,0 +1,9 @@
[run]
source = sglang_router
omit =
*/mini_lb.py
*/cli.py
*/__main__.py
[report]
fail_under = 80

View File

@@ -0,0 +1,28 @@
[package]
name = "sgl-model-gateway-python"
version = "0.3.2"
edition = "2021"
[lib]
name = "sglang_router_rs"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.27.1", features = ["extension-module", "abi3-py38"] }
tokio = { version = "1.42.0", features = ["full"] }
once_cell = "1.19"
[dependencies.sgl-model-gateway]
path = "../.."
default-features = true
[features]
default = ["pyo3/extension-module"]
vendored-openssl = ["sgl-model-gateway/vendored-openssl"]
[profile.ci]
inherits = "release"
opt-level = 2 # Lighter optimization (still fast runtime, much faster compile)
lto = "thin" # Thin LTO - good balance
codegen-units = 16 # More parallelization for faster builds
strip = true

View File

@@ -0,0 +1,9 @@
# Must include:
include Cargo.toml # Python bindings Cargo configuration
include ../../Cargo.toml # Main Rust project configuration
include ../../build.rs # Build script for protobuf generation
include ../../LICENSE
recursive-include src *.rs # Python bindings wrapper
recursive-include ../../src *.rs # Main Rust source files
recursive-include ../../src/proto *.proto # Protobuf definitions
recursive-include sglang_router *.py # Python source files

View File

@@ -0,0 +1,77 @@
# SGLang Model Gateway Python Bindings
This directory contains the Python bindings for the SGLang Router, built using [maturin](https://github.com/PyO3/maturin) and [PyO3](https://github.com/PyO3/pyo3).
## Directory Structure
```
bindings/python/
├── src/ # Source code (src layout)
│ ├── lib.rs # Rust/PyO3 bindings implementation
│ └── sglang_router/ # Python source code
│ ├── __init__.py
│ ├── version.py
│ ├── launch_server.py
│ ├── launch_router.py
│ ├── router.py
│ ├── router_args.py
│ └── mini_lb.py
├── tests/ # Python unit tests
│ ├── conftest.py
│ ├── test_validation.py
│ ├── test_arg_parser.py
│ ├── test_router_config.py
│ └── test_startup_sequence.py
├── Cargo.toml # Rust package configuration for bindings
├── pyproject.toml # Python package configuration
├── setup.py # Setup configuration
├── MANIFEST.in # Package manifest
├── .coveragerc # Test coverage configuration
└── README.md # This file
```
## Building
### Development Build
```bash
# Install maturin
pip install maturin
# Build and install in development mode
cd sgl-model-gateway/bindings/python
maturin develop --features vendored-openssl
```
### Production Build
```bash
# Build wheel
cd sgl-model-gateway/bindings/python
maturin build --release --out dist --features vendored-openssl
# Install the built wheel
pip install dist/sglang_router-*.whl
```
## Testing
```bash
# Run Python unit tests (after maturin develop)
cd sgl-model-gateway/bindings/python
pytest tests/
```
## Configuration
- **pyproject.toml**: Defines package metadata, dependencies, and build configuration
- **python-source**: Set to `"src"` indicating Python source uses the src layout
- **module-name**: `sglang_router.sglang_router_rs` - the Rust extension module name
## Notes
- The Rust bindings source code is located in `src/lib.rs`
- The bindings have their own `Cargo.toml` in this directory
- The main sglang-router library is located in `../../` and is used as a dependency
- The package includes both Python code and Rust extensions built with PyO3
- PyO3 types are prefixed with `Py` in Rust but exposed to Python without the prefix using the `name` attribute

View File

@@ -0,0 +1,64 @@
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[project]
name = "sglang-router"
version = "0.3.2"
description = "High-performance Rust-based load balancer for SGLang with multiple routing algorithms and prefill-decode disaggregation support"
authors = [
{name = "Simo Lin", email = "linsimo.mark@gmail.com"},
{name = "Chang Su", email = "mckvtl@gmail.com"},
{name = "Keyang Ru", email = "rukeyang@gmail.com"},
{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}
]
requires-python = ">=3.8"
readme = "../../README.md"
license = { text = "Apache-2.0" }
classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Rust",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
]
dependencies = [
"setproctitle",
"aiohttp",
"orjson",
"uvicorn",
"fastapi",
]
[project.optional-dependencies]
dev = [
"requests>=2.25.0",
"pytest>=7.0.0",
]
[project.scripts]
smg = "sglang_router.cli:main"
amg = "sglang_router.cli:main"
sglang-router = "sglang_router.cli:main"
[tool.maturin]
python-source = "src"
module-name = "sglang_router.sglang_router_rs"
# Exclude bindings/python/README.md to use root README only
exclude = ["README.md"]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"unit: mark test as a unit test (no GPU required)",
]

View File

@@ -0,0 +1,28 @@
import os
import warnings
from setuptools import setup
with_rust = os.environ.get("SGLANG_ROUTER_BUILD_WITH_RUST", None)
with_rust = with_rust is None or (not with_rust.lower() in ["0", "false", "no"])
rust_extensions = []
if with_rust:
from setuptools_rust import Binding, RustExtension
rust_extensions.append(
RustExtension(
target="sglang_router_rs",
path="Cargo.toml",
binding=Binding.PyO3,
)
)
else:
warnings.warn(
"Building 'sglang-router' without Rust support. Performance may be degraded."
)
setup(
rust_extensions=rust_extensions,
zip_safe=False,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
from sglang_router.version import __version__
__all__ = ["__version__"]

View File

@@ -0,0 +1,8 @@
"""
Allow running the CLI via: python -m sglang_router
"""
from sglang_router.cli import main
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python3
"""
SGLang Model Gateway CLI
Provides convenient command-line interface for launching the router and server.
Usage:
smg launch [args] # Launch router only
smg server [args] # Launch router + server
smg --help # Show help
"""
import argparse
import os
import sys
from typing import List, Optional
from sglang_router.sglang_router_rs import (
get_verbose_version_string,
get_version_string,
)
def create_parser() -> argparse.ArgumentParser:
"""Create the main CLI parser with subcommands."""
prog_name = os.path.basename(sys.argv[0]) if sys.argv else "smg"
parser = argparse.ArgumentParser(
prog=prog_name,
description="SGLang Model Gateway - High-performance inference router",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Launch router subcommand
launch_parser = subparsers.add_parser(
"launch",
help="Launch router only (requires existing worker URLs)",
description="Launch the SGLang router with existing worker instances",
add_help=False, # Let router handle --help
)
# Launch server + router subcommand
server_parser = subparsers.add_parser(
"server",
help="Launch router and server processes together",
description="Launch both SGLang router and server processes",
add_help=False, # Let server handle --help
)
return parser
def main(argv: Optional[List[str]] = None) -> None:
"""Main CLI entry point."""
if argv is None:
argv = sys.argv[1:]
# Handle version flags before parsing
if argv and argv[0] in ["--version", "-V", "--version-verbose"]:
if argv[0] == "--version-verbose":
print(get_verbose_version_string())
else:
print(get_version_string())
sys.exit(0)
# Handle empty command - show help
if not argv or argv[0] not in ["launch", "server", "-h", "--help"]:
parser = create_parser()
parser.print_help()
sys.exit(1)
parser = create_parser()
args, unknown = parser.parse_known_args(argv)
if args.command == "launch":
# Import and call launch_router functions directly
from sglang_router.launch_router import launch_router, parse_router_args
# All router args are in unknown
router_args = parse_router_args(unknown)
launch_router(router_args)
elif args.command == "server":
# Import and call launch_server main with proper argv
# Note: launch_server.main() uses argparse internally which reads sys.argv
# We need to temporarily set sys.argv for compatibility
import sglang_router.launch_server as launch_server_module
# Preserve original sys.argv
original_argv = sys.argv
try:
# All server args are in unknown
prog_name = os.path.basename(sys.argv[0]) if sys.argv else "smg"
sys.argv = [f"{prog_name} server"] + unknown
launch_server_module.main()
finally:
# Restore original sys.argv
sys.argv = original_argv
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,109 @@
import argparse
import logging
import sys
from typing import List, Optional
import setproctitle
from sglang_router.mini_lb import MiniLoadBalancer
from sglang_router.router_args import RouterArgs
logger = logging.getLogger("router")
try:
from sglang_router.router import Router
except ImportError:
Router = None
logger.warning(
"Rust Router is not installed, only python MiniLB (debugging only) is available"
)
def launch_router(args: argparse.Namespace) -> Optional[Router]:
"""
Launch the SGLang router with the configuration from parsed arguments.
Args:
args: Namespace object containing router configuration
Can be either raw argparse.Namespace or converted RouterArgs
Returns:
Router instance if successful, None if failed
"""
setproctitle.setproctitle("sglang::router")
try:
# Convert to RouterArgs if needed
if not isinstance(args, RouterArgs):
router_args = RouterArgs.from_cli_args(args)
else:
router_args = args
if router_args.mini_lb:
mini_lb = MiniLoadBalancer(router_args)
mini_lb.start()
else:
if Router is None:
raise RuntimeError("Rust Router is not installed")
router_args._validate_router_args()
router = Router.from_args(router_args)
router.start()
except Exception as e:
logger.error(f"Error starting router: {e}")
raise e
class CustomHelpFormatter(
argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter
):
"""Custom formatter that preserves both description formatting and shows defaults"""
pass
def parse_router_args(args: List[str]) -> RouterArgs:
"""Parse command line arguments and return RouterArgs instance."""
parser = argparse.ArgumentParser(
description="""SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode with same policy for both
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware
# PD mode with optional bootstrap ports
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 \\ # With bootstrap port
--prefill http://prefill2:8000 none \\ # Explicitly no bootstrap port
--prefill http://prefill3:8000 \\ # Defaults to no bootstrap port
--decode http://decode1:8001 --decode http://decode2:8001
# PD mode with different policies for prefill and decode
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 --prefill http://prefill2:8000 \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--prefill-policy cache_aware --decode-policy power_of_two
""",
formatter_class=CustomHelpFormatter,
)
RouterArgs.add_cli_args(parser, use_router_prefix=False)
return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False)
def main() -> None:
router_args = parse_router_args(sys.argv[1:])
launch_router(router_args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,213 @@
import argparse
import asyncio
import copy
import logging
import multiprocessing as mp
import os
import random
import signal
import sys
import time
from typing import List
import requests
from setproctitle import setproctitle
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils.network import is_port_available
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
# Create new process group
def run_server(server_args, dp_rank):
"""
Note:
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
Terminal (PGID=100)
└── Main Python Process (PGID=100)
└── Server Process 1 (PGID=100)
└── Scheduler 1
└── Detokenizer 1
└── Server Process 2 (PGID=100)
└── Scheduler 2
└── Detokenizer 2
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
Terminal (PGID=100)
└── Main Python Process (PGID=200)
└── Server Process 1 (PGID=300)
└── Scheduler 1
└── Detokenizer 1
└── Server Process 2 (PGID=400)
└── Scheduler 2
└── Detokenizer 2
"""
# create new process group
os.setpgrp()
setproctitle("sglang::server")
# Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
# Launch server in appropriate mode (HTTP or gRPC)
if server_args.grpc_mode:
from sglang.srt.entrypoints.grpc_server import serve_grpc
asyncio.run(serve_grpc(server_args))
else:
from sglang.srt.entrypoints.http_server import launch_server
launch_server(server_args)
def launch_server_process(
server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
"""Launch a single server process with the given args and port."""
server_args = copy.deepcopy(server_args)
server_args.port = worker_port
server_args.base_gpu_id = dp_id * server_args.tp_size
server_args.dp_size = 1
proc = mp.Process(target=run_server, args=(server_args, dp_id))
proc.start()
return proc
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.perf_counter()
url = f"http://{host}:{port}/health"
while time.perf_counter() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)
return available_ports
def cleanup_processes(processes: List[mp.Process]):
for process in processes:
logger.info(f"Terminating process group {process.pid}")
try:
os.killpg(process.pid, signal.SIGTERM)
except ProcessLookupError:
# Process group may already be terminated
pass
# Wait for processes to terminate
for process in processes:
process.join(timeout=5)
if process.is_alive():
logger.warning(
f"Process {process.pid} did not terminate gracefully, forcing kill"
)
try:
os.killpg(process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
logger.info("All process groups terminated")
def main():
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
parser = argparse.ArgumentParser(
description="Launch SGLang router and server processes"
)
ServerArgs.add_cli_args(parser)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
parser.add_argument(
"--router-dp-worker-base-port",
type=int,
default=31000,
help="Base port number for data parallel workers",
)
# No extra retry/CB flags here; RouterArgs.add_cli_args already defines them with router- prefix
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Find available ports for workers
worker_ports = find_available_ports(
args.router_dp_worker_base_port, server_args.dp_size
)
# Start server processes
server_processes = []
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
signal.signal(
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
)
signal.signal(
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
)
# Update router args with worker URLs
# Use grpc:// protocol if server is in gRPC mode, otherwise http://
protocol = "grpc" if server_args.grpc_mode else "http"
router_args.worker_urls = [
f"{protocol}://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
try:
launch_router(router_args)
except Exception as e:
logger.error(f"Failed to start router: {e}")
cleanup_processes(server_processes)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,462 @@
"""
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
import asyncio
import ipaddress
import logging
import random
import urllib
import warnings
from http import HTTPStatus
from itertools import chain
from typing import Optional
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang_router.router_args import RouterArgs
logger = logging.getLogger(__name__)
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
) # 64KB, to prevent aiohttp's "Chunk too big" error
def maybe_wrap_ipv6_address(address: str) -> str:
try:
ipaddress.IPv6Address(address)
return f"[{address}]"
except ValueError:
return address
class MiniLoadBalancer:
def __init__(
self,
router_args: RouterArgs,
):
self._validate_router_args(router_args)
self.host = router_args.host
self.port = router_args.port
self.timeout = router_args.request_timeout_secs
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
self.decode_urls = router_args.decode_urls
self.test_external_dp_routing = router_args.test_external_dp_routing
self.prefill_dp_size = None
self.decode_dp_size = None
def _validate_router_args(self, router_args: RouterArgs):
logger.warning(
"\x1b[33mMiniLB is only for debugging purposes, it only supports random policy!\033[0m"
)
# NOTE: too many arguments unsupported, just validate some important ones
if router_args.policy != "random":
logger.warning("[MiniLB] Overriding policy to random")
router_args.policy = "random"
if not router_args.pd_disaggregation:
raise ValueError("MiniLB only supports PD disaggregation mode")
if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0:
raise ValueError(
"MiniLB requires at least one prefill and one decode server"
)
def start(self):
global lb
lb = self
uvicorn.run(app, host=self.host, port=self.port)
async def _ensure_dp_sizes(self):
if self.prefill_dp_size is not None:
return
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.prefill_urls[0]}/server_info") as resp:
info = await resp.json()
self.prefill_dp_size = len(info.get("internal_states", [1]))
async with session.get(f"{self.decode_urls[0]}/server_info") as resp:
info = await resp.json()
self.decode_dp_size = len(info.get("internal_states", [1]))
logger.info(
f"[MiniLB] DP sizes: prefill={self.prefill_dp_size}, decode={self.decode_dp_size}"
)
def _fork_dp_requests(self, request):
p_rank = random.randint(0, self.prefill_dp_size - 1)
d_rank = random.randint(0, self.decode_dp_size - 1)
prefill_req = request.copy()
decode_req = request.copy()
prefill_req["routed_dp_rank"] = p_rank
decode_req["routed_dp_rank"] = d_rank
decode_req["disagg_prefill_dp_rank"] = p_rank
return prefill_req, decode_req, d_rank
def select_pair(self):
assert len(self.prefill_urls) > 0, "No prefill servers available"
assert len(self.decode_urls) > 0, "No decode servers available"
pidx = random.randint(0, len(self.prefill_urls) - 1)
didx = random.randint(0, len(self.decode_urls) - 1)
return (
self.prefill_urls[pidx],
self.prefill_bootstrap_ports[pidx],
self.decode_urls[didx],
)
async def generate(
self, modified_request, prefill_server, decode_server, endpoint
) -> ORJSONResponse:
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
expected_decode_dp_rank = None
if self.test_external_dp_routing:
await self._ensure_dp_sizes()
prefill_req, decode_req, expected_decode_dp_rank = self._fork_dp_requests(
modified_request
)
else:
prefill_req = modified_request
decode_req = modified_request
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=prefill_req),
session.post(f"{decode_server}/{endpoint}", json=decode_req),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
if "return_logprob" in modified_request:
prefill_json = await prefill_response.json()
ret_json = await decode_response.json()
# merge `meta_info.input_token_logprobs` from prefill to decode
if "meta_info" in ret_json:
if "input_token_logprobs" in ret_json["meta_info"]:
ret_json["meta_info"]["input_token_logprobs"] = (
prefill_json["meta_info"]["input_token_logprobs"]
+ ret_json["meta_info"]["input_token_logprobs"]
)
else:
ret_json = await decode_response.json()
if expected_decode_dp_rank is not None:
actual = ret_json.get("meta_info", {}).get("dp_rank")
if actual != expected_decode_dp_rank:
return ORJSONResponse(
content={
"error": f"DP rank mismatch: expected {expected_decode_dp_rank}, got {actual}"
},
status_code=500,
)
return ORJSONResponse(
content=ret_json,
status_code=decode_response.status,
)
async def generate_stream(
self, modified_request, prefill_server, decode_server, endpoint="generate"
):
if self.test_external_dp_routing:
warnings.warn("--test-external-dp-routing is not supported with streaming")
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
async def stream_results():
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=self.timeout
) # Add timeout for request reliability
) as session:
# Create the tasks for both prefill and decode requests
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
if modified_request.get("return_logprob", False):
prefill_chunks = []
async for chunk in prefill_response.content:
prefill_chunks.append(chunk)
first_prefill_chunk = (
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
)
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
async for chunk in decode_response.content:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk = chunk.decode("utf-8")
if (
decoded_chunk
and decoded_chunk.startswith("data:")
and "[DONE]" not in decoded_chunk
):
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
ret_json["meta_info"]["input_token_logprobs"] = (
first_prefill_chunk_json["meta_info"][
"input_token_logprobs"
]
+ ret_json["meta_info"]["input_token_logprobs"]
)
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
else:
yield chunk
else:
async for chunk in decode_response.content.iter_chunked(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield chunk
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
)
app = FastAPI()
lb: Optional[MiniLoadBalancer] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.get(f"{server}/health_generate"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
@app.post("/flush_cache")
async def flush_cache():
async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = []
for server in chain(lb.prefill_urls, lb.decode_urls):
tasks.append(session.post(f"{server}/flush_cache"))
for i, response in enumerate(asyncio.as_completed(tasks)):
await response
return Response(status_code=200)
# TODO: Remove `/get_server_info` alias after one release-cycle deprecation window.
@app.get("/server_info")
@app.get("/get_server_info")
async def get_server_info():
prefill_infos = []
decode_infos = []
all_internal_states = []
async with aiohttp.ClientSession() as session:
for server in lb.prefill_urls:
server_info = await session.get(f"{server}/server_info")
prefill_infos.append(await server_info.json())
for server in lb.decode_urls:
server_info = await session.get(f"{server}/server_info")
info_json = await server_info.json()
decode_infos.append(info_json)
# Extract internal_states from decode servers
if "internal_states" in info_json:
all_internal_states.extend(info_json["internal_states"])
# Return format expected by bench_one_batch_server.py
if all_internal_states:
return {
"internal_states": all_internal_states,
"prefill": prefill_infos,
"decode": decode_infos,
}
else:
# Fallback with dummy data if no internal states found
return {
"internal_states": [
{
"last_gen_throughput": 0.0,
"avg_spec_accept_length": None,
}
],
"prefill": prefill_infos,
"decode": decode_infos,
}
async def _get_model_info_impl():
if not lb or not lb.prefill_urls:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="There is no server registered",
)
target_server_url = lb.prefill_urls[0]
endpoint_url = f"{target_server_url}/model_info"
async with aiohttp.ClientSession() as session:
try:
async with session.get(endpoint_url) as response:
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=HTTPStatus.BAD_GATEWAY,
detail=(
f"Failed to get model info from {target_server_url}"
f"Status: {response.status}, Response: {error_text}"
),
)
model_info_json = await response.json()
return ORJSONResponse(content=model_info_json)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail=f"Failed to get model info from backend",
)
@app.get("/model_info")
async def model_info():
return await _get_model_info_impl()
@app.get("/get_model_info")
async def get_model_info():
return await _get_model_info_impl()
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request)
if batch_size is not None:
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
}
)
else:
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request, prefill_server, decode_server, "generate"
)
else:
return await lb.generate(
modified_request, prefill_server, decode_server, "generate"
)
async def _forward_to_backend(request_data: dict, endpoint_name: str):
prefill_server, bootstrap_port, decode_server = lb.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
if request_data.get("stream", False):
return await lb.generate_stream(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
else:
return await lb.generate(
modified_request,
prefill_server,
decode_server,
endpoint=endpoint_name,
)
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/chat/completions")
@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
return await _forward_to_backend(request_data, "v1/completions")
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
if (text := request.get("text")) is not None:
return None if isinstance(text, str) else len(text)
if (input_ids := request.get("input_ids")) is not None:
return None if isinstance(input_ids[0], int) else len(input_ids)
return None
@app.get("/v1/models")
async def get_models():
prefill_server = lb.prefill_urls[0] # Get the first prefill server
async with aiohttp.ClientSession() as session:
try:
response = await session.get(f"{prefill_server}/v1/models")
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Prefill server error: Status {response.status}",
)
return ORJSONResponse(content=await response.json())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,320 @@
from typing import Optional
from sglang_router.router_args import RouterArgs
from sglang_router.sglang_router_rs import (
BackendType,
HistoryBackendType,
PolicyType,
PyApiKeyEntry,
PyControlPlaneAuthConfig,
PyJwtConfig,
PyOracleConfig,
PyPostgresConfig,
PyRedisConfig,
PyRole,
)
from sglang_router.sglang_router_rs import Router as _Router
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
"""Convert policy string to PolicyType enum."""
if policy_str is None:
return None
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
"bucket": PolicyType.Bucket,
"manual": PolicyType.Manual,
"consistent_hashing": PolicyType.ConsistentHashing,
"prefix_hash": PolicyType.PrefixHash,
}
return policy_map[policy_str]
def backend_from_str(backend_str: Optional[str]) -> BackendType:
"""Convert backend string to BackendType enum."""
if isinstance(backend_str, BackendType):
return backend_str
if backend_str is None:
return BackendType.Sglang
backend_map = {"sglang": BackendType.Sglang, "openai": BackendType.Openai}
backend_lower = backend_str.lower()
if backend_lower not in backend_map:
raise ValueError(
f"Unknown backend: {backend_str}. Valid options: {', '.join(backend_map.keys())}"
)
return backend_map[backend_lower]
def history_backend_from_str(backend_str: Optional[str]) -> HistoryBackendType:
"""Convert history backend string to HistoryBackendType enum."""
if isinstance(backend_str, HistoryBackendType):
return backend_str
if backend_str is None:
return HistoryBackendType.Memory
backend_lower = backend_str.lower()
if backend_lower == "memory":
return HistoryBackendType.Memory
elif backend_lower == "none":
# Use getattr to access 'None' which is a Python keyword
return getattr(HistoryBackendType, "None")
elif backend_lower == "oracle":
return HistoryBackendType.Oracle
elif backend_lower == "postgres":
return HistoryBackendType.Postgres
elif backend_lower == "redis":
return HistoryBackendType.Redis
else:
raise ValueError(f"Unknown history backend: {backend_str}")
def role_from_str(role_str: str) -> PyRole:
"""Convert role string to PyRole enum."""
if role_str.lower() == "admin":
return PyRole.Admin
return PyRole.User
def build_control_plane_auth_config(
args_dict: dict,
) -> Optional[PyControlPlaneAuthConfig]:
"""Build control plane auth config from args dict."""
api_keys = args_dict.get("control_plane_api_keys", [])
jwt_issuer = args_dict.get("jwt_issuer")
jwt_audience = args_dict.get("jwt_audience")
audit_enabled = args_dict.get("control_plane_audit_enabled", False)
# Check if any auth is configured
has_api_keys = bool(api_keys)
has_jwt = jwt_issuer is not None and jwt_audience is not None
if not has_api_keys and not has_jwt:
return None
# Build API key entries
py_api_keys = []
for key_tuple in api_keys:
# Tuple format: (id, name, key, role)
key_id, name, key, role = key_tuple
py_api_keys.append(
PyApiKeyEntry(
id=key_id,
name=name,
key=key,
role=role_from_str(role),
)
)
# Build JWT config if present
jwt_config = None
if has_jwt:
jwt_config = PyJwtConfig(
issuer=jwt_issuer,
audience=jwt_audience,
jwks_uri=args_dict.get("jwt_jwks_uri"),
role_mapping=args_dict.get("jwt_role_mapping", {}),
)
return PyControlPlaneAuthConfig(
jwt=jwt_config,
api_keys=py_api_keys,
audit_enabled=audit_enabled,
)
class Router:
"""
A high-performance router for distributing requests across worker nodes.
Args:
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
host: Host address to bind the router server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces. Default: '0.0.0.0'
port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup and registration. Large models can take significant time to load into GPU memory. Default: 1800 (30 minutes)
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
dp_aware: Enable data parallelism aware schedule. Default: False
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled,
the router can manage multiple models simultaneously with per-model load balancing
policies. Default: False
api_key: The api key used for the authorization with the worker.
Useful when the dp aware scheduling strategy is enabled.
Default: None
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
log_level: Logging level. Options: 'debug', 'info', 'warn', 'error'.
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
automatically discover worker pods based on the selector. Default: False
selector: Dictionary mapping of label keys to values for Kubernetes pod selection.
Example: {"app": "sglang-worker"}. Default: {}
service_discovery_port: Port to use for service discovery. The router will generate
worker URLs using this port. Default: 80
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
watches pods across all namespaces (requires cluster-wide permissions). Default: None
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for prefill servers (PD mode only). Default: {}
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for decode servers (PD mode only). Default: {}
prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
If not specified, uses the main policy. Default: None
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
If not specified, uses the main policy. Default: None
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
Default: 'sglang.ai/bootstrap-port'
request_timeout_secs: Request timeout in seconds. Default: 600
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
health_check_endpoint: Health check endpoint path. Default: '/health'
model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
server_cert_path: Path to server TLS certificate (PEM format). Default: None
server_key_path: Path to server TLS private key (PEM format). Default: None
"""
def __init__(self, router: _Router):
self._router = router
@staticmethod
def from_args(args: RouterArgs) -> "Router":
"""Create a router from a RouterArgs instance."""
args_dict = vars(args)
# Convert RouterArgs to _Router parameters
args_dict["worker_urls"] = (
[]
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
else args_dict["worker_urls"]
)
args_dict["policy"] = policy_from_str(args_dict["policy"])
args_dict["prefill_urls"] = (
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["decode_urls"] = (
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
# Convert backend
args_dict["backend"] = backend_from_str(args_dict.get("backend"))
# Convert history_backend to enum first
history_backend_raw = args_dict.get("history_backend", "memory")
history_backend = history_backend_from_str(history_backend_raw)
# Convert Oracle config if needed
oracle_config = None
if history_backend == HistoryBackendType.Oracle:
# Prioritize TNS alias over connect descriptor
tns_alias = args_dict.get("oracle_tns_alias")
connect_descriptor = args_dict.get("oracle_connect_descriptor")
# Use TNS alias if provided, otherwise use connect descriptor
final_descriptor = tns_alias if tns_alias else connect_descriptor
oracle_config = PyOracleConfig(
password=args_dict.get("oracle_password"),
username=args_dict.get("oracle_username"),
connect_descriptor=final_descriptor,
wallet_path=args_dict.get("oracle_wallet_path"),
pool_min=args_dict.get("oracle_pool_min", 1),
pool_max=args_dict.get("oracle_pool_max", 16),
pool_timeout_secs=args_dict.get("oracle_pool_timeout_secs", 30),
)
args_dict["oracle_config"] = oracle_config
args_dict["history_backend"] = history_backend
# Convert Postgres config if needed
postgres_config = None
if history_backend == HistoryBackendType.Postgres:
postgres_config = PyPostgresConfig(
db_url=args_dict.get("postgres_db_url"),
pool_max=args_dict.get("postgres_pool_max", 16),
)
args_dict["postgres_config"] = postgres_config
# Convert Redis config if needed
redis_config = None
if history_backend == HistoryBackendType.Redis:
retention_days = args_dict.get("redis_retention_days", 30)
# If retention_days is negative, it means persistent storage (None in Rust)
retention_arg = None if retention_days < 0 else retention_days
redis_config = PyRedisConfig(
url=args_dict.get("redis_url"),
pool_max=args_dict.get("redis_pool_max", 16),
retention_days=retention_arg,
)
args_dict["redis_config"] = redis_config
# Build control plane auth config
args_dict["control_plane_auth"] = build_control_plane_auth_config(args_dict)
# Remove fields that shouldn't be passed to Rust Router constructor
fields_to_remove = [
"mini_lb",
"test_external_dp_routing",
"oracle_wallet_path",
"oracle_tns_alias",
"oracle_connect_descriptor",
"oracle_username",
"oracle_password",
"oracle_pool_min",
"oracle_pool_max",
"oracle_pool_timeout_secs",
"postgres_db_url",
"postgres_pool_max",
"redis_url",
"redis_pool_max",
"redis_retention_days",
# Control plane auth fields (converted to control_plane_auth)
"control_plane_api_keys",
"control_plane_audit_enabled",
"jwt_issuer",
"jwt_audience",
"jwt_jwks_uri",
"jwt_role_mapping",
]
for field in fields_to_remove:
args_dict.pop(field, None)
return Router(_Router(**args_dict))
def start(self) -> None:
"""Start the router server.
This method blocks until the server is shut down.
"""
self._router.start()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
__version__ = "0.3.2"

View File

@@ -0,0 +1,14 @@
"""
Pytest configuration for sglang_router Python binding tests.
These are unit tests that run without GPU resources or external dependencies.
"""
import pytest
def pytest_configure(config):
"""Configure pytest markers."""
config.addinivalue_line(
"markers", "unit: mark test as a unit test (no GPU required)"
)

View File

@@ -0,0 +1,637 @@
"""
Unit tests for argument parsing functionality in sglang_router.
These tests focus on testing the argument parsing logic in isolation,
without starting actual router instances.
"""
from types import SimpleNamespace
import pytest
from sglang_router.launch_router import RouterArgs, parse_router_args
from sglang_router.router import policy_from_str
class TestRouterArgs:
"""Test RouterArgs dataclass and its methods."""
def test_default_values(self):
"""Test that RouterArgs has correct default values."""
args = RouterArgs()
# Test basic defaults
assert args.host == "0.0.0.0"
assert args.port == 30000
assert args.policy == "cache_aware"
assert args.worker_urls == []
assert args.pd_disaggregation is False
assert args.prefill_urls == []
assert args.decode_urls == []
# Test PD-specific defaults
assert args.prefill_policy is None
assert args.decode_policy is None
# Test service discovery defaults
assert args.service_discovery is False
assert args.selector == {}
assert args.service_discovery_port == 80
assert args.service_discovery_namespace is None
# Test retry and circuit breaker defaults
assert args.retry_max_retries == 5
assert args.cb_failure_threshold == 10
assert args.disable_retries is False
assert args.disable_circuit_breaker is False
def test_parse_selector_valid(self):
"""Test parsing valid selector arguments."""
# Test single key-value pair
result = RouterArgs._parse_selector(["app=worker"])
assert result == {"app": "worker"}
# Test multiple key-value pairs
result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"])
assert result == {"app": "worker", "env": "prod", "version": "v1"}
# Test empty list
result = RouterArgs._parse_selector([])
assert result == {}
# Test None
result = RouterArgs._parse_selector(None)
assert result == {}
def test_parse_selector_invalid(self):
"""Test parsing invalid selector arguments."""
# Test malformed selector (no equals sign)
result = RouterArgs._parse_selector(["app"])
assert result == {}
# Test multiple equals signs (should use first one)
result = RouterArgs._parse_selector(["app=worker=extra"])
assert result == {"app": "worker=extra"}
def test_parse_prefill_urls_valid(self):
"""Test parsing valid prefill URL arguments."""
# Test with bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]])
assert result == [("http://prefill1:8000", 9000)]
# Test with 'none' bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]])
assert result == [("http://prefill1:8000", None)]
# Test without bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]])
assert result == [("http://prefill1:8000", None)]
# Test multiple prefill URLs
result = RouterArgs._parse_prefill_urls(
[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
["http://prefill3:8000"],
]
)
expected = [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
("http://prefill3:8000", None),
]
assert result == expected
# Test empty list
result = RouterArgs._parse_prefill_urls([])
assert result == []
# Test None
result = RouterArgs._parse_prefill_urls(None)
assert result == []
def test_parse_prefill_urls_invalid(self):
"""Test parsing invalid prefill URL arguments."""
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]])
def test_parse_decode_urls_valid(self):
"""Test parsing valid decode URL arguments."""
# Test single decode URL
result = RouterArgs._parse_decode_urls([["http://decode1:8001"]])
assert result == ["http://decode1:8001"]
# Test multiple decode URLs
result = RouterArgs._parse_decode_urls(
[["http://decode1:8001"], ["http://decode2:8001"]]
)
assert result == ["http://decode1:8001", "http://decode2:8001"]
# Test empty list
result = RouterArgs._parse_decode_urls([])
assert result == []
# Test None
result = RouterArgs._parse_decode_urls(None)
assert result == []
def test_from_cli_args_basic(self):
"""Test creating RouterArgs from basic CLI arguments."""
args = SimpleNamespace(
host="0.0.0.0",
port=30001,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="round_robin",
prefill=None,
decode=None,
router_policy="round_robin",
router_pd_disaggregation=False,
router_prefill_policy=None,
router_decode_policy=None,
router_worker_startup_timeout_secs=300,
router_worker_startup_check_interval=15,
router_cache_threshold=0.7,
router_balance_abs_threshold=128,
router_balance_rel_threshold=2.0,
router_eviction_interval=180,
router_max_tree_size=2**28,
router_max_payload_size=1024 * 1024 * 1024, # 1GB
router_dp_aware=True,
router_api_key="test-key",
router_log_dir="/tmp/logs",
router_log_level="debug",
router_service_discovery=True,
router_selector=["app=worker", "env=test"],
router_service_discovery_port=8080,
router_service_discovery_namespace="default",
router_prefill_selector=["app=prefill"],
router_decode_selector=["app=decode"],
router_prometheus_port=29000,
router_prometheus_host="0.0.0.0",
router_request_id_headers=["x-request-id", "x-trace-id"],
router_request_timeout_secs=1200,
router_max_concurrent_requests=512,
router_queue_size=200,
router_queue_timeout_secs=120,
router_rate_limit_tokens_per_second=100,
router_cors_allowed_origins=["http://localhost:3000"],
router_retry_max_retries=3,
router_retry_initial_backoff_ms=100,
router_retry_max_backoff_ms=10000,
router_retry_backoff_multiplier=2.0,
router_retry_jitter_factor=0.1,
router_cb_failure_threshold=5,
router_cb_success_threshold=2,
router_cb_timeout_duration_secs=30,
router_cb_window_duration_secs=60,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=2,
router_health_success_threshold=1,
router_health_check_timeout_secs=3,
router_health_check_interval_secs=30,
router_health_check_endpoint="/healthz",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test basic configuration
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
# Test PD configuration
assert router_args.pd_disaggregation is False
assert router_args.prefill_urls == []
assert router_args.decode_urls == []
# Test service discovery
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "test"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
assert router_args.prefill_selector == {"app": "prefill"}
assert router_args.decode_selector == {"app": "decode"}
# Test other configurations
assert router_args.dp_aware is True
assert router_args.api_key == "test-key"
assert router_args.log_dir == "/tmp/logs"
assert router_args.log_level == "debug"
assert router_args.prometheus_port == 29000
assert router_args.prometheus_host == "0.0.0.0"
assert router_args.request_id_headers == ["x-request-id", "x-trace-id"]
assert router_args.request_timeout_secs == 1200
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
assert router_args.cors_allowed_origins == ["http://localhost:3000"]
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_retries is False
assert router_args.disable_circuit_breaker is False
# Test health check configuration
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
# Note: model_path and tokenizer_path are not available in current RouterArgs
def test_from_cli_args_pd_mode(self):
"""Test creating RouterArgs from CLI arguments in PD mode."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=[],
policy="cache_aware",
prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
router_decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_policy="cache_aware",
router_pd_disaggregation=True,
router_prefill_policy="power_of_two",
router_decode_policy="round_robin",
# Include all required fields with defaults
router_worker_startup_timeout_secs=600,
router_worker_startup_check_interval=30,
router_cache_threshold=0.3,
router_balance_abs_threshold=64,
router_balance_rel_threshold=1.5,
router_eviction_interval=120,
router_max_tree_size=2**26,
router_max_payload_size=512 * 1024 * 1024,
router_dp_aware=False,
router_api_key=None,
router_log_dir=None,
router_log_level=None,
router_service_discovery=False,
router_selector=None,
router_service_discovery_port=80,
router_service_discovery_namespace=None,
router_prefill_selector=None,
router_decode_selector=None,
router_prometheus_port=None,
router_prometheus_host=None,
router_request_id_headers=None,
router_request_timeout_secs=1800,
router_max_concurrent_requests=256,
router_queue_size=100,
router_queue_timeout_secs=60,
router_rate_limit_tokens_per_second=None,
router_cors_allowed_origins=[],
router_retry_max_retries=5,
router_retry_initial_backoff_ms=50,
router_retry_max_backoff_ms=30000,
router_retry_backoff_multiplier=1.5,
router_retry_jitter_factor=0.2,
router_cb_failure_threshold=10,
router_cb_success_threshold=3,
router_cb_timeout_duration_secs=60,
router_cb_window_duration_secs=120,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=3,
router_health_success_threshold=2,
router_health_check_timeout_secs=5,
router_health_check_interval_secs=60,
router_health_check_endpoint="/health",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test PD configuration
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
assert router_args.policy == "cache_aware" # Main policy still set
def test_from_cli_args_without_prefix(self):
"""Test creating RouterArgs from CLI arguments without router prefix."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="random",
prefill=None,
decode=None,
pd_disaggregation=False,
prefill_policy=None,
decode_policy=None,
worker_startup_timeout_secs=600,
worker_startup_check_interval=30,
cache_threshold=0.3,
balance_abs_threshold=64,
balance_rel_threshold=1.5,
eviction_interval=120,
max_tree_size=2**26,
max_payload_size=512 * 1024 * 1024,
dp_aware=False,
api_key=None,
log_dir=None,
log_level=None,
service_discovery=False,
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
prefill_selector=None,
decode_selector=None,
prometheus_port=None,
prometheus_host=None,
request_id_headers=None,
request_timeout_secs=1800,
max_concurrent_requests=256,
queue_size=100,
queue_timeout_secs=60,
rate_limit_tokens_per_second=None,
cors_allowed_origins=[],
retry_max_retries=5,
retry_initial_backoff_ms=50,
retry_max_backoff_ms=30000,
retry_backoff_multiplier=1.5,
retry_jitter_factor=0.2,
cb_failure_threshold=10,
cb_success_threshold=3,
cb_timeout_duration_secs=60,
cb_window_duration_secs=120,
disable_retries=False,
disable_circuit_breaker=False,
health_failure_threshold=3,
health_success_threshold=2,
health_check_timeout_secs=5,
health_check_interval_secs=60,
health_check_endpoint="/health",
model_path=None,
tokenizer_path=None,
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=False)
assert router_args.host == "127.0.0.1"
assert router_args.port == 30000
assert router_args.worker_urls == ["http://worker1:8000"]
assert router_args.policy == "random"
assert router_args.pd_disaggregation is False
class TestPolicyFromStr:
"""Test policy string to enum conversion."""
def test_valid_policies(self):
"""Test conversion of valid policy strings."""
from sglang_router.sglang_router_rs import PolicyType
assert policy_from_str("random") == PolicyType.Random
assert policy_from_str("round_robin") == PolicyType.RoundRobin
assert policy_from_str("cache_aware") == PolicyType.CacheAware
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
def test_invalid_policy(self):
"""Test conversion of invalid policy string."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
class TestParseRouterArgs:
"""Test the parse_router_args function."""
def test_parse_basic_args(self):
"""Test parsing basic router arguments."""
args = [
"--host",
"0.0.0.0",
"--port",
"30001",
"--worker-urls",
"http://worker1:8000",
"http://worker2:8000",
"--policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
def test_parse_pd_args(self):
"""Test parsing PD disaggregated mode arguments."""
args = [
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"9000",
"--prefill",
"http://prefill2:8000",
"none",
"--decode",
"http://decode1:8001",
"--decode",
"http://decode2:8001",
"--prefill-policy",
"power_of_two",
"--decode-policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
def test_parse_service_discovery_args(self):
"""Test parsing service discovery arguments."""
args_a = [
"--service-discovery",
"--selector",
"app=worker",
"env=prod",
"--service-discovery-port",
"8080",
"--service-discovery-namespace",
"default",
]
args_b = [
"--service-discovery",
"--selector",
# OME has this style
"app=worker env=prod",
"--service-discovery-port",
"8080",
"--service-discovery-namespace",
"default",
]
for args in [args_a, args_b]:
router_args = parse_router_args(args)
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "prod"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
def test_parse_retry_and_circuit_breaker_args(self):
"""Test parsing retry and circuit breaker arguments."""
args = [
"--retry-max-retries",
"3",
"--retry-initial-backoff-ms",
"100",
"--retry-max-backoff-ms",
"10000",
"--retry-backoff-multiplier",
"2.0",
"--retry-jitter-factor",
"0.1",
"--disable-retries",
"--cb-failure-threshold",
"5",
"--cb-success-threshold",
"2",
"--cb-timeout-duration-secs",
"30",
"--cb-window-duration-secs",
"60",
"--disable-circuit-breaker",
]
router_args = parse_router_args(args)
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
assert router_args.disable_retries is True
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_circuit_breaker is True
def test_parse_rate_limiting_args(self):
"""Test parsing rate limiting arguments."""
args = [
"--max-concurrent-requests",
"512",
"--queue-size",
"200",
"--queue-timeout-secs",
"120",
"--rate-limit-tokens-per-second",
"100",
]
router_args = parse_router_args(args)
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
def test_parse_health_check_args(self):
"""Test parsing health check arguments."""
args = [
"--health-failure-threshold",
"2",
"--health-success-threshold",
"1",
"--health-check-timeout-secs",
"3",
"--health-check-interval-secs",
"30",
"--health-check-endpoint",
"/healthz",
]
router_args = parse_router_args(args)
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
def test_parse_cors_args(self):
"""Test parsing CORS arguments."""
args = [
"--cors-allowed-origins",
"http://localhost:3000",
"https://example.com",
]
router_args = parse_router_args(args)
assert router_args.cors_allowed_origins == [
"http://localhost:3000",
"https://example.com",
]
def test_parse_tokenizer_args(self):
"""Test parsing tokenizer arguments."""
# Note: model-path and tokenizer-path arguments are not available in current implementation
# This test is skipped until those arguments are added
pytest.skip("Tokenizer arguments not available in current implementation")
def test_parse_invalid_args(self):
"""Test parsing invalid arguments."""
# Test invalid policy
with pytest.raises(SystemExit):
parse_router_args(["--policy", "invalid_policy"])
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
parse_router_args(
[
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"invalid_port",
]
)
def test_help_output(self):
"""Test that help output is generated correctly."""
with pytest.raises(SystemExit) as exc_info:
parse_router_args(["--help"])
# SystemExit with code 0 indicates help was displayed
assert exc_info.value.code == 0

View File

@@ -0,0 +1,423 @@
"""
Unit tests for router configuration validation and setup.
These tests focus on testing the router configuration logic in isolation,
including validation of configuration parameters and their interactions.
"""
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
from sglang_router.router import policy_from_str
from sglang_router.sglang_router_rs import PolicyType
class TestRouterConfigValidation:
"""Test router configuration validation logic."""
def test_valid_basic_config(self):
"""Test that a valid basic configuration passes validation."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="cache_aware",
)
# Should not raise any exceptions
assert args.host == "127.0.0.1"
assert args.port == 30000
assert args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert args.policy == "cache_aware"
def test_valid_pd_config(self):
"""Test that a valid PD configuration passes validation."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
],
decode_urls=["http://decode1:8001", "http://decode2:8001"],
policy="cache_aware",
)
assert args.pd_disaggregation is True
assert args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert args.policy == "cache_aware"
def test_pd_config_without_urls_allowed(self):
"""Test that PD mode without URLs is now allowed (URLs are optional)."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
# Should not raise validation error - URLs are now optional
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# This should succeed without raising an error
launch_router(args)
router_mod.from_args.assert_called_once()
def test_pd_config_with_service_discovery_allows_empty_urls(self):
"""Test that PD mode with service discovery allows empty URLs."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error when service discovery is enabled
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_regular_mode_without_workers_allows_empty_urls(self):
"""Test that regular mode allows empty worker URLs."""
args = RouterArgs(worker_urls=[], service_discovery=False)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_cache_threshold_validation(self):
"""Test cache threshold validation."""
# Valid cache threshold
args = RouterArgs(cache_threshold=0.5)
assert args.cache_threshold == 0.5
# Edge cases
args = RouterArgs(cache_threshold=0.0)
assert args.cache_threshold == 0.0
args = RouterArgs(cache_threshold=1.0)
assert args.cache_threshold == 1.0
def test_balance_threshold_validation(self):
"""Test load balancing threshold validation."""
# Valid thresholds
args = RouterArgs(balance_abs_threshold=64, balance_rel_threshold=1.5)
assert args.balance_abs_threshold == 64
assert args.balance_rel_threshold == 1.5
# Edge cases
args = RouterArgs(balance_abs_threshold=0, balance_rel_threshold=1.0)
assert args.balance_abs_threshold == 0
assert args.balance_rel_threshold == 1.0
def test_timeout_validation(self):
"""Test timeout parameter validation."""
# Valid timeouts
args = RouterArgs(
worker_startup_timeout_secs=600,
worker_startup_check_interval=30,
request_timeout_secs=1800,
queue_timeout_secs=60,
)
assert args.worker_startup_timeout_secs == 600
assert args.worker_startup_check_interval == 30
assert args.request_timeout_secs == 1800
assert args.queue_timeout_secs == 60
def test_retry_config_validation(self):
"""Test retry configuration validation."""
# Valid retry config
args = RouterArgs(
retry_max_retries=5,
retry_initial_backoff_ms=50,
retry_max_backoff_ms=30000,
retry_backoff_multiplier=1.5,
retry_jitter_factor=0.2,
disable_retries=False,
)
assert args.retry_max_retries == 5
assert args.retry_initial_backoff_ms == 50
assert args.retry_max_backoff_ms == 30000
assert args.retry_backoff_multiplier == 1.5
assert args.retry_jitter_factor == 0.2
assert args.disable_retries is False
def test_circuit_breaker_config_validation(self):
"""Test circuit breaker configuration validation."""
# Valid circuit breaker config
args = RouterArgs(
cb_failure_threshold=10,
cb_success_threshold=3,
cb_timeout_duration_secs=60,
cb_window_duration_secs=120,
disable_circuit_breaker=False,
)
assert args.cb_failure_threshold == 10
assert args.cb_success_threshold == 3
assert args.cb_timeout_duration_secs == 60
assert args.cb_window_duration_secs == 120
assert args.disable_circuit_breaker is False
def test_health_check_config_validation(self):
"""Test health check configuration validation."""
# Valid health check config
args = RouterArgs(
health_failure_threshold=3,
health_success_threshold=2,
health_check_timeout_secs=5,
health_check_interval_secs=60,
health_check_endpoint="/health",
)
assert args.health_failure_threshold == 3
assert args.health_success_threshold == 2
assert args.health_check_timeout_secs == 5
assert args.health_check_interval_secs == 60
assert args.health_check_endpoint == "/health"
def test_rate_limiting_config_validation(self):
"""Test rate limiting configuration validation."""
# Valid rate limiting config
args = RouterArgs(
max_concurrent_requests=256,
queue_size=100,
queue_timeout_secs=60,
rate_limit_tokens_per_second=100,
)
assert args.max_concurrent_requests == 256
assert args.queue_size == 100
assert args.queue_timeout_secs == 60
assert args.rate_limit_tokens_per_second == 100
def test_service_discovery_config_validation(self):
"""Test service discovery configuration validation."""
# Valid service discovery config
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
assert args.service_discovery is True
assert args.selector == {"app": "worker", "env": "prod"}
assert args.service_discovery_port == 8080
assert args.service_discovery_namespace == "default"
def test_pd_service_discovery_config_validation(self):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery config
args = RouterArgs(
pd_disaggregation=True,
service_discovery=True,
prefill_selector={"app": "prefill"},
decode_selector={"app": "decode"},
bootstrap_port_annotation="sglang.ai/bootstrap-port",
)
assert args.pd_disaggregation is True
assert args.service_discovery is True
assert args.prefill_selector == {"app": "prefill"}
assert args.decode_selector == {"app": "decode"}
assert args.bootstrap_port_annotation == "sglang.ai/bootstrap-port"
def test_prometheus_config_validation(self):
"""Test Prometheus configuration validation."""
# Valid Prometheus config
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
assert args.prometheus_port == 29000
assert args.prometheus_host == "127.0.0.1"
def test_cors_config_validation(self):
"""Test CORS configuration validation."""
# Valid CORS config
args = RouterArgs(
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
)
assert args.cors_allowed_origins == [
"http://localhost:3000",
"https://example.com",
]
def test_tokenizer_config_validation(self):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
def test_dp_aware_config_validation(self):
"""Test data parallelism aware configuration validation."""
# Valid DP aware config
args = RouterArgs(dp_aware=True, api_key="test-api-key")
assert args.dp_aware is True
assert args.api_key == "test-api-key"
def test_request_id_headers_validation(self):
"""Test request ID headers configuration validation."""
# Valid request ID headers config
args = RouterArgs(
request_id_headers=["x-request-id", "x-trace-id", "x-correlation-id"]
)
assert args.request_id_headers == [
"x-request-id",
"x-trace-id",
"x-correlation-id",
]
def test_policy_consistency_validation(self):
"""Test policy consistency validation in PD mode."""
# Test with both prefill and decode policies specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy="round_robin",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_policy_fallback_validation(self):
"""Test policy fallback validation in PD mode."""
# Test with only prefill policy specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy=None,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_policy_enum_conversion(self):
"""Test policy string to enum conversion."""
# Test all valid policy conversions
assert policy_from_str("random") == PolicyType.Random
assert policy_from_str("round_robin") == PolicyType.RoundRobin
assert policy_from_str("cache_aware") == PolicyType.CacheAware
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
def test_invalid_policy_enum_conversion(self):
"""Test invalid policy string to enum conversion."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
def test_config_immutability(self):
"""Test that configuration objects are properly immutable."""
args = RouterArgs(
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
)
# Test that we can't modify the configuration after creation
# (This is more of a design test - dataclasses are mutable by default)
original_host = args.host
args.host = "0.0.0.0"
assert args.host == "0.0.0.0" # Dataclasses are mutable
assert args.host != original_host
def test_config_defaults_consistency(self):
"""Test that configuration defaults are consistent."""
args1 = RouterArgs()
args2 = RouterArgs()
# Both instances should have the same defaults
assert args1.host == args2.host
assert args1.port == args2.port
assert args1.policy == args2.policy
assert args1.worker_urls == args2.worker_urls
assert args1.pd_disaggregation == args2.pd_disaggregation
def test_config_serialization(self):
"""Test that configuration can be serialized/deserialized."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
cache_threshold=0.5,
)
# Test that we can access all attributes
assert hasattr(args, "host")
assert hasattr(args, "port")
assert hasattr(args, "worker_urls")
assert hasattr(args, "policy")
assert hasattr(args, "cache_threshold")
def test_config_with_none_values(self):
"""Test configuration with None values."""
args = RouterArgs(
api_key=None,
log_dir=None,
log_level=None,
prometheus_port=None,
prometheus_host=None,
request_id_headers=None,
rate_limit_tokens_per_second=None,
service_discovery_namespace=None,
)
# All None values should be preserved
assert args.api_key is None
assert args.log_dir is None
assert args.log_level is None
assert args.prometheus_port is None
assert args.prometheus_host is None
assert args.request_id_headers is None
assert args.rate_limit_tokens_per_second is None
assert args.service_discovery_namespace is None
def test_config_with_empty_lists(self):
"""Test configuration with empty lists."""
args = RouterArgs(
worker_urls=[], prefill_urls=[], decode_urls=[], cors_allowed_origins=[]
)
# All empty lists should be preserved
assert args.worker_urls == []
assert args.prefill_urls == []
assert args.decode_urls == []
assert args.cors_allowed_origins == []
def test_config_with_empty_dicts(self):
"""Test configuration with empty dictionaries."""
args = RouterArgs(selector={}, prefill_selector={}, decode_selector={})
# All empty dictionaries should be preserved
assert args.selector == {}
assert args.prefill_selector == {}
assert args.decode_selector == {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,509 @@
"""
Unit tests for validation logic in sglang_router.
These tests focus on testing the validation logic in isolation,
including parameter validation, URL validation, and configuration validation.
"""
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
class TestURLValidation:
"""Test URL validation logic."""
def test_valid_worker_urls(self):
"""Test validation of valid worker URLs."""
valid_urls = [
"http://worker1:8000",
"https://worker2:8000",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://192.168.1.100:8000",
"http://worker.example.com:8000",
]
for url in valid_urls:
args = RouterArgs(worker_urls=[url])
# Should not raise any validation errors
assert url in args.worker_urls
def test_valid_prefill_urls(self):
"""Test validation of valid prefill URLs."""
valid_prefill_urls = [
("http://prefill1:8000", 9000),
("https://prefill2:8000", None),
("http://localhost:8000", 9000),
("http://127.0.0.1:8000", None),
]
for url, bootstrap_port in valid_prefill_urls:
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
# Should not raise any validation errors
assert (url, bootstrap_port) in args.prefill_urls
def test_valid_decode_urls(self):
"""Test validation of valid decode URLs."""
valid_decode_urls = [
"http://decode1:8001",
"https://decode2:8001",
"http://localhost:8001",
"http://127.0.0.1:8001",
]
for url in valid_decode_urls:
args = RouterArgs(decode_urls=[url])
# Should not raise any validation errors
assert url in args.decode_urls
def test_malformed_urls(self):
"""Test handling of malformed URLs."""
# Note: The current implementation doesn't validate URL format
# This test documents the current behavior
malformed_urls = [
"not-a-url",
"ftp://worker1:8000", # Wrong protocol
"http://", # Missing host
":8000", # Missing protocol and host
"http://worker1", # Missing port
]
for url in malformed_urls:
args = RouterArgs(worker_urls=[url])
# Currently, malformed URLs are accepted
# This might be something to improve in the future
assert url in args.worker_urls
class TestPortValidation:
"""Test port validation logic."""
def test_valid_ports(self):
"""Test validation of valid port numbers."""
valid_ports = [1, 80, 8000, 30000, 65535]
for port in valid_ports:
args = RouterArgs(port=port)
assert args.port == port
def test_invalid_ports(self):
"""Test handling of invalid port numbers."""
# Note: The current implementation doesn't validate port ranges
# This test documents the current behavior
invalid_ports = [0, -1, 65536, 70000]
for port in invalid_ports:
args = RouterArgs(port=port)
# Currently, invalid ports are accepted
# This might be something to improve in the future
assert args.port == port
def test_bootstrap_port_validation(self):
"""Test validation of bootstrap ports in PD mode."""
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
for bootstrap_port in valid_bootstrap_ports:
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
assert args.prefill_urls[0][1] == bootstrap_port
class TestParameterValidation:
"""Test parameter validation logic."""
def test_cache_threshold_validation(self):
"""Test cache threshold parameter validation."""
# Valid cache thresholds
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
for threshold in valid_thresholds:
args = RouterArgs(cache_threshold=threshold)
assert args.cache_threshold == threshold
def test_balance_threshold_validation(self):
"""Test load balancing threshold parameter validation."""
# Valid absolute thresholds
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
for threshold in valid_abs_thresholds:
args = RouterArgs(balance_abs_threshold=threshold)
assert args.balance_abs_threshold == threshold
# Valid relative thresholds
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
for threshold in valid_rel_thresholds:
args = RouterArgs(balance_rel_threshold=threshold)
assert args.balance_rel_threshold == threshold
def test_timeout_validation(self):
"""Test timeout parameter validation."""
# Valid timeouts
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
for timeout in valid_timeouts:
args = RouterArgs(
worker_startup_timeout_secs=timeout,
worker_startup_check_interval=timeout,
request_timeout_secs=timeout,
queue_timeout_secs=timeout,
)
assert args.worker_startup_timeout_secs == timeout
assert args.worker_startup_check_interval == timeout
assert args.request_timeout_secs == timeout
assert args.queue_timeout_secs == timeout
def test_retry_parameter_validation(self):
"""Test retry parameter validation."""
# Valid retry parameters
valid_retry_counts = [0, 1, 3, 5, 10]
for count in valid_retry_counts:
args = RouterArgs(retry_max_retries=count)
assert args.retry_max_retries == count
# Valid backoff parameters
valid_backoff_ms = [1, 50, 100, 1000, 30000]
for backoff in valid_backoff_ms:
args = RouterArgs(
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
)
assert args.retry_initial_backoff_ms == backoff
assert args.retry_max_backoff_ms == backoff
# Valid multiplier parameters
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
for multiplier in valid_multipliers:
args = RouterArgs(retry_backoff_multiplier=multiplier)
assert args.retry_backoff_multiplier == multiplier
# Valid jitter parameters
valid_jitter = [0.0, 0.1, 0.2, 0.5]
for jitter in valid_jitter:
args = RouterArgs(retry_jitter_factor=jitter)
assert args.retry_jitter_factor == jitter
def test_circuit_breaker_parameter_validation(self):
"""Test circuit breaker parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 3, 5, 10, 20]
for threshold in valid_failure_thresholds:
args = RouterArgs(cb_failure_threshold=threshold)
assert args.cb_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(cb_success_threshold=threshold)
assert args.cb_success_threshold == threshold
# Valid timeout durations
valid_timeouts = [10, 30, 60, 120, 300]
for timeout in valid_timeouts:
args = RouterArgs(
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
)
assert args.cb_timeout_duration_secs == timeout
assert args.cb_window_duration_secs == timeout
def test_health_check_parameter_validation(self):
"""Test health check parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 2, 3, 5, 10]
for threshold in valid_failure_thresholds:
args = RouterArgs(health_failure_threshold=threshold)
assert args.health_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(health_success_threshold=threshold)
assert args.health_success_threshold == threshold
# Valid timeouts and intervals
valid_times = [1, 5, 10, 30, 60, 120]
for time_val in valid_times:
args = RouterArgs(
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
)
assert args.health_check_timeout_secs == time_val
assert args.health_check_interval_secs == time_val
def test_rate_limiting_parameter_validation(self):
"""Test rate limiting parameter validation."""
# Valid concurrent request limits
valid_limits = [1, 10, 64, 256, 512, 1000]
for limit in valid_limits:
args = RouterArgs(max_concurrent_requests=limit)
assert args.max_concurrent_requests == limit
# Valid queue sizes
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
for size in valid_queue_sizes:
args = RouterArgs(queue_size=size)
assert args.queue_size == size
# Valid token rates
valid_rates = [1, 10, 50, 100, 500, 1000]
for rate in valid_rates:
args = RouterArgs(rate_limit_tokens_per_second=rate)
assert args.rate_limit_tokens_per_second == rate
def test_tree_size_validation(self):
"""Test tree size parameter validation."""
# Valid tree sizes (powers of 2)
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
for size in valid_sizes:
args = RouterArgs(max_tree_size=size)
assert args.max_tree_size == size
def test_payload_size_validation(self):
"""Test payload size parameter validation."""
# Valid payload sizes
valid_sizes = [
1024, # 1KB
1024 * 1024, # 1MB
10 * 1024 * 1024, # 10MB
100 * 1024 * 1024, # 100MB
512 * 1024 * 1024, # 512MB
1024 * 1024 * 1024, # 1GB
]
for size in valid_sizes:
args = RouterArgs(max_payload_size=size)
assert args.max_payload_size == size
class TestConfigurationValidation:
"""Test configuration validation logic."""
def test_pd_mode_validation(self):
"""Test PD mode configuration validation."""
# Valid PD configuration
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
)
assert args.pd_disaggregation is True
assert len(args.prefill_urls) > 0
assert len(args.decode_urls) > 0
def test_service_discovery_validation(self):
"""Test service discovery configuration validation."""
# Valid service discovery configuration
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
assert args.service_discovery is True
assert args.selector == {"app": "worker", "env": "prod"}
assert args.service_discovery_port == 8080
assert args.service_discovery_namespace == "default"
def test_pd_service_discovery_validation(self):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery configuration
args = RouterArgs(
pd_disaggregation=True,
service_discovery=True,
prefill_selector={"app": "prefill"},
decode_selector={"app": "decode"},
)
assert args.pd_disaggregation is True
assert args.service_discovery is True
assert args.prefill_selector == {"app": "prefill"}
assert args.decode_selector == {"app": "decode"}
def test_policy_validation(self):
"""Test policy configuration validation."""
# Valid policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for policy in valid_policies:
args = RouterArgs(policy=policy)
assert args.policy == policy
def test_pd_policy_validation(self):
"""Test PD policy configuration validation."""
# Valid PD policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for prefill_policy in valid_policies:
for decode_policy in valid_policies:
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)
assert args.prefill_policy == prefill_policy
assert args.decode_policy == decode_policy
def test_cors_validation(self):
"""Test CORS configuration validation."""
# Valid CORS origins
valid_origins = [
[],
["http://localhost:3000"],
["https://example.com"],
["http://localhost:3000", "https://example.com"],
["*"], # Wildcard (if supported)
]
for origins in valid_origins:
args = RouterArgs(cors_allowed_origins=origins)
assert args.cors_allowed_origins == origins
def test_logging_validation(self):
"""Test logging configuration validation."""
# Valid log levels
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
for level in valid_log_levels:
args = RouterArgs(log_level=level)
assert args.log_level == level
def test_prometheus_validation(self):
"""Test Prometheus configuration validation."""
# Valid Prometheus configuration
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
assert args.prometheus_port == 29000
assert args.prometheus_host == "127.0.0.1"
def test_tokenizer_validation(self):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
def test_request_id_headers_validation(self):
"""Test request ID headers configuration validation."""
# Valid request ID headers
valid_headers = [
["x-request-id"],
["x-request-id", "x-trace-id"],
["x-request-id", "x-trace-id", "x-correlation-id"],
["custom-header"],
]
for headers in valid_headers:
args = RouterArgs(request_id_headers=headers)
assert args.request_id_headers == headers
class TestLaunchValidation:
"""Test launch-time validation logic."""
def test_pd_mode_allows_empty_urls(self):
"""Test that PD mode now allows empty URLs (URLs are optional)."""
# PD mode without URLs is now allowed
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
# Should not raise validation error - URLs are now optional
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# This should succeed without raising an error
launch_router(args)
router_mod.from_args.assert_called_once()
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
"""Test that PD mode with service discovery allows empty URLs."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_regular_mode_allows_empty_worker_urls(self):
"""Test that regular mode allows empty worker URLs."""
args = RouterArgs(worker_urls=[], service_discovery=False)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_valid_config(self):
"""Test launching with valid configuration."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_pd_config(self):
"""Test launching with valid PD configuration."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_service_discovery_config(self):
"""Test launching with valid service discovery configuration."""
args = RouterArgs(
service_discovery=True,
selector={"app": "worker"},
service_discovery_port=8080,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()

View File

@@ -0,0 +1,108 @@
use std::process::Command;
const DEFAULT_VERSION: &str = "0.0.0";
const DEFAULT_PROJECT_NAME: &str = "sgl-model-gateway";
/// Set a compile-time environment variable with the SGL_MODEL_GATEWAY_ prefix
macro_rules! set_env {
($name:expr, $value:expr) => {
println!("cargo:rustc-env=SGL_MODEL_GATEWAY_{}={}", $name, $value);
};
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Rebuild triggers
println!("cargo:rerun-if-changed=Cargo.toml");
// Set version info environment variables
let version = read_cargo_version().unwrap_or_else(|_| DEFAULT_VERSION.to_string());
let target = std::env::var("TARGET").unwrap_or_else(|_| get_rustc_host().unwrap_or_default());
let profile = std::env::var("PROFILE").unwrap_or_default();
set_env!("PROJECT_NAME", DEFAULT_PROJECT_NAME);
set_env!("VERSION", version);
set_env!(
"BUILD_TIME",
chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC")
);
set_env!(
"BUILD_MODE",
if profile == "release" {
"release"
} else {
"debug"
}
);
set_env!("TARGET_TRIPLE", target);
set_env!(
"GIT_BRANCH",
git_branch().unwrap_or_else(|| "unknown".into())
);
set_env!(
"GIT_COMMIT",
git_commit().unwrap_or_else(|| "unknown".into())
);
set_env!(
"GIT_STATUS",
git_status().unwrap_or_else(|| "unknown".into())
);
set_env!(
"RUSTC_VERSION",
rustc_version().unwrap_or_else(|| "unknown".into())
);
set_env!(
"CARGO_VERSION",
cargo_version().unwrap_or_else(|| "unknown".into())
);
Ok(())
}
fn read_cargo_version() -> Result<String, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string("Cargo.toml")?;
let toml: toml::Value = toml::from_str(&content)?;
toml.get("package")
.and_then(|p| p.get("version"))
.and_then(|v| v.as_str())
.map(String::from)
.ok_or_else(|| "Missing version in Cargo.toml".into())
}
fn run_cmd(cmd: &str, args: &[&str]) -> Option<String> {
Command::new(cmd)
.args(args)
.output()
.ok()
.filter(|o| o.status.success())
.and_then(|o| String::from_utf8(o.stdout).ok())
.map(|s| s.trim().to_string())
}
fn git_branch() -> Option<String> {
run_cmd("git", &["rev-parse", "--abbrev-ref", "HEAD"])
}
fn git_commit() -> Option<String> {
run_cmd("git", &["rev-parse", "--short", "HEAD"])
}
fn git_status() -> Option<String> {
run_cmd("git", &["status", "--porcelain"])
.map(|s| if s.is_empty() { "clean" } else { "dirty" }.into())
}
fn rustc_version() -> Option<String> {
run_cmd("rustc", &["--version"])
}
fn cargo_version() -> Option<String> {
run_cmd("cargo", &["--version"])
}
fn get_rustc_host() -> Option<String> {
run_cmd("rustc", &["-vV"])?
.lines()
.find(|l| l.starts_with("host: "))
.and_then(|l| l.strip_prefix("host: "))
.map(|s| s.trim().to_string())
}

View File

@@ -0,0 +1 @@
"""Test package root for router Python tests."""

View File

@@ -0,0 +1,222 @@
"""Benchmark-specific fixtures."""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import time
from pathlib import Path
import pytest
from infra import GPUMonitor, should_monitor_gpu, terminate_process
from .results import BenchmarkResult
logger = logging.getLogger(__name__)
def _build_command(
cli: str,
router_url: str,
model_path: str,
experiment_folder: str,
num_concurrency: int,
traffic_scenario: str,
max_requests: int,
) -> list[str]:
"""Build genai-bench command."""
return [
cli,
"benchmark",
"--api-backend",
"openai",
"--api-base",
router_url,
"--api-key",
"dummy-token",
"--api-model-name",
model_path,
"--model-tokenizer",
model_path,
"--task",
"text-to-text",
"--num-concurrency",
str(num_concurrency),
"--traffic-scenario",
traffic_scenario,
"--max-requests-per-run",
str(max_requests),
"--max-time-per-run",
"3",
"--experiment-folder-name",
experiment_folder,
"--experiment-base-dir",
str(Path.cwd()),
]
def _find_results(experiment_folder: str, timeout: int = 10) -> list[Path]:
"""Find benchmark result JSON files."""
base = Path.cwd()
folder = base / experiment_folder
if not folder.is_dir():
# Search for folder
for p in base.rglob(experiment_folder):
if p.is_dir() and p.name == experiment_folder:
folder = p
break
if not folder.is_dir():
raise AssertionError(f"Experiment folder not found: {experiment_folder}")
# Wait for JSON results
for _ in range(timeout):
files = [
p
for p in folder.rglob("*.json")
if "experiment_metadata" not in p.name and "gpu_utilization" not in p.name
]
if files:
return files
time.sleep(1)
raise AssertionError(f"No JSON results found in {folder}")
def _cleanup_procs(procs: list, drain_delay: int) -> None:
"""Terminate processes gracefully."""
if not procs:
return
if drain_delay > 0:
time.sleep(drain_delay)
for p in procs:
try:
proc = getattr(p, "proc", p) if hasattr(p, "proc") else p
if isinstance(proc, subprocess.Popen):
terminate_process(proc)
except Exception:
pass
time.sleep(2)
@pytest.fixture(scope="session")
def genai_bench_runner():
"""Run genai-bench and validate metrics.
Usage:
def test_perf(setup_backend, genai_bench_runner):
backend, model_path, client, gateway = setup_backend
genai_bench_runner(
router_url=gateway.base_url,
model_path=model_path,
experiment_folder="benchmark_results",
thresholds={"ttft_mean_max": 5, "gpu_util_p50_min": 99},
)
"""
def _run(
*,
router_url: str,
model_path: str,
experiment_folder: str,
thresholds: dict | None = None,
timeout_sec: int | None = None,
num_concurrency: int = 32,
traffic_scenario: str = "D(4000,100)",
max_requests_per_run: int | None = None,
kill_procs: list | None = None,
drain_delay_sec: int = 6,
) -> None:
cli = shutil.which("genai-bench")
if not cli:
pytest.fail("genai-bench CLI not found")
# Clean previous results
exp_dir = Path.cwd() / experiment_folder
if exp_dir.exists():
shutil.rmtree(exp_dir, ignore_errors=True)
# Build and run command
max_requests = max_requests_per_run or num_concurrency * 5
cmd = _build_command(
cli,
router_url,
model_path,
experiment_folder,
num_concurrency,
traffic_scenario,
max_requests,
)
timeout = timeout_sec or int(os.environ.get("GENAI_BENCH_TEST_TIMEOUT", "120"))
try:
proc = subprocess.Popen(
cmd,
env=os.environ.copy(),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
except FileNotFoundError:
pytest.fail(f"genai-bench executable not found at {cli}")
except PermissionError:
pytest.fail(f"Permission denied executing {cli}")
except OSError as e:
pytest.fail(f"Failed to start genai-bench: {e}")
# Start GPU monitor if needed
gpu_monitor: GPUMonitor | None = None
if should_monitor_gpu(thresholds):
interval = float(os.environ.get("GPU_UTIL_SAMPLE_INTERVAL", "2.0"))
gpu_monitor = GPUMonitor(output_dir=exp_dir, interval=interval)
gpu_monitor.start(target_pid=proc.pid)
try:
stdout, stderr = proc.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
proc.kill()
stdout, stderr = proc.communicate()
logger.error("genai-bench timed out after %ds", timeout)
# Log output if process failed or for debugging
if proc.returncode != 0:
logger.error(
"genai-bench exited with code %d\nstdout:\n%s\nstderr:\n%s",
proc.returncode,
stdout or "(empty)",
stderr or "(empty)",
)
try:
# Parse and validate results
for path in _find_results(experiment_folder):
result = BenchmarkResult.from_json(path)
result.log(experiment_folder, logger)
if thresholds:
result.validate(thresholds)
# Validate GPU utilization
if gpu_monitor:
gpu_monitor.stop()
gpu_monitor.log_summary()
gpu_monitor.assert_thresholds(thresholds)
except AssertionError:
# Log genai-bench output when results not found
logger.error(
"genai-bench output (returncode=%d):\nstdout:\n%s\nstderr:\n%s",
proc.returncode,
stdout or "(empty)",
stderr or "(empty)",
)
raise
finally:
_cleanup_procs(kill_procs, drain_delay_sec)
if gpu_monitor:
gpu_monitor.stop(timeout=2)
return _run

View File

@@ -0,0 +1,98 @@
"""Benchmark result dataclasses for parsing genai-bench and GPU monitor output."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
@dataclass
class BenchmarkResult:
"""Parsed benchmark metrics from genai-bench output."""
ttft_mean: float
e2e_latency_mean: float
input_throughput_mean: float
output_throughput_mean: float
file_name: str
@classmethod
def from_json(cls, path: Path) -> "BenchmarkResult":
"""Parse benchmark results from JSON file."""
with path.open() as f:
data = json.load(f)
stats = data.get("aggregated_metrics", {}).get("stats", {})
return cls(
ttft_mean=float(stats.get("ttft", {}).get("mean", float("inf"))),
e2e_latency_mean=float(
stats.get("e2e_latency", {}).get("mean", float("inf"))
),
input_throughput_mean=float(
stats.get("input_throughput", {}).get("mean", 0.0)
),
output_throughput_mean=float(
stats.get("output_throughput", {}).get("mean", 0.0)
),
file_name=path.name,
)
def log(self, experiment: str, logger) -> None:
"""Log benchmark results."""
logger.info(
"genai-bench[%s] %s ttft=%.3fs e2e=%.3fs input=%.1f tok/s output=%.1f tok/s",
experiment,
self.file_name,
self.ttft_mean,
self.e2e_latency_mean,
self.input_throughput_mean,
self.output_throughput_mean,
)
def validate(self, thresholds: dict) -> None:
"""Validate metrics against thresholds."""
checks = [
("ttft_mean_max", self.ttft_mean, "<=", "TTFT"),
("e2e_latency_mean_max", self.e2e_latency_mean, "<=", "E2E latency"),
(
"input_throughput_mean_min",
self.input_throughput_mean,
">=",
"Input throughput",
),
(
"output_throughput_mean_min",
self.output_throughput_mean,
">=",
"Output throughput",
),
]
for key, value, op, name in checks:
if key not in thresholds:
continue
threshold = thresholds[key]
if op == "<=" and value > threshold:
raise AssertionError(f"{name}: {value:.2f} > {threshold}")
if op == ">=" and value < threshold:
raise AssertionError(f"{name}: {value:.2f} < {threshold}")
@dataclass
class GPUUtilization:
"""Parsed GPU utilization metrics from gpu_monitor output."""
overall_mean: float
per_gpu: dict[str, dict[str, float]]
@classmethod
def from_json(cls, path: Path) -> "GPUUtilization | None":
"""Parse GPU utilization from JSON file."""
try:
with path.open() as f:
data = json.load(f)
return cls(
overall_mean=float(data.get("overall", {}).get("mean", 0)),
per_gpu=data.get("per_gpu", {}),
)
except Exception:
return None

View File

@@ -0,0 +1,119 @@
"""Generate benchmark summary for GitHub Actions."""
from __future__ import annotations
import os
import sys
from pathlib import Path
from results import BenchmarkResult, GPUUtilization
def discover_benchmarks(base_dir: Path) -> list[tuple[Path, str]]:
"""Auto-discover benchmark folders and their result JSON files.
Returns list of (json_path, label) tuples sorted by folder name.
"""
results = []
for folder in base_dir.rglob("benchmark_*"):
if not folder.is_dir():
continue
# Find result JSON (exclude metadata and gpu files)
for json_file in folder.glob("*.json"):
if (
"experiment_metadata" not in json_file.name
and "gpu_utilization" not in json_file.name
):
# Generate label from folder name: benchmark_cache_aware_pd_grpc -> cache_aware pd grpc
label = folder.name.replace("benchmark_", "").replace("_", " ")
results.append((json_file, label))
break # One JSON per folder
return sorted(results, key=lambda x: x[0].parent.name)
def find_gpu_utilization(result_path: Path) -> Path | None:
"""Find GPU utilization JSON in same folder as result."""
gpu_json = result_path.parent / "gpu_utilization.json"
return gpu_json if gpu_json.exists() else None
def generate_summary(base_dir: Path) -> str:
"""Generate markdown summary."""
benchmarks = discover_benchmarks(base_dir)
if not benchmarks:
return (
"## Gateway E2E Genai-Bench Results Summary\n\nNo benchmark results found."
)
lines = [
"## Gateway E2E Genai-Bench Results Summary",
"",
"| Scenario | Status | TTFT (s) | E2E Latency (s) | Input Throughput (tok/s) | Output Throughput (tok/s) |",
"|----------|--------|----------|-----------------|--------------------------|---------------------------|",
]
gpu_sections = []
for result_path, label in benchmarks:
try:
result = BenchmarkResult.from_json(result_path)
except Exception as e:
print(f"Warning: Failed to parse {result_path}: {e}", file=sys.stderr)
lines.append(f"| {label} | ❌ Failed | - | - | - | - |")
continue
lines.append(
f"| {label} | ✅ Success | "
f"{result.ttft_mean:.2f} | "
f"{result.e2e_latency_mean:.2f} | "
f"{result.input_throughput_mean:.0f} | "
f"{result.output_throughput_mean:.0f} |"
)
# GPU utilization
gpu_path = find_gpu_utilization(result_path)
if gpu_path:
gpu = GPUUtilization.from_json(gpu_path)
if gpu and gpu.per_gpu:
gpu_lines = [
f"### GPU Utilization — {label}",
"",
f"Overall mean: {gpu.overall_mean:.2f}%",
"",
"| GPU | Mean (%) | p5 | p10 | p25 | p50 | p75 | p90 | p95 |",
"|-----|----------|----|-----|-----|-----|-----|-----|-----|",
]
for gpu_id, stats in sorted(
gpu.per_gpu.items(), key=lambda x: int(x[0])
):
gpu_lines.append(
f"| {gpu_id} | {stats.get('mean', 0):.2f} | "
f"{stats.get('p5', 0):.2f} | {stats.get('p10', 0):.2f} | "
f"{stats.get('p25', 0):.2f} | {stats.get('p50', 0):.2f} | "
f"{stats.get('p75', 0):.2f} | {stats.get('p90', 0):.2f} | "
f"{stats.get('p95', 0):.2f} |"
)
gpu_sections.append("\n".join(gpu_lines))
return "\n".join(lines) + "\n" + "\n\n".join(gpu_sections)
def main() -> None:
"""Main entry point."""
base_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path.cwd()
summary = generate_summary(base_dir)
# Write to GITHUB_STEP_SUMMARY if available
summary_file = os.environ.get("GITHUB_STEP_SUMMARY")
if summary_file:
with open(summary_file, "a") as f:
f.write(summary)
f.write("\n")
print(f"Summary written to {summary_file}")
else:
print(summary)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,26 @@
"""PD (prefill/decode disaggregation) router performance benchmark test."""
import pytest
@pytest.mark.e2e
@pytest.mark.workers(prefill=2, decode=2)
@pytest.mark.parametrize("setup_backend", ["pd"], indirect=True)
class TestPDPerf:
"""Performance benchmark for PD disaggregation router."""
def test_pd_perf(self, setup_backend, genai_bench_runner):
"""Run genai-bench against PD router and validate metrics."""
backend, model_path, client, gateway = setup_backend
genai_bench_runner(
router_url=gateway.base_url,
model_path=model_path,
experiment_folder="benchmark_round_robin_pd",
thresholds={
"ttft_mean_max": 13,
"e2e_latency_mean_max": 16,
"input_throughput_mean_min": 350,
"output_throughput_mean_min": 18,
"gpu_util_p50_min": 99,
},
)

View File

@@ -0,0 +1,27 @@
"""Regular router performance benchmark test."""
import pytest
@pytest.mark.e2e
@pytest.mark.workers(count=4)
@pytest.mark.gateway(policy="cache_aware")
@pytest.mark.parametrize("setup_backend", ["http", "grpc"], indirect=True)
class TestRegularPerf:
"""Performance benchmark for regular (non-PD) router."""
def test_regular_perf(self, setup_backend, genai_bench_runner):
"""Run genai-bench against regular router and validate metrics."""
backend, model_path, client, gateway = setup_backend
genai_bench_runner(
router_url=gateway.base_url,
model_path=model_path,
experiment_folder=f"benchmark_cache_aware_regular_{backend}",
thresholds={
"ttft_mean_max": 6,
"e2e_latency_mean_max": 14,
"input_throughput_mean_min": 800,
"output_throughput_mean_min": 12,
"gpu_util_p50_min": 99,
},
)

View File

@@ -0,0 +1,168 @@
"""Enable Thinking E2E Tests.
Tests for chat completions with enable_thinking feature (Qwen3 reasoning).
Source: Migrated from e2e_grpc/features/test_enable_thinking.py
"""
from __future__ import annotations
import json
import logging
import pytest
import requests
logger = logging.getLogger(__name__)
# API key is not validated by the gateway, but required for OpenAI-compatible headers
API_KEY = "not-used"
# =============================================================================
# Enable Thinking Tests (Qwen 30B)
# =============================================================================
@pytest.mark.model("qwen-30b")
@pytest.mark.gateway(
extra_args=["--reasoning-parser", "qwen3", "--history-backend", "memory"]
)
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
class TestEnableThinking:
"""Tests for enable_thinking feature with Qwen3 reasoning parser."""
def test_chat_completion_with_reasoning(self, setup_backend):
"""Test non-streaming with enable_thinking=True, reasoning_content should not be empty."""
_, model, client, gateway = setup_backend
response = requests.post(
f"{gateway.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {API_KEY}"},
json={
"model": model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
},
)
assert response.status_code == 200, f"Failed with: {response.text}"
data = response.json()
assert "choices" in data
assert len(data["choices"]) > 0
assert "message" in data["choices"][0]
assert "reasoning_content" in data["choices"][0]["message"]
assert data["choices"][0]["message"]["reasoning_content"] is not None
def test_chat_completion_without_reasoning(self, setup_backend):
"""Test non-streaming with enable_thinking=False, reasoning_content should be empty."""
_, model, client, gateway = setup_backend
response = requests.post(
f"{gateway.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {API_KEY}"},
json={
"model": model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response.status_code == 200, f"Failed with: {response.text}"
data = response.json()
assert "choices" in data
assert len(data["choices"]) > 0
assert "message" in data["choices"][0]
if "reasoning_content" in data["choices"][0]["message"]:
assert data["choices"][0]["message"]["reasoning_content"] is None
def test_stream_chat_completion_with_reasoning(self, setup_backend):
"""Test streaming with enable_thinking=True, reasoning_content should not be empty."""
_, model, client, gateway = setup_backend
response = requests.post(
f"{gateway.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {API_KEY}"},
json={
"model": model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": True},
},
stream=True,
)
assert response.status_code == 200, f"Failed with: {response.text}"
has_reasoning = False
has_content = False
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
assert (
has_reasoning
), "The reasoning content is not included in the stream response"
assert has_content, "The stream response does not contain normal content"
def test_stream_chat_completion_without_reasoning(self, setup_backend):
"""Test streaming with enable_thinking=False, reasoning_content should be empty."""
_, model, client, gateway = setup_backend
response = requests.post(
f"{gateway.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {API_KEY}"},
json={
"model": model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": False},
},
stream=True,
)
assert response.status_code == 200, f"Failed with: {response.text}"
has_reasoning = False
has_content = False
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
assert (
not has_reasoning
), "The reasoning content should not be included in the stream response"
assert has_content, "The stream response does not contain normal content"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,316 @@
"""Chat Completions API E2E Tests - OpenAI Server Compatibility.
Tests for OpenAI-compatible chat completions API through the gateway.
Source: Migrated from e2e_grpc/basic/test_openai_server.py
"""
from __future__ import annotations
import json
import logging
import pytest
logger = logging.getLogger(__name__)
# =============================================================================
# Chat Completion Tests (Llama 8B)
# =============================================================================
@pytest.mark.model("llama-8b")
@pytest.mark.gateway(extra_args=["--history-backend", "memory"])
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
class TestChatCompletion:
"""Tests for OpenAI-compatible chat completions API."""
@pytest.mark.parametrize("logprobs", [None, 5])
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
def test_chat_completion(self, setup_backend, logprobs, parallel_sample_num):
"""Test non-streaming chat completion with logprobs and parallel sampling."""
_, model, client, gateway = setup_backend
self._run_chat_completion(client, model, logprobs, parallel_sample_num)
@pytest.mark.parametrize("logprobs", [None, 5])
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
def test_chat_completion_stream(self, setup_backend, logprobs, parallel_sample_num):
"""Test streaming chat completion with logprobs and parallel sampling."""
_, model, client, gateway = setup_backend
self._run_chat_completion_stream(client, model, logprobs, parallel_sample_num)
def test_regex(self, setup_backend):
"""Test structured output with regex constraint."""
_, model, client, gateway = setup_backend
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_penalty(self, setup_backend):
"""Test frequency penalty parameter."""
_, model, client, gateway = setup_backend
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=32,
frequency_penalty=1.0,
)
text = response.choices[0].message.content
assert isinstance(text, str)
def test_response_prefill(self, setup_backend):
"""Test assistant message prefill with continue_final_message."""
_, model, client, gateway = setup_backend
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": """
Extract the name, size, price, and color from this product description as a JSON object:
<description>
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
</description>
""",
},
{
"role": "assistant",
"content": "{\n",
},
],
temperature=0,
extra_body={"continue_final_message": True},
)
assert (
response.choices[0]
.message.content.strip()
.startswith('"name": "SmartHome Mini",')
)
def test_model_list(self, setup_backend):
"""Test listing available models."""
_, model, client, gateway = setup_backend
models = list(client.models.list().data)
assert len(models) == 1
@pytest.mark.skip(
reason="Skipping retrieve model test as it is not supported by the router"
)
def test_retrieve_model(self, setup_backend):
"""Test retrieving a specific model."""
import openai
_, model, client, gateway = setup_backend
retrieved_model = client.models.retrieve(model)
assert retrieved_model.id == model
assert retrieved_model.root == model
with pytest.raises(openai.NotFoundError):
client.models.retrieve("non-existent-model")
# -------------------------------------------------------------------------
# Helper methods
# -------------------------------------------------------------------------
def _run_chat_completion(self, client, model, logprobs, parallel_sample_num):
"""Run a non-streaming chat completion and verify response."""
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "What is the capital of France? Answer in a few words.",
},
],
temperature=0,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
n=parallel_sample_num,
)
if logprobs:
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
)
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert len(response.choices) == parallel_sample_num
assert response.choices[0].message.role == "assistant"
assert isinstance(response.choices[0].message.content, str)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def _run_chat_completion_stream(
self, client, model, logprobs, parallel_sample_num=1
):
"""Run a streaming chat completion and verify response chunks."""
generator = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
)
is_firsts = {}
is_finished = {}
finish_reason_counts = {}
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0, "usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, "usage.completion_tokens was zero"
assert usage.total_tokens > 0, "usage.total_tokens was zero"
continue
index = response.choices[0].index
finish_reason = response.choices[0].finish_reason
if finish_reason is not None:
is_finished[index] = True
finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1
data = response.choices[0].delta
if is_firsts.get(index, True):
assert (
data.role == "assistant"
), "data.role was not 'assistant' for first chunk"
is_firsts[index] = False
continue
if logprobs and not is_finished.get(index, False):
assert response.choices[0].logprobs, "logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
), "top_logprobs token was not a string"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list
), "top_logprobs was not a list"
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert (
isinstance(data.content, str)
or isinstance(data.reasoning_content, str)
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
or response.choices[0].finish_reason
)
assert response.id
assert response.created
for index in range(parallel_sample_num):
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
for index in range(parallel_sample_num):
assert (
index in finish_reason_counts
), f"No finish_reason found for index {index}"
assert finish_reason_counts[index] == 1, (
f"Expected 1 finish_reason chunk for index {index}, "
f"got {finish_reason_counts[index]}"
)
# =============================================================================
# Chat Completion Tests (GPT-OSS)
#
# NOTE: Some tests are skipped because they don't work with OSS models:
# - test_regex: OSS models don't support regex constraints
# - test_penalty: OSS models don't support frequency_penalty
# - test_response_prefill: OSS models don't support continue_final_message
# =============================================================================
@pytest.mark.model("gpt-oss")
@pytest.mark.gateway(
extra_args=["--reasoning-parser=gpt-oss", "--history-backend", "memory"]
)
class TestChatCompletionGptOss(TestChatCompletion):
"""Tests for chat completions API with GPT-OSS model.
Inherits from TestChatCompletion and overrides tests that don't work
with OSS models.
"""
@pytest.mark.parametrize("logprobs", [None]) # No logprobs for OSS
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
def test_chat_completion(self, setup_backend, logprobs, parallel_sample_num):
"""Test non-streaming chat completion with parallel sampling (no logprobs)."""
super().test_chat_completion(setup_backend, logprobs, parallel_sample_num)
@pytest.mark.parametrize("logprobs", [None]) # No logprobs for OSS
@pytest.mark.parametrize("parallel_sample_num", [1, 2])
def test_chat_completion_stream(self, setup_backend, logprobs, parallel_sample_num):
"""Test streaming chat completion with parallel sampling (no logprobs)."""
super().test_chat_completion_stream(
setup_backend, logprobs, parallel_sample_num
)
@pytest.mark.skip(reason="OSS models don't support regex constraints")
def test_regex(self, setup_backend):
pass
@pytest.mark.skip(reason="OSS models don't support frequency_penalty")
def test_penalty(self, setup_backend):
pass
@pytest.mark.skip(reason="OSS models don't support continue_final_message")
def test_response_prefill(self, setup_backend):
pass

View File

@@ -0,0 +1,165 @@
"""Reasoning Content E2E Tests.
Tests for chat completions with reasoning content (DeepSeek R1 reasoning parser).
Source: Migrated from e2e_grpc/features/test_reasoning_content.py
"""
from __future__ import annotations
import logging
import pytest
logger = logging.getLogger(__name__)
# =============================================================================
# Reasoning Content API Tests (DeepSeek 7B)
# =============================================================================
@pytest.mark.model("deepseek-7b")
@pytest.mark.gateway(
extra_args=["--reasoning-parser", "deepseek_r1", "--history-backend", "memory"]
)
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
class TestReasoningContentAPI:
"""Tests for reasoning content API with DeepSeek R1 reasoning parser."""
def test_streaming_separate_reasoning_false(self, setup_backend):
"""Test streaming with separate_reasoning=False, reasoning_content should be empty."""
_, model, client, gateway = setup_backend
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": "What is 1+3?",
}
],
max_tokens=100,
stream=True,
extra_body={"separate_reasoning": False},
)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) == 0
assert len(content) > 0
def test_streaming_separate_reasoning_true(self, setup_backend):
"""Test streaming with separate_reasoning=True, reasoning_content should not be empty."""
_, model, client, gateway = setup_backend
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": "What is 1+3?",
}
],
max_tokens=100,
stream=True,
extra_body={"separate_reasoning": True},
)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content
assert len(reasoning_content) > 0
assert len(content) > 0
def test_streaming_separate_reasoning_true_stream_reasoning_false(
self, setup_backend
):
"""Test streaming with separate_reasoning=True and stream_reasoning=False."""
_, model, client, gateway = setup_backend
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": "What is 1+3?",
}
],
max_tokens=100,
stream=True,
extra_body={"separate_reasoning": True, "stream_reasoning": False},
)
reasoning_content = ""
content = ""
first_chunk = False
for chunk in response:
if chunk.choices[0].delta.reasoning_content:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
if not first_chunk:
reasoning_content = chunk.choices[0].delta.reasoning_content
first_chunk = True
if not first_chunk:
assert (
not chunk.choices[0].delta.reasoning_content
or len(chunk.choices[0].delta.reasoning_content) == 0
)
assert len(reasoning_content) > 0
assert len(content) > 0
def test_nonstreaming_separate_reasoning_false(self, setup_backend):
"""Test non-streaming with separate_reasoning=False, reasoning_content should be empty."""
_, model, client, gateway = setup_backend
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": "What is 1+3?",
}
],
max_tokens=100,
extra_body={"separate_reasoning": False},
)
assert (
not response.choices[0].message.reasoning_content
or len(response.choices[0].message.reasoning_content) == 0
)
assert len(response.choices[0].message.content) > 0
def test_nonstreaming_separate_reasoning_true(self, setup_backend):
"""Test non-streaming with separate_reasoning=True, reasoning_content should not be empty."""
_, model, client, gateway = setup_backend
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": "What is 1+3?",
}
],
max_tokens=100,
extra_body={"separate_reasoning": True},
)
assert len(response.choices[0].message.reasoning_content) > 0
assert len(response.choices[0].message.content) > 0

View File

@@ -0,0 +1,167 @@
"""Validation E2E Tests.
Tests for validation features like ignore_eos and large token handling.
Source: Migrated from e2e_grpc/validation/test_openai_server_ignore_eos.py
and e2e_grpc/validation/test_large_max_new_tokens.py
"""
from __future__ import annotations
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
import pytest
logger = logging.getLogger(__name__)
# Lazy load tokenizer to avoid import errors if transformers not installed
_tokenizer_cache: dict = {}
_tokenizer_lock = threading.Lock()
def get_tokenizer(model_path: str):
"""Get tokenizer for a model, with caching."""
if model_path not in _tokenizer_cache:
with _tokenizer_lock:
# Re-check after acquiring the lock to handle race conditions
if model_path not in _tokenizer_cache:
from transformers import AutoTokenizer
_tokenizer_cache[model_path] = AutoTokenizer.from_pretrained(model_path)
return _tokenizer_cache[model_path]
# =============================================================================
# Ignore EOS Tests (Llama 8B)
# =============================================================================
@pytest.mark.model("llama-8b")
@pytest.mark.gateway(extra_args=["--history-backend", "memory"])
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
class TestIgnoreEOS:
"""Tests for ignore_eos feature."""
def test_ignore_eos(self, setup_backend):
"""Test that ignore_eos=True allows generation to continue beyond EOS token.
When ignore_eos=True, the model should generate until max_tokens is reached,
even if it encounters an EOS token.
"""
_, model, client, _ = setup_backend
tokenizer = get_tokenizer(model)
max_tokens = 200
# Request without ignore_eos (default behavior - stops at EOS)
response_default = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count from 1 to 20."},
],
temperature=0,
max_tokens=max_tokens,
extra_body={"ignore_eos": False},
)
# Request with ignore_eos=True (continues past EOS until max_tokens)
response_ignore_eos = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count from 1 to 20."},
],
temperature=0,
max_tokens=max_tokens,
extra_body={"ignore_eos": True},
)
default_tokens = len(
tokenizer.encode(response_default.choices[0].message.content)
)
ignore_eos_tokens = len(
tokenizer.encode(response_ignore_eos.choices[0].message.content)
)
# Check if ignore_eos resulted in more tokens or exactly max_tokens
# The ignore_eos response should either:
# 1. Have more tokens than the default response (if default stopped at EOS before max_tokens)
# 2. Have exactly max_tokens (if it reached the max_tokens limit)
assert (
ignore_eos_tokens > default_tokens or ignore_eos_tokens >= max_tokens
), f"ignore_eos did not generate more tokens: {ignore_eos_tokens} vs {default_tokens}"
assert response_ignore_eos.choices[0].finish_reason == "length", (
f"Expected finish_reason='length' for ignore_eos=True, "
f"got {response_ignore_eos.choices[0].finish_reason}"
)
# =============================================================================
# Large Max New Tokens Tests (Llama 8B)
#
# NOTE: This test verifies concurrent request handling with large token limits.
# The original test monitored server logs to verify concurrency, which is not
# possible with the pool-based infrastructure. This simplified version verifies
# that concurrent requests complete successfully.
# =============================================================================
@pytest.mark.model("llama-8b")
@pytest.mark.gateway(extra_args=["--history-backend", "memory"])
@pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True)
class TestLargeMaxNewTokens:
"""Tests for handling large max_new_tokens with concurrent requests."""
def test_concurrent_chat_completions(self, setup_backend):
"""Test that multiple concurrent requests with large token generation complete.
This test sends multiple requests that ask for long outputs concurrently
to verify the server can handle concurrent long-running requests.
"""
_, model, client, _ = setup_backend
num_requests = 4
def run_chat_completion():
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "Please repeat the word 'hello' for 100 times.",
},
],
temperature=0,
max_tokens=256, # Reasonable limit for concurrent test
)
return response
# Send concurrent requests
start_time = time.time()
futures = []
with ThreadPoolExecutor(max_workers=num_requests) as executor:
for _ in range(num_requests):
futures.append(executor.submit(run_chat_completion))
# Wait for all to complete and collect results
responses = [f.result() for f in futures]
elapsed = time.time() - start_time
logger.info("Completed %d concurrent requests in %.2fs", num_requests, elapsed)
# Verify all requests completed successfully
assert len(responses) == num_requests
for i, response in enumerate(responses):
assert response.choices[
0
].message.content, f"Request {i} returned empty content"
assert response.choices[0].finish_reason in ("stop", "length"), (
f"Request {i} had unexpected finish_reason: "
f"{response.choices[0].finish_reason}"
)

View File

@@ -0,0 +1,225 @@
"""Pytest configuration for E2E tests.
Parallel Execution
------------------
Tests can run in parallel using pytest-parallel with shared worker processes.
Use --workers 1 --tests-per-worker N for N concurrent test threads:
pytest --workers 1 --tests-per-worker 4 e2e_test/router/
This leverages the thread-safe ModelPool and GPUAllocator classes to enable
true shared-worker parallelism where all threads share the same session-scoped
model_pool fixture. Tests marked with @pytest.mark.thread_unsafe will be
automatically skipped in parallel mode.
Markers
-------
This module defines several pytest markers for configuring E2E tests:
@pytest.mark.model(name)
Specify which model to use for the test.
Args:
name: Model ID from MODEL_SPECS (e.g., "llama-8b", "qwen-7b")
GPU Resource Management:
When GPUs are limited (e.g., 4 GPUs, 6 models), the model pool uses
MRU (Most Recently Used) eviction:
1. Models are pre-launched until GPUs are full
2. When a test needs a model that isn't running, MRU model is evicted
(models just used are likely done, models not yet used are waiting)
3. The needed model is then launched on-demand
Examples:
@pytest.mark.model("llama-8b")
@pytest.mark.model("qwen-72b")
@pytest.mark.workers(count=1, prefill=None, decode=None)
Configure worker topology for the test.
Args:
count: Number of regular workers (default: 1)
prefill: Number of prefill workers for PD disaggregation
decode: Number of decode workers for PD disaggregation
Examples:
@pytest.mark.workers(count=3) # 3 regular workers
@pytest.mark.workers(prefill=2, decode=2) # PD mode
@pytest.mark.gateway(policy="round_robin", timeout=None, extra_args=None)
Configure the gateway/router.
Args:
policy: Routing policy ("round_robin", "random", etc.)
timeout: Startup timeout in seconds
extra_args: Additional CLI arguments for the router
Examples:
@pytest.mark.gateway(policy="random")
@pytest.mark.gateway(extra_args=["--cache-routing"])
@pytest.mark.e2e
Mark test as an end-to-end test requiring GPU workers.
@pytest.mark.slow
Mark test as slow-running.
@pytest.mark.thread_unsafe(reason=None)
Mark test as incompatible with parallel thread execution.
Tests with this marker are automatically skipped when running
with --tests-per-worker > 1.
Args:
reason: Optional explanation of why the test is thread-unsafe.
Examples:
@pytest.mark.thread_unsafe
@pytest.mark.thread_unsafe(reason="Modifies global state")
Fixtures
--------
model_pool: Session-scoped fixture managing SGLang worker processes.
setup_backend: Class-scoped fixture that launches gateway + provides client.
Usage Examples
--------------
Basic test with default model:
@pytest.mark.e2e
@pytest.mark.parametrize("setup_backend", ["http"], indirect=True)
class TestBasic:
def test_chat(self, setup_backend):
backend, model, client, gateway = setup_backend
response = client.chat.completions.create(...)
Test with specific model and multiple backends:
@pytest.mark.e2e
@pytest.mark.model("qwen-7b")
@pytest.mark.parametrize("setup_backend", ["grpc", "http"], indirect=True)
class TestQwen:
def test_generate(self, setup_backend):
...
PD disaggregation mode:
@pytest.mark.e2e
@pytest.mark.workers(prefill=1, decode=1)
@pytest.mark.parametrize("setup_backend", ["pd"], indirect=True)
class TestPD:
def test_pd_inference(self, setup_backend):
...
"""
from __future__ import annotations
import logging
import sys
from importlib.util import find_spec
from pathlib import Path
# ---------------------------------------------------------------------------
# Path setup (must happen before other imports)
# ---------------------------------------------------------------------------
_ROOT = Path(__file__).resolve().parents[1] # sgl-model-gateway/
_E2E_TEST = Path(__file__).resolve().parent # e2e_test/
_SRC = _ROOT / "bindings" / "python"
# Add e2e_test to path so "from infra import ..." works
if str(_E2E_TEST) not in sys.path:
sys.path.insert(0, str(_E2E_TEST))
# Add bindings/python to path if the wheel is not installed (for local development)
_wheel_installed = find_spec("sglang_router.sglang_router_rs") is not None
if not _wheel_installed and str(_SRC) not in sys.path:
sys.path.insert(0, str(_SRC))
# ---------------------------------------------------------------------------
# Logging setup (clean output without pytest's "---- live log ----" dividers)
# ---------------------------------------------------------------------------
def _setup_logging() -> None:
"""Configure clean logging to stdout with timestamps and thread info.
In parallel mode (--tests-per-worker > 1), logs from different threads
would be interleaved. Including thread name helps identify which test
produced each log line.
"""
# Include thread name for parallel execution readability
# MainThread for sequential, Thread-N for parallel workers
fmt = "%(asctime)s.%(msecs)03d [%(threadName)s] [%(name)s] %(message)s"
datefmt = "%H:%M:%S"
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(fmt, datefmt))
for logger_name in ("e2e_test", "infra", "fixtures"):
log = logging.getLogger(logger_name)
log.setLevel(logging.INFO)
log.addHandler(handler)
log.propagate = False
for logger_name in ("openai", "httpx", "httpcore", "numexpr"):
logging.getLogger(logger_name).setLevel(logging.WARNING)
_setup_logging()
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Test visibility hooks
# ---------------------------------------------------------------------------
def pytest_runtest_logstart(nodeid: str, location: tuple) -> None:
"""Print clear test header at start of each test."""
import threading
from infra import LOG_SEPARATOR_WIDTH
test_name = nodeid.split("::")[-1] if "::" in nodeid else nodeid
thread_name = threading.current_thread().name
print(f"\n{'=' * LOG_SEPARATOR_WIDTH}")
print(f"[{thread_name}] TEST: {test_name}")
print(f"{'=' * LOG_SEPARATOR_WIDTH}")
# ---------------------------------------------------------------------------
# Import pytest hooks and fixtures from fixtures/ package
# ---------------------------------------------------------------------------
# Import fixtures - pytest discovers these by name
# Import hooks - pytest discovers these by name
from fixtures import (
backend_router,
model_base_url,
model_client,
model_pool,
pytest_collection_finish,
pytest_collection_modifyitems,
pytest_configure,
pytest_runtest_setup,
setup_backend,
)
# Re-export for pytest discovery
__all__ = [
# Hooks
"pytest_runtest_logstart",
"pytest_collection_modifyitems",
"pytest_collection_finish",
"pytest_configure",
"pytest_runtest_setup",
# Fixtures
"model_pool",
"model_client",
"model_base_url",
"setup_backend",
"backend_router",
]

View File

@@ -0,0 +1,143 @@
"""Basic embedding API tests.
Tests the embedding functionality through the router with both gRPC and HTTP backends.
Source: Migrated from e2e_grpc/basic/test_embedding_server.py
Usage:
pytest e2e_test/embeddings/test_basic.py -v
pytest e2e_test/embeddings/test_basic.py -v -k "grpc"
"""
from __future__ import annotations
import logging
import pytest
logger = logging.getLogger(__name__)
@pytest.mark.e2e
@pytest.mark.model("embedding")
@pytest.mark.parametrize("setup_backend", ["grpc", "http"], indirect=True)
class TestEmbeddingBasic:
"""Basic embedding API tests using local workers (gRPC and HTTP)."""
def test_embedding_single(self, setup_backend):
"""Test single text embedding.
Verifies that:
- Response object structure is correct
- Embedding is a non-empty list of floats
- Usage statistics are present
"""
backend, model, client, gateway = setup_backend
input_text = "Hello world"
response = client.embeddings.create(
model=model,
input=input_text,
)
assert response.object == "list"
assert len(response.data) == 1
embedding = response.data[0]
assert embedding.object == "embedding"
assert embedding.index == 0
assert len(embedding.embedding) > 0
assert isinstance(embedding.embedding[0], float)
# Verify usage statistics
assert response.usage.prompt_tokens > 0
assert response.usage.total_tokens == response.usage.prompt_tokens
logger.info(
"Single embedding: %d dimensions, %d tokens",
len(embedding.embedding),
response.usage.prompt_tokens,
)
def test_embedding_batch(self, setup_backend):
"""Test batch embedding with multiple texts.
Note: The original test expected len(response.data) == 1 for batch,
which seems incorrect. This might be model-specific behavior.
"""
backend, model, client, gateway = setup_backend
input_texts = ["Hello world", "SGLang is fast"]
response = client.embeddings.create(
model=model,
input=input_texts,
)
# Note: Original test had len(response.data) == 1, which seems like
# a bug or model-specific behavior. Standard behavior should return
# one embedding per input text.
assert len(response.data) >= 1
assert response.data[0].index == 0
assert len(response.data[0].embedding) > 0
logger.info("Batch embedding: %d results", len(response.data))
def test_embedding_dimensions_consistent(self, setup_backend):
"""Test that embedding dimensions are consistent across different inputs.
Verifies that different length inputs produce embeddings with
the same dimensionality.
"""
backend, model, client, gateway = setup_backend
response1 = client.embeddings.create(
model=model,
input="A short text",
)
dim1 = len(response1.data[0].embedding)
response2 = client.embeddings.create(
model=model,
input="A much longer text to ensure dimensions match regardless of input length",
)
dim2 = len(response2.data[0].embedding)
assert dim1 == dim2, f"Dimensions differ: {dim1} vs {dim2}"
logger.info("Embedding dimensions: %d (consistent)", dim1)
def test_embedding_empty_string(self, setup_backend):
"""Test embedding with empty string input.
Some models may handle empty strings differently.
This test verifies the API doesn't crash on empty input.
"""
backend, model, client, gateway = setup_backend
try:
response = client.embeddings.create(
model=model,
input="",
)
# If it succeeds, verify structure
assert len(response.data) >= 1
logger.info("Empty string embedding succeeded")
except Exception as e:
# Some models may reject empty strings - that's acceptable
logger.info("Empty string embedding rejected: %s", e)
def test_embedding_unicode(self, setup_backend):
"""Test embedding with unicode characters.
Verifies that the API handles non-ASCII characters correctly.
"""
backend, model, client, gateway = setup_backend
input_text = "Hello 世界! 🚀 Привет мир"
response = client.embeddings.create(
model=model,
input=input_text,
)
assert len(response.data) == 1
assert len(response.data[0].embedding) > 0
logger.info("Unicode embedding: %d dimensions", len(response.data[0].embedding))

View File

@@ -0,0 +1,262 @@
"""Embedding correctness tests.
Tests that embeddings from the router match HuggingFace reference embeddings.
Validates numerical correctness including tokenization and inference.
Source: Migrated from e2e_grpc/basic/test_embedding_correctness.py
Usage:
pytest e2e_test/embeddings/test_correctness.py -v
pytest e2e_test/embeddings/test_correctness.py -v -k "grpc"
Requirements:
- sentence-transformers (for reference embeddings)
- torch
- numpy
"""
from __future__ import annotations
import logging
import threading
from typing import Any
import numpy as np
import pytest
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
# Thread-safe storage for HF reference embeddings
_hf_embeddings_cache: dict[str, Any] | None = None
_hf_embeddings_lock = threading.Lock()
# Test data for semantic similarity checks
SEMANTIC_TEST_SETS: list[list[str]] = [
[
"The cat sat on the mat.",
"A feline was resting on a rug.",
"Bright stars illuminate the night sky.", # Unrelated sentence
],
[
"The quick brown fox jumps over the lazy dog.",
"A fast, dark-colored fox leaps above a sluggish canine.",
"Ocean waves gently lap against the shore.", # Unrelated sentence
],
[
"An apple a day keeps the doctor away.",
"Eating a daily apple can prevent medical visits.",
"Mountains are vast and often snow-capped.", # Unrelated sentence
],
]
# Test data for relevance scoring
RELEVANCE_TEST_DATA: dict[str, Any] = {
"sample_query": "Why is Oracle launching Cloud Lift Services?",
"sample_reference": [
{
"docid": 466,
"body": "What are some extended benefits of using Oracle Cloud Infrastructure? \nWhen customers migrate their on-premises Oracle applications to Oracle Cloud Infrastructure, they realize the benefits \nof the cloud without needing to rearchitect those applications. Customers can lower total cost of ownership, improve \nagility and increase workload performance. Additional benefits include: \nConsistently low global pricing and lack of hidden charges \nAutomated migration support, leveraging cloud managers and tools for key applications \nFlexible universal credits applied towards any IaaS or PaaS service \nBring Your Own License (BYOL) capabilities \nIs Oracle Cloud Lift available for PAYGO customers? \nOracle Cloud Lift Services are designed for customers who use the UCM credits (Monthly Flex). PAYGO customers can \ncontact their sales representative or cloud engineer to evaluate their eligibility. \nAre any countries excluded from Oracle Cloud Lift Services? \nAmong the countries that Oracle operates in, only China is excluded from the Oracle Cloud Lift Services program.",
},
{
"docid": 636,
"body": "Cloud Lift Services as needed to make our joint customers more successful. Public Sector accounts and partner \nengagements are not currently eligible to participate in this program. \n How can I get started with Oracle Cloud? \nYou can use the Oracle Cloud Free Tier for a free trial and Contact Us for more information.",
},
{
"docid": 545,
"body": "Frequently Asked Questions (FAQs) for \nOracle Cloud Lift Services \n \nWhy is Oracle launching Cloud Lift Services? \n \n \n \nThis program underscores Oracle's intent to better serve its customer base. Cloud Lift Services provide new and \nexisting customers expanded access to cloud engineering tools and resources to quickly migrate workloads at no \nadditional cost.",
},
{
"docid": 716,
"body": "as part of their existing contract. \nWhat happens if I already have a paid services engagement? \nPlease keep proceeding with your existing engagement. Oracle will work with you to identify expansion opportunities \nto leverage Cloud Lift Services for other projects.",
},
],
}
def get_openai_embeddings(
texts: str | list[str],
client,
model: str,
) -> list[list[float]]:
"""Get embeddings from the gateway via OpenAI-compatible API."""
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
response = client.embeddings.create(
model=model,
input=text,
)
embeddings.append(response.data[0].embedding)
return embeddings
def get_hf_st_embeddings(texts: str | list[str], model_path: str) -> np.ndarray:
"""Get embeddings using sentence-transformers library.
This handles the correct pooling strategy for each model automatically.
For e5-mistral, it uses last-token pooling (not mean pooling).
Uses CPU to compute reference embeddings to avoid GPU memory conflicts
with the worker being tested.
"""
from sentence_transformers import SentenceTransformer
if isinstance(texts, str):
texts = [texts]
# Force CPU to avoid GPU memory conflicts in CI
model = SentenceTransformer(model_path, trust_remote_code=True, device="cpu")
embeddings = model.encode(texts, normalize_embeddings=True)
return embeddings
def compare_embeddings(
embeddings1: list[list[float]], embeddings2: list[list[float]]
) -> list[float]:
"""Compare two sets of embeddings using cosine similarity."""
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0).item()
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities
def get_input_texts(test_json: dict) -> list[str]:
"""Extract document bodies from test JSON."""
return [doc["body"] for doc in test_json["sample_reference"]]
@pytest.fixture(scope="session")
def hf_reference_embeddings(request):
"""Pre-compute HuggingFace reference embeddings on CPU.
This is done once per session with thread-safe initialization to support
pytest-parallel execution. Uses CPU to avoid GPU memory conflicts.
"""
global _hf_embeddings_cache
# Thread-safe initialization - only one thread computes embeddings
with _hf_embeddings_lock:
if _hf_embeddings_cache is not None:
return _hf_embeddings_cache
from infra.model_specs import MODEL_SPECS
# Get model path from MODEL_SPECS for the embedding model
model_path = MODEL_SPECS.get("embedding", {}).get("model")
if model_path is None:
pytest.skip("Embedding model not found in MODEL_SPECS")
logger.info(
"Pre-computing HuggingFace reference embeddings (CPU) for %s", model_path
)
# Flatten all test texts for semantic similarity
all_semantic_texts = []
for text_set in SEMANTIC_TEST_SETS:
all_semantic_texts.extend(text_set)
# Get relevance test texts
query = f"Instruct: Given a search query, retrieve relevant passages that answer the query\nQuery: {RELEVANCE_TEST_DATA['sample_query']}"
docs = get_input_texts(RELEVANCE_TEST_DATA)
# Compute all reference embeddings at once
hf_semantic = get_hf_st_embeddings(all_semantic_texts, model_path)
hf_query = get_hf_st_embeddings(query, model_path)
hf_docs = get_hf_st_embeddings(docs, model_path)
logger.info("Reference embeddings computed on CPU")
_hf_embeddings_cache = {
"semantic": hf_semantic,
"query": hf_query,
"docs": hf_docs,
}
return _hf_embeddings_cache
@pytest.mark.e2e
@pytest.mark.model("embedding")
@pytest.mark.parametrize("setup_backend", ["grpc", "http"], indirect=True)
class TestEmbeddingCorrectness:
"""Test embedding correctness by comparing gateway output against HuggingFace reference.
Strategy: Pre-compute HuggingFace reference embeddings on CPU, then launch the
worker on GPU and compare. Using CPU for reference avoids GPU memory conflicts.
"""
def test_semantic_similarity(self, setup_backend, hf_reference_embeddings):
"""Check if gateway and HF embeddings give similar results.
For each text in the semantic test sets, the gateway embedding should
have >0.98 cosine similarity with the HuggingFace reference embedding.
"""
backend, model_path, client, gateway = setup_backend
tolerance = 1e-2
# Track position in pre-computed embeddings
embed_idx = 0
for i, input_texts in enumerate(SEMANTIC_TEST_SETS):
logger.info("Processing semantic similarity test set %d", i + 1)
embedding_gateway = get_openai_embeddings(input_texts, client, model_path)
# Get pre-computed HF embeddings for this set
num_texts = len(input_texts)
embedding_hf = hf_reference_embeddings["semantic"][
embed_idx : embed_idx + num_texts
].tolist()
embed_idx += num_texts
similarities = compare_embeddings(embedding_gateway, embedding_hf)
logger.info("Similarities: %s", similarities)
# Verify all similarities are close to 1.0
for j, sim in enumerate(similarities):
assert (
abs(sim - 1.0) < tolerance
), f"Set {i+1}, text {j+1}: similarity {sim:.4f} not close to 1.0"
logger.info("Semantic similarity test set %d passed", i + 1)
def test_relevance_scores(self, setup_backend, hf_reference_embeddings):
"""Compare relevance scores between gateway and HF implementations.
The relevance scores (query @ docs) should match between the gateway
and HuggingFace implementations within tolerance.
"""
backend, model_path, client, gateway = setup_backend
tolerance = 0.05
# Format query with instruction (for e5-mistral)
query = f"Instruct: Given a search query, retrieve relevant passages that answer the query\nQuery: {RELEVANCE_TEST_DATA['sample_query']}"
docs = get_input_texts(RELEVANCE_TEST_DATA)
# Get gateway scores
query_embeddings_gateway = get_openai_embeddings(query, client, model_path)
docs_embeddings_gateway = get_openai_embeddings(docs, client, model_path)
scores_gateway = (
np.array(query_embeddings_gateway) @ np.array(docs_embeddings_gateway).T
) * 100
# Use pre-computed HF scores
scores_hf = (
hf_reference_embeddings["query"] @ hf_reference_embeddings["docs"].T
) * 100
logger.info("Gateway relevance scores: %s", scores_gateway)
logger.info("HF relevance scores: %s", scores_hf)
assert np.allclose(
scores_gateway, scores_hf, atol=tolerance
), f"Scores differ beyond tolerance:\nGateway: {scores_gateway}\nHF: {scores_hf}"
logger.info("Relevance scores comparison passed")

Some files were not shown because too many files have changed in this diff Show More