|
| 1 | +"""LLM subclass with enterprise gateway support. |
| 2 | +
|
| 3 | +This module provides LLMWithGateway, which extends the base LLM class to support |
| 4 | +OAuth 2.0 authentication flows and custom headers for enterprise API gateways. |
| 5 | +""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +import threading |
| 10 | +import time |
| 11 | +from typing import Any |
| 12 | + |
| 13 | +import httpx |
| 14 | +from pydantic import Field, PrivateAttr |
| 15 | + |
| 16 | +from openhands.sdk.llm.llm import LLM |
| 17 | +from openhands.sdk.logger import get_logger |
| 18 | + |
| 19 | +logger = get_logger(__name__) |
| 20 | + |
| 21 | +__all__ = ["LLMWithGateway"] |
| 22 | + |
| 23 | + |
| 24 | +class LLMWithGateway(LLM): |
| 25 | + """LLM subclass with enterprise gateway support. |
| 26 | +
|
| 27 | + Supports OAuth 2.0 token exchange with configurable headers and bodies. |
| 28 | + Designed for enterprise API gateways that require: |
| 29 | + 1. Initial OAuth call to get a bearer token |
| 30 | + 2. Bearer token included in subsequent LLM API calls |
| 31 | + 3. Custom headers for routing/authentication |
| 32 | +
|
| 33 | + Example usage: |
| 34 | + llm = LLMWithGateway( |
| 35 | + model="gpt-4", |
| 36 | + base_url="https://gateway.company.com/llm/v1", |
| 37 | + gateway_auth_url="https://gateway.company.com/oauth/token", |
| 38 | + gateway_auth_headers={ |
| 39 | + "X-Client-Id": os.environ["GATEWAY_CLIENT_ID"], |
| 40 | + "X-Client-Secret": os.environ["GATEWAY_CLIENT_SECRET"], |
| 41 | + }, |
| 42 | + gateway_auth_body={"grant_type": "client_credentials"}, |
| 43 | + custom_headers={"X-Gateway-Key": os.environ["GATEWAY_API_KEY"]}, |
| 44 | + ) |
| 45 | + """ |
| 46 | + |
| 47 | + # OAuth configuration |
| 48 | + gateway_auth_url: str | None = Field( |
| 49 | + default=None, |
| 50 | + description="Identity provider URL to fetch gateway tokens (OAuth endpoint).", |
| 51 | + ) |
| 52 | + gateway_auth_method: str = Field( |
| 53 | + default="POST", |
| 54 | + description="HTTP method for identity provider requests.", |
| 55 | + ) |
| 56 | + gateway_auth_headers: dict[str, str] | None = Field( |
| 57 | + default=None, |
| 58 | + description="Headers to include when calling the identity provider.", |
| 59 | + ) |
| 60 | + gateway_auth_body: dict[str, Any] | None = Field( |
| 61 | + default=None, |
| 62 | + description="JSON body to include when calling the identity provider.", |
| 63 | + ) |
| 64 | + gateway_auth_token_path: str = Field( |
| 65 | + default="access_token", |
| 66 | + description="Dot-notation path to the token in the OAuth response (e.g., 'access_token' or 'data.token').", |
| 67 | + ) |
| 68 | + gateway_auth_token_ttl: int | None = Field( |
| 69 | + default=None, |
| 70 | + description="Token TTL in seconds. If not set, defaults to 300s (5 minutes).", |
| 71 | + ) |
| 72 | + |
| 73 | + # Token header configuration |
| 74 | + gateway_token_header: str = Field( |
| 75 | + default="Authorization", |
| 76 | + description="Header name for the gateway token (defaults to 'Authorization').", |
| 77 | + ) |
| 78 | + gateway_token_prefix: str = Field( |
| 79 | + default="Bearer ", |
| 80 | + description="Prefix prepended to the token (e.g., 'Bearer ').", |
| 81 | + ) |
| 82 | + |
| 83 | + # Custom headers for all requests |
| 84 | + custom_headers: dict[str, str] | None = Field( |
| 85 | + default=None, |
| 86 | + description="Custom headers to include with every LLM request.", |
| 87 | + ) |
| 88 | + |
| 89 | + # Private fields for token management |
| 90 | + _gateway_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) |
| 91 | + _gateway_token: str | None = PrivateAttr(default=None) |
| 92 | + _gateway_token_expiry: float | None = PrivateAttr(default=None) |
| 93 | + |
| 94 | + def model_post_init(self, __context: Any) -> None: |
| 95 | + """Initialize private fields after model validation.""" |
| 96 | + super().model_post_init(__context) |
| 97 | + self._gateway_lock = threading.Lock() |
| 98 | + self._gateway_token = None |
| 99 | + self._gateway_token_expiry = None |
| 100 | + |
| 101 | + def _completion( |
| 102 | + self, messages: list[dict], **kwargs |
| 103 | + ) -> Any: # Returns ModelResponse |
| 104 | + """Override to inject gateway authentication before calling LiteLLM.""" |
| 105 | + self._prepare_gateway_call(kwargs) |
| 106 | + return super()._completion(messages, **kwargs) |
| 107 | + |
| 108 | + def _responses( |
| 109 | + self, messages: list[dict], **kwargs |
| 110 | + ) -> Any: # Returns ResponsesAPIResponse |
| 111 | + """Override to inject gateway authentication before calling LiteLLM.""" |
| 112 | + self._prepare_gateway_call(kwargs) |
| 113 | + return super()._responses(messages, **kwargs) |
| 114 | + |
| 115 | + def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None: |
| 116 | + """Augment LiteLLM kwargs with gateway headers and token. |
| 117 | +
|
| 118 | + This method: |
| 119 | + 1. Fetches/refreshes OAuth token if needed |
| 120 | + 2. Adds token to headers |
| 121 | + 3. Adds custom headers |
| 122 | + 4. Performs basic template variable replacement |
| 123 | + """ |
| 124 | + if not self.gateway_auth_url and not self.custom_headers: |
| 125 | + return |
| 126 | + |
| 127 | + # Start with existing headers |
| 128 | + headers: dict[str, str] = {} |
| 129 | + existing_headers = call_kwargs.get("extra_headers") |
| 130 | + if isinstance(existing_headers, dict): |
| 131 | + headers.update(existing_headers) |
| 132 | + |
| 133 | + # Add custom headers (with template replacement) |
| 134 | + if self.custom_headers: |
| 135 | + rendered_headers = self._render_templates(self.custom_headers) |
| 136 | + if isinstance(rendered_headers, dict): |
| 137 | + headers.update(rendered_headers) |
| 138 | + |
| 139 | + # Add gateway token if OAuth is configured |
| 140 | + if self.gateway_auth_url: |
| 141 | + token_headers = self._get_gateway_token_headers() |
| 142 | + if token_headers: |
| 143 | + headers.update(token_headers) |
| 144 | + |
| 145 | + # Set headers on the call |
| 146 | + if headers: |
| 147 | + call_kwargs["extra_headers"] = headers |
| 148 | + |
| 149 | + def _get_gateway_token_headers(self) -> dict[str, str]: |
| 150 | + """Get headers containing the gateway token.""" |
| 151 | + token = self._ensure_gateway_token() |
| 152 | + if not token: |
| 153 | + return {} |
| 154 | + |
| 155 | + header_name = self.gateway_token_header |
| 156 | + prefix = self.gateway_token_prefix |
| 157 | + value = f"{prefix}{token}" if prefix else token |
| 158 | + return {header_name: value} |
| 159 | + |
| 160 | + def _ensure_gateway_token(self) -> str | None: |
| 161 | + """Ensure we have a valid gateway token, refreshing if needed. |
| 162 | +
|
| 163 | + Returns: |
| 164 | + Valid gateway token, or None if gateway auth is not configured. |
| 165 | + """ |
| 166 | + if not self.gateway_auth_url: |
| 167 | + return None |
| 168 | + |
| 169 | + # Fast path: check if current token is still valid (with 5s buffer) |
| 170 | + now = time.time() |
| 171 | + if ( |
| 172 | + self._gateway_token |
| 173 | + and self._gateway_token_expiry |
| 174 | + and now < self._gateway_token_expiry - 5 |
| 175 | + ): |
| 176 | + return self._gateway_token |
| 177 | + |
| 178 | + # Slow path: acquire lock and refresh token |
| 179 | + with self._gateway_lock: |
| 180 | + # Double-check after acquiring lock |
| 181 | + if ( |
| 182 | + self._gateway_token |
| 183 | + and self._gateway_token_expiry |
| 184 | + and time.time() < self._gateway_token_expiry - 5 |
| 185 | + ): |
| 186 | + return self._gateway_token |
| 187 | + |
| 188 | + # Refresh token |
| 189 | + return self._refresh_gateway_token() |
| 190 | + |
| 191 | + def _refresh_gateway_token(self) -> str: |
| 192 | + """Fetch a new gateway token from the identity provider. |
| 193 | +
|
| 194 | + This method is called while holding _gateway_lock. |
| 195 | +
|
| 196 | + Returns: |
| 197 | + Fresh gateway token. |
| 198 | +
|
| 199 | + Raises: |
| 200 | + Exception: If token fetch fails. |
| 201 | + """ |
| 202 | + method = self.gateway_auth_method.upper() |
| 203 | + headers = self._render_templates(self.gateway_auth_headers or {}) |
| 204 | + body = self._render_templates(self.gateway_auth_body or {}) |
| 205 | + |
| 206 | + logger.debug( |
| 207 | + f"Fetching gateway token from {self.gateway_auth_url} " |
| 208 | + f"(method={method})" |
| 209 | + ) |
| 210 | + |
| 211 | + try: |
| 212 | + response = httpx.request( |
| 213 | + method, |
| 214 | + self.gateway_auth_url, |
| 215 | + headers=headers if isinstance(headers, dict) else None, |
| 216 | + json=body if isinstance(body, dict) else None, |
| 217 | + timeout=self.timeout or 30, |
| 218 | + ) |
| 219 | + response.raise_for_status() |
| 220 | + except Exception as exc: |
| 221 | + logger.error(f"Gateway auth request failed: {exc}") |
| 222 | + raise |
| 223 | + |
| 224 | + try: |
| 225 | + payload = response.json() |
| 226 | + except Exception as exc: |
| 227 | + logger.error(f"Failed to parse gateway auth response JSON: {exc}") |
| 228 | + raise |
| 229 | + |
| 230 | + # Extract token from response |
| 231 | + token_path = self.gateway_auth_token_path |
| 232 | + token_value = self._extract_from_path(payload, token_path) |
| 233 | + if not isinstance(token_value, str) or not token_value.strip(): |
| 234 | + raise ValueError( |
| 235 | + f"Gateway auth response did not contain token at path " |
| 236 | + f'"{token_path}". Response: {payload}' |
| 237 | + ) |
| 238 | + |
| 239 | + # Determine TTL |
| 240 | + ttl_seconds = float(self.gateway_auth_token_ttl or 300) |
| 241 | + |
| 242 | + # Update cache |
| 243 | + self._gateway_token = token_value.strip() |
| 244 | + self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0) |
| 245 | + |
| 246 | + logger.info( |
| 247 | + f"Gateway token refreshed successfully (expires in {ttl_seconds}s)" |
| 248 | + ) |
| 249 | + return self._gateway_token |
| 250 | + |
| 251 | + def _render_templates(self, value: Any) -> Any: |
| 252 | + """Replace template variables in strings with actual values. |
| 253 | +
|
| 254 | + Supports: |
| 255 | + - {{llm_model}} -> self.model |
| 256 | + - {{llm_base_url}} -> self.base_url |
| 257 | + - {{llm_api_key}} -> self.api_key (if set) |
| 258 | +
|
| 259 | + Args: |
| 260 | + value: String, dict, list, or other value to render. |
| 261 | +
|
| 262 | + Returns: |
| 263 | + Value with templates replaced. |
| 264 | + """ |
| 265 | + if isinstance(value, str): |
| 266 | + replacements: dict[str, str] = { |
| 267 | + "{{llm_model}}": self.model, |
| 268 | + "{{llm_base_url}}": self.base_url or "", |
| 269 | + } |
| 270 | + if self.api_key: |
| 271 | + replacements["{{llm_api_key}}"] = self.api_key.get_secret_value() |
| 272 | + |
| 273 | + result = value |
| 274 | + for placeholder, actual in replacements.items(): |
| 275 | + result = result.replace(placeholder, actual) |
| 276 | + return result |
| 277 | + |
| 278 | + if isinstance(value, dict): |
| 279 | + return {k: self._render_templates(v) for k, v in value.items()} |
| 280 | + |
| 281 | + if isinstance(value, list): |
| 282 | + return [self._render_templates(v) for v in value] |
| 283 | + |
| 284 | + return value |
| 285 | + |
| 286 | + @staticmethod |
| 287 | + def _extract_from_path(payload: Any, path: str) -> Any: |
| 288 | + """Extract a value from nested dict/list using dot notation. |
| 289 | +
|
| 290 | + Examples: |
| 291 | + _extract_from_path({"a": {"b": "value"}}, "a.b") -> "value" |
| 292 | + _extract_from_path({"data": [{"token": "x"}]}, "data.0.token") -> "x" |
| 293 | +
|
| 294 | + Args: |
| 295 | + payload: Dict or list to traverse. |
| 296 | + path: Dot-separated path (e.g., "data.token" or "items.0.value"). |
| 297 | +
|
| 298 | + Returns: |
| 299 | + Value at the specified path. |
| 300 | +
|
| 301 | + Raises: |
| 302 | + ValueError: If path cannot be traversed. |
| 303 | + """ |
| 304 | + current = payload |
| 305 | + if not path: |
| 306 | + return current |
| 307 | + |
| 308 | + for part in path.split("."): |
| 309 | + if isinstance(current, dict): |
| 310 | + current = current.get(part) |
| 311 | + if current is None: |
| 312 | + raise ValueError( |
| 313 | + f'Key "{part}" not found in response while traversing path "{path}".' |
| 314 | + ) |
| 315 | + elif isinstance(current, list): |
| 316 | + try: |
| 317 | + index = int(part) |
| 318 | + except (ValueError, TypeError): |
| 319 | + raise ValueError( |
| 320 | + f'Invalid list index "{part}" while traversing response path "{path}".' |
| 321 | + ) from None |
| 322 | + try: |
| 323 | + current = current[index] |
| 324 | + except (IndexError, TypeError): |
| 325 | + raise ValueError( |
| 326 | + f'Index {index} out of range while traversing response path "{path}".' |
| 327 | + ) from None |
| 328 | + else: |
| 329 | + raise ValueError( |
| 330 | + f'Cannot traverse path "{path}"; segment "{part}" not found or not accessible.' |
| 331 | + ) |
| 332 | + |
| 333 | + return current |
0 commit comments