Skip to content

Commit 038cd24

Browse files
authored
Handle partial flushes for application/json (#2422)
This adds code fixing the issue #2403 for `application/json` content type.
1 parent b6a5e46 commit 038cd24

File tree

5 files changed

+172
-136
lines changed

5 files changed

+172
-136
lines changed

pkg/mcp/tool_filter.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ func (rw *toolFilterWriter) Flush() {
336336

337337
var b bytes.Buffer
338338
err := processBuffer(rw.config, rw.buffer, mimeType, &b)
339-
if err == errKeepBuffering {
339+
if errors.Is(err, errKeepBuffering) {
340340
logger.Debugf("Buffered %d so far, keep buffering...", len(rw.buffer))
341341
return
342342
}
@@ -386,7 +386,11 @@ func processBuffer(
386386
switch mimeType {
387387
case "application/json":
388388
var toolsListResponse toolsListResponse
389+
var syntaxError *json.SyntaxError
389390
err := json.Unmarshal(buffer, &toolsListResponse)
391+
if errors.As(err, &syntaxError) {
392+
return fmt.Errorf("%w: %v", errKeepBuffering, err)
393+
}
390394
if err == nil && toolsListResponse.Result.Tools != nil {
391395
return processToolsListResponse(config, toolsListResponse, w)
392396
}
@@ -416,7 +420,7 @@ func processEventStream(
416420
w io.Writer,
417421
) error {
418422
if len(buffer) > 1 && buffer[len(buffer)-1] != '\n' && buffer[len(buffer)-1] != '\r' {
419-
return errKeepBuffering
423+
return fmt.Errorf("%w: %v", errKeepBuffering, "event separator not found")
420424
}
421425

422426
// NOTE: this looks uglier, but is more efficient than scanning the whole buffer

pkg/mcp/tool_middleware_test.go

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,23 +117,6 @@ func TestNewListToolsMappingMiddleware_Scenarios(t *testing.T) {
117117
{"name": "MyFoo", "description": "Foo tool"},
118118
},
119119
},
120-
{
121-
name: "Filter MyFoo, Override Foo -> MyFoo with connection hang",
122-
serverOpts: []testkit.TestMCPServerOption{
123-
//nolint:goconst
124-
testkit.WithTool("Foo", "Foo tool", func() string { return "Foo" }),
125-
//nolint:goconst
126-
testkit.WithTool("Bar", "Bar tool", func() string { return "Bar" }),
127-
testkit.WithConnectionHang(10 * time.Second),
128-
},
129-
opts: &[]ToolMiddlewareOption{
130-
WithToolsFilter("MyFoo"),
131-
WithToolsOverride("Foo", "MyFoo", ""),
132-
},
133-
expected: &[]map[string]any{
134-
{"name": "MyFoo", "description": "Foo tool"},
135-
},
136-
},
137120
}
138121

139122
for _, tt := range tests {
@@ -631,7 +614,7 @@ func TestNewToolCallMappingMiddleware_ErrorCases(t *testing.T) {
631614
}
632615
}
633616

634-
func TestNewListToolsMappingMiddleware_ConnectionHang(t *testing.T) {
617+
func TestSSEBufferFlushes(t *testing.T) {
635618
t.Parallel()
636619
middlewares := []func(http.Handler) http.Handler{}
637620

@@ -684,4 +667,61 @@ func TestNewListToolsMappingMiddleware_ConnectionHang(t *testing.T) {
684667
require.NoError(t, err)
685668
require.NotNil(t, response.Result)
686669
require.NotNil(t, response.Result.Tools)
670+
require.Len(t, *response.Result.Tools, 1)
671+
}
672+
673+
func TestJSONBufferFlushes(t *testing.T) {
674+
t.Parallel()
675+
middlewares := []func(http.Handler) http.Handler{}
676+
677+
opts := []ToolMiddlewareOption{
678+
WithToolsFilter("MyFoo"),
679+
WithToolsOverride("Foo", "MyFoo", ""),
680+
}
681+
682+
// Create the middleware
683+
toolsListmiddleware, err := NewListToolsMappingMiddleware(opts...)
684+
assert.NoError(t, err)
685+
toolsCallMiddleware, err := NewToolCallMappingMiddleware(opts...)
686+
assert.NoError(t, err)
687+
688+
middlewares = append(middlewares,
689+
toolsCallMiddleware,
690+
toolsListmiddleware,
691+
)
692+
693+
// Create test server
694+
serverOpts := []testkit.TestMCPServerOption{
695+
testkit.WithJSONClientType(),
696+
testkit.WithConnectionHang(10 * time.Second),
697+
testkit.WithMiddlewares(middlewares...),
698+
testkit.WithWithProxy(),
699+
testkit.WithTool("Foo", "Foo tool", func() string { return "Foo" }),
700+
}
701+
702+
for i := 0; i < 100; i++ {
703+
opt := testkit.WithTool(
704+
fmt.Sprintf("Foo%d", i),
705+
strings.Repeat("A", 10*1024),
706+
func() string { return fmt.Sprintf("Foo%d", i) },
707+
)
708+
serverOpts = append(serverOpts, opt)
709+
}
710+
711+
server, client, err := testkit.NewStreamableTestServer(
712+
serverOpts...,
713+
)
714+
require.NoError(t, err)
715+
defer server.Close()
716+
717+
// Make request
718+
respBody, err := client.ToolsList()
719+
require.NoError(t, err)
720+
721+
var response toolsListResponse
722+
err = json.NewDecoder(bytes.NewReader(respBody)).Decode(&response)
723+
require.NoError(t, err)
724+
require.NotNil(t, response.Result)
725+
require.NotNil(t, response.Result.Tools)
726+
require.Len(t, *response.Result.Tools, 1)
687727
}

test/testkit/sse_server.go

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import (
99
"io"
1010
"net/http"
1111
"net/http/httptest"
12-
"net/http/httputil"
13-
"net/url"
1412
"time"
1513

1614
"github.com/go-chi/chi/v5"
@@ -225,28 +223,10 @@ func NewSSETestServer(
225223
// If the server is configured to use a proxy,create a reverse proxy to
226224
// the backend test server.
227225
if server.withProxy {
228-
backendURL, err := url.Parse(backendServer.URL)
226+
proxyServer, err := wrapBackendWithProxy(backendServer.URL, allMiddlewares)
229227
if err != nil {
230-
return nil, nil, fmt.Errorf("failed to parse backend URL: %w", err)
228+
return nil, nil, fmt.Errorf("failed to wrap backend with proxy: %w", err)
231229
}
232-
233-
// Create a reverse proxy to the backend test server.
234-
// Ideally, this would use ToolHive reverse proxy, but
235-
// it is too tightly coupled with containers and needs
236-
// to be refactored.
237-
proxy := httputil.NewSingleHostReverseProxy(backendURL)
238-
proxy.FlushInterval = -1
239-
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
240-
proxy.ServeHTTP(w, r)
241-
})
242-
243-
// Apply middleware chain in reverse order (last middleware is applied first)
244-
var finalHandler http.Handler = handler
245-
for _, mw := range allMiddlewares {
246-
finalHandler = mw(finalHandler)
247-
}
248-
249-
proxyServer := httptest.NewServer(finalHandler)
250230
testServer = proxyServer
251231
}
252232

@@ -319,7 +299,7 @@ func (s *sseServer) sseHandler(w http.ResponseWriter, _ *http.Request) {
319299
w.Header().Set("Content-Type", "text/event-stream")
320300

321301
// Get flusher for streaming responses
322-
flusher, ok := w.(http.Flusher)
302+
_, ok := w.(http.Flusher)
323303
if !ok {
324304
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
325305
return
@@ -356,18 +336,12 @@ func (s *sseServer) sseHandler(w http.ResponseWriter, _ *http.Request) {
356336
response = "failed to generate response"
357337
}
358338

359-
if _, err := w.Write([]byte("event: random-stuff\ndata: " + response + "\n\n")); err != nil {
360-
http.Error(w, "Error writing response", http.StatusInternalServerError)
361-
return
339+
if s.connHangDuration == 0 {
340+
singleFlushResponse([]byte("event: random-stuff\ndata: "+response+"\n\n"), w)
341+
} else {
342+
staggeredFlushResponse([]byte("event: random-stuff\ndata: "+response+"\n\n"), w, s.connHangDuration)
362343
}
363344
}
364-
365-
// Flush the response immediately
366-
flusher.Flush()
367-
368-
if s.connHangDuration != 0 {
369-
time.Sleep(s.connHangDuration)
370-
}
371345
}
372346
}
373347
}

test/testkit/streamable_server.go

Lines changed: 25 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import (
99
"io"
1010
"net/http"
1111
"net/http/httptest"
12-
"net/http/httputil"
13-
"net/url"
1412
"time"
1513

1614
"github.com/go-chi/chi/v5"
@@ -202,18 +200,19 @@ func NewStreamableTestServer(
202200
// This precompiles the tools list response based on the provided tools
203201
server.toolsListResponse = makeToolsList(server.tools)
204202

203+
allMiddlewares := append(
204+
[]func(http.Handler) http.Handler{
205+
middleware.RequestID,
206+
middleware.Recoverer,
207+
},
208+
server.middlewares...,
209+
)
210+
205211
router := chi.NewRouter()
206212

207213
// If the server is not configured to use a proxy, apply the middlewares to
208214
// the router directly.
209215
if !server.withProxy {
210-
allMiddlewares := append(
211-
[]func(http.Handler) http.Handler{
212-
middleware.RequestID,
213-
middleware.Recoverer,
214-
},
215-
server.middlewares...,
216-
)
217216
router.Use(allMiddlewares...)
218217
}
219218

@@ -229,35 +228,10 @@ func NewStreamableTestServer(
229228
// If the server is configured to use a proxy,create a reverse proxy to
230229
// the backend test server.
231230
if server.withProxy {
232-
backendURL, err := url.Parse(backendServer.URL)
231+
proxyServer, err := wrapBackendWithProxy(backendServer.URL, allMiddlewares)
233232
if err != nil {
234-
return nil, nil, fmt.Errorf("failed to parse backend URL: %w", err)
235-
}
236-
237-
// Create a reverse proxy to the backend test server.
238-
// Ideally, this would use ToolHive reverse proxy, but
239-
// it is too tightly coupled with containers and needs
240-
// to be refactored.
241-
proxy := httputil.NewSingleHostReverseProxy(backendURL)
242-
proxy.FlushInterval = -1
243-
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
244-
proxy.ServeHTTP(w, r)
245-
})
246-
247-
// Apply middleware chain in reverse order (last middleware is applied first)
248-
allMiddlewares := append(
249-
[]func(http.Handler) http.Handler{
250-
middleware.RequestID,
251-
middleware.Recoverer,
252-
},
253-
server.middlewares...,
254-
)
255-
var finalHandler http.Handler = handler
256-
for _, mw := range allMiddlewares {
257-
finalHandler = mw(finalHandler)
233+
return nil, nil, fmt.Errorf("failed to wrap backend with proxy: %w", err)
258234
}
259-
260-
proxyServer := httptest.NewServer(finalHandler)
261235
testServer = proxyServer
262236
}
263237

@@ -277,7 +251,10 @@ func NewStreamableTestServer(
277251
}
278252
}
279253

280-
func (s *streamableServer) mcpJSONHandler(w http.ResponseWriter, r *http.Request) {
254+
func (s *streamableServer) mcpJSONHandler(
255+
w http.ResponseWriter,
256+
r *http.Request,
257+
) {
281258
// Read the request body
282259
body, err := io.ReadAll(r.Body)
283260
if err != nil {
@@ -319,23 +296,18 @@ func (s *streamableServer) mcpJSONHandler(w http.ResponseWriter, r *http.Request
319296

320297
w.Header().Set("Content-Type", "application/json")
321298
w.WriteHeader(http.StatusOK)
322-
if _, err := w.Write([]byte(response)); err != nil {
323-
http.Error(w, "Error writing response", http.StatusInternalServerError)
324-
return
325-
}
326299

327-
// Flush if available
328-
if flusher, ok := w.(http.Flusher); ok {
329-
flusher.Flush()
330-
}
331-
332-
if s.connHangDuration != 0 {
333-
time.Sleep(s.connHangDuration)
300+
if s.connHangDuration == 0 {
301+
singleFlushResponse([]byte(response), w)
302+
} else {
303+
staggeredFlushResponse([]byte(response), w, s.connHangDuration)
334304
}
335305
}
336306

337-
//nolint:gocyclo
338-
func (s *streamableServer) mcpEventStreamHandler(w http.ResponseWriter, r *http.Request) {
307+
func (s *streamableServer) mcpEventStreamHandler(
308+
w http.ResponseWriter,
309+
r *http.Request,
310+
) {
339311
// Read the request body
340312
body, err := io.ReadAll(r.Body)
341313
if err != nil {
@@ -378,41 +350,11 @@ func (s *streamableServer) mcpEventStreamHandler(w http.ResponseWriter, r *http.
378350
response = "failed to generate response"
379351
}
380352

381-
if _, err := w.Write([]byte("event: random-stuff\ndata: " + response)); err != nil {
382-
http.Error(w, "Error writing response", http.StatusInternalServerError)
383-
return
384-
}
353+
response = "event: random-stuff\ndata: " + response + "\n\n"
385354

386355
if s.connHangDuration == 0 {
387-
_, err := w.Write([]byte("\n\n"))
388-
if err != nil {
389-
http.Error(w, "Error writing response", http.StatusInternalServerError)
390-
return
391-
}
392-
393-
// Flush if available
394-
if flusher, ok := w.(http.Flusher); ok {
395-
flusher.Flush()
396-
}
356+
singleFlushResponse([]byte(response), w)
397357
} else {
398-
// Flush if available
399-
if flusher, ok := w.(http.Flusher); ok {
400-
flusher.Flush()
401-
}
402-
403-
if s.connHangDuration != 0 {
404-
time.Sleep(s.connHangDuration)
405-
}
406-
407-
_, err := w.Write([]byte("\n\n"))
408-
if err != nil {
409-
http.Error(w, "Error writing response", http.StatusInternalServerError)
410-
return
411-
}
412-
413-
// Flush if available
414-
if flusher, ok := w.(http.Flusher); ok {
415-
flusher.Flush()
416-
}
358+
staggeredFlushResponse([]byte(response), w, s.connHangDuration)
417359
}
418360
}

0 commit comments

Comments
 (0)