-
Notifications
You must be signed in to change notification settings - Fork 47
feat: add enterprise gateway support for LLM providers #963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ad9752c
4cfd5a4
28eb5a1
ec34c93
1029d4f
10fdfbc
00e9ae8
c52bb3e
4b9b99d
aa373aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -162,6 +162,14 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): | |||||
| ) | ||||||
| ollama_base_url: str | None = Field(default=None) | ||||||
|
|
||||||
| ssl_verify: bool | str | None = Field( | ||||||
| default=None, | ||||||
| description=( | ||||||
| "TLS verification forwarded to LiteLLM; " | ||||||
| "set to False when corporate proxies break certificate chains." | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| drop_params: bool = Field(default=True) | ||||||
| modify_params: bool = Field( | ||||||
| default=True, | ||||||
|
|
@@ -446,15 +454,19 @@ def completion( | |||||
| has_tools_flag = bool(cc_tools) and use_native_fc | ||||||
| # Behavior-preserving: delegate to select_chat_options | ||||||
| call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) | ||||||
| call_kwargs = self._prepare_request_kwargs(call_kwargs) | ||||||
|
|
||||||
| # 4) optional request logging context (kept small) | ||||||
| assert self._telemetry is not None | ||||||
| log_ctx = None | ||||||
| if self._telemetry.log_enabled: | ||||||
| sanitized_kwargs = { | ||||||
| k: v for k, v in call_kwargs.items() if k != "extra_headers" | ||||||
| } | ||||||
| log_ctx = { | ||||||
| "messages": formatted_messages[:], # already simple dicts | ||||||
| "tools": tools, | ||||||
| "kwargs": {k: v for k, v in call_kwargs.items()}, | ||||||
| "kwargs": sanitized_kwargs, | ||||||
| "context_window": self.max_input_tokens or 0, | ||||||
| } | ||||||
| if tools and not use_native_fc: | ||||||
|
|
@@ -473,7 +485,7 @@ def completion( | |||||
| def _one_attempt(**retry_kwargs) -> ModelResponse: | ||||||
| assert self._telemetry is not None | ||||||
| # Merge retry-modified kwargs (like temperature) with call_kwargs | ||||||
| final_kwargs = {**call_kwargs, **retry_kwargs} | ||||||
| final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs}) | ||||||
| resp = self._transport_call(messages=formatted_messages, **final_kwargs) | ||||||
| raw_resp: ModelResponse | None = None | ||||||
| if use_mock_tools: | ||||||
|
|
@@ -557,16 +569,20 @@ def responses( | |||||
| call_kwargs = select_responses_options( | ||||||
| self, kwargs, include=include, store=store | ||||||
| ) | ||||||
| call_kwargs = self._prepare_request_kwargs(call_kwargs) | ||||||
|
|
||||||
| # Optional request logging | ||||||
| assert self._telemetry is not None | ||||||
| log_ctx = None | ||||||
| if self._telemetry.log_enabled: | ||||||
| sanitized_kwargs = { | ||||||
| k: v for k, v in call_kwargs.items() if k != "extra_headers" | ||||||
| } | ||||||
| log_ctx = { | ||||||
| "llm_path": "responses", | ||||||
| "input": input_items[:], | ||||||
| "tools": tools, | ||||||
| "kwargs": {k: v for k, v in call_kwargs.items()}, | ||||||
| "kwargs": sanitized_kwargs, | ||||||
| "context_window": self.max_input_tokens or 0, | ||||||
| } | ||||||
| self._telemetry.on_request(log_ctx=log_ctx) | ||||||
|
|
@@ -581,7 +597,7 @@ def responses( | |||||
| retry_listener=self.retry_listener, | ||||||
| ) | ||||||
| def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: | ||||||
| final_kwargs = {**call_kwargs, **retry_kwargs} | ||||||
| final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs}) | ||||||
| with self._litellm_modify_params_ctx(self.modify_params): | ||||||
| with warnings.catch_warnings(): | ||||||
| warnings.filterwarnings("ignore", category=DeprecationWarning) | ||||||
|
|
@@ -598,15 +614,17 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: | |||||
| else None, | ||||||
| api_base=self.base_url, | ||||||
| api_version=self.api_version, | ||||||
| custom_llm_provider=self.custom_llm_provider, | ||||||
| timeout=self.timeout, | ||||||
| ssl_verify=self.ssl_verify, | ||||||
| drop_params=self.drop_params, | ||||||
| seed=self.seed, | ||||||
| **final_kwargs, | ||||||
| ) | ||||||
| assert isinstance(ret, ResponsesAPIResponse), ( | ||||||
| f"Expected ResponsesAPIResponse, got {type(ret)}" | ||||||
| ) | ||||||
| # telemetry (latency, cost). Token usage mapping we handle after. | ||||||
| # telemetry (latency, cost). Token usage handled after. | ||||||
| assert self._telemetry is not None | ||||||
| self._telemetry.on_response(ret) | ||||||
| return ret | ||||||
|
|
@@ -637,6 +655,11 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: | |||||
| # ========================================================================= | ||||||
| # Transport + helpers | ||||||
| # ========================================================================= | ||||||
| def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Nit: if it's a hook, it probably should be part of public API |
||||||
| """Hook for subclasses to adjust final LiteLLM kwargs.""" | ||||||
|
|
||||||
| return call_kwargs | ||||||
|
|
||||||
| def _transport_call( | ||||||
| self, *, messages: list[dict[str, Any]], **kwargs | ||||||
| ) -> ModelResponse: | ||||||
|
|
@@ -666,7 +689,9 @@ def _transport_call( | |||||
| api_key=self.api_key.get_secret_value() if self.api_key else None, | ||||||
| base_url=self.base_url, | ||||||
| api_version=self.api_version, | ||||||
| custom_llm_provider=self.custom_llm_provider, | ||||||
| timeout=self.timeout, | ||||||
| ssl_verify=self.ssl_verify, | ||||||
| drop_params=self.drop_params, | ||||||
| seed=self.seed, | ||||||
| messages=messages, | ||||||
|
|
@@ -928,6 +953,7 @@ def load_from_json(cls, json_path: str) -> LLM: | |||||
| @classmethod | ||||||
| def load_from_env(cls, prefix: str = "LLM_") -> LLM: | ||||||
| TRUTHY = {"true", "1", "yes", "on"} | ||||||
| FALSY = {"false", "0", "no", "off"} | ||||||
|
|
||||||
| def _unwrap_type(t: Any) -> Any: | ||||||
| origin = get_origin(t) | ||||||
|
|
@@ -936,31 +962,44 @@ def _unwrap_type(t: Any) -> Any: | |||||
| args = [a for a in get_args(t) if a is not type(None)] | ||||||
| return args[0] if args else t | ||||||
|
|
||||||
| def _cast_value(raw: str, t: Any) -> Any: | ||||||
| t = _unwrap_type(t) | ||||||
| def _cast_value(field_name: str, raw: str, annotation: Any) -> Any: | ||||||
| stripped = raw.strip() | ||||||
| lowered = stripped.lower() | ||||||
| if field_name == "ssl_verify": | ||||||
| if lowered in TRUTHY: | ||||||
| return True | ||||||
| if lowered in FALSY: | ||||||
| return False | ||||||
| return stripped | ||||||
|
|
||||||
| t = _unwrap_type(annotation) | ||||||
| if t is SecretStr: | ||||||
| return SecretStr(raw) | ||||||
| return SecretStr(stripped) | ||||||
| if t is bool: | ||||||
| return raw.lower() in TRUTHY | ||||||
| if lowered in TRUTHY: | ||||||
| return True | ||||||
| if lowered in FALSY: | ||||||
| return False | ||||||
| return None | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to be dense, could you perhaps tell why we needed None for bool? It seems that in the subclass, the new field |
||||||
| if t is int: | ||||||
| try: | ||||||
| return int(raw) | ||||||
| return int(stripped) | ||||||
| except ValueError: | ||||||
| return None | ||||||
| if t is float: | ||||||
| try: | ||||||
| return float(raw) | ||||||
| return float(stripped) | ||||||
| except ValueError: | ||||||
| return None | ||||||
| origin = get_origin(t) | ||||||
| if (origin in (list, dict, tuple)) or ( | ||||||
| isinstance(t, type) and issubclass(t, BaseModel) | ||||||
| ): | ||||||
| try: | ||||||
| return json.loads(raw) | ||||||
| return json.loads(stripped) | ||||||
| except Exception: | ||||||
| pass | ||||||
| return raw | ||||||
| return stripped | ||||||
|
|
||||||
| data: dict[str, Any] = {} | ||||||
| fields: dict[str, Any] = { | ||||||
|
|
@@ -975,7 +1014,7 @@ def _cast_value(raw: str, t: Any) -> Any: | |||||
| field_name = key[len(prefix) :].lower() | ||||||
| if field_name not in fields: | ||||||
| continue | ||||||
| v = _cast_value(value, fields[field_name]) | ||||||
| v = _cast_value(field_name, value, fields[field_name]) | ||||||
| if v is not None: | ||||||
| data[field_name] = v | ||||||
| return cls(**data) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| """LLM subclass with enterprise gateway support. | ||
|
|
||
| This module provides LLMWithGateway, which extends the base LLM class to support | ||
| custom headers for enterprise API gateways. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Mapping | ||
| from typing import Any | ||
|
|
||
| from pydantic import Field | ||
|
|
||
| from openhands.sdk.llm.llm import LLM | ||
| from openhands.sdk.logger import get_logger | ||
|
|
||
|
|
||
| __all__ = ["LLMWithGateway"] | ||
|
|
||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| class LLMWithGateway(LLM): | ||
| """LLM subclass with enterprise gateway support. | ||
|
|
||
| Supports adding static custom headers on each request. Take care not to include | ||
| raw secrets in headers unless the gateway is trusted and headers are never logged. | ||
| """ | ||
|
|
||
| custom_headers: dict[str, str] | None = Field( | ||
| default=None, | ||
| description="Custom headers to include with every LLM request.", | ||
| ) | ||
|
|
||
| def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: | ||
| prepared = dict(super()._prepare_request_kwargs(call_kwargs)) | ||
|
|
||
| if not self.custom_headers: | ||
| return prepared | ||
|
|
||
| existing = prepared.get("extra_headers") | ||
| base_headers: dict[str, Any] | ||
| if isinstance(existing, Mapping): | ||
| base_headers = dict(existing) | ||
| elif existing is None: | ||
| base_headers = {} | ||
| else: | ||
| base_headers = {} | ||
|
|
||
| merged, collisions = self._merge_headers(base_headers, self.custom_headers) | ||
| for header, old_val, new_val in collisions: | ||
| logger.warning( | ||
| "LLMWithGateway overriding header %s (existing=%r, new=%r)", | ||
| header, | ||
| old_val, | ||
| new_val, | ||
| ) | ||
|
|
||
| if merged: | ||
| prepared["extra_headers"] = merged | ||
|
|
||
| return prepared | ||
|
|
||
| @staticmethod | ||
| def _merge_headers( | ||
| existing: dict[str, Any], new_headers: dict[str, Any] | ||
| ) -> tuple[dict[str, Any], list[tuple[str, Any, Any]]]: | ||
| """Merge header dictionaries case-insensitively. | ||
|
|
||
| Returns the merged headers and a list of collisions where an existing | ||
| header was replaced with a different value. | ||
| """ | ||
|
|
||
| merged = dict(existing) | ||
| lower_map = {k.lower(): k for k in merged} | ||
| collisions: list[tuple[str, Any, Any]] = [] | ||
|
|
||
| for key, value in new_headers.items(): | ||
| lower = key.lower() | ||
| if lower in lower_map: | ||
| canonical = lower_map[lower] | ||
| old_value = merged[canonical] | ||
| if old_value != value: | ||
| collisions.append((canonical, old_value, value)) | ||
| merged[canonical] = value | ||
| else: | ||
| merged[key] = value | ||
| lower_map[lower] = key | ||
|
|
||
| return merged, collisions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be better done in
telemetry.py, because excluding allextra_headersmeans also excluding Anthropic's extended thinking from telemetry, and any other non-auth/non-authz client code use 🤔Maybe we could do it like this, and fix it in a follow-up?