From 57a43ccf46cdce0f7938d3bc0aed9fa024b54605 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 10 Sep 2025 12:27:25 -0400 Subject: [PATCH 1/3] Add tests --- pymongo_vectorsearch_utils/index.py | 9 +- pymongo_vectorsearch_utils/operation.py | 49 ++++++ tests/test_operation.py | 189 +++++++++++++++++++++++- 3 files changed, 242 insertions(+), 5 deletions(-) diff --git a/pymongo_vectorsearch_utils/index.py b/pymongo_vectorsearch_utils/index.py index 54d64cd..bb42f90 100644 --- a/pymongo_vectorsearch_utils/index.py +++ b/pymongo_vectorsearch_utils/index.py @@ -259,7 +259,6 @@ def create_fulltext_search_index( def wait_for_docs_in_index( collection: Collection[Any], index_name: str, - embedding_field: str, n_docs: int, ) -> bool: """Wait until the given number of documents are indexed by the given index. @@ -270,12 +269,16 @@ def wait_for_docs_in_index( embedding_field (str): The name of the document field containing embeddings. n_docs (int): The number of documents to expect in the index. """ - query_vector = [0.0] * 1024 # Dummy vector + index = collection.list_search_indexes(index_name).to_list()[0] + num_dimensions = index["latestDefinition"]["fields"][0]["numDimensions"] + field = index["latestDefinition"]["fields"][0]["path"] + + query_vector = [0.001] * num_dimensions # Dummy vector query = [ { "$vectorSearch": { "index": index_name, - "path": embedding_field, + "path": field, "queryVector": query_vector, "numCandidates": n_docs, "limit": n_docs, diff --git a/pymongo_vectorsearch_utils/operation.py b/pymongo_vectorsearch_utils/operation.py index fd0d232..ff7b4bb 100644 --- a/pymongo_vectorsearch_utils/operation.py +++ b/pymongo_vectorsearch_utils/operation.py @@ -7,6 +7,7 @@ from pymongo import ReplaceOne from pymongo.synchronous.collection import Collection +from pymongo_vectorsearch_utils.pipeline import vector_search_stage from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid @@ -60,3 +61,51 @@ def bulk_embed_and_insert_texts( result = collection.bulk_write(operations) assert result.upserted_ids is not None return [oid_to_str(_id) for _id in result.upserted_ids.values()] + + +def execute_search_query( + query_vector: list[float], + collection: Collection[Any], + embedding_key: str, + text_key: str, + index_name: str, + k: int = 4, + pre_filter: dict[str, Any] | None = None, + post_filter_pipeline: list[dict[str, Any]] | None = None, + oversampling_factor: int = 10, + include_embeddings: bool = False, + **kwargs: Any, +) -> list[tuple[Any, float]]: + """Execute a MongoDB vector search query.""" + + # Atlas Vector Search, potentially with filter + pipeline = [ + vector_search_stage( + query_vector, + embedding_key, + index_name, + k, + pre_filter, + oversampling_factor, + **kwargs, + ), + {"$set": {"score": {"$meta": "vectorSearchScore"}}}, + ] + + + # Remove embeddings unless requested. + if not include_embeddings: + pipeline.append({"$project": {embedding_key: 0}}) + # Post-processing + if post_filter_pipeline is not None: + pipeline.extend(post_filter_pipeline) + + # Execution + cursor = collection.aggregate(pipeline) + docs = [] + + for doc in cursor: + if text_key not in doc: + continue + docs.append(doc) + return docs diff --git a/tests/test_operation.py b/tests/test_operation.py index 2303a67..2491145 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -1,17 +1,20 @@ """Tests for operation utilities.""" import os -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from bson import ObjectId from pymongo import MongoClient from pymongo.collection import Collection -from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts +from pymongo_vectorsearch_utils import drop_vector_search_index +from pymongo_vectorsearch_utils.index import create_vector_search_index, wait_for_docs_in_index +from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts, execute_search_query DB_NAME = "vectorsearch_utils_test" COLLECTION_NAME = "test_operation" +VECTOR_INDEX_NAME = "operation_vector_index" @pytest.fixture(scope="module") @@ -21,6 +24,15 @@ def client(): yield client client.close() +@pytest.fixture(scope="module") +def preserved_collection(client): + if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME) + else: + clxn = client[DB_NAME][COLLECTION_NAME] + clxn.delete_many({}) + yield clxn + clxn.delete_many({}) @pytest.fixture def collection(client): @@ -266,3 +278,176 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func): assert "vector" in doc assert doc["content"] == texts[0] assert doc["vector"] == [0.0, 0.0, 0.0] + + +class TestExecuteSearchQuery: + @pytest.fixture(scope="class", autouse=True) + def vector_search_index(self, client): + coll = client[DB_NAME][COLLECTION_NAME] + if len(coll.list_search_indexes(VECTOR_INDEX_NAME).to_list()) == 0: + create_vector_search_index( + collection=coll, + index_name=VECTOR_INDEX_NAME, + dimensions=3, + path="embedding", + similarity="cosine", + filters=["category", "color", "wheels"], + wait_until_complete=120, + ) + yield + drop_vector_search_index(collection=coll, index_name=VECTOR_INDEX_NAME) + + @pytest.fixture(scope="class", autouse=True) + def sample_docs(self, preserved_collection: Collection): + texts = ["apple fruit", "banana fruit", "car vehicle", "bike vehicle"] + metadatas = [ + {"category": "fruit", "color": "red"}, + {"category": "fruit", "color": "yellow"}, + {"category": "vehicle", "wheels": 4}, + {"category": "vehicle", "wheels": 2}, + ] + + def embeddings(texts): + mapping = { + "apple fruit": [1.0, 0.5, 0.0], + "banana fruit": [0.5, 0.5, 0.0], + "car vehicle": [0.0, 0.5, 1.0], + "bike vehicle": [0.0, 1.0, 0.5], + } + return [mapping[text] for text in texts] + + bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=embeddings, + collection=preserved_collection, + text_key="text", + embedding_key="embedding", + ) + # Add a document that should not be returned in searches + preserved_collection.insert_one({'_id': ObjectId('68c1a038fd976373aa4ec19f'), 'category': 'fruit', 'color': 'red', 'embedding': [1.0, 1.0, 1.0]}) + wait_for_docs_in_index(preserved_collection, VECTOR_INDEX_NAME, n_docs=5) + return preserved_collection + + def test_basic_search_query(self, sample_docs: Collection): + query_vector = [1.0, 0.5, 0.0] + + result = execute_search_query( + query_vector=query_vector, + collection=sample_docs, + embedding_key="embedding", + text_key="text", + index_name=VECTOR_INDEX_NAME, + k=2, + ) + + assert len(result) == 2 + assert result[0]["text"] == "apple fruit" + assert result[1]["text"] == "banana fruit" + assert "score" in result[0] + assert "score" in result[1] + + def test_search_with_pre_filter(self, sample_docs: Collection): + query_vector = [1.0, 0.5, 1.0] + pre_filter = {"category": "fruit"} + + result = execute_search_query( + query_vector=query_vector, + collection=sample_docs, + embedding_key="embedding", + text_key="text", + index_name=VECTOR_INDEX_NAME, + k=4, + pre_filter=pre_filter, + ) + + assert len(result) == 2 + assert result[0]["category"] == "fruit" + assert result[1]["category"] == "fruit" + + def test_search_with_post_filter_pipeline(self, sample_docs: Collection): + query_vector = [1.0, 0.5, 0.0] + post_filter_pipeline = [ + {"$match": {"score": {"$gte": 0.99}}}, + {"$sort": {"score": -1}}, + ] + + result = execute_search_query( + query_vector=query_vector, + collection=sample_docs, + embedding_key="embedding", + text_key="text", + index_name=VECTOR_INDEX_NAME, + k=2, + post_filter_pipeline=post_filter_pipeline, + ) + + assert len(result) == 1 + + def test_search_with_embeddings_included(self, sample_docs: Collection): + query_vector = [1.0, 0.5, 0.0] + + result = execute_search_query( + query_vector=query_vector, + collection=sample_docs, + embedding_key="embedding", + text_key="text", + index_name=VECTOR_INDEX_NAME, + k=1, + include_embeddings=True, + ) + + assert len(result) == 1 + assert "embedding" in result[0] + assert result[0]["embedding"] == [1.0, 0.5, 0.0] + + def test_search_with_custom_field_names(self, sample_docs: Collection): + query_vector = [1.0, 0.5, 0.25] + + mock_cursor = [ + { + "_id": ObjectId(), + "content": "apple fruit", + "vector": [1.0, 0.5, 0.25], + "score": 0.9, + } + ] + + with patch.object(sample_docs, "aggregate") as mock_aggregate: + mock_aggregate.return_value = mock_cursor + + result = execute_search_query( + query_vector=query_vector, + collection=sample_docs, + embedding_key="vector", + text_key="content", + index_name=VECTOR_INDEX_NAME, + k=1, + ) + + assert len(result) == 1 + assert "content" in result[0] + assert result[0]["content"] == "apple fruit" + + pipeline_arg = mock_aggregate.call_args[0][0] + vector_search_stage = pipeline_arg[0]["$vectorSearch"] + assert vector_search_stage["path"] == "vector" + assert {"$project": {"vector": 0}} in pipeline_arg + + def test_search_filters_documents_without_text_key(self, sample_docs: Collection): + query_vector = [1.0, 0.5, 0.0] + + result = execute_search_query( + query_vector=query_vector, + collection=sample_docs, + embedding_key="embedding", + text_key="text", + index_name=VECTOR_INDEX_NAME, + k=3, + ) + + # Should only return documents with text field + assert len(result) == 2 + assert all("text" in doc for doc in result) + assert result[0]["text"] == "apple fruit" + assert result[1]["text"] == "banana fruit" From d17d7f076cc270cca4a048756205670050fc1a69 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 10 Sep 2025 12:38:34 -0400 Subject: [PATCH 2/3] Update wait_for_docs_in_index --- pymongo_vectorsearch_utils/index.py | 5 ++++- tests/test_operation.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pymongo_vectorsearch_utils/index.py b/pymongo_vectorsearch_utils/index.py index bb42f90..03ac9c4 100644 --- a/pymongo_vectorsearch_utils/index.py +++ b/pymongo_vectorsearch_utils/index.py @@ -269,7 +269,10 @@ def wait_for_docs_in_index( embedding_field (str): The name of the document field containing embeddings. n_docs (int): The number of documents to expect in the index. """ - index = collection.list_search_indexes(index_name).to_list()[0] + indexes = collection.list_search_indexes(index_name).to_list() + if len(indexes) == 0: + raise ValueError(f"Index {index_name} does not exist in collection {collection.name}") + index = indexes[0] num_dimensions = index["latestDefinition"]["fields"][0]["numDimensions"] field = index["latestDefinition"]["fields"][0]["path"] diff --git a/tests/test_operation.py b/tests/test_operation.py index 2491145..b61fc97 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -298,7 +298,7 @@ def vector_search_index(self, client): drop_vector_search_index(collection=coll, index_name=VECTOR_INDEX_NAME) @pytest.fixture(scope="class", autouse=True) - def sample_docs(self, preserved_collection: Collection): + def sample_docs(self, preserved_collection: Collection, vector_search_index): texts = ["apple fruit", "banana fruit", "car vehicle", "bike vehicle"] metadatas = [ {"category": "fruit", "color": "red"}, From 3028828dfc6969a10e12d1e13cbf026bd195a0d2 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 10 Sep 2025 12:40:51 -0400 Subject: [PATCH 3/3] Linting --- pymongo_vectorsearch_utils/operation.py | 1 - tests/test_operation.py | 10 +++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pymongo_vectorsearch_utils/operation.py b/pymongo_vectorsearch_utils/operation.py index ff7b4bb..e4b7e83 100644 --- a/pymongo_vectorsearch_utils/operation.py +++ b/pymongo_vectorsearch_utils/operation.py @@ -92,7 +92,6 @@ def execute_search_query( {"$set": {"score": {"$meta": "vectorSearchScore"}}}, ] - # Remove embeddings unless requested. if not include_embeddings: pipeline.append({"$project": {embedding_key: 0}}) diff --git a/tests/test_operation.py b/tests/test_operation.py index b61fc97..295a79a 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -24,6 +24,7 @@ def client(): yield client client.close() + @pytest.fixture(scope="module") def preserved_collection(client): if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): @@ -34,6 +35,7 @@ def preserved_collection(client): yield clxn clxn.delete_many({}) + @pytest.fixture def collection(client): if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): @@ -325,7 +327,13 @@ def embeddings(texts): embedding_key="embedding", ) # Add a document that should not be returned in searches - preserved_collection.insert_one({'_id': ObjectId('68c1a038fd976373aa4ec19f'), 'category': 'fruit', 'color': 'red', 'embedding': [1.0, 1.0, 1.0]}) + preserved_collection.insert_one( + { + "category": "fruit", + "color": "red", + "embedding": [1.0, 1.0, 1.0], + } + ) wait_for_docs_in_index(preserved_collection, VECTOR_INDEX_NAME, n_docs=5) return preserved_collection