Skip to content

Commit b04d136

Browse files
committed
Squash merge fix/lazy into refactor/google-prompt-driver
1 parent a7f2b2d commit b04d136

21 files changed

+111
-77
lines changed

docs/griptape-framework/drivers/src/vector_store_drivers_10.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from qdrant_client.http.models import Distance, VectorParams
4+
35
from griptape.chunkers import TextChunker
46
from griptape.drivers.embedding.openai import OpenAiEmbeddingDriver
57
from griptape.drivers.vector.qdrant import QdrantVectorStoreDriver
@@ -27,7 +29,7 @@
2729
# Recreate Qdrant collection
2830
vector_store_driver.client.recreate_collection(
2931
collection_name=vector_store_driver.collection_name,
30-
vectors_config={"size": 1536, "distance": vector_store_driver.distance},
32+
vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
3133
)
3234

3335
# Upsert Artifacts into the Vector Store Driver

griptape/drivers/assistant/openai_assistant_driver.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,26 @@ def client(self) -> openai.OpenAI:
5959
)
6060

6161
def try_run(self, *args: BaseArtifact) -> TextArtifact:
62-
if self.thread_id is None and self.auto_create_thread:
63-
self.thread_id = self.client.beta.threads.create().id
64-
response = self._create_run(*args)
62+
if self.thread_id is None:
63+
if self.auto_create_thread:
64+
thread_id = self.client.beta.threads.create().id
65+
self.thread_id = thread_id
66+
else:
67+
raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.")
68+
else:
69+
thread_id = self.thread_id
70+
71+
response = self._create_run(thread_id, *args)
6572

6673
response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id})
6774

6875
return response
6976

70-
def _create_run(self, *args: BaseArtifact) -> TextArtifact:
77+
def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact:
7178
content = "\n".join(arg.value for arg in args)
72-
message_id = self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=content)
79+
message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content)
7380
with self.client.beta.threads.runs.stream(
74-
thread_id=self.thread_id,
81+
thread_id=thread_id,
7582
assistant_id=self.assistant_id,
7683
event_handler=self.event_handler,
7784
) as stream:
@@ -80,7 +87,9 @@ def _create_run(self, *args: BaseArtifact) -> TextArtifact:
8087

8188
message_contents = []
8289
for message in last_messages:
83-
message_contents.append("".join(content.text.value for content in message.content))
90+
message_contents.append(
91+
"".join(content.text.value for content in message.content if content.type == "TextContentBlock")
92+
)
8493
message_text = "\n".join(message_contents)
8594

8695
response = TextArtifact(message_text)

griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
if TYPE_CHECKING:
1414
import boto3
15-
from mypy_boto3_bedrock import BedrockClient
15+
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
1616

1717
from griptape.tokenizers.base_tokenizer import BaseTokenizer
1818

@@ -40,12 +40,12 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
4040
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
4141
kw_only=True,
4242
)
43-
_client: Optional[BedrockClient] = field(
43+
_client: Optional[BedrockRuntimeClient] = field(
4444
default=None, kw_only=True, alias="client", metadata={"serializable": False}
4545
)
4646

4747
@lazy_property()
48-
def client(self) -> BedrockClient:
48+
def client(self) -> BedrockRuntimeClient:
4949
return self.session.client("bedrock-runtime")
5050

5151
def try_embed_chunk(self, chunk: str) -> list[float]:

griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
if TYPE_CHECKING:
1616
import boto3
17-
from mypy_boto3_bedrock import BedrockClient
17+
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
1818

1919
from griptape.tokenizers.base_tokenizer import BaseTokenizer
2020

@@ -38,12 +38,12 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
3838
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
3939
kw_only=True,
4040
)
41-
_client: Optional[BedrockClient] = field(
41+
_client: Optional[BedrockRuntimeClient] = field(
4242
default=None, kw_only=True, alias="client", metadata={"serializable": False}
4343
)
4444

4545
@lazy_property()
46-
def client(self) -> BedrockClient:
46+
def client(self) -> BedrockRuntimeClient:
4747
return self.session.client("bedrock-runtime")
4848

4949
def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact) -> list[float]:

griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
if TYPE_CHECKING:
1313
import boto3
14-
from mypy_boto3_sagemaker import SageMakerClient
14+
from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient
1515

1616

1717
@define
@@ -20,12 +20,12 @@ class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
2020
endpoint: str = field(kw_only=True, metadata={"serializable": True})
2121
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
2222
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
23-
_client: Optional[SageMakerClient] = field(
23+
_client: Optional[SageMakerRuntimeClient] = field(
2424
default=None, kw_only=True, alias="client", metadata={"serializable": False}
2525
)
2626

2727
@lazy_property()
28-
def client(self) -> SageMakerClient:
28+
def client(self) -> SageMakerRuntimeClient:
2929
return self.session.client("sagemaker-runtime")
3030

3131
def try_embed_chunk(self, chunk: str) -> list[float]:

griptape/drivers/embedding/huggingface_hub_embedding_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ def client(self) -> InferenceClient:
3737
def try_embed_chunk(self, chunk: str) -> list[float]:
3838
response = self.client.feature_extraction(chunk)
3939

40-
return response.flatten().tolist()
40+
return [float(val) for val in response.flatten().tolist()]

griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from griptape.utils.decorators import lazy_property
1111

1212
if TYPE_CHECKING:
13+
from collections.abc import Sequence
14+
1315
import boto3
1416
from mypy_boto3_sqs import SQSClient
17+
from mypy_boto3_sqs.type_defs import SendMessageBatchRequestEntryTypeDef
1518

1619

1720
@define
@@ -28,7 +31,7 @@ def try_publish_event_payload(self, event_payload: dict) -> None:
2831
self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))
2932

3033
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
31-
entries = [
34+
entries: Sequence[SendMessageBatchRequestEntryTypeDef] = [
3235
{"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
3336
for event_payload in event_payload_batch
3437
]

griptape/drivers/file_manager/amazon_s3_file_manager_driver.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
if TYPE_CHECKING:
1313
import boto3
1414
from mypy_boto3_s3 import S3Client
15+
from mypy_boto3_s3.type_defs import PaginatorConfigTypeDef
1516

1617

1718
@define
@@ -98,7 +99,7 @@ def _to_dir_full_key(self, path: str) -> str:
9899

99100
def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
100101
max_items = kwargs.get("max_items")
101-
pagination_config = {}
102+
pagination_config: PaginatorConfigTypeDef = {}
102103
if max_items is not None:
103104
pagination_config["MaxItems"] = max_items
104105

@@ -112,12 +113,12 @@ def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
112113
files_and_dirs = []
113114
for page in pages:
114115
for obj in page.get("CommonPrefixes", []):
115-
prefix = obj.get("Prefix")
116+
prefix = obj.get("Prefix", "")
116117
directory = prefix[len(full_key) :].rstrip("/")
117118
files_and_dirs.append(directory)
118119

119120
for obj in page.get("Contents", []):
120-
key = obj.get("Key")
121+
key = obj.get("Key", "")
121122
file = key[len(full_key) :]
122123
files_and_dirs.append(file)
123124
return files_and_dirs

griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
if TYPE_CHECKING:
1414
import boto3
15-
from mypy_boto3_bedrock import BedrockClient
15+
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
1616

1717

1818
@define
@@ -32,12 +32,12 @@ class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
3232
image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
3333
image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
3434
seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
35-
_client: Optional[BedrockClient] = field(
35+
_client: Optional[BedrockRuntimeClient] = field(
3636
default=None, kw_only=True, alias="client", metadata={"serializable": False}
3737
)
3838

3939
@lazy_property()
40-
def client(self) -> BedrockClient:
40+
def client(self) -> BedrockRuntimeClient:
4141
return self.session.client("bedrock-runtime")
4242

4343
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:

griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def load(self) -> tuple[list[Run], dict[str, Any]]:
4545
response = self.table.get_item(Key=self._get_key())
4646

4747
if "Item" in response and self.value_attribute_key in response["Item"]:
48-
memory_dict = json.loads(response["Item"][self.value_attribute_key])
48+
memory_dict = json.loads(str(response["Item"][self.value_attribute_key]))
4949
return self._from_params_dict(memory_dict)
5050
return [], {}
5151

0 commit comments

Comments
 (0)