1616
1717package org .springframework .ai .mcp .annotation .spring ;
1818
19- import java .lang .annotation .Annotation ;
2019import java .util .ArrayList ;
2120import java .util .HashMap ;
22- import java .util .HashSet ;
23- import java .util .LinkedHashSet ;
2421import java .util .List ;
2522import java .util .Map ;
26- import java .util .Set ;
2723import java .util .function .Function ;
28- import java .util .stream .Collectors ;
2924
3025import io .modelcontextprotocol .spec .McpSchema ;
3126import org .springaicommunity .mcp .annotation .McpElicitation ;
3833import reactor .core .publisher .Flux ;
3934import reactor .core .publisher .Mono ;
4035
41- import org .springframework .beans .BeansException ;
4236import org .springframework .beans .factory .SmartInitializingSingleton ;
43- import org .springframework .beans .factory .config .BeanFactoryPostProcessor ;
44- import org .springframework .beans .factory .config .ConfigurableListableBeanFactory ;
45- import org .springframework .core .annotation .AnnotationUtils ;
46- import org .springframework .util .ReflectionUtils ;
4737
4838/**
4939 * Registry of methods annotated with MCP Client annotations (sampling, logging, etc.).
7060 * @author Daniel Garnier-Moiroux
7161 * @since 1.1.0
7262 */
73- public class ClientMcpAsyncHandlersRegistry implements BeanFactoryPostProcessor , SmartInitializingSingleton {
74-
75- private static final Class <? extends Annotation >[] CLIENT_MCP_ANNOTATIONS = new Class [] { McpSampling .class ,
76- McpElicitation .class , McpLogging .class , McpProgress .class , McpToolListChanged .class ,
77- McpPromptListChanged .class , McpResourceListChanged .class };
78-
79- private final McpSchema .ClientCapabilities EMPTY_CAPABILITIES = new McpSchema .ClientCapabilities (null , null , null ,
80- null );
81-
82- private Map <String , McpSchema .ClientCapabilities > capabilitiesPerClient = new HashMap <>();
83-
84- private ConfigurableListableBeanFactory beanFactory ;
85-
86- private final Set <String > allAnnotatedBeans = new HashSet <>();
63+ public class ClientMcpAsyncHandlersRegistry extends AbstractClientMcpHandlerRegistry
64+ implements SmartInitializingSingleton {
8765
8866 private final Map <String , Function <McpSchema .CreateMessageRequest , Mono <McpSchema .CreateMessageResult >>> samplingHandlers = new HashMap <>();
8967
@@ -104,7 +82,7 @@ public class ClientMcpAsyncHandlersRegistry implements BeanFactoryPostProcessor,
10482 * registered with the {@link McpSampling} and {@link McpElicitation} annotations.
10583 */
10684 public McpSchema .ClientCapabilities getCapabilities (String clientName ) {
107- return this .capabilitiesPerClient .getOrDefault (clientName , this . EMPTY_CAPABILITIES );
85+ return this .capabilitiesPerClient .getOrDefault (clientName , EMPTY_CAPABILITIES );
10886 }
10987
11088 /**
@@ -206,90 +184,9 @@ public Mono<Void> handleResourceListChanged(String name, List<McpSchema.Resource
206184 return Flux .fromIterable (consumers ).flatMap (c -> c .apply (updatedResources )).then ();
207185 }
208186
209- @ Override
210- public void postProcessBeanFactory (ConfigurableListableBeanFactory beanFactory ) throws BeansException {
211- this .beanFactory = beanFactory ;
212- Map <String , List <String >> elicitationClientToAnnotatedBeans = new HashMap <>();
213- Map <String , List <String >> samplingClientToAnnotatedBeans = new HashMap <>();
214- for (var beanName : beanFactory .getBeanDefinitionNames ()) {
215- var definition = beanFactory .getBeanDefinition (beanName );
216- var foundAnnotations = scan (definition .getResolvableType ().toClass ());
217- if (!foundAnnotations .isEmpty ()) {
218- this .allAnnotatedBeans .add (beanName );
219- }
220- for (var foundAnnotation : foundAnnotations ) {
221- if (foundAnnotation instanceof McpSampling sampling ) {
222- for (var client : sampling .clients ()) {
223- samplingClientToAnnotatedBeans .computeIfAbsent (client , c -> new ArrayList <>()).add (beanName );
224- }
225- }
226- else if (foundAnnotation instanceof McpElicitation elicitation ) {
227- for (var client : elicitation .clients ()) {
228- elicitationClientToAnnotatedBeans .computeIfAbsent (client , c -> new ArrayList <>()).add (beanName );
229- }
230- }
231- }
232- }
233-
234- for (var elicitationEntry : elicitationClientToAnnotatedBeans .entrySet ()) {
235- if (elicitationEntry .getValue ().size () > 1 ) {
236- throw new IllegalArgumentException (
237- "Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client"
238- .formatted (elicitationEntry .getKey (), new LinkedHashSet <>(elicitationEntry .getValue ())));
239- }
240- }
241- for (var samplingEntry : samplingClientToAnnotatedBeans .entrySet ()) {
242- if (samplingEntry .getValue ().size () > 1 ) {
243- throw new IllegalArgumentException (
244- "Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client"
245- .formatted (samplingEntry .getKey (), new LinkedHashSet <>(samplingEntry .getValue ())));
246- }
247- }
248-
249- Map <String , McpSchema .ClientCapabilities .Builder > capsPerClient = new HashMap <>();
250- for (var samplingClient : samplingClientToAnnotatedBeans .keySet ()) {
251- capsPerClient .computeIfAbsent (samplingClient , ignored -> McpSchema .ClientCapabilities .builder ()).sampling ();
252- }
253- for (var elicitationClient : elicitationClientToAnnotatedBeans .keySet ()) {
254- capsPerClient .computeIfAbsent (elicitationClient , ignored -> McpSchema .ClientCapabilities .builder ())
255- .elicitation ();
256- }
257-
258- this .capabilitiesPerClient = capsPerClient .entrySet ()
259- .stream ()
260- .collect (Collectors .toMap (Map .Entry ::getKey , entry -> entry .getValue ().build ()));
261- }
262-
263- private List <Annotation > scan (Class <?> beanClass ) {
264- List <Annotation > foundAnnotations = new ArrayList <>();
265-
266- // Scan all methods in the bean class
267- ReflectionUtils .doWithMethods (beanClass , method -> {
268- for (var annotationType : CLIENT_MCP_ANNOTATIONS ) {
269- Annotation annotation = AnnotationUtils .findAnnotation (method , annotationType );
270- if (annotation != null ) {
271- foundAnnotations .add (annotation );
272- }
273- }
274- });
275- return foundAnnotations ;
276- }
277-
278187 @ Override
279188 public void afterSingletonsInstantiated () {
280- // Use a set in case multiple handlers are registered in the same bean
281- Map <Class <? extends Annotation >, Set <Object >> beansByAnnotation = new HashMap <>();
282- for (var annotation : CLIENT_MCP_ANNOTATIONS ) {
283- beansByAnnotation .put (annotation , new HashSet <>());
284- }
285-
286- for (var beanName : this .allAnnotatedBeans ) {
287- var bean = this .beanFactory .getBean (beanName );
288- var annotations = scan (bean .getClass ());
289- for (var annotation : annotations ) {
290- beansByAnnotation .computeIfAbsent (annotation .annotationType (), k -> new HashSet <>()).add (bean );
291- }
292- }
189+ var beansByAnnotation = getBeansByAnnotationType ();
293190
294191 var samplingSpecs = AsyncMcpAnnotationProviders
295192 .samplingSpecifications (new ArrayList <>(beansByAnnotation .get (McpSampling .class )));
0 commit comments