Skip to content

Commit d2057f3

Browse files
committed
Fix tool call merging for streaming APIs without IDs
- Update MessageAggregator to handle tool calls without IDs - When tool call has no ID, merge with last tool call - Add comprehensive tests for streaming patterns Signed-off-by: ultramancode <[email protected]>
1 parent bd1834d commit d2057f3

File tree

2 files changed

+236
-3
lines changed

2 files changed

+236
-3
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
import org.springframework.util.CollectionUtils;
3939
import org.springframework.util.StringUtils;
4040

41-
import static org.springframework.ai.chat.messages.AssistantMessage.ToolCall;
42-
4341
/**
4442
* Helper that for streaming chat responses, aggregate the chat response messages into a
4543
* single AssistantMessage. Job is performed in parallel to the chat response processing.
@@ -48,6 +46,7 @@
4846
* @author Alexandros Pappas
4947
* @author Thomas Vitale
5048
* @author Heonwoo Kim
49+
* @author Taewoong Kim
5150
* @since 1.0.0
5251
*/
5352
public class MessageAggregator {
@@ -104,7 +103,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
104103
}
105104
AssistantMessage outputMessage = chatResponse.getResult().getOutput();
106105
if (!CollectionUtils.isEmpty(outputMessage.getToolCalls())) {
107-
toolCallsRef.get().addAll(outputMessage.getToolCalls());
106+
mergeToolCalls(toolCallsRef.get(), outputMessage.getToolCalls());
108107
}
109108

110109
}
@@ -188,6 +187,74 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
188187
}).doOnError(e -> logger.error("Aggregation Error", e));
189188
}
190189

190+
/**
191+
* Merge tool calls by id to handle streaming responses where tool call data is split
192+
* across multiple chunks. This is common in OpenAI-compatible APIs like Qwen, where
193+
* the first chunk contains the function name and subsequent chunks contain only arguments.
194+
* if a tool call has an ID, it's matched by ID.
195+
* if it has no ID (empty or null), it's merged with the last tool call in the list.
196+
* @param existingToolCalls the list of existing tool calls to merge into
197+
* @param newToolCalls the new tool calls to merge
198+
*/
199+
private void mergeToolCalls(List<ToolCall> existingToolCalls, List<ToolCall> newToolCalls) {
200+
for (ToolCall newCall : newToolCalls) {
201+
if (StringUtils.hasText(newCall.id())) {
202+
// ID present: match by ID or add as new
203+
ToolCall existingMatch = existingToolCalls.stream()
204+
.filter(existing -> newCall.id().equals(existing.id()))
205+
.findFirst()
206+
.orElse(null);
207+
208+
if (existingMatch != null) {
209+
// Merge with existing tool call with same ID
210+
int index = existingToolCalls.indexOf(existingMatch);
211+
ToolCall merged = mergeToolCall(existingMatch, newCall);
212+
existingToolCalls.set(index, merged);
213+
} else {
214+
// New tool call with ID
215+
existingToolCalls.add(newCall);
216+
}
217+
} else {
218+
// No ID: merge with last tool call
219+
ToolCall lastToolCall = existingToolCalls.isEmpty() ? null : existingToolCalls.get(existingToolCalls.size() - 1);
220+
ToolCall merged = mergeToolCall(lastToolCall, newCall);
221+
222+
if (lastToolCall != null) {
223+
existingToolCalls.set(existingToolCalls.size() - 1, merged);
224+
} else {
225+
existingToolCalls.add(merged);
226+
}
227+
}
228+
}
229+
}
230+
231+
/**
232+
* Merge two tool calls into one, combining their properties.
233+
* @param existing the existing tool call
234+
* @param current the current tool call to merge
235+
* @return the merged tool call
236+
*/
237+
private ToolCall mergeToolCall(ToolCall existing, ToolCall current) {
238+
if (existing == null) {
239+
return current;
240+
}
241+
242+
// Use non-empty ID, prefer existing if both present (for consistency)
243+
String mergedId = StringUtils.hasText(existing.id()) ? existing.id() : current.id();
244+
245+
// Use non-empty name, prefer new if both present
246+
String mergedName = StringUtils.hasText(current.name()) ? current.name() : existing.name();
247+
248+
// Use non-empty type, prefer new if both present
249+
String mergedType = StringUtils.hasText(current.type()) ? current.type() : existing.type();
250+
251+
// Concatenate arguments
252+
String mergedArgs = (existing.arguments() != null ? existing.arguments() : "")
253+
+ (current.arguments() != null ? current.arguments() : "");
254+
255+
return new ToolCall(mergedId, mergedType, mergedName, mergedArgs);
256+
}
257+
191258
public record DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens) implements Usage {
192259

193260
@Override
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.model;
18+
19+
import java.util.List;
20+
import java.util.concurrent.atomic.AtomicReference;
21+
22+
import org.junit.jupiter.api.Test;
23+
import reactor.core.publisher.Flux;
24+
25+
import org.springframework.ai.chat.messages.AssistantMessage;
26+
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
29+
/**
30+
* Tests for {@link MessageAggregator} with streaming tool calls that lack IDs in subsequent chunks.
31+
* This pattern is common in OpenAI-compatible APIs.
32+
* @author Taewoong Kim
33+
*/
34+
class MessageAggregatorTests {
35+
36+
private final MessageAggregator messageAggregator = new MessageAggregator();
37+
38+
/**
39+
* Test merging of tool calls when subsequent chunks have no ID.
40+
* First chunk contains the tool name and ID, subsequent chunks contain only arguments.
41+
*/
42+
@Test
43+
void shouldMergeToolCallsWithoutIds() {
44+
// Chunk 1: ID and name present
45+
ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
46+
.toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "getCurrentWeather", "")))
47+
.build())));
48+
49+
// Chunk 2-5: No ID, only arguments (common streaming pattern)
50+
ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
51+
.toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"location\": \"")))
52+
.build())));
53+
54+
ChatResponse chunk3 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
55+
.toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "Se")))
56+
.build())));
57+
58+
ChatResponse chunk4 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
59+
.toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "oul")))
60+
.build())));
61+
62+
ChatResponse chunk5 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
63+
.toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "\"}")))
64+
.build())));
65+
66+
Flux<ChatResponse> flux = Flux.just(chunk1, chunk2, chunk3, chunk4, chunk5);
67+
68+
// When: Aggregate the streaming responses
69+
AtomicReference<ChatResponse> finalResponse = new AtomicReference<>();
70+
this.messageAggregator.aggregate(flux, finalResponse::set).blockLast();
71+
72+
// Then: Verify the tool call was properly merged
73+
assertThat(finalResponse.get()).isNotNull();
74+
List<AssistantMessage.ToolCall> toolCalls = finalResponse.get().getResult().getOutput().getToolCalls();
75+
76+
assertThat(toolCalls).hasSize(1);
77+
AssistantMessage.ToolCall mergedToolCall = toolCalls.get(0);
78+
79+
assertThat(mergedToolCall.id()).isEqualTo("chatcmpl-tool-123");
80+
assertThat(mergedToolCall.name()).isEqualTo("getCurrentWeather");
81+
assertThat(mergedToolCall.arguments()).isEqualTo("{\"location\": \"Seoul\"}");
82+
}
83+
84+
/**
85+
* Test multiple tool calls being streamed simultaneously. Each tool call has its own ID in the first chunk,
86+
* and subsequent chunks have no ID but are merged with the last tool call.
87+
*/
88+
@Test
89+
void shouldMergeMultipleToolCallsWithMixedIds() {
90+
// Given: Multiple tool calls being streamed
91+
// Chunk 1: First tool call starts with ID
92+
ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
93+
.toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", "")))
94+
.build())));
95+
96+
// Chunk 2: Argument for first tool call (no ID)
97+
ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
98+
.toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"city\":\"Tokyo\"}")))
99+
.build())));
100+
101+
// Chunk 3: Second tool call starts with ID
102+
ChatResponse chunk3 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
103+
.toolCalls(List.of(new AssistantMessage.ToolCall("tool-2", "function", "getTime", "")))
104+
.build())));
105+
106+
// Chunk 4: Argument for second tool call (no ID)
107+
ChatResponse chunk4 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
108+
.toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"timezone\":\"JST\"}")))
109+
.build())));
110+
111+
Flux<ChatResponse> flux = Flux.just(chunk1, chunk2, chunk3, chunk4);
112+
113+
// When: Aggregate the streaming responses
114+
AtomicReference<ChatResponse> finalResponse = new AtomicReference<>();
115+
this.messageAggregator.aggregate(flux, finalResponse::set).blockLast();
116+
117+
// Then: Verify both tool calls were properly merged
118+
assertThat(finalResponse.get()).isNotNull();
119+
List<AssistantMessage.ToolCall> toolCalls = finalResponse.get().getResult().getOutput().getToolCalls();
120+
121+
assertThat(toolCalls).hasSize(2);
122+
123+
AssistantMessage.ToolCall firstToolCall = toolCalls.get(0);
124+
assertThat(firstToolCall.id()).isEqualTo("tool-1");
125+
assertThat(firstToolCall.name()).isEqualTo("getWeather");
126+
assertThat(firstToolCall.arguments()).isEqualTo("{\"city\":\"Tokyo\"}");
127+
128+
AssistantMessage.ToolCall secondToolCall = toolCalls.get(1);
129+
assertThat(secondToolCall.id()).isEqualTo("tool-2");
130+
assertThat(secondToolCall.name()).isEqualTo("getTime");
131+
assertThat(secondToolCall.arguments()).isEqualTo("{\"timezone\":\"JST\"}");
132+
}
133+
134+
/**
135+
* Test that tool calls with IDs are still matched correctly by ID, even when they arrive in different chunks.
136+
*/
137+
@Test
138+
void shouldMergeToolCallsById() {
139+
// Given: Chunks with same ID arriving separately
140+
ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
141+
.toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", "{\"ci")))
142+
.build())));
143+
144+
ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
145+
.toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "", "ty\":\"Paris\"}")))
146+
.build())));
147+
148+
Flux<ChatResponse> flux = Flux.just(chunk1, chunk2);
149+
150+
// When: Aggregate the streaming responses
151+
AtomicReference<ChatResponse> finalResponse = new AtomicReference<>();
152+
this.messageAggregator.aggregate(flux, finalResponse::set).blockLast();
153+
154+
// Then: Verify the tool call was merged by ID
155+
assertThat(finalResponse.get()).isNotNull();
156+
List<AssistantMessage.ToolCall> toolCalls = finalResponse.get().getResult().getOutput().getToolCalls();
157+
158+
assertThat(toolCalls).hasSize(1);
159+
AssistantMessage.ToolCall mergedToolCall = toolCalls.get(0);
160+
assertThat(mergedToolCall.id()).isEqualTo("tool-1");
161+
assertThat(mergedToolCall.name()).isEqualTo("getWeather");
162+
assertThat(mergedToolCall.arguments()).isEqualTo("{\"city\":\"Paris\"}");
163+
}
164+
165+
}
166+

0 commit comments

Comments
 (0)