1
1
"""Tests for operation utilities."""
2
2
3
3
import os
4
- from unittest .mock import Mock
4
+ from unittest .mock import Mock , patch
5
5
6
6
import pytest
7
7
from bson import ObjectId
8
8
from pymongo import MongoClient
9
9
from pymongo .collection import Collection
10
10
11
- from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts
11
+ from pymongo_vectorsearch_utils import drop_vector_search_index
12
+ from pymongo_vectorsearch_utils .index import create_vector_search_index , wait_for_docs_in_index
13
+ from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts , execute_search_query
12
14
13
15
DB_NAME = "vectorsearch_utils_test"
14
16
COLLECTION_NAME = "test_operation"
17
+ VECTOR_INDEX_NAME = "operation_vector_index"
15
18
16
19
17
20
@pytest .fixture (scope = "module" )
@@ -22,6 +25,17 @@ def client():
22
25
client .close ()
23
26
24
27
28
+ @pytest .fixture (scope = "module" )
29
+ def preserved_collection (client ):
30
+ if COLLECTION_NAME not in client [DB_NAME ].list_collection_names ():
31
+ clxn = client [DB_NAME ].create_collection (COLLECTION_NAME )
32
+ else :
33
+ clxn = client [DB_NAME ][COLLECTION_NAME ]
34
+ clxn .delete_many ({})
35
+ yield clxn
36
+ clxn .delete_many ({})
37
+
38
+
25
39
@pytest .fixture
26
40
def collection (client ):
27
41
if COLLECTION_NAME not in client [DB_NAME ].list_collection_names ():
@@ -266,3 +280,182 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
266
280
assert "vector" in doc
267
281
assert doc ["content" ] == texts [0 ]
268
282
assert doc ["vector" ] == [0.0 , 0.0 , 0.0 ]
283
+
284
+
285
+ class TestExecuteSearchQuery :
286
+ @pytest .fixture (scope = "class" , autouse = True )
287
+ def vector_search_index (self , client ):
288
+ coll = client [DB_NAME ][COLLECTION_NAME ]
289
+ if len (coll .list_search_indexes (VECTOR_INDEX_NAME ).to_list ()) == 0 :
290
+ create_vector_search_index (
291
+ collection = coll ,
292
+ index_name = VECTOR_INDEX_NAME ,
293
+ dimensions = 3 ,
294
+ path = "embedding" ,
295
+ similarity = "cosine" ,
296
+ filters = ["category" , "color" , "wheels" ],
297
+ wait_until_complete = 120 ,
298
+ )
299
+ yield
300
+ drop_vector_search_index (collection = coll , index_name = VECTOR_INDEX_NAME )
301
+
302
+ @pytest .fixture (scope = "class" , autouse = True )
303
+ def sample_docs (self , preserved_collection : Collection , vector_search_index ):
304
+ texts = ["apple fruit" , "banana fruit" , "car vehicle" , "bike vehicle" ]
305
+ metadatas = [
306
+ {"category" : "fruit" , "color" : "red" },
307
+ {"category" : "fruit" , "color" : "yellow" },
308
+ {"category" : "vehicle" , "wheels" : 4 },
309
+ {"category" : "vehicle" , "wheels" : 2 },
310
+ ]
311
+
312
+ def embeddings (texts ):
313
+ mapping = {
314
+ "apple fruit" : [1.0 , 0.5 , 0.0 ],
315
+ "banana fruit" : [0.5 , 0.5 , 0.0 ],
316
+ "car vehicle" : [0.0 , 0.5 , 1.0 ],
317
+ "bike vehicle" : [0.0 , 1.0 , 0.5 ],
318
+ }
319
+ return [mapping [text ] for text in texts ]
320
+
321
+ bulk_embed_and_insert_texts (
322
+ texts = texts ,
323
+ metadatas = metadatas ,
324
+ embedding_func = embeddings ,
325
+ collection = preserved_collection ,
326
+ text_key = "text" ,
327
+ embedding_key = "embedding" ,
328
+ )
329
+ # Add a document that should not be returned in searches
330
+ preserved_collection .insert_one (
331
+ {
332
+ "category" : "fruit" ,
333
+ "color" : "red" ,
334
+ "embedding" : [1.0 , 1.0 , 1.0 ],
335
+ }
336
+ )
337
+ wait_for_docs_in_index (preserved_collection , VECTOR_INDEX_NAME , n_docs = 5 )
338
+ return preserved_collection
339
+
340
+ def test_basic_search_query (self , sample_docs : Collection ):
341
+ query_vector = [1.0 , 0.5 , 0.0 ]
342
+
343
+ result = execute_search_query (
344
+ query_vector = query_vector ,
345
+ collection = sample_docs ,
346
+ embedding_key = "embedding" ,
347
+ text_key = "text" ,
348
+ index_name = VECTOR_INDEX_NAME ,
349
+ k = 2 ,
350
+ )
351
+
352
+ assert len (result ) == 2
353
+ assert result [0 ]["text" ] == "apple fruit"
354
+ assert result [1 ]["text" ] == "banana fruit"
355
+ assert "score" in result [0 ]
356
+ assert "score" in result [1 ]
357
+
358
+ def test_search_with_pre_filter (self , sample_docs : Collection ):
359
+ query_vector = [1.0 , 0.5 , 1.0 ]
360
+ pre_filter = {"category" : "fruit" }
361
+
362
+ result = execute_search_query (
363
+ query_vector = query_vector ,
364
+ collection = sample_docs ,
365
+ embedding_key = "embedding" ,
366
+ text_key = "text" ,
367
+ index_name = VECTOR_INDEX_NAME ,
368
+ k = 4 ,
369
+ pre_filter = pre_filter ,
370
+ )
371
+
372
+ assert len (result ) == 2
373
+ assert result [0 ]["category" ] == "fruit"
374
+ assert result [1 ]["category" ] == "fruit"
375
+
376
+ def test_search_with_post_filter_pipeline (self , sample_docs : Collection ):
377
+ query_vector = [1.0 , 0.5 , 0.0 ]
378
+ post_filter_pipeline = [
379
+ {"$match" : {"score" : {"$gte" : 0.99 }}},
380
+ {"$sort" : {"score" : - 1 }},
381
+ ]
382
+
383
+ result = execute_search_query (
384
+ query_vector = query_vector ,
385
+ collection = sample_docs ,
386
+ embedding_key = "embedding" ,
387
+ text_key = "text" ,
388
+ index_name = VECTOR_INDEX_NAME ,
389
+ k = 2 ,
390
+ post_filter_pipeline = post_filter_pipeline ,
391
+ )
392
+
393
+ assert len (result ) == 1
394
+
395
+ def test_search_with_embeddings_included (self , sample_docs : Collection ):
396
+ query_vector = [1.0 , 0.5 , 0.0 ]
397
+
398
+ result = execute_search_query (
399
+ query_vector = query_vector ,
400
+ collection = sample_docs ,
401
+ embedding_key = "embedding" ,
402
+ text_key = "text" ,
403
+ index_name = VECTOR_INDEX_NAME ,
404
+ k = 1 ,
405
+ include_embeddings = True ,
406
+ )
407
+
408
+ assert len (result ) == 1
409
+ assert "embedding" in result [0 ]
410
+ assert result [0 ]["embedding" ] == [1.0 , 0.5 , 0.0 ]
411
+
412
+ def test_search_with_custom_field_names (self , sample_docs : Collection ):
413
+ query_vector = [1.0 , 0.5 , 0.25 ]
414
+
415
+ mock_cursor = [
416
+ {
417
+ "_id" : ObjectId (),
418
+ "content" : "apple fruit" ,
419
+ "vector" : [1.0 , 0.5 , 0.25 ],
420
+ "score" : 0.9 ,
421
+ }
422
+ ]
423
+
424
+ with patch .object (sample_docs , "aggregate" ) as mock_aggregate :
425
+ mock_aggregate .return_value = mock_cursor
426
+
427
+ result = execute_search_query (
428
+ query_vector = query_vector ,
429
+ collection = sample_docs ,
430
+ embedding_key = "vector" ,
431
+ text_key = "content" ,
432
+ index_name = VECTOR_INDEX_NAME ,
433
+ k = 1 ,
434
+ )
435
+
436
+ assert len (result ) == 1
437
+ assert "content" in result [0 ]
438
+ assert result [0 ]["content" ] == "apple fruit"
439
+
440
+ pipeline_arg = mock_aggregate .call_args [0 ][0 ]
441
+ vector_search_stage = pipeline_arg [0 ]["$vectorSearch" ]
442
+ assert vector_search_stage ["path" ] == "vector"
443
+ assert {"$project" : {"vector" : 0 }} in pipeline_arg
444
+
445
+ def test_search_filters_documents_without_text_key (self , sample_docs : Collection ):
446
+ query_vector = [1.0 , 0.5 , 0.0 ]
447
+
448
+ result = execute_search_query (
449
+ query_vector = query_vector ,
450
+ collection = sample_docs ,
451
+ embedding_key = "embedding" ,
452
+ text_key = "text" ,
453
+ index_name = VECTOR_INDEX_NAME ,
454
+ k = 3 ,
455
+ )
456
+
457
+ # Should only return documents with text field
458
+ assert len (result ) == 2
459
+ assert all ("text" in doc for doc in result )
460
+ assert result [0 ]["text" ] == "apple fruit"
461
+ assert result [1 ]["text" ] == "banana fruit"
0 commit comments