Skip to content

Commit 660f8a0

Browse files
committed
fix: resolve linting and type errors in LLMWithGateway
- Fix line length (E501) errors by splitting long strings - Fix pyright type errors with proper type assertions and overrides - Fix method overrides to use completion/responses instead of _completion/_responses - Update tests to use proper Message objects instead of dicts - Fix test assertions to match actual error messages - All 25 tests now passing
1 parent a203fdb commit 660f8a0

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

openhands-sdk/openhands/sdk/llm/llm_with_gateway.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from openhands.sdk.llm.llm import LLM
1717
from openhands.sdk.logger import get_logger
1818

19+
1920
logger = get_logger(__name__)
2021

2122
__all__ = ["LLMWithGateway"]
@@ -63,7 +64,10 @@ class LLMWithGateway(LLM):
6364
)
6465
gateway_auth_token_path: str = Field(
6566
default="access_token",
66-
description="Dot-notation path to the token in the OAuth response (e.g., 'access_token' or 'data.token').",
67+
description=(
68+
"Dot-notation path to the token in the OAuth response "
69+
"(e.g., 'access_token' or 'data.token')."
70+
),
6771
)
6872
gateway_auth_token_ttl: int | None = Field(
6973
default=None,
@@ -98,19 +102,15 @@ def model_post_init(self, __context: Any) -> None:
98102
self._gateway_token = None
99103
self._gateway_token_expiry = None
100104

101-
def _completion(
102-
self, messages: list[dict], **kwargs
103-
) -> Any: # Returns ModelResponse
105+
def completion(self, *args, **kwargs):
104106
"""Override to inject gateway authentication before calling LiteLLM."""
105107
self._prepare_gateway_call(kwargs)
106-
return super()._completion(messages, **kwargs)
108+
return super().completion(*args, **kwargs)
107109

108-
def _responses(
109-
self, messages: list[dict], **kwargs
110-
) -> Any: # Returns ResponsesAPIResponse
110+
def responses(self, *args, **kwargs):
111111
"""Override to inject gateway authentication before calling LiteLLM."""
112112
self._prepare_gateway_call(kwargs)
113-
return super()._responses(messages, **kwargs)
113+
return super().responses(*args, **kwargs)
114114

115115
def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None:
116116
"""Augment LiteLLM kwargs with gateway headers and token.
@@ -199,13 +199,13 @@ def _refresh_gateway_token(self) -> str:
199199
Raises:
200200
Exception: If token fetch fails.
201201
"""
202+
assert self.gateway_auth_url is not None, "gateway_auth_url must be set"
202203
method = self.gateway_auth_method.upper()
203204
headers = self._render_templates(self.gateway_auth_headers or {})
204205
body = self._render_templates(self.gateway_auth_body or {})
205206

206207
logger.debug(
207-
f"Fetching gateway token from {self.gateway_auth_url} "
208-
f"(method={method})"
208+
f"Fetching gateway token from {self.gateway_auth_url} (method={method})"
209209
)
210210

211211
try:
@@ -243,9 +243,7 @@ def _refresh_gateway_token(self) -> str:
243243
self._gateway_token = token_value.strip()
244244
self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0)
245245

246-
logger.info(
247-
f"Gateway token refreshed successfully (expires in {ttl_seconds}s)"
248-
)
246+
logger.info(f"Gateway token refreshed successfully (expires in {ttl_seconds}s)")
249247
return self._gateway_token
250248

251249
def _render_templates(self, value: Any) -> Any:
@@ -310,24 +308,28 @@ def _extract_from_path(payload: Any, path: str) -> Any:
310308
current = current.get(part)
311309
if current is None:
312310
raise ValueError(
313-
f'Key "{part}" not found in response while traversing path "{path}".'
311+
f'Key "{part}" not found in response '
312+
f'while traversing path "{path}".'
314313
)
315314
elif isinstance(current, list):
316315
try:
317316
index = int(part)
318317
except (ValueError, TypeError):
319318
raise ValueError(
320-
f'Invalid list index "{part}" while traversing response path "{path}".'
319+
f'Invalid list index "{part}" '
320+
f'while traversing response path "{path}".'
321321
) from None
322322
try:
323323
current = current[index]
324324
except (IndexError, TypeError):
325325
raise ValueError(
326-
f'Index {index} out of range while traversing response path "{path}".'
326+
f"Index {index} out of range "
327+
f'while traversing response path "{path}".'
327328
) from None
328329
else:
329330
raise ValueError(
330-
f'Cannot traverse path "{path}"; segment "{part}" not found or not accessible.'
331+
f'Cannot traverse path "{path}"; '
332+
f'segment "{part}" not found or not accessible.'
331333
)
332334

333335
return current

tests/sdk/llm/test_llm_with_gateway.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from pydantic import SecretStr
1010

11-
from openhands.sdk.llm import LLMWithGateway
11+
from openhands.sdk.llm import LLMWithGateway, Message, TextContent
1212
from tests.conftest import create_mock_litellm_response
1313

1414

@@ -178,7 +178,7 @@ def test_token_fetch_missing_token(self, mock_request, gateway_llm):
178178
mock_response.raise_for_status = Mock()
179179
mock_request.return_value = mock_response
180180

181-
with pytest.raises(ValueError, match="did not contain token"):
181+
with pytest.raises(ValueError, match="not found in response"):
182182
gateway_llm._ensure_gateway_token()
183183

184184

@@ -203,7 +203,7 @@ def test_custom_headers_injected(
203203
)
204204

205205
# Make completion request
206-
messages = [{"role": "user", "content": "test"}]
206+
messages = [Message(role="user", content=[TextContent(text="test")])]
207207
gateway_llm.completion(messages)
208208

209209
# Verify custom headers were passed
@@ -231,7 +231,7 @@ def test_gateway_token_header_injected(
231231
)
232232

233233
# Make completion request
234-
messages = [{"role": "user", "content": "test"}]
234+
messages = [Message(role="user", content=[TextContent(text="test")])]
235235
gateway_llm.completion(messages)
236236

237237
# Verify gateway token was passed
@@ -370,7 +370,7 @@ def test_full_gateway_flow(
370370
)
371371

372372
# Make completion request
373-
messages = [{"role": "user", "content": "Hello"}]
373+
messages = [Message(role="user", content=[TextContent(text="Hello")])]
374374
response = llm.completion(messages)
375375

376376
# Verify OAuth was called
@@ -387,7 +387,8 @@ def test_full_gateway_flow(
387387
assert headers["X-Gateway-Key"] == "gateway789"
388388

389389
# Verify response
390-
assert response.content[0].text == "Hello from gateway!"
390+
assert isinstance(response.message.content[0], TextContent)
391+
assert response.message.content[0].text == "Hello from gateway!"
391392

392393
def test_gateway_disabled_when_no_config(self):
393394
"""Test that gateway logic is skipped when not configured."""

0 commit comments

Comments
 (0)