Skip to content

Commit 59c21e5

Browse files
authored
Merge pull request #4 from NoahStapp/add-query
Add query helper
2 parents 62a5c5f + 3028828 commit 59c21e5

File tree

3 files changed

+252
-5
lines changed

3 files changed

+252
-5
lines changed

pymongo_vectorsearch_utils/index.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def create_fulltext_search_index(
259259
def wait_for_docs_in_index(
260260
collection: Collection[Any],
261261
index_name: str,
262-
embedding_field: str,
263262
n_docs: int,
264263
) -> bool:
265264
"""Wait until the given number of documents are indexed by the given index.
@@ -270,12 +269,19 @@ def wait_for_docs_in_index(
270269
embedding_field (str): The name of the document field containing embeddings.
271270
n_docs (int): The number of documents to expect in the index.
272271
"""
273-
query_vector = [0.0] * 1024 # Dummy vector
272+
indexes = collection.list_search_indexes(index_name).to_list()
273+
if len(indexes) == 0:
274+
raise ValueError(f"Index {index_name} does not exist in collection {collection.name}")
275+
index = indexes[0]
276+
num_dimensions = index["latestDefinition"]["fields"][0]["numDimensions"]
277+
field = index["latestDefinition"]["fields"][0]["path"]
278+
279+
query_vector = [0.001] * num_dimensions # Dummy vector
274280
query = [
275281
{
276282
"$vectorSearch": {
277283
"index": index_name,
278-
"path": embedding_field,
284+
"path": field,
279285
"queryVector": query_vector,
280286
"numCandidates": n_docs,
281287
"limit": n_docs,

pymongo_vectorsearch_utils/operation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pymongo import ReplaceOne
88
from pymongo.synchronous.collection import Collection
99

10+
from pymongo_vectorsearch_utils.pipeline import vector_search_stage
1011
from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid
1112

1213

@@ -60,3 +61,50 @@ def bulk_embed_and_insert_texts(
6061
result = collection.bulk_write(operations)
6162
assert result.upserted_ids is not None
6263
return [oid_to_str(_id) for _id in result.upserted_ids.values()]
64+
65+
66+
def execute_search_query(
67+
query_vector: list[float],
68+
collection: Collection[Any],
69+
embedding_key: str,
70+
text_key: str,
71+
index_name: str,
72+
k: int = 4,
73+
pre_filter: dict[str, Any] | None = None,
74+
post_filter_pipeline: list[dict[str, Any]] | None = None,
75+
oversampling_factor: int = 10,
76+
include_embeddings: bool = False,
77+
**kwargs: Any,
78+
) -> list[tuple[Any, float]]:
79+
"""Execute a MongoDB vector search query."""
80+
81+
# Atlas Vector Search, potentially with filter
82+
pipeline = [
83+
vector_search_stage(
84+
query_vector,
85+
embedding_key,
86+
index_name,
87+
k,
88+
pre_filter,
89+
oversampling_factor,
90+
**kwargs,
91+
),
92+
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
93+
]
94+
95+
# Remove embeddings unless requested.
96+
if not include_embeddings:
97+
pipeline.append({"$project": {embedding_key: 0}})
98+
# Post-processing
99+
if post_filter_pipeline is not None:
100+
pipeline.extend(post_filter_pipeline)
101+
102+
# Execution
103+
cursor = collection.aggregate(pipeline)
104+
docs = []
105+
106+
for doc in cursor:
107+
if text_key not in doc:
108+
continue
109+
docs.append(doc)
110+
return docs

tests/test_operation.py

Lines changed: 195 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
"""Tests for operation utilities."""
22

33
import os
4-
from unittest.mock import Mock
4+
from unittest.mock import Mock, patch
55

66
import pytest
77
from bson import ObjectId
88
from pymongo import MongoClient
99
from pymongo.collection import Collection
1010

11-
from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts
11+
from pymongo_vectorsearch_utils import drop_vector_search_index
12+
from pymongo_vectorsearch_utils.index import create_vector_search_index, wait_for_docs_in_index
13+
from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts, execute_search_query
1214

1315
DB_NAME = "vectorsearch_utils_test"
1416
COLLECTION_NAME = "test_operation"
17+
VECTOR_INDEX_NAME = "operation_vector_index"
1518

1619

1720
@pytest.fixture(scope="module")
@@ -22,6 +25,17 @@ def client():
2225
client.close()
2326

2427

28+
@pytest.fixture(scope="module")
29+
def preserved_collection(client):
30+
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
31+
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
32+
else:
33+
clxn = client[DB_NAME][COLLECTION_NAME]
34+
clxn.delete_many({})
35+
yield clxn
36+
clxn.delete_many({})
37+
38+
2539
@pytest.fixture
2640
def collection(client):
2741
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
@@ -266,3 +280,182 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
266280
assert "vector" in doc
267281
assert doc["content"] == texts[0]
268282
assert doc["vector"] == [0.0, 0.0, 0.0]
283+
284+
285+
class TestExecuteSearchQuery:
286+
@pytest.fixture(scope="class", autouse=True)
287+
def vector_search_index(self, client):
288+
coll = client[DB_NAME][COLLECTION_NAME]
289+
if len(coll.list_search_indexes(VECTOR_INDEX_NAME).to_list()) == 0:
290+
create_vector_search_index(
291+
collection=coll,
292+
index_name=VECTOR_INDEX_NAME,
293+
dimensions=3,
294+
path="embedding",
295+
similarity="cosine",
296+
filters=["category", "color", "wheels"],
297+
wait_until_complete=120,
298+
)
299+
yield
300+
drop_vector_search_index(collection=coll, index_name=VECTOR_INDEX_NAME)
301+
302+
@pytest.fixture(scope="class", autouse=True)
303+
def sample_docs(self, preserved_collection: Collection, vector_search_index):
304+
texts = ["apple fruit", "banana fruit", "car vehicle", "bike vehicle"]
305+
metadatas = [
306+
{"category": "fruit", "color": "red"},
307+
{"category": "fruit", "color": "yellow"},
308+
{"category": "vehicle", "wheels": 4},
309+
{"category": "vehicle", "wheels": 2},
310+
]
311+
312+
def embeddings(texts):
313+
mapping = {
314+
"apple fruit": [1.0, 0.5, 0.0],
315+
"banana fruit": [0.5, 0.5, 0.0],
316+
"car vehicle": [0.0, 0.5, 1.0],
317+
"bike vehicle": [0.0, 1.0, 0.5],
318+
}
319+
return [mapping[text] for text in texts]
320+
321+
bulk_embed_and_insert_texts(
322+
texts=texts,
323+
metadatas=metadatas,
324+
embedding_func=embeddings,
325+
collection=preserved_collection,
326+
text_key="text",
327+
embedding_key="embedding",
328+
)
329+
# Add a document that should not be returned in searches
330+
preserved_collection.insert_one(
331+
{
332+
"category": "fruit",
333+
"color": "red",
334+
"embedding": [1.0, 1.0, 1.0],
335+
}
336+
)
337+
wait_for_docs_in_index(preserved_collection, VECTOR_INDEX_NAME, n_docs=5)
338+
return preserved_collection
339+
340+
def test_basic_search_query(self, sample_docs: Collection):
341+
query_vector = [1.0, 0.5, 0.0]
342+
343+
result = execute_search_query(
344+
query_vector=query_vector,
345+
collection=sample_docs,
346+
embedding_key="embedding",
347+
text_key="text",
348+
index_name=VECTOR_INDEX_NAME,
349+
k=2,
350+
)
351+
352+
assert len(result) == 2
353+
assert result[0]["text"] == "apple fruit"
354+
assert result[1]["text"] == "banana fruit"
355+
assert "score" in result[0]
356+
assert "score" in result[1]
357+
358+
def test_search_with_pre_filter(self, sample_docs: Collection):
359+
query_vector = [1.0, 0.5, 1.0]
360+
pre_filter = {"category": "fruit"}
361+
362+
result = execute_search_query(
363+
query_vector=query_vector,
364+
collection=sample_docs,
365+
embedding_key="embedding",
366+
text_key="text",
367+
index_name=VECTOR_INDEX_NAME,
368+
k=4,
369+
pre_filter=pre_filter,
370+
)
371+
372+
assert len(result) == 2
373+
assert result[0]["category"] == "fruit"
374+
assert result[1]["category"] == "fruit"
375+
376+
def test_search_with_post_filter_pipeline(self, sample_docs: Collection):
377+
query_vector = [1.0, 0.5, 0.0]
378+
post_filter_pipeline = [
379+
{"$match": {"score": {"$gte": 0.99}}},
380+
{"$sort": {"score": -1}},
381+
]
382+
383+
result = execute_search_query(
384+
query_vector=query_vector,
385+
collection=sample_docs,
386+
embedding_key="embedding",
387+
text_key="text",
388+
index_name=VECTOR_INDEX_NAME,
389+
k=2,
390+
post_filter_pipeline=post_filter_pipeline,
391+
)
392+
393+
assert len(result) == 1
394+
395+
def test_search_with_embeddings_included(self, sample_docs: Collection):
396+
query_vector = [1.0, 0.5, 0.0]
397+
398+
result = execute_search_query(
399+
query_vector=query_vector,
400+
collection=sample_docs,
401+
embedding_key="embedding",
402+
text_key="text",
403+
index_name=VECTOR_INDEX_NAME,
404+
k=1,
405+
include_embeddings=True,
406+
)
407+
408+
assert len(result) == 1
409+
assert "embedding" in result[0]
410+
assert result[0]["embedding"] == [1.0, 0.5, 0.0]
411+
412+
def test_search_with_custom_field_names(self, sample_docs: Collection):
413+
query_vector = [1.0, 0.5, 0.25]
414+
415+
mock_cursor = [
416+
{
417+
"_id": ObjectId(),
418+
"content": "apple fruit",
419+
"vector": [1.0, 0.5, 0.25],
420+
"score": 0.9,
421+
}
422+
]
423+
424+
with patch.object(sample_docs, "aggregate") as mock_aggregate:
425+
mock_aggregate.return_value = mock_cursor
426+
427+
result = execute_search_query(
428+
query_vector=query_vector,
429+
collection=sample_docs,
430+
embedding_key="vector",
431+
text_key="content",
432+
index_name=VECTOR_INDEX_NAME,
433+
k=1,
434+
)
435+
436+
assert len(result) == 1
437+
assert "content" in result[0]
438+
assert result[0]["content"] == "apple fruit"
439+
440+
pipeline_arg = mock_aggregate.call_args[0][0]
441+
vector_search_stage = pipeline_arg[0]["$vectorSearch"]
442+
assert vector_search_stage["path"] == "vector"
443+
assert {"$project": {"vector": 0}} in pipeline_arg
444+
445+
def test_search_filters_documents_without_text_key(self, sample_docs: Collection):
446+
query_vector = [1.0, 0.5, 0.0]
447+
448+
result = execute_search_query(
449+
query_vector=query_vector,
450+
collection=sample_docs,
451+
embedding_key="embedding",
452+
text_key="text",
453+
index_name=VECTOR_INDEX_NAME,
454+
k=3,
455+
)
456+
457+
# Should only return documents with text field
458+
assert len(result) == 2
459+
assert all("text" in doc for doc in result)
460+
assert result[0]["text"] == "apple fruit"
461+
assert result[1]["text"] == "banana fruit"

0 commit comments

Comments
 (0)