diff --git a/core/src/utcp/plugins/plugin_loader.py b/core/src/utcp/plugins/plugin_loader.py index 18b6b0b..0eb5668 100644 --- a/core/src/utcp/plugins/plugin_loader.py +++ b/core/src/utcp/plugins/plugin_loader.py @@ -9,6 +9,13 @@ def _load_plugins(): from utcp.data.auth_implementations import OAuth2AuthSerializer, BasicAuthSerializer, ApiKeyAuthSerializer from utcp.data.variable_loader_implementations import DotEnvVariableLoaderSerializer from utcp.implementations.post_processors import FilterDictPostProcessorConfigSerializer, LimitStringsPostProcessorConfigSerializer + + # Try to import optional plugin, skip if not installed + try: + from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategyConfigSerializer + in_mem_embeddings_available = True + except ImportError: + in_mem_embeddings_available = False register_auth("oauth2", OAuth2AuthSerializer()) register_auth("basic", BasicAuthSerializer()) @@ -19,6 +26,10 @@ def _load_plugins(): register_tool_repository(ConcurrentToolRepositoryConfigSerializer.default_repository, InMemToolRepositoryConfigSerializer()) register_tool_search_strategy(ToolSearchStrategyConfigSerializer.default_strategy, TagAndDescriptionWordMatchStrategyConfigSerializer()) + + # Register optional plugin only if available + if in_mem_embeddings_available: + register_tool_search_strategy("in_mem_embeddings", InMemEmbeddingsSearchStrategyConfigSerializer()) register_tool_post_processor("filter_dict", FilterDictPostProcessorConfigSerializer()) register_tool_post_processor("limit_strings", LimitStringsPostProcessorConfigSerializer()) diff --git a/plugins/tool_search/embedding/README.md b/plugins/tool_search/embedding/README.md new file mode 100644 index 0000000..71cfa9e --- /dev/null +++ b/plugins/tool_search/embedding/README.md @@ -0,0 +1,19 @@ +# UTCP Embedding Search Plugin + +This plugin registers the embedding-based semantic search strategy with UTCP 1.0 via entry points. + +## Installation + +```bash +pip install utcp-embedding-search +``` + +Optionally, for high-quality embeddings: + +```bash +pip install "utcp-in-mem-embeddings[embedding]" +``` + +## How it works + +When installed, this package exposes an entry point under `utcp.plugins` so the UTCP core can auto-discover and register the `in_mem_embeddings` strategy. diff --git a/plugins/tool_search/in_mem_embeddings/README.md b/plugins/tool_search/in_mem_embeddings/README.md new file mode 100644 index 0000000..5a844a6 --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/README.md @@ -0,0 +1,39 @@ +# UTCP In-Memory Embeddings Search Plugin + +This plugin registers the in-memory embedding-based semantic search strategy with UTCP 1.0 via entry points. + +## Installation + +```bash +pip install utcp-in-mem-embeddings +``` + +Optionally, for high-quality embeddings: + +```bash +pip install "utcp-in-mem-embeddings[embedding]" +``` + +Or install the required dependencies directly: + +```bash +pip install "sentence-transformers>=2.2.0" "torch>=1.9.0" +``` + +## Why are sentence-transformers and torch needed? + +While the plugin works without these packages (using a simple character frequency-based fallback), installing them provides significant benefits: + +- **Enhanced Semantic Understanding**: The `sentence-transformers` package provides pre-trained models that convert text into high-quality vector embeddings, capturing the semantic meaning of text rather than just keywords. + +- **Better Search Results**: With these packages installed, the search can understand conceptual similarity between queries and tools, even when they don't share exact keywords. + +- **Performance**: The default model (all-MiniLM-L6-v2) offers a good balance between quality and performance for semantic search applications. + +- **Fallback Mechanism**: Without these packages, the plugin automatically falls back to a simpler text similarity method, which works but with reduced accuracy. + +## How it works + +When installed, this package exposes an entry point under `utcp.plugins` so the UTCP core can auto-discover and register the `in_mem_embeddings` strategy. + +The embeddings are cached in memory for improved performance during repeated searches. diff --git a/plugins/tool_search/in_mem_embeddings/pyproject.toml b/plugins/tool_search/in_mem_embeddings/pyproject.toml new file mode 100644 index 0000000..e77ba4e --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "utcp-in-mem-embeddings" +version = "1.0.0" +authors = [ + { name = "UTCP Contributors" }, +] +description = "UTCP plugin providing in-memory embedding-based semantic tool search." +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "utcp>=1.0", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", +] +license = "MPL-2.0" + +[project.optional-dependencies] +embedding = [ + "sentence-transformers>=2.2.0", + "torch>=1.9.0", +] + + +[project.urls] +Homepage = "https://utcp.io" +Source = "https://github.com/universal-tool-calling-protocol/python-utcp" +Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" + +[project.entry-points."utcp.plugins"] +in_mem_embeddings = "utcp_in_mem_embeddings:register" diff --git a/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/__init__.py b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/__init__.py new file mode 100644 index 0000000..53da84d --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/__init__.py @@ -0,0 +1,7 @@ +from utcp.plugins.discovery import register_tool_search_strategy +from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategyConfigSerializer + + +def register(): + """Entry point function to register the in-memory embeddings search strategy.""" + register_tool_search_strategy("in_mem_embeddings", InMemEmbeddingsSearchStrategyConfigSerializer()) diff --git a/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/in_mem_embeddings_search.py b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/in_mem_embeddings_search.py new file mode 100644 index 0000000..669748d --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/in_mem_embeddings_search.py @@ -0,0 +1,241 @@ +"""In-memory embedding-based semantic search strategy for UTCP tools. + +This module provides a semantic search implementation that uses sentence embeddings +to find tools based on meaning similarity rather than just keyword matching. +Embeddings are cached in memory for improved performance. +""" + +import asyncio +import logging +from typing import List, Tuple, Optional, Literal, Dict, Any +from concurrent.futures import ThreadPoolExecutor +import numpy as np +from pydantic import BaseModel, Field, PrivateAttr + +from utcp.interfaces.tool_search_strategy import ToolSearchStrategy +from utcp.data.tool import Tool +from utcp.interfaces.concurrent_tool_repository import ConcurrentToolRepository +from utcp.interfaces.serializer import Serializer + +logger = logging.getLogger(__name__) + +class InMemEmbeddingsSearchStrategy(ToolSearchStrategy): + """In-memory semantic search strategy using sentence embeddings. + + This strategy converts tool descriptions and search queries into numerical + embeddings and finds the most semantically similar tools using cosine similarity. + Embeddings are cached in memory for improved performance during repeated searches. + """ + + tool_search_strategy_type: Literal["in_mem_embeddings"] = "in_mem_embeddings" + + # Configuration parameters + model_name: str = Field( + default="all-MiniLM-L6-v2", + description="Sentence transformer model name to use for embeddings. " + "Accepts any model from Hugging Face sentence-transformers library. " + "Popular options: 'all-MiniLM-L6-v2' (fast, good quality), " + "'all-mpnet-base-v2' (slower, higher quality), " + "'paraphrase-MiniLM-L6-v2' (paraphrase detection). " + "See https://huggingface.co/sentence-transformers for full list." + ) + similarity_threshold: float = Field(default=0.3, description="Minimum similarity score to consider a match") + max_workers: int = Field(default=4, description="Maximum number of worker threads for embedding generation") + cache_embeddings: bool = Field(default=True, description="Whether to cache tool embeddings for performance") + + # Private attributes + _embedding_model: Optional[Any] = PrivateAttr(default=None) + _tool_embeddings_cache: Dict[str, np.ndarray] = PrivateAttr(default_factory=dict) + _executor: Optional[ThreadPoolExecutor] = PrivateAttr(default=None) + _model_loaded: bool = PrivateAttr(default=False) + + def __init__(self, **data): + super().__init__(**data) + self._executor = ThreadPoolExecutor(max_workers=self.max_workers) + + async def _ensure_model_loaded(self): + """Ensure the embedding model is loaded.""" + if self._model_loaded: + return + + try: + # Import sentence-transformers here to avoid dependency issues + from sentence_transformers import SentenceTransformer + + # Load the model in a thread to avoid blocking + loop = asyncio.get_running_loop() + self._embedding_model = await loop.run_in_executor( + self._executor, + SentenceTransformer, + self.model_name + ) + self._model_loaded = True + logger.info(f"Loaded embedding model: {self.model_name}") + + except ImportError: + logger.warning("sentence-transformers not available, falling back to simple text similarity") + self._embedding_model = None + self._model_loaded = True + except Exception as e: + logger.error(f"Failed to load embedding model: {e}") + self._embedding_model = None + self._model_loaded = True + + async def _get_text_embedding(self, text: str) -> np.ndarray: + """Generate embedding for given text.""" + if not text: + return np.zeros(384) # Default dimension for all-MiniLM-L6-v2 + + if self._embedding_model is None: + # Fallback to simple text similarity + return self._simple_text_embedding(text) + + try: + loop = asyncio.get_event_loop() + embedding = await loop.run_in_executor( + self._executor, + self._embedding_model.encode, + text + ) + return embedding + except Exception as e: + logger.warning(f"Failed to generate embedding for text: {e}") + return self._simple_text_embedding(text) + + def _simple_text_embedding(self, text: str) -> np.ndarray: + """Simple fallback embedding using character frequency.""" + # Create a simple embedding based on character frequency + # This is a fallback when sentence-transformers is not available + embedding = np.zeros(384) + text_lower = text.lower() + + # Simple character frequency-based embedding + for i, char in enumerate(text_lower): + embedding[i % 384] += ord(char) / 1000.0 + + # Normalize + norm = np.linalg.norm(embedding) + if norm > 0: + embedding = embedding / norm + + return embedding + + async def _get_tool_embedding(self, tool: Tool) -> np.ndarray: + """Get or generate embedding for a tool.""" + if not self.cache_embeddings or tool.name not in self._tool_embeddings_cache: + # Create text representation of the tool + tool_text = f"{tool.name} {tool.description} {' '.join(tool.tags)}" + embedding = await self._get_text_embedding(tool_text) + + if self.cache_embeddings: + self._tool_embeddings_cache[tool.name] = embedding + + return embedding + + return self._tool_embeddings_cache[tool.name] + + def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: + """Calculate cosine similarity between two vectors.""" + try: + dot_product = np.dot(a, b) + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return dot_product / (norm_a * norm_b) + except Exception as e: + logger.warning(f"Error calculating cosine similarity: {e}") + return 0.0 + + async def search_tools( + self, + tool_repository: ConcurrentToolRepository, + query: str, + limit: int = 10, + any_of_tags_required: Optional[List[str]] = None + ) -> List[Tool]: + """Search for tools using semantic similarity. + + Args: + tool_repository: The tool repository to search within. + query: The search query string. + limit: Maximum number of tools to return. + any_of_tags_required: Optional list of tags where one of them must be present. + + Returns: + List of Tool objects ranked by semantic similarity. + """ + if limit < 0: + raise ValueError("limit must be non-negative") + + # Ensure the embedding model is loaded + await self._ensure_model_loaded() + + # Get all tools + tools: List[Tool] = await tool_repository.get_tools() + + # Filter by required tags if specified + if any_of_tags_required and len(any_of_tags_required) > 0: + any_of_tags_required = [tag.lower() for tag in any_of_tags_required] + tools = [ + tool for tool in tools + if any(tag.lower() in any_of_tags_required for tag in tool.tags) + ] + + if not tools: + return [] + + # Generate query embedding + query_embedding = await self._get_text_embedding(query) + + # Calculate similarity scores for all tools + tool_scores: List[Tuple[Tool, float]] = [] + + for tool in tools: + try: + tool_embedding = await self._get_tool_embedding(tool) + similarity = self._cosine_similarity(query_embedding, tool_embedding) + + if similarity >= self.similarity_threshold: + tool_scores.append((tool, similarity)) + + except Exception as e: + logger.warning(f"Error processing tool {tool.name}: {e}") + continue + + # Sort by similarity score (descending) + sorted_tools = [ + tool for tool, score in sorted( + tool_scores, + key=lambda x: x[1], + reverse=True + ) + ] + + # Return up to 'limit' tools + return sorted_tools[:limit] if limit > 0 else sorted_tools + + async def __aenter__(self): + """Async context manager entry.""" + await self._ensure_model_loaded() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._executor: + self._executor.shutdown(wait=False) + + +class InMemEmbeddingsSearchStrategyConfigSerializer(Serializer[InMemEmbeddingsSearchStrategy]): + """Serializer for InMemEmbeddingsSearchStrategy configuration.""" + + def to_dict(self, obj: InMemEmbeddingsSearchStrategy) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> InMemEmbeddingsSearchStrategy: + try: + return InMemEmbeddingsSearchStrategy.model_validate(data) + except Exception as e: + raise ValueError(f"Invalid configuration: {e}") from e diff --git a/plugins/tool_search/in_mem_embeddings/test_integration.py b/plugins/tool_search/in_mem_embeddings/test_integration.py new file mode 100644 index 0000000..908be2b --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/test_integration.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""Integration test to verify the plugin works with the core UTCP system.""" + +import sys +import asyncio +from pathlib import Path + +# Add paths +plugin_src = (Path(__file__).parent / "src").resolve() +core_src = (Path(__file__).parent.parent.parent.parent / "core" / "src").resolve() +sys.path.insert(0, str(plugin_src)) +sys.path.insert(0, str(core_src)) + +async def test_integration(): + """Test plugin integration with core system.""" + print("๐Ÿ”— Testing Integration with Core UTCP System...") + + try: + # Test 1: Plugin registration + print("1. Testing plugin registration...") + from utcp_in_mem_embeddings import register + register() + print(" โœ… Plugin registered successfully") + + # Test 2: Core system can discover the plugin + print("2. Testing plugin discovery...") + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + strategies = ToolSearchStrategyConfigSerializer.tool_search_strategy_implementations + assert "in_mem_embeddings" in strategies + print(" โœ… Plugin discovered by core system") + + # Test 3: Create strategy through core system + print("3. Testing strategy creation through core...") + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + serializer = ToolSearchStrategyConfigSerializer() + + # This should work if the plugin is properly registered + strategy_config = { + "tool_search_strategy_type": "in_mem_embeddings", + "model_name": "all-MiniLM-L6-v2", + "similarity_threshold": 0.3 + } + + strategy = serializer.validate_dict(strategy_config) + print(f" โœ… Strategy created: {strategy.tool_search_strategy_type}") + + # Test 4: Basic functionality test + print("4. Testing basic search functionality...") + from utcp.data.tool import Tool, JsonSchema + from utcp.data.call_template import CallTemplate + from utcp.implementations.in_mem_tool_repository import InMemToolRepository + + # Create sample tools + tools = [ + Tool( + name="test.tool1", + description="A test tool for cooking", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["cooking", "test"], + tool_call_template=CallTemplate( + name="test.tool1", + call_template_type="default" + ) + ) + ] + + # Create repository + repo = InMemToolRepository() + + # Create a manual and add it to the repository + from utcp.data.utcp_manual import UtcpManual + manual = UtcpManual(tools=tools) + manual_call_template = CallTemplate(name="test_manual", call_template_type="default") + await repo.save_manual(manual_call_template, manual) + + + # Test search + results = await strategy.search_tools(repo, "cooking", limit=1) + print(f" โœ… Search completed, found {len(results)} results") + + # Validate search results + assert len(results) > 0, "Search should return at least one result for 'cooking' query" + + print("\n๐ŸŽ‰ Integration test passed! Plugin works with core system.") + return True + + except Exception as e: + print(f"โŒ Integration test failed: {e}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + success = asyncio.run(test_integration()) + sys.exit(0 if success else 1) diff --git a/plugins/tool_search/in_mem_embeddings/test_performance.py b/plugins/tool_search/in_mem_embeddings/test_performance.py new file mode 100644 index 0000000..b447a35 --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/test_performance.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Performance test for the in-memory embeddings plugin.""" + +import sys +import asyncio +import time +from pathlib import Path + +# Add paths +plugin_src = Path(__file__).parent / "src" +core_src = Path(__file__).parent.parent.parent.parent / "core" / "src" +sys.path.insert(0, str(plugin_src)) +sys.path.insert(0, str(core_src)) + +async def test_performance(): + """Test plugin performance with multiple tools and searches.""" + print("โšก Testing Performance...") + + try: + from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategy + from utcp.data.tool import Tool, JsonSchema + from utcp.data.call_template import CallTemplate + + # Create strategy + strategy = InMemEmbeddingsSearchStrategy( + model_name="all-MiniLM-L6-v2", + similarity_threshold=0.3, + max_workers=2, + cache_embeddings=True + ) + + # Create many tools + print("1. Creating 100 test tools...") + tools = [] + for i in range(100): + tool = Tool( + name=f"test_tool{i}", + description=f"Test tool {i} for various purposes like cooking, coding, data analysis", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["test", f"category{i%5}"], + tool_call_template=CallTemplate( + name=f"test_tool{i}", + description=f"Test tool {i}", + call_template_type="default" + ) + ) + tools.append(tool) + + # Mock repository + class MockRepo: + def __init__(self, tools): + self.tools = tools + async def get_tools(self): + return self.tools + + repo = MockRepo(tools) + + # Test 1: First search (cold start) + print("2. Testing cold start performance...") + start_time = time.perf_counter() + results1 = await strategy.search_tools(repo, "cooking tools", limit=10) + cold_time = time.perf_counter() - start_time + print(f" โฑ๏ธ Cold start: {cold_time:.3f}s, found {len(results1)} results") + + # Test 2: Second search (warm cache) + print("3. Testing warm cache performance...") + start_time = time.perf_counter() + results2 = await strategy.search_tools(repo, "coding tools", limit=10) + warm_time = time.perf_counter() - start_time + print(f" โฑ๏ธ Warm cache: {warm_time:.3f}s, found {len(results2)} results") + + # Test 3: Multiple searches + print("4. Testing multiple searches...") + queries = ["cooking", "programming", "data analysis", "testing", "utilities"] + start_time = time.perf_counter() + + for query in queries: + await strategy.search_tools(repo, query, limit=5) + + total_time = time.perf_counter() - start_time + avg_time = total_time / len(queries) + print(f" โฑ๏ธ Average per search: {avg_time:.3f}s") + + # Performance assertions + assert cold_time < 10.0, f"Cold start too slow: {cold_time}s" # Allow more time for model loading + assert warm_time < 1.0, f"Warm cache too slow: {warm_time}s" + assert avg_time < 0.5, f"Average search too slow: {avg_time}s" + + print("\n๐ŸŽ‰ Performance test passed!") + return True + + except Exception as e: + print(f"โŒ Performance test failed: {e}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + success = asyncio.run(test_performance()) + sys.exit(0 if success else 1) diff --git a/plugins/tool_search/in_mem_embeddings/test_plugin.py b/plugins/tool_search/in_mem_embeddings/test_plugin.py new file mode 100644 index 0000000..95f11cf --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/test_plugin.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +"""Simple test script to verify the in-memory embeddings plugin works.""" + +import sys +import os +import asyncio +from pathlib import Path + +# Add the plugin source to Python path +plugin_src = Path(__file__).parent / "src" +sys.path.insert(0, str(plugin_src)) + +# Add core to path for imports +core_src = Path(__file__).parent.parent.parent.parent / "core" / "src" +sys.path.insert(0, str(core_src)) + +async def test_plugin(): + """Test the plugin functionality.""" + print("๐Ÿงช Testing In-Memory Embeddings Plugin...") + + try: + # Test 1: Import the plugin + print("1. Testing imports...") + from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategy + from utcp_in_mem_embeddings import register + print(" โœ… Imports successful") + + # Test 2: Create strategy instance + print("2. Testing strategy creation...") + strategy = InMemEmbeddingsSearchStrategy( + model_name="all-MiniLM-L6-v2", + similarity_threshold=0.3, + max_workers=2, + cache_embeddings=True + ) + print(f" โœ… Strategy created: {strategy.tool_search_strategy_type}") + + # Test 3: Test registration function + print("3. Testing registration...") + register() + print(" โœ… Registration function works") + + # Test 4: Test basic functionality + print("4. Testing basic functionality...") + + # Create mock tools + from utcp.data.tool import Tool, JsonSchema + from utcp.data.call_template import CallTemplate + + tools = [ + Tool( + name="cooking.spatula", + description="A kitchen utensil for flipping food", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["cooking", "kitchen"], + tool_call_template=CallTemplate( + name="cooking.spatula", + description="Spatula tool", + call_template_type="default" + ) + ), + Tool( + name="dev.code_review", + description="Review source code for quality", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["programming", "development"], + tool_call_template=CallTemplate( + name="dev.code_review", + description="Code review tool", + call_template_type="default" + ) + ) + ] + + # Create mock repository + class MockRepo: + def __init__(self, tools): + self.tools = tools + async def get_tools(self): + return self.tools + + repo = MockRepo(tools) + + # Test search + results = await strategy.search_tools(repo, "cooking utensils", limit=2) + print(f" โœ… Search completed, found {len(results)} results") + + if results: + print(f" ๐Ÿ“‹ Top result: {results[0].name}") + + print("\n๐ŸŽ‰ All tests passed! Plugin is working correctly.") + + except Exception as e: + print(f"โŒ Test failed: {e}") + import traceback + traceback.print_exc() + return False + + return True + +if __name__ == "__main__": + success = asyncio.run(test_plugin()) + sys.exit(0 if success else 1) diff --git a/plugins/tool_search/in_mem_embeddings/tests/test_in_mem_embeddings_search.py b/plugins/tool_search/in_mem_embeddings/tests/test_in_mem_embeddings_search.py new file mode 100644 index 0000000..d27773e --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/tests/test_in_mem_embeddings_search.py @@ -0,0 +1,343 @@ +"""Tests for the InMemEmbeddingsSearchStrategy implementation.""" +"""just test""" +import pytest +import numpy as np +import sys +from pathlib import Path +from unittest.mock import patch +from typing import List + +# Add plugin source to path +plugin_src = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(plugin_src)) + +# Add core to path +core_src = Path(__file__).parent.parent.parent.parent.parent / "core" / "src" +sys.path.insert(0, str(core_src)) + +from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategy +from utcp.data.tool import Tool, JsonSchema +from utcp.data.call_template import CallTemplate + + +class MockToolRepository: + """Simplified mock repository for testing.""" + + def __init__(self, tools: List[Tool]): + self.tools = tools + + async def get_tools(self) -> List[Tool]: + return self.tools + + +@pytest.fixture +def sample_tools(): + """Create sample tools for testing.""" + tools = [] + + # Tool 1: Cooking related + tool1 = Tool( + name="cooking.spatula", + description="A kitchen utensil used for flipping and turning food while cooking", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["cooking", "kitchen", "utensil"], + tool_call_template=CallTemplate( + name="cooking.spatula", + description="Spatula tool", + call_template_type="default" + ) + ) + tools.append(tool1) + + # Tool 2: Programming related + tool2 = Tool( + name="dev.code_review", + description="Review and analyze source code for quality and best practices", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["programming", "development", "code"], + tool_call_template=CallTemplate( + name="dev.code_review", + description="Code review tool", + call_template_type="default" + ) + ) + tools.append(tool2) + + # Tool 3: Data analysis + tool3 = Tool( + name="data.analyze", + description="Analyze datasets and generate insights from data", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["data", "analysis", "insights"], + tool_call_template=CallTemplate( + name="data.analyze", + description="Data analysis tool", + call_template_type="default" + ) + ) + tools.append(tool3) + + return tools + + +@pytest.fixture +def in_mem_embeddings_strategy(): + """Create an in-memory embeddings search strategy instance.""" + return InMemEmbeddingsSearchStrategy( + model_name="all-MiniLM-L6-v2", + similarity_threshold=0.3, + max_workers=2, + cache_embeddings=True + ) + + +@pytest.mark.asyncio +async def test_in_mem_embeddings_strategy_initialization(in_mem_embeddings_strategy): + """Test that the in-memory embeddings strategy initializes correctly.""" + assert in_mem_embeddings_strategy.tool_search_strategy_type == "in_mem_embeddings" + assert in_mem_embeddings_strategy.model_name == "all-MiniLM-L6-v2" + assert in_mem_embeddings_strategy.similarity_threshold == 0.3 + assert in_mem_embeddings_strategy.max_workers == 2 + assert in_mem_embeddings_strategy.cache_embeddings is True + + +@pytest.mark.asyncio +async def test_simple_text_embedding_fallback(in_mem_embeddings_strategy): + """Test the fallback text embedding when sentence-transformers is not available.""" + # Mock the embedding model to be None to trigger fallback + in_mem_embeddings_strategy._embedding_model = None + in_mem_embeddings_strategy._model_loaded = True + + text = "test text" + embedding = await in_mem_embeddings_strategy._get_text_embedding(text) + + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (384,) + assert np.linalg.norm(embedding) > 0 + + +@pytest.mark.asyncio +async def test_cosine_similarity_calculation(in_mem_embeddings_strategy): + """Test cosine similarity calculation.""" + # Test with identical vectors + vec1 = np.array([1.0, 0.0, 0.0]) + vec2 = np.array([1.0, 0.0, 0.0]) + similarity = in_mem_embeddings_strategy._cosine_similarity(vec1, vec2) + assert similarity == pytest.approx(1.0) + + # Test with orthogonal vectors + vec3 = np.array([0.0, 1.0, 0.0]) + similarity = in_mem_embeddings_strategy._cosine_similarity(vec1, vec3) + assert similarity == pytest.approx(0.0) + + # Test with zero vectors + vec4 = np.zeros(3) + similarity = in_mem_embeddings_strategy._cosine_similarity(vec1, vec4) + assert similarity == 0.0 + + +@pytest.mark.asyncio +async def test_tool_embedding_generation(in_mem_embeddings_strategy, sample_tools): + """Test that tool embeddings are generated and cached correctly.""" + tool = sample_tools[0] + + # Mock the text embedding method + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_embed: + mock_embed.return_value = np.random.rand(384) + + # First call should generate and cache + embedding1 = await in_mem_embeddings_strategy._get_tool_embedding(tool) + assert tool.name in in_mem_embeddings_strategy._tool_embeddings_cache + + # Second call should use cache + embedding2 = await in_mem_embeddings_strategy._get_tool_embedding(tool) + assert np.array_equal(embedding1, embedding2) + + # Verify the mock was called only once + mock_embed.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_tools_basic(in_mem_embeddings_strategy, sample_tools): + """Test basic search functionality.""" + tool_repo = MockToolRepository(sample_tools) + + # Mock the embedding methods + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed: + + # Create mock embeddings + query_embedding = np.random.rand(384) + tool_embeddings = [np.random.rand(384) for _ in sample_tools] + + mock_query_embed.return_value = query_embedding + mock_tool_embed.side_effect = tool_embeddings + + # Mock cosine similarity to return high scores + with patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + mock_sim.return_value = 0.8 # High similarity + + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "cooking", limit=2) + + assert len(results) == 2 + assert all(isinstance(tool, Tool) for tool in results) + + +@pytest.mark.asyncio +async def test_search_tools_with_tag_filtering(in_mem_embeddings_strategy, sample_tools): + """Test search with tag filtering.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed, \ + patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + + mock_query_embed.return_value = np.random.rand(384) + mock_tool_embed.return_value = np.random.rand(384) + mock_sim.return_value = 0.8 + + # Search with required tags + results = await in_mem_embeddings_strategy.search_tools( + tool_repo, + "cooking", + limit=10, + any_of_tags_required=["cooking", "kitchen"] + ) + + # Should only return tools with cooking or kitchen tags + assert all( + any(tag in ["cooking", "kitchen"] for tag in tool.tags) + for tool in results + ) + + +@pytest.mark.asyncio +async def test_search_tools_with_similarity_threshold(in_mem_embeddings_strategy, sample_tools): + """Test that similarity threshold filtering works correctly.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed, \ + patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + + mock_query_embed.return_value = np.random.rand(384) + mock_tool_embed.return_value = np.random.rand(384) + + # Set threshold to 0.5 and return scores below and above + in_mem_embeddings_strategy.similarity_threshold = 0.5 + mock_sim.side_effect = [0.3, 0.7, 0.2] # Only second tool should pass + + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=10) + + assert len(results) == 1 # Only one tool above threshold + + +@pytest.mark.asyncio +async def test_search_tools_limit_respected(in_mem_embeddings_strategy, sample_tools): + """Test that the limit parameter is respected.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed, \ + patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + + mock_query_embed.return_value = np.random.rand(384) + mock_tool_embed.return_value = np.random.rand(384) + mock_sim.return_value = 0.8 + + # Test with limit 1 + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=1) + assert len(results) == 1 + + # Test with limit 0 (no limit) + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=0) + assert len(results) == 3 # All tools + + +@pytest.mark.asyncio +async def test_search_tools_empty_repository(in_mem_embeddings_strategy): + """Test search behavior with empty tool repository.""" + tool_repo = MockToolRepository([]) + + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=10) + assert results == [] + + +@pytest.mark.asyncio +async def test_search_tools_invalid_limit(in_mem_embeddings_strategy, sample_tools): + """Test that invalid limit values raise appropriate errors.""" + tool_repo = MockToolRepository(sample_tools) + + with pytest.raises(ValueError, match="limit must be non-negative"): + await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=-1) + + +@pytest.mark.asyncio +async def test_context_manager_behavior(in_mem_embeddings_strategy): + """Test async context manager behavior.""" + async with in_mem_embeddings_strategy as strategy: + assert strategy._model_loaded is True + + # Executor should be shut down + assert strategy._executor._shutdown is True + + +@pytest.mark.asyncio +async def test_error_handling_in_search(in_mem_embeddings_strategy, sample_tools): + """Test that errors in search are handled gracefully.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed: + + mock_query_embed.return_value = np.random.rand(384) + + # Make the second tool fail + def mock_tool_embed_side_effect(tool): + if tool.name == "dev.code_review": + raise Exception("Simulated error") + return np.random.rand(384) + + mock_tool_embed.side_effect = mock_tool_embed_side_effect + + # Mock cosine similarity + with patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + mock_sim.return_value = 0.8 + + # Should not crash, just skip the problematic tool + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=10) + + # Should return tools that didn't fail + assert len(results) == 2 # One tool failed, so only 2 results + + +@pytest.mark.asyncio +async def test_in_mem_embeddings_strategy_config_serializer(): + """Test the configuration serializer.""" + from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategyConfigSerializer + + serializer = InMemEmbeddingsSearchStrategyConfigSerializer() + + # Test serialization + strategy = InMemEmbeddingsSearchStrategy( + model_name="test-model", + similarity_threshold=0.5, + max_workers=8, + cache_embeddings=False + ) + + config_dict = serializer.to_dict(strategy) + assert config_dict["model_name"] == "test-model" + assert config_dict["similarity_threshold"] == 0.5 + assert config_dict["max_workers"] == 8 + assert config_dict["cache_embeddings"] is False + + # Test deserialization + restored_strategy = serializer.validate_dict(config_dict) + assert restored_strategy.model_name == "test-model" + assert restored_strategy.similarity_threshold == 0.5 + assert restored_strategy.max_workers == 8 + assert restored_strategy.cache_embeddings is False