Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pymongo_vectorsearch_utils/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def create_fulltext_search_index(
def wait_for_docs_in_index(
collection: Collection[Any],
index_name: str,
embedding_field: str,
n_docs: int,
) -> bool:
"""Wait until the given number of documents are indexed by the given index.
Expand All @@ -270,12 +269,19 @@ def wait_for_docs_in_index(
embedding_field (str): The name of the document field containing embeddings.
n_docs (int): The number of documents to expect in the index.
"""
query_vector = [0.0] * 1024 # Dummy vector
indexes = collection.list_search_indexes(index_name).to_list()
if len(indexes) == 0:
raise ValueError(f"Index {index_name} does not exist in collection {collection.name}")
index = indexes[0]
num_dimensions = index["latestDefinition"]["fields"][0]["numDimensions"]
field = index["latestDefinition"]["fields"][0]["path"]

query_vector = [0.001] * num_dimensions # Dummy vector
query = [
{
"$vectorSearch": {
"index": index_name,
"path": embedding_field,
"path": field,
"queryVector": query_vector,
"numCandidates": n_docs,
"limit": n_docs,
Expand Down
48 changes: 48 additions & 0 deletions pymongo_vectorsearch_utils/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pymongo import ReplaceOne
from pymongo.synchronous.collection import Collection

from pymongo_vectorsearch_utils.pipeline import vector_search_stage
from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid


Expand Down Expand Up @@ -60,3 +61,50 @@ def bulk_embed_and_insert_texts(
result = collection.bulk_write(operations)
assert result.upserted_ids is not None
return [oid_to_str(_id) for _id in result.upserted_ids.values()]


def execute_search_query(
query_vector: list[float],
collection: Collection[Any],
embedding_key: str,
text_key: str,
index_name: str,
k: int = 4,
pre_filter: dict[str, Any] | None = None,
post_filter_pipeline: list[dict[str, Any]] | None = None,
oversampling_factor: int = 10,
include_embeddings: bool = False,
**kwargs: Any,
) -> list[tuple[Any, float]]:
"""Execute a MongoDB vector search query."""

# Atlas Vector Search, potentially with filter
pipeline = [
vector_search_stage(
query_vector,
embedding_key,
index_name,
k,
pre_filter,
oversampling_factor,
**kwargs,
),
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]

# Remove embeddings unless requested.
if not include_embeddings:
pipeline.append({"$project": {embedding_key: 0}})
# Post-processing
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)

# Execution
cursor = collection.aggregate(pipeline)
docs = []

for doc in cursor:
if text_key not in doc:
continue
docs.append(doc)
return docs
197 changes: 195 additions & 2 deletions tests/test_operation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
"""Tests for operation utilities."""

import os
from unittest.mock import Mock
from unittest.mock import Mock, patch

import pytest
from bson import ObjectId
from pymongo import MongoClient
from pymongo.collection import Collection

from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts
from pymongo_vectorsearch_utils import drop_vector_search_index
from pymongo_vectorsearch_utils.index import create_vector_search_index, wait_for_docs_in_index
from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts, execute_search_query

DB_NAME = "vectorsearch_utils_test"
COLLECTION_NAME = "test_operation"
VECTOR_INDEX_NAME = "operation_vector_index"


@pytest.fixture(scope="module")
Expand All @@ -22,6 +25,17 @@ def client():
client.close()


@pytest.fixture(scope="module")
def preserved_collection(client):
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
clxn = client[DB_NAME][COLLECTION_NAME]
clxn.delete_many({})
yield clxn
clxn.delete_many({})


@pytest.fixture
def collection(client):
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
Expand Down Expand Up @@ -266,3 +280,182 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
assert "vector" in doc
assert doc["content"] == texts[0]
assert doc["vector"] == [0.0, 0.0, 0.0]


class TestExecuteSearchQuery:
@pytest.fixture(scope="class", autouse=True)
def vector_search_index(self, client):
coll = client[DB_NAME][COLLECTION_NAME]
if len(coll.list_search_indexes(VECTOR_INDEX_NAME).to_list()) == 0:
create_vector_search_index(
collection=coll,
index_name=VECTOR_INDEX_NAME,
dimensions=3,
path="embedding",
similarity="cosine",
filters=["category", "color", "wheels"],
wait_until_complete=120,
)
yield
drop_vector_search_index(collection=coll, index_name=VECTOR_INDEX_NAME)

@pytest.fixture(scope="class", autouse=True)
def sample_docs(self, preserved_collection: Collection, vector_search_index):
texts = ["apple fruit", "banana fruit", "car vehicle", "bike vehicle"]
metadatas = [
{"category": "fruit", "color": "red"},
{"category": "fruit", "color": "yellow"},
{"category": "vehicle", "wheels": 4},
{"category": "vehicle", "wheels": 2},
]

def embeddings(texts):
mapping = {
"apple fruit": [1.0, 0.5, 0.0],
"banana fruit": [0.5, 0.5, 0.0],
"car vehicle": [0.0, 0.5, 1.0],
"bike vehicle": [0.0, 1.0, 0.5],
}
return [mapping[text] for text in texts]

bulk_embed_and_insert_texts(
texts=texts,
metadatas=metadatas,
embedding_func=embeddings,
collection=preserved_collection,
text_key="text",
embedding_key="embedding",
)
# Add a document that should not be returned in searches
preserved_collection.insert_one(
{
"category": "fruit",
"color": "red",
"embedding": [1.0, 1.0, 1.0],
}
)
wait_for_docs_in_index(preserved_collection, VECTOR_INDEX_NAME, n_docs=5)
return preserved_collection

def test_basic_search_query(self, sample_docs: Collection):
query_vector = [1.0, 0.5, 0.0]

result = execute_search_query(
query_vector=query_vector,
collection=sample_docs,
embedding_key="embedding",
text_key="text",
index_name=VECTOR_INDEX_NAME,
k=2,
)

assert len(result) == 2
assert result[0]["text"] == "apple fruit"
assert result[1]["text"] == "banana fruit"
assert "score" in result[0]
assert "score" in result[1]

def test_search_with_pre_filter(self, sample_docs: Collection):
query_vector = [1.0, 0.5, 1.0]
pre_filter = {"category": "fruit"}

result = execute_search_query(
query_vector=query_vector,
collection=sample_docs,
embedding_key="embedding",
text_key="text",
index_name=VECTOR_INDEX_NAME,
k=4,
pre_filter=pre_filter,
)

assert len(result) == 2
assert result[0]["category"] == "fruit"
assert result[1]["category"] == "fruit"

def test_search_with_post_filter_pipeline(self, sample_docs: Collection):
query_vector = [1.0, 0.5, 0.0]
post_filter_pipeline = [
{"$match": {"score": {"$gte": 0.99}}},
{"$sort": {"score": -1}},
]

result = execute_search_query(
query_vector=query_vector,
collection=sample_docs,
embedding_key="embedding",
text_key="text",
index_name=VECTOR_INDEX_NAME,
k=2,
post_filter_pipeline=post_filter_pipeline,
)

assert len(result) == 1

def test_search_with_embeddings_included(self, sample_docs: Collection):
query_vector = [1.0, 0.5, 0.0]

result = execute_search_query(
query_vector=query_vector,
collection=sample_docs,
embedding_key="embedding",
text_key="text",
index_name=VECTOR_INDEX_NAME,
k=1,
include_embeddings=True,
)

assert len(result) == 1
assert "embedding" in result[0]
assert result[0]["embedding"] == [1.0, 0.5, 0.0]

def test_search_with_custom_field_names(self, sample_docs: Collection):
query_vector = [1.0, 0.5, 0.25]

mock_cursor = [
{
"_id": ObjectId(),
"content": "apple fruit",
"vector": [1.0, 0.5, 0.25],
"score": 0.9,
}
]

with patch.object(sample_docs, "aggregate") as mock_aggregate:
mock_aggregate.return_value = mock_cursor

result = execute_search_query(
query_vector=query_vector,
collection=sample_docs,
embedding_key="vector",
text_key="content",
index_name=VECTOR_INDEX_NAME,
k=1,
)

assert len(result) == 1
assert "content" in result[0]
assert result[0]["content"] == "apple fruit"

pipeline_arg = mock_aggregate.call_args[0][0]
vector_search_stage = pipeline_arg[0]["$vectorSearch"]
assert vector_search_stage["path"] == "vector"
assert {"$project": {"vector": 0}} in pipeline_arg

def test_search_filters_documents_without_text_key(self, sample_docs: Collection):
query_vector = [1.0, 0.5, 0.0]

result = execute_search_query(
query_vector=query_vector,
collection=sample_docs,
embedding_key="embedding",
text_key="text",
index_name=VECTOR_INDEX_NAME,
k=3,
)

# Should only return documents with text field
assert len(result) == 2
assert all("text" in doc for doc in result)
assert result[0]["text"] == "apple fruit"
assert result[1]["text"] == "banana fruit"