Skip to content

Commit 255ef56

Browse files
committed
Add RSocketServiceMethod support for suspending functions
See #34868 Signed-off-by: Dmitry Sulman <[email protected]>
1 parent 2faed3c commit 255ef56

File tree

5 files changed

+177
-3
lines changed

5 files changed

+177
-3
lines changed

spring-aop/src/main/java/org/springframework/aop/framework/CoroutinesUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.aop.framework;
1818

1919
import kotlin.coroutines.Continuation;
20+
import kotlinx.coroutines.flow.Flow;
2021
import kotlinx.coroutines.reactive.ReactiveFlowKt;
2122
import kotlinx.coroutines.reactor.MonoKt;
2223
import org.jspecify.annotations.Nullable;
@@ -35,6 +36,9 @@ static Object asFlow(@Nullable Object publisher) {
3536
if (publisher instanceof Publisher<?> rsPublisher) {
3637
return ReactiveFlowKt.asFlow(rsPublisher);
3738
}
39+
else if (publisher instanceof Flow<?>) {
40+
return publisher;
41+
}
3842
else {
3943
throw new IllegalArgumentException("Not a Reactive Streams Publisher: " + publisher);
4044
}

spring-aop/src/test/kotlin/org/springframework/aop/framework/CoroutinesUtilsTests.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.springframework.aop.framework
1818

1919
import kotlinx.coroutines.CoroutineName
2020
import kotlinx.coroutines.flow.Flow
21+
import kotlinx.coroutines.flow.flowOf
2122
import kotlinx.coroutines.flow.toList
2223
import kotlinx.coroutines.runBlocking
2324
import org.assertj.core.api.Assertions.assertThat
@@ -72,4 +73,16 @@ class CoroutinesUtilsTests {
7273
}
7374
}
7475

76+
@Test
77+
@Suppress("UNCHECKED_CAST")
78+
fun flowAsFlow() {
79+
val value1 = "foo"
80+
val value2 = "bar"
81+
val values = flowOf(value1, value2)
82+
val flow = CoroutinesUtils.asFlow(values) as Flow<String>
83+
runBlocking {
84+
assertThat(flow.toList()).containsExactly(value1, value2)
85+
}
86+
}
87+
7588
}

spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceMethod.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import reactor.core.publisher.Mono;
3131

3232
import org.springframework.core.DefaultParameterNameDiscoverer;
33+
import org.springframework.core.KotlinDetector;
3334
import org.springframework.core.MethodParameter;
3435
import org.springframework.core.ParameterizedTypeReference;
3536
import org.springframework.core.ReactiveAdapter;
@@ -54,6 +55,8 @@
5455
*/
5556
final class RSocketServiceMethod {
5657

58+
private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";
59+
5760
private final Method method;
5861

5962
private final MethodParameter[] parameters;
@@ -82,6 +85,10 @@ private static MethodParameter[] initMethodParameters(Method method) {
8285
if (count == 0) {
8386
return new MethodParameter[0];
8487
}
88+
if (KotlinDetector.isSuspendingFunction(method)) {
89+
count -= 1;
90+
}
91+
8592
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
8693
MethodParameter[] parameters = new MethodParameter[count];
8794
for (int i = 0; i < count; i++) {
@@ -129,10 +136,16 @@ private static Function<RSocketRequestValues, Object> initResponseFunction(
129136

130137
MethodParameter returnParam = new MethodParameter(method, -1);
131138
Class<?> returnType = returnParam.getParameterType();
139+
boolean isUnwrapped = KotlinDetector.isSuspendingFunction(method) &&
140+
!COROUTINES_FLOW_CLASS_NAME.equals(returnParam.getParameterType().getName());
141+
if (isUnwrapped) {
142+
returnType = Mono.class;
143+
}
144+
132145
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
133146

134147
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
135-
Class<?> actualType = actualParam.getNestedParameterType();
148+
Class<?> actualType = isUnwrapped ? actualParam.getParameterType() : actualParam.getNestedParameterType();
136149

137150
Function<RSocketRequestValues, Publisher<?>> responseFunction;
138151
if (ClassUtils.isVoidType(actualType) || (reactiveAdapter != null && reactiveAdapter.isNoValue())) {
@@ -147,7 +160,8 @@ else if (reactiveAdapter == null) {
147160
}
148161
else {
149162
ParameterizedTypeReference<?> payloadType =
150-
ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType());
163+
ParameterizedTypeReference.forType(isUnwrapped ? actualParam.getGenericParameterType() :
164+
actualParam.getNestedGenericParameterType());
151165

152166
responseFunction = values -> (
153167
reactiveAdapter.isMultiValue() ?

spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceProxyFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import org.springframework.aop.framework.ProxyFactory;
3333
import org.springframework.aop.framework.ReflectiveMethodInvocation;
34+
import org.springframework.core.KotlinDetector;
3435
import org.springframework.core.MethodIntrospector;
3536
import org.springframework.core.ReactiveAdapterRegistry;
3637
import org.springframework.core.annotation.AnnotatedElementUtils;
@@ -246,7 +247,9 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
246247
Method method = invocation.getMethod();
247248
RSocketServiceMethod serviceMethod = this.serviceMethods.get(method);
248249
if (serviceMethod != null) {
249-
return serviceMethod.invoke(invocation.getArguments());
250+
@Nullable Object[] arguments = KotlinDetector.isSuspendingFunction(method) ?
251+
resolveCoroutinesArguments(invocation.getArguments()) : invocation.getArguments();
252+
return serviceMethod.invoke(arguments);
250253
}
251254
if (method.isDefault()) {
252255
if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) {
@@ -256,6 +259,12 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
256259
}
257260
throw new IllegalStateException("Unexpected method invocation: " + method);
258261
}
262+
263+
private static Object[] resolveCoroutinesArguments(@Nullable Object[] args) {
264+
Object[] functionArgs = new Object[args.length - 1];
265+
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
266+
return functionArgs;
267+
}
259268
}
260269

261270
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright 2002-present the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.messaging.rsocket.service
18+
19+
import io.rsocket.util.DefaultPayload
20+
import kotlinx.coroutines.flow.Flow
21+
import kotlinx.coroutines.flow.flowOf
22+
import kotlinx.coroutines.flow.map
23+
import kotlinx.coroutines.flow.toList
24+
import kotlinx.coroutines.reactive.asFlow
25+
import kotlinx.coroutines.runBlocking
26+
import org.assertj.core.api.Assertions.assertThat
27+
import org.junit.jupiter.api.BeforeEach
28+
import org.junit.jupiter.api.Test
29+
import org.springframework.messaging.rsocket.RSocketRequester
30+
import org.springframework.messaging.rsocket.RSocketStrategies
31+
import org.springframework.messaging.rsocket.TestRSocket
32+
import org.springframework.util.MimeTypeUtils.TEXT_PLAIN
33+
import reactor.core.publisher.Flux
34+
import reactor.core.publisher.Mono
35+
36+
/**
37+
* Kotlin tests for [RSocketServiceMethod].
38+
*
39+
* @author Dmitry Sulman
40+
*/
41+
class RSocketServiceMethodKotlinTests {
42+
43+
private lateinit var rsocket: TestRSocket
44+
45+
private lateinit var proxyFactory: RSocketServiceProxyFactory
46+
47+
@BeforeEach
48+
fun setUp() {
49+
rsocket = TestRSocket()
50+
val requester = RSocketRequester.wrap(rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create())
51+
proxyFactory = RSocketServiceProxyFactory.builder(requester).build()
52+
}
53+
54+
@Test
55+
fun fireAndForget(): Unit = runBlocking {
56+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
57+
58+
val requestPayload = "request"
59+
service.fireAndForget(requestPayload)
60+
61+
assertThat(rsocket.savedMethodName).isEqualTo("fireAndForget")
62+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("ff")
63+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
64+
}
65+
66+
@Test
67+
fun requestResponse(): Unit = runBlocking {
68+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
69+
70+
val requestPayload = "request"
71+
val responsePayload = "response"
72+
rsocket.setPayloadMonoToReturn(Mono.just(DefaultPayload.create(responsePayload)))
73+
val response = service.requestResponse(requestPayload)
74+
75+
assertThat(response).isEqualTo(responsePayload)
76+
assertThat(rsocket.savedMethodName).isEqualTo("requestResponse")
77+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rr")
78+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
79+
}
80+
81+
@Test
82+
fun requestStream(): Unit = runBlocking {
83+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
84+
85+
val requestPayload = "request"
86+
val responsePayload1 = "response1"
87+
val responsePayload2 = "response2"
88+
rsocket.setPayloadFluxToReturn(
89+
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
90+
val response = service.requestStream(requestPayload).toList()
91+
92+
assertThat(response).containsExactly(responsePayload1, responsePayload2)
93+
assertThat(rsocket.savedMethodName).isEqualTo("requestStream")
94+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rs")
95+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
96+
}
97+
98+
@Test
99+
fun requestChannel(): Unit = runBlocking {
100+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
101+
102+
val requestPayload1 = "request1"
103+
val requestPayload2 = "request2"
104+
val responsePayload1 = "response1"
105+
val responsePayload2 = "response2"
106+
rsocket.setPayloadFluxToReturn(
107+
Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2)))
108+
val response = service.requestChannel(flowOf(requestPayload1, requestPayload2)).toList()
109+
110+
assertThat(response).containsExactly(responsePayload1, responsePayload2)
111+
assertThat(rsocket.savedMethodName).isEqualTo("requestChannel")
112+
113+
val savedPayloads = rsocket.savedPayloadFlux
114+
?.asFlow()
115+
?.map { it.dataUtf8 }
116+
?.toList()
117+
assertThat(savedPayloads).containsExactly(requestPayload1, requestPayload2)
118+
}
119+
120+
private interface SuspendingFunctionsService {
121+
122+
@RSocketExchange("ff")
123+
suspend fun fireAndForget(input: String)
124+
125+
@RSocketExchange("rr")
126+
suspend fun requestResponse(input: String): String
127+
128+
@RSocketExchange("rs")
129+
suspend fun requestStream(input: String): Flow<String>
130+
131+
@RSocketExchange("rc")
132+
suspend fun requestChannel(input: Flow<String>): Flow<String>
133+
}
134+
}

0 commit comments

Comments
 (0)