Skip to content

Commit 39fc229

Browse files
committed
Add protocol version to initialize response, only enable output schemas on supported protocols
1 parent a960ddf commit 39fc229

File tree

3 files changed

+168
-9
lines changed

3 files changed

+168
-9
lines changed

mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServer.java

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public final class McpServer implements Server {
8484
private final CountDownLatch done = new CountDownLatch(1);
8585
private final AtomicReference<JsonRpcRequest> initializeRequest = new AtomicReference<>();
8686
private final ToolFilter toolFilter;
87+
private volatile ProtocolVersion protocolVersion;
8788

8889
McpServer(McpServerBuilder builder) {
8990
this.services = builder.services;
@@ -127,9 +128,19 @@ private void handleRequest(JsonRpcRequest req) {
127128
switch (req.getMethod()) {
128129
case "initialize" -> {
129130
this.initializeRequest.set(req);
131+
var maybeVersion = req.getParams().getMember("protocolVersion");
132+
String pv = null;
133+
if (maybeVersion != null) {
134+
protocolVersion = ProtocolVersion.version(maybeVersion.asString());
135+
if (!(protocolVersion instanceof ProtocolVersion.UnknownVersion)) {
136+
pv = protocolVersion.identifier();
137+
}
138+
}
139+
130140
proxies.values().forEach(this::initialize);
131141
writeResponse(req.getId(),
132142
InitializeResult.builder()
143+
.protocolVersion(pv)
133144
.capabilities(Capabilities.builder()
134145
.tools(Tools.builder().listChanged(true).build())
135146
.prompts(Prompts.builder().listChanged(true).build())
@@ -158,14 +169,18 @@ private void handleRequest(JsonRpcRequest req) {
158169
var result = promptProcessor.buildPromptResult(prompt, promptArguments);
159170
writeResponse(req.getId(), result);
160171
}
161-
case "tools/list" -> writeResponse(req.getId(),
162-
ListToolsResult.builder()
163-
.tools(tools.values()
164-
.stream()
165-
.filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName()))
166-
.map(Tool::toolInfo)
167-
.toList())
168-
.build());
172+
case "tools/list" -> {
173+
boolean supportsOutputSchema = protocolVersion != null
174+
&& protocolVersion.compareTo(ProtocolVersion.v2025_06_18.INSTANCE) >= 0;
175+
writeResponse(req.getId(),
176+
ListToolsResult.builder()
177+
.tools(tools.values()
178+
.stream()
179+
.filter(t -> toolFilter.allowTool(t.serverId(), t.toolInfo().getName()))
180+
.map(tool -> extractToolInfo(tool, supportsOutputSchema))
181+
.toList())
182+
.build());
183+
}
169184
case "tools/call" -> {
170185
var operationName = req.getParams().getMember("name").asString();
171186
var tool = tools.get(operationName);
@@ -225,6 +240,16 @@ private void handleRequest(JsonRpcRequest req) {
225240
}
226241
}
227242

243+
private ToolInfo extractToolInfo(Tool tool, boolean supportsOutput) {
244+
var toolInfo = tool.toolInfo();
245+
if (supportsOutput || toolInfo.getOutputSchema() == null) {
246+
return toolInfo;
247+
}
248+
return toolInfo.toBuilder()
249+
.outputSchema(null)
250+
.build();
251+
}
252+
228253
private void validate(JsonRpcRequest req) {
229254
Document id = req.getId();
230255
boolean isRequest = !req.getMethod().startsWith("notifications/");
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package software.amazon.smithy.java.mcp.server;
7+
8+
public abstract sealed class ProtocolVersion implements Comparable<ProtocolVersion>
9+
permits ProtocolVersion.UnknownVersion, ProtocolVersion.v2024_11_05, ProtocolVersion.v2025_03_26,
10+
ProtocolVersion.v2025_06_18 {
11+
public static final class v2025_06_18 extends ProtocolVersion {
12+
public static final v2025_06_18 INSTANCE = new v2025_06_18();
13+
14+
private v2025_06_18() {
15+
super("2025-06-18");
16+
}
17+
}
18+
19+
public static final class v2025_03_26 extends ProtocolVersion {
20+
public static final v2025_03_26 INSTANCE = new v2025_03_26();
21+
22+
private v2025_03_26() {
23+
super("2025-03-26");
24+
}
25+
}
26+
27+
public static final class v2024_11_05 extends ProtocolVersion {
28+
public static final v2024_11_05 INSTANCE = new v2024_11_05();
29+
30+
private v2024_11_05() {
31+
super("2024-11-05");
32+
}
33+
}
34+
35+
public static final class UnknownVersion extends ProtocolVersion {
36+
private UnknownVersion(String identifier) {
37+
super(identifier);
38+
}
39+
}
40+
41+
private final String identifier;
42+
43+
private ProtocolVersion(String identifier) {
44+
this.identifier = identifier;
45+
}
46+
47+
public String identifier() {
48+
return identifier;
49+
}
50+
51+
@Override
52+
public final int compareTo(ProtocolVersion o) {
53+
if (o instanceof UnknownVersion) {
54+
if (this instanceof UnknownVersion) {
55+
return 0;
56+
}
57+
return 1;
58+
}
59+
60+
return identifier.compareTo(o.identifier);
61+
}
62+
63+
public static ProtocolVersion version(String identifier) {
64+
return switch (identifier) {
65+
case "2024-11-05" -> v2024_11_05.INSTANCE;
66+
case "2025-03-26" -> v2025_03_26.INSTANCE;
67+
case "2025-06-18" -> v2025_06_18.INSTANCE;
68+
default -> new UnknownVersion(identifier);
69+
};
70+
}
71+
}

mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
import static org.junit.jupiter.api.Assertions.assertEquals;
99
import static org.junit.jupiter.api.Assertions.assertFalse;
1010
import static org.junit.jupiter.api.Assertions.assertNotNull;
11+
import static org.junit.jupiter.api.Assertions.assertNull;
12+
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
1113
import static org.junit.jupiter.api.Assertions.assertTrue;
1214

1315
import java.math.BigDecimal;
1416
import java.math.BigInteger;
1517
import java.nio.charset.StandardCharsets;
18+
import java.time.Duration;
1619
import java.util.List;
1720
import java.util.Map;
1821
import java.util.concurrent.atomic.AtomicReference;
@@ -59,9 +62,48 @@ public void afterEach() {
5962
}
6063
}
6164

65+
private void initializeWithProtocolVersion(ProtocolVersion protocolVersion) {
66+
write("initialize", Document.of(Map.of("protocolVersion", Document.of(protocolVersion.identifier()))));
67+
var pv = read().getResult().getMember("protocolVersion").asString();
68+
assertEquals(protocolVersion.identifier(), pv);
69+
}
70+
71+
@Test
72+
public void noOutputSchemaWithUnsupportedProtocolVersion() {
73+
server = McpServer.builder()
74+
.name("smithy-mcp-server")
75+
.input(input)
76+
.output(output)
77+
.addService("test-mcp",
78+
ProxyService.builder()
79+
.service(ShapeId.from("smithy.test#TestService"))
80+
.proxyEndpoint("http://localhost")
81+
.model(MODEL)
82+
.build())
83+
.build();
84+
85+
server.start();
86+
87+
initializeWithProtocolVersion(ProtocolVersion.v2025_03_26.INSTANCE);
88+
write("tools/list", Document.of(Map.of()));
89+
var response = read();
90+
var tools = response.getResult().asStringMap().get("tools").asList();
91+
92+
var tool = tools.stream()
93+
.filter(t -> t.asStringMap().get("name").asString().equals("NoInputOperation"))
94+
.findFirst()
95+
.orElseThrow()
96+
.asStringMap();
97+
98+
assertEquals("NoInputOperation", tool.get("name").asString());
99+
assertNotNull(tool.get("inputSchema"));
100+
assertNull(tool.get("outputSchema"));
101+
}
102+
62103
@Test
63104
public void validateToolsList() {
64105
server = McpServer.builder()
106+
.name("smithy-mcp-server")
65107
.input(input)
66108
.output(output)
67109
.addService("test-mcp",
@@ -74,6 +116,7 @@ public void validateToolsList() {
74116

75117
server.start();
76118

119+
initializeWithProtocolVersion(ProtocolVersion.v2025_06_18.INSTANCE);
77120
write("tools/list", Document.of(Map.of()));
78121
var response = read();
79122
var result = response.getResult().asStringMap();
@@ -94,6 +137,7 @@ public void validateToolsList() {
94137
@Test
95138
public void validateNoIOOperationTool() {
96139
server = McpServer.builder()
140+
.name("smithy-mcp-server")
97141
.input(input)
98142
.output(output)
99143
.addService("test-mcp",
@@ -106,6 +150,7 @@ public void validateNoIOOperationTool() {
106150

107151
server.start();
108152

153+
initializeWithProtocolVersion(ProtocolVersion.v2025_06_18.INSTANCE);
109154
write("tools/list", Document.of(Map.of()));
110155
var response = read();
111156
var tools = response.getResult().asStringMap().get("tools").asList();
@@ -126,6 +171,7 @@ public void validateNoIOOperationTool() {
126171
@Test
127172
public void validateNoOutputOperationTool() {
128173
server = McpServer.builder()
174+
.name("smithy-mcp-server")
129175
.input(input)
130176
.output(output)
131177
.addService("test-mcp",
@@ -138,6 +184,7 @@ public void validateNoOutputOperationTool() {
138184

139185
server.start();
140186

187+
initializeWithProtocolVersion(ProtocolVersion.v2025_06_18.INSTANCE);
141188
write("tools/list", Document.of(Map.of()));
142189
var response = read();
143190
var tools = response.getResult().asStringMap().get("tools").asList();
@@ -167,6 +214,7 @@ public void validateNoOutputOperationTool() {
167214
@Test
168215
public void validateNoInputOperationTool() {
169216
server = McpServer.builder()
217+
.name("smithy-mcp-server")
170218
.input(input)
171219
.output(output)
172220
.addService("test-mcp",
@@ -179,6 +227,7 @@ public void validateNoInputOperationTool() {
179227

180228
server.start();
181229

230+
initializeWithProtocolVersion(ProtocolVersion.v2025_06_18.INSTANCE);
182231
write("tools/list", Document.of(Map.of()));
183232
var response = read();
184233
var tools = response.getResult().asStringMap().get("tools").asList();
@@ -208,6 +257,7 @@ public void validateNoInputOperationTool() {
208257
@Test
209258
public void validateTestOperationTool() {
210259
server = McpServer.builder()
260+
.name("smithy-mcp-server")
211261
.input(input)
212262
.output(output)
213263
.addService("test-mcp",
@@ -220,6 +270,7 @@ public void validateTestOperationTool() {
220270

221271
server.start();
222272

273+
initializeWithProtocolVersion(ProtocolVersion.v2025_06_18.INSTANCE);
223274
write("tools/list", Document.of(Map.of()));
224275
var response = read();
225276
var tools = response.getResult().asStringMap().get("tools").asList();
@@ -240,6 +291,7 @@ public void validateTestOperationTool() {
240291
@Test
241292
void testNumberAndStringIds() {
242293
server = McpServer.builder()
294+
.name("smithy-mcp-server")
243295
.input(input)
244296
.output(output)
245297
.addService("test-mcp",
@@ -285,6 +337,7 @@ void testNumberAndStringIds() {
285337
@Test
286338
void testInvalidIds() {
287339
server = McpServer.builder()
340+
.name("smithy-mcp-server")
288341
.input(input)
289342
.output(output)
290343
.addService("test-mcp",
@@ -327,6 +380,7 @@ void testInvalidIds() {
327380
@Test
328381
void testRequestsRequireIds() {
329382
server = McpServer.builder()
383+
.name("smithy-mcp-server")
330384
.input(input)
331385
.output(output)
332386
.addService("test-mcp",
@@ -350,6 +404,7 @@ void testRequestsRequireIds() {
350404
void testInputAdaptation() {
351405
AtomicReference<StructDocument> capturedInput = new AtomicReference<>();
352406
server = McpServer.builder()
407+
.name("smithy-mcp-server")
353408
.input(input)
354409
.output(output)
355410
.addService("test-mcp",
@@ -456,6 +511,7 @@ public void readBeforeSerialization(InputHook<?, ?> hook) {
456511
@Test
457512
void testNotificationsDoNotRequireRequestId() {
458513
server = McpServer.builder()
514+
.name("smithy-mcp-server")
459515
.input(input)
460516
.output(output)
461517
.addService("test-mcp",
@@ -485,6 +541,7 @@ void testNotificationsDoNotRequireRequestId() {
485541
@Test
486542
void testPromptsList() {
487543
server = McpServer.builder()
544+
.name("smithy-mcp-server")
488545
.input(input)
489546
.output(output)
490547
.addService("test-mcp",
@@ -528,6 +585,7 @@ void testPromptsList() {
528585
@Test
529586
void testPromptsGetWithValidPrompt() {
530587
server = McpServer.builder()
588+
.name("smithy-mcp-server")
531589
.input(input)
532590
.output(output)
533591
.addService("test-mcp",
@@ -561,6 +619,7 @@ void testPromptsGetWithValidPrompt() {
561619
@Test
562620
void testPromptsGetWithDifferentCasing() {
563621
server = McpServer.builder()
622+
.name("smithy-mcp-server")
564623
.input(input)
565624
.output(output)
566625
.addService("test-mcp",
@@ -631,6 +690,7 @@ void testPromptsGetWithDifferentCasing() {
631690
@Test
632691
void testPromptsGetWithInvalidPrompt() {
633692
server = McpServer.builder()
693+
.name("smithy-mcp-server")
634694
.input(input)
635695
.output(output)
636696
.addService("test-mcp",
@@ -670,6 +730,7 @@ void testPromptsGetWithTemplateArguments() {
670730
.unwrap();
671731

672732
server = McpServer.builder()
733+
.name("smithy-mcp-server")
673734
.input(input)
674735
.output(output)
675736
.addService("test-mcp",
@@ -711,6 +772,7 @@ void testPromptsGetWithMissingRequiredArguments() {
711772
.unwrap();
712773

713774
server = McpServer.builder()
775+
.name("smithy-mcp-server")
714776
.input(input)
715777
.output(output)
716778
.addService("test-mcp",
@@ -748,6 +810,7 @@ void testApplyTemplateArgumentsEdgeCases() {
748810
.unwrap();
749811

750812
server = McpServer.builder()
813+
.name("smithy-mcp-server")
751814
.input(input)
752815
.output(output)
753816
.addService("test-mcp",
@@ -942,7 +1005,7 @@ private void write(String method, Document document, Document requestId) {
9421005
}
9431006

9441007
private JsonRpcResponse read() {
945-
var line = output.read();
1008+
var line = assertTimeoutPreemptively(Duration.ofSeconds(1), output::read, "No response within one second");
9461009
return CODEC.deserializeShape(line, JsonRpcResponse.builder());
9471010
}
9481011

0 commit comments

Comments
 (0)