Skip to content

Commit ec47a08

Browse files
committed
Add tests for session ID tagging
1 parent be367d4 commit ec47a08

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

contrib/mark3labs/mcp-go/mcpgo.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ func NewToolHandlerMiddleware() server.ToolHandlerMiddleware {
5959
outputText = string(resultJSON)
6060
}
6161

62+
tagWithSessionID(ctx, toolSpan)
63+
6264
toolSpan.AnnotateTextIO(string(inputJSON), outputText)
6365

6466
if err != nil {
@@ -86,6 +88,7 @@ func (h *hooks) onBeforeInitialize(ctx context.Context, id any, request *mcp.Ini
8688
taskSpan.Annotate(llmobs.WithAnnotatedTags(map[string]string{"client_name": clientName, "client_version": clientName + "_" + clientVersion}))
8789

8890
h.spanCache.Store(id, taskSpan)
91+
tagWithSessionID(ctx, taskSpan)
8992
}
9093

9194
func (h *hooks) onAfterInitialize(ctx context.Context, id any, request *mcp.InitializeRequest, result *mcp.InitializeResult) {
@@ -114,6 +117,14 @@ func (h *hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, messa
114117
span.Finish(llmobs.WithError(err))
115118
}
116119

120+
func tagWithSessionID(ctx context.Context, span llmobs.Span) {
121+
session := server.ClientSessionFromContext(ctx)
122+
if session != nil {
123+
sessionID := session.SessionID()
124+
span.Annotate(llmobs.WithAnnotatedTags(map[string]string{"mcp_session_id": sessionID}))
125+
}
126+
}
127+
117128
func finishSpanWithIO[Req any, Res any](h *hooks, id any, request Req, result Res) {
118129
value, ok := h.spanCache.LoadAndDelete(id)
119130
if !ok {

contrib/mark3labs/mcp-go/mcpgo_test.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ func TestIntegrationSessionInitialize(t *testing.T) {
5252
server.WithHooks(hooks))
5353

5454
ctx := context.Background()
55+
sessionID := "test-session-init"
56+
session := &mockSession{id: sessionID}
57+
session.Initialize()
58+
ctx = srv.WithContext(ctx, session)
59+
5560
initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}`
5661

5762
response := srv.HandleMessage(ctx, []byte(initRequest))
@@ -77,6 +82,8 @@ func TestIntegrationSessionInitialize(t *testing.T) {
7782
assert.Contains(t, taskSpan.Tags, "client_name:test-client")
7883
assert.Contains(t, taskSpan.Tags, "client_version:test-client_1.0.0")
7984

85+
assert.Contains(t, taskSpan.Tags, "mcp_session_id:test-session-init")
86+
8087
assert.Contains(t, taskSpan.Meta, "input")
8188
assert.Contains(t, taskSpan.Meta, "output")
8289

@@ -101,7 +108,11 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
101108
tt := testTracer(t)
102109
defer tt.Stop()
103110

111+
hooks := &server.Hooks{}
112+
AddServerHooks(hooks)
113+
104114
srv := server.NewMCPServer("test-server", "1.0.0",
115+
server.WithHooks(hooks),
105116
server.WithToolHandlerMiddleware(NewToolHandlerMiddleware()))
106117

107118
calcTool := mcp.NewTool("calculator",
@@ -133,9 +144,13 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
133144
session.Initialize()
134145
ctx = srv.WithContext(ctx, session)
135146

147+
initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}`
148+
response := srv.HandleMessage(ctx, []byte(initRequest))
149+
assert.NotNil(t, response)
150+
136151
toolCallRequest := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"calculator","arguments":{"operation":"add","x":5,"y":3}}}`
137152

138-
response := srv.HandleMessage(ctx, []byte(toolCallRequest))
153+
response = srv.HandleMessage(ctx, []byte(toolCallRequest))
139154
assert.NotNil(t, response)
140155

141156
responseBytes, err := json.Marshal(response)
@@ -147,10 +162,25 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
147162
assert.Equal(t, "2.0", resp["jsonrpc"])
148163
assert.NotNil(t, resp["result"])
149164

150-
spans := tt.WaitForLLMObsSpans(t, 1)
151-
require.Len(t, spans, 1)
165+
spans := tt.WaitForLLMObsSpans(t, 2)
166+
require.Len(t, spans, 2)
167+
168+
var initSpan, toolSpan *testtracer.LLMObsSpan
169+
for i := range spans {
170+
if spans[i].Name == "mcp.initialize" {
171+
initSpan = &spans[i]
172+
} else if spans[i].Name == "calculator" {
173+
toolSpan = &spans[i]
174+
}
175+
}
176+
177+
require.NotNil(t, initSpan, "initialize span not found")
178+
require.NotNil(t, toolSpan, "tool span not found")
179+
180+
expectedTag := "mcp_session_id:test-session-123"
181+
assert.Contains(t, initSpan.Tags, expectedTag)
182+
assert.Contains(t, toolSpan.Tags, expectedTag)
152183

153-
toolSpan := spans[0]
154184
assert.Equal(t, "calculator", toolSpan.Name)
155185
assert.Equal(t, "tool", toolSpan.Meta["span.kind"])
156186

@@ -218,6 +248,8 @@ func TestIntegrationToolCallError(t *testing.T) {
218248
assert.Equal(t, "error_tool", toolSpan.Name)
219249
assert.Equal(t, "tool", toolSpan.Meta["span.kind"])
220250

251+
assert.Contains(t, toolSpan.Tags, "mcp_session_id:test-session-456")
252+
221253
assert.Contains(t, toolSpan.Meta, "error.message")
222254
assert.Contains(t, toolSpan.Meta["error.message"], "intentional test error")
223255
assert.Contains(t, toolSpan.Meta, "error.type")

0 commit comments

Comments
 (0)