Skip to content

Commit dfd397b

Browse files
committed
WIP
1 parent b04d136 commit dfd397b

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

griptape/drivers/prompt/google_prompt_driver.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import TYPE_CHECKING, Optional
66

77
from attrs import Factory, define, field
8-
from google.genai.types import GenerateContentResponse
98
from pydantic import BaseModel
109

1110
from griptape.artifacts import ActionArtifact, TextArtifact
@@ -35,7 +34,7 @@
3534
from collections.abc import Iterator
3635

3736
from google.genai import Client
38-
from google.genai.types import Content, ContentDict, Part
37+
from google.genai.types import Content, ContentDict, GenerateContentResponse, Part, Tool
3938

4039
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
4140
from griptape.tools import BaseTool
@@ -97,16 +96,14 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
9796
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
9897
params = self._base_params(prompt_stack)
9998
logging.debug(params)
100-
response = self.client.models.generate_content_stream(
101-
**params,
102-
)
99+
response = self.client.models.generate_content_stream(**params)
103100

104101
prompt_token_count = None
105102
for chunk in response:
106103
logger.debug(chunk.model_dump())
107104
usage_metadata = chunk.usage_metadata
108105

109-
content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None
106+
content = self.__to_prompt_stack_delta_message_content(chunk)
110107
# Only want to output the prompt token count once since it is static each chunk
111108
if prompt_token_count is None:
112109
yield DeltaMessage(
@@ -135,33 +132,35 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
135132
params = {
136133
"model": self.model,
137134
"contents": self.__to_google_messages(prompt_stack),
138-
"config": types.GenerateContentConfig(
139-
**{
140-
"stop_sequences": [] if self.use_native_tools else self.tokenizer.stop_sequences,
141-
"max_output_tokens": self.max_tokens,
142-
"temperature": self.temperature,
143-
"top_p": self.top_p,
144-
"top_k": self.top_k,
145-
"system_instruction": system_instruction,
146-
**self.extra_params,
147-
},
148-
),
135+
}
136+
137+
config = {
138+
"stop_sequences": [] if self.use_native_tools else self.tokenizer.stop_sequences,
139+
"max_output_tokens": self.max_tokens,
140+
"temperature": self.temperature,
141+
"top_p": self.top_p,
142+
"top_k": self.top_k,
143+
"system_instruction": system_instruction,
144+
"automatic_function_calling": types.AutomaticFunctionCallingConfig(disable=True),
145+
**self.extra_params,
149146
}
150147

151148
if (
152149
self.structured_output_strategy == "native"
153150
and isinstance(prompt_stack.output_schema, type)
154151
and issubclass(prompt_stack.output_schema, BaseModel)
155152
):
156-
params["config"].response_schema = prompt_stack.output_schema
153+
config["response_schema"] = prompt_stack.output_schema
157154

158155
if prompt_stack.tools and self.use_native_tools:
159-
params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}
156+
config["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}
160157

161158
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
162-
params["tool_config"]["function_calling_config"]["mode"] = "auto"
159+
config["tool_config"]["function_calling_config"]["mode"] = "auto"
160+
161+
config["tools"] = self.__to_google_tools(prompt_stack.tools)
163162

164-
params["tools"] = self.__to_google_tools(prompt_stack.tools)
163+
params["config"] = types.GenerateContentConfig(**config)
165164

166165
return params
167166

@@ -182,7 +181,7 @@ def __to_google_role(self, message: Message) -> str:
182181
return "model"
183182
return "user"
184183

185-
def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
184+
def __to_google_tools(self, tools: list[BaseTool]) -> list[Tool]:
186185
types = import_optional_dependency("google.genai.types")
187186

188187
tool_declarations = []
@@ -194,7 +193,7 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
194193
schema = schema["properties"]["values"]
195194

196195
schema = remove_key_in_dict_recursively(schema, "additionalProperties")
197-
tool_declaration = types.FunctionDeclaration(
196+
function_declaration = types.FunctionDeclaration(
198197
name=tool.to_native_tool_name(activity),
199198
description=tool.activity_description(activity),
200199
**(
@@ -209,8 +208,9 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
209208
else {}
210209
),
211210
)
211+
google_tool = types.Tool(function_declarations=[function_declaration])
212212

213-
tool_declarations.append(tool_declaration)
213+
tool_declarations.append(google_tool)
214214

215215
return tool_declarations
216216

@@ -259,11 +259,11 @@ def __to_prompt_stack_message_content(self, content: GenerateContentResponse) ->
259259

260260
return []
261261

262-
def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent:
262+
def __to_prompt_stack_delta_message_content(self, content: GenerateContentResponse) -> BaseDeltaMessageContent:
263263
if content.text:
264264
return TextDeltaMessageContent(content.text)
265-
if content.function_call:
266-
function_call = content.function_call
265+
if content.function_calls:
266+
function_call = content.function_calls[0]
267267

268268
args = function_call.args
269269
return ActionCallDeltaMessageContent(

griptape/utils/dict_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def remove_null_values_in_dict_recursively(d: dict) -> dict:
1212
def remove_key_in_dict_recursively(d: dict, key: str) -> dict:
1313
if isinstance(d, dict):
1414
return {k: remove_key_in_dict_recursively(v, key) for k, v in d.items() if k != key}
15+
if isinstance(d, list):
16+
return [remove_key_in_dict_recursively(v, key) for v in d]
1517
return d
1618

1719

tests/unit/drivers/prompt/test_google_prompt_driver.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from unittest.mock import MagicMock, Mock
22

33
import pytest
4-
from google.generativeai.protos import FunctionCall, FunctionResponse, Part
5-
from google.generativeai.types import ContentDict, GenerationConfig
4+
from google.genai.types import ContentDict, FunctionCall, FunctionResponse, GenerationConfig, Part
65
from google.protobuf.json_format import MessageToDict
76
from schema import Schema
87

@@ -51,7 +50,7 @@ class TestGooglePromptDriver:
5150

5251
@pytest.fixture()
5352
def mock_generative_model(self, mocker):
54-
mock_generative_model = mocker.patch("google.generativeai.GenerativeModel")
53+
mock_generative_model = mocker.patch("google.genai.GenerativeModel")
5554
mocker.patch("google.protobuf.json_format.MessageToDict").return_value = {
5655
"args": {"foo": "bar"},
5756
}
@@ -71,7 +70,7 @@ def mock_generative_model(self, mocker):
7170

7271
@pytest.fixture()
7372
def mock_stream_generative_model(self, mocker):
74-
mock_generative_model = mocker.patch("google.generativeai.GenerativeModel")
73+
mock_generative_model = mocker.patch("google.genai.GenerativeModel")
7574
mocker.patch("google.protobuf.json_format.MessageToDict").return_value = {
7675
"args": {"foo": "bar"},
7776
}

0 commit comments

Comments
 (0)