Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 72 additions & 15 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,7 @@ async def a_receive(
1. "content": content of the message, can be None.
2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls")
3. "tool_calls": a list of dictionaries containing the function name and arguments.
4. "role": role of the message, can be "assistant", "user", "function".
4. "role": role of the message, can be "assistant", "user", "function", "tool".
This field is only needed to distinguish between "function" or "assistant"/"user".
5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name.
6. "context" (dict): the context of the message, which will be passed to
Expand Down Expand Up @@ -2201,7 +2201,11 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Option
extracted_response["function_call"]["name"]
)
for tool_call in extracted_response.get("tool_calls") or []:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
if tool_call.get("type") == "custom":
tool_call["custom"]["name"] = self._normalize_name(tool_call["custom"]["name"])
else:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])

# Remove id and type if they are not present.
# This is to make the tool call object compatible with Mistral API.
if tool_call.get("id") is None:
Expand Down Expand Up @@ -2422,8 +2426,17 @@ def generate_tool_calls_reply(
message = messages[-1]
tool_returns = []
for tool_call in message.get("tool_calls", []):
function_call = tool_call.get("function", {})
tool_call_id = tool_call.get("id", None)

# Handle custom tool calls
if tool_call.get("type") == "custom" and "custom" in tool_call:
function_call = {
"name": tool_call["custom"].get("name"),
"arguments": tool_call["custom"].get("input", "{}"),
}
else:
function_call = tool_call.get("function", {})

func = self._function_map.get(function_call.get("name", None), None)
if inspect.iscoroutinefunction(func):
coro = self.a_execute_function(function_call, call_id=tool_call_id)
Expand Down Expand Up @@ -2458,7 +2471,17 @@ def generate_tool_calls_reply(

async def _a_execute_tool_call(self, tool_call):
tool_call_id = tool_call["id"]
function_call = tool_call.get("function", {})

# Handle custom tool calls
if tool_call.get("type") == "custom" and "custom" in tool_call:
function_call = {
"name": tool_call["custom"].get("name"),
"arguments": tool_call["custom"].get("input", "{}"),
}
else:
# Handle standard function tool calls
function_call = tool_call.get("function", {})

_, func_return = await self.a_execute_function(function_call, call_id=tool_call_id)
return {
"tool_call_id": tool_call_id,
Expand Down Expand Up @@ -3346,14 +3369,19 @@ def register_function(self, function_map: dict[str, Union[Callable[..., Any]]],
self._function_map = {k: v for k, v in self._function_map.items() if v is not None}

def update_function_signature(
self, func_sig: Union[str, dict[str, Any]], is_remove: None, silent_override: bool = False
self,
func_sig: Union[str, dict[str, Any]],
is_remove: None,
silent_override: bool = False,
free_form: bool = False,
):
"""Update a function_signature in the LLM configuration for function_call.

Args:
func_sig (str or dict): description/name of the function to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions
is_remove: whether removing the function from llm_config with name 'func_sig'
silent_override: whether to print warnings when overriding functions.
free_form: allow the function to take free-form inputs.

Deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
Expand Down Expand Up @@ -3396,17 +3424,27 @@ def update_function_signature(
if len(self.llm_config["functions"]) == 0 and isinstance(self.llm_config, dict):
del self.llm_config["functions"]

if free_form:
func_sig["type"] = "custom"
if "function" in func_sig:
func_sig["custom"] = func_sig.pop("function")

self.client = OpenAIWrapper(**self.llm_config)

def update_tool_signature(
self, tool_sig: Union[str, dict[str, Any]], is_remove: bool, silent_override: bool = False
self,
tool_sig: Union[str, dict[str, Any]],
is_remove: bool,
silent_override: bool = False,
free_form: bool = False,
):
"""Update a tool_signature in the LLM configuration for tool_call.

Args:
tool_sig (str or dict): description/name of the tool to update/remove to the model. See: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
is_remove: whether removing the tool from llm_config with name 'tool_sig'
silent_override: whether to print warnings when overriding functions.
free_form: allow the tool to take free-form inputs.
"""
if not self.llm_config:
error_msg = "To update a tool signature, agent must have an llm_config"
Expand Down Expand Up @@ -3455,6 +3493,10 @@ def update_tool_signature(
# Do this only if llm_config is a dict. If llm_config is LLMConfig, LLMConfig will handle this.
if len(self.llm_config["tools"]) == 0 and isinstance(self.llm_config, dict):
del self.llm_config["tools"]
if free_form:
tool_sig["type"] = "custom"
if "function" in tool_sig:
tool_sig["custom"] = tool_sig.pop("function")

self.client = OpenAIWrapper(**self.llm_config)

Expand Down Expand Up @@ -3510,15 +3552,16 @@ def _create_tool_if_needed(
func_or_tool: Union[F, Tool],
name: Optional[str],
description: Optional[str],
free_form: bool = False,
) -> Tool:
if isinstance(func_or_tool, Tool):
tool: Tool = func_or_tool
# create new tool object if name or description is not None
if name or description:
tool = Tool(func_or_tool=tool, name=name, description=description)
# create new tool object if name, description, or free_form is specified
if free_form or name or description:
tool = Tool(func_or_tool=tool, name=name, description=description, free_form=free_form)
elif inspect.isfunction(func_or_tool):
function: Callable[..., Any] = func_or_tool
tool = Tool(func_or_tool=function, name=name, description=description)
tool = Tool(func_or_tool=function, name=name, description=description, free_form=free_form)
else:
raise TypeError(f"'func_or_tool' must be a function or a Tool object, got '{type(func_or_tool)}' instead.")
return tool
Expand All @@ -3530,6 +3573,7 @@ def register_for_llm(
description: Optional[str] = None,
api_style: Literal["function", "tool"] = "tool",
silent_override: bool = False,
free_form: bool = False,
) -> Callable[[Union[F, Tool]], Tool]:
"""Decorator factory for registering a function to be used by an agent.

Expand All @@ -3548,6 +3592,7 @@ def register_for_llm(
`"function"` if `"tool"` doesn't work.
See [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling?tabs=python) for details.
silent_override (bool): whether to suppress any override warning messages.
free_form (bool): allow the function to take free-form inputs.

Returns:
The decorator for registering a function to be used by an agent.
Expand Down Expand Up @@ -3591,7 +3636,7 @@ def _decorator(
"""
tool = self._create_tool_if_needed(func_or_tool, name, description)

self._register_for_llm(tool, api_style, silent_override=silent_override)
self._register_for_llm(tool, api_style, silent_override=silent_override, free_form=free_form)
if tool not in self._tools:
self._tools.append(tool)

Expand All @@ -3600,7 +3645,12 @@ def _decorator(
return _decorator

def _register_for_llm(
self, tool: Tool, api_style: Literal["tool", "function"], is_remove: bool = False, silent_override: bool = False
self,
tool: Tool,
api_style: Literal["tool", "function"],
is_remove: bool = False,
silent_override: bool = False,
free_form: bool = False,
) -> None:
"""
Register a tool for LLM.
Expand All @@ -3610,6 +3660,7 @@ def _register_for_llm(
api_style: the API style for function call ("tool" or "function").
is_remove: whether to remove the function or tool.
silent_override: whether to suppress any override warning messages.
free_form: allow the tool to take free-form inputs.

Returns:
None
Expand All @@ -3619,9 +3670,13 @@ def _register_for_llm(
raise RuntimeError("LLM config must be setup before registering a function for LLM.")

if api_style == "function":
self.update_function_signature(tool.function_schema, is_remove=is_remove, silent_override=silent_override)
self.update_function_signature(
tool.function_schema, is_remove=is_remove, silent_override=silent_override, free_form=free_form
)
elif api_style == "tool":
self.update_tool_signature(tool.tool_schema, is_remove=is_remove, silent_override=silent_override)
self.update_tool_signature(
tool.tool_schema, is_remove=is_remove, silent_override=silent_override, free_form=free_form
)
else:
raise ValueError(f"Unsupported API style: {api_style}")

Expand Down Expand Up @@ -4065,6 +4120,7 @@ def register_function(
executor: ConversableAgent,
name: Optional[str] = None,
description: str,
free_form: bool = False,
) -> None:
"""Register a function to be proposed by an agent and executed for an executor.

Expand All @@ -4079,7 +4135,8 @@ def register_function(
description: description of the function. The description is used by LLM to decode whether the function
is called. Make sure the description is properly describing what the function does or it might not be
called by LLM when needed.
free_form: allow the function to take free-form inputs.

"""
f = caller.register_for_llm(name=name, description=description)(f)
f = caller.register_for_llm(name=name, description=description, free_form=free_form)(f)
executor.register_for_execution(name=name)(f)
30 changes: 25 additions & 5 deletions autogen/events/agent_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Union
from uuid import UUID

from pydantic import BaseModel, field_validator, model_serializer
Expand Down Expand Up @@ -116,12 +116,13 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None:
class FunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
input: Optional[str] = None # Add input field for custom tools

def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print

name = self.name or "(No function name found)"
arguments = self.arguments or "(No arguments found)"
arguments = self.input or self.arguments or "(No arguments found)"

func_print = f"***** Suggested function call: {name} *****"
f(colored(func_print, "green"), flush=True)
Expand Down Expand Up @@ -152,17 +153,35 @@ def print(self, f: Optional[Callable[..., Any]] = None) -> None:


class ToolCall(BaseModel):
"""
Represents a tool call, which can be a function or a custom tool.

Params:
id (Optional[str]): The unique identifier for the tool call.
function (Optional[FunctionCall]): The function call details if type is 'function'.
custom (Optional[Dict[str, Any]]): The custom tool details if type is 'custom'.
type (str): The type of tool call, e.g., 'function' or 'custom'.
"""

id: Optional[str] = None
function: FunctionCall
function: Optional[FunctionCall] = None
custom: Optional[Dict[str, Any]] = None
type: str

def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print

id = self.id or "No tool call id found"

name = self.function.name or "(No function name found)"
arguments = self.function.arguments or "(No arguments found)"
if self.type == "function" and self.function:
name = self.function.name or "(No function name found)"
arguments = self.function.arguments or "(No arguments found)"
elif self.type == "custom" and self.custom:
name = self.custom.get("name", "(No custom tool name found)")
arguments = self.custom.get("input", "(No input found)")
else:
name = "(Unknown tool type)"
arguments = "(No arguments found)"

func_print = f"***** Suggested tool call ({id}): {name} *****"
f(colored(func_print, "green"), flush=True)
Expand Down Expand Up @@ -254,6 +273,7 @@ def create_received_event_model(
)

if event.get("tool_calls"):
print("event", event)
return ToolCallEvent(
**event,
sender=sender.name,
Expand Down
18 changes: 15 additions & 3 deletions autogen/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,19 @@ def __init__(
description: Optional[str] = None,
func_or_tool: Union["Tool", Callable[..., Any]],
parameters_json_schema: Optional[dict[str, Any]] = None,
free_form: bool = False,
) -> None:
"""Create a new Tool object.
"""
Create a new Tool object.

Args:
name (str): The name of the tool.
description (str): The description of the tool.
func_or_tool (Union[Tool, Callable[..., Any]]): The function or Tool instance to create a Tool from.
parameters_json_schema (Optional[dict[str, Any]]): A schema describing the parameters that the function accepts. If None, the schema will be generated from the function signature.
free_form (bool): allow the tool to take free-form inputs.
"""
self._free_form: bool = free_form
if isinstance(func_or_tool, Tool):
self._name: str = name or func_or_tool.name
self._description: str = description or func_or_tool.description
Expand Down Expand Up @@ -95,7 +99,7 @@ def register_for_llm(self, agent: "ConversableAgent") -> None:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
if self._func_schema:
agent.update_tool_signature(self._func_schema, is_remove=False)
agent.update_tool_signature(self._func_schema, is_remove=False, free_form=self._free_form)
agent.register_for_llm()(self)

def register_for_execution(self, agent: "ConversableAgent") -> None:
Expand Down Expand Up @@ -169,18 +173,26 @@ def realtime_tool_schema(self) -> dict[str, Any]:


@export_module("autogen.tools")
def tool(name: Optional[str] = None, description: Optional[str] = None) -> Callable[[Callable[..., Any]], Tool]:
def tool(
name: Optional[str] = None, description: Optional[str] = None, free_form: bool = False
) -> Callable[[Callable[..., Any]], Tool]:
"""Decorator to create a Tool from a function.

Args:
name (str): The name of the tool.
description (str): The description of the tool.
free_form (bool): allow the tool to take free-form inputs.

Returns:
Callable[[Callable[..., Any]], Tool]: A decorator that creates a Tool from a function.
"""

def decorator(func: Callable[..., Any]) -> Tool:
if free_form:
func_description = description or func.__doc__ or ""
schema = get_function_schema(func, name=name, description=func_description)
schema["type"] = "custom"
return Tool(name=name, description=func_description, func_or_tool=func, parameters_json_schema=schema)
return Tool(name=name, description=description, func_or_tool=func)

return decorator
Loading
Loading