Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class VertexAIEmbeddings(BaseModel, Embeddings):
"the environment."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
dimensions: Optional[int] = None
"""Default output dimensionality for embeddings. If not specified, uses the
model's default. Can be overridden per request in embed() method."""

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -122,7 +125,8 @@ def embed(
QUESTION_ANSWERING
FACT_VERIFICATION
dimensions: [int] optional. Output embeddings dimensions.
Only supported on preview models.
Only supported on preview models. If not provided, uses the
default dimensions specified in the constructor.
title: [str] optional. Title for the text. Only applicable when
TaskType is RETRIEVAL_DOCUMENT

Expand All @@ -131,10 +135,11 @@ def embed(
"""
if len(texts) == 0:
return []
effective_dimensions = dimensions if dimensions is not None else self.dimensions
embeddings = self._get_embeddings_with_retry(
texts=texts,
embeddings_type=embeddings_task_type,
dimensions=dimensions,
dimensions=effective_dimensions,
title=title,
)
return embeddings
Expand Down
100 changes: 100 additions & 0 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,103 @@ def test_embed_parameters(mock_get_embeddings, mock_client):
dimensions=128,
title="test-title",
)


@patch("langchain_google_vertexai.embeddings.genai.Client")
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
def test_default_dimensions_used_when_not_specified(mock_get_embeddings, mock_client):
"""Test that constructor dimensions are used when not specified in embed()."""
mock_client.return_value = MagicMock()
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=256)
texts = ["hello", "world"]

mock_get_embeddings.return_value = [[0.001] * 256 for _ in texts]

embeddings.embed(texts)

mock_get_embeddings.assert_called_once_with(
texts=texts,
embeddings_type=None,
dimensions=256,
title=None,
)


@patch("langchain_google_vertexai.embeddings.genai.Client")
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
def test_explicit_dimensions_override_default(mock_get_embeddings, mock_client):
"""Test that explicit dimensions in embed() override constructor default."""
mock_client.return_value = MagicMock()
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=256)
texts = ["hello", "world"]

mock_get_embeddings.return_value = [[0.001] * 512 for _ in texts]

embeddings.embed(texts, dimensions=512)

mock_get_embeddings.assert_called_once_with(
texts=texts,
embeddings_type=None,
dimensions=512,
title=None,
)


@patch("langchain_google_vertexai.embeddings.genai.Client")
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
def test_no_default_dimensions_works_as_before(mock_get_embeddings, mock_client):
"""Test backward compatibility when no default dimensions specified."""
mock_client.return_value = MagicMock()
embeddings = VertexAIEmbeddings(model="text-embedding-004")
texts = ["hello", "world"]

mock_get_embeddings.return_value = [[0.001] * 768 for _ in texts]

embeddings.embed(texts)

mock_get_embeddings.assert_called_once_with(
texts=texts,
embeddings_type=None,
dimensions=None,
title=None,
)


@patch("langchain_google_vertexai.embeddings.genai.Client")
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
def test_default_dimensions_used_in_embed_documents(mock_get_embeddings, mock_client):
"""Test that constructor dimensions are used in embed_documents()."""
mock_client.return_value = MagicMock()
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=128)
texts = ["hello", "world"]

mock_get_embeddings.return_value = [[0.001] * 128 for _ in texts]

embeddings.embed_documents(texts)

mock_get_embeddings.assert_called_once_with(
texts=texts,
embeddings_type="RETRIEVAL_DOCUMENT",
dimensions=128,
title=None,
)


@patch("langchain_google_vertexai.embeddings.genai.Client")
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
def test_default_dimensions_used_in_embed_query(mock_get_embeddings, mock_client):
"""Test that constructor dimensions are used in embed_query()."""
mock_client.return_value = MagicMock()
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=128)
text = "hello"

mock_get_embeddings.return_value = [[0.001] * 128]

embeddings.embed_query(text)

mock_get_embeddings.assert_called_once_with(
texts=[text],
embeddings_type="RETRIEVAL_QUERY",
dimensions=128,
title=None,
)