diff --git a/examples/01_standalone_sdk/02_custom_tools.py b/examples/01_standalone_sdk/02_custom_tools.py index 7ff11a82e7..62292a55ba 100644 --- a/examples/01_standalone_sdk/02_custom_tools.py +++ b/examples/01_standalone_sdk/02_custom_tools.py @@ -26,8 +26,8 @@ ) from openhands.tools.execute_bash import ( BashExecutor, + BashTool, ExecuteBashAction, - execute_bash_tool, ) from openhands.tools.file_editor import FileEditorTool @@ -115,6 +115,42 @@ def __call__(self, action: GrepAction, conversation=None) -> GrepObservation: # * When you are doing an open ended search that may require multiple rounds of globbing and grepping, use the Agent tool instead """ # noqa: E501 + +# --- Tool Definition --- + + +class GrepTool(ToolDefinition[GrepAction, GrepObservation]): + """A custom grep tool that searches file contents using regular expressions.""" + + @classmethod + def create( + cls, conv_state, bash_executor: BashExecutor | None = None + ) -> Sequence[ToolDefinition]: + """Create GrepTool instance with a GrepExecutor. + + Args: + conv_state: Conversation state to get working directory from. + bash_executor: Optional bash executor to reuse. If not provided, + a new one will be created. + + Returns: + A sequence containing a single GrepTool instance. + """ + if bash_executor is None: + bash_executor = BashExecutor(working_dir=conv_state.workspace.working_dir) + grep_executor = GrepExecutor(bash_executor) + + return [ + cls( + name="grep", + description=_GREP_DESCRIPTION, + action_type=GrepAction, + observation_type=GrepObservation, + executor=grep_executor, + ) + ] + + # Configure LLM api_key = os.getenv("LLM_API_KEY") assert api_key is not None, "LLM_API_KEY environment variable is not set." @@ -135,16 +171,11 @@ def _make_bash_and_grep_tools(conv_state) -> list[ToolDefinition]: """Create execute_bash and custom grep tools sharing one executor.""" bash_executor = BashExecutor(working_dir=conv_state.workspace.working_dir) - bash_tool = execute_bash_tool.set_executor(executor=bash_executor) - - grep_executor = GrepExecutor(bash_executor) - grep_tool = ToolDefinition( - name="grep", - description=_GREP_DESCRIPTION, - action_type=GrepAction, - observation_type=GrepObservation, - executor=grep_executor, - ) + # bash_tool = execute_bash_tool.set_executor(executor=bash_executor) + bash_tool = BashTool.create(conv_state, executor=bash_executor)[0] + + # Use the GrepTool.create() method with shared bash_executor + grep_tool = GrepTool.create(conv_state, bash_executor=bash_executor)[0] return [bash_tool, grep_tool] diff --git a/openhands-sdk/openhands/sdk/__init__.py b/openhands-sdk/openhands/sdk/__init__.py index 07ec5f48a1..a07bf60c25 100644 --- a/openhands-sdk/openhands/sdk/__init__.py +++ b/openhands-sdk/openhands/sdk/__init__.py @@ -37,7 +37,6 @@ Action, Observation, Tool, - ToolBase, ToolDefinition, list_registered_tools, register_tool, @@ -67,7 +66,6 @@ "RedactedThinkingBlock", "Tool", "ToolDefinition", - "ToolBase", "AgentBase", "Agent", "Action", diff --git a/openhands-sdk/openhands/sdk/agent/agent.py b/openhands-sdk/openhands/sdk/agent/agent.py index 7f746cdae6..23d65f60dd 100644 --- a/openhands-sdk/openhands/sdk/agent/agent.py +++ b/openhands-sdk/openhands/sdk/agent/agent.py @@ -37,7 +37,6 @@ from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer from openhands.sdk.tool import ( Action, - FinishTool, Observation, ) from openhands.sdk.tool.builtins import FinishAction, ThinkAction @@ -431,6 +430,6 @@ def _execute_action_event( on_event(obs_event) # Set conversation state - if tool.name == FinishTool.name: + if tool.name == "finish": state.agent_status = AgentExecutionStatus.FINISHED return obs_event diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index 62b82de3c9..cc47e5ed68 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -223,7 +223,9 @@ def _initialize(self, state: "ConversationState"): ) # Always include built-in tools; not subject to filtering - tools.extend(BUILT_IN_TOOLS) + # Instantiate built-in tools using their .create() method + for tool_class in BUILT_IN_TOOLS: + tools.extend(tool_class.create(state)) # Check tool types for tool in tools: diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index dd275cda20..7fea0e1335 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: # type hints only, avoid runtime import cycle - from openhands.sdk.tool.tool import ToolBase + from openhands.sdk.tool.tool import ToolDefinition from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff @@ -425,7 +425,7 @@ def restore_metrics(self, metrics: Metrics) -> None: def completion( self, messages: list[Message], - tools: Sequence[ToolBase] | None = None, + tools: Sequence[ToolDefinition] | None = None, _return_metrics: bool = False, add_security_risk_prediction: bool = False, **kwargs, @@ -562,7 +562,7 @@ def _one_attempt(**retry_kwargs) -> ModelResponse: def responses( self, messages: list[Message], - tools: Sequence[ToolBase] | None = None, + tools: Sequence[ToolDefinition] | None = None, include: list[str] | None = None, store: bool | None = None, _return_metrics: bool = False, diff --git a/openhands-sdk/openhands/sdk/llm/router/base.py b/openhands-sdk/openhands/sdk/llm/router/base.py index e6188a8562..cd908255e6 100644 --- a/openhands-sdk/openhands/sdk/llm/router/base.py +++ b/openhands-sdk/openhands/sdk/llm/router/base.py @@ -11,7 +11,7 @@ from openhands.sdk.llm.llm_response import LLMResponse from openhands.sdk.llm.message import Message from openhands.sdk.logger import get_logger -from openhands.sdk.tool.tool import ToolBase +from openhands.sdk.tool.tool import ToolDefinition logger = get_logger(__name__) @@ -49,7 +49,7 @@ def validate_llms_not_empty(cls, v): def completion( self, messages: list[Message], - tools: Sequence[ToolBase] | None = None, + tools: Sequence[ToolDefinition] | None = None, return_metrics: bool = False, add_security_risk_prediction: bool = False, **kwargs, diff --git a/openhands-sdk/openhands/sdk/mcp/utils.py b/openhands-sdk/openhands/sdk/mcp/utils.py index 5d193b5962..7c3252c37c 100644 --- a/openhands-sdk/openhands/sdk/mcp/utils.py +++ b/openhands-sdk/openhands/sdk/mcp/utils.py @@ -8,7 +8,7 @@ from openhands.sdk.logger import get_logger from openhands.sdk.mcp import MCPClient, MCPToolDefinition -from openhands.sdk.tool.tool import ToolBase +from openhands.sdk.tool.tool import ToolDefinition logger = get_logger(__name__) @@ -30,9 +30,9 @@ async def log_handler(message: LogMessage): logger.log(level, msg, extra=extra) -async def _list_tools(client: MCPClient) -> list[ToolBase]: +async def _list_tools(client: MCPClient) -> list[ToolDefinition]: """List tools from an MCP client.""" - tools: list[ToolBase] = [] + tools: list[ToolDefinition] = [] async with client: assert client.is_connected(), "MCP client is not connected." diff --git a/openhands-sdk/openhands/sdk/tool/__init__.py b/openhands-sdk/openhands/sdk/tool/__init__.py index 76381eabcb..b268574513 100644 --- a/openhands-sdk/openhands/sdk/tool/__init__.py +++ b/openhands-sdk/openhands/sdk/tool/__init__.py @@ -14,7 +14,6 @@ from openhands.sdk.tool.tool import ( ExecutableTool, ToolAnnotations, - ToolBase, ToolDefinition, ToolExecutor, ) @@ -23,7 +22,6 @@ __all__ = [ "Tool", "ToolDefinition", - "ToolBase", "ToolAnnotations", "ToolExecutor", "ExecutableTool", diff --git a/openhands-sdk/openhands/sdk/tool/builtins/finish.py b/openhands-sdk/openhands/sdk/tool/builtins/finish.py index 6d2ac10420..0b0aea9a29 100644 --- a/openhands-sdk/openhands/sdk/tool/builtins/finish.py +++ b/openhands-sdk/openhands/sdk/tool/builtins/finish.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from pydantic import Field from rich.text import Text @@ -16,6 +16,7 @@ if TYPE_CHECKING: from openhands.sdk.conversation.base import BaseConversation + from openhands.sdk.conversation.state import ConversationState class FinishAction(Action): @@ -67,17 +68,42 @@ def __call__( return FinishObservation(message=action.message) -FinishTool = ToolDefinition( - name="finish", - action_type=FinishAction, - observation_type=FinishObservation, - description=TOOL_DESCRIPTION, - executor=FinishExecutor(), - annotations=ToolAnnotations( - title="finish", - readOnlyHint=True, - destructiveHint=False, - idempotentHint=True, - openWorldHint=False, - ), -) +class FinishTool(ToolDefinition[FinishAction, FinishObservation]): + """Tool for signaling the completion of a task or conversation.""" + + @classmethod + def create( + cls, + conv_state: "ConversationState | None" = None, # noqa: ARG003 + **params, + ) -> Sequence[Self]: + """Create FinishTool instance. + + Args: + conv_state: Optional conversation state (not used by FinishTool). + **params: Additional parameters (none supported). + + Returns: + A sequence containing a single FinishTool instance. + + Raises: + ValueError: If any parameters are provided. + """ + if params: + raise ValueError("FinishTool doesn't accept parameters") + return [ + cls( + name="finish", + action_type=FinishAction, + observation_type=FinishObservation, + description=TOOL_DESCRIPTION, + executor=FinishExecutor(), + annotations=ToolAnnotations( + title="finish", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + ) + ] diff --git a/openhands-sdk/openhands/sdk/tool/builtins/think.py b/openhands-sdk/openhands/sdk/tool/builtins/think.py index 01d84d6ece..090f66c23d 100644 --- a/openhands-sdk/openhands/sdk/tool/builtins/think.py +++ b/openhands-sdk/openhands/sdk/tool/builtins/think.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from pydantic import Field from rich.text import Text @@ -16,6 +16,7 @@ if TYPE_CHECKING: from openhands.sdk.conversation.base import BaseConversation + from openhands.sdk.conversation.state import ConversationState class ThinkAction(Action): @@ -83,16 +84,41 @@ def __call__( return ThinkObservation() -ThinkTool = ToolDefinition( - name="think", - description=THINK_DESCRIPTION, - action_type=ThinkAction, - observation_type=ThinkObservation, - executor=ThinkExecutor(), - annotations=ToolAnnotations( - readOnlyHint=True, - destructiveHint=False, - idempotentHint=True, - openWorldHint=False, - ), -) +class ThinkTool(ToolDefinition[ThinkAction, ThinkObservation]): + """Tool for logging thoughts without making changes.""" + + @classmethod + def create( + cls, + conv_state: "ConversationState | None" = None, # noqa: ARG003 + **params, + ) -> Sequence[Self]: + """Create ThinkTool instance. + + Args: + conv_state: Optional conversation state (not used by ThinkTool). + **params: Additional parameters (none supported). + + Returns: + A sequence containing a single ThinkTool instance. + + Raises: + ValueError: If any parameters are provided. + """ + if params: + raise ValueError("ThinkTool doesn't accept parameters") + return [ + cls( + name="think", + description=THINK_DESCRIPTION, + action_type=ThinkAction, + observation_type=ThinkObservation, + executor=ThinkExecutor(), + annotations=ToolAnnotations( + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + ) + ] diff --git a/openhands-sdk/openhands/sdk/tool/registry.py b/openhands-sdk/openhands/sdk/tool/registry.py index 2a00dc3d6e..135d349f5c 100644 --- a/openhands-sdk/openhands/sdk/tool/registry.py +++ b/openhands-sdk/openhands/sdk/tool/registry.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from openhands.sdk.tool.spec import Tool -from openhands.sdk.tool.tool import ToolBase, ToolDefinition +from openhands.sdk.tool.tool import ToolDefinition if TYPE_CHECKING: @@ -85,7 +85,7 @@ def _is_abstract_method(cls: type, name: str) -> bool: return getattr(attr, "__isabstractmethod__", False) -def _resolver_from_subclass(_name: str, cls: type[ToolBase]) -> Resolver: +def _resolver_from_subclass(_name: str, cls: type[ToolDefinition]) -> Resolver: create = getattr(cls, "create", None) if create is None or not callable(create) or _is_abstract_method(cls, "create"): @@ -115,14 +115,16 @@ def _resolve( def register_tool( name: str, - factory: ToolDefinition | type[ToolBase] | Callable[..., Sequence[ToolDefinition]], + factory: ToolDefinition + | type[ToolDefinition] + | Callable[..., Sequence[ToolDefinition]], ) -> None: if not isinstance(name, str) or not name.strip(): raise ValueError("ToolDefinition name must be a non-empty string") if isinstance(factory, ToolDefinition): resolver = _resolver_from_instance(name, factory) - elif isinstance(factory, type) and issubclass(factory, ToolBase): + elif isinstance(factory, type) and issubclass(factory, ToolDefinition): resolver = _resolver_from_subclass(name, factory) elif callable(factory): resolver = _resolver_from_callable(name, factory) diff --git a/openhands-sdk/openhands/sdk/tool/tool.py b/openhands-sdk/openhands/sdk/tool/tool.py index c1e8603730..eb93ef9c1e 100644 --- a/openhands-sdk/openhands/sdk/tool/tool.py +++ b/openhands-sdk/openhands/sdk/tool/tool.py @@ -1,6 +1,13 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, Self, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Protocol, + Self, + TypeVar, +) from litellm import ( ChatCompletionToolParam, @@ -122,27 +129,35 @@ def __call__( ... -class ToolBase[ActionT, ObservationT](DiscriminatedUnionMixin, ABC): - """Base class for tools that agents can use to perform actions. +class ToolDefinition[ActionT, ObservationT](DiscriminatedUnionMixin, ABC): + """Base class for all tool implementations. - Tools wrap executor functions with input/output validation and schema definition. - They provide a standardized interface for agents to interact with external systems, - APIs, or perform specific operations. + This class serves as a base for the discriminated union of all tool types. + All tools must inherit from this class and implement the .create() method for + proper initialization with executors and parameters. Features: - - Normalize input/output schemas (class or dict) into both model+schema - - Validate inputs before execution - - Coerce outputs only if an output model is defined; else return vanilla JSON - - Export MCP (Model Context Protocol) tool descriptions - - Example: - >>> from openhands.sdk.tool import ToolDefinition - >>> tool = ToolDefinition( - ... name="echo", - ... description="Echo the input message", - ... action_type=EchoAction, - ... executor=echo_executor - ... ) + - Normalize input/output schemas (class or dict) into both model+schema. + - Validate inputs before execute. + - Coerce outputs only if an output model is defined; else return vanilla JSON. + - Export MCP tool description. + + Examples: + Simple tool with no parameters: + class FinishTool(ToolDefinition[FinishAction, FinishObservation]): + @classmethod + def create(cls, conv_state=None, **params): + return [cls(name="finish", ..., executor=FinishExecutor())] + + Complex tool with initialization parameters: + class BashTool(ToolDefinition[ExecuteBashAction, ExecuteBashObservation]): + @classmethod + def create(cls, conv_state, **params): + executor = BashExecutor( + working_dir=conv_state.workspace.working_dir, + **params, + ) + return [cls(name="execute_bash", ..., executor=executor)] """ model_config: ClassVar[ConfigDict] = ConfigDict( @@ -165,15 +180,21 @@ class ToolBase[ActionT, ObservationT](DiscriminatedUnionMixin, ABC): @classmethod @abstractmethod def create(cls, *args, **kwargs) -> Sequence[Self]: - """Create a sequence of Tool instances. Placeholder for subclasses. + """Create a sequence of Tool instances. + + This method must be implemented by all subclasses to provide custom + initialization logic, typically initializing the executor with parameters + from conv_state and other optional parameters. - This can be overridden in subclasses to provide custom initialization logic - (e.g., typically initializing the executor with parameters). + Args: + *args: Variable positional arguments (typically conv_state as first arg). + **kwargs: Optional parameters for tool initialization. Returns: A sequence of Tool instances. Even single tools are returned as a sequence to provide a consistent interface and eliminate union return types. """ + raise NotImplementedError("ToolDefinition subclasses must implement .create()") @computed_field(return_type=str, alias="title") @property @@ -380,36 +401,21 @@ def to_responses_tool( @classmethod def resolve_kind(cls, kind: str) -> type: - for subclass in get_known_concrete_subclasses(cls): - if subclass.__name__ == kind: - return subclass - # Fallback to "ToolDefinition" for unknown type - return ToolDefinition - + """Resolve a kind string to its corresponding tool class. -class ToolDefinition[ActionT, ObservationT](ToolBase[ActionT, ObservationT]): - """Concrete tool class that inherits from ToolBase. - - This class serves as a concrete implementation of ToolBase for cases where - you want to create a tool instance directly without implementing a custom - subclass. Built-in tools (like FinishTool, ThinkTool) are instantiated - directly from this class, while more complex tools (like BashTool, - FileEditorTool) inherit from this class and provide their own create() - method implementations. - """ + Args: + kind: The name of the tool class to resolve - @classmethod - def create(cls, *args, **kwargs) -> Sequence[Self]: - """Create a sequence of ToolDefinition instances. + Returns: + The tool class corresponding to the kind - TODO https://github.com/OpenHands/agent-sdk/issues/493 - Refactor this - the ToolDefinition class should not have a concrete create() - implementation. Built-in tools should be refactored to not rely on this - method, and then this should be made abstract with @abstractmethod. + Raises: + ValueError: If the kind is unknown """ - raise NotImplementedError( - "ToolDefinition.create() should be implemented by subclasses" - ) + for subclass in get_known_concrete_subclasses(cls): + if subclass.__name__ == kind: + return subclass + raise ValueError(f"Unknown kind '{kind}' for {cls}") def _create_action_type_with_risk(action_type: type[Schema]) -> type[Schema]: diff --git a/openhands-tools/openhands/tools/browser_use/__init__.py b/openhands-tools/openhands/tools/browser_use/__init__.py index 7c0d05009f..21227e274a 100644 --- a/openhands-tools/openhands/tools/browser_use/__init__.py +++ b/openhands-tools/openhands/tools/browser_use/__init__.py @@ -23,31 +23,10 @@ BrowserToolSet, BrowserTypeAction, BrowserTypeTool, - browser_click_tool, - browser_close_tab_tool, - browser_get_content_tool, - browser_get_state_tool, - browser_go_back_tool, - browser_list_tabs_tool, - browser_navigate_tool, - browser_scroll_tool, - browser_switch_tab_tool, - browser_type_tool, ) __all__ = [ - # Tool objects - "browser_navigate_tool", - "browser_click_tool", - "browser_type_tool", - "browser_get_state_tool", - "browser_get_content_tool", - "browser_scroll_tool", - "browser_go_back_tool", - "browser_list_tabs_tool", - "browser_switch_tab_tool", - "browser_close_tab_tool", # Tool classes "BrowserNavigateTool", "BrowserClickTool", diff --git a/openhands-tools/openhands/tools/browser_use/definition.py b/openhands-tools/openhands/tools/browser_use/definition.py index 39ce262380..8ee077d66e 100644 --- a/openhands-tools/openhands/tools/browser_use/definition.py +++ b/openhands-tools/openhands/tools/browser_use/definition.py @@ -12,7 +12,6 @@ ToolAnnotations, ToolDefinition, ) -from openhands.sdk.tool.tool import ToolBase from openhands.sdk.utils import maybe_truncate @@ -60,10 +59,23 @@ def to_llm_content(self) -> Sequence[TextContent | ImageContent]: return content +# ============================================ +# Base Browser Action +# ============================================ +class BrowserAction(Action): + """Base class for all browser actions. + + This base class serves as the parent for all browser-related actions, + enabling proper type hierarchy and eliminating the need for union types. + """ + + pass + + # ============================================ # `go_to_url` # ============================================ -class BrowserNavigateAction(Action): +class BrowserNavigateAction(BrowserAction): """Schema for browser navigation.""" url: str = Field(description="The URL to navigate to") @@ -85,20 +97,6 @@ class BrowserNavigateAction(Action): - Open GitHub in new tab: url="https://github.com", new_tab=True """ # noqa: E501 -browser_navigate_tool = ToolDefinition( - name="browser_navigate", - action_type=BrowserNavigateAction, - observation_type=BrowserObservation, - description=BROWSER_NAVIGATE_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_navigate", - readOnlyHint=False, - destructiveHint=False, - idempotentHint=False, - openWorldHint=True, - ), -) - class BrowserNavigateTool(ToolDefinition[BrowserNavigateAction, BrowserObservation]): """Tool for browser navigation.""" @@ -107,11 +105,17 @@ class BrowserNavigateTool(ToolDefinition[BrowserNavigateAction, BrowserObservati def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_navigate_tool.name, + name="browser_navigate", description=BROWSER_NAVIGATE_DESCRIPTION, action_type=BrowserNavigateAction, observation_type=BrowserObservation, - annotations=browser_navigate_tool.annotations, + annotations=ToolAnnotations( + title="browser_navigate", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), executor=executor, ) ] @@ -120,7 +124,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_click` # ============================================ -class BrowserClickAction(Action): +class BrowserClickAction(BrowserAction): """Schema for clicking elements.""" index: int = Field( @@ -144,20 +148,6 @@ class BrowserClickAction(Action): Important: Only use indices that appear in your current browser_get_state output. """ # noqa: E501 -browser_click_tool = ToolDefinition( - name="browser_click", - action_type=BrowserClickAction, - observation_type=BrowserObservation, - description=BROWSER_CLICK_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_click", - readOnlyHint=False, - destructiveHint=False, - idempotentHint=False, - openWorldHint=True, - ), -) - class BrowserClickTool(ToolDefinition[BrowserClickAction, BrowserObservation]): """Tool for clicking browser elements.""" @@ -166,11 +156,17 @@ class BrowserClickTool(ToolDefinition[BrowserClickAction, BrowserObservation]): def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_click_tool.name, + name="browser_click", description=BROWSER_CLICK_DESCRIPTION, action_type=BrowserClickAction, observation_type=BrowserObservation, - annotations=browser_click_tool.annotations, + annotations=ToolAnnotations( + title="browser_click", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), executor=executor, ) ] @@ -179,7 +175,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_type` # ============================================ -class BrowserTypeAction(Action): +class BrowserTypeAction(BrowserAction): """Schema for typing text into elements.""" index: int = Field( @@ -200,20 +196,6 @@ class BrowserTypeAction(Action): Important: Only use indices that appear in your current browser_get_state output. """ # noqa: E501 -browser_type_tool = ToolDefinition( - name="browser_type", - action_type=BrowserTypeAction, - observation_type=BrowserObservation, - description=BROWSER_TYPE_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_type", - readOnlyHint=False, - destructiveHint=False, - idempotentHint=False, - openWorldHint=True, - ), -) - class BrowserTypeTool(ToolDefinition[BrowserTypeAction, BrowserObservation]): """Tool for typing text into browser elements.""" @@ -222,11 +204,17 @@ class BrowserTypeTool(ToolDefinition[BrowserTypeAction, BrowserObservation]): def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_type_tool.name, + name="browser_type", description=BROWSER_TYPE_DESCRIPTION, action_type=BrowserTypeAction, observation_type=BrowserObservation, - annotations=browser_type_tool.annotations, + annotations=ToolAnnotations( + title="browser_type", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), executor=executor, ) ] @@ -235,7 +223,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_get_state` # ============================================ -class BrowserGetStateAction(Action): +class BrowserGetStateAction(BrowserAction): """Schema for getting browser state.""" include_screenshot: bool = Field( @@ -253,20 +241,6 @@ class BrowserGetStateAction(Action): - include_screenshot: Whether to include a screenshot (optional, default: False) """ # noqa: E501 -browser_get_state_tool = ToolDefinition( - name="browser_get_state", - action_type=BrowserGetStateAction, - observation_type=BrowserObservation, - description=BROWSER_GET_STATE_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_get_state", - readOnlyHint=True, - destructiveHint=False, - idempotentHint=True, - openWorldHint=True, - ), -) - class BrowserGetStateTool(ToolDefinition[BrowserGetStateAction, BrowserObservation]): """Tool for getting browser state.""" @@ -275,11 +249,17 @@ class BrowserGetStateTool(ToolDefinition[BrowserGetStateAction, BrowserObservati def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_get_state_tool.name, + name="browser_get_state", description=BROWSER_GET_STATE_DESCRIPTION, action_type=BrowserGetStateAction, observation_type=BrowserObservation, - annotations=browser_get_state_tool.annotations, + annotations=ToolAnnotations( + title="browser_get_state", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=True, + ), executor=executor, ) ] @@ -288,7 +268,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_get_content` # ============================================ -class BrowserGetContentAction(Action): +class BrowserGetContentAction(BrowserAction): """Schema for getting page content in markdown.""" extract_links: bool = Field( @@ -307,20 +287,6 @@ class BrowserGetContentAction(Action): If the content was truncated and you need more information, use start_from_char parameter to continue from where truncation occurred. """ # noqa: E501 -browser_get_content_tool = ToolDefinition( - name="browser_get_content", - action_type=BrowserGetContentAction, - observation_type=BrowserObservation, - description=BROWSER_GET_CONTENT_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_get_content", - readOnlyHint=True, - destructiveHint=False, - idempotentHint=True, - openWorldHint=True, - ), -) - class BrowserGetContentTool( ToolDefinition[BrowserGetContentAction, BrowserObservation] @@ -331,11 +297,17 @@ class BrowserGetContentTool( def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_get_content_tool.name, + name="browser_get_content", description=BROWSER_GET_CONTENT_DESCRIPTION, action_type=BrowserGetContentAction, observation_type=BrowserObservation, - annotations=browser_get_content_tool.annotations, + annotations=ToolAnnotations( + title="browser_get_content", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=True, + ), executor=executor, ) ] @@ -344,7 +316,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_scroll` # ============================================ -class BrowserScrollAction(Action): +class BrowserScrollAction(BrowserAction): """Schema for scrolling the page.""" direction: Literal["up", "down"] = Field( @@ -362,20 +334,6 @@ class BrowserScrollAction(Action): - direction: Direction to scroll - "up" or "down" (optional, default: "down") """ # noqa: E501 -browser_scroll_tool = ToolDefinition( - name="browser_scroll", - action_type=BrowserScrollAction, - observation_type=BrowserObservation, - description=BROWSER_SCROLL_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_scroll", - readOnlyHint=False, - destructiveHint=False, - idempotentHint=False, - openWorldHint=True, - ), -) - class BrowserScrollTool(ToolDefinition[BrowserScrollAction, BrowserObservation]): """Tool for scrolling the browser page.""" @@ -384,11 +342,17 @@ class BrowserScrollTool(ToolDefinition[BrowserScrollAction, BrowserObservation]) def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_scroll_tool.name, + name="browser_scroll", description=BROWSER_SCROLL_DESCRIPTION, action_type=BrowserScrollAction, observation_type=BrowserObservation, - annotations=browser_scroll_tool.annotations, + annotations=ToolAnnotations( + title="browser_scroll", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), executor=executor, ) ] @@ -397,7 +361,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_go_back` # ============================================ -class BrowserGoBackAction(Action): +class BrowserGoBackAction(BrowserAction): """Schema for going back in browser history.""" pass @@ -409,20 +373,6 @@ class BrowserGoBackAction(Action): browser's back button. """ # noqa: E501 -browser_go_back_tool = ToolDefinition( - name="browser_go_back", - action_type=BrowserGoBackAction, - observation_type=BrowserObservation, - description=BROWSER_GO_BACK_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_go_back", - readOnlyHint=False, - destructiveHint=False, - idempotentHint=False, - openWorldHint=True, - ), -) - class BrowserGoBackTool(ToolDefinition[BrowserGoBackAction, BrowserObservation]): """Tool for going back in browser history.""" @@ -431,11 +381,17 @@ class BrowserGoBackTool(ToolDefinition[BrowserGoBackAction, BrowserObservation]) def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_go_back_tool.name, + name="browser_go_back", description=BROWSER_GO_BACK_DESCRIPTION, action_type=BrowserGoBackAction, observation_type=BrowserObservation, - annotations=browser_go_back_tool.annotations, + annotations=ToolAnnotations( + title="browser_go_back", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), executor=executor, ) ] @@ -444,7 +400,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_list_tabs` # ============================================ -class BrowserListTabsAction(Action): +class BrowserListTabsAction(BrowserAction): """Schema for listing browser tabs.""" pass @@ -456,20 +412,6 @@ class BrowserListTabsAction(Action): with browser_switch_tab or browser_close_tab. """ # noqa: E501 -browser_list_tabs_tool = ToolDefinition( - name="browser_list_tabs", - action_type=BrowserListTabsAction, - observation_type=BrowserObservation, - description=BROWSER_LIST_TABS_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_list_tabs", - readOnlyHint=True, - destructiveHint=False, - idempotentHint=True, - openWorldHint=False, - ), -) - class BrowserListTabsTool(ToolDefinition[BrowserListTabsAction, BrowserObservation]): """Tool for listing browser tabs.""" @@ -478,11 +420,17 @@ class BrowserListTabsTool(ToolDefinition[BrowserListTabsAction, BrowserObservati def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_list_tabs_tool.name, + name="browser_list_tabs", description=BROWSER_LIST_TABS_DESCRIPTION, action_type=BrowserListTabsAction, observation_type=BrowserObservation, - annotations=browser_list_tabs_tool.annotations, + annotations=ToolAnnotations( + title="browser_list_tabs", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), executor=executor, ) ] @@ -491,7 +439,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_switch_tab` # ============================================ -class BrowserSwitchTabAction(Action): +class BrowserSwitchTabAction(BrowserAction): """Schema for switching browser tabs.""" tab_id: str = Field( @@ -508,36 +456,25 @@ class BrowserSwitchTabAction(Action): - tab_id: 4 Character Tab ID of the tab to switch to """ -browser_switch_tab_tool = ToolDefinition( - name="browser_switch_tab", - action_type=BrowserSwitchTabAction, - observation_type=BrowserObservation, - description=BROWSER_SWITCH_TAB_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_switch_tab", - readOnlyHint=False, - destructiveHint=False, - idempotentHint=False, - openWorldHint=False, - ), -) - class BrowserSwitchTabTool(ToolDefinition[BrowserSwitchTabAction, BrowserObservation]): """Tool for switching browser tabs.""" - # Override executor to be non-optional for initialized BrowserSwitchTabTool - # instances - @classmethod def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_switch_tab_tool.name, + name="browser_switch_tab", description=BROWSER_SWITCH_TAB_DESCRIPTION, action_type=BrowserSwitchTabAction, observation_type=BrowserObservation, - annotations=browser_switch_tab_tool.annotations, + annotations=ToolAnnotations( + title="browser_switch_tab", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=False, + ), executor=executor, ) ] @@ -546,7 +483,7 @@ def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: # ============================================ # `browser_close_tab` # ============================================ -class BrowserCloseTabAction(Action): +class BrowserCloseTabAction(BrowserAction): """Schema for closing browser tabs.""" tab_id: str = Field( @@ -562,20 +499,6 @@ class BrowserCloseTabAction(Action): - tab_id: 4 Character Tab ID of the tab to close """ -browser_close_tab_tool = ToolDefinition( - name="browser_close_tab", - action_type=BrowserCloseTabAction, - observation_type=BrowserObservation, - description=BROWSER_CLOSE_TAB_DESCRIPTION, - annotations=ToolAnnotations( - title="browser_close_tab", - readOnlyHint=False, - destructiveHint=True, - idempotentHint=False, - openWorldHint=False, - ), -) - class BrowserCloseTabTool(ToolDefinition[BrowserCloseTabAction, BrowserObservation]): """Tool for closing browser tabs.""" @@ -584,32 +507,23 @@ class BrowserCloseTabTool(ToolDefinition[BrowserCloseTabAction, BrowserObservati def create(cls, executor: "BrowserToolExecutor") -> Sequence[Self]: return [ cls( - name=browser_close_tab_tool.name, + name="browser_close_tab", description=BROWSER_CLOSE_TAB_DESCRIPTION, action_type=BrowserCloseTabAction, observation_type=BrowserObservation, - annotations=browser_close_tab_tool.annotations, + annotations=ToolAnnotations( + title="browser_close_tab", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=False, + openWorldHint=False, + ), executor=executor, ) ] -# Union type for all browser actions -BrowserAction = ( - BrowserNavigateAction - | BrowserClickAction - | BrowserTypeAction - | BrowserGetStateAction - | BrowserGetContentAction - | BrowserScrollAction - | BrowserGoBackAction - | BrowserListTabsAction - | BrowserSwitchTabAction - | BrowserCloseTabAction -) - - -class BrowserToolSet(ToolBase[BrowserAction, BrowserObservation]): +class BrowserToolSet(ToolDefinition[BrowserAction, BrowserObservation]): """A set of all browser tools. This tool set includes all available browser-related tools @@ -623,21 +537,25 @@ class BrowserToolSet(ToolBase[BrowserAction, BrowserObservation]): def create( cls, **executor_config, - ) -> list[ToolBase[BrowserAction, BrowserObservation]]: + ) -> list[ToolDefinition[BrowserAction, BrowserObservation]]: # Import executor only when actually needed to # avoid hanging during module import from openhands.tools.browser_use.impl import BrowserToolExecutor executor = BrowserToolExecutor(**executor_config) - return [ - browser_navigate_tool.set_executor(executor), - browser_click_tool.set_executor(executor), - browser_get_state_tool.set_executor(executor), - browser_get_content_tool.set_executor(executor), - browser_type_tool.set_executor(executor), - browser_scroll_tool.set_executor(executor), - browser_go_back_tool.set_executor(executor), - browser_list_tabs_tool.set_executor(executor), - browser_switch_tab_tool.set_executor(executor), - browser_close_tab_tool.set_executor(executor), - ] + # Each tool.create() returns a Sequence[Self], so we flatten the results + tools: list[ToolDefinition[BrowserAction, BrowserObservation]] = [] + for tool_class in [ + BrowserNavigateTool, + BrowserClickTool, + BrowserGetStateTool, + BrowserGetContentTool, + BrowserTypeTool, + BrowserScrollTool, + BrowserGoBackTool, + BrowserListTabsTool, + BrowserSwitchTabTool, + BrowserCloseTabTool, + ]: + tools.extend(tool_class.create(executor)) + return tools diff --git a/openhands-tools/openhands/tools/delegate/__init__.py b/openhands-tools/openhands/tools/delegate/__init__.py index c6a84b9383..ea9f3f813e 100644 --- a/openhands-tools/openhands/tools/delegate/__init__.py +++ b/openhands-tools/openhands/tools/delegate/__init__.py @@ -4,7 +4,6 @@ DelegateAction, DelegateObservation, DelegateTool, - delegate_tool, ) from openhands.tools.delegate.impl import DelegateExecutor @@ -14,5 +13,4 @@ "DelegateObservation", "DelegateExecutor", "DelegateTool", - "delegate_tool", ] diff --git a/openhands-tools/openhands/tools/delegate/definition.py b/openhands-tools/openhands/tools/delegate/definition.py index 2f8be795a6..f23d87de74 100644 --- a/openhands-tools/openhands/tools/delegate/definition.py +++ b/openhands-tools/openhands/tools/delegate/definition.py @@ -76,20 +76,6 @@ def to_llm_content(self) -> Sequence[TextContent | ImageContent]: - Sub-agents work in the same workspace as the main agent: {workspace_path} """ # noqa -delegate_tool = ToolDefinition( - name="delegate", - action_type=DelegateAction, - observation_type=DelegateObservation, - description=TOOL_DESCRIPTION, - annotations=ToolAnnotations( - title="delegate", - readOnlyHint=False, - destructiveHint=False, - idempotentHint=False, - openWorldHint=True, - ), -) - class DelegateTool(ToolDefinition[DelegateAction, DelegateObservation]): """A ToolDefinition subclass that automatically initializes a DelegateExecutor.""" @@ -124,11 +110,17 @@ def create( # Initialize the parent Tool with the executor return [ cls( - name=delegate_tool.name, - description=tool_description, + name="delegate", action_type=DelegateAction, observation_type=DelegateObservation, - annotations=delegate_tool.annotations, + description=tool_description, + annotations=ToolAnnotations( + title="delegate", + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), executor=executor, ) ] diff --git a/openhands-tools/openhands/tools/execute_bash/__init__.py b/openhands-tools/openhands/tools/execute_bash/__init__.py index b72257fd68..0a631e9d23 100644 --- a/openhands-tools/openhands/tools/execute_bash/__init__.py +++ b/openhands-tools/openhands/tools/execute_bash/__init__.py @@ -3,7 +3,6 @@ BashTool, ExecuteBashAction, ExecuteBashObservation, - execute_bash_tool, ) from openhands.tools.execute_bash.impl import BashExecutor @@ -18,7 +17,6 @@ __all__ = [ # === Core Tool Interface === "BashTool", - "execute_bash_tool", "ExecuteBashAction", "ExecuteBashObservation", "BashExecutor", diff --git a/openhands-tools/openhands/tools/execute_bash/definition.py b/openhands-tools/openhands/tools/execute_bash/definition.py index 575976c266..d1a39d18ea 100644 --- a/openhands-tools/openhands/tools/execute_bash/definition.py +++ b/openhands-tools/openhands/tools/execute_bash/definition.py @@ -17,6 +17,7 @@ Observation, ToolAnnotations, ToolDefinition, + ToolExecutor, ) from openhands.sdk.utils import maybe_truncate from openhands.tools.execute_bash.constants import ( @@ -217,21 +218,6 @@ def visualize(self) -> Text: """ # noqa -execute_bash_tool = ToolDefinition( - name="execute_bash", - action_type=ExecuteBashAction, - observation_type=ExecuteBashObservation, - description=TOOL_DESCRIPTION, - annotations=ToolAnnotations( - title="execute_bash", - readOnlyHint=False, - destructiveHint=True, - idempotentHint=False, - openWorldHint=True, - ), -) - - class BashTool(ToolDefinition[ExecuteBashAction, ExecuteBashObservation]): """A ToolDefinition subclass that automatically initializes a BashExecutor with auto-detection.""" # noqa: E501 @@ -242,6 +228,7 @@ def create( username: str | None = None, no_change_timeout_seconds: int | None = None, terminal_type: Literal["tmux", "subprocess"] | None = None, + executor: ToolExecutor | None = None, ) -> Sequence["BashTool"]: """Initialize BashTool with executor parameters. @@ -265,21 +252,28 @@ def create( raise ValueError(f"working_dir '{working_dir}' is not a valid directory") # Initialize the executor - executor = BashExecutor( - working_dir=working_dir, - username=username, - no_change_timeout_seconds=no_change_timeout_seconds, - terminal_type=terminal_type, - ) + if executor is None: + executor = BashExecutor( + working_dir=working_dir, + username=username, + no_change_timeout_seconds=no_change_timeout_seconds, + terminal_type=terminal_type, + ) # Initialize the parent ToolDefinition with the executor return [ cls( - name=execute_bash_tool.name, - description=TOOL_DESCRIPTION, + name="execute_bash", action_type=ExecuteBashAction, observation_type=ExecuteBashObservation, - annotations=execute_bash_tool.annotations, + description=TOOL_DESCRIPTION, + annotations=ToolAnnotations( + title="execute_bash", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=False, + openWorldHint=True, + ), executor=executor, ) ] diff --git a/openhands-tools/openhands/tools/file_editor/__init__.py b/openhands-tools/openhands/tools/file_editor/__init__.py index 961bc2987a..5c5c7b41a4 100644 --- a/openhands-tools/openhands/tools/file_editor/__init__.py +++ b/openhands-tools/openhands/tools/file_editor/__init__.py @@ -2,13 +2,11 @@ FileEditorAction, FileEditorObservation, FileEditorTool, - file_editor_tool, ) from openhands.tools.file_editor.impl import FileEditorExecutor, file_editor __all__ = [ - "file_editor_tool", "FileEditorAction", "FileEditorObservation", "file_editor", diff --git a/openhands-tools/openhands/tools/file_editor/definition.py b/openhands-tools/openhands/tools/file_editor/definition.py index 571baa5660..f3d96f9152 100644 --- a/openhands-tools/openhands/tools/file_editor/definition.py +++ b/openhands-tools/openhands/tools/file_editor/definition.py @@ -187,21 +187,6 @@ def _has_meaningful_diff(self) -> bool: """ # noqa: E501 -file_editor_tool = ToolDefinition( - name="str_replace_editor", - action_type=FileEditorAction, - observation_type=FileEditorObservation, - description=TOOL_DESCRIPTION, - annotations=ToolAnnotations( - title="str_replace_editor", - readOnlyHint=False, - destructiveHint=True, - idempotentHint=False, - openWorldHint=False, - ), -) - - class FileEditorTool(ToolDefinition[FileEditorAction, FileEditorObservation]): """A ToolDefinition subclass that automatically initializes a FileEditorExecutor.""" @@ -236,11 +221,17 @@ def create( # Initialize the parent Tool with the executor return [ cls( - name=file_editor_tool.name, - description=enhanced_description, + name="str_replace_editor", action_type=FileEditorAction, observation_type=FileEditorObservation, - annotations=file_editor_tool.annotations, + description=enhanced_description, + annotations=ToolAnnotations( + title="str_replace_editor", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=False, + openWorldHint=False, + ), executor=executor, ) ] diff --git a/openhands-tools/openhands/tools/task_tracker/__init__.py b/openhands-tools/openhands/tools/task_tracker/__init__.py index ca64609f07..c291f43f9c 100644 --- a/openhands-tools/openhands/tools/task_tracker/__init__.py +++ b/openhands-tools/openhands/tools/task_tracker/__init__.py @@ -3,7 +3,6 @@ TaskTrackerExecutor, TaskTrackerObservation, TaskTrackerTool, - task_tracker_tool, ) @@ -12,5 +11,4 @@ "TaskTrackerExecutor", "TaskTrackerObservation", "TaskTrackerTool", - "task_tracker_tool", ] diff --git a/openhands-tools/openhands/tools/task_tracker/definition.py b/openhands-tools/openhands/tools/task_tracker/definition.py index 0f04f3d9ae..82ad0529b6 100644 --- a/openhands-tools/openhands/tools/task_tracker/definition.py +++ b/openhands-tools/openhands/tools/task_tracker/definition.py @@ -391,20 +391,6 @@ def _save_tasks(self) -> None: systematic approach and ensures comprehensive requirement fulfillment.""" # noqa: E501 -task_tracker_tool = ToolDefinition( - name="task_tracker", - description=TASK_TRACKER_DESCRIPTION, - action_type=TaskTrackerAction, - observation_type=TaskTrackerObservation, - annotations=ToolAnnotations( - readOnlyHint=False, - destructiveHint=False, - idempotentHint=True, - openWorldHint=False, - ), -) - - class TaskTrackerTool(ToolDefinition[TaskTrackerAction, TaskTrackerObservation]): """A ToolDefinition subclass that automatically initializes a TaskTrackerExecutor.""" # noqa: E501 @@ -426,7 +412,12 @@ def create(cls, conv_state: "ConversationState") -> Sequence["TaskTrackerTool"]: description=TASK_TRACKER_DESCRIPTION, action_type=TaskTrackerAction, observation_type=TaskTrackerObservation, - annotations=task_tracker_tool.annotations, + annotations=ToolAnnotations( + readOnlyHint=False, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), executor=executor, ) ] diff --git a/tests/conftest.py b/tests/conftest.py index 654dc755fd..9d19019b61 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,17 @@ """Common test fixtures and utilities.""" +import uuid from unittest.mock import MagicMock import pytest from pydantic import SecretStr +from openhands.sdk import Agent +from openhands.sdk.conversation.state import ConversationState +from openhands.sdk.io import InMemoryFileStore from openhands.sdk.llm import LLM from openhands.sdk.tool import ToolExecutor +from openhands.sdk.workspace import LocalWorkspace @pytest.fixture @@ -22,6 +27,26 @@ def mock_llm(): ) +@pytest.fixture +def mock_conversation_state(mock_llm, tmp_path): + """Create a standard mock ConversationState for testing.""" + agent = Agent(llm=mock_llm) + workspace = LocalWorkspace(working_dir=str(tmp_path)) + + state = ConversationState( + id=uuid.uuid4(), + workspace=workspace, + persistence_dir=str(tmp_path / ".state"), + agent=agent, + ) + + # Set up filestore for state persistence + state._fs = InMemoryFileStore() + state._autosave_enabled = False + + return state + + @pytest.fixture def mock_tool(): """Create a mock tool for testing.""" diff --git a/tests/sdk/agent/test_agent_serialization.py b/tests/sdk/agent/test_agent_serialization.py index 707bc4a017..0636c42978 100644 --- a/tests/sdk/agent/test_agent_serialization.py +++ b/tests/sdk/agent/test_agent_serialization.py @@ -13,7 +13,7 @@ from openhands.sdk.llm import LLM from openhands.sdk.mcp.client import MCPClient from openhands.sdk.mcp.tool import MCPToolDefinition -from openhands.sdk.tool.tool import ToolBase +from openhands.sdk.tool.tool import ToolDefinition from openhands.sdk.utils.models import OpenHandsModel @@ -55,7 +55,7 @@ def test_agent_supports_polymorphic_json_serialization() -> None: def test_mcp_tool_serialization(): tool = create_mock_mcp_tool("test_mcp_tool_serialization") dumped = tool.model_dump_json() - loaded = ToolBase.model_validate_json(dumped) + loaded = ToolDefinition.model_validate_json(dumped) assert loaded.model_dump_json() == dumped diff --git a/tests/sdk/agent/test_agent_tool_init.py b/tests/sdk/agent/test_agent_tool_init.py index 5d4ce971c3..e77a8b86d5 100644 --- a/tests/sdk/agent/test_agent_tool_init.py +++ b/tests/sdk/agent/test_agent_tool_init.py @@ -27,16 +27,24 @@ def __call__(self, action: _Action, conversation=None) -> _Obs: return _Obs(out=action.text.upper()) +class _UpperTool(ToolDefinition[_Action, _Obs]): + """Concrete tool for uppercase testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["_UpperTool"]: + return [ + cls( + name="upper", + description="Uppercase", + action_type=_Action, + observation_type=_Obs, + executor=_Exec(), + ) + ] + + def _make_tool(conv_state=None, **kwargs) -> Sequence[ToolDefinition]: - return [ - ToolDefinition( - name="upper", - description="Uppercase", - action_type=_Action, - observation_type=_Obs, - executor=_Exec(), - ) - ] + return _UpperTool.create(conv_state, **kwargs) def test_agent_initializes_tools_from_toolspec_locally(monkeypatch): diff --git a/tests/sdk/agent/test_message_while_finishing.py b/tests/sdk/agent/test_message_while_finishing.py index 5d6333cace..f6bd7f4779 100644 --- a/tests/sdk/agent/test_message_while_finishing.py +++ b/tests/sdk/agent/test_message_while_finishing.py @@ -129,17 +129,25 @@ def __call__(self, action: SleepAction, conversation=None) -> SleepObservation: return SleepObservation(message=action.message) +class SleepTool(ToolDefinition[SleepAction, SleepObservation]): + """Sleep tool for testing message processing during finish.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["SleepTool"]: + return [ + cls( + name="sleep_tool", + action_type=SleepAction, + observation_type=SleepObservation, + description="Sleep for specified duration and return a message", + executor=SleepExecutor(), + ) + ] + + def _make_sleep_tool(conv_state=None, **kwargs) -> Sequence[ToolDefinition]: """Create sleep tool for testing.""" - return [ - ToolDefinition( - name="sleep_tool", - action_type=SleepAction, - observation_type=SleepObservation, - description="Sleep for specified duration and return a message", - executor=SleepExecutor(), - ) - ] + return SleepTool.create(conv_state, **kwargs) # Register the tool diff --git a/tests/sdk/conversation/local/test_agent_status_transition.py b/tests/sdk/conversation/local/test_agent_status_transition.py index 9f0e518448..200aaed97e 100644 --- a/tests/sdk/conversation/local/test_agent_status_transition.py +++ b/tests/sdk/conversation/local/test_agent_status_transition.py @@ -78,22 +78,36 @@ def __call__( return StatusTransitionMockObservation(result=f"Executed: {action.command}") -@patch("openhands.sdk.llm.llm.litellm_completion") -def test_agent_status_transitions_to_running_from_idle(mock_completion): - """Test that agent status transitions to RUNNING when run() is called from IDLE.""" - status_during_execution: list[AgentExecutionStatus] = [] +class StatusTransitionTestTool( + ToolDefinition[StatusTransitionMockAction, StatusTransitionMockObservation] +): + """Concrete tool for status transition testing.""" - def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: + @classmethod + def create( + cls, conv_state=None, *, executor: ToolExecutor, **params + ) -> Sequence["StatusTransitionTestTool"]: return [ - ToolDefinition( + cls( name="test_tool", description="A test tool", action_type=StatusTransitionMockAction, observation_type=StatusTransitionMockObservation, - executor=StatusCheckingExecutor(status_during_execution), + executor=executor, ) ] + +@patch("openhands.sdk.llm.llm.litellm_completion") +def test_agent_status_transitions_to_running_from_idle(mock_completion): + """Test that agent status transitions to RUNNING when run() is called from IDLE.""" + status_during_execution: list[AgentExecutionStatus] = [] + + def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: + return StatusTransitionTestTool.create( + executor=StatusCheckingExecutor(status_during_execution) + ) + register_tool("test_tool", _make_tool) llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm") @@ -137,15 +151,9 @@ def test_agent_status_is_running_during_execution_from_idle(mock_completion): execution_started = threading.Event() def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: - return [ - ToolDefinition( - name="test_tool", - description="A test tool", - action_type=StatusTransitionMockAction, - observation_type=StatusTransitionMockObservation, - executor=StatusCheckingExecutor(status_during_execution), - ) - ] + return StatusTransitionTestTool.create( + executor=StatusCheckingExecutor(status_during_execution) + ) register_tool("test_tool", _make_tool) @@ -293,15 +301,7 @@ def test_agent_status_transitions_from_waiting_for_confirmation(mock_completion) llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm") def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: - return [ - ToolDefinition( - name="test_tool", - description="A test tool", - action_type=StatusTransitionMockAction, - observation_type=StatusTransitionMockObservation, - executor=StatusCheckingExecutor([]), - ) - ] + return StatusTransitionTestTool.create(executor=StatusCheckingExecutor([])) register_tool("test_tool", _make_tool) diff --git a/tests/sdk/conversation/local/test_confirmation_mode.py b/tests/sdk/conversation/local/test_confirmation_mode.py index 1c863c61f2..39eafbe956 100644 --- a/tests/sdk/conversation/local/test_confirmation_mode.py +++ b/tests/sdk/conversation/local/test_confirmation_mode.py @@ -58,6 +58,42 @@ def to_llm_content(self) -> Sequence[TextContent | ImageContent]: return [TextContent(text=self.result)] +class TestExecutor( + ToolExecutor[MockConfirmationModeAction, MockConfirmationModeObservation] +): + """Test executor for confirmation mode testing.""" + + def __call__( + self, + action: MockConfirmationModeAction, + conversation=None, # noqa: ARG002 + ) -> MockConfirmationModeObservation: + return MockConfirmationModeObservation(result=f"Executed: {action.command}") + + +class ConfirmationTestTool( + ToolDefinition[MockConfirmationModeAction, MockConfirmationModeObservation] +): + """Concrete tool for confirmation mode testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["ConfirmationTestTool"]: + return [ + cls( + name="test_tool", + description="A test tool", + action_type=MockConfirmationModeAction, + observation_type=MockConfirmationModeObservation, + executor=TestExecutor(), + ) + ] + + +def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: + """Factory function for creating test tools.""" + return ConfirmationTestTool.create(conv_state, **params) + + class TestConfirmationMode: """Test suite for confirmation mode functionality.""" @@ -91,29 +127,6 @@ def setup_method(self): ) self.mock_llm.metrics.get_snapshot.return_value = mock_metrics_snapshot - class TestExecutor( - ToolExecutor[MockConfirmationModeAction, MockConfirmationModeObservation] - ): - def __call__( - self, - action: MockConfirmationModeAction, - conversation=None, # noqa: ARG002 - ) -> MockConfirmationModeObservation: - return MockConfirmationModeObservation( - result=f"Executed: {action.command}" - ) - - def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: - return [ - ToolDefinition( - name="test_tool", - description="A test tool", - action_type=MockConfirmationModeAction, - observation_type=MockConfirmationModeObservation, - executor=TestExecutor(), - ) - ] - register_tool("test_tool", _make_tool) self.agent: Agent = Agent( diff --git a/tests/sdk/conversation/local/test_conversation_pause_functionality.py b/tests/sdk/conversation/local/test_conversation_pause_functionality.py index 6a95a8d7b4..be5ed82f57 100644 --- a/tests/sdk/conversation/local/test_conversation_pause_functionality.py +++ b/tests/sdk/conversation/local/test_conversation_pause_functionality.py @@ -77,6 +77,66 @@ def __call__( return PauseFunctionalityMockObservation(result=f"Executed: {action.command}") +class TestExecutor( + ToolExecutor[PauseFunctionalityMockAction, PauseFunctionalityMockObservation] +): + """Test executor for pause functionality testing.""" + + def __call__( + self, + action: PauseFunctionalityMockAction, + conversation: BaseConversation | None = None, + ) -> PauseFunctionalityMockObservation: + return PauseFunctionalityMockObservation(result=f"Executed: {action.command}") + + +class PauseFunctionalityTestTool( + ToolDefinition[PauseFunctionalityMockAction, PauseFunctionalityMockObservation] +): + """Concrete tool for pause functionality testing.""" + + @classmethod + def create( + cls, conv_state=None, **params + ) -> Sequence["PauseFunctionalityTestTool"]: + return [ + cls( + name="test_tool", + description="A test tool", + action_type=PauseFunctionalityMockAction, + observation_type=PauseFunctionalityMockObservation, + executor=TestExecutor(), + ) + ] + + +def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: + """Factory function for creating test tools.""" + return PauseFunctionalityTestTool.create(conv_state, **params) + + +class BlockingTestTool( + ToolDefinition[PauseFunctionalityMockAction, PauseFunctionalityMockObservation] +): + """Concrete tool for blocking pause testing.""" + + @classmethod + def create( + cls, conv_state=None, step_entered=None, **params + ) -> Sequence["BlockingTestTool"]: + if step_entered is None: + raise ValueError("step_entered is required for BlockingTestTool") + return [ + cls( + name="test_tool", + description="Blocking tool for pause test", + action_type=PauseFunctionalityMockAction, + observation_type=PauseFunctionalityMockObservation, + executor=BlockingExecutor(step_entered), + ) + ] + + class TestPauseFunctionality: """Test suite for pause functionality.""" @@ -87,31 +147,6 @@ def setup_method(self): model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" ) - class TestExecutor( - ToolExecutor[ - PauseFunctionalityMockAction, PauseFunctionalityMockObservation - ] - ): - def __call__( - self, - action: PauseFunctionalityMockAction, - conversation: BaseConversation | None = None, - ) -> PauseFunctionalityMockObservation: - return PauseFunctionalityMockObservation( - result=f"Executed: {action.command}" - ) - - def _make_tool(conv_state=None, **params) -> Sequence[ToolDefinition]: - return [ - ToolDefinition( - name="test_tool", - description="A test tool", - action_type=PauseFunctionalityMockAction, - observation_type=PauseFunctionalityMockObservation, - executor=TestExecutor(), - ) - ] - register_tool("test_tool", _make_tool) self.agent: Agent = Agent( @@ -309,15 +344,9 @@ def test_pause_while_running_continuous_actions(self, mock_completion): step_entered = threading.Event() def _make_blocking_tool(conv_state=None, **kwargs) -> Sequence[ToolDefinition]: - return [ - ToolDefinition( - name="test_tool", - description="Blocking tool for pause test", - action_type=PauseFunctionalityMockAction, - observation_type=PauseFunctionalityMockObservation, - executor=BlockingExecutor(step_entered), - ) - ] + return BlockingTestTool.create( + conv_state, step_entered=step_entered, **kwargs + ) register_tool("test_tool", _make_blocking_tool) agent = Agent( diff --git a/tests/sdk/llm/test_llm_completion.py b/tests/sdk/llm/test_llm_completion.py index 8bf9e8a2ac..dcf93b6661 100644 --- a/tests/sdk/llm/test_llm_completion.py +++ b/tests/sdk/llm/test_llm_completion.py @@ -1,5 +1,6 @@ """Tests for LLM completion functionality, configuration, and metrics tracking.""" +from collections.abc import Sequence from unittest.mock import patch import pytest @@ -19,7 +20,7 @@ TextContent, ) from openhands.sdk.tool.schema import Action -from openhands.sdk.tool.tool import ToolBase, ToolDefinition +from openhands.sdk.tool.tool import ToolDefinition def create_mock_response(content: str = "Test response", response_id: str = "test-id"): @@ -41,6 +42,23 @@ def create_mock_response(content: str = "Test response", response_id: str = "tes ) +# Helper tool classes for testing +class _ArgsBasic(Action): + """Basic action for testing.""" + + param: str + + +class _MockTool(ToolDefinition[_ArgsBasic, None]): + """Mock tool for LLM completion testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["_MockTool"]: + return [ + cls(name="test_tool", description="A test tool", action_type=_ArgsBasic) + ] + + @pytest.fixture def default_config(): return LLM( @@ -127,13 +145,7 @@ def test_llm_completion_with_tools(mock_completion): # Test completion with tools messages = [Message(role="user", content=[TextContent(text="Use the test tool")])] - class _ArgsBasic(Action): - param: str - - tool: ToolBase = ToolDefinition( - name="test_tool", description="A test tool", action_type=_ArgsBasic - ) - tools_list: list[ToolBase] = [tool] + tools_list = list(_MockTool.create()) response = llm.completion(messages=messages, tools=tools_list) @@ -333,16 +345,7 @@ def test_llm_completion_non_function_call_mode(mock_completion): ) ] - class TestNonFCArgs(Action): - param: str - - tools: list[ToolBase] = [ - ToolDefinition( - name="test_tool", - description="A test tool for non-function call mode", - action_type=TestNonFCArgs, - ) - ] + tools = list(_MockTool.create()) # Verify that tools should be mocked (non-function call path) cc_tools = [t.to_openai_tool(add_security_risk_prediction=False) for t in tools] @@ -390,14 +393,7 @@ def test_llm_completion_function_call_vs_non_function_call_mode(mock_completion) mock_response = create_mock_response("Test response") mock_completion.return_value = mock_response - class TestFCArgs(Action): - param: str | None = None - - tools: list[ToolBase] = [ - ToolDefinition( - name="test_tool", description="A test tool", action_type=TestFCArgs - ) - ] + tools = list(_MockTool.create()) messages = [Message(role="user", content=[TextContent(text="Use the test tool")])] # Test with native function calling enabled (default behavior for gpt-4o) diff --git a/tests/sdk/mcp/test_mcp_tool_serialization.py b/tests/sdk/mcp/test_mcp_tool_serialization.py index a073927610..a8f7e0d037 100644 --- a/tests/sdk/mcp/test_mcp_tool_serialization.py +++ b/tests/sdk/mcp/test_mcp_tool_serialization.py @@ -13,7 +13,7 @@ from openhands.sdk.mcp.definition import MCPToolAction, MCPToolObservation from openhands.sdk.mcp.tool import MCPToolDefinition from openhands.sdk.tool.schema import Action -from openhands.sdk.tool.tool import ToolBase +from openhands.sdk.tool.tool import ToolDefinition def create_mock_mcp_tool(name: str) -> mcp.types.Tool: @@ -57,8 +57,8 @@ def test_mcp_tool_polymorphic_behavior() -> None: tools = MCPToolDefinition.create(mock_mcp_tool, mock_client) mcp_tool = tools[0] # Extract single tool from sequence - # Should be instance of ToolBase - assert isinstance(mcp_tool, ToolBase) + # Should be instance of ToolDefinition + assert isinstance(mcp_tool, ToolDefinition) assert isinstance(mcp_tool, MCPToolDefinition) # Check basic properties @@ -99,8 +99,8 @@ def test_mcp_tool_fallback_behavior() -> None: }, } - deserialized_tool = ToolBase.model_validate(tool_data) - assert isinstance(deserialized_tool, ToolBase) + deserialized_tool = ToolDefinition.model_validate(tool_data) + assert isinstance(deserialized_tool, ToolDefinition) assert deserialized_tool.name == "fallback-tool" assert issubclass(deserialized_tool.action_type, Action) assert deserialized_tool.observation_type and issubclass( diff --git a/tests/sdk/tool/test_builtins.py b/tests/sdk/tool/test_builtins.py index ca3da4caa8..d69f0d589d 100644 --- a/tests/sdk/tool/test_builtins.py +++ b/tests/sdk/tool/test_builtins.py @@ -2,13 +2,22 @@ def test_all_tools_property(): - for tool in BUILT_IN_TOOLS: - assert tool.description is not None - assert tool.executor is not None - assert tool.annotations is not None - # Annotations should have specific hints - # Builtin tools should have all these properties - assert tool.annotations.readOnlyHint - assert not tool.annotations.destructiveHint - assert tool.annotations.idempotentHint - assert not tool.annotations.openWorldHint + # BUILT_IN_TOOLS contains tool classes, so we need to instantiate them + for tool_class in BUILT_IN_TOOLS: + # Create tool instances using .create() method + tool_instances = tool_class.create() + assert len(tool_instances) > 0, ( + f"{tool_class.__name__}.create() should return at least one tool" + ) + + # Check properties for all instances (usually just one) + for tool in tool_instances: + assert tool.description is not None + assert tool.executor is not None + assert tool.annotations is not None + # Annotations should have specific hints + # Builtin tools should have all these properties + assert tool.annotations.readOnlyHint + assert not tool.annotations.destructiveHint + assert tool.annotations.idempotentHint + assert not tool.annotations.openWorldHint diff --git a/tests/sdk/tool/test_registry.py b/tests/sdk/tool/test_registry.py index 5cdb176f73..8192544952 100644 --- a/tests/sdk/tool/test_registry.py +++ b/tests/sdk/tool/test_registry.py @@ -69,16 +69,24 @@ def __call__( ] +class _SimpleHelloTool(ToolDefinition[_HelloAction, _HelloObservation]): + """Simple concrete tool for registry testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["_SimpleHelloTool"]: + return [ + cls( + name="say_hello", + description="Says hello", + action_type=_HelloAction, + observation_type=_HelloObservation, + executor=_HelloExec(), + ) + ] + + def _hello_tool_factory(conv_state=None, **params) -> list[ToolDefinition]: - return [ - ToolDefinition( - name="say_hello", - description="Says hello", - action_type=_HelloAction, - observation_type=_HelloObservation, - executor=_HelloExec(), - ) - ] + return list(_SimpleHelloTool.create(conv_state, **params)) def test_register_and_resolve_callable_factory(): diff --git a/tests/sdk/tool/test_to_responses_tool.py b/tests/sdk/tool/test_to_responses_tool.py index c33d90659f..a0b4985faf 100644 --- a/tests/sdk/tool/test_to_responses_tool.py +++ b/tests/sdk/tool/test_to_responses_tool.py @@ -1,5 +1,5 @@ from openhands.sdk.tool.schema import Action, Observation -from openhands.sdk.tool.tool import ToolBase +from openhands.sdk.tool.tool import ToolDefinition class A(Action): @@ -13,7 +13,7 @@ def to_llm_content(self): # type: ignore[override] return [TextContent(text="ok")] -class T(ToolBase[A, Obs]): +class T(ToolDefinition[A, Obs]): @classmethod def create(cls, *args, **kwargs): # pragma: no cover raise NotImplementedError diff --git a/tests/sdk/tool/test_to_responses_tool_security.py b/tests/sdk/tool/test_to_responses_tool_security.py index 87b20c4c97..7eb7ab4eb4 100644 --- a/tests/sdk/tool/test_to_responses_tool_security.py +++ b/tests/sdk/tool/test_to_responses_tool_security.py @@ -1,15 +1,25 @@ +from collections.abc import Sequence + from pydantic import Field -from openhands.sdk.tool import Action, ToolAnnotations, ToolDefinition +from openhands.sdk.tool import Action, Observation, ToolAnnotations, ToolDefinition class TRTSAction(Action): x: int = Field(description="x") +class MockSecurityTool(ToolDefinition[TRTSAction, Observation]): + """Concrete mock tool for security testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["MockSecurityTool"]: + return [cls(**params)] + + def test_to_responses_tool_security_gating(): # readOnlyHint=True -> do not add security_risk even if requested - readonly = ToolDefinition( + readonly = MockSecurityTool( name="t1", description="d", action_type=TRTSAction, @@ -24,7 +34,7 @@ def test_to_responses_tool_security_gating(): assert "security_risk" not in props # readOnlyHint=False -> add when requested - writable = ToolDefinition( + writable = MockSecurityTool( name="t2", description="d", action_type=TRTSAction, @@ -39,7 +49,7 @@ def test_to_responses_tool_security_gating(): assert "security_risk" in props2 # add_security_risk_prediction=False -> never add - noflag = ToolDefinition( + noflag = MockSecurityTool( name="t3", description="d", action_type=TRTSAction, diff --git a/tests/sdk/tool/test_tool_call_output_coercion.py b/tests/sdk/tool/test_tool_call_output_coercion.py index 2d32ba5461..23f3caabbb 100644 --- a/tests/sdk/tool/test_tool_call_output_coercion.py +++ b/tests/sdk/tool/test_tool_call_output_coercion.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import pytest from pydantic import Field @@ -19,6 +21,14 @@ def to_llm_content(self): # type: ignore[override] return [TextContent(text=str(self.value))] +class MockCoercionTool(ToolDefinition[OCAAction, OCAObs]): + """Concrete mock tool for output coercion testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["MockCoercionTool"]: + return [cls(**params)] + + def test_tool_call_with_observation_none_result_shapes(): # When observation_type is None, results are wrapped/coerced to Observation # 1) dict -> Observation @@ -26,7 +36,7 @@ class E1(ToolExecutor[OCAAction, dict[str, object]]): def __call__(self, action: OCAAction, conversation=None) -> dict[str, object]: return {"kind": "OCAObs", "value": 1} - t = ToolDefinition( + t = MockCoercionTool( name="t", description="d", action_type=OCAAction, @@ -50,7 +60,7 @@ class E2(ToolExecutor[OCAAction, MObs]): def __call__(self, action: OCAAction, conversation=None) -> MObs: return MObs(value=2) - t2 = ToolDefinition( + t2 = MockCoercionTool( name="t2", description="d", action_type=OCAAction, @@ -66,7 +76,7 @@ class E3(ToolExecutor[OCAAction, list[int]]): def __call__(self, action: OCAAction, conversation=None) -> list[int]: return [1, 2, 3] - t3 = ToolDefinition( + t3 = MockCoercionTool( name="t3", description="d", action_type=OCAAction, diff --git a/tests/sdk/tool/test_tool_definition.py b/tests/sdk/tool/test_tool_definition.py index 9c50a46b68..4550918332 100644 --- a/tests/sdk/tool/test_tool_definition.py +++ b/tests/sdk/tool/test_tool_definition.py @@ -36,12 +36,20 @@ def to_llm_content(self) -> Sequence[TextContent | ImageContent]: return [TextContent(text=self.result)] +class MockTestTool(ToolDefinition[ToolMockAction, ToolMockObservation]): + """Concrete mock tool for testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["MockTestTool"]: + return [cls(**params)] + + class TestTool: """Test cases for the Tool class.""" def test_tool_creation_basic(self): """Test basic tool creation.""" - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -61,7 +69,7 @@ class MockExecutor(ToolExecutor): def __call__(self, action, conversation=None) -> ToolMockObservation: return ToolMockObservation(result=f"Executed: {action.command}") - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -84,7 +92,7 @@ def test_tool_creation_with_annotations(self): destructiveHint=False, ) - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -100,7 +108,7 @@ def test_tool_creation_with_annotations(self): def test_to_mcp_tool_basic(self): """Test conversion to MCP tool format.""" - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -129,7 +137,7 @@ def test_to_mcp_tool_with_annotations(self): readOnlyHint=True, ) - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -147,7 +155,7 @@ def test_to_mcp_tool_with_annotations(self): def test_call_without_executor(self): """Test calling tool without executor raises error.""" - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -167,7 +175,7 @@ class MockExecutor(ToolExecutor): def __call__(self, action, conversation=None) -> ToolMockObservation: return ToolMockObservation(result=f"Processed: {action.command}") - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -193,7 +201,7 @@ class ComplexAction(Action): default_factory=list, description="List of strings" ) - tool = ToolDefinition( + tool = MockTestTool( name="complex_tool", description="Tool with complex types", action_type=ComplexAction, @@ -217,7 +225,7 @@ class MockExecutor(ToolExecutor): def __call__(self, action, conversation=None) -> ToolMockObservation: return ToolMockObservation(result="success") - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -239,7 +247,7 @@ class MockExecutor(ToolExecutor): def __call__(self, action, conversation=None) -> ToolMockObservation: return ToolMockObservation(result="test", extra_field="extra_data") - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -256,7 +264,7 @@ def __call__(self, action, conversation=None) -> ToolMockObservation: def test_action_validation_with_nested_data(self): """Test action validation with nested data structures.""" - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -282,7 +290,7 @@ def test_schema_roundtrip_conversion(self): original_schema = ToolMockAction.to_mcp_schema() # Create tool and get its schema - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -298,7 +306,7 @@ def test_schema_roundtrip_conversion(self): def test_tool_with_no_observation_type(self): """Test tool creation with None observation type.""" - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -321,7 +329,7 @@ def __call__(self, action, conversation=None) -> ToolMockObservation: executor = MockExecutor() - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -339,7 +347,7 @@ def __call__(self, action, conversation=None) -> ToolMockObservation: def test_tool_name_validation(self): """Test tool name validation.""" # Valid names should work - tool = ToolDefinition( + tool = MockTestTool( name="valid_tool_name", description="A test tool", action_type=ToolMockAction, @@ -348,7 +356,7 @@ def test_tool_name_validation(self): assert tool.name == "valid_tool_name" # Empty name should still work (validation might be elsewhere) - tool2 = ToolDefinition( + tool2 = MockTestTool( name="", description="A test tool", action_type=ToolMockAction, @@ -376,7 +384,7 @@ def __call__(self, action, conversation=None) -> ComplexObservation: count=len(action.command) if hasattr(action, "command") else 0, ) - tool = ToolDefinition( + tool = MockTestTool( name="complex_tool", description="Tool with complex observation", action_type=ToolMockAction, @@ -398,7 +406,7 @@ class FailingExecutor(ToolExecutor): def __call__(self, action, conversation=None) -> ToolMockObservation: raise RuntimeError("Executor failed") - tool = ToolDefinition( + tool = MockTestTool( name="failing_tool", description="Tool that fails", action_type=ToolMockAction, @@ -425,7 +433,7 @@ class ValidExecutor(ToolExecutor): def __call__(self, action, conversation=None) -> StrictObservation: return StrictObservation(message="success", value=42) - tool = ToolDefinition( + tool = MockTestTool( name="strict_tool", description="Tool with strict observation", action_type=ToolMockAction, @@ -441,14 +449,14 @@ def __call__(self, action, conversation=None) -> StrictObservation: def test_tool_equality_and_hashing(self): """Test tool equality and hashing behavior.""" - tool1 = ToolDefinition( + tool1 = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, observation_type=ToolMockObservation, ) - tool2 = ToolDefinition( + tool2 = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -469,7 +477,7 @@ class RequiredFieldAction(Action): default=None, description="This field is optional" ) - tool = ToolDefinition( + tool = MockTestTool( name="required_tool", description="Tool with required fields", action_type=RequiredFieldAction, @@ -488,7 +496,7 @@ def test_tool_with_meta_data(self): """Test tool creation with metadata.""" meta_data = {"version": "1.0", "author": "test"} - tool = ToolDefinition( + tool = MockTestTool( name="meta_tool", description="Tool with metadata", action_type=ToolMockAction, @@ -525,7 +533,7 @@ class ComplexNestedAction(Action): default=None, description="Optional array" ) - tool = ToolDefinition( + tool = MockTestTool( name="complex_nested_tool", description="Tool with complex nested types", action_type=ComplexNestedAction, @@ -574,7 +582,7 @@ def test_security_risk_only_added_for_non_readonly_tools(self): readOnlyHint=True, ) - readonly_tool = ToolDefinition( + readonly_tool = MockTestTool( name="readonly_tool", description="A read-only tool", action_type=ToolMockAction, @@ -588,7 +596,7 @@ def test_security_risk_only_added_for_non_readonly_tools(self): readOnlyHint=False, ) - writable_tool = ToolDefinition( + writable_tool = MockTestTool( name="writable_tool", description="A writable tool", action_type=ToolMockAction, @@ -597,7 +605,7 @@ def test_security_risk_only_added_for_non_readonly_tools(self): ) # Test with tool that has no annotations (should be treated as writable) - no_annotations_tool = ToolDefinition( + no_annotations_tool = MockTestTool( name="no_annotations_tool", description="A tool with no annotations", action_type=ToolMockAction, @@ -662,7 +670,7 @@ def test_security_risk_is_required_field_in_schema(self): assert "security_risk" in schema["required"] # Test via to_openai_tool method - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -686,7 +694,7 @@ def test_security_risk_is_required_field_in_schema(self): readOnlyHint=False, ) - writable_tool = ToolDefinition( + writable_tool = MockTestTool( name="writable_tool", description="A writable tool", action_type=ToolMockAction, @@ -713,7 +721,7 @@ def __call__(self, action, conversation=None) -> ToolMockObservation: return ToolMockObservation(result=f"Executed: {action.command}") executor = MockExecutor() - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, @@ -734,7 +742,7 @@ def __call__(self, action, conversation=None) -> ToolMockObservation: def test_as_executable_without_executor(self): """Test as_executable() method with a tool that has no executor.""" - tool = ToolDefinition( + tool = MockTestTool( name="test_tool", description="A test tool", action_type=ToolMockAction, diff --git a/tests/sdk/tool/test_tool_immutability.py b/tests/sdk/tool/test_tool_immutability.py index f94cbaf448..21c9221bf0 100644 --- a/tests/sdk/tool/test_tool_immutability.py +++ b/tests/sdk/tool/test_tool_immutability.py @@ -36,12 +36,22 @@ def to_llm_content(self) -> Sequence[TextContent | ImageContent]: return [TextContent(text=self.result)] +class MockImmutableTool( + ToolDefinition[ToolImmutabilityMockAction, ToolImmutabilityMockObservation] +): + """Concrete mock tool for immutability testing.""" + + @classmethod + def create(cls, conv_state=None, **params) -> Sequence["MockImmutableTool"]: + return [cls(**params)] + + class TestToolImmutability: """Test suite for Tool immutability features.""" def test_tool_is_frozen(self): """Test that Tool instances are frozen and cannot be modified.""" - tool = ToolDefinition( + tool = MockImmutableTool( name="test_tool", description="Test tool", action_type=ToolImmutabilityMockAction, @@ -62,7 +72,7 @@ def test_tool_is_frozen(self): def test_tool_set_executor_returns_new_instance(self): """Test that set_executor returns a new Tool instance.""" - tool = ToolDefinition( + tool = MockImmutableTool( name="test_tool", description="Test tool", action_type=ToolImmutabilityMockAction, @@ -89,7 +99,7 @@ def __call__( def test_tool_model_copy_creates_modified_instance(self): """Test that model_copy can create modified versions of Tool instances.""" - tool = ToolDefinition( + tool = MockImmutableTool( name="test_tool", description="Test tool", action_type=ToolImmutabilityMockAction, @@ -111,7 +121,7 @@ def test_tool_model_copy_creates_modified_instance(self): def test_tool_meta_field_immutability(self): """Test that the meta field works correctly and is immutable.""" meta_data = {"version": "1.0", "author": "test"} - tool = ToolDefinition( + tool = MockImmutableTool( name="test_tool", description="Test tool", action_type=ToolImmutabilityMockAction, @@ -135,7 +145,7 @@ def test_tool_meta_field_immutability(self): def test_tool_constructor_parameter_validation(self): """Test that Tool constructor validates parameters correctly.""" # Test that new parameter names work - tool = ToolDefinition( + tool = MockImmutableTool( name="test_tool", description="Test tool", action_type=ToolImmutabilityMockAction, @@ -146,7 +156,7 @@ def test_tool_constructor_parameter_validation(self): # Test that invalid field types are rejected with pytest.raises(ValidationError): - ToolDefinition( + MockImmutableTool( name="test_tool", description="Test tool", action_type="invalid_type", # type: ignore[arg-type] # Should be a class, not string @@ -161,7 +171,7 @@ def test_tool_annotations_immutability(self): destructiveHint=False, ) - tool = ToolDefinition( + tool = MockImmutableTool( name="test_tool", description="Test tool", action_type=ToolImmutabilityMockAction, diff --git a/tests/sdk/tool/test_tool_serialization.py b/tests/sdk/tool/test_tool_serialization.py index 20c4636611..aa8c861cde 100644 --- a/tests/sdk/tool/test_tool_serialization.py +++ b/tests/sdk/tool/test_tool_serialization.py @@ -3,17 +3,17 @@ import json import pytest -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from openhands.sdk.tool import ToolDefinition from openhands.sdk.tool.builtins import FinishTool, ThinkTool -from openhands.sdk.tool.tool import ToolBase def test_tool_serialization_deserialization() -> None: """Test that Tool supports polymorphic JSON serialization/deserialization.""" # Use FinishTool which is a simple built-in tool - tool = FinishTool + tool_instances = FinishTool.create() + tool = tool_instances[0] # Serialize to JSON tool_json = tool.model_dump_json() @@ -33,7 +33,8 @@ class Container(BaseModel): tool: ToolDefinition # Create container with tool - tool = FinishTool + tool_instances = FinishTool.create() + tool = tool_instances[0] container = Container(tool=tool) # Serialize to JSON @@ -54,8 +55,10 @@ class NestedContainer(BaseModel): tools: list[ToolDefinition] # Create container with multiple tools - tool1 = FinishTool - tool2 = ThinkTool + tool1_instances = FinishTool.create() + tool1 = tool1_instances[0] + tool2_instances = ThinkTool.create() + tool2 = tool2_instances[0] container = NestedContainer(tools=[tool1, tool2]) # Serialize to JSON @@ -75,7 +78,8 @@ class NestedContainer(BaseModel): def test_tool_model_validate_json_dict() -> None: """Test that Tool.model_validate works with dict from JSON.""" # Create tool - tool = FinishTool + tool_instances = FinishTool.create() + tool = tool_instances[0] # Serialize to JSON, then parse to dict tool_json = tool.model_dump_json() @@ -101,14 +105,15 @@ def test_tool_no_fallback_behavior_json() -> None: } tool_json = json.dumps(tool_dict) - with pytest.raises(ValidationError): - ToolBase.model_validate_json(tool_json) + with pytest.raises(ValueError, match="Unknown kind 'UnknownToolType'"): + ToolDefinition.model_validate_json(tool_json) def test_tool_type_annotation_works_json() -> None: """Test that ToolType annotation works correctly with JSON.""" # Create tool - tool = FinishTool + tool_instances = FinishTool.create() + tool = tool_instances[0] # Use ToolType annotation class TestModel(BaseModel): @@ -130,7 +135,8 @@ class TestModel(BaseModel): def test_tool_kind_field_json() -> None: """Test Tool kind field is correctly set and preserved through JSON.""" # Create tool - tool = FinishTool + tool_instances = FinishTool.create() + tool = tool_instances[0] # Check kind field assert hasattr(tool, "kind") diff --git a/tests/tools/browser_use/test_browser_toolset.py b/tests/tools/browser_use/test_browser_toolset.py index bb51b64e88..ec2c3221a7 100644 --- a/tests/tools/browser_use/test_browser_toolset.py +++ b/tests/tools/browser_use/test_browser_toolset.py @@ -1,20 +1,7 @@ """Test BrowserToolSet functionality.""" from openhands.sdk.tool import ToolDefinition -from openhands.sdk.tool.tool import ToolBase -from openhands.tools.browser_use import ( - BrowserToolSet, - browser_click_tool, - browser_close_tab_tool, - browser_get_content_tool, - browser_get_state_tool, - browser_go_back_tool, - browser_list_tabs_tool, - browser_navigate_tool, - browser_scroll_tool, - browser_switch_tab_tool, - browser_type_tool, -) +from openhands.tools.browser_use import BrowserToolSet from openhands.tools.browser_use.impl import BrowserToolExecutor @@ -39,16 +26,16 @@ def test_browser_toolset_create_includes_all_browser_tools(): # Expected tool names based on the browser tools expected_names = [ - browser_navigate_tool.name, - browser_click_tool.name, - browser_get_state_tool.name, - browser_get_content_tool.name, - browser_type_tool.name, - browser_scroll_tool.name, - browser_go_back_tool.name, - browser_list_tabs_tool.name, - browser_switch_tab_tool.name, - browser_close_tab_tool.name, + "browser_navigate", + "browser_click", + "browser_get_state", + "browser_get_content", + "browser_type", + "browser_scroll", + "browser_go_back", + "browser_list_tabs", + "browser_switch_tab", + "browser_close_tab", ] # Verify all expected tools are present @@ -80,14 +67,14 @@ def test_browser_toolset_create_tools_are_properly_configured(): # Find a specific tool to test (e.g., navigate tool) navigate_tool = None for tool in tools: - if tool.name == browser_navigate_tool.name: + if tool.name == "browser_navigate": navigate_tool = tool break assert navigate_tool is not None - assert navigate_tool.description == browser_navigate_tool.description - assert navigate_tool.action_type == browser_navigate_tool.action_type - assert navigate_tool.observation_type == browser_navigate_tool.observation_type + assert navigate_tool.description is not None + assert navigate_tool.action_type is not None + assert navigate_tool.observation_type is not None assert navigate_tool.executor is not None @@ -134,11 +121,11 @@ def test_browser_toolset_create_no_parameters(): def test_browser_toolset_inheritance(): """Test that BrowserToolSet properly inherits from Tool.""" - assert issubclass(BrowserToolSet, ToolBase) + assert issubclass(BrowserToolSet, ToolDefinition) # BrowserToolSet should not be instantiable directly (it's a factory) # The create method returns a list, not an instance of BrowserToolSet tools = BrowserToolSet.create() for tool in tools: assert not isinstance(tool, BrowserToolSet) - assert isinstance(tool, ToolBase) + assert isinstance(tool, ToolDefinition) diff --git a/tests/tools/execute_bash/test_schema.py b/tests/tools/execute_bash/test_schema.py index 86a821168a..829672d500 100644 --- a/tests/tools/execute_bash/test_schema.py +++ b/tests/tools/execute_bash/test_schema.py @@ -1,9 +1,14 @@ -from openhands.tools.execute_bash import execute_bash_tool +from openhands.tools.execute_bash import BashTool -def test_to_mcp_tool_detailed_type_validation_bash(): +def test_to_mcp_tool_detailed_type_validation_bash(mock_conversation_state): """Test detailed type validation for MCP tool schema generation (execute_bash).""" # noqa: E501 + execute_bash_tool = BashTool.create(conv_state=mock_conversation_state) + assert len(execute_bash_tool) == 1 + execute_bash_tool = execute_bash_tool[0] + assert isinstance(execute_bash_tool, BashTool) + # Test execute_bash tool schema bash_mcp = execute_bash_tool.to_mcp_tool() bash_schema = bash_mcp["inputSchema"] diff --git a/tests/tools/file_editor/test_schema.py b/tests/tools/file_editor/test_schema.py index 3fc526226e..deee278d30 100644 --- a/tests/tools/file_editor/test_schema.py +++ b/tests/tools/file_editor/test_schema.py @@ -1,9 +1,14 @@ -from openhands.tools.file_editor import file_editor_tool +from openhands.tools.file_editor import FileEditorTool -def test_to_mcp_tool_detailed_type_validation_editor(): +def test_to_mcp_tool_detailed_type_validation_editor(mock_conversation_state): """Test detailed type validation for MCP tool schema generation.""" + file_editor_tool = FileEditorTool.create(conv_state=mock_conversation_state) + assert len(file_editor_tool) == 1 + file_editor_tool = file_editor_tool[0] + assert isinstance(file_editor_tool, FileEditorTool) + # Test file_editor tool schema str_editor_mcp = file_editor_tool.to_mcp_tool() str_editor_schema = str_editor_mcp["inputSchema"]