From 221ab62c8c521637f87abfa23c72b408c51358d6 Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan <83240266+personal-dc@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:58:51 -0400 Subject: [PATCH] added support for mcp tools for when people use flow handlers --- src/pipecat_flows/manager.py | 583 ++++++++++++++--------------------- 1 file changed, 228 insertions(+), 355 deletions(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index cb2b71c1..64c07184 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -8,7 +8,6 @@ This module provides the FlowManager class which orchestrates conversations across different LLM providers. It supports: - - Static flows with predefined paths - Dynamic flows with runtime-determined transitions - State management and transitions @@ -17,7 +16,6 @@ - Cross-provider compatibility The flow manager coordinates all aspects of a conversation, including: - - LLM context management - Function registration - State transitions @@ -28,7 +26,6 @@ import asyncio import inspect import sys -import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union, cast from loguru import logger @@ -42,27 +39,19 @@ from pipecat.services.llm_service import FunctionCallParams from pipecat.transports.base_transport import BaseTransport -from pipecat_flows.actions import ActionError, ActionManager -from pipecat_flows.adapters import create_adapter -from pipecat_flows.exceptions import ( - FlowError, - FlowInitializationError, - FlowTransitionError, - InvalidFunctionError, -) -from pipecat_flows.types import ( +from .actions import ActionError, ActionManager +from .adapters import create_adapter +from .exceptions import FlowError, FlowInitializationError, FlowTransitionError +from .types import ( ActionConfig, - ConsolidatedFunctionResult, ContextStrategy, ContextStrategyConfig, FlowArgs, FlowConfig, FlowResult, - FlowsDirectFunctionWrapper, FlowsFunctionSchema, FunctionHandler, NodeConfig, - get_or_generate_node_name, ) if TYPE_CHECKING: @@ -76,16 +65,20 @@ class FlowManager: - """Manages conversation flows supporting both static and dynamic configurations. + """Manages conversation flows, supporting both static and dynamic configurations. The FlowManager orchestrates conversation flows by managing state transitions, function registration, and message handling across different LLM providers. - It supports both predefined static flows and runtime-determined dynamic flows - with comprehensive action handling and error management. - The manager coordinates all aspects of a conversation including LLM context - management, function registration, state transitions, action execution, and - provider-specific format handling. + Attributes: + task: Pipeline task for frame queueing + llm: LLM service instance (OpenAI, Anthropic, or Google) + tts: Optional TTS service for voice actions + state: Shared state dictionary across nodes + current_node: Currently active node identifier + initialized: Whether the manager has been initialized + nodes: Node configurations for static flows + current_functions: Currently registered function names """ def __init__( @@ -102,42 +95,26 @@ def __init__( """Initialize the flow manager. Args: - task: PipelineTask instance for queueing frames. - llm: LLM service instance (e.g., OpenAI, Anthropic, Google). - context_aggregator: Context aggregator for updating user context. - tts: Text-to-speech service for voice actions. - - .. deprecated:: 0.0.18 - The tts parameter is deprecated and will be removed in 1.0.0. - - flow_config: Static flow configuration. If provided, operates in static - mode with predefined nodes. - - .. deprecated:: 0.0.19 - Static flows are deprecated and will be removed in 1.0.0. - Use dynamic flows instead. - - context_strategy: Context strategy configuration for managing conversation - context during transitions. - transport: Transport instance for communication. + task: PipelineTask instance for queueing frames + llm: LLM service instance (e.g., OpenAI, Anthropic) + context_aggregator: Context aggregator for updating user context + tts: Optional TTS service for voice actions + flow_config: Optional static flow configuration. If provided, + operates in static mode with predefined nodes + context_strategy: Optional context strategy configuration + transport: Optional transport Raises: - ValueError: If any transition handler is not a valid async callable. + ValueError: If any transition handler is not a valid async callable """ - if tts is not None: - warnings.warn( - "The 'tts' parameter is deprecated and will be removed in 1.0.0.", - DeprecationWarning, - stacklevel=2, - ) - self.task = task self.llm = llm - self.action_manager = ActionManager(task, flow_manager=self) + self.tts = tts + self.action_manager = ActionManager(task, flow_manager=self, tts=tts) self.adapter = create_adapter(llm) self.initialized = False self._context_aggregator = context_aggregator - self._pending_transition: Optional[Dict[str, Any]] = None + self._pending_function_calls = 0 self._context_strategy = context_strategy or ContextStrategyConfig( strategy=ContextStrategy.APPEND ) @@ -145,11 +122,6 @@ def __init__( # Set up static or dynamic mode if flow_config: - warnings.warn( - "Static flows are deprecated as of 0.0.19 and will be removed in 1.0.0.", - DeprecationWarning, - stacklevel=2, - ) self.nodes = flow_config["nodes"] self.initial_node = flow_config["initial_node"] logger.debug("Initialized in static mode") @@ -162,9 +134,6 @@ def __init__( self.current_functions: Set[str] = set() # Track registered functions self.current_node: Optional[str] = None - self._showed_deprecation_warning_for_transition_fields = False - self._showed_deprecation_warning_for_set_node = False - def _validate_transition_callback(self, name: str, callback: Any) -> None: """Validate a transition callback. @@ -180,32 +149,8 @@ def _validate_transition_callback(self, name: str, callback: Any) -> None: if not inspect.iscoroutinefunction(callback): raise ValueError(f"Transition callback for {name} must be async") - async def initialize(self, initial_node: Optional[NodeConfig] = None) -> None: - """Initialize the flow manager. - - Args: - initial_node: Optional initial node configuration for dynamic flows. - - Raises: - FlowInitializationError: If initialization fails. - - Examples: - Static flow:: - - flow_manager = FlowManager( - ... # Initialization parameters - ) - # Static flow: no initialization args required - flow_manager.initialize() - - Dynamic flow:: - - flow_manager = FlowManager( - ... # Initialization parameters - ) - # Dynamic flow: Initialize with the initial node configuration - flow_manager.initialize(create_initial_node()) - """ + async def initialize(self) -> None: + """Initialize the flow manager.""" if self.initialized: logger.warning(f"{self.__class__.__name__} already initialized") return @@ -214,26 +159,10 @@ async def initialize(self, initial_node: Optional[NodeConfig] = None) -> None: self.initialized = True logger.debug(f"Initialized {self.__class__.__name__}") - # Set initial node - node_name = None - node = None + # If in static mode, set initial node if self.initial_node: - # Static flow: self.initial_node is expected to be there - node_name = self.initial_node - node = self.nodes[self.initial_node] - if not node: - raise ValueError( - f"Initial node '{self.initial_node}' not found in static flow configuration" - ) - else: - # Dynamic flow: initial_node argument may have been provided (otherwise initial node - # will be set later via set_node()) - if initial_node: - node_name = get_or_generate_node_name(initial_node) - node = initial_node - if node_name: - logger.debug(f"Setting initial node: {node_name}") - await self._set_node(node_name, node) + logger.debug(f"Setting initial node: {self.initial_node}") + await self.set_node(self.initial_node, self.nodes[self.initial_node]) except Exception as e: self.initialized = False @@ -247,7 +176,7 @@ def get_current_context(self) -> List[dict]: user messages, and assistant responses. Raises: - FlowError: If context aggregator is not available. + FlowError: If context aggregator is not available """ if not self._context_aggregator: raise FlowError("No context aggregator available") @@ -258,11 +187,10 @@ def register_action(self, action_type: str, handler: Callable) -> None: """Register a handler for a specific action type. Args: - action_type: String identifier for the action (e.g., "tts_say"). - handler: Async or sync function that handles the action. - - Example:: + action_type: String identifier for the action (e.g., "tts_say") + handler: Async or sync function that handles the action + Example: async def custom_notification(action: dict): text = action.get("text", "") await notify_user(text) @@ -275,10 +203,10 @@ def _register_action_from_config(self, action: ActionConfig) -> None: """Register an action handler from action configuration. Args: - action: Action configuration dictionary containing type and optional handler. + action: Action configuration dictionary containing type and optional handler Raises: - ActionError: If action type is not registered and no valid handler provided. + ActionError: If action type is not registered and no valid handler provided """ action_type = action.get("type") handler = action.get("handler") @@ -296,26 +224,30 @@ def _register_action_from_config(self, action: ActionConfig) -> None: "Provide handler in action config or register manually." ) - async def _call_handler( - self, handler: FunctionHandler, args: FlowArgs - ) -> FlowResult | ConsolidatedFunctionResult: + async def _call_handler(self, handler: FunctionHandler, args: FlowArgs) -> FlowResult: """Call handler with appropriate parameters based on its signature. Detects whether the handler can accept a flow_manager parameter and calls it accordingly to maintain backward compatibility with legacy handlers. Args: - handler: The function handler to call (either legacy or modern format). - args: Arguments dictionary from the function call. + handler: The function handler to call (either legacy or modern format) + args: Arguments dictionary from the function call Returns: - The result returned by the handler. + FlowResult: The result returned by the handler """ # Get the function signature sig = inspect.signature(handler) - # Calculate effective parameter count - effective_param_count = len(sig.parameters) + # Check if handler is a method (has self parameter) + is_method = inspect.ismethod(handler) + + # Calculate effective parameter count (excluding 'self' if method) + if is_method: + effective_param_count = len(sig.parameters) - 1 + else: + effective_param_count = len(sig.parameters) # Handle different function signatures if effective_param_count == 0: @@ -331,23 +263,26 @@ async def _call_handler( async def _create_transition_func( self, name: str, - handler: Optional[Callable | FlowsDirectFunctionWrapper], + handler: Optional[Callable], transition_to: Optional[str], transition_callback: Optional[Callable] = None, + is_mcp: bool = False, ) -> Callable: """Create a transition function for the given name and handler. Args: - name: Name of the function being registered. - handler: Optional function to process data. - transition_to: Optional node to transition to (static flows). - transition_callback: Optional callback for dynamic transitions. + name: Name of the function being registered + handler: Optional function to process data + transition_to: Optional node to transition to (static flows) + transition_callback: Optional callback for dynamic transitions + is_mcp: Is this function registered via an MCP server. + Returns: - Async function that handles the tool invocation. + Callable: Async function that handles the tool invocation Raises: - ValueError: If both transition_to and transition_callback are specified. + ValueError: If both transition_to and transition_callback are specified """ if transition_to and transition_callback: raise ValueError( @@ -358,131 +293,135 @@ async def _create_transition_func( if transition_callback: self._validate_transition_callback(name, transition_callback) + # If MCP tool without handler, get it from LLM's registered functions + if is_mcp and handler is None: + if hasattr(self.llm, '_functions') and name in self.llm._functions: + existing_item = self.llm._functions[name] + if existing_item and hasattr(existing_item, 'handler') and existing_item.handler: + handler = existing_item.handler + logger.debug(f"Using MCP-registered handler for: {name}") + + is_edge_function = bool(transition_to) or bool(transition_callback) + + def decrease_pending_function_calls() -> None: + """Decrease the pending function calls counter if greater than zero.""" + if self._pending_function_calls > 0: + self._pending_function_calls -= 1 + logger.debug( + f"Function call completed: {name} (remaining: {self._pending_function_calls})" + ) + + async def on_context_updated_edge( + args: Dict[str, Any], result: Any, result_callback: Callable + ) -> None: + """Handle context updates for edge functions with transitions.""" + try: + decrease_pending_function_calls() + + # Only process transition if this was the last pending call + if self._pending_function_calls == 0: + if transition_to: # Static flow + logger.debug(f"Static transition to: {transition_to}") + await self.set_node(transition_to, self.nodes[transition_to]) + elif transition_callback: # Dynamic flow + logger.debug(f"Dynamic transition for: {name}") + # Check callback signature + sig = inspect.signature(transition_callback) + if len(sig.parameters) == 2: + # Old style: (args, flow_manager) + await transition_callback(args, self) + else: + # New style: (args, result, flow_manager) + await transition_callback(args, result, self) + # Reset counter after transition completes + self._pending_function_calls = 0 + logger.debug("Reset pending function calls counter") + else: + logger.debug( + f"Skipping transition, {self._pending_function_calls} calls still pending" + ) + except Exception as e: + logger.error(f"Error in transition: {str(e)}") + self._pending_function_calls = 0 + await result_callback( + {"status": "error", "error": str(e)}, + properties=None, # Clear properties to prevent further callbacks + ) + raise # Re-raise to prevent further processing + + async def on_context_updated_node() -> None: + """Handle context updates for node functions without transitions.""" + decrease_pending_function_calls() + async def transition_func(params: FunctionCallParams) -> None: """Inner function that handles the actual tool invocation.""" try: - logger.debug(f"Function called: {name}") + # Track pending function call + self._pending_function_calls += 1 + logger.debug( + f"Function call pending: {name} (total: {self._pending_function_calls})" + ) # Execute handler if present - is_transition_only_function = False - acknowledged_result = {"status": "acknowledged"} if handler: - # Invoke the handler with the provided arguments - if isinstance(handler, FlowsDirectFunctionWrapper): - handler_response = await handler.invoke(params.arguments, self) - else: - handler_response = await self._call_handler(handler, params.arguments) - # Support both "consolidated" handlers that return (result, next_node) and handlers - # that return just the result. - if isinstance(handler_response, tuple): - result, next_node = handler_response - if result is None: - result = acknowledged_result - is_transition_only_function = True + if is_mcp: + logger.debug(f"Calling MCP tool: {name}") + + # MCP tools handle their own result callback, so we need to capture it + mcp_result = None + + async def mcp_result_callback(response): + nonlocal mcp_result + mcp_result = response + logger.debug(f"MCP tool {name} returned: {str(response)[:200]}") + + # Call MCP wrapper with its expected signature + mcp_params = FunctionCallParams( + function_name=name, + tool_call_id=getattr(params, 'tool_call_id', f"{name}_call"), + arguments=params.arguments, + llm=self.llm, + context=getattr(params, 'context', None), + result_callback=mcp_result_callback + ) + + await handler(mcp_params) + + # Use the captured result + result = mcp_result if mcp_result is not None else {"status": "completed"} + logger.debug(f"MCP handler completed for {name}") else: - result = handler_response - next_node = None - # FlowsDirectFunctions should always be "consolidated" functions that return a tuple - if isinstance(handler, FlowsDirectFunctionWrapper): - raise InvalidFunctionError( - f"Direct function {name} expected to return a tuple (result, next_node) but got {type(result)}" - ) + # Regular handler - use FlowManager's calling convention + result = await self._call_handler(handler, params.arguments) + logger.debug(f"Handler completed for {name}") else: - result = acknowledged_result - next_node = None - is_transition_only_function = True - - logger.debug( - f"{'Transition-only function called for' if is_transition_only_function else 'Function handler completed for'} {name}" - ) + result = {"status": "acknowledged"} + logger.debug(f"Function called without handler: {name}") + + # For edge functions, prevent LLM completion until transition (run_llm=False) + # For node functions, allow immediate completion (run_llm=True) + async def on_context_updated() -> None: + if is_edge_function: + await on_context_updated_edge( + params.arguments, result, params.result_callback + ) + else: + await on_context_updated_node() - # Determine if this is an edge function - is_edge_function = ( - bool(next_node) or bool(transition_to) or bool(transition_callback) + properties = FunctionCallResultProperties( + run_llm=not is_edge_function, + on_context_updated=on_context_updated, ) - - if is_edge_function: - # Store transition info for coordinated execution - transition_info = { - "next_node": next_node, - "transition_to": transition_to, - "transition_callback": transition_callback, - "function_name": name, - "arguments": params.arguments, - "result": result, - } - self._pending_transition = transition_info - - properties = FunctionCallResultProperties( - run_llm=False, # Don't run LLM until transition completes - on_context_updated=self._check_and_execute_transition, - ) - else: - # Node function - run LLM immediately - properties = FunctionCallResultProperties( - run_llm=True, - on_context_updated=None, - ) - await params.result_callback(result, properties=properties) except Exception as e: logger.error(f"Error in transition function {name}: {str(e)}") + self._pending_function_calls = 0 error_result = {"status": "error", "error": str(e)} await params.result_callback(error_result) return transition_func - async def _check_and_execute_transition(self) -> None: - """Check if all functions are complete and execute transition if so.""" - if not self._pending_transition: - return - - # Check if all function calls are complete using Pipecat's state - assistant_aggregator = self._context_aggregator.assistant() - if not assistant_aggregator.has_function_calls_in_progress: - # All functions complete, execute transition - transition_info = self._pending_transition - self._pending_transition = None - - await self._execute_transition(transition_info) - - async def _execute_transition(self, transition_info: Dict[str, Any]) -> None: - """Execute the stored transition.""" - next_node = transition_info.get("next_node") - transition_to = transition_info.get("transition_to") - transition_callback = transition_info.get("transition_callback") - function_name = transition_info.get("function_name") - arguments = transition_info.get("arguments") - result = transition_info.get("result") - - try: - if next_node: # Function-returned next node (consolidated function) - if isinstance(next_node, str): # Static flow - node_name = next_node - node = self.nodes[next_node] - else: # Dynamic flow - node_name = get_or_generate_node_name(next_node) - node = next_node - logger.debug(f"Transition to function-returned node: {node_name}") - await self._set_node(node_name, node) - elif transition_to: # Static flow (deprecated) - logger.debug(f"Static transition to: {transition_to}") - await self._set_node(transition_to, self.nodes[transition_to]) - elif transition_callback: # Dynamic flow (deprecated) - logger.debug(f"Dynamic transition for: {function_name}") - # Check callback signature - sig = inspect.signature(transition_callback) - if len(sig.parameters) == 2: - # Old style: (args, flow_manager) - await transition_callback(arguments, self) - else: - # New style: (args, result, flow_manager) - await transition_callback(arguments, result, self) - except Exception as e: - logger.error(f"Error executing transition: {str(e)}") - raise - def _lookup_function(self, func_name: str) -> Callable: """Look up a function by name in the main module. @@ -514,22 +453,24 @@ async def _register_function( self, name: str, new_functions: Set[str], - handler: Optional[Callable | FlowsDirectFunctionWrapper], + handler: Optional[Callable], transition_to: Optional[str] = None, transition_callback: Optional[Callable] = None, + is_mcp: bool = False, ) -> None: """Register a function with the LLM if not already registered. Args: - name: Name of the function to register - handler: A callable function handler, a FlowsDirectFunction, or a string. - If string starts with '__function__:', extracts the function name after the prefix. - transition_to: Optional node to transition to (static flows) - transition_callback: Optional transition callback (dynamic flows) + name: Name of the function to register with the LLM new_functions: Set to track newly registered functions for this node + handler: Either a callable function or a string. If string starts with + '__function__:', extracts the function name after the prefix + transition_to: Optional node name to transition to after function execution + transition_callback: Optional callback for dynamic transitions + is_mcp: Is this function registered via an MCP server. Raises: - FlowError: If function registration fails + FlowError: If function registration fails or handler lookup fails """ if name not in self.current_functions: try: @@ -540,7 +481,7 @@ async def _register_function( # Create transition function transition_func = await self._create_transition_func( - name, handler, transition_to, transition_callback + name, handler, transition_to, transition_callback, is_mcp ) # Register function with LLM @@ -555,54 +496,9 @@ async def _register_function( logger.error(f"Failed to register function {name}: {str(e)}") raise FlowError(f"Function registration failed: {str(e)}") from e - async def set_node_from_config(self, node_config: NodeConfig) -> None: - """Set up a new conversation node and transition to it. - - Used to manually transition between nodes in a dynamic flow. - - Args: - node_config: Configuration for the new node. - - Raises: - FlowTransitionError: If manager not initialized. - FlowError: If node setup fails. - """ - await self._set_node(get_or_generate_node_name(node_config), node_config) - async def set_node(self, node_id: str, node_config: NodeConfig) -> None: """Set up a new conversation node and transition to it. - .. deprecated:: 0.0.18 - This method is deprecated and will be removed in 1.0.0. - Use set_node_from_config() instead, or prefer "consolidated" functions - that return a tuple (result, next_node). - - Args: - node_id: Identifier for the new node. - node_config: Configuration for the new node. - - Raises: - FlowTransitionError: If manager not initialized. - FlowError: If node setup fails. - """ - if not self._showed_deprecation_warning_for_set_node: - self._showed_deprecation_warning_for_set_node = True - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - """`set_node()` is deprecated and will be removed in 1.0.0. Instead, do the following for dynamic flows: -- Prefer "consolidated" or "direct" functions that return a tuple (result, next_node) over deprecated `transition_callback`s -- Pass your initial node to `FlowManager.initialize()` -- If you really need to set a node explicitly, use `set_node_from_config()` -In all of these cases, you can provide a `name` in your new node's config for debug logging purposes.""", - DeprecationWarning, - stacklevel=2, - ) - await self._set_node(node_id, node_config) - - async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: - """Set up a new conversation node and transition to it. - Handles the complete node transition process in the following order: 1. Execute pre-actions (if any) 2. Set up messages (role and task) @@ -613,23 +509,17 @@ async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: 7. Execute post-actions (if any) Args: - node_id: Identifier for the new node. - node_config: Complete configuration for the node. + node_id: Identifier for the new node + node_config: Complete configuration for the node Raises: - FlowTransitionError: If manager not initialized. - FlowError: If node setup fails. + FlowTransitionError: If manager not initialized + FlowError: If node setup fails """ if not self.initialized: raise FlowTransitionError(f"{self.__class__.__name__} must be initialized first") try: - # Clear any pending transition state when starting a new node - # This ensures clean state regardless of how we arrived here: - # - Normal transition flow (already cleared in _check_and_execute_transition) - # - Direct calls to set_node/set_node_from_config - self._pending_transition = None - self._validate_node_config(node_id, node_config) logger.debug(f"Setting node: {node_id}") @@ -655,13 +545,10 @@ async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: messages.extend(node_config["task_messages"]) # Register functions and prepare tools - tools: List[FlowsFunctionSchema | FlowsDirectFunctionWrapper] = [] + tools = [] new_functions: Set[str] = set() - # Get functions list with default empty list if not provided - functions_list = node_config.get("functions", []) - - async def register_function_schema(schema: FlowsFunctionSchema): + async def register_function_schema(schema): """Helper to register a single FlowsFunctionSchema.""" tools.append(schema) await self._register_function( @@ -670,26 +557,13 @@ async def register_function_schema(schema: FlowsFunctionSchema): handler=schema.handler, transition_to=schema.transition_to, transition_callback=schema.transition_callback, - ) + is_mcp=is_mcp, - async def register_direct_function(func): - """Helper to register a single direct function.""" - direct_function = FlowsDirectFunctionWrapper(function=func) - tools.append(direct_function) - await self._register_function( - name=direct_function.name, - new_functions=new_functions, - handler=direct_function, - transition_to=None, - transition_callback=None, ) - for func_config in functions_list: - # Handle direct functions - if callable(func_config): - await register_direct_function(func_config) + for func_config in node_config["functions"]: # Handle Gemini's nested function declarations as a special case - elif ( + if ( not isinstance(func_config, FlowsFunctionSchema) and "function_declarations" in func_config ): @@ -698,14 +572,27 @@ async def register_direct_function(func): schema = self.adapter.convert_to_function_schema( {"function_declarations": [declaration]} ) - await register_function_schema(schema) - # Convert to FlowsFunctionSchema if needed and process it + is_mcp = declaration.get("is_mcp", False) + + await register_function_schema(schema, is_mcp) else: + # Convert to FlowsFunctionSchema if needed and process it schema = ( func_config if isinstance(func_config, FlowsFunctionSchema) else self.adapter.convert_to_function_schema(func_config) ) + + # Extract is_mcp flag from config + is_mcp = False + if isinstance(func_config, dict): + # OpenAI format: check in "function" object + if "function" in func_config: + is_mcp = func_config["function"].get("is_mcp", False) + # Other formats: check directly + else: + is_mcp = func_config.get("is_mcp", False) + await register_function_schema(schema) # Create ToolsSchema with standard function schemas @@ -716,7 +603,7 @@ async def register_direct_function(func): # Use provider adapter to format tools, passing original configs for Gemini adapter formatted_tools = self.adapter.format_functions( - standard_functions, original_configs=functions_list + standard_functions, original_configs=node_config["functions"] ) # Update LLM context @@ -740,6 +627,7 @@ async def register_direct_function(func): await self._execute_actions(post_actions=post_actions) else: # Schedule post-actions for execution after first LLM response in this node + print("[pk] Scheduling post-actions for execution after LLM response") self._schedule_deferred_post_actions(post_actions=post_actions) logger.debug(f"Successfully set node: {node_id}") @@ -766,12 +654,12 @@ async def _update_llm_context( """Update LLM context with new messages and functions. Args: - messages: New messages to add to context. - functions: New functions to make available. - strategy: Optional context update configuration. + messages: New messages to add to context + functions: New functions to make available + strategy: Optional context update configuration Raises: - FlowError: If context update fails. + FlowError: If context update fails """ try: update_config = strategy or self._context_strategy @@ -835,8 +723,8 @@ async def _execute_actions( """Execute pre and post actions. Args: - pre_actions: Actions to execute before context update. - post_actions: Actions to execute after context update. + pre_actions: Actions to execute before context update + post_actions: Actions to execute after context update """ if pre_actions: await self.action_manager.execute_actions(pre_actions) @@ -847,34 +735,27 @@ def _validate_node_config(self, node_id: str, config: NodeConfig) -> None: """Validate the configuration of a conversation node. This method ensures that: - 1. Required fields (task_messages) are present + 1. Required fields (task_messages, functions) are present 2. Functions have valid configurations based on their type: - FlowsFunctionSchema objects have proper handler/transition fields - Dictionary format functions have valid handler/transition entries - - Direct functions are valid according to the FlowsDirectFunctions validation 3. Edge functions (matching node names) are allowed without handlers/transitions Args: - node_id: Identifier for the node being validated. - config: Complete node configuration to validate. + node_id: Identifier for the node being validated + config: Complete node configuration to validate Raises: - ValueError: If configuration is invalid or missing required fields. + ValueError: If configuration is invalid or missing required fields """ # Check required fields if "task_messages" not in config: raise ValueError(f"Node '{node_id}' missing required 'task_messages' field") + if "functions" not in config: + raise ValueError(f"Node '{node_id}' missing required 'functions' field") - # Get functions list with default empty list if not provided - functions_list = config.get("functions", []) - - # Validate each function configuration if there are any - for func in functions_list: - # If the function is callable, validate using FlowsDirectFunction - if callable(func): - FlowsDirectFunctionWrapper.validate_function(func) - continue - + # Validate each function configuration + for func in config["functions"]: # Extract function name using adapter (handles all formats) try: name = self.adapter.get_function_name(func) @@ -891,6 +772,7 @@ def _validate_node_config(self, node_id: str, config: NodeConfig) -> None: has_handler = func.handler is not None has_transition_to = func.transition_to is not None has_transition_callback = func.transition_callback is not None + is_mcp = False else: # For dictionary formats, use the provider-specific format checks # OpenAI format @@ -898,40 +780,31 @@ def _validate_node_config(self, node_id: str, config: NodeConfig) -> None: has_handler = "handler" in func["function"] has_transition_to = "transition_to" in func["function"] has_transition_callback = "transition_callback" in func["function"] + is_mcp = func["function"].get("is_mcp", False) + # Anthropic format elif "name" in func and "input_schema" in func: has_handler = "handler" in func has_transition_to = "transition_to" in func has_transition_callback = "transition_callback" in func + is_mcp = func.get("is_mcp", False) + # Gemini format elif "function_declarations" in func and func["function_declarations"]: decl = func["function_declarations"][0] has_handler = "handler" in decl has_transition_to = "transition_to" in decl has_transition_callback = "transition_callback" in decl + is_mcp = decl.get("is_mcp", False) + else: # Unknown format, report error raise ValueError( f"Unknown function format for function '{name}' in node '{node_id}'" ) - # Warn if the function has no handler or transitions - if not has_handler and not has_transition_to and not has_transition_callback: + # Warn if the function has no handler or transitions (unless it is an MCP tool) + if not is_mcp and not has_handler and not has_transition_to and not has_transition_callback: logger.warning( f"Function '{name}' in node '{node_id}' has neither handler, transition_to, nor transition_callback" ) - - # Warn about usage of deprecated transition_to and transition_callback - if ( - has_transition_to - or has_transition_callback - and not self._showed_deprecation_warning_for_transition_fields - ): - self._showed_deprecation_warning_for_transition_fields = True - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - '`transition_to` and `transition_callback` are deprecated and will be removed in 1.0.0. Use a "consolidated" `handler` that returns a tuple (result, next_node) instead.', - DeprecationWarning, - stacklevel=2, - )