diff --git a/chat.go b/chat.go index 9719f6b92..90a6db2b5 100644 --- a/chat.go +++ b/chat.go @@ -325,6 +325,11 @@ type ChatCompletionRequest struct { // We recommend hashing their username or email address, in order to avoid sending us any identifying information. // https://platform.openai.com/docs/api-reference/chat/create#chat_create-safety_identifier SafetyIdentifier string `json:"safety_identifier,omitempty"` + // ExtraBody provides configuration options for the generation process in Gemini API. + // Additional configuration parameters to control model behavior. Will be passed directly to the Gemini API. + // Such as thinking mode for Gemini. "extra_body": {"google": {"thinking_config": {"include_thoughts": true}}} + // https://ai.google.dev/gemini-api/docs/openai + ExtraBody map[string]any `json:"extra_body,omitempty"` // Embedded struct for non-OpenAI extensions ChatCompletionRequestExtensions } @@ -477,11 +482,28 @@ func (c *Client) CreateChatCompletion( return } + // The body map is used to dynamically construct the request payload for the chat completion API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := request.ExtraBody + request.ExtraBody = nil + + // Serialize request to JSON + jsonData, err := json.Marshal(request) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + req, err := c.newRequest( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. ) if err != nil { return diff --git a/chat_stream.go b/chat_stream.go index 80d16cc63..56d2b355e 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "net/http" ) @@ -91,11 +92,28 @@ func (c *Client) CreateChatCompletionStream( return } + // The body map is used to dynamically construct the request payload for the chat completion API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := request.ExtraBody + request.ExtraBody = nil + + // Serialize request to JSON + jsonData, err := json.Marshal(request) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + req, err := c.newRequest( ctx, http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), - withBody(request), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. ) if err != nil { return nil, err diff --git a/chat_test.go b/chat_test.go index 236cff736..b36bcb009 100644 --- a/chat_test.go +++ b/chat_test.go @@ -756,6 +756,105 @@ func TestChatCompletionsFunctions(t *testing.T) { }) } +func TestChatCompletionsWithExtraBody(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + // Register a custom handler that checks if ExtraBody fields are properly embedded + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + // Read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + + // Parse the request body into a map to check all fields + var requestBody map[string]any + err = json.Unmarshal(reqBody, &requestBody) + if err != nil { + http.Error( + w, + fmt.Sprintf("could not parse request: %v, body: %s", err, string(reqBody)), http.StatusInternalServerError, + ) + return + } + + // Check that ExtraBody fields are present in the root level + if _, exists := requestBody["custom_field"]; !exists { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + err = json.NewEncoder(w).Encode(map[string]string{"error": "custom_field not found in request body"}) + if err != nil { + http.Error(w, fmt.Sprintf("could not write response: %v", err), http.StatusInternalServerError) + return + } + return + } + + // ExtraBody should not be present in the final request + if _, exists := requestBody["extra_body"]; exists { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + err = json.NewEncoder(w).Encode(map[string]string{"error": "extra_body should not be present in final request"}) + if err != nil { + http.Error(w, fmt.Sprintf("could not write response: %v", err), http.StatusInternalServerError) + return + } + return + } + + // Return a success response + res := openai.ChatCompletionResponse{ + ID: "test-id", + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "gpt-3.5-turbo", + Choices: []openai.ChatCompletionChoice{ + { + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: "Hello!", + }, + FinishReason: openai.FinishReasonStop, + }, + }, + Usage: openai.Usage{ + PromptTokens: 5, + CompletionTokens: 5, + TotalTokens: 10, + }, + } + + resBytes, _ := json.Marshal(res) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err = w.Write(resBytes) + if err != nil { + http.Error(w, fmt.Sprintf("could not write response: %v", err), http.StatusInternalServerError) + return + } + }) + + // Test the ExtraBody functionality + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + ExtraBody: map[string]any{ + "custom_field": "custom_value", + "another_field": 123, + }, + }) + + checks.NoError(t, err, "CreateChatCompletion with ExtraBody error") +} + func TestAzureChatCompletions(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown()