Bypass proxies for loopback engines
This commit is contained in:
@@ -5,14 +5,35 @@ import os
|
|||||||
import time
|
import time
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from ipaddress import ip_address
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
class HttpClientError(RuntimeError):
|
class HttpClientError(RuntimeError):
|
||||||
"""Raised for HTTP client failures."""
|
"""Raised for HTTP client failures."""
|
||||||
|
|
||||||
|
|
||||||
|
def _should_bypass_proxy(url: str) -> bool:
|
||||||
|
host = (urlparse(url).hostname or "").strip()
|
||||||
|
if not host:
|
||||||
|
return False
|
||||||
|
if host == "localhost":
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
return ip_address(host).is_loopback
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _urlopen(request: urllib.request.Request, *, timeout: float):
|
||||||
|
if _should_bypass_proxy(request.full_url):
|
||||||
|
opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
|
||||||
|
return opener.open(request, timeout=timeout)
|
||||||
|
return urllib.request.urlopen(request, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
@@ -37,7 +58,7 @@ def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
|
|||||||
while time.monotonic() < deadline:
|
while time.monotonic() < deadline:
|
||||||
try:
|
try:
|
||||||
request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET")
|
request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET")
|
||||||
with urllib.request.urlopen(request, timeout=5) as response:
|
with _urlopen(request, timeout=5) as response:
|
||||||
if 200 <= response.status < 500:
|
if 200 <= response.status < 500:
|
||||||
return
|
return
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
@@ -66,7 +87,7 @@ def chat_completion(
|
|||||||
method="POST",
|
method="POST",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
with _urlopen(request, timeout=timeout_s) as response:
|
||||||
return json.loads(response.read().decode("utf-8"))
|
return json.loads(response.read().decode("utf-8"))
|
||||||
except urllib.error.HTTPError as exc:
|
except urllib.error.HTTPError as exc:
|
||||||
detail = exc.read().decode("utf-8", errors="replace")
|
detail = exc.read().decode("utf-8", errors="replace")
|
||||||
@@ -99,7 +120,7 @@ def stream_chat_completion(
|
|||||||
chunk_token_count = 0
|
chunk_token_count = 0
|
||||||
completion_tokens: int | None = None
|
completion_tokens: int | None = None
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
with _urlopen(request, timeout=timeout_s) as response:
|
||||||
for raw in _iter_sse_lines(response):
|
for raw in _iter_sse_lines(response):
|
||||||
if raw == "[DONE]":
|
if raw == "[DONE]":
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from aituner.cli import main as cli_main
|
from aituner.cli import main as cli_main
|
||||||
from aituner.http_client import _openai_url
|
from aituner.http_client import _openai_url, _should_bypass_proxy
|
||||||
from aituner.job import append_job, build_trial_job
|
from aituner.job import append_job, build_trial_job
|
||||||
from aituner.llm import build_prompt, parse_proposal_text
|
from aituner.llm import build_prompt, parse_proposal_text
|
||||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||||
@@ -833,6 +833,11 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
"http://example.com/v1/chat/completions",
|
"http://example.com/v1/chat/completions",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_loopback_urls_bypass_proxy(self) -> None:
|
||||||
|
self.assertTrue(_should_bypass_proxy("http://127.0.0.1:8000/v1/models"))
|
||||||
|
self.assertTrue(_should_bypass_proxy("http://localhost:8000/health"))
|
||||||
|
self.assertFalse(_should_bypass_proxy("http://example.com/v1/models"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user