Skip to content

Commit d256a9c

Browse files
authored
mcp: don't persist failed streamable sessions (#605)
If a stateful streamable session fails to initialize, it is unusable. Avoid allocating resources by closing the session immediately. Also: - Fix a bug where InitializeParams is not guarded with its mutex. - Fix a resource leak where sessions are persisted with session id "", even though they are unaddressable. - Add a relevant benchmark, and update tests. Fixes #578
1 parent ae6bda6 commit d256a9c

File tree

4 files changed

+113
-7
lines changed

4 files changed

+113
-7
lines changed

mcp/server.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,13 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any,
11451145
return handleReceive(ctx, ss, req)
11461146
}
11471147

1148-
func (ss *ServerSession) InitializeParams() *InitializeParams { return ss.state.InitializeParams }
1148+
// InitializeParams returns the InitializeParams provided during the client's
1149+
// initial connection.
1150+
func (ss *ServerSession) InitializeParams() *InitializeParams {
1151+
ss.mu.Lock()
1152+
defer ss.mu.Unlock()
1153+
return ss.state.InitializeParams
1154+
}
11491155

11501156
func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) {
11511157
if params == nil {

mcp/streamable.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
324324
logger: h.opts.Logger,
325325
}
326326

327+
// Sessions without a session ID are also stateless: there's no way to
328+
// address them.
329+
stateless := h.opts.Stateless || sessionID == ""
327330
// To support stateless mode, we initialize the session with a default
328331
// state, so that it doesn't reject subsequent requests.
329332
var connectOpts *ServerSessionOptions
330-
if h.opts.Stateless {
333+
if stateless {
331334
// Peek at the body to see if it is initialize or initialized.
332335
// We want those to be handled as usual.
333336
var hasInitialize, hasInitialized bool
@@ -405,7 +408,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
405408
transport: transport,
406409
}
407410

408-
if h.opts.Stateless {
411+
if stateless {
409412
// Stateless mode: close the session when the request exits.
410413
defer session.Close() // close the fake session after handling the request
411414
} else {
@@ -424,6 +427,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
424427
h.mu.Lock()
425428
h.sessions[transport.SessionID] = sessInfo
426429
h.mu.Unlock()
430+
defer func() {
431+
// If initialization failed, clean up the session (#578).
432+
if session.InitializeParams() == nil {
433+
// Initialization failed.
434+
session.Close()
435+
}
436+
}()
427437
}
428438
}
429439

mcp/streamable_bench_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@ package mcp_test
66

77
import (
88
"context"
9+
"flag"
10+
"log"
911
"net/http"
1012
"net/http/httptest"
13+
"os"
1114
"reflect"
15+
"runtime"
16+
"runtime/pprof"
17+
"strings"
1218
"testing"
1319

1420
"github.com/google/jsonschema-go/jsonschema"
@@ -65,3 +71,62 @@ func BenchmarkStreamableServing(b *testing.B) {
6571
}
6672
}
6773
}
74+
75+
var streamableHeap = flag.String("streamable_heap", "", "if set, write streamable heap profiles with this prefix")
76+
77+
func BenchmarkStreamableServing_BadSessions(b *testing.B) {
78+
server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil)
79+
80+
handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
81+
return server
82+
}, &mcp.StreamableHTTPOptions{JSONResponse: true})
83+
84+
httpServer := httptest.NewServer(handler)
85+
defer httpServer.Close()
86+
87+
ctx, cancel := context.WithCancel(context.Background())
88+
defer cancel()
89+
90+
if *streamableHeap != "" {
91+
writeHeap := func(file string) {
92+
// GC a couple times to ensure accurate heap.
93+
runtime.GC()
94+
runtime.GC()
95+
f, err := os.Create(file)
96+
if err != nil {
97+
log.Fatal("could not create memory profile: ", err)
98+
}
99+
defer func() {
100+
if err := f.Close(); err != nil {
101+
b.Errorf("writing heap file %q: %v", file, err)
102+
}
103+
}()
104+
if err := pprof.Lookup("heap").WriteTo(f, 0); err != nil {
105+
b.Errorf("could not write heap profile: %v", err)
106+
}
107+
}
108+
writeHeap(*streamableHeap + ".before")
109+
defer writeHeap(*streamableHeap + ".after")
110+
}
111+
112+
b.ResetTimer()
113+
for range b.N {
114+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, strings.NewReader("{}"))
115+
if err != nil {
116+
b.Fatal(err)
117+
}
118+
req.Header.Add("Accept", "application/json")
119+
req.Header.Add("Accept", "text/event-stream")
120+
resp, err := http.DefaultClient.Do(req)
121+
if err != nil {
122+
b.Fatal(err)
123+
}
124+
if got, want := resp.StatusCode, http.StatusBadRequest; got != want {
125+
b.Fatalf("POST got status %d, want %d", got, want)
126+
}
127+
if got := resp.Header.Get("Mcp-Session-Id"); got != "" {
128+
b.Fatalf("POST got unexpected session ID")
129+
}
130+
resp.Body.Close()
131+
}
132+
}

mcp/streamable_test.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,10 +690,11 @@ func TestStreamableServerTransport(t *testing.T) {
690690
}
691691

692692
tests := []struct {
693-
name string
694-
replay bool // if set, use a MemoryEventStore to enable stream replay
695-
tool func(*testing.T, context.Context, *ServerSession)
696-
requests []streamableRequest // http requests
693+
name string
694+
replay bool // if set, use a MemoryEventStore to enable stream replay
695+
tool func(*testing.T, context.Context, *ServerSession)
696+
requests []streamableRequest // http requests
697+
wantSessions int // number of sessions expected after the test
697698
}{
698699
{
699700
name: "basic",
@@ -707,6 +708,19 @@ func TestStreamableServerTransport(t *testing.T) {
707708
wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)},
708709
},
709710
},
711+
wantSessions: 1,
712+
},
713+
{
714+
name: "uninitialized",
715+
requests: []streamableRequest{
716+
{
717+
method: "POST",
718+
messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
719+
wantStatusCode: http.StatusOK,
720+
wantBodyContaining: "invalid during session initialization",
721+
},
722+
},
723+
wantSessions: 0,
710724
},
711725
{
712726
name: "accept headers",
@@ -741,6 +755,7 @@ func TestStreamableServerTransport(t *testing.T) {
741755
wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)},
742756
},
743757
},
758+
wantSessions: 1,
744759
},
745760
{
746761
name: "protocol version headers",
@@ -756,6 +771,7 @@ func TestStreamableServerTransport(t *testing.T) {
756771
wantSessionID: false, // could be true, but shouldn't matter
757772
},
758773
},
774+
wantSessions: 1,
759775
},
760776
{
761777
name: "batch rejected on 2025-06-18",
@@ -775,6 +791,7 @@ func TestStreamableServerTransport(t *testing.T) {
775791
wantBodyContaining: "batch",
776792
},
777793
},
794+
wantSessions: 1,
778795
},
779796
{
780797
name: "batch accepted on 2025-03-26",
@@ -797,6 +814,7 @@ func TestStreamableServerTransport(t *testing.T) {
797814
},
798815
},
799816
},
817+
wantSessions: 1,
800818
},
801819
{
802820
name: "tool notification",
@@ -821,6 +839,7 @@ func TestStreamableServerTransport(t *testing.T) {
821839
},
822840
},
823841
},
842+
wantSessions: 1,
824843
},
825844
{
826845
name: "tool upcall",
@@ -853,6 +872,7 @@ func TestStreamableServerTransport(t *testing.T) {
853872
},
854873
},
855874
},
875+
wantSessions: 1,
856876
},
857877
{
858878
name: "background",
@@ -915,6 +935,7 @@ func TestStreamableServerTransport(t *testing.T) {
915935
headers: map[string][]string{"Accept": nil},
916936
},
917937
},
938+
wantSessions: 0, // session deleted
918939
},
919940
{
920941
name: "errors",
@@ -946,6 +967,7 @@ func TestStreamableServerTransport(t *testing.T) {
946967
})},
947968
},
948969
},
970+
wantSessions: 0,
949971
},
950972
}
951973

@@ -972,6 +994,9 @@ func TestStreamableServerTransport(t *testing.T) {
972994
defer handler.closeAll()
973995

974996
testStreamableHandler(t, handler, test.requests)
997+
if got := len(slices.Collect(server.Sessions())); got != test.wantSessions {
998+
t.Errorf("after test, got %d sessions, want %d", got, test.wantSessions)
999+
}
9751000
})
9761001
}
9771002
}

0 commit comments

Comments
 (0)