|
16 | 16 | from openhands.sdk.llm.llm import LLM |
17 | 17 | from openhands.sdk.logger import get_logger |
18 | 18 |
|
| 19 | + |
19 | 20 | logger = get_logger(__name__) |
20 | 21 |
|
21 | 22 | __all__ = ["LLMWithGateway"] |
@@ -63,7 +64,10 @@ class LLMWithGateway(LLM): |
63 | 64 | ) |
64 | 65 | gateway_auth_token_path: str = Field( |
65 | 66 | default="access_token", |
66 | | - description="Dot-notation path to the token in the OAuth response (e.g., 'access_token' or 'data.token').", |
| 67 | + description=( |
| 68 | + "Dot-notation path to the token in the OAuth response " |
| 69 | + "(e.g., 'access_token' or 'data.token')." |
| 70 | + ), |
67 | 71 | ) |
68 | 72 | gateway_auth_token_ttl: int | None = Field( |
69 | 73 | default=None, |
@@ -98,19 +102,15 @@ def model_post_init(self, __context: Any) -> None: |
98 | 102 | self._gateway_token = None |
99 | 103 | self._gateway_token_expiry = None |
100 | 104 |
|
101 | | - def _completion( |
102 | | - self, messages: list[dict], **kwargs |
103 | | - ) -> Any: # Returns ModelResponse |
| 105 | + def completion(self, *args, **kwargs): |
104 | 106 | """Override to inject gateway authentication before calling LiteLLM.""" |
105 | 107 | self._prepare_gateway_call(kwargs) |
106 | | - return super()._completion(messages, **kwargs) |
| 108 | + return super().completion(*args, **kwargs) |
107 | 109 |
|
108 | | - def _responses( |
109 | | - self, messages: list[dict], **kwargs |
110 | | - ) -> Any: # Returns ResponsesAPIResponse |
| 110 | + def responses(self, *args, **kwargs): |
111 | 111 | """Override to inject gateway authentication before calling LiteLLM.""" |
112 | 112 | self._prepare_gateway_call(kwargs) |
113 | | - return super()._responses(messages, **kwargs) |
| 113 | + return super().responses(*args, **kwargs) |
114 | 114 |
|
115 | 115 | def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None: |
116 | 116 | """Augment LiteLLM kwargs with gateway headers and token. |
@@ -199,13 +199,13 @@ def _refresh_gateway_token(self) -> str: |
199 | 199 | Raises: |
200 | 200 | Exception: If token fetch fails. |
201 | 201 | """ |
| 202 | + assert self.gateway_auth_url is not None, "gateway_auth_url must be set" |
202 | 203 | method = self.gateway_auth_method.upper() |
203 | 204 | headers = self._render_templates(self.gateway_auth_headers or {}) |
204 | 205 | body = self._render_templates(self.gateway_auth_body or {}) |
205 | 206 |
|
206 | 207 | logger.debug( |
207 | | - f"Fetching gateway token from {self.gateway_auth_url} " |
208 | | - f"(method={method})" |
| 208 | + f"Fetching gateway token from {self.gateway_auth_url} (method={method})" |
209 | 209 | ) |
210 | 210 |
|
211 | 211 | try: |
@@ -243,9 +243,7 @@ def _refresh_gateway_token(self) -> str: |
243 | 243 | self._gateway_token = token_value.strip() |
244 | 244 | self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0) |
245 | 245 |
|
246 | | - logger.info( |
247 | | - f"Gateway token refreshed successfully (expires in {ttl_seconds}s)" |
248 | | - ) |
| 246 | + logger.info(f"Gateway token refreshed successfully (expires in {ttl_seconds}s)") |
249 | 247 | return self._gateway_token |
250 | 248 |
|
251 | 249 | def _render_templates(self, value: Any) -> Any: |
@@ -310,24 +308,28 @@ def _extract_from_path(payload: Any, path: str) -> Any: |
310 | 308 | current = current.get(part) |
311 | 309 | if current is None: |
312 | 310 | raise ValueError( |
313 | | - f'Key "{part}" not found in response while traversing path "{path}".' |
| 311 | + f'Key "{part}" not found in response ' |
| 312 | + f'while traversing path "{path}".' |
314 | 313 | ) |
315 | 314 | elif isinstance(current, list): |
316 | 315 | try: |
317 | 316 | index = int(part) |
318 | 317 | except (ValueError, TypeError): |
319 | 318 | raise ValueError( |
320 | | - f'Invalid list index "{part}" while traversing response path "{path}".' |
| 319 | + f'Invalid list index "{part}" ' |
| 320 | + f'while traversing response path "{path}".' |
321 | 321 | ) from None |
322 | 322 | try: |
323 | 323 | current = current[index] |
324 | 324 | except (IndexError, TypeError): |
325 | 325 | raise ValueError( |
326 | | - f'Index {index} out of range while traversing response path "{path}".' |
| 326 | + f"Index {index} out of range " |
| 327 | + f'while traversing response path "{path}".' |
327 | 328 | ) from None |
328 | 329 | else: |
329 | 330 | raise ValueError( |
330 | | - f'Cannot traverse path "{path}"; segment "{part}" not found or not accessible.' |
| 331 | + f'Cannot traverse path "{path}"; ' |
| 332 | + f'segment "{part}" not found or not accessible.' |
331 | 333 | ) |
332 | 334 |
|
333 | 335 | return current |
0 commit comments