-
Notifications
You must be signed in to change notification settings - Fork 35
Implement Embedding Search Plugin #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Thuraabtech
wants to merge
13
commits into
universal-tool-calling-protocol:dev
Choose a base branch
from
Thuraabtech:issue/embedding-based-search
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d28dc58
Update docstrings and fix README
h3xxit 018806c
Merge pull request #52 from universal-tool-calling-protocol/dev
h3xxit 0f2af7e
Merge pull request #53 from universal-tool-calling-protocol/dev
h3xxit f12548a
Added embedding search feature for utcp 1.0
ef64c00
embedding search updated for UTCP 1.0
c8a4777
Update plugins/tool_search/embedding/pyproject.toml
Thuraabtech 2dd5752
Update plugins/tool_search/embedding/README.md
h3xxit 8dc2fe6
To be resolve
ba3e1fc
folder structure to be resolved
235d490
Correct folder placement done.
144a025
updated pyproject
eef4809
Description for values accepted by model_name
21b548a
Resolved cubic suggestions
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# UTCP Embedding Search Plugin | ||
|
||
This plugin registers the embedding-based semantic search strategy with UTCP 1.0 via entry points. | ||
|
||
## Installation | ||
|
||
```bash | ||
pip install utcp-embedding-search | ||
``` | ||
|
||
Optionally, for high-quality embeddings: | ||
|
||
```bash | ||
pip install "utcp-in-mem-embeddings[embedding]" | ||
``` | ||
|
||
## How it works | ||
|
||
When installed, this package exposes an entry point under `utcp.plugins` so the UTCP core can auto-discover and register the `in_mem_embeddings` strategy. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# UTCP In-Memory Embeddings Search Plugin | ||
|
||
This plugin registers the in-memory embedding-based semantic search strategy with UTCP 1.0 via entry points. | ||
|
||
## Installation | ||
|
||
```bash | ||
pip install utcp-in-mem-embeddings | ||
``` | ||
|
||
Optionally, for high-quality embeddings: | ||
|
||
```bash | ||
pip install "utcp-in-mem-embeddings[embedding]" | ||
``` | ||
|
||
Or install the required dependencies directly: | ||
|
||
```bash | ||
pip install "sentence-transformers>=2.2.0" "torch>=1.9.0" | ||
``` | ||
|
||
## Why are sentence-transformers and torch needed? | ||
|
||
While the plugin works without these packages (using a simple character frequency-based fallback), installing them provides significant benefits: | ||
|
||
- **Enhanced Semantic Understanding**: The `sentence-transformers` package provides pre-trained models that convert text into high-quality vector embeddings, capturing the semantic meaning of text rather than just keywords. | ||
|
||
- **Better Search Results**: With these packages installed, the search can understand conceptual similarity between queries and tools, even when they don't share exact keywords. | ||
|
||
- **Performance**: The default model (all-MiniLM-L6-v2) offers a good balance between quality and performance for semantic search applications. | ||
|
||
- **Fallback Mechanism**: Without these packages, the plugin automatically falls back to a simpler text similarity method, which works but with reduced accuracy. | ||
|
||
## How it works | ||
|
||
When installed, this package exposes an entry point under `utcp.plugins` so the UTCP core can auto-discover and register the `in_mem_embeddings` strategy. | ||
|
||
The embeddings are cached in memory for improved performance during repeated searches. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "utcp-in-mem-embeddings" | ||
version = "1.0.0" | ||
authors = [ | ||
{ name = "UTCP Contributors" }, | ||
] | ||
description = "UTCP plugin providing in-memory embedding-based semantic tool search." | ||
readme = "README.md" | ||
requires-python = ">=3.10" | ||
dependencies = [ | ||
"utcp>=1.0", | ||
] | ||
classifiers = [ | ||
"Development Status :: 4 - Beta", | ||
"Intended Audience :: Developers", | ||
"Programming Language :: Python :: 3", | ||
"Operating System :: OS Independent", | ||
] | ||
license = "MPL-2.0" | ||
|
||
[project.optional-dependencies] | ||
embedding = [ | ||
"sentence-transformers>=2.2.0", | ||
"torch>=1.9.0", | ||
] | ||
|
||
|
||
[project.urls] | ||
Homepage = "https://utcp.io" | ||
Source = "https://github.com/universal-tool-calling-protocol/python-utcp" | ||
Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" | ||
|
||
[project.entry-points."utcp.plugins"] | ||
in_mem_embeddings = "utcp_in_mem_embeddings:register" |
7 changes: 7 additions & 0 deletions
7
plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from utcp.plugins.discovery import register_tool_search_strategy | ||
from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategyConfigSerializer | ||
|
||
|
||
def register(): | ||
"""Entry point function to register the in-memory embeddings search strategy.""" | ||
register_tool_search_strategy("in_mem_embeddings", InMemEmbeddingsSearchStrategyConfigSerializer()) |
241 changes: 241 additions & 0 deletions
241
plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/in_mem_embeddings_search.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
"""In-memory embedding-based semantic search strategy for UTCP tools. | ||
|
||
This module provides a semantic search implementation that uses sentence embeddings | ||
to find tools based on meaning similarity rather than just keyword matching. | ||
Embeddings are cached in memory for improved performance. | ||
""" | ||
|
||
import asyncio | ||
import logging | ||
from typing import List, Tuple, Optional, Literal, Dict, Any | ||
from concurrent.futures import ThreadPoolExecutor | ||
import numpy as np | ||
from pydantic import BaseModel, Field, PrivateAttr | ||
|
||
from utcp.interfaces.tool_search_strategy import ToolSearchStrategy | ||
from utcp.data.tool import Tool | ||
from utcp.interfaces.concurrent_tool_repository import ConcurrentToolRepository | ||
from utcp.interfaces.serializer import Serializer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class InMemEmbeddingsSearchStrategy(ToolSearchStrategy): | ||
"""In-memory semantic search strategy using sentence embeddings. | ||
|
||
This strategy converts tool descriptions and search queries into numerical | ||
embeddings and finds the most semantically similar tools using cosine similarity. | ||
Embeddings are cached in memory for improved performance during repeated searches. | ||
""" | ||
|
||
tool_search_strategy_type: Literal["in_mem_embeddings"] = "in_mem_embeddings" | ||
|
||
# Configuration parameters | ||
model_name: str = Field( | ||
default="all-MiniLM-L6-v2", | ||
description="Sentence transformer model name to use for embeddings. " | ||
"Accepts any model from Hugging Face sentence-transformers library. " | ||
"Popular options: 'all-MiniLM-L6-v2' (fast, good quality), " | ||
"'all-mpnet-base-v2' (slower, higher quality), " | ||
"'paraphrase-MiniLM-L6-v2' (paraphrase detection). " | ||
"See https://huggingface.co/sentence-transformers for full list." | ||
) | ||
similarity_threshold: float = Field(default=0.3, description="Minimum similarity score to consider a match") | ||
max_workers: int = Field(default=4, description="Maximum number of worker threads for embedding generation") | ||
cache_embeddings: bool = Field(default=True, description="Whether to cache tool embeddings for performance") | ||
|
||
# Private attributes | ||
_embedding_model: Optional[Any] = PrivateAttr(default=None) | ||
_tool_embeddings_cache: Dict[str, np.ndarray] = PrivateAttr(default_factory=dict) | ||
_executor: Optional[ThreadPoolExecutor] = PrivateAttr(default=None) | ||
_model_loaded: bool = PrivateAttr(default=False) | ||
|
||
def __init__(self, **data): | ||
super().__init__(**data) | ||
self._executor = ThreadPoolExecutor(max_workers=self.max_workers) | ||
|
||
async def _ensure_model_loaded(self): | ||
"""Ensure the embedding model is loaded.""" | ||
if self._model_loaded: | ||
return | ||
|
||
try: | ||
# Import sentence-transformers here to avoid dependency issues | ||
from sentence_transformers import SentenceTransformer | ||
|
||
# Load the model in a thread to avoid blocking | ||
loop = asyncio.get_running_loop() | ||
self._embedding_model = await loop.run_in_executor( | ||
self._executor, | ||
SentenceTransformer, | ||
self.model_name | ||
) | ||
self._model_loaded = True | ||
logger.info(f"Loaded embedding model: {self.model_name}") | ||
|
||
except ImportError: | ||
logger.warning("sentence-transformers not available, falling back to simple text similarity") | ||
self._embedding_model = None | ||
self._model_loaded = True | ||
except Exception as e: | ||
logger.error(f"Failed to load embedding model: {e}") | ||
self._embedding_model = None | ||
self._model_loaded = True | ||
|
||
async def _get_text_embedding(self, text: str) -> np.ndarray: | ||
"""Generate embedding for given text.""" | ||
if not text: | ||
return np.zeros(384) # Default dimension for all-MiniLM-L6-v2 | ||
|
||
if self._embedding_model is None: | ||
# Fallback to simple text similarity | ||
return self._simple_text_embedding(text) | ||
|
||
try: | ||
loop = asyncio.get_event_loop() | ||
embedding = await loop.run_in_executor( | ||
self._executor, | ||
self._embedding_model.encode, | ||
text | ||
) | ||
return embedding | ||
except Exception as e: | ||
logger.warning(f"Failed to generate embedding for text: {e}") | ||
return self._simple_text_embedding(text) | ||
|
||
def _simple_text_embedding(self, text: str) -> np.ndarray: | ||
"""Simple fallback embedding using character frequency.""" | ||
# Create a simple embedding based on character frequency | ||
# This is a fallback when sentence-transformers is not available | ||
embedding = np.zeros(384) | ||
text_lower = text.lower() | ||
|
||
# Simple character frequency-based embedding | ||
for i, char in enumerate(text_lower): | ||
embedding[i % 384] += ord(char) / 1000.0 | ||
|
||
# Normalize | ||
norm = np.linalg.norm(embedding) | ||
if norm > 0: | ||
embedding = embedding / norm | ||
|
||
return embedding | ||
|
||
async def _get_tool_embedding(self, tool: Tool) -> np.ndarray: | ||
"""Get or generate embedding for a tool.""" | ||
if not self.cache_embeddings or tool.name not in self._tool_embeddings_cache: | ||
# Create text representation of the tool | ||
tool_text = f"{tool.name} {tool.description} {' '.join(tool.tags)}" | ||
embedding = await self._get_text_embedding(tool_text) | ||
|
||
if self.cache_embeddings: | ||
self._tool_embeddings_cache[tool.name] = embedding | ||
|
||
return embedding | ||
|
||
return self._tool_embeddings_cache[tool.name] | ||
|
||
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: | ||
"""Calculate cosine similarity between two vectors.""" | ||
try: | ||
dot_product = np.dot(a, b) | ||
norm_a = np.linalg.norm(a) | ||
norm_b = np.linalg.norm(b) | ||
|
||
if norm_a == 0 or norm_b == 0: | ||
return 0.0 | ||
|
||
return dot_product / (norm_a * norm_b) | ||
except Exception as e: | ||
logger.warning(f"Error calculating cosine similarity: {e}") | ||
return 0.0 | ||
|
||
async def search_tools( | ||
self, | ||
tool_repository: ConcurrentToolRepository, | ||
query: str, | ||
limit: int = 10, | ||
any_of_tags_required: Optional[List[str]] = None | ||
) -> List[Tool]: | ||
"""Search for tools using semantic similarity. | ||
|
||
Args: | ||
tool_repository: The tool repository to search within. | ||
query: The search query string. | ||
limit: Maximum number of tools to return. | ||
any_of_tags_required: Optional list of tags where one of them must be present. | ||
|
||
Returns: | ||
List of Tool objects ranked by semantic similarity. | ||
""" | ||
if limit < 0: | ||
raise ValueError("limit must be non-negative") | ||
|
||
# Ensure the embedding model is loaded | ||
await self._ensure_model_loaded() | ||
|
||
# Get all tools | ||
tools: List[Tool] = await tool_repository.get_tools() | ||
|
||
# Filter by required tags if specified | ||
if any_of_tags_required and len(any_of_tags_required) > 0: | ||
any_of_tags_required = [tag.lower() for tag in any_of_tags_required] | ||
tools = [ | ||
tool for tool in tools | ||
if any(tag.lower() in any_of_tags_required for tag in tool.tags) | ||
] | ||
|
||
if not tools: | ||
return [] | ||
|
||
# Generate query embedding | ||
query_embedding = await self._get_text_embedding(query) | ||
|
||
# Calculate similarity scores for all tools | ||
tool_scores: List[Tuple[Tool, float]] = [] | ||
|
||
for tool in tools: | ||
try: | ||
tool_embedding = await self._get_tool_embedding(tool) | ||
similarity = self._cosine_similarity(query_embedding, tool_embedding) | ||
|
||
if similarity >= self.similarity_threshold: | ||
tool_scores.append((tool, similarity)) | ||
|
||
except Exception as e: | ||
logger.warning(f"Error processing tool {tool.name}: {e}") | ||
continue | ||
|
||
# Sort by similarity score (descending) | ||
sorted_tools = [ | ||
tool for tool, score in sorted( | ||
tool_scores, | ||
key=lambda x: x[1], | ||
reverse=True | ||
) | ||
] | ||
|
||
# Return up to 'limit' tools | ||
return sorted_tools[:limit] if limit > 0 else sorted_tools | ||
|
||
async def __aenter__(self): | ||
"""Async context manager entry.""" | ||
await self._ensure_model_loaded() | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
"""Async context manager exit.""" | ||
if self._executor: | ||
self._executor.shutdown(wait=False) | ||
|
||
|
||
class InMemEmbeddingsSearchStrategyConfigSerializer(Serializer[InMemEmbeddingsSearchStrategy]): | ||
"""Serializer for InMemEmbeddingsSearchStrategy configuration.""" | ||
|
||
def to_dict(self, obj: InMemEmbeddingsSearchStrategy) -> dict: | ||
return obj.model_dump() | ||
|
||
def validate_dict(self, data: dict) -> InMemEmbeddingsSearchStrategy: | ||
try: | ||
return InMemEmbeddingsSearchStrategy.model_validate(data) | ||
except Exception as e: | ||
raise ValueError(f"Invalid configuration: {e}") from e |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imported symbol appears undefined in the codebase; this import is likely to raise ImportError at plugin load. Verify correct module path or function name.
Prompt for AI agents