Skip to content

Commit c67e1e4

Browse files
committed
feat: add LLMWithGateway for enterprise OAuth support
Add LLMWithGateway class that extends LLM with enterprise gateway support: - OAuth 2.0 token fetching and automatic refresh with caching - Configurable token paths and TTL for various OAuth response formats - Custom header injection for routing and additional authentication - Template variable support ({{llm_model}}, {{llm_base_url}}, etc.) - Thread-safe token management - Works with both completion() and responses() APIs The class maintains full API compatibility with the base LLM class while transparently handling gateway authentication flows behind the scenes. Includes comprehensive test coverage (25 tests) covering: - OAuth token lifecycle (fetch, cache, refresh) - Header injection and custom headers - Template replacement - Nested path extraction from OAuth responses - Error handling and edge cases
1 parent b9860ce commit c67e1e4

File tree

3 files changed

+741
-0
lines changed

3 files changed

+741
-0
lines changed

openhands-sdk/openhands/sdk/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from openhands.sdk.llm.llm import LLM
22
from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent
33
from openhands.sdk.llm.llm_response import LLMResponse
4+
from openhands.sdk.llm.llm_with_gateway import LLMWithGateway
45
from openhands.sdk.llm.message import (
56
ImageContent,
67
Message,
@@ -23,6 +24,7 @@
2324
__all__ = [
2425
"LLMResponse",
2526
"LLM",
27+
"LLMWithGateway",
2628
"LLMRegistry",
2729
"RouterLLM",
2830
"RegistryEvent",
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
"""LLM subclass with enterprise gateway support.
2+
3+
This module provides LLMWithGateway, which extends the base LLM class to support
4+
OAuth 2.0 authentication flows and custom headers for enterprise API gateways.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import threading
10+
import time
11+
from typing import Any
12+
13+
import httpx
14+
from pydantic import Field, PrivateAttr
15+
16+
from openhands.sdk.llm.llm import LLM
17+
from openhands.sdk.logger import get_logger
18+
19+
20+
logger = get_logger(__name__)
21+
22+
__all__ = ["LLMWithGateway"]
23+
24+
25+
class LLMWithGateway(LLM):
26+
"""LLM subclass with enterprise gateway support.
27+
28+
Supports OAuth 2.0 token exchange with configurable headers and bodies.
29+
Designed for enterprise API gateways that require:
30+
1. Initial OAuth call to get a bearer token
31+
2. Bearer token included in subsequent LLM API calls
32+
3. Custom headers for routing/authentication
33+
34+
Example usage:
35+
llm = LLMWithGateway(
36+
model="gpt-4",
37+
base_url="https://gateway.company.com/llm/v1",
38+
gateway_auth_url="https://gateway.company.com/oauth/token",
39+
gateway_auth_headers={
40+
"X-Client-Id": os.environ["GATEWAY_CLIENT_ID"],
41+
"X-Client-Secret": os.environ["GATEWAY_CLIENT_SECRET"],
42+
},
43+
gateway_auth_body={"grant_type": "client_credentials"},
44+
custom_headers={"X-Gateway-Key": os.environ["GATEWAY_API_KEY"]},
45+
)
46+
"""
47+
48+
# OAuth configuration
49+
gateway_auth_url: str | None = Field(
50+
default=None,
51+
description="Identity provider URL to fetch gateway tokens (OAuth endpoint).",
52+
)
53+
gateway_auth_method: str = Field(
54+
default="POST",
55+
description="HTTP method for identity provider requests.",
56+
)
57+
gateway_auth_headers: dict[str, str] | None = Field(
58+
default=None,
59+
description="Headers to include when calling the identity provider.",
60+
)
61+
gateway_auth_body: dict[str, Any] | None = Field(
62+
default=None,
63+
description="JSON body to include when calling the identity provider.",
64+
)
65+
gateway_auth_token_path: str = Field(
66+
default="access_token",
67+
description=(
68+
"Dot-notation path to the token in the OAuth response "
69+
"(e.g., 'access_token' or 'data.token')."
70+
),
71+
)
72+
gateway_auth_token_ttl: int | None = Field(
73+
default=None,
74+
description="Token TTL in seconds. If not set, defaults to 300s (5 minutes).",
75+
)
76+
77+
# Token header configuration
78+
gateway_token_header: str = Field(
79+
default="Authorization",
80+
description="Header name for the gateway token (defaults to 'Authorization').",
81+
)
82+
gateway_token_prefix: str = Field(
83+
default="Bearer ",
84+
description="Prefix prepended to the token (e.g., 'Bearer ').",
85+
)
86+
87+
# Custom headers for all requests
88+
custom_headers: dict[str, str] | None = Field(
89+
default=None,
90+
description="Custom headers to include with every LLM request.",
91+
)
92+
93+
# Private fields for token management
94+
_gateway_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
95+
_gateway_token: str | None = PrivateAttr(default=None)
96+
_gateway_token_expiry: float | None = PrivateAttr(default=None)
97+
98+
def model_post_init(self, __context: Any) -> None:
99+
"""Initialize private fields after model validation."""
100+
super().model_post_init(__context)
101+
self._gateway_lock = threading.Lock()
102+
self._gateway_token = None
103+
self._gateway_token_expiry = None
104+
105+
def completion(self, *args, **kwargs):
106+
"""Override to inject gateway authentication before calling LiteLLM."""
107+
self._prepare_gateway_call(kwargs)
108+
return super().completion(*args, **kwargs)
109+
110+
def responses(self, *args, **kwargs):
111+
"""Override to inject gateway authentication before calling LiteLLM."""
112+
self._prepare_gateway_call(kwargs)
113+
return super().responses(*args, **kwargs)
114+
115+
def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None:
116+
"""Augment LiteLLM kwargs with gateway headers and token.
117+
118+
This method:
119+
1. Fetches/refreshes OAuth token if needed
120+
2. Adds token to headers
121+
3. Adds custom headers
122+
4. Performs basic template variable replacement
123+
"""
124+
if not self.gateway_auth_url and not self.custom_headers:
125+
return
126+
127+
# Start with existing headers
128+
headers: dict[str, str] = {}
129+
existing_headers = call_kwargs.get("extra_headers")
130+
if isinstance(existing_headers, dict):
131+
headers.update(existing_headers)
132+
133+
# Add custom headers (with template replacement)
134+
if self.custom_headers:
135+
rendered_headers = self._render_templates(self.custom_headers)
136+
if isinstance(rendered_headers, dict):
137+
headers.update(rendered_headers)
138+
139+
# Add gateway token if OAuth is configured
140+
if self.gateway_auth_url:
141+
token_headers = self._get_gateway_token_headers()
142+
if token_headers:
143+
headers.update(token_headers)
144+
145+
# Set headers on the call
146+
if headers:
147+
call_kwargs["extra_headers"] = headers
148+
149+
def _get_gateway_token_headers(self) -> dict[str, str]:
150+
"""Get headers containing the gateway token."""
151+
token = self._ensure_gateway_token()
152+
if not token:
153+
return {}
154+
155+
header_name = self.gateway_token_header
156+
prefix = self.gateway_token_prefix
157+
value = f"{prefix}{token}" if prefix else token
158+
return {header_name: value}
159+
160+
def _ensure_gateway_token(self) -> str | None:
161+
"""Ensure we have a valid gateway token, refreshing if needed.
162+
163+
Returns:
164+
Valid gateway token, or None if gateway auth is not configured.
165+
"""
166+
if not self.gateway_auth_url:
167+
return None
168+
169+
# Fast path: check if current token is still valid (with 5s buffer)
170+
now = time.time()
171+
if (
172+
self._gateway_token
173+
and self._gateway_token_expiry
174+
and now < self._gateway_token_expiry - 5
175+
):
176+
return self._gateway_token
177+
178+
# Slow path: acquire lock and refresh token
179+
with self._gateway_lock:
180+
# Double-check after acquiring lock
181+
if (
182+
self._gateway_token
183+
and self._gateway_token_expiry
184+
and time.time() < self._gateway_token_expiry - 5
185+
):
186+
return self._gateway_token
187+
188+
# Refresh token
189+
return self._refresh_gateway_token()
190+
191+
def _refresh_gateway_token(self) -> str:
192+
"""Fetch a new gateway token from the identity provider.
193+
194+
This method is called while holding _gateway_lock.
195+
196+
Returns:
197+
Fresh gateway token.
198+
199+
Raises:
200+
Exception: If token fetch fails.
201+
"""
202+
assert self.gateway_auth_url is not None, "gateway_auth_url must be set"
203+
method = self.gateway_auth_method.upper()
204+
headers = self._render_templates(self.gateway_auth_headers or {})
205+
body = self._render_templates(self.gateway_auth_body or {})
206+
207+
logger.debug(
208+
f"Fetching gateway token from {self.gateway_auth_url} (method={method})"
209+
)
210+
211+
try:
212+
response = httpx.request(
213+
method,
214+
self.gateway_auth_url,
215+
headers=headers if isinstance(headers, dict) else None,
216+
json=body if isinstance(body, dict) else None,
217+
timeout=self.timeout or 30,
218+
)
219+
response.raise_for_status()
220+
except Exception as exc:
221+
logger.error(f"Gateway auth request failed: {exc}")
222+
raise
223+
224+
try:
225+
payload = response.json()
226+
except Exception as exc:
227+
logger.error(f"Failed to parse gateway auth response JSON: {exc}")
228+
raise
229+
230+
# Extract token from response
231+
token_path = self.gateway_auth_token_path
232+
token_value = self._extract_from_path(payload, token_path)
233+
if not isinstance(token_value, str) or not token_value.strip():
234+
raise ValueError(
235+
f"Gateway auth response did not contain token at path "
236+
f'"{token_path}". Response: {payload}'
237+
)
238+
239+
# Determine TTL
240+
ttl_seconds = float(self.gateway_auth_token_ttl or 300)
241+
242+
# Update cache
243+
self._gateway_token = token_value.strip()
244+
self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0)
245+
246+
logger.info(f"Gateway token refreshed successfully (expires in {ttl_seconds}s)")
247+
return self._gateway_token
248+
249+
def _render_templates(self, value: Any) -> Any:
250+
"""Replace template variables in strings with actual values.
251+
252+
Supports:
253+
- {{llm_model}} -> self.model
254+
- {{llm_base_url}} -> self.base_url
255+
- {{llm_api_key}} -> self.api_key (if set)
256+
257+
Args:
258+
value: String, dict, list, or other value to render.
259+
260+
Returns:
261+
Value with templates replaced.
262+
"""
263+
if isinstance(value, str):
264+
replacements: dict[str, str] = {
265+
"{{llm_model}}": self.model,
266+
"{{llm_base_url}}": self.base_url or "",
267+
}
268+
if self.api_key:
269+
replacements["{{llm_api_key}}"] = self.api_key.get_secret_value()
270+
271+
result = value
272+
for placeholder, actual in replacements.items():
273+
result = result.replace(placeholder, actual)
274+
return result
275+
276+
if isinstance(value, dict):
277+
return {k: self._render_templates(v) for k, v in value.items()}
278+
279+
if isinstance(value, list):
280+
return [self._render_templates(v) for v in value]
281+
282+
return value
283+
284+
@staticmethod
285+
def _extract_from_path(payload: Any, path: str) -> Any:
286+
"""Extract a value from nested dict/list using dot notation.
287+
288+
Examples:
289+
_extract_from_path({"a": {"b": "value"}}, "a.b") -> "value"
290+
_extract_from_path({"data": [{"token": "x"}]}, "data.0.token") -> "x"
291+
292+
Args:
293+
payload: Dict or list to traverse.
294+
path: Dot-separated path (e.g., "data.token" or "items.0.value").
295+
296+
Returns:
297+
Value at the specified path.
298+
299+
Raises:
300+
ValueError: If path cannot be traversed.
301+
"""
302+
current = payload
303+
if not path:
304+
return current
305+
306+
for part in path.split("."):
307+
if isinstance(current, dict):
308+
current = current.get(part)
309+
if current is None:
310+
raise ValueError(
311+
f'Key "{part}" not found in response '
312+
f'while traversing path "{path}".'
313+
)
314+
elif isinstance(current, list):
315+
try:
316+
index = int(part)
317+
except (ValueError, TypeError):
318+
raise ValueError(
319+
f'Invalid list index "{part}" '
320+
f'while traversing response path "{path}".'
321+
) from None
322+
try:
323+
current = current[index]
324+
except (IndexError, TypeError):
325+
raise ValueError(
326+
f"Index {index} out of range "
327+
f'while traversing response path "{path}".'
328+
) from None
329+
else:
330+
raise ValueError(
331+
f'Cannot traverse path "{path}"; '
332+
f'segment "{part}" not found or not accessible.'
333+
)
334+
335+
return current

0 commit comments

Comments
 (0)