Skip to content

Commit dd15ed0

Browse files
jun06tDominic Green
authored andcommitted
Refactor interceptor chain functions (#220)
1 parent 6e1e746 commit dd15ed0

File tree

1 file changed

+36
-100
lines changed

1 file changed

+36
-100
lines changed

chain.go

Lines changed: 36 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,19 @@ import (
1919
func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
2020
n := len(interceptors)
2121

22-
if n > 1 {
23-
lastI := n - 1
24-
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
25-
var (
26-
chainHandler grpc.UnaryHandler
27-
curI int
28-
)
29-
30-
chainHandler = func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
31-
if curI == lastI {
32-
return handler(currentCtx, currentReq)
33-
}
34-
curI++
35-
resp, err := interceptors[curI](currentCtx, currentReq, info, chainHandler)
36-
curI--
37-
return resp, err
22+
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
23+
chainer := func(currentInter grpc.UnaryServerInterceptor, currentHandler grpc.UnaryHandler) grpc.UnaryHandler {
24+
return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
25+
return currentInter(currentCtx, currentReq, info, currentHandler)
3826
}
39-
40-
return interceptors[0](ctx, req, info, chainHandler)
4127
}
42-
}
4328

44-
if n == 1 {
45-
return interceptors[0]
46-
}
29+
chainedHandler := handler
30+
for i := n - 1; i >= 0; i-- {
31+
chainedHandler = chainer(interceptors[i], chainedHandler)
32+
}
4733

48-
// n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
49-
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
50-
return handler(ctx, req)
34+
return chainedHandler(ctx, req)
5135
}
5236
}
5337

@@ -59,35 +43,19 @@ func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnarySer
5943
func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
6044
n := len(interceptors)
6145

62-
if n > 1 {
63-
lastI := n - 1
64-
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
65-
var (
66-
chainHandler grpc.StreamHandler
67-
curI int
68-
)
69-
70-
chainHandler = func(currentSrv interface{}, currentStream grpc.ServerStream) error {
71-
if curI == lastI {
72-
return handler(currentSrv, currentStream)
73-
}
74-
curI++
75-
err := interceptors[curI](currentSrv, currentStream, info, chainHandler)
76-
curI--
77-
return err
46+
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
47+
chainer := func(currentInter grpc.StreamServerInterceptor, currentHandler grpc.StreamHandler) grpc.StreamHandler {
48+
return func(currentSrv interface{}, currentStream grpc.ServerStream) error {
49+
return currentInter(currentSrv, currentStream, info, currentHandler)
7850
}
79-
80-
return interceptors[0](srv, stream, info, chainHandler)
8151
}
82-
}
8352

84-
if n == 1 {
85-
return interceptors[0]
86-
}
53+
chainedHandler := handler
54+
for i := n - 1; i >= 0; i-- {
55+
chainedHandler = chainer(interceptors[i], chainedHandler)
56+
}
8757

88-
// n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
89-
return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
90-
return handler(srv, stream)
58+
return chainedHandler(srv, ss)
9159
}
9260
}
9361

@@ -98,35 +66,19 @@ func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.Stream
9866
func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
9967
n := len(interceptors)
10068

101-
if n > 1 {
102-
lastI := n - 1
103-
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
104-
var (
105-
chainHandler grpc.UnaryInvoker
106-
curI int
107-
)
108-
109-
chainHandler = func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
110-
if curI == lastI {
111-
return invoker(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentOpts...)
112-
}
113-
curI++
114-
err := interceptors[curI](currentCtx, currentMethod, currentReq, currentRepl, currentConn, chainHandler, currentOpts...)
115-
curI--
116-
return err
69+
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
70+
chainer := func(currentInter grpc.UnaryClientInterceptor, currentInvoker grpc.UnaryInvoker) grpc.UnaryInvoker {
71+
return func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
72+
return currentInter(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentInvoker, currentOpts...)
11773
}
118-
119-
return interceptors[0](ctx, method, req, reply, cc, chainHandler, opts...)
12074
}
121-
}
12275

123-
if n == 1 {
124-
return interceptors[0]
125-
}
76+
chainedInvoker := invoker
77+
for i := n - 1; i >= 0; i-- {
78+
chainedInvoker = chainer(interceptors[i], chainedInvoker)
79+
}
12680

127-
// n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
128-
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
129-
return invoker(ctx, method, req, reply, cc, opts...)
81+
return chainedInvoker(ctx, method, req, reply, cc, opts...)
13082
}
13183
}
13284

@@ -137,35 +89,19 @@ func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryCli
13789
func ChainStreamClient(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
13890
n := len(interceptors)
13991

140-
if n > 1 {
141-
lastI := n - 1
142-
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
143-
var (
144-
chainHandler grpc.Streamer
145-
curI int
146-
)
147-
148-
chainHandler = func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) {
149-
if curI == lastI {
150-
return streamer(currentCtx, currentDesc, currentConn, currentMethod, currentOpts...)
151-
}
152-
curI++
153-
stream, err := interceptors[curI](currentCtx, currentDesc, currentConn, currentMethod, chainHandler, currentOpts...)
154-
curI--
155-
return stream, err
92+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
93+
chainer := func(currentInter grpc.StreamClientInterceptor, currentStreamer grpc.Streamer) grpc.Streamer {
94+
return func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) {
95+
return currentInter(currentCtx, currentDesc, currentConn, currentMethod, currentStreamer, currentOpts...)
15696
}
157-
158-
return interceptors[0](ctx, desc, cc, method, chainHandler, opts...)
15997
}
160-
}
16198

162-
if n == 1 {
163-
return interceptors[0]
164-
}
99+
chainedStreamer := streamer
100+
for i := n - 1; i >= 0; i-- {
101+
chainedStreamer = chainer(interceptors[i], chainedStreamer)
102+
}
165103

166-
// n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
167-
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
168-
return streamer(ctx, desc, cc, method, opts...)
104+
return chainedStreamer(ctx, desc, cc, method, opts...)
169105
}
170106
}
171107

0 commit comments

Comments
 (0)