Bypass proxies for loopback engines

This commit is contained in:
2026-04-04 23:50:42 +08:00
parent 7632de8dad
commit 75a9842f1a
2 changed files with 30 additions and 4 deletions

View File

@@ -5,14 +5,35 @@ import os
import time
import urllib.error
import urllib.request
from ipaddress import ip_address
from dataclasses import dataclass
from typing import Any, Iterable
from urllib.parse import urlparse
class HttpClientError(RuntimeError):
"""Raised for HTTP client failures."""
def _should_bypass_proxy(url: str) -> bool:
host = (urlparse(url).hostname or "").strip()
if not host:
return False
if host == "localhost":
return True
try:
return ip_address(host).is_loopback
except ValueError:
return False
def _urlopen(request: urllib.request.Request, *, timeout: float):
if _should_bypass_proxy(request.full_url):
opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
return opener.open(request, timeout=timeout)
return urllib.request.urlopen(request, timeout=timeout)
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if api_key_env:
@@ -37,7 +58,7 @@ def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
while time.monotonic() < deadline:
try:
request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET")
with urllib.request.urlopen(request, timeout=5) as response:
with _urlopen(request, timeout=5) as response:
if 200 <= response.status < 500:
return
except Exception as exc: # noqa: BLE001
@@ -66,7 +87,7 @@ def chat_completion(
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=timeout_s) as response:
with _urlopen(request, timeout=timeout_s) as response:
return json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="replace")
@@ -99,7 +120,7 @@ def stream_chat_completion(
chunk_token_count = 0
completion_tokens: int | None = None
try:
with urllib.request.urlopen(request, timeout=timeout_s) as response:
with _urlopen(request, timeout=timeout_s) as response:
for raw in _iter_sse_lines(response):
if raw == "[DONE]":
break

View File

@@ -8,7 +8,7 @@ from pathlib import Path
from unittest import mock
from aituner.cli import main as cli_main
from aituner.http_client import _openai_url
from aituner.http_client import _openai_url, _should_bypass_proxy
from aituner.job import append_job, build_trial_job
from aituner.llm import build_prompt, parse_proposal_text
from aituner.search import ThresholdProbe, binary_search_max_feasible
@@ -833,6 +833,11 @@ class CoreFlowTests(unittest.TestCase):
"http://example.com/v1/chat/completions",
)
def test_loopback_urls_bypass_proxy(self) -> None:
self.assertTrue(_should_bypass_proxy("http://127.0.0.1:8000/v1/models"))
self.assertTrue(_should_bypass_proxy("http://localhost:8000/health"))
self.assertFalse(_should_bypass_proxy("http://example.com/v1/models"))
if __name__ == "__main__":
unittest.main()