Skip to content

Commit 1e53cc8

Browse files
committed
Add tests for session ID tagging
1 parent 0507b87 commit 1e53cc8

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

contrib/mark3labs/mcp-go/mcpgo.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ func NewToolHandlerMiddleware() server.ToolHandlerMiddleware {
5959
outputText = string(resultJSON)
6060
}
6161

62+
tagWithSessionID(ctx, toolSpan)
63+
6264
toolSpan.AnnotateTextIO(string(inputJSON), outputText)
6365

6466
if err != nil {
@@ -81,6 +83,7 @@ func newHooks() *hooks {
8183
func (h *hooks) onBeforeInitialize(ctx context.Context, id any, request *mcp.InitializeRequest) {
8284
taskSpan, _ := llmobs.StartTaskSpan(ctx, "mcp.initialize", llmobs.WithIntegration("mark3labs/mcp-go"))
8385
h.spanCache.Store(id, taskSpan)
86+
tagWithSessionID(ctx, taskSpan)
8487
}
8588

8689
func (h *hooks) onAfterInitialize(ctx context.Context, id any, request *mcp.InitializeRequest, result *mcp.InitializeResult) {
@@ -109,6 +112,14 @@ func (h *hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, messa
109112
span.Finish(llmobs.WithError(err))
110113
}
111114

115+
func tagWithSessionID(ctx context.Context, span llmobs.Span) {
116+
session := server.ClientSessionFromContext(ctx)
117+
if session != nil {
118+
sessionID := session.SessionID()
119+
span.Annotate(llmobs.WithAnnotatedTags(map[string]string{"mcp_session_id": sessionID}))
120+
}
121+
}
122+
112123
func finishSpanWithIO[Req any, Res any](h *hooks, id any, request Req, result Res) {
113124
value, ok := h.spanCache.LoadAndDelete(id)
114125
if !ok {

contrib/mark3labs/mcp-go/mcpgo_test.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ func TestIntegrationSessionInitialize(t *testing.T) {
5454
server.WithHooks(hooks))
5555

5656
ctx := context.Background()
57+
sessionID := "test-session-init"
58+
session := &mockSession{id: sessionID}
59+
session.Initialize()
60+
ctx = srv.WithContext(ctx, session)
61+
5762
initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}`
5863

5964
response := srv.HandleMessage(ctx, []byte(initRequest))
@@ -76,6 +81,8 @@ func TestIntegrationSessionInitialize(t *testing.T) {
7681
assert.Equal(t, "mcp.initialize", taskSpan.Name)
7782
assert.Equal(t, "task", taskSpan.Meta["span.kind"])
7883

84+
assert.Contains(t, taskSpan.Tags, "mcp_session_id:test-session-init")
85+
7986
assert.Contains(t, taskSpan.Meta, "input")
8087
assert.Contains(t, taskSpan.Meta, "output")
8188

@@ -99,7 +106,11 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
99106
tt := testTracer(t)
100107
defer tt.Stop()
101108

109+
hooks := &server.Hooks{}
110+
AddServerHooks(hooks)
111+
102112
srv := server.NewMCPServer("test-server", "1.0.0",
113+
server.WithHooks(hooks),
103114
server.WithToolHandlerMiddleware(NewToolHandlerMiddleware()))
104115

105116
calcTool := mcp.NewTool("calculator",
@@ -131,9 +142,13 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
131142
session.Initialize()
132143
ctx = srv.WithContext(ctx, session)
133144

145+
initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}`
146+
response := srv.HandleMessage(ctx, []byte(initRequest))
147+
assert.NotNil(t, response)
148+
134149
toolCallRequest := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"calculator","arguments":{"operation":"add","x":5,"y":3}}}`
135150

136-
response := srv.HandleMessage(ctx, []byte(toolCallRequest))
151+
response = srv.HandleMessage(ctx, []byte(toolCallRequest))
137152
assert.NotNil(t, response)
138153

139154
responseBytes, err := json.Marshal(response)
@@ -145,10 +160,25 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
145160
assert.Equal(t, "2.0", resp["jsonrpc"])
146161
assert.NotNil(t, resp["result"])
147162

148-
spans := tt.WaitForLLMObsSpans(t, 1)
149-
require.Len(t, spans, 1)
163+
spans := tt.WaitForLLMObsSpans(t, 2)
164+
require.Len(t, spans, 2)
165+
166+
var initSpan, toolSpan *testtracer.LLMObsSpan
167+
for i := range spans {
168+
if spans[i].Name == "mcp.initialize" {
169+
initSpan = &spans[i]
170+
} else if spans[i].Name == "calculator" {
171+
toolSpan = &spans[i]
172+
}
173+
}
174+
175+
require.NotNil(t, initSpan, "initialize span not found")
176+
require.NotNil(t, toolSpan, "tool span not found")
177+
178+
expectedTag := "mcp_session_id:test-session-123"
179+
assert.Contains(t, initSpan.Tags, expectedTag)
180+
assert.Contains(t, toolSpan.Tags, expectedTag)
150181

151-
toolSpan := spans[0]
152182
assert.Equal(t, "calculator", toolSpan.Name)
153183
assert.Equal(t, "tool", toolSpan.Meta["span.kind"])
154184

@@ -215,6 +245,8 @@ func TestIntegrationToolCallError(t *testing.T) {
215245
assert.Equal(t, "error_tool", toolSpan.Name)
216246
assert.Equal(t, "tool", toolSpan.Meta["span.kind"])
217247

248+
assert.Contains(t, toolSpan.Tags, "mcp_session_id:test-session-456")
249+
218250
assert.Contains(t, toolSpan.Meta, "error.message")
219251
assert.Contains(t, toolSpan.Meta["error.message"], "intentional test error")
220252
assert.Contains(t, toolSpan.Meta, "error.type")

0 commit comments

Comments
 (0)