diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index a89b0f8af..c6f8439c5 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import re from typing import Any, Dict, List, Optional, Sequence, Union +logger = logging.getLogger(__name__) + from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager from langchain_core.runnables import RunnableConfig @@ -238,15 +241,78 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: def _store_reasoning_traces(response) -> None: + """Store reasoning traces from response in context variable. + + Extracts reasoning content from response.additional_kwargs["reasoning_content"] + if available. Otherwise, falls back to extracting from tags in the + response content (and removes the tags from content). + + Args: + response: The LLM response object + """ + + reasoning_content = _extract_reasoning_content(response) + + if not reasoning_content: + # Some LLM providers (e.g., certain NVIDIA models) embed reasoning in tags + # instead of properly populating reasoning_content in additional_kwargs, so we need + # both extraction methods to support different provider implementations. + reasoning_content = _extract_and_remove_think_tags(response) + + if reasoning_content: + reasoning_trace_var.set(reasoning_content) + + +def _extract_reasoning_content(response): if hasattr(response, "additional_kwargs"): additional_kwargs = response.additional_kwargs if ( isinstance(additional_kwargs, dict) and "reasoning_content" in additional_kwargs ): - reasoning_content = additional_kwargs["reasoning_content"] - if reasoning_content: - reasoning_trace_var.set(reasoning_content) + return additional_kwargs["reasoning_content"] + return None + + +def _extract_and_remove_think_tags(response) -> Optional[str]: + """Extract reasoning from tags and remove them from `response.content`. + + This function looks for ... tags in the response content, + and if found, extracts the reasoning content inside the tags. It has a side-effect: + it removes the full reasoning trace and tags from response.content. + + Args: + response: The LLM response object + + Returns: + The extracted reasoning content, or None if no tags found + """ + if not hasattr(response, "content"): + return None + + content = response.content + has_opening_tag = "" in content + has_closing_tag = "" in content + + if not has_opening_tag and not has_closing_tag: + return None + + if has_opening_tag != has_closing_tag: + logger.warning( + "Malformed tags detected: missing %s tag. " + "Skipping reasoning extraction to prevent corrupted content.", + "closing" if has_opening_tag else "opening", + ) + return None + + match = re.search(r"(.*?)", content, re.DOTALL) + if match: + reasoning_content = match.group(1).strip() + response.content = re.sub( + r".*?", "", content, flags=re.DOTALL + ).strip() + return reasoning_content + return None def _store_tool_calls(response) -> None: diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index cbd48ec62..a4e497762 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -15,9 +15,9 @@ import asyncio import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast -from annoy import AnnoyIndex +from annoy import AnnoyIndex # type: ignore from nemoguardrails.embeddings.cache import cache_embeddings from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem @@ -45,26 +45,14 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): max_batch_hold: The maximum time a batch is held before being processed """ - embedding_model: str - embedding_engine: str - embedding_params: Dict[str, Any] - index: AnnoyIndex - embedding_size: int - cache_config: EmbeddingsCacheConfig - embeddings: List[List[float]] - search_threshold: float - use_batching: bool - max_batch_size: int - max_batch_hold: float - def __init__( self, - embedding_model=None, - embedding_engine=None, - embedding_params=None, - index=None, - cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None, - search_threshold: float = None, + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + embedding_engine: str = "SentenceTransformers", + embedding_params: Optional[Dict[str, Any]] = None, + index: Optional[AnnoyIndex] = None, + cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, + search_threshold: float = float("inf"), use_batching: bool = False, max_batch_size: int = 10, max_batch_hold: float = 0.01, @@ -72,22 +60,23 @@ def __init__( """Initialize the BasicEmbeddingsIndex. Args: - embedding_model (str, optional): The model for computing embeddings. Defaults to None. - embedding_engine (str, optional): The engine for computing embeddings. Defaults to None. - index (AnnoyIndex, optional): The pre-existing index. Defaults to None. - cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None. + embedding_model: The model for computing embeddings. + embedding_engine: The engine for computing embeddings. + index: The pre-existing index. + cache_config: The cache configuration. + search_threshold: The threshold for filtering search results. use_batching: Whether to batch requests when computing the embeddings. max_batch_size: The maximum size of a batch. max_batch_hold: The maximum time a batch is held before being processed """ self._model: Optional[EmbeddingModel] = None - self._items = [] - self._embeddings = [] + self._items: List[IndexItem] = [] + self._embeddings: List[List[float]] = [] self.embedding_model = embedding_model self.embedding_engine = embedding_engine self.embedding_params = embedding_params or {} self._embedding_size = 0 - self.search_threshold = search_threshold or float("inf") + self.search_threshold = search_threshold if isinstance(cache_config, Dict): self._cache_config = EmbeddingsCacheConfig(**cache_config) else: @@ -95,12 +84,12 @@ def __init__( self._index = index # Data structures for batching embedding requests - self._req_queue = {} - self._req_results = {} - self._req_idx = 0 - self._current_batch_finished_event = None - self._current_batch_full_event = None - self._current_batch_submitted = asyncio.Event() + self._req_queue: Dict[int, str] = {} + self._req_results: Dict[int, List[float]] = {} + self._req_idx: int = 0 + self._current_batch_finished_event: Optional[asyncio.Event] = None + self._current_batch_full_event: Optional[asyncio.Event] = None + self._current_batch_submitted: asyncio.Event = asyncio.Event() # Initialize the batching configuration self.use_batching = use_batching @@ -112,6 +101,11 @@ def embeddings_index(self): """Get the current embedding index""" return self._index + @embeddings_index.setter + def embeddings_index(self, index): + """Setter to allow replacing the index dynamically.""" + self._index = index + @property def cache_config(self): """Get the cache configuration.""" @@ -127,16 +121,14 @@ def embeddings(self): """Get the computed embeddings.""" return self._embeddings - @embeddings_index.setter - def embeddings_index(self, index): - """Setter to allow replacing the index dynamically.""" - self._index = index - def _init_model(self): """Initialize the model used for computing the embeddings.""" + model = self.embedding_model + engine = self.embedding_engine + self._model = init_embedding_model( - embedding_model=self.embedding_model, - embedding_engine=self.embedding_engine, + embedding_model=model, + embedding_engine=engine, embedding_params=self.embedding_params, ) @@ -153,7 +145,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: if self._model is None: self._init_model() - embeddings = await self._model.encode_async(texts) + # self._model can't be None here, or self._init_model() would throw a ValueError + model: EmbeddingModel = cast(EmbeddingModel, self._model) + embeddings = await model.encode_async(texts) return embeddings async def add_item(self, item: IndexItem): @@ -199,6 +193,12 @@ async def _run_batch(self): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. + if ( + self._current_batch_full_event is None + or self._current_batch_finished_event is None + ): + raise RuntimeError("Batch events not initialized. This should not happen.") + done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), @@ -244,7 +244,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: self._req_idx += 1 self._req_queue[req_id] = text - if self._current_batch_finished_event is None: + if ( + self._current_batch_finished_event is None + or self._current_batch_full_event is None + ): self._current_batch_finished_event = asyncio.Event() self._current_batch_full_event = asyncio.Event() self._current_batch_submitted.clear() diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index 9abeb1de2..cdef48c27 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -20,7 +20,12 @@ from abc import ABC, abstractmethod from functools import singledispatchmethod from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional + +try: + import redis # type: ignore +except ImportError: + redis = None # type: ignore from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig @@ -30,6 +35,8 @@ class KeyGenerator(ABC): """Abstract class for key generators.""" + name: str # Class attribute that should be defined in subclasses + @abstractmethod def generate_key(self, text: str) -> str: pass @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str: class CacheStore(ABC): """Abstract class for cache stores.""" + name: str + @abstractmethod def get(self, key): """Get a value from the cache.""" @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore): name = "filesystem" - def __init__(self, cache_dir: str = None): + def __init__(self, cache_dir: Optional[str] = None): self._cache_dir = Path(cache_dir or ".cache/embeddings") self._cache_dir.mkdir(parents=True, exist_ok=True) @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore): name = "redis" def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0): - import redis - + if redis is None: + raise ImportError( + "Could not import redis, please install it with `pip install redis`." + ) self._redis = redis.Redis(host=host, port=port, db=db) def get(self, key): @@ -207,9 +218,9 @@ def clear(self): class EmbeddingsCache: def __init__( self, - key_generator: KeyGenerator = None, - cache_store: CacheStore = None, - store_config: dict = None, + key_generator: KeyGenerator, + cache_store: CacheStore, + store_config: Optional[dict] = None, ): self._key_generator = key_generator self._cache_store = cache_store @@ -218,7 +229,10 @@ def __init__( @classmethod def from_dict(cls, d: Dict[str, str]): key_generator = KeyGenerator.from_name(d.get("key_generator"))() - store_config = d.get("store_config") + store_config_raw = d.get("store_config") + store_config: dict = ( + store_config_raw if isinstance(store_config_raw, dict) else {} + ) cache_store = CacheStore.from_name(d.get("store"))(**store_config) return cls(key_generator=key_generator, cache_store=cache_store) @@ -239,7 +253,7 @@ def get_config(self): def get(self, texts): raise NotImplementedError - @get.register + @get.register(str) def _(self, text: str): key = self._key_generator.generate_key(text) log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") @@ -248,7 +262,7 @@ def _(self, text: str): return result - @get.register + @get.register(list) def _(self, texts: list): cached = {} @@ -266,13 +280,13 @@ def _(self, texts: list): def set(self, texts): raise NotImplementedError - @set.register + @set.register(str) def _(self, text: str, value: List[float]): key = self._key_generator.generate_key(text) log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") self._cache_store.set(key, value) - @set.register + @set.register(list) def _(self, texts: list, values: List[List[float]]): for text, value in zip(texts, values): self.set(text, value) diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py index 5c5906d5d..e77ab481a 100644 --- a/nemoguardrails/embeddings/providers/azureopenai.py +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -46,17 +46,16 @@ class AzureEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str): try: - from openai import AzureOpenAI + from openai import AzureOpenAI # type: ignore except ImportError: raise ImportError( - "Could not import openai, please install it with " - "`pip install openai`." + "Could not import openai, please install it with `pip install openai`." ) # Set Azure OpenAI API credentials self.client = AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore ) self.embedding_model = embedding_model diff --git a/nemoguardrails/embeddings/providers/cohere.py b/nemoguardrails/embeddings/providers/cohere.py index 34cee4156..704e0bcd7 100644 --- a/nemoguardrails/embeddings/providers/cohere.py +++ b/nemoguardrails/embeddings/providers/cohere.py @@ -14,7 +14,7 @@ # limitations under the License. import asyncio from contextvars import ContextVar -from typing import List +from typing import TYPE_CHECKING, List from .base import EmbeddingModel @@ -23,6 +23,10 @@ # is changed, it will fail. async_client_var: ContextVar = ContextVar("async_client", default=None) +if TYPE_CHECKING: + import cohere + from cohere import AsyncClient, Client + class CohereEmbeddingModel(EmbeddingModel): """ @@ -64,7 +68,7 @@ def __init__( self.model = embedding_model self.input_type = input_type - self.client = cohere.Client(**kwargs) + self.client = cohere.Client(**kwargs) # type: ignore[reportCallIssue] self.embedding_size_dict = { "embed-v4.0": 1536, @@ -120,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]: """ # Make embedding request to Cohere API - return self.client.embed( + # Since we don't pass embedding_types parameter, the response should be + # EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]] + response = self.client.embed( texts=documents, model=self.model, input_type=self.input_type - ).embeddings + ) + return response.embeddings # type: ignore[return-value] diff --git a/nemoguardrails/embeddings/providers/fastembed.py b/nemoguardrails/embeddings/providers/fastembed.py index 1062e566f..1359f7ab5 100644 --- a/nemoguardrails/embeddings/providers/fastembed.py +++ b/nemoguardrails/embeddings/providers/fastembed.py @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel): engine_name = "FastEmbed" def __init__(self, embedding_model: str, **kwargs): - from fastembed import TextEmbedding as Embedding + from fastembed import TextEmbedding as Embedding # type: ignore # Enabling a short form model name for all-MiniLM-L6-v2. if embedding_model == "all-MiniLM-L6-v2": diff --git a/nemoguardrails/embeddings/providers/google.py b/nemoguardrails/embeddings/providers/google.py index cf55399af..1f78974e6 100644 --- a/nemoguardrails/embeddings/providers/google.py +++ b/nemoguardrails/embeddings/providers/google.py @@ -46,7 +46,7 @@ class GoogleEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from google import genai + from google import genai # type: ignore[import] except ImportError: raise ImportError( diff --git a/nemoguardrails/embeddings/providers/nim.py b/nemoguardrails/embeddings/providers/nim.py index dd5690a4d..8ea9c1d0f 100644 --- a/nemoguardrails/embeddings/providers/nim.py +++ b/nemoguardrails/embeddings/providers/nim.py @@ -35,7 +35,7 @@ class NIMEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings + from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # type: ignore self.model = embedding_model self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs) diff --git a/nemoguardrails/embeddings/providers/openai.py b/nemoguardrails/embeddings/providers/openai.py index 83f83f8c2..bd12f2333 100644 --- a/nemoguardrails/embeddings/providers/openai.py +++ b/nemoguardrails/embeddings/providers/openai.py @@ -46,14 +46,14 @@ def __init__( **kwargs, ): try: - import openai - from openai import AsyncOpenAI, OpenAI + import openai # type: ignore + from openai import AsyncOpenAI, OpenAI # type: ignore except ImportError: raise ImportError( "Could not import openai, please install it with " "`pip install openai`." ) - if openai.__version__ < "1.0.0": + if openai.__version__ < "1.0.0": # type: ignore raise RuntimeError( "`openai<1.0.0` is no longer supported. " "Please upgrade using `pip install openai>=1.0.0`." diff --git a/nemoguardrails/embeddings/providers/sentence_transformers.py b/nemoguardrails/embeddings/providers/sentence_transformers.py index 7ffcec712..cc7ce7be8 100644 --- a/nemoguardrails/embeddings/providers/sentence_transformers.py +++ b/nemoguardrails/embeddings/providers/sentence_transformers.py @@ -43,7 +43,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from sentence_transformers import SentenceTransformer + from sentence_transformers import SentenceTransformer # type: ignore except ImportError: raise ImportError( "Could not import sentence-transformers, please install it with " @@ -51,7 +51,7 @@ def __init__(self, embedding_model: str, **kwargs): ) try: - from torch import cuda + from torch import cuda # type: ignore except ImportError: raise ImportError( "Could not import torch, please install it with `pip install torch`." diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index c3c43c3e2..6769dec1e 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -22,7 +22,7 @@ import time import warnings from contextlib import asynccontextmanager -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -42,14 +42,32 @@ logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) + +class GuardrailsApp(FastAPI): + """Custom FastAPI subclass with additional attributes for Guardrails server.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Initialize custom attributes + self.default_config_id: Optional[str] = None + self.rails_config_path: str = "" + self.disable_chat_ui: bool = False + self.auto_reload: bool = False + self.stop_signal: bool = False + self.single_config_mode: bool = False + self.single_config_id: Optional[str] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + self.task: Optional[asyncio.Future] = None + + # The list of registered loggers. Can be used to send logs to various # backends and storage engines. -registered_loggers = [] +registered_loggers: List[Callable] = [] api_description = """Guardrails Sever API.""" # The headers for each request -api_request_headers = contextvars.ContextVar("headers") +api_request_headers: contextvars.ContextVar = contextvars.ContextVar("headers") # The datastore that the Server should use. # This is currently used only for storing threads. @@ -59,7 +77,7 @@ @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: GuardrailsApp): # Startup logic here """Register any additional challenges, if available at startup.""" challenges_files = os.path.join(app.rails_config_path, "challenges.json") @@ -82,8 +100,11 @@ async def lifespan(app: FastAPI): if os.path.exists(filepath): filename = os.path.basename(filepath) spec = importlib.util.spec_from_file_location(filename, filepath) - config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config_module) + if spec is not None and spec.loader is not None: + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + else: + config_module = None # If there is an `init` function, we call it with the reference to the app. if config_module is not None and hasattr(config_module, "init"): @@ -110,6 +131,7 @@ async def root_handler(): if app.auto_reload: app.loop = asyncio.get_running_loop() + # Store the future directly as task app.task = app.loop.run_in_executor(None, start_auto_reload_monitoring) yield @@ -117,14 +139,14 @@ async def root_handler(): # Shutdown logic here if app.auto_reload: app.stop_signal = True - if hasattr(app, "task"): + if hasattr(app, "task") and app.task is not None: app.task.cancel() log.info("Shutting down file observer") else: pass -app = FastAPI( +app = GuardrailsApp( title="Guardrails Server API", description=api_description, version="0.1.0", @@ -186,7 +208,7 @@ class RequestBody(BaseModel): max_length=255, description="The id of an existing thread to which the messages should be added.", ) - messages: List[dict] = Field( + messages: Optional[List[dict]] = Field( default=None, description="The list of messages in the current conversation." ) context: Optional[dict] = Field( @@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values): class ResponseBody(BaseModel): - messages: List[dict] = Field( + messages: Optional[List[dict]] = Field( default=None, description="The new messages in the conversation" ) llm_output: Optional[dict] = Field( @@ -282,8 +304,8 @@ async def get_rails_configs(): # One instance of LLMRails per config id -llm_rails_instances = {} -llm_rails_events_history_cache = {} +llm_rails_instances: dict[str, LLMRails] = {} +llm_rails_events_history_cache: dict[str, dict] = {} def _generate_cache_key(config_ids: List[str]) -> str: @@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails: # get the same thing. config_ids = [""] - full_llm_rails_config = None + full_llm_rails_config: Optional[RailsConfig] = None for config_id in config_ids: base_path = os.path.abspath(app.rails_config_path) @@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails: else: full_llm_rails_config += rails_config + if full_llm_rails_config is None: + raise ValueError("No valid rails configuration found.") + llm_rails = LLMRails(config=full_llm_rails_config, verbose=True) llm_rails_instances[configs_cache_key] = llm_rails @@ -360,30 +385,33 @@ async def chat_completion(body: RequestBody, request: Request): # Save the request headers in a context variable. api_request_headers.set(request.headers) + # Use Request config_ids if set, otherwise use the FastAPI default config. + # If neither is available we can't generate any completions as we have no config_id config_ids = body.config_ids - if not config_ids and app.default_config_id: - config_ids = [app.default_config_id] - elif not config_ids and not app.default_config_id: - raise GuardrailsConfigurationError( - "No 'config_id' provided and no default configuration is set for the server. " - "You must set a 'config_id' in your request or set use --default-config-id when . " - ) + if not config_ids: + if app.default_config_id: + config_ids = [app.default_config_id] + else: + raise GuardrailsConfigurationError( + "No request config_ids provided and server has no default configuration" + ) + try: llm_rails = _get_rails(config_ids) except ValueError as ex: log.exception(ex) - return { - "messages": [ + return ResponseBody( + messages=[ { "role": "assistant", "content": f"Could not load the {config_ids} guardrails configuration. " f"An internal error has occurred.", } ] - } + ) try: - messages = body.messages + messages = body.messages or [] if body.context: messages.insert(0, {"role": "context", "content": body.context}) @@ -396,14 +424,14 @@ async def chat_completion(body: RequestBody, request: Request): # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: - return { - "messages": [ + return ResponseBody( + messages=[ { "role": "assistant", "content": "The `thread_id` must have a minimum length of 16 characters.", } ] - } + ) # Fetch the existing thread messages. For easier management, we prepend # the string `thread-` to all thread keys. @@ -440,32 +468,37 @@ async def chat_completion(body: RequestBody, request: Request): ) if isinstance(res, GenerationResponse): - bot_message = res.response[0] + bot_message_content = res.response[0] + # Ensure bot_message is always a dict + if isinstance(bot_message_content, str): + bot_message = {"role": "assistant", "content": bot_message_content} + else: + bot_message = bot_message_content else: assert isinstance(res, dict) bot_message = res # If we're using threads, we also need to update the data before returning # the message. - if body.thread_id: + if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = {"messages": [bot_message]} + result = ResponseBody(messages=[bot_message]) # If we have additional GenerationResponse fields, we return as well if isinstance(res, GenerationResponse): - result["llm_output"] = res.llm_output - result["output_data"] = res.output_data - result["log"] = res.log - result["state"] = res.state + result.llm_output = res.llm_output + result.output_data = res.output_data + result.log = res.log + result.state = res.state return result except Exception as ex: log.exception(ex) - return { - "messages": [{"role": "assistant", "content": "Internal server error."}] - } + return ResponseBody( + messages=[{"role": "assistant", "content": "Internal server error."}] + ) # By default, there are no challenges @@ -498,7 +531,7 @@ def register_datastore(datastore_instance: DataStore): datastore = datastore_instance -def register_logger(logger: callable): +def register_logger(logger: Callable): """Register an additional logger""" registered_loggers.append(logger) @@ -510,8 +543,7 @@ def start_auto_reload_monitoring(): from watchdog.observers import Observer class Handler(FileSystemEventHandler): - @staticmethod - def on_any_event(event): + def on_any_event(self, event): if event.is_directory: return None @@ -521,7 +553,8 @@ def on_any_event(event): ) # Compute the relative path - rel_path = os.path.relpath(event.src_path, app.rails_config_path) + src_path_str = str(event.src_path) + rel_path = os.path.relpath(src_path_str, app.rails_config_path) # The config_id is the first component parts = rel_path.split(os.path.sep) @@ -530,7 +563,7 @@ def on_any_event(event): if ( not parts[-1].startswith(".") and ".ipynb_checkpoints" not in parts - and os.path.isfile(event.src_path) + and os.path.isfile(src_path_str) ): # We just remove the config from the cache so that a new one is used next time if config_id in llm_rails_instances: diff --git a/nemoguardrails/server/datastore/redis_store.py b/nemoguardrails/server/datastore/redis_store.py index 4f6437f96..6e436dbff 100644 --- a/nemoguardrails/server/datastore/redis_store.py +++ b/nemoguardrails/server/datastore/redis_store.py @@ -16,7 +16,10 @@ import asyncio from typing import Optional -import aioredis +try: + import aioredis # type: ignore[import] +except ImportError: + aioredis = None # type: ignore[assignment] from nemoguardrails.server.datastore.datastore import DataStore @@ -35,6 +38,11 @@ def __init__( username: [Optional] The username to use for authentication. password: [Optional] The password to use for authentication """ + if aioredis is None: + raise ImportError( + "aioredis is required for RedisStore. Install it with: pip install aioredis" + ) + self.url = url self.username = username self.password = password diff --git a/pyproject.toml b/pyproject.toml index 6be833997..2e79e544d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,10 +157,12 @@ pyright = "^1.1.405" include = [ "nemoguardrails/rails/**", "nemoguardrails/actions/**", + "nemoguardrails/embeddings/**", "nemoguardrails/cli/**", "nemoguardrails/kb/**", "nemoguardrails/logging/**", "nemoguardrails/tracing/**", + "nemoguardrails/server/**", "tests/test_callbacks.py", ] diff --git a/tests/conftest.py b/tests/conftest.py index 1dc00134b..2e3f0c1d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,5 +22,15 @@ ) +@pytest.fixture(autouse=True) +def reset_reasoning_trace_var(): + """Reset reasoning_trace_var before each test to prevent state leakage.""" + from nemoguardrails.context import reasoning_trace_var + + reasoning_trace_var.set(None) + yield + reasoning_trace_var.set(None) + + def pytest_configure(config): patch("prompt_toolkit.PromptSession", autospec=True).start() diff --git a/tests/test_actions_llm_utils.py b/tests/test_actions_llm_utils.py index 9c238dda2..8f0accbd2 100644 --- a/tests/test_actions_llm_utils.py +++ b/tests/test_actions_llm_utils.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemoguardrails.actions.llm.utils import _infer_provider_from_module +from nemoguardrails.actions.llm.utils import ( + _extract_and_remove_think_tags, + _infer_provider_from_module, + _store_reasoning_traces, +) +from nemoguardrails.context import reasoning_trace_var class MockOpenAILLM: @@ -123,3 +128,179 @@ class Wrapper3(Wrapper2): llm = Wrapper3() provider = _infer_provider_from_module(llm) assert provider == "anthropic" + + +class MockResponse: + def __init__(self, content="", additional_kwargs=None): + self.content = content + self.additional_kwargs = additional_kwargs or {} + + +def test_store_reasoning_traces_from_additional_kwargs(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="The answer is 42", + additional_kwargs={"reasoning_content": "Let me think about this..."}, + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "Let me think about this..." + + +def test_store_reasoning_traces_from_think_tags(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="Let me think about this...The answer is 42" + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "Let me think about this..." + assert response.content == "The answer is 42" + + +def test_store_reasoning_traces_multiline_think_tags(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solutionThe answer is 42" + ) + + _store_reasoning_traces(response) + + assert ( + reasoning_trace_var.get() + == "Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solution" + ) + assert response.content == "The answer is 42" + + +def test_store_reasoning_traces_prefers_additional_kwargs(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="This should not be usedThe answer is 42", + additional_kwargs={"reasoning_content": "This should be used"}, + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "This should be used" + + +def test_store_reasoning_traces_no_reasoning_content(): + reasoning_trace_var.set(None) + + response = MockResponse(content="The answer is 42") + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_empty_reasoning_content(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="The answer is 42", additional_kwargs={"reasoning_content": ""} + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_incomplete_think_tags(): + reasoning_trace_var.set(None) + + response = MockResponse(content="This is incomplete") + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_no_content_attribute(): + reasoning_trace_var.set(None) + + class ResponseWithoutContent: + def __init__(self): + self.additional_kwargs = {} + + response = ResponseWithoutContent() + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_removes_think_tags_with_whitespace(): + reasoning_trace_var.set(None) + + response = MockResponse( + content=" reasoning here \n\n Final answer " + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "reasoning here" + assert response.content == "Final answer" + + +def test_extract_and_remove_think_tags_basic(): + response = MockResponse(content="reasoninganswer") + + result = _extract_and_remove_think_tags(response) + + assert result == "reasoning" + assert response.content == "answer" + + +def test_extract_and_remove_think_tags_multiline(): + response = MockResponse(content="line1\nline2\nline3final answer") + + result = _extract_and_remove_think_tags(response) + + assert result == "line1\nline2\nline3" + assert response.content == "final answer" + + +def test_extract_and_remove_think_tags_no_tags(): + response = MockResponse(content="just a normal response") + + result = _extract_and_remove_think_tags(response) + + assert result is None + assert response.content == "just a normal response" + + +def test_extract_and_remove_think_tags_incomplete(): + response = MockResponse(content="incomplete") + + result = _extract_and_remove_think_tags(response) + + assert result is None + assert response.content == "incomplete" + + +def test_extract_and_remove_think_tags_no_content_attribute(): + class ResponseWithoutContent: + pass + + response = ResponseWithoutContent() + + result = _extract_and_remove_think_tags(response) + + assert result is None + + +def test_extract_and_remove_think_tags_wrong_order(): + response = MockResponse(content=" text here ") + + result = _extract_and_remove_think_tags(response) + + assert result is None + assert response.content == " text here " diff --git a/tests/test_reasoning_trace_extraction.py b/tests/test_reasoning_trace_extraction.py index b794de0e4..a74892679 100644 --- a/tests/test_reasoning_trace_extraction.py +++ b/tests/test_reasoning_trace_extraction.py @@ -304,3 +304,94 @@ async def test_reasoning_content_with_other_additional_kwargs(self): assert stored_trace == test_reasoning reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_extracts_reasoning_from_think_tags(self): + test_reasoning = "Let me analyze this step by step" + + mock_llm = AsyncMock() + mock_response = AIMessage( + content=f"{test_reasoning}The answer is 42", + additional_kwargs={}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "What is the answer?") + + assert result == "The answer is 42" + assert "" not in result + stored_trace = reasoning_trace_var.get() + assert stored_trace == test_reasoning + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_prefers_additional_kwargs_over_think_tags(self): + reasoning_from_kwargs = "This should be used" + reasoning_from_tags = "This should be ignored" + + mock_llm = AsyncMock() + mock_response = AIMessage( + content=f"{reasoning_from_tags}Response", + additional_kwargs={"reasoning_content": reasoning_from_kwargs}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Query") + + assert result == f"{reasoning_from_tags}Response" + stored_trace = reasoning_trace_var.get() + assert stored_trace == reasoning_from_kwargs + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_extracts_multiline_reasoning_from_think_tags(self): + multiline_reasoning = """Step 1: Understand the question +Step 2: Break down the problem +Step 3: Formulate the answer""" + + mock_llm = AsyncMock() + mock_response = AIMessage( + content=f"{multiline_reasoning}Final answer", + additional_kwargs={}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Question") + + assert result == "Final answer" + assert "" not in result + stored_trace = reasoning_trace_var.get() + assert stored_trace == multiline_reasoning + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_handles_incomplete_think_tags(self): + mock_llm = AsyncMock() + mock_response = AIMessage( + content="This is incomplete", + additional_kwargs={}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Query") + + assert result == "This is incomplete" + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) diff --git a/tests/v2_x/test_passthroug_mode.py b/tests/v2_x/test_passthroug_mode.py index c75112f56..9421a1dd1 100644 --- a/tests/v2_x/test_passthroug_mode.py +++ b/tests/v2_x/test_passthroug_mode.py @@ -81,9 +81,6 @@ def test_passthrough_llm_action_not_invoked_via_logs(self): self.assertIn("content", response) self.assertIsInstance(response["content"], str) - @unittest.skip( - reason="Github issue https://github.com/NVIDIA/NeMo-Guardrails/issues/1378" - ) def test_passthrough_llm_action_invoked_via_logs(self): chat = TestChat( config,