Skip to content

Commit a203fdb

Browse files
committed
feat: add LLMWithGateway for enterprise OAuth support
Add LLMWithGateway subclass to support enterprise API gateways with OAuth 2.0 authentication. Key features: - OAuth 2.0 token fetch and automatic refresh - Thread-safe token caching with TTL - Custom header injection for gateway-specific requirements - Template variable replacement for flexible configuration - Fully generic implementation (no vendor lock-in) Implementation approach: - Separate LLMWithGateway class (addresses PR #963 feedback from neubig) - Focused feature set for OAuth + custom headers (no over-engineering) - Comprehensive test coverage This replaces the previous approach of modifying the main LLM class, keeping the codebase cleaner and more maintainable. Example usage: ```python llm = LLMWithGateway( model="gpt-4", base_url=os.environ["GATEWAY_BASE_URL"], gateway_auth_url=os.environ["GATEWAY_AUTH_URL"], gateway_auth_headers={ "X-Client-Id": os.environ["GATEWAY_CLIENT_ID"], "X-Client-Secret": os.environ["GATEWAY_CLIENT_SECRET"], }, gateway_auth_body={"grant_type": "client_credentials"}, custom_headers={"X-Gateway-Key": os.environ["GATEWAY_API_KEY"]}, ) ``` Files added: - openhands-sdk/openhands/sdk/llm/llm_with_gateway.py (new class) - tests/sdk/llm/test_llm_with_gateway.py (comprehensive tests)
1 parent b9860ce commit a203fdb

File tree

3 files changed

+738
-0
lines changed

3 files changed

+738
-0
lines changed

openhands-sdk/openhands/sdk/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from openhands.sdk.llm.llm import LLM
22
from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent
33
from openhands.sdk.llm.llm_response import LLMResponse
4+
from openhands.sdk.llm.llm_with_gateway import LLMWithGateway
45
from openhands.sdk.llm.message import (
56
ImageContent,
67
Message,
@@ -23,6 +24,7 @@
2324
__all__ = [
2425
"LLMResponse",
2526
"LLM",
27+
"LLMWithGateway",
2628
"LLMRegistry",
2729
"RouterLLM",
2830
"RegistryEvent",
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
"""LLM subclass with enterprise gateway support.
2+
3+
This module provides LLMWithGateway, which extends the base LLM class to support
4+
OAuth 2.0 authentication flows and custom headers for enterprise API gateways.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import threading
10+
import time
11+
from typing import Any
12+
13+
import httpx
14+
from pydantic import Field, PrivateAttr
15+
16+
from openhands.sdk.llm.llm import LLM
17+
from openhands.sdk.logger import get_logger
18+
19+
logger = get_logger(__name__)
20+
21+
__all__ = ["LLMWithGateway"]
22+
23+
24+
class LLMWithGateway(LLM):
25+
"""LLM subclass with enterprise gateway support.
26+
27+
Supports OAuth 2.0 token exchange with configurable headers and bodies.
28+
Designed for enterprise API gateways that require:
29+
1. Initial OAuth call to get a bearer token
30+
2. Bearer token included in subsequent LLM API calls
31+
3. Custom headers for routing/authentication
32+
33+
Example usage:
34+
llm = LLMWithGateway(
35+
model="gpt-4",
36+
base_url="https://gateway.company.com/llm/v1",
37+
gateway_auth_url="https://gateway.company.com/oauth/token",
38+
gateway_auth_headers={
39+
"X-Client-Id": os.environ["GATEWAY_CLIENT_ID"],
40+
"X-Client-Secret": os.environ["GATEWAY_CLIENT_SECRET"],
41+
},
42+
gateway_auth_body={"grant_type": "client_credentials"},
43+
custom_headers={"X-Gateway-Key": os.environ["GATEWAY_API_KEY"]},
44+
)
45+
"""
46+
47+
# OAuth configuration
48+
gateway_auth_url: str | None = Field(
49+
default=None,
50+
description="Identity provider URL to fetch gateway tokens (OAuth endpoint).",
51+
)
52+
gateway_auth_method: str = Field(
53+
default="POST",
54+
description="HTTP method for identity provider requests.",
55+
)
56+
gateway_auth_headers: dict[str, str] | None = Field(
57+
default=None,
58+
description="Headers to include when calling the identity provider.",
59+
)
60+
gateway_auth_body: dict[str, Any] | None = Field(
61+
default=None,
62+
description="JSON body to include when calling the identity provider.",
63+
)
64+
gateway_auth_token_path: str = Field(
65+
default="access_token",
66+
description="Dot-notation path to the token in the OAuth response (e.g., 'access_token' or 'data.token').",
67+
)
68+
gateway_auth_token_ttl: int | None = Field(
69+
default=None,
70+
description="Token TTL in seconds. If not set, defaults to 300s (5 minutes).",
71+
)
72+
73+
# Token header configuration
74+
gateway_token_header: str = Field(
75+
default="Authorization",
76+
description="Header name for the gateway token (defaults to 'Authorization').",
77+
)
78+
gateway_token_prefix: str = Field(
79+
default="Bearer ",
80+
description="Prefix prepended to the token (e.g., 'Bearer ').",
81+
)
82+
83+
# Custom headers for all requests
84+
custom_headers: dict[str, str] | None = Field(
85+
default=None,
86+
description="Custom headers to include with every LLM request.",
87+
)
88+
89+
# Private fields for token management
90+
_gateway_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
91+
_gateway_token: str | None = PrivateAttr(default=None)
92+
_gateway_token_expiry: float | None = PrivateAttr(default=None)
93+
94+
def model_post_init(self, __context: Any) -> None:
95+
"""Initialize private fields after model validation."""
96+
super().model_post_init(__context)
97+
self._gateway_lock = threading.Lock()
98+
self._gateway_token = None
99+
self._gateway_token_expiry = None
100+
101+
def _completion(
102+
self, messages: list[dict], **kwargs
103+
) -> Any: # Returns ModelResponse
104+
"""Override to inject gateway authentication before calling LiteLLM."""
105+
self._prepare_gateway_call(kwargs)
106+
return super()._completion(messages, **kwargs)
107+
108+
def _responses(
109+
self, messages: list[dict], **kwargs
110+
) -> Any: # Returns ResponsesAPIResponse
111+
"""Override to inject gateway authentication before calling LiteLLM."""
112+
self._prepare_gateway_call(kwargs)
113+
return super()._responses(messages, **kwargs)
114+
115+
def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None:
116+
"""Augment LiteLLM kwargs with gateway headers and token.
117+
118+
This method:
119+
1. Fetches/refreshes OAuth token if needed
120+
2. Adds token to headers
121+
3. Adds custom headers
122+
4. Performs basic template variable replacement
123+
"""
124+
if not self.gateway_auth_url and not self.custom_headers:
125+
return
126+
127+
# Start with existing headers
128+
headers: dict[str, str] = {}
129+
existing_headers = call_kwargs.get("extra_headers")
130+
if isinstance(existing_headers, dict):
131+
headers.update(existing_headers)
132+
133+
# Add custom headers (with template replacement)
134+
if self.custom_headers:
135+
rendered_headers = self._render_templates(self.custom_headers)
136+
if isinstance(rendered_headers, dict):
137+
headers.update(rendered_headers)
138+
139+
# Add gateway token if OAuth is configured
140+
if self.gateway_auth_url:
141+
token_headers = self._get_gateway_token_headers()
142+
if token_headers:
143+
headers.update(token_headers)
144+
145+
# Set headers on the call
146+
if headers:
147+
call_kwargs["extra_headers"] = headers
148+
149+
def _get_gateway_token_headers(self) -> dict[str, str]:
150+
"""Get headers containing the gateway token."""
151+
token = self._ensure_gateway_token()
152+
if not token:
153+
return {}
154+
155+
header_name = self.gateway_token_header
156+
prefix = self.gateway_token_prefix
157+
value = f"{prefix}{token}" if prefix else token
158+
return {header_name: value}
159+
160+
def _ensure_gateway_token(self) -> str | None:
161+
"""Ensure we have a valid gateway token, refreshing if needed.
162+
163+
Returns:
164+
Valid gateway token, or None if gateway auth is not configured.
165+
"""
166+
if not self.gateway_auth_url:
167+
return None
168+
169+
# Fast path: check if current token is still valid (with 5s buffer)
170+
now = time.time()
171+
if (
172+
self._gateway_token
173+
and self._gateway_token_expiry
174+
and now < self._gateway_token_expiry - 5
175+
):
176+
return self._gateway_token
177+
178+
# Slow path: acquire lock and refresh token
179+
with self._gateway_lock:
180+
# Double-check after acquiring lock
181+
if (
182+
self._gateway_token
183+
and self._gateway_token_expiry
184+
and time.time() < self._gateway_token_expiry - 5
185+
):
186+
return self._gateway_token
187+
188+
# Refresh token
189+
return self._refresh_gateway_token()
190+
191+
def _refresh_gateway_token(self) -> str:
192+
"""Fetch a new gateway token from the identity provider.
193+
194+
This method is called while holding _gateway_lock.
195+
196+
Returns:
197+
Fresh gateway token.
198+
199+
Raises:
200+
Exception: If token fetch fails.
201+
"""
202+
method = self.gateway_auth_method.upper()
203+
headers = self._render_templates(self.gateway_auth_headers or {})
204+
body = self._render_templates(self.gateway_auth_body or {})
205+
206+
logger.debug(
207+
f"Fetching gateway token from {self.gateway_auth_url} "
208+
f"(method={method})"
209+
)
210+
211+
try:
212+
response = httpx.request(
213+
method,
214+
self.gateway_auth_url,
215+
headers=headers if isinstance(headers, dict) else None,
216+
json=body if isinstance(body, dict) else None,
217+
timeout=self.timeout or 30,
218+
)
219+
response.raise_for_status()
220+
except Exception as exc:
221+
logger.error(f"Gateway auth request failed: {exc}")
222+
raise
223+
224+
try:
225+
payload = response.json()
226+
except Exception as exc:
227+
logger.error(f"Failed to parse gateway auth response JSON: {exc}")
228+
raise
229+
230+
# Extract token from response
231+
token_path = self.gateway_auth_token_path
232+
token_value = self._extract_from_path(payload, token_path)
233+
if not isinstance(token_value, str) or not token_value.strip():
234+
raise ValueError(
235+
f"Gateway auth response did not contain token at path "
236+
f'"{token_path}". Response: {payload}'
237+
)
238+
239+
# Determine TTL
240+
ttl_seconds = float(self.gateway_auth_token_ttl or 300)
241+
242+
# Update cache
243+
self._gateway_token = token_value.strip()
244+
self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0)
245+
246+
logger.info(
247+
f"Gateway token refreshed successfully (expires in {ttl_seconds}s)"
248+
)
249+
return self._gateway_token
250+
251+
def _render_templates(self, value: Any) -> Any:
252+
"""Replace template variables in strings with actual values.
253+
254+
Supports:
255+
- {{llm_model}} -> self.model
256+
- {{llm_base_url}} -> self.base_url
257+
- {{llm_api_key}} -> self.api_key (if set)
258+
259+
Args:
260+
value: String, dict, list, or other value to render.
261+
262+
Returns:
263+
Value with templates replaced.
264+
"""
265+
if isinstance(value, str):
266+
replacements: dict[str, str] = {
267+
"{{llm_model}}": self.model,
268+
"{{llm_base_url}}": self.base_url or "",
269+
}
270+
if self.api_key:
271+
replacements["{{llm_api_key}}"] = self.api_key.get_secret_value()
272+
273+
result = value
274+
for placeholder, actual in replacements.items():
275+
result = result.replace(placeholder, actual)
276+
return result
277+
278+
if isinstance(value, dict):
279+
return {k: self._render_templates(v) for k, v in value.items()}
280+
281+
if isinstance(value, list):
282+
return [self._render_templates(v) for v in value]
283+
284+
return value
285+
286+
@staticmethod
287+
def _extract_from_path(payload: Any, path: str) -> Any:
288+
"""Extract a value from nested dict/list using dot notation.
289+
290+
Examples:
291+
_extract_from_path({"a": {"b": "value"}}, "a.b") -> "value"
292+
_extract_from_path({"data": [{"token": "x"}]}, "data.0.token") -> "x"
293+
294+
Args:
295+
payload: Dict or list to traverse.
296+
path: Dot-separated path (e.g., "data.token" or "items.0.value").
297+
298+
Returns:
299+
Value at the specified path.
300+
301+
Raises:
302+
ValueError: If path cannot be traversed.
303+
"""
304+
current = payload
305+
if not path:
306+
return current
307+
308+
for part in path.split("."):
309+
if isinstance(current, dict):
310+
current = current.get(part)
311+
if current is None:
312+
raise ValueError(
313+
f'Key "{part}" not found in response while traversing path "{path}".'
314+
)
315+
elif isinstance(current, list):
316+
try:
317+
index = int(part)
318+
except (ValueError, TypeError):
319+
raise ValueError(
320+
f'Invalid list index "{part}" while traversing response path "{path}".'
321+
) from None
322+
try:
323+
current = current[index]
324+
except (IndexError, TypeError):
325+
raise ValueError(
326+
f'Index {index} out of range while traversing response path "{path}".'
327+
) from None
328+
else:
329+
raise ValueError(
330+
f'Cannot traverse path "{path}"; segment "{part}" not found or not accessible.'
331+
)
332+
333+
return current

0 commit comments

Comments
 (0)