diff --git a/nemoguardrails/logging/callbacks.py b/nemoguardrails/logging/callbacks.py index 175d629f6..ae5fac3ab 100644 --- a/nemoguardrails/logging/callbacks.py +++ b/nemoguardrails/logging/callbacks.py @@ -15,11 +15,15 @@ import logging import uuid from time import time -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast from uuid import UUID from langchain.callbacks import StdOutCallbackHandler -from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager +from langchain.callbacks.base import ( + AsyncCallbackHandler, + BaseCallbackHandler, + BaseCallbackManager, +) from langchain.callbacks.manager import AsyncCallbackManagerForChainRun from langchain.schema import AgentAction, AgentFinish, AIMessage, BaseMessage, LLMResult from langchain_core.outputs import ChatGeneration @@ -33,7 +37,7 @@ log = logging.getLogger(__name__) -class LoggingCallbackHandler(AsyncCallbackHandler, StdOutCallbackHandler): +class LoggingCallbackHandler(AsyncCallbackHandler): """Async callback handler that can be used to handle callbacks from langchain.""" async def on_llm_start( @@ -203,10 +207,17 @@ async def on_llm_end( ) log.info("Output Stats :: %s", response.llm_output) - took = llm_call_info.finished_at - llm_call_info.started_at - log.info("--- :: LLM call took %.2f seconds", took) - llm_stats.inc("total_time", took) - llm_call_info.duration = took + if ( + llm_call_info.finished_at is not None + and llm_call_info.started_at is not None + ): + took = llm_call_info.finished_at - llm_call_info.started_at + log.info("--- :: LLM call took %.2f seconds", took) + llm_stats.inc("total_time", took) + llm_call_info.duration = took + else: + log.warning("LLM call timing information incomplete") + llm_call_info.duration = 0.0 # Update the token usage stats as well token_stats_found = False @@ -278,7 +289,7 @@ async def on_llm_end( async def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -309,7 +320,7 @@ async def on_chain_end( async def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -340,7 +351,7 @@ async def on_tool_end( async def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -381,14 +392,15 @@ async def on_agent_finish( handlers = [LoggingCallbackHandler()] logging_callbacks = BaseCallbackManager( - handlers=handlers, inheritable_handlers=handlers + handlers=cast(List[BaseCallbackHandler], handlers), + inheritable_handlers=cast(List[BaseCallbackHandler], handlers), ) logging_callback_manager_for_chain = AsyncCallbackManagerForChainRun( run_id=uuid.uuid4(), parent_run_id=None, - handlers=handlers, - inheritable_handlers=handlers, + handlers=cast(List[BaseCallbackHandler], handlers), + inheritable_handlers=cast(List[BaseCallbackHandler], handlers), tags=[], inheritable_tags=[], ) diff --git a/nemoguardrails/logging/processing_log.py b/nemoguardrails/logging/processing_log.py index decc50181..4841b1c97 100644 --- a/nemoguardrails/logging/processing_log.py +++ b/nemoguardrails/logging/processing_log.py @@ -153,25 +153,36 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: action_params=event_data["action_params"], started_at=event["timestamp"], ) - activated_rail.executed_actions.append(executed_action) + if activated_rail is not None: + activated_rail.executed_actions.append(executed_action) elif event_type == "InternalSystemActionFinished": action_name = event_data["action_name"] if action_name in ignored_actions: continue - executed_action.finished_at = event["timestamp"] - executed_action.duration = ( - executed_action.finished_at - executed_action.started_at - ) - executed_action.return_value = event_data["return_value"] + if executed_action is not None: + executed_action.finished_at = event["timestamp"] + if ( + executed_action.finished_at is not None + and executed_action.started_at is not None + ): + executed_action.duration = ( + executed_action.finished_at - executed_action.started_at + ) + executed_action.return_value = event_data["return_value"] executed_action = None elif event_type in ["InputRailFinished", "OutputRailFinished"]: - activated_rail.finished_at = event["timestamp"] - activated_rail.duration = ( - activated_rail.finished_at - activated_rail.started_at - ) + if activated_rail is not None: + activated_rail.finished_at = event["timestamp"] + if ( + activated_rail.finished_at is not None + and activated_rail.started_at is not None + ): + activated_rail.duration = ( + activated_rail.finished_at - activated_rail.started_at + ) activated_rail = None elif event_type == "InputRailsFinished": @@ -181,14 +192,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: output_rails_finished_at = event["timestamp"] elif event["type"] == "llm_call_info": - executed_action.llm_calls.append(event["data"]) + if executed_action is not None: + executed_action.llm_calls.append(event["data"]) # If at the end of the processing we still have an active rail, it is because # we have hit a stop. In this case, we take the last timestamp as the timestamp for # finishing the rail. if activated_rail is not None: activated_rail.finished_at = last_timestamp - activated_rail.duration = activated_rail.finished_at - activated_rail.started_at + if ( + activated_rail.finished_at is not None + and activated_rail.started_at is not None + ): + activated_rail.duration = ( + activated_rail.finished_at - activated_rail.started_at + ) if activated_rail.type in ["input", "output"]: activated_rail.stop = True @@ -213,9 +231,13 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: if activated_rail.type in ["dialog", "generation"]: next_rail = generation_log.activated_rails[i + 1] activated_rail.finished_at = next_rail.started_at - activated_rail.duration = ( - activated_rail.finished_at - activated_rail.started_at - ) + if ( + activated_rail.finished_at is not None + and activated_rail.started_at is not None + ): + activated_rail.duration = ( + activated_rail.finished_at - activated_rail.started_at + ) # If we have output rails, we also record the general stats if output_rails_started_at: @@ -257,17 +279,21 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog: for executed_action in activated_rail.executed_actions: for llm_call in executed_action.llm_calls: - generation_log.stats.llm_calls_count += 1 - generation_log.stats.llm_calls_duration += llm_call.duration - generation_log.stats.llm_calls_total_prompt_tokens += ( - llm_call.prompt_tokens or 0 - ) - generation_log.stats.llm_calls_total_completion_tokens += ( - llm_call.completion_tokens or 0 - ) - generation_log.stats.llm_calls_total_tokens += ( - llm_call.total_tokens or 0 - ) + generation_log.stats.llm_calls_count = ( + generation_log.stats.llm_calls_count or 0 + ) + 1 + generation_log.stats.llm_calls_duration = ( + generation_log.stats.llm_calls_duration or 0 + ) + (llm_call.duration or 0) + generation_log.stats.llm_calls_total_prompt_tokens = ( + generation_log.stats.llm_calls_total_prompt_tokens or 0 + ) + (llm_call.prompt_tokens or 0) + generation_log.stats.llm_calls_total_completion_tokens = ( + generation_log.stats.llm_calls_total_completion_tokens or 0 + ) + (llm_call.completion_tokens or 0) + generation_log.stats.llm_calls_total_tokens = ( + generation_log.stats.llm_calls_total_tokens or 0 + ) + (llm_call.total_tokens or 0) generation_log.stats.total_duration = ( processing_log[-1]["timestamp"] - processing_log[0]["timestamp"] diff --git a/pyproject.toml b/pyproject.toml index 343c2deaa..6f841582d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ pyright = "^1.1.405" include = [ "nemoguardrails/rails/**", "nemoguardrails/actions/**", + "nemoguardrails/logging/**", "nemoguardrails/tracing/**", "tests/test_callbacks.py", ]