55from typing import TYPE_CHECKING , Optional
66
77from attrs import Factory , define , field
8- from google .genai .types import GenerateContentResponse
98from pydantic import BaseModel
109
1110from griptape .artifacts import ActionArtifact , TextArtifact
3534 from collections .abc import Iterator
3635
3736 from google .genai import Client
38- from google .genai .types import Content , ContentDict , Part
37+ from google .genai .types import Content , ContentDict , GenerateContentResponse , Part , Tool
3938
4039 from griptape .drivers .prompt .base_prompt_driver import StructuredOutputStrategy
4140 from griptape .tools import BaseTool
@@ -97,16 +96,14 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
9796 def try_stream (self , prompt_stack : PromptStack ) -> Iterator [DeltaMessage ]:
9897 params = self ._base_params (prompt_stack )
9998 logging .debug (params )
100- response = self .client .models .generate_content_stream (
101- ** params ,
102- )
99+ response = self .client .models .generate_content_stream (** params )
103100
104101 prompt_token_count = None
105102 for chunk in response :
106103 logger .debug (chunk .model_dump ())
107104 usage_metadata = chunk .usage_metadata
108105
109- content = self .__to_prompt_stack_delta_message_content (chunk . parts [ 0 ]) if chunk . parts else None
106+ content = self .__to_prompt_stack_delta_message_content (chunk )
110107 # Only want to output the prompt token count once since it is static each chunk
111108 if prompt_token_count is None :
112109 yield DeltaMessage (
@@ -135,33 +132,35 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
135132 params = {
136133 "model" : self .model ,
137134 "contents" : self .__to_google_messages (prompt_stack ),
138- "config" : types . GenerateContentConfig (
139- ** {
140- "stop_sequences" : [] if self . use_native_tools else self . tokenizer . stop_sequences ,
141- "max_output_tokens" : self .max_tokens ,
142- "temperature " : self .temperature ,
143- "top_p " : self .top_p ,
144- "top_k " : self .top_k ,
145- "system_instruction " : system_instruction ,
146- ** self . extra_params ,
147- } ,
148- ) ,
135+ }
136+
137+ config = {
138+ "stop_sequences" : [] if self . use_native_tools else self .tokenizer . stop_sequences ,
139+ "max_output_tokens " : self .max_tokens ,
140+ "temperature " : self .temperature ,
141+ "top_p " : self .top_p ,
142+ "top_k " : self . top_k ,
143+ "system_instruction" : system_instruction ,
144+ "automatic_function_calling" : types . AutomaticFunctionCallingConfig ( disable = True ) ,
145+ ** self . extra_params ,
149146 }
150147
151148 if (
152149 self .structured_output_strategy == "native"
153150 and isinstance (prompt_stack .output_schema , type )
154151 and issubclass (prompt_stack .output_schema , BaseModel )
155152 ):
156- params [ "config" ]. response_schema = prompt_stack .output_schema
153+ config [ "response_schema" ] = prompt_stack .output_schema
157154
158155 if prompt_stack .tools and self .use_native_tools :
159- params ["tool_config" ] = {"function_calling_config" : {"mode" : self .tool_choice }}
156+ config ["tool_config" ] = {"function_calling_config" : {"mode" : self .tool_choice }}
160157
161158 if prompt_stack .output_schema is not None and self .structured_output_strategy == "tool" :
162- params ["tool_config" ]["function_calling_config" ]["mode" ] = "auto"
159+ config ["tool_config" ]["function_calling_config" ]["mode" ] = "auto"
160+
161+ config ["tools" ] = self .__to_google_tools (prompt_stack .tools )
163162
164- params ["tools " ] = self . __to_google_tools ( prompt_stack . tools )
163+ params ["config " ] = types . GenerateContentConfig ( ** config )
165164
166165 return params
167166
@@ -182,7 +181,7 @@ def __to_google_role(self, message: Message) -> str:
182181 return "model"
183182 return "user"
184183
185- def __to_google_tools (self , tools : list [BaseTool ]) -> list [dict ]:
184+ def __to_google_tools (self , tools : list [BaseTool ]) -> list [Tool ]:
186185 types = import_optional_dependency ("google.genai.types" )
187186
188187 tool_declarations = []
@@ -194,7 +193,7 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
194193 schema = schema ["properties" ]["values" ]
195194
196195 schema = remove_key_in_dict_recursively (schema , "additionalProperties" )
197- tool_declaration = types .FunctionDeclaration (
196+ function_declaration = types .FunctionDeclaration (
198197 name = tool .to_native_tool_name (activity ),
199198 description = tool .activity_description (activity ),
200199 ** (
@@ -209,8 +208,9 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
209208 else {}
210209 ),
211210 )
211+ google_tool = types .Tool (function_declarations = [function_declaration ])
212212
213- tool_declarations .append (tool_declaration )
213+ tool_declarations .append (google_tool )
214214
215215 return tool_declarations
216216
@@ -259,11 +259,11 @@ def __to_prompt_stack_message_content(self, content: GenerateContentResponse) ->
259259
260260 return []
261261
262- def __to_prompt_stack_delta_message_content (self , content : Part ) -> BaseDeltaMessageContent :
262+ def __to_prompt_stack_delta_message_content (self , content : GenerateContentResponse ) -> BaseDeltaMessageContent :
263263 if content .text :
264264 return TextDeltaMessageContent (content .text )
265- if content .function_call :
266- function_call = content .function_call
265+ if content .function_calls :
266+ function_call = content .function_calls [ 0 ]
267267
268268 args = function_call .args
269269 return ActionCallDeltaMessageContent (
0 commit comments