Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openhands-sdk/openhands/sdk/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from openhands.sdk.llm.llm import LLM
from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent
from openhands.sdk.llm.llm_response import LLMResponse
from openhands.sdk.llm.llm_with_gateway import LLMWithGateway
from openhands.sdk.llm.message import (
ImageContent,
Message,
Expand All @@ -23,6 +24,7 @@
__all__ = [
"LLMResponse",
"LLM",
"LLMWithGateway",
"LLMRegistry",
"RouterLLM",
"RegistryEvent",
Expand Down
67 changes: 53 additions & 14 deletions openhands-sdk/openhands/sdk/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin):
)
ollama_base_url: str | None = Field(default=None)

ssl_verify: bool | str | None = Field(
default=None,
description=(
"TLS verification forwarded to LiteLLM; "
"set to False when corporate proxies break certificate chains."
),
)

drop_params: bool = Field(default=True)
modify_params: bool = Field(
default=True,
Expand Down Expand Up @@ -446,15 +454,19 @@ def completion(
has_tools_flag = bool(cc_tools) and use_native_fc
# Behavior-preserving: delegate to select_chat_options
call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag)
call_kwargs = self._prepare_request_kwargs(call_kwargs)

# 4) optional request logging context (kept small)
assert self._telemetry is not None
log_ctx = None
if self._telemetry.log_enabled:
sanitized_kwargs = {
k: v for k, v in call_kwargs.items() if k != "extra_headers"
}
log_ctx = {
"messages": formatted_messages[:], # already simple dicts
"tools": tools,
"kwargs": {k: v for k, v in call_kwargs.items()},
"kwargs": sanitized_kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be better done in telemetry.py, because excluding all extra_headers means also excluding Anthropic's extended thinking from telemetry, and any other non-auth/non-authz client code use 🤔

Maybe we could do it like this, and fix it in a follow-up?

"context_window": self.max_input_tokens or 0,
}
if tools and not use_native_fc:
Expand All @@ -473,7 +485,7 @@ def completion(
def _one_attempt(**retry_kwargs) -> ModelResponse:
assert self._telemetry is not None
# Merge retry-modified kwargs (like temperature) with call_kwargs
final_kwargs = {**call_kwargs, **retry_kwargs}
final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs})
resp = self._transport_call(messages=formatted_messages, **final_kwargs)
raw_resp: ModelResponse | None = None
if use_mock_tools:
Expand Down Expand Up @@ -557,16 +569,20 @@ def responses(
call_kwargs = select_responses_options(
self, kwargs, include=include, store=store
)
call_kwargs = self._prepare_request_kwargs(call_kwargs)

# Optional request logging
assert self._telemetry is not None
log_ctx = None
if self._telemetry.log_enabled:
sanitized_kwargs = {
k: v for k, v in call_kwargs.items() if k != "extra_headers"
}
log_ctx = {
"llm_path": "responses",
"input": input_items[:],
"tools": tools,
"kwargs": {k: v for k, v in call_kwargs.items()},
"kwargs": sanitized_kwargs,
"context_window": self.max_input_tokens or 0,
}
self._telemetry.on_request(log_ctx=log_ctx)
Expand All @@ -581,7 +597,7 @@ def responses(
retry_listener=self.retry_listener,
)
def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse:
final_kwargs = {**call_kwargs, **retry_kwargs}
final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs})
with self._litellm_modify_params_ctx(self.modify_params):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
Expand All @@ -598,15 +614,17 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse:
else None,
api_base=self.base_url,
api_version=self.api_version,
custom_llm_provider=self.custom_llm_provider,
timeout=self.timeout,
ssl_verify=self.ssl_verify,
drop_params=self.drop_params,
seed=self.seed,
**final_kwargs,
)
assert isinstance(ret, ResponsesAPIResponse), (
f"Expected ResponsesAPIResponse, got {type(ret)}"
)
# telemetry (latency, cost). Token usage mapping we handle after.
# telemetry (latency, cost). Token usage handled after.
assert self._telemetry is not None
self._telemetry.on_response(ret)
return ret
Expand Down Expand Up @@ -637,6 +655,11 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse:
# =========================================================================
# Transport + helpers
# =========================================================================
def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]:
def prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]:

Nit: if it's a hook, it probably should be part of public API

"""Hook for subclasses to adjust final LiteLLM kwargs."""

return call_kwargs

def _transport_call(
self, *, messages: list[dict[str, Any]], **kwargs
) -> ModelResponse:
Expand Down Expand Up @@ -666,7 +689,9 @@ def _transport_call(
api_key=self.api_key.get_secret_value() if self.api_key else None,
base_url=self.base_url,
api_version=self.api_version,
custom_llm_provider=self.custom_llm_provider,
timeout=self.timeout,
ssl_verify=self.ssl_verify,
drop_params=self.drop_params,
seed=self.seed,
messages=messages,
Expand Down Expand Up @@ -928,6 +953,7 @@ def load_from_json(cls, json_path: str) -> LLM:
@classmethod
def load_from_env(cls, prefix: str = "LLM_") -> LLM:
TRUTHY = {"true", "1", "yes", "on"}
FALSY = {"false", "0", "no", "off"}

def _unwrap_type(t: Any) -> Any:
origin = get_origin(t)
Expand All @@ -936,31 +962,44 @@ def _unwrap_type(t: Any) -> Any:
args = [a for a in get_args(t) if a is not type(None)]
return args[0] if args else t

def _cast_value(raw: str, t: Any) -> Any:
t = _unwrap_type(t)
def _cast_value(field_name: str, raw: str, annotation: Any) -> Any:
stripped = raw.strip()
lowered = stripped.lower()
if field_name == "ssl_verify":
if lowered in TRUTHY:
return True
if lowered in FALSY:
return False
return stripped

t = _unwrap_type(annotation)
if t is SecretStr:
return SecretStr(raw)
return SecretStr(stripped)
if t is bool:
return raw.lower() in TRUTHY
if lowered in TRUTHY:
return True
if lowered in FALSY:
return False
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be dense, could you perhaps tell why we needed None for bool?

It seems that in the subclass, the new field ssl_verify is None, and if that is None, do we still need these changes?

if t is int:
try:
return int(raw)
return int(stripped)
except ValueError:
return None
if t is float:
try:
return float(raw)
return float(stripped)
except ValueError:
return None
origin = get_origin(t)
if (origin in (list, dict, tuple)) or (
isinstance(t, type) and issubclass(t, BaseModel)
):
try:
return json.loads(raw)
return json.loads(stripped)
except Exception:
pass
return raw
return stripped

data: dict[str, Any] = {}
fields: dict[str, Any] = {
Expand All @@ -975,7 +1014,7 @@ def _cast_value(raw: str, t: Any) -> Any:
field_name = key[len(prefix) :].lower()
if field_name not in fields:
continue
v = _cast_value(value, fields[field_name])
v = _cast_value(field_name, value, fields[field_name])
if v is not None:
data[field_name] = v
return cls(**data)
Expand Down
91 changes: 91 additions & 0 deletions openhands-sdk/openhands/sdk/llm/llm_with_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""LLM subclass with enterprise gateway support.

This module provides LLMWithGateway, which extends the base LLM class to support
custom headers for enterprise API gateways.
"""

from __future__ import annotations

from collections.abc import Mapping
from typing import Any

from pydantic import Field

from openhands.sdk.llm.llm import LLM
from openhands.sdk.logger import get_logger


__all__ = ["LLMWithGateway"]


logger = get_logger(__name__)


class LLMWithGateway(LLM):
"""LLM subclass with enterprise gateway support.

Supports adding static custom headers on each request. Take care not to include
raw secrets in headers unless the gateway is trusted and headers are never logged.
"""

custom_headers: dict[str, str] | None = Field(
default=None,
description="Custom headers to include with every LLM request.",
)

def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]:
prepared = dict(super()._prepare_request_kwargs(call_kwargs))

if not self.custom_headers:
return prepared

existing = prepared.get("extra_headers")
base_headers: dict[str, Any]
if isinstance(existing, Mapping):
base_headers = dict(existing)
elif existing is None:
base_headers = {}
else:
base_headers = {}

merged, collisions = self._merge_headers(base_headers, self.custom_headers)
for header, old_val, new_val in collisions:
logger.warning(
"LLMWithGateway overriding header %s (existing=%r, new=%r)",
header,
old_val,
new_val,
)

if merged:
prepared["extra_headers"] = merged

return prepared

@staticmethod
def _merge_headers(
existing: dict[str, Any], new_headers: dict[str, Any]
) -> tuple[dict[str, Any], list[tuple[str, Any, Any]]]:
"""Merge header dictionaries case-insensitively.

Returns the merged headers and a list of collisions where an existing
header was replaced with a different value.
"""

merged = dict(existing)
lower_map = {k.lower(): k for k in merged}
collisions: list[tuple[str, Any, Any]] = []

for key, value in new_headers.items():
lower = key.lower()
if lower in lower_map:
canonical = lower_map[lower]
old_value = merged[canonical]
if old_value != value:
collisions.append((canonical, old_value, value))
merged[canonical] = value
else:
merged[key] = value
lower_map[lower] = key

return merged, collisions
Loading
Loading