Skip to content

Commit 1a11372

Browse files
committed
Introduce AbstractClientMcpHandlerRegistry
Signed-off-by: Daniel Garnier-Moiroux <[email protected]>
1 parent 2b1f49b commit 1a11372

File tree

3 files changed

+143
-214
lines changed

3 files changed

+143
-214
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package org.springframework.ai.mcp.annotation.spring;
2+
3+
import java.lang.annotation.Annotation;
4+
import java.util.ArrayList;
5+
import java.util.HashMap;
6+
import java.util.HashSet;
7+
import java.util.LinkedHashSet;
8+
import java.util.List;
9+
import java.util.Map;
10+
import java.util.Set;
11+
import java.util.stream.Collectors;
12+
13+
import io.modelcontextprotocol.spec.McpSchema;
14+
import org.springaicommunity.mcp.annotation.McpElicitation;
15+
import org.springaicommunity.mcp.annotation.McpLogging;
16+
import org.springaicommunity.mcp.annotation.McpProgress;
17+
import org.springaicommunity.mcp.annotation.McpPromptListChanged;
18+
import org.springaicommunity.mcp.annotation.McpResourceListChanged;
19+
import org.springaicommunity.mcp.annotation.McpSampling;
20+
import org.springaicommunity.mcp.annotation.McpToolListChanged;
21+
22+
import org.springframework.beans.BeansException;
23+
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
24+
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
25+
import org.springframework.core.annotation.AnnotationUtils;
26+
import org.springframework.util.ReflectionUtils;
27+
28+
/**
29+
* Base class for sync and async ClientMcpHandlerRegistries. Not intended for public use.
30+
*
31+
* @see ClientMcpAsyncHandlersRegistry
32+
* @see ClientMcpSyncHandlersRegistry
33+
*/
34+
abstract class AbstractClientMcpHandlerRegistry implements BeanFactoryPostProcessor {
35+
36+
protected Map<String, McpSchema.ClientCapabilities> capabilitiesPerClient = new HashMap<>();
37+
38+
protected ConfigurableListableBeanFactory beanFactory;
39+
40+
protected final Set<String> allAnnotatedBeans = new HashSet<>();
41+
42+
static final Class<? extends Annotation>[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class,
43+
McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class,
44+
McpPromptListChanged.class, McpResourceListChanged.class };
45+
46+
static final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null,
47+
null);
48+
49+
@Override
50+
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
51+
this.beanFactory = beanFactory;
52+
Map<String, List<String>> elicitationClientToAnnotatedBeans = new HashMap<>();
53+
Map<String, List<String>> samplingClientToAnnotatedBeans = new HashMap<>();
54+
for (var beanName : beanFactory.getBeanDefinitionNames()) {
55+
var definition = beanFactory.getBeanDefinition(beanName);
56+
var foundAnnotations = scan(definition.getResolvableType().toClass());
57+
if (!foundAnnotations.isEmpty()) {
58+
this.allAnnotatedBeans.add(beanName);
59+
}
60+
for (var foundAnnotation : foundAnnotations) {
61+
if (foundAnnotation instanceof McpSampling sampling) {
62+
for (var client : sampling.clients()) {
63+
samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName);
64+
}
65+
}
66+
else if (foundAnnotation instanceof McpElicitation elicitation) {
67+
for (var client : elicitation.clients()) {
68+
elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName);
69+
}
70+
}
71+
}
72+
}
73+
74+
for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) {
75+
if (elicitationEntry.getValue().size() > 1) {
76+
throw new IllegalArgumentException(
77+
"Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client"
78+
.formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue())));
79+
}
80+
}
81+
for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) {
82+
if (samplingEntry.getValue().size() > 1) {
83+
throw new IllegalArgumentException(
84+
"Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client"
85+
.formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue())));
86+
}
87+
}
88+
89+
Map<String, McpSchema.ClientCapabilities.Builder> capsPerClient = new HashMap<>();
90+
for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) {
91+
capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling();
92+
}
93+
for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) {
94+
capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder())
95+
.elicitation();
96+
}
97+
98+
this.capabilitiesPerClient = capsPerClient.entrySet()
99+
.stream()
100+
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build()));
101+
}
102+
103+
protected List<Annotation> scan(Class<?> beanClass) {
104+
List<Annotation> foundAnnotations = new ArrayList<>();
105+
106+
// Scan all methods in the bean class
107+
ReflectionUtils.doWithMethods(beanClass, method -> {
108+
for (var annotationType : CLIENT_MCP_ANNOTATIONS) {
109+
Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType);
110+
if (annotation != null) {
111+
foundAnnotations.add(annotation);
112+
}
113+
}
114+
});
115+
return foundAnnotations;
116+
}
117+
118+
protected Map<Class<? extends Annotation>, Set<Object>> getBeansByAnnotationType() {
119+
// Use a set in case multiple handlers are registered in the same bean
120+
Map<Class<? extends Annotation>, Set<Object>> beansByAnnotation = new HashMap<>();
121+
for (var annotation : CLIENT_MCP_ANNOTATIONS) {
122+
beansByAnnotation.put(annotation, new HashSet<>());
123+
}
124+
125+
for (var beanName : this.allAnnotatedBeans) {
126+
var bean = this.beanFactory.getBean(beanName);
127+
var annotations = scan(bean.getClass());
128+
for (var annotation : annotations) {
129+
beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean);
130+
}
131+
}
132+
return beansByAnnotation;
133+
}
134+
135+
}

mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java

Lines changed: 4 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,11 @@
1616

1717
package org.springframework.ai.mcp.annotation.spring;
1818

19-
import java.lang.annotation.Annotation;
2019
import java.util.ArrayList;
2120
import java.util.HashMap;
22-
import java.util.HashSet;
23-
import java.util.LinkedHashSet;
2421
import java.util.List;
2522
import java.util.Map;
26-
import java.util.Set;
2723
import java.util.function.Function;
28-
import java.util.stream.Collectors;
2924

3025
import io.modelcontextprotocol.spec.McpSchema;
3126
import org.springaicommunity.mcp.annotation.McpElicitation;
@@ -38,12 +33,7 @@
3833
import reactor.core.publisher.Flux;
3934
import reactor.core.publisher.Mono;
4035

41-
import org.springframework.beans.BeansException;
4236
import org.springframework.beans.factory.SmartInitializingSingleton;
43-
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
44-
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
45-
import org.springframework.core.annotation.AnnotationUtils;
46-
import org.springframework.util.ReflectionUtils;
4737

4838
/**
4939
* Registry of methods annotated with MCP Client annotations (sampling, logging, etc.).
@@ -70,20 +60,8 @@
7060
* @author Daniel Garnier-Moiroux
7161
* @since 1.1.0
7262
*/
73-
public class ClientMcpAsyncHandlersRegistry implements BeanFactoryPostProcessor, SmartInitializingSingleton {
74-
75-
private static final Class<? extends Annotation>[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class,
76-
McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class,
77-
McpPromptListChanged.class, McpResourceListChanged.class };
78-
79-
private final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null,
80-
null);
81-
82-
private Map<String, McpSchema.ClientCapabilities> capabilitiesPerClient = new HashMap<>();
83-
84-
private ConfigurableListableBeanFactory beanFactory;
85-
86-
private final Set<String> allAnnotatedBeans = new HashSet<>();
63+
public class ClientMcpAsyncHandlersRegistry extends AbstractClientMcpHandlerRegistry
64+
implements SmartInitializingSingleton {
8765

8866
private final Map<String, Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>>> samplingHandlers = new HashMap<>();
8967

@@ -104,7 +82,7 @@ public class ClientMcpAsyncHandlersRegistry implements BeanFactoryPostProcessor,
10482
* registered with the {@link McpSampling} and {@link McpElicitation} annotations.
10583
*/
10684
public McpSchema.ClientCapabilities getCapabilities(String clientName) {
107-
return this.capabilitiesPerClient.getOrDefault(clientName, this.EMPTY_CAPABILITIES);
85+
return this.capabilitiesPerClient.getOrDefault(clientName, EMPTY_CAPABILITIES);
10886
}
10987

11088
/**
@@ -206,90 +184,9 @@ public Mono<Void> handleResourceListChanged(String name, List<McpSchema.Resource
206184
return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedResources)).then();
207185
}
208186

209-
@Override
210-
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
211-
this.beanFactory = beanFactory;
212-
Map<String, List<String>> elicitationClientToAnnotatedBeans = new HashMap<>();
213-
Map<String, List<String>> samplingClientToAnnotatedBeans = new HashMap<>();
214-
for (var beanName : beanFactory.getBeanDefinitionNames()) {
215-
var definition = beanFactory.getBeanDefinition(beanName);
216-
var foundAnnotations = scan(definition.getResolvableType().toClass());
217-
if (!foundAnnotations.isEmpty()) {
218-
this.allAnnotatedBeans.add(beanName);
219-
}
220-
for (var foundAnnotation : foundAnnotations) {
221-
if (foundAnnotation instanceof McpSampling sampling) {
222-
for (var client : sampling.clients()) {
223-
samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName);
224-
}
225-
}
226-
else if (foundAnnotation instanceof McpElicitation elicitation) {
227-
for (var client : elicitation.clients()) {
228-
elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName);
229-
}
230-
}
231-
}
232-
}
233-
234-
for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) {
235-
if (elicitationEntry.getValue().size() > 1) {
236-
throw new IllegalArgumentException(
237-
"Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client"
238-
.formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue())));
239-
}
240-
}
241-
for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) {
242-
if (samplingEntry.getValue().size() > 1) {
243-
throw new IllegalArgumentException(
244-
"Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client"
245-
.formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue())));
246-
}
247-
}
248-
249-
Map<String, McpSchema.ClientCapabilities.Builder> capsPerClient = new HashMap<>();
250-
for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) {
251-
capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling();
252-
}
253-
for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) {
254-
capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder())
255-
.elicitation();
256-
}
257-
258-
this.capabilitiesPerClient = capsPerClient.entrySet()
259-
.stream()
260-
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build()));
261-
}
262-
263-
private List<Annotation> scan(Class<?> beanClass) {
264-
List<Annotation> foundAnnotations = new ArrayList<>();
265-
266-
// Scan all methods in the bean class
267-
ReflectionUtils.doWithMethods(beanClass, method -> {
268-
for (var annotationType : CLIENT_MCP_ANNOTATIONS) {
269-
Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType);
270-
if (annotation != null) {
271-
foundAnnotations.add(annotation);
272-
}
273-
}
274-
});
275-
return foundAnnotations;
276-
}
277-
278187
@Override
279188
public void afterSingletonsInstantiated() {
280-
// Use a set in case multiple handlers are registered in the same bean
281-
Map<Class<? extends Annotation>, Set<Object>> beansByAnnotation = new HashMap<>();
282-
for (var annotation : CLIENT_MCP_ANNOTATIONS) {
283-
beansByAnnotation.put(annotation, new HashSet<>());
284-
}
285-
286-
for (var beanName : this.allAnnotatedBeans) {
287-
var bean = this.beanFactory.getBean(beanName);
288-
var annotations = scan(bean.getClass());
289-
for (var annotation : annotations) {
290-
beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean);
291-
}
292-
}
189+
var beansByAnnotation = getBeansByAnnotationType();
293190

294191
var samplingSpecs = AsyncMcpAnnotationProviders
295192
.samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class)));

0 commit comments

Comments
 (0)