Bypass proxies for loopback engines
This commit is contained in:
@@ -5,14 +5,35 @@ import os
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from ipaddress import ip_address
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class HttpClientError(RuntimeError):
|
||||
"""Raised for HTTP client failures."""
|
||||
|
||||
|
||||
def _should_bypass_proxy(url: str) -> bool:
|
||||
host = (urlparse(url).hostname or "").strip()
|
||||
if not host:
|
||||
return False
|
||||
if host == "localhost":
|
||||
return True
|
||||
try:
|
||||
return ip_address(host).is_loopback
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _urlopen(request: urllib.request.Request, *, timeout: float):
|
||||
if _should_bypass_proxy(request.full_url):
|
||||
opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
|
||||
return opener.open(request, timeout=timeout)
|
||||
return urllib.request.urlopen(request, timeout=timeout)
|
||||
|
||||
|
||||
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key_env:
|
||||
@@ -37,7 +58,7 @@ def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET")
|
||||
with urllib.request.urlopen(request, timeout=5) as response:
|
||||
with _urlopen(request, timeout=5) as response:
|
||||
if 200 <= response.status < 500:
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@@ -66,7 +87,7 @@ def chat_completion(
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
||||
with _urlopen(request, timeout=timeout_s) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
except urllib.error.HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
@@ -99,7 +120,7 @@ def stream_chat_completion(
|
||||
chunk_token_count = 0
|
||||
completion_tokens: int | None = None
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
||||
with _urlopen(request, timeout=timeout_s) as response:
|
||||
for raw in _iter_sse_lines(response):
|
||||
if raw == "[DONE]":
|
||||
break
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from aituner.cli import main as cli_main
|
||||
from aituner.http_client import _openai_url
|
||||
from aituner.http_client import _openai_url, _should_bypass_proxy
|
||||
from aituner.job import append_job, build_trial_job
|
||||
from aituner.llm import build_prompt, parse_proposal_text
|
||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||
@@ -833,6 +833,11 @@ class CoreFlowTests(unittest.TestCase):
|
||||
"http://example.com/v1/chat/completions",
|
||||
)
|
||||
|
||||
def test_loopback_urls_bypass_proxy(self) -> None:
|
||||
self.assertTrue(_should_bypass_proxy("http://127.0.0.1:8000/v1/models"))
|
||||
self.assertTrue(_should_bypass_proxy("http://localhost:8000/health"))
|
||||
self.assertFalse(_should_bypass_proxy("http://example.com/v1/models"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user