Skip to content

Commit 62a5c5f

Browse files
authored
Merge pull request #3 from NoahStapp/add-bulk
Add bulk_embed_and_insert_texts
2 parents 0296924 + ae226b3 commit 62a5c5f

File tree

5 files changed

+377
-0
lines changed

5 files changed

+377
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
.DS_STORE
33
__pycache__/
4+
.env

pymongo_vectorsearch_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
drop_vector_search_index,
66
update_vector_search_index,
77
)
8+
from .operation import bulk_embed_and_insert_texts
89
from .pipeline import (
910
combine_pipelines,
1011
final_hybrid_stage,
@@ -24,4 +25,5 @@
2425
"combine_pipelines",
2526
"reciprocal_rank_stage",
2627
"final_hybrid_stage",
28+
"bulk_embed_and_insert_texts",
2729
]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""CRUD utilities and helpers."""
2+
3+
from collections.abc import Callable, Generator, Iterable
4+
from typing import Any
5+
6+
from bson import ObjectId
7+
from pymongo import ReplaceOne
8+
from pymongo.synchronous.collection import Collection
9+
10+
from pymongo_vectorsearch_utils.util import oid_to_str, str_to_oid
11+
12+
13+
def bulk_embed_and_insert_texts(
14+
texts: list[str] | Iterable[str],
15+
metadatas: list[dict] | Generator[dict, Any, Any],
16+
embedding_func: Callable[[list[str]], list[list[float]]],
17+
collection: Collection[Any],
18+
text_key: str,
19+
embedding_key: str,
20+
ids: list[str] | None = None,
21+
**kwargs: Any,
22+
) -> list[str]:
23+
"""Bulk insert single batch of texts, embeddings, and optionally ids.
24+
25+
Important notes on ids:
26+
- If _id or id is a key in the metadatas dicts, one must
27+
pop them and provide as separate list.
28+
- They must be unique.
29+
- If they are not provided, unique ones are created,
30+
stored as bson.ObjectIds internally, and strings in the database.
31+
These will appear in Document.metadata with key, '_id'.
32+
33+
Args:
34+
texts: Iterable of strings to add to the vectorstore.
35+
metadatas: Optional list of metadatas associated with the texts.
36+
embedding_func: A function that generates embedding vectors from the texts.
37+
collection: The MongoDB collection where documents will be inserted.
38+
text_key: The field name where thet text will be stored in each document.
39+
embedding_key: The field name where the embedding will be stored in each document.
40+
ids: Optional list of unique ids that will be used as index in VectorStore.
41+
See note on ids.
42+
"""
43+
if not texts:
44+
return []
45+
# Compute embedding vectors
46+
embeddings = embedding_func(list(texts))
47+
if not ids:
48+
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
49+
docs = [
50+
{
51+
"_id": str_to_oid(i),
52+
text_key: t,
53+
embedding_key: embedding,
54+
**m,
55+
}
56+
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings, strict=False)
57+
]
58+
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
59+
# insert the documents in MongoDB Atlas
60+
result = collection.bulk_write(operations)
61+
assert result.upserted_ids is not None
62+
return [oid_to_str(_id) for _id in result.upserted_ids.values()]

pymongo_vectorsearch_utils/util.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
from typing import Any
3+
4+
logger = logging.getLogger(__file__)
5+
6+
7+
def str_to_oid(str_repr: str) -> Any | str:
8+
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
9+
10+
To be consistent with ObjectId, input must be a 24 character hex string.
11+
If it is not, MongoDB will happily use the string in the main _id index.
12+
Importantly, the str representation that comes out of MongoDB will have this form.
13+
14+
Args:
15+
str_repr: id as string.
16+
17+
Returns:
18+
ObjectID
19+
"""
20+
from bson import ObjectId
21+
from bson.errors import InvalidId
22+
23+
try:
24+
return ObjectId(str_repr)
25+
except InvalidId:
26+
logger.debug(
27+
"ObjectIds must be 12-character byte or 24-character hex strings. "
28+
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
29+
)
30+
return str_repr
31+
32+
33+
def oid_to_str(oid: Any) -> str:
34+
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
35+
36+
Instructive helper to show where data is coming out of MongoDB.
37+
38+
Args:
39+
oid: bson.ObjectId
40+
41+
Returns:
42+
24 character hex string.
43+
"""
44+
return str(oid)

tests/test_operation.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""Tests for operation utilities."""
2+
3+
import os
4+
from unittest.mock import Mock
5+
6+
import pytest
7+
from bson import ObjectId
8+
from pymongo import MongoClient
9+
from pymongo.collection import Collection
10+
11+
from pymongo_vectorsearch_utils.operation import bulk_embed_and_insert_texts
12+
13+
DB_NAME = "vectorsearch_utils_test"
14+
COLLECTION_NAME = "test_operation"
15+
16+
17+
@pytest.fixture(scope="module")
18+
def client():
19+
conn_str = os.environ.get("MONGODB_URI", "mongodb://127.0.0.1:27017?directConnection=true")
20+
client = MongoClient(conn_str)
21+
yield client
22+
client.close()
23+
24+
25+
@pytest.fixture
26+
def collection(client):
27+
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
28+
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
29+
else:
30+
clxn = client[DB_NAME][COLLECTION_NAME]
31+
clxn.delete_many({})
32+
yield clxn
33+
clxn.delete_many({})
34+
35+
36+
@pytest.fixture
37+
def mock_embedding_func():
38+
"""Mock embedding function that returns predictable embeddings."""
39+
40+
def embedding_func(texts):
41+
return [[float(i), float(i) * 0.5, float(i) * 0.25] for i in range(len(texts))]
42+
43+
return embedding_func
44+
45+
46+
class TestBulkEmbedAndInsertTexts:
47+
def test_empty_texts_returns_empty_list(self, collection: Collection, mock_embedding_func):
48+
result = bulk_embed_and_insert_texts(
49+
texts=[],
50+
metadatas=[],
51+
embedding_func=mock_embedding_func,
52+
collection=collection,
53+
text_key="text",
54+
embedding_key="embedding",
55+
)
56+
assert result == []
57+
58+
def test_basic_insertion_with_generated_ids(self, collection: Collection, mock_embedding_func):
59+
texts = ["text one", "text two"]
60+
metadatas = [{"category": "test_1"}, {"category": "test_2"}]
61+
62+
result = bulk_embed_and_insert_texts(
63+
texts=texts,
64+
metadatas=metadatas,
65+
embedding_func=mock_embedding_func,
66+
collection=collection,
67+
text_key="content",
68+
embedding_key="vector",
69+
)
70+
71+
assert len(result) == 2
72+
assert all(isinstance(id_str, str) for id_str in result)
73+
74+
docs = list(collection.find({}))
75+
assert len(docs) == 2
76+
77+
for i, doc in enumerate(docs):
78+
assert doc["content"] == texts[i]
79+
assert doc["vector"] == [float(i), float(i) * 0.5, float(i) * 0.25]
80+
assert doc["category"] == metadatas[i]["category"]
81+
assert isinstance(doc["_id"], ObjectId)
82+
83+
def test_insertion_with_custom_ids(self, collection: Collection, mock_embedding_func):
84+
texts = ["text one"]
85+
metadatas = [{"type": "custom"}]
86+
custom_ids = ["custom_id_123"]
87+
88+
result = bulk_embed_and_insert_texts(
89+
texts=texts,
90+
metadatas=metadatas,
91+
embedding_func=mock_embedding_func,
92+
collection=collection,
93+
text_key="text",
94+
embedding_key="embedding",
95+
ids=custom_ids,
96+
)
97+
98+
assert result == custom_ids
99+
100+
doc = collection.find_one({"_id": "custom_id_123"})
101+
assert doc is not None
102+
assert doc["text"] == texts[0]
103+
assert doc["type"] == "custom"
104+
105+
def test_insertion_with_objectid_string_ids(self, collection: Collection, mock_embedding_func):
106+
texts = ["text one"]
107+
metadatas = [{"test": True}]
108+
object_id_str = str(ObjectId())
109+
110+
result = bulk_embed_and_insert_texts(
111+
texts=texts,
112+
metadatas=metadatas,
113+
embedding_func=mock_embedding_func,
114+
collection=collection,
115+
text_key="text",
116+
embedding_key="embedding",
117+
ids=[object_id_str],
118+
)
119+
120+
assert result == [object_id_str]
121+
122+
# Verify document was inserted with ObjectId
123+
doc = collection.find_one({})
124+
assert doc is not None
125+
assert isinstance(doc["_id"], ObjectId)
126+
assert str(doc["_id"]) == object_id_str
127+
128+
def test_upsert_behavior(self, collection: Collection, mock_embedding_func):
129+
texts = ["text one"]
130+
metadatas = [{"version": 1}]
131+
custom_id = "upsert_id"
132+
133+
# First insertion
134+
bulk_embed_and_insert_texts(
135+
texts=texts,
136+
metadatas=metadatas,
137+
embedding_func=mock_embedding_func,
138+
collection=collection,
139+
text_key="text",
140+
embedding_key="embedding",
141+
ids=[custom_id],
142+
)
143+
144+
new_metadatas = [{"version": 2}]
145+
bulk_embed_and_insert_texts(
146+
texts=["updated text"],
147+
metadatas=new_metadatas,
148+
embedding_func=mock_embedding_func,
149+
collection=collection,
150+
text_key="text",
151+
embedding_key="embedding",
152+
ids=[custom_id],
153+
)
154+
155+
docs = list(collection.find({}))
156+
assert len(docs) == 1
157+
assert docs[0]["text"] == "updated text"
158+
assert docs[0]["version"] == 2
159+
160+
def test_with_generator_metadata(self, collection: Collection, mock_embedding_func):
161+
def metadata_generator():
162+
yield {"index": 0}
163+
yield {"index": 1}
164+
165+
result = bulk_embed_and_insert_texts(
166+
texts=["text one", "text two"],
167+
metadatas=metadata_generator(),
168+
embedding_func=mock_embedding_func,
169+
collection=collection,
170+
text_key="text",
171+
embedding_key="embedding",
172+
)
173+
174+
assert len(result) == 2
175+
docs = list(collection.find({}).sort("index", 1))
176+
assert len(docs) == 2
177+
assert docs[0]["text"] == "text one"
178+
assert docs[1]["text"] == "text two"
179+
180+
def test_embedding_function_called_correctly(self, collection: Collection):
181+
texts = ["text one", "text two", "text three"]
182+
metadatas = [{}, {}, {}]
183+
184+
mock_embedding_func = Mock(return_value=[[1.0], [2.0], [3.0]])
185+
186+
bulk_embed_and_insert_texts(
187+
texts=texts,
188+
metadatas=metadatas,
189+
embedding_func=mock_embedding_func,
190+
collection=collection,
191+
text_key="text",
192+
embedding_key="embedding",
193+
)
194+
195+
mock_embedding_func.assert_called_once_with(texts)
196+
197+
def test_large_batch_processing(self, collection: Collection, mock_embedding_func):
198+
num_docs = 100
199+
texts = [f"text {i}" for i in range(num_docs)]
200+
metadatas = [{"doc_num": i} for i in range(num_docs)]
201+
202+
result = bulk_embed_and_insert_texts(
203+
texts=texts,
204+
metadatas=metadatas,
205+
embedding_func=mock_embedding_func,
206+
collection=collection,
207+
text_key="text",
208+
embedding_key="embedding",
209+
)
210+
211+
assert len(result) == num_docs
212+
assert collection.count_documents({}) == num_docs
213+
214+
def test_with_additional_kwargs(self, collection: Collection, mock_embedding_func):
215+
texts = ["text one"]
216+
metadatas = [{}]
217+
218+
result = bulk_embed_and_insert_texts(
219+
texts=texts,
220+
metadatas=metadatas,
221+
embedding_func=mock_embedding_func,
222+
collection=collection,
223+
text_key="text",
224+
embedding_key="embedding",
225+
extra_param="ignored",
226+
)
227+
228+
assert len(result) == 1
229+
230+
def test_mismatched_lengths_handled_gracefully(
231+
self, collection: Collection, mock_embedding_func
232+
):
233+
texts = ["text one", "text two"]
234+
metadatas = [{"meta": 1}] # Shorter than texts
235+
236+
result = bulk_embed_and_insert_texts(
237+
texts=texts,
238+
metadatas=metadatas,
239+
embedding_func=mock_embedding_func,
240+
collection=collection,
241+
text_key="text",
242+
embedding_key="embedding",
243+
)
244+
245+
assert len(result) == 1
246+
docs = list(collection.find({}))
247+
assert len(docs) == 1
248+
assert docs[0]["text"] == "text one"
249+
250+
def test_custom_field_names(self, collection: Collection, mock_embedding_func):
251+
texts = ["text one"]
252+
metadatas = [{}]
253+
254+
bulk_embed_and_insert_texts(
255+
texts=texts,
256+
metadatas=metadatas,
257+
embedding_func=mock_embedding_func,
258+
collection=collection,
259+
text_key="content",
260+
embedding_key="vector",
261+
)
262+
263+
doc = collection.find_one({})
264+
assert doc is not None
265+
assert "content" in doc
266+
assert "vector" in doc
267+
assert doc["content"] == texts[0]
268+
assert doc["vector"] == [0.0, 0.0, 0.0]

0 commit comments

Comments
 (0)