diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 89381f3..d8fae3f 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -73,7 +73,7 @@ jobs: version: "22.2" - name: Install protoc-gen-grpc-gateway run: | - git clone https://github.com/geebytes/grpc-gateway.git + git clone --recursive https://github.com/geebytes/grpc-gateway.git cd grpc-gateway go install ./protoc-gen-grpc-gateway - name: Test diff --git a/.vscode/settings.json b/.vscode/settings.json index 54e84f7..e5cd08e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,6 @@ "editor.insertSpaces": true, "editor.detectIndentation": false, "go.formatTool": "goimports", - "editor.formatOnSave": true + "editor.formatOnSave": true, + "go.testTimeout": "60s" } \ No newline at end of file diff --git a/cmd/begonia/endpoint.go b/cmd/begonia/endpoint.go index 45ae1bb..22e1d19 100644 --- a/cmd/begonia/endpoint.go +++ b/cmd/begonia/endpoint.go @@ -67,7 +67,7 @@ func readInitAPP(env string) { addr = gw.Addr } -func RegisterEndpoint(env,name string, endpoints []string, pbFile string, opts ...client.EndpointOption) { +func RegisterEndpoint(env, name string, endpoints []string, pbFile string, opts ...client.EndpointOption) { readInitAPP(env) pb, err := os.ReadFile(pbFile) if err != nil { @@ -103,7 +103,7 @@ func RegisterEndpoint(env,name string, endpoints []string, pbFile string, opts . log.Printf("#####################Add Endpoint Success#####################") log.Printf("#####################ID:%s####################################", rsp.Id) } -func UpdateEndpoint(env,id string, mask []string, opts ...client.EndpointOption) { +func UpdateEndpoint(env, id string, mask []string, opts ...client.EndpointOption) { readInitAPP(env) apiClient := client.NewEndpointAPI(addr, accessKey, secret) log.Printf("#####################Update Endpoint###########################") @@ -120,7 +120,7 @@ func UpdateEndpoint(env,id string, mask []string, opts ...client.EndpointOption) log.Printf("#####################Update Endpoint %s Success#####################", id) } -func DeleteEndpoint(env,id string) { +func DeleteEndpoint(env, id string) { readInitAPP(env) apiClient := client.NewEndpointAPI(addr, accessKey, secret) log.Printf("#####################Delete Endpoint:%s#####################", id) diff --git a/cmd/begonia/main.go b/cmd/begonia/main.go index 8bfbf93..e3e499a 100644 --- a/cmd/begonia/main.go +++ b/cmd/begonia/main.go @@ -77,7 +77,7 @@ func NewEndpointDelCmd() *cobra.Command { id, _ := cmd.Flags().GetString("id") env, _ := cmd.Flags().GetString("env") - DeleteEndpoint(env,id) + DeleteEndpoint(env, id) }, } cmd.Flags().StringP("id", "i", "", "ID Of Your Service") @@ -96,9 +96,9 @@ func NewEndpointAddCmd() *cobra.Command { tags, _ := cmd.Flags().GetStringArray("tags") balance, _ := cmd.Flags().GetString("balance") endpoints, _ := cmd.Flags().GetStringArray("endpoint") - env,_:=cmd.Flags().GetString("env") + env, _ := cmd.Flags().GetString("env") - RegisterEndpoint(env,name, endpoints, desc, client.WithBalance(strings.ToUpper(balance)), client.WithTags(tags)) + RegisterEndpoint(env, name, endpoints, desc, client.WithBalance(strings.ToUpper(balance)), client.WithTags(tags)) }, } cmd = newWriteEndpointCmd(cmd) @@ -134,7 +134,7 @@ func NewEndpointUpdateCmd() *cobra.Command { options = append(options, client.WithName(name)) mask = append(mask, "name") } - + if cmd.Flags().Changed("desc") { options = append(options, client.WithDescription(desc)) mask = append(mask, "description") @@ -158,8 +158,8 @@ func NewEndpointUpdateCmd() *cobra.Command { options = append(options, client.WithEndpoints(meta)) mask = append(mask, "endpoints") } - env,_:=cmd.Flags().GetString("env") - UpdateEndpoint(env,id, mask, options...) + env, _ := cmd.Flags().GetString("env") + UpdateEndpoint(env, id, mask, options...) }, } cmd = newWriteEndpointCmd(cmd) @@ -205,4 +205,14 @@ func main() { if err := cmd.Execute(); err != nil { log.Fatalf("failed to start begonia: %v", err) } + // env, _ := cmd.Flags().GetString("env") + // cnf, err := cmd.Flags().GetString("config") + // if err != nil { + // log.Fatalf("failed to get config: %v", err) + // } + // config := config.ReadConfigWithDir("dev", "/data/work/begonia-org/begonia/config/settings.yml") + // worker := internal.New(config, gateway.Log, "127.0.0.1:12138") + // hd, _ := os.UserHomeDir() + // _ = os.WriteFile(hd+"/.begonia/gateway.json", []byte(fmt.Sprintf(`{"addr":"http://%s"}`, "127.0.0.1:12138")), 0666) + // worker.Start() } diff --git a/config/settings.yml b/config/settings.yml index 0a5b32e..6d524f1 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -11,8 +11,8 @@ file: engines: - name: "FILE_ENGINE_MINIO" endpoint: "127.0.0.1:9000" - accessKey: "7OdVJK1alV8cpRMeBLBW" - secretKey: "2GNNRqqElReC1KnV3kX9jSjyLU4kwOaTZEqDS2vH" + accessKey: "rLV2Jjj2UbMWSJOhTOtZ" + secretKey: "OVyJOILwx4iVE0EVJB4CKB65j7xlhjT5q1aGCv5t" - name: "FILE_ENGINE_LOCAL" endpoint: "/data/work/begonia-org/begonia/upload" protos: @@ -70,11 +70,12 @@ gateway: - "example.com" plugins: local: - logger: 1 - exception: 0 + # 优先级越大越先执行 + exception: 4 + logger: 3 http: 2 - params_validator: 3 - auth: 4 + auth: 1 + params_validator: 0 # only_api_key_auth: 4 rpc: # - server: diff --git a/gateway.json b/gateway.json deleted file mode 100644 index 825d680..0000000 --- a/gateway.json +++ /dev/null @@ -1,299 +0,0 @@ -{ - "/helloworld.Greeter/SayHello": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "post" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/post" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHello", - "HttpUri": "/api/v1/example/post", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloBody": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "body" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/body" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHelloBody", - "HttpUri": "/api/v1/example/body", - "PathParams": [], - "InName": "HttpBody", - "OutName": "HttpBody", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "google.api", - "OutPkg": "google.api" - } - ], - "/helloworld.Greeter/SayHelloClientStream": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "client", - "stream" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/client/stream" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHelloClientStream", - "HttpUri": "/api/v1/example/client/stream", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "RepeatedReply", - "IsClientStream": true, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloError": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "error", - "test" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/error/test" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloError", - "HttpUri": "/api/v1/example/error/test", - "PathParams": [], - "InName": "ErrorRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloGet": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 1, - 0, - 4, - 1, - 5, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "name" - ], - "Verb": "", - "Fields": [ - "name" - ], - "Template": "/api/v1/example/{name}" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloGet", - "HttpUri": "/api/v1/example/{name}", - "PathParams": [ - "name" - ], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloServerSideEvent": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4, - 1, - 0, - 4, - 1, - 5, - 5 - ], - "Pool": [ - "api", - "v1", - "example", - "server", - "sse", - "name" - ], - "Verb": "", - "Fields": [ - "name" - ], - "Template": "/api/v1/example/server/sse/{name}" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloServerSideEvent", - "HttpUri": "/api/v1/example/server/sse/{name}", - "PathParams": [ - "name" - ], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": true, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloWebsocket": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "server", - "websocket" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/server/websocket" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloWebsocket", - "HttpUri": "/api/v1/example/server/websocket", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": true, - "IsServerStream": true, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ] -} \ No newline at end of file diff --git a/gateway/endpoint.go b/gateway/endpoint.go index a91e85a..219e763 100644 --- a/gateway/endpoint.go +++ b/gateway/endpoint.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "fmt" "strings" loadbalance "github.com/begonia-org/go-loadbalancer" @@ -42,7 +43,7 @@ func (e *httpForwardGrpcEndpointImpl) Request(req GrpcRequest) (proto.Message, r return nil, runtime.ServerMetadata{ HeaderMD: make(map[string][]string), TrailerMD: make(map[string][]string), - }, err + }, fmt.Errorf("get conn error:%v", err) } defer e.pool.Release(req.GetContext(), cc) @@ -51,6 +52,7 @@ func (e *httpForwardGrpcEndpointImpl) Request(req GrpcRequest) (proto.Message, r in := req.GetIn() ctx := req.GetContext() err = conn.Invoke(ctx, req.GetFullMethodName(), in, out, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + // log.Printf("request %s out:%v",req.GetFullMethodName(), out.ProtoReflect().Type().Descriptor().FullName()) return out, metadata, err } diff --git a/gateway/exception.go b/gateway/exception.go index 7f3da65..017c188 100644 --- a/gateway/exception.go +++ b/gateway/exception.go @@ -9,8 +9,10 @@ import ( gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" + "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" ) type Exception struct { @@ -19,17 +21,34 @@ type Exception struct { name string } +func (e *Exception) setHeader(ctx context.Context) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + reqId := "" + if !ok || len(md.Get(XRequestID)) == 0 { + reqId = uuid.New().String() + if !ok { + md = metadata.New(make(map[string]string)) + } + md.Set(XRequestID, reqId) + ctx = metadata.NewIncomingContext(ctx, md) + + } else { + reqId = md.Get(XRequestID)[0] + } + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(XRequestID, reqId)) + + _ = grpc.SetHeader(ctx, metadata.Pairs(XRequestID, reqId)) + return ctx + +} func (e *Exception) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { defer func() { if p := recover(); p != nil { - // buf := make([]byte, 1024) - // n := sysRuntime.Stack(buf, false) // false 表示不需要所有goroutine的调用栈 - // stackTrace := string(buf[:n]) - // err = fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - // err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") err = e.handlePanic(p) } + }() + ctx = e.setHeader(ctx) resp, err = handler(ctx, req) if err == nil { return resp, err @@ -37,7 +56,7 @@ func (e *Exception) UnaryInterceptor(ctx context.Context, req interface{}, info return nil, err } func (e *Exception) handlePanic(p interface{}) error { - const maxFrames = 10 + const maxFrames = 15 var pcs [maxFrames]uintptr n := runtime.Callers(2, pcs[:]) // skip first 3 frames @@ -60,18 +79,50 @@ func (e *Exception) handlePanic(p interface{}) error { func (e *Exception) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { defer func() { if p := recover(); p != nil { - // buf := make([]byte, 512) - // n := sysRuntime.Stack(buf, false) // false 表示不需要所有goroutine的调用栈 - // stackTrace := string(buf[:n]) - - // err = fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - // err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") - // _ = ss.SendMsg(err) err = e.handlePanic(p) } + }() + e.setHeader(ss.Context()) return handler(srv, ss) } +func (e *Exception) wrapHandlerWithPanicRecovery(handler grpc.StreamHandler) grpc.StreamHandler { + return func(srv any, stream grpc.ServerStream) (err error) { + // reqId := "" + defer func() { + if p := recover(); p != nil { + err = e.handlePanic(p) + } + }() + + return handler(srv, stream) + } +} +func (e *Exception) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + reqId := "" + md, ok := metadata.FromOutgoingContext(ctx) + if !ok || len(md.Get(XRequestID)) == 0 { + reqID := uuid.New().String() + reqId = reqID + if !ok { + md = metadata.New(make(map[string]string)) + } + md.Set(XRequestID, reqId) + ctx = metadata.NewOutgoingContext(ctx, md) + + } else { + reqId = md.Get(XRequestID)[0] + } + + _ = grpc.SetHeader(ctx, metadata.Pairs(XRequestID, reqId)) + desc.Handler = e.wrapHandlerWithPanicRecovery(desc.Handler) + ss, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, err + + } + return ss, nil +} func NewException(log logger.Logger) *Exception { return &Exception{log: log, name: "exception"} } diff --git a/gateway/exception_test.go b/gateway/exception_test.go index be8175a..2f497bd 100644 --- a/gateway/exception_test.go +++ b/gateway/exception_test.go @@ -6,8 +6,10 @@ import ( "testing" "github.com/begonia-org/go-sdk/logger" + "github.com/google/uuid" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) type MiddlewaresTest struct { @@ -39,6 +41,21 @@ func (e *MiddlewaresTest) Name() string { return e.name } +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func TestUnaryInterceptor(t *testing.T) { c.Convey("TestUnaryInterceptor", t, func() { mid := NewException(Log) @@ -64,3 +81,33 @@ func TestUnaryInterceptor(t *testing.T) { }) } +func TestExceptionStreamClientInterceptor(t *testing.T) { + c.Convey("TestExceptionStreamClientInterceptor", t, func() { + mid := NewException(Log) + ctx := context.Background() + + desc := &grpc.StreamDesc{ + StreamName: "/INTEGRATION.TESTSERVICE/GET", + ClientStreams: true, + ServerStreams: true, + Handler: func(srv interface{}, ss grpc.ServerStream) error { + panic("test painc") + }, + } + st, err := mid.StreamClientInterceptor(ctx, desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: context.Background()}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + // has request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, uuid.New().String())), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: context.Background()}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, nil) + c.So(err, c.ShouldNotBeNil) + + }) +} diff --git a/gateway/gateway.go b/gateway/gateway.go index b5b6eca..926d613 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -18,7 +18,7 @@ import ( ) type GrpcServerOptions struct { - Middlewares []GrpcProxyMiddleware + Middlewares []grpc.StreamClientInterceptor Options []grpc.ServerOption PoolOptions []loadbalance.PoolOptionsBuildOption HttpMiddlewares []runtime.ServeMuxOption @@ -28,6 +28,8 @@ type GatewayConfig struct { GatewayAddr string GrpcProxyAddr string } +type DynamicGrpcItem struct { +} type GatewayServer struct { grpcServer *grpc.Server httpGateway HttpEndpoint @@ -37,14 +39,15 @@ type GatewayServer struct { proxyAddr string opts *GrpcServerOptions mux *sync.Mutex + proxy *GrpcProxy } -func NewGrpcServer(opts *GrpcServerOptions, lb *GrpcLoadBalancer) *grpc.Server { +func NewGrpcProxyServer(opts *GrpcServerOptions, lb *GrpcLoadBalancer) *GrpcProxy { - proxy := NewGrpcProxy(lb, opts.Middlewares...) + proxy := NewGrpcProxy(lb, Log, opts.Middlewares...) - opts.Options = append(opts.Options, grpc.UnknownServiceHandler(proxy.Handler)) - return grpc.NewServer(opts.Options...) + opts.Options = append(opts.Options, grpc.UnknownServiceHandler(proxy.Do)) + return proxy } func NewHttpServer(addr string, poolOpt ...loadbalance.PoolOptionsBuildOption) (HttpEndpoint, error) { @@ -57,7 +60,9 @@ func NewHttpServer(addr string, poolOpt ...loadbalance.PoolOptionsBuildOption) ( } func NewGateway(cfg *GatewayConfig, opts *GrpcServerOptions) *GatewayServer { lb := NewGrpcLoadBalancer() - grpcServer := NewGrpcServer(opts, lb) + gProxy := NewGrpcProxyServer(opts, lb) + opts.Options = append(opts.Options, grpc.UnknownServiceHandler(gProxy.Do)) + grpcServer := grpc.NewServer(opts.Options...) _, port, _ := net.SplitHostPort(cfg.GrpcProxyAddr) proxy := fmt.Sprintf("127.0.0.1:%s", port) @@ -77,6 +82,7 @@ func NewGateway(cfg *GatewayConfig, opts *GrpcServerOptions) *GatewayServer { proxyAddr: cfg.GrpcProxyAddr, opts: opts, mux: &sync.Mutex{}, + proxy: gProxy, } // }) return gatewayS @@ -97,6 +103,13 @@ func (g *GatewayServer) RegisterLocalService(ctx context.Context, pd ProtobufDes g.grpcServer.RegisterService(sd, ss) return g.httpGateway.RegisterHandlerClient(ctx, pd, g.gatewayMux) } +func (g *GatewayServer) RegisterServiceWithProxy(pd ProtobufDescription) { + g.mux.Lock() + defer g.mux.Unlock() + g.proxy.buildServiceDesc(pd) + +} + func (g *GatewayServer) DeleteLocalService(pd ProtobufDescription) { g.mux.Lock() defer g.mux.Unlock() diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go new file mode 100644 index 0000000..1283661 --- /dev/null +++ b/gateway/gateway_test.go @@ -0,0 +1,35 @@ +package gateway + +import ( + "fmt" + "log" + "os" + "path/filepath" + "testing" + + "github.com/begonia-org/begonia/internal/pkg/config" + common "github.com/begonia-org/go-sdk/common/api/v1" +) + +func readDesc(conf *config.Config) (ProtobufDescription, error) { + desc := conf.GetLocalAPIDesc() + log.Printf("read desc file:%s", desc) + bin, err := os.ReadFile(desc) + if err != nil { + return nil, fmt.Errorf("read desc file error:%w", err) + } + pd, err := NewDescriptionFromBinary(bin, filepath.Dir(desc)) + if err != nil { + return nil, err + } + err = pd.SetHttpResponse(common.E_HttpResponse) + if err != nil { + return nil, err + } + return pd, nil +} +func TestRegisterDynamicServices(t *testing.T) { + // pd, _ := readDesc(config.NewConfig(cfg.ReadConfig("test"))) + // _ = &GatewayServer{} + // gw.buildServiceDesc(pd) +} diff --git a/gateway/http.go b/gateway/http.go index fc01fc3..2e28889 100644 --- a/gateway/http.go +++ b/gateway/http.go @@ -5,6 +5,7 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "fmt" "io" "log" @@ -13,6 +14,7 @@ import ( "strings" "sync" + gosdk "github.com/begonia-org/go-sdk" "github.com/gorilla/websocket" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "google.golang.org/grpc/codes" @@ -88,6 +90,9 @@ func loadHttpEndpointItem(pd ProtobufDescription, descFile string) ([]*HttpEndpo return nil, fmt.Errorf("Failed to unmarshal %s file: %w,%s", descFile, err, string(data)) } for _, binds := range items { + if len(binds) == 0 { + continue + } item := binds[0] // 设置入参和出参 item.In = pd.GetMessageTypeByName(item.InPkg, item.InName) @@ -208,7 +213,7 @@ func (h *HttpEndpointImpl) serverStreamRequest(ctx context.Context, item *HttpEn dec := marshaler.NewDecoder(req.Body) err := dec.Decode(protoReq) - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } } @@ -297,7 +302,7 @@ func (h *HttpEndpointImpl) addHexEncodeSHA256HashV2(req *http.Request) error { if req.ContentLength == 0 { hashStruct.Write([]byte("{}")) hexStr := fmt.Sprintf("%x", hashStruct.Sum(nil)) - req.Header.Set("X-Content-Sha256", hexStr) + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) return nil } @@ -368,6 +373,7 @@ func (h *HttpEndpointImpl) DeleteEndpoint(ctx context.Context, pd ProtobufDescri } return nil } + func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd ProtobufDescription, mux *runtime.ServeMux) error { h.mux.Lock() defer h.mux.Unlock() @@ -384,12 +390,17 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu // log.Printf("register endpoint %s: %s %v", strings.ToUpper(item.HttpMethod), item.HttpUri, item.Pattern) mux.Handle(strings.ToUpper(item.HttpMethod), item.Pattern, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { if req.Header.Get("accept") == "" || req.Header.Get("accept") == "*/*" { - req.Header.Set("accept", "application/json") + if item.IsServerStream && !item.IsClientStream { + req.Header.Set("accept", "text/event-stream") + } else if !item.IsClientStream && !item.IsServerStream { + req.Header.Set("accept", "application/json") + } } - log.Printf("request content-type:%s", req.Header.Get("content-type")) + // log.Printf("request content-type:%s", req.Header.Get("content-type")) ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + // log.Printf("outbound marshaler:%s", outboundMarshaler.ContentType(req)) var err error var annotatedContext context.Context // 添加sha256 hash @@ -421,12 +432,12 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu return } resp, md, err := h.client.Request(reqInstance) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + // log.Printf("response marshaler:%s",outboundMarshaler.ContentType(resp)) runtime.ForwardResponseMessage(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) } else if item.IsServerStream && !item.IsClientStream { // 服务端推流,升级为sse服务 @@ -436,9 +447,14 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu return } annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") recv := func() (proto.Message, error) { - return resp.Recv() + + rsp, err := resp.Recv() + return rsp, err + } runtime.ForwardResponseStream(annotatedContext, mux, outboundMarshaler, w, req, recv, mux.GetForwardResponseOptions()...) } else if !item.IsServerStream && item.IsClientStream { @@ -460,6 +476,7 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu return } // defer ws.Close() + log.Printf("upgrade to websocket:%v", outboundMarshaler.ContentType(req)) stream, md, err := h.stream(annotatedContext, item, inboundMarshaler, ws) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) diff --git a/gateway/http_test.go b/gateway/http_test.go index 6bb8c5d..bfe6b37 100644 --- a/gateway/http_test.go +++ b/gateway/http_test.go @@ -51,7 +51,7 @@ var eps []loadbalance.Endpoint func newTestServer(gwPort, randomNumber int) (*GrpcServerOptions, *GatewayConfig) { opts := &GrpcServerOptions{ - Middlewares: make([]GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), @@ -68,6 +68,8 @@ func newTestServer(gwPort, randomNumber int) (*GrpcServerOptions, *GatewayConfig opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption(gwRuntime.MIMEWildcard, NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption("application/octet-stream", NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption("text/event-stream", NewEventSourceMarshaler())) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithStreamErrorHandler(HandleServerStreamError(Log))) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption(ClientStreamContentType, NewProtobufWithLengthPrefix())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMetadata(IncomingHeadersToMetadata)) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithErrorHandler(HandleErrorWithLogger(Log))) @@ -131,6 +133,7 @@ func testRegisterClient(t *testing.T) { load, err := loadbalance.New(loadbalance.RRBalanceType, endps) c.So(err, c.ShouldBeNil) + gw.RegisterServiceWithProxy(pd) err = gw.RegisterService(context.Background(), pd, load) c.So(err, c.ShouldBeNil) c.So(gw.GetLoadbalanceName(), c.ShouldEqual, loadbalance.RRBalanceType) @@ -194,7 +197,7 @@ func testRequestGet(t *testing.T) { c.So(err, c.ShouldBeNil) c.So(resp2.StatusCode, c.ShouldEqual, http.StatusNotImplemented) - // test appkey + // test appkey url = fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/world?msg=hello", gwPort) r, err = http.NewRequest(http.MethodGet, url, nil) r.Header.Set(XApiKey, "12345678") @@ -319,8 +322,6 @@ func testRequestPost(t *testing.T) { func testServerSideEvent(t *testing.T) { c.Convey("test server side event", t, func() { url := fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/server/sse/world?msg=hello", gwPort) - // t.Logf("url:%s", url) - // time.Sleep(30 * time.Second) client := sse.NewClient(url, func(c *sse.Client) { c.ReconnectStrategy = &backoff.StopBackOff{} }) @@ -364,6 +365,7 @@ func testWebsocket(t *testing.T) { _, message, err := conn.ReadMessage() c.So(err, c.ShouldBeNil) reply := &hello.HelloReply{} + // t.Logf("read message:%s", string(message)) err = json.Unmarshal(message, reply) c.So(err, c.ShouldBeNil) c.So(reply.Message, c.ShouldEqual, fmt.Sprintf("hello-%d-%d", i, i)) @@ -661,6 +663,10 @@ func testRequestError(t *testing.T) { patch: (*GrpcLoadBalancer).Select, output: []interface{}{nil, fmt.Errorf("test select error")}, }, + { + patch: (*goloadbalancer.ConnPool).Get, + output: []interface{}{nil, fmt.Errorf("test get error")}, + }, { patch: (*GrpcProxy).forwardServerToClient, output: []interface{}{errChan}, diff --git a/gateway/mask.go b/gateway/mask.go index f62dfef..6499b65 100644 --- a/gateway/mask.go +++ b/gateway/mask.go @@ -5,7 +5,9 @@ import ( "encoding/json" "io" + common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/fieldmaskpb" ) @@ -22,7 +24,6 @@ func SetUpdateMaskFields(message protoreflect.ProtoMessage, fields []string) { // 遍历所有字段 for i := 0; i < md.Fields().Len(); i++ { field := md.Fields().Get(i) - // 检查字段是否是FieldMask类型 if field.Message() != nil && field.Message().FullName() == "google.protobuf.FieldMask" { // 获取字段的值(确保它是FieldMask类型) @@ -42,6 +43,21 @@ type maskDecoder struct { newDecoder func(r io.Reader) runtime.Decoder } +func FilterUnUpdatedFields(message protoreflect.ProtoMessage, mask []string) []string { + md := message.ProtoReflect().Descriptor() + filters := make([]string, 0) + for _, field := range mask { + if fd := md.Fields().ByJSONName(field); fd != nil { + // opt, ok := proto.GetExtension(fd.Options(), common.E_UnUpdatable).(bool) + // log.Printf("field:%v,ok:%v,opt:%v", fd.JSONName(),ok,opt) + + if opt, ok := proto.GetExtension(fd.Options(), common.E_UnUpdatable).(bool); !ok || !opt { + filters = append(filters, field) + } + } + } + return filters +} func NewJsonDecoder(r io.Reader) runtime.Decoder { return runtime.DecoderWrapper{Decoder: json.NewDecoder(r)} } @@ -82,6 +98,8 @@ func (d *maskDecoder) Decode(v interface{}) error { return err } // 设置更新掩码字段 + fields = FilterUnUpdatedFields(message, fields) + // log.Printf("fields:%v", fields) SetUpdateMaskFields(message, fields) return nil diff --git a/gateway/middlewares.go b/gateway/middlewares.go index 0556228..ef6ef24 100644 --- a/gateway/middlewares.go +++ b/gateway/middlewares.go @@ -9,7 +9,7 @@ import ( "time" gosdk "github.com/begonia-org/go-sdk" - _ "github.com/begonia-org/go-sdk/api" + // _ "github.com/begonia-org/go-sdk/api" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" "github.com/google/uuid" @@ -28,15 +28,15 @@ import ( ) const ( - XRequestID = "x-request-id" - XUID = "x-uid" - XAccessKey = "x-access-key" - XHttpMethod = "x-http-method" - XRemoteAddr = "x-http-forwarded-for" - XProtocol = "x-http-protocol" - XHttpURI = "x-http-uri" - XIdentity = "x-identity" - XApiKey = "x-api-key" + XRequestID = "x-request-id" + XUID = "x-uid" + XAccessKey = "x-access-key" + XHttpMethod = "x-http-method" + XRemoteAddr = "x-http-forwarded-for" + XProtocol = "x-http-protocol" + XHttpURI = "x-http-uri" + XIdentity = "x-identity" + XApiKey = "x-api-key" XIdentityType = "x-identity-type" ) @@ -117,7 +117,7 @@ func IncomingHeadersToMetadata(ctx context.Context, req *http.Request) metadata. md.Set(gosdk.GetMetadataKey(XRequestID), reqID) xuid := md.Get(XUID) accessKey := md.Get(XAccessKey) - apikey:=md.Get(XApiKey) + apikey := md.Get(XApiKey) author := "" // idType := gosdk.UidType if len(xuid) > 0 { @@ -127,7 +127,7 @@ func IncomingHeadersToMetadata(ctx context.Context, req *http.Request) metadata. author = accessKey[0] // idType = gosdk.AccessKeyType } - if author == ""&& len(apikey)>0{ + if author == "" && len(apikey) > 0 { author = apikey[0] // idType = gosdk.ApiKeyType } @@ -182,13 +182,53 @@ func (log *LoggerMiddleware) UnaryInterceptor(ctx context.Context, req interface elapsed := time.Since(now) log.logger(ctx, info.FullMethod, err, elapsed) } - }() + if md, ok := metadata.FromIncomingContext(ctx); ok { + reqId := md.Get(XRequestID) + if len(reqId) > 0 { + _ = grpc.SendHeader(ctx, metadata.Pairs(XRequestID, reqId[0])) + } + } + }() + // fmt.Printf("call logger mid\n") rsp, err = handler(ctx, req) elapsed := time.Since(now) + // fmt.Printf("logger error:%v", err) log.logger(ctx, info.FullMethod, err, elapsed) return } +func (log *LoggerMiddleware) wrapHandlerWithLogger(handler grpc.StreamHandler) grpc.StreamHandler { + return func(srv interface{}, ss grpc.ServerStream) (err error) { + now := time.Now() + defer func() { + if r := recover(); r != nil { + elapsed := time.Since(now) + method, _ := grpc.Method(ss.Context()) + err = fmt.Errorf("handle err:%v", r) + log.logger(ss.Context(), method, fmt.Errorf("handle err:%v", r), elapsed) + } + if md, ok := metadata.FromIncomingContext(ss.Context()); ok { + reqId := md.Get(XRequestID) + if len(reqId) > 0 { + _ = grpc.SendHeader(ss.Context(), metadata.Pairs(XRequestID, reqId[0])) + } + } + }() + err = handler(srv, ss) + elapsed := time.Since(now) + method, _ := grpc.Method(ss.Context()) + + log.logger(ss.Context(), method, err, elapsed) + return err + } + +} +func (log *LoggerMiddleware) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + + desc.Handler = log.wrapHandlerWithLogger(desc.Handler) + return streamer(ctx, desc, cc, method, opts...) + +} func (log *LoggerMiddleware) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { now := time.Now() defer func() { @@ -196,6 +236,12 @@ func (log *LoggerMiddleware) StreamInterceptor(srv interface{}, ss grpc.ServerSt elapsed := time.Since(now) log.logger(ss.Context(), info.FullMethod, err, elapsed) } + if md, ok := metadata.FromIncomingContext(ss.Context()); ok { + reqId := md.Get(XRequestID) + if len(reqId) > 0 { + _ = grpc.SendHeader(ss.Context(), metadata.Pairs(XRequestID, reqId[0])) + } + } }() ctx := ss.Context() @@ -216,6 +262,9 @@ func getClientMessageMap() map[int32]string { if msg := proto.GetExtension(opts, common.E_Msg); msg != nil { codes[int32(v.Number())] = msg.(string) } + if msg := proto.GetExtension(opts, common.E_Description); msg != nil { + codes[int32(v.Number())] = msg.(string) + } } return true }) @@ -237,6 +286,42 @@ func clientMessageFromCode(code codes.Code) string { } } +// HandleServerStreamError handle server stream error +// +// convert error to status.Status, and get error message from error details, +// try to get error message from error details message or ToClientMessage, if not found, get error message from error code, +// if not found, return "internal error" +func HandleServerStreamError(logger logger.Logger) runtime.StreamErrorHandlerFunc { + return func(ctx context.Context, err error) *status.Status { + if st, ok := status.FromError(err); ok { + details := st.Details() + message := clientMessageFromCode(st.Code()) + for _, detail := range details { + var errDetail *common.Errors = new(common.Errors) + + if d, ok := detail.(*common.Errors); ok { + + errDetail = d + } else if anyType, ok := detail.(*anypb.Any); ok { + if err := anyType.UnmarshalTo(errDetail); err != nil { + Log.Errorf(ctx, "unmarshal error details err:%v", err) + continue + } + } + if errDetail.Message != "" { + message = errDetail.Message + } + if errDetail.ToClientMessage != "" { + message = errDetail.ToClientMessage + } + } + return status.New(st.Code(), message) + } + return status.New(codes.Internal, fmt.Sprintf("Unknown error:%v", err)) + + } + +} func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { codes := getClientMessageMap() return func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, req *http.Request, err error) { @@ -248,7 +333,7 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { "status": statusCode, }, ) - fmt.Printf("error type:%T, error:%v", err, err) + // fmt.Printf("error type:%T, error:%v", err, err) if _, ok := metadata.FromIncomingContext(ctx); !ok { md := IncomingHeadersToMetadata(ctx, req) ctx = metadata.NewIncomingContext(ctx, md) @@ -257,44 +342,53 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { data := &common.HttpResponse{} data.Code = int32(common.Code_INTERNAL_ERROR) data.Message = "internal error" + // fmt.Printf("sse error type:%T, error:%v", err, err) if st, ok := status.FromError(err); ok { msg := st.Message() details := st.Details() data.Message = clientMessageFromCode(st.Code()) - // data.Data = &structpb.Struct{} for _, detail := range details { - if anyType, ok := detail.(*anypb.Any); ok { - var errDetail common.Errors - if err := anyType.UnmarshalTo(&errDetail); err == nil { - rspCode := float64(errDetail.Code) - log = log.WithFields(logrus.Fields{ - "status": int(rspCode), - "file": errDetail.File, - "line": errDetail.Line, - "fn": errDetail.Fn, - }) - - msg := codes[int32(errDetail.Code)] - if errDetail.ToClientMessage != "" { - msg = errDetail.ToClientMessage - } - - data.Code = errDetail.Code - data.Message = msg - data.Data = &structpb.Struct{} - break + var errDetail *common.Errors = new(common.Errors) + + if d, ok := detail.(*common.Errors); ok { + + errDetail = d + } else if anyType, ok := detail.(*anypb.Any); ok { + if err := anyType.UnmarshalTo(errDetail); err != nil { + log.Errorf(ctx, "error type:%T, error:%v", err, err) + continue } } else { log.Errorf(ctx, "error type:%T, error:%v", err, err) } + if errDetail.Message != "" { + rspCode := float64(errDetail.Code) + log = log.WithFields(logrus.Fields{ + "status": int(rspCode), + "file": errDetail.File, + "line": errDetail.Line, + "fn": errDetail.Fn, + }) + + msg := codes[int32(errDetail.Code)] + if errDetail.ToClientMessage != "" { + msg = errDetail.ToClientMessage + } + data.Code = errDetail.Code + data.Message = msg + data.Data = &structpb.Struct{} + break + } } + // fmt.Printf("error message:%s,err code:%d", data.Message, st.Code()) code = runtime.HTTPStatusFromCode(st.Code()) log.WithField("status", code).Errorf(ctx, msg) w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) + log.Errorf(ctx, "error message:%s,err code:%d", data.Message, data.Code) bData, _ := protojson.Marshal(data) _, _ = w.Write(bData) return @@ -312,8 +406,9 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { for _, v := range value { w.Header().Del(key) if v != "" { + // log.Printf("http key:%s, value:%s", httpKey, v) if strings.EqualFold(httpKey, "Content-Type") { - if v == "application/grpc" { + if strings.EqualFold(v, "application/grpc") { continue } w.Header().Set(httpKey, v) @@ -328,6 +423,9 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { return } } + if strings.HasPrefix(strings.ToLower(httpKey), strings.ToLower("Grpc-")) { + continue + } headers = append(headers, http.CanonicalHeaderKey(httpKey)) w.Header().Set("Access-Control-Expose-Headers", strings.Join(headers, ",")) @@ -339,31 +437,37 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { } func HttpResponseBodyModify(ctx context.Context, w http.ResponseWriter, msg proto.Message) error { httpCode := http.StatusOK - for key, value := range w.Header() { - if strings.HasPrefix(key, "Grpc-Metadata-") { + // log.Printf("response message header:%v", w.Header()) + for key := range w.Header() { + if strings.HasPrefix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("Grpc-")) { w.Header().Del(key) + continue } - writeHttpHeaders(w, key, value) - if strings.HasSuffix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("X-Http-Code")) { - codeStr := value[0] - code, err := strconv.ParseInt(codeStr, 10, 32) - if err != nil { - Log.Error(ctx, err) - return status.Error(codes.Internal, err.Error()) - } - httpCode = int(code) + } + if out, ok := runtime.ServerMetadataFromContext(ctx); ok { + // log.Printf("response message header from server metadata:%v", out.HeaderMD) + for key, value := range out.HeaderMD { - } + if strings.HasPrefix(strings.ToLower(key), strings.ToLower("Grpc-")) || strings.EqualFold(key, "content-type") { + continue - } + } + writeHttpHeaders(w, key, value) + if strings.HasSuffix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("X-Http-Code")) { + codeStr := value[0] + code, err := strconv.ParseInt(codeStr, 10, 32) + if err != nil { + Log.Error(ctx, err) + return status.Error(codes.Internal, err.Error()) + } + httpCode = int(code) + + } - out, ok := metadata.FromIncomingContext(ctx) - if ok { - for k, v := range out { - writeHttpHeaders(w, k, v) } } + if httpCode != http.StatusOK { w.WriteHeader(httpCode) } diff --git a/gateway/middlewares_test.go b/gateway/middlewares_test.go index b23d342..c647f16 100644 --- a/gateway/middlewares_test.go +++ b/gateway/middlewares_test.go @@ -2,14 +2,21 @@ package gateway import ( "context" + "fmt" "net/http" "testing" + "github.com/agiledragon/gomonkey/v2" hello "github.com/begonia-org/go-sdk/api/example/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/google/uuid" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" ) type responseWriter struct { @@ -25,20 +32,20 @@ func (r *responseWriter) Write([]byte) (int, error) { func (r *responseWriter) WriteHeader(int) { } -func TestClientMessageFromCode(t *testing.T){ - c.Convey("TestClientMessageFromCode",t,func(){ - msg:=clientMessageFromCode(codes.NotFound) - c.So(msg,c.ShouldContainSubstring,"not found") +func TestClientMessageFromCode(t *testing.T) { + c.Convey("TestClientMessageFromCode", t, func() { + msg := clientMessageFromCode(codes.NotFound) + c.So(msg, c.ShouldContainSubstring, "not found") msg = clientMessageFromCode(codes.ResourceExhausted) - c.So(msg,c.ShouldContainSubstring,"resource size exceeds") + c.So(msg, c.ShouldContainSubstring, "resource size exceeds") msg = clientMessageFromCode(codes.AlreadyExists) - c.So(msg,c.ShouldContainSubstring,"already exists") + c.So(msg, c.ShouldContainSubstring, "already exists") msg = clientMessageFromCode(codes.DataLoss) - c.So(msg,c.ShouldContainSubstring,"Unknown error") + c.So(msg, c.ShouldContainSubstring, "Unknown error") }) } func TestLoggerMiddlewares(t *testing.T) { - mid :=NewLoggerMiddleware(Log) + mid := NewLoggerMiddleware(Log) c.Convey("TestLoggerMiddlewares panic", t, func() { f := func() { _, _ = mid.UnaryInterceptor(context.TODO(), nil, &grpc.UnaryServerInfo{ @@ -49,11 +56,46 @@ func TestLoggerMiddlewares(t *testing.T) { } c.So(f, c.ShouldNotPanic) f2 := func() { - _ = mid.StreamInterceptor(nil, &streamMock{}, &grpc.StreamServerInfo{FullMethod: "/test"}, func(srv interface{}, ss grpc.ServerStream) error { + _ = mid.StreamInterceptor(nil, &streamMock{ctx: context.Background()}, &grpc.StreamServerInfo{FullMethod: "/test"}, func(srv interface{}, ss grpc.ServerStream) error { panic("test") }) } + // f2() c.So(f2, c.ShouldNotPanic) + + desc := &grpc.StreamDesc{ + StreamName: "/INTEGRATION.TESTSERVICE/GET", + ClientStreams: true, + ServerStreams: true, + Handler: func(srv interface{}, ss grpc.ServerStream) error { + panic("test painc") + }, + } + patch := gomonkey.ApplyFuncReturn(grpc.Method, "/INTEGRATION.TESTSERVICE/GET", true) + defer patch.Reset() + st, err := mid.StreamClientInterceptor(context.Background(), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + // has request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, uuid.New().String())), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, &streamMock{ctx: context.Background()}) + c.So(err, c.ShouldNotBeNil) + // no request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, "test")), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, "test"))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, &streamMock{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs(XRequestID, "test"))}) + c.So(err, c.ShouldNotBeNil) + }) } @@ -71,8 +113,74 @@ func TestHttpResponseBodyModify(t *testing.T) { c.Convey("TestHttpResponseBodyModify", t, func() { resp := &responseWriter{header: make(http.Header)} - ctx:=metadata.NewIncomingContext(context.Background(),metadata.Pairs(XAccessKey,"123456")) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(XAccessKey, "123456")) + header := metadata.New(map[string]string{"Content-Type": "application/grpc", "Grpc-Metadata-key": "value", "Grpc-Key": "grpc-val"}) + patch := gomonkey.ApplyFuncReturn(runtime.ServerMetadataFromContext, runtime.ServerMetadata{HeaderMD: header, TrailerMD: metadata.New(make(map[string]string))}, true) + defer patch.Reset() resp2 := HttpResponseBodyModify(ctx, resp, &hello.HelloReply{}) c.So(resp2, c.ShouldBeNil) + for k, v := range header { + writeHttpHeaders(resp, k, v) + } + }) +} + +func TestHandleErrorWithLogger(t *testing.T) { + c.Convey("TestHandleErrorWithLogger", t, func() { + f := HandleErrorWithLogger(Log) + resp := &responseWriter{header: make(http.Header)} + req, _ := http.NewRequest("Get", "http://www.example.com", nil) + st := status.New(codes.NotFound, "not found") + srvErr := &common.Errors{ + Code: int32(common.Code_NOT_FOUND), + Message: "not found", + Action: "action", + File: "file", + Line: int32(0), + Fn: "funcName", + } + st, _ = st.WithDetails(srvErr) + f(metadata.NewIncomingContext(context.Background(), metadata.Pairs(XRequestID, "123456")), &runtime.ServeMux{}, nil, resp, req, st.Err()) + st1 := status.New(codes.NotFound, "not found") + st1, _ = st1.WithDetails(&common.APIResponse{}) + f(metadata.NewIncomingContext(context.Background(), metadata.Pairs(XRequestID, "123456")), &runtime.ServeMux{}, nil, resp, req, st1.Err()) + + }) +} + +func TestHandleServerStreamError(t *testing.T) { + c.Convey("TestHandleServerStreamError", t, func() { + f := HandleServerStreamError(Log) + st := status.New(codes.NotFound, "not found") + srvErr := &common.Errors{ + Code: int32(common.Code_NOT_FOUND), + Message: "not found", + Action: "action", + File: "file", + Line: int32(0), + Fn: "funcName", + ToClientMessage: "not found resource", + } + st, _ = st.WithDetails(srvErr) + err := f(context.Background(), st.Err()) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "not found resource") + + err = f(context.Background(), status.Error(codes.Internal, "internal error")) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "Unknown error") + err = f(context.Background(), fmt.Errorf("test error")) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "test error") + + patch := gomonkey.ApplyFuncReturn((*anypb.Any).UnmarshalTo, fmt.Errorf("test error")) + defer patch.Reset() + ay, _ := anypb.New(srvErr) + st = status.New(codes.NotFound, "not found") + + st, _ = st.WithDetails(ay) + err = f(context.Background(), st.Err()) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "The requested resource is not found.") }) } diff --git a/gateway/protobuf.go b/gateway/protobuf.go index d1b80b9..2bce1d6 100644 --- a/gateway/protobuf.go +++ b/gateway/protobuf.go @@ -159,6 +159,7 @@ func NewDescriptionFromBinary(data []byte, outDir string) (ProtobufDescription, return nil, err } // desc.gatewayJsonSchema = filepath.Join(outDir, "gateway.json") + // log.Printf("GetFileDescriptorSet result is :%v",desc.GetFileDescriptorSet()) contents, err := register.Register(desc.GetFileDescriptorSet(), false, "") if err != nil { return nil, fmt.Errorf("Failed to register: %w", err) @@ -196,6 +197,7 @@ func (p *protobufDescription) GetMessageTypeByFullName(fullName string) protoref v := desc.(protoreflect.MessageDescriptor) return v } + // log.Printf("GetMessageTypeByFullName failed:%s", fullName) return nil } func (p *protobufDescription) GetGatewayJsonSchema() string { diff --git a/gateway/grpc.go b/gateway/proxy.go similarity index 58% rename from gateway/grpc.go rename to gateway/proxy.go index fe59bdf..77bc98b 100644 --- a/gateway/grpc.go +++ b/gateway/proxy.go @@ -2,13 +2,16 @@ package gateway import ( "context" + "errors" "fmt" "io" + "runtime/debug" "strings" "sync" "time" loadbalance "github.com/begonia-org/go-loadbalancer" + "github.com/begonia-org/go-sdk/logger" "github.com/spark-lence/tiga" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -16,7 +19,8 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) type grpcEndpointImpl struct { @@ -128,17 +132,34 @@ func (g *GrpcLoadBalancer) Select(method string, args ...interface{}) (loadbalan return nil, loadbalance.ErrNoEndpoint } +type IOType struct { + In protoreflect.MessageDescriptor + Out protoreflect.MessageDescriptor +} type GrpcProxyMiddleware func(srv interface{}, serverStream grpc.ServerStream) error type GrpcProxy struct { - lb *GrpcLoadBalancer - middlewares []GrpcProxyMiddleware + lb *GrpcLoadBalancer + // middlewares []grpc.StreamServerInterceptor + // chainStreamInts []grpc.StreamServerInterceptor + streamInt grpc.StreamClientInterceptor + chainClientStream []grpc.StreamClientInterceptor + ioType map[string]*IOType + log logger.Logger } -func NewGrpcProxy(lb *GrpcLoadBalancer, middlewares ...GrpcProxyMiddleware) *GrpcProxy { - return &GrpcProxy{ - lb: lb, - middlewares: middlewares, +func NewGrpcProxy(lb *GrpcLoadBalancer, log logger.Logger, middlewares ...grpc.StreamClientInterceptor) *GrpcProxy { + g := &GrpcProxy{ + lb: lb, + // middlewares: middlewares, + chainClientStream: middlewares, + ioType: make(map[string]*IOType), + log: log, } + g.chainStreamClientInterceptors() + return g +} +func (g *GrpcProxy) Register(lb loadbalance.LoadBalance, pd ProtobufDescription) { + g.lb.Register(lb, pd) } func (g *GrpcProxy) getClientIP(ctx context.Context) (string, error) { peer, ok := peer.FromContext(ctx) @@ -156,14 +177,53 @@ func (g *GrpcProxy) getXForward(ctx context.Context) []string { return md.Get("X-Forwarded-For") } -func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) error { +// chainStreamServerInterceptors chains all stream server interceptors into one. +func (g *GrpcProxy) chainStreamClientInterceptors() { + // Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will + // be executed before any other chained interceptors. + interceptors := g.chainClientStream + if g.streamInt != nil { + interceptors = append([]grpc.StreamClientInterceptor{g.streamInt}, g.chainClientStream...) + } - // 执行中间件 - for _, middleware := range g.middlewares { - if err := middleware(srv, serverStream); err != nil { - return err - } + var chainedInt grpc.StreamClientInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = g.chainStreamInterceptors(interceptors) } + + g.streamInt = chainedInt +} + +func (g *GrpcProxy) chainStreamInterceptors(interceptors []grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return interceptors[0](ctx, desc, cc, method, g.getChainStreamHandler(interceptors, 0, streamer)) + } +} + +func (g *GrpcProxy) getChainStreamHandler(interceptors []grpc.StreamClientInterceptor, curr int, finalHandler grpc.Streamer) grpc.Streamer { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return interceptors[curr+1](ctx, desc, cc, method, g.getChainStreamHandler(interceptors, curr+1, finalHandler)) + } +} + +func (g *GrpcProxy) Do(srv interface{}, serverStream grpc.ServerStream) error { + + return g.Handler(srv, serverStream) +} +func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) error { + defer func() { + if p := recover(); p != nil { + s := debug.Stack() + g.log.Errorf(serverStream.Context(), "panic recover! p: %v stack:%s", p, s) + } + }() // 获取方法名 fullMethodName, ok := grpc.MethodFromServerStream(serverStream) if !ok { @@ -193,6 +253,7 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err defer endpoint.AfterTransform(serverStream.Context(), cn.((loadbalance.Connection))) conn := cn.(loadbalance.Connection).ConnInstance().(*grpc.ClientConn) + clientCtx, clientCancel := context.WithCancel(serverStream.Context()) defer clientCancel() proxyDesc := &grpc.StreamDesc{ @@ -202,38 +263,53 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err // 添加本地ip local, _ := tiga.GetLocalIP() xForwards = append(xForwards, local) - md, ok := metadata.FromIncomingContext(clientCtx) + md, ok := metadata.FromIncomingContext(serverStream.Context()) if !ok { md = metadata.MD{} } + md.Set("X-Forwarded-For", strings.Join(xForwards, ",")) clientCtx = metadata.NewOutgoingContext(clientCtx, md) + var clientStream grpc.ClientStream = nil - clientStream, err := grpc.NewClientStream(clientCtx, proxyDesc, conn, fullMethodName) - if err != nil { - return err + if len(g.chainClientStream) > 0 && g.streamInt != nil { + clientStream, err = g.streamInt(clientCtx, proxyDesc, conn, fullMethodName, g.getChainStreamHandler(g.chainClientStream, 0, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return grpc.NewClientStream(ctx, desc, cc, method, opts...) + })) + if err != nil { + return fmt.Errorf("failed creating client stream from stream client int: %v", err) + } + } else { + clientStream, err = grpc.NewClientStream(clientCtx, proxyDesc, conn, fullMethodName) + if err != nil { + return err + } } + + // proxyStream := &proxyClientStream{ClientStream: clientStream, ctx: clientCtx} // 转发流量 // 从客户端到服务端 s2cErrChan := g.forwardServerToClient(serverStream, clientStream) + // 从服务端到客户端 c2sErrChan := g.forwardClientToServer(clientStream, serverStream) for i := 0; i < 2; i++ { select { case s2cErr := <-s2cErrChan: - if s2cErr == io.EOF { + if errors.Is(s2cErr, io.EOF) { // this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./ // the clientStream>serverStream may continue pumping though. + // log.Printf("s2cErr:%v", s2cErr) err = clientStream.CloseSend() if err != nil { - return status.Errorf(codes.Internal, "failed closing client stream: %v", err) + return fmt.Errorf("failed closing client stream: %w", err) } } else { // however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need // to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and // exit with an error to the stack clientCancel() - return status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) + return fmt.Errorf("failed proxying s2c: %w", s2cErr) } case c2sErr := <-c2sErrChan: // This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two @@ -241,7 +317,8 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err // will be nil. serverStream.SetTrailer(clientStream.Trailer()) // c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error. - if c2sErr != io.EOF { + if !errors.Is(c2sErr, io.EOF) { + // log.Printf("c2sErr:%v", c2sErr) return c2sErr } return nil @@ -253,10 +330,21 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error { ret := make(chan error, 1) go func() { - f := &emptypb.Empty{} + defer func() { + if p := recover(); p != nil { + s := debug.Stack() + g.log.Errorf(dst.Context(), "panic recover! p: %v stack:%s", p, s) + ret <- fmt.Errorf("panic recover! p: %v stack:%s", p, s) + } + }() + method, _ := grpc.Method(dst.Context()) + io := g.ioType[method] + f := dynamicpb.NewMessage(io.Out) + // f := &emptypb.Empty{} + for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { - ret <- err // this can be io.EOF which is happy case + ret <- fmt.Errorf("fail forward client to server:%w", err) // this can be io.EOF which is happy case break } if i == 0 { @@ -264,18 +352,25 @@ func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.Server // received but must be written to server stream before the first msg is flushed. // This is the only place to do it nicely // 先转发header + // inMD, _ := metadata.FromIncomingContext(src.Context()) + // outMD, _ := metadata.FromOutgoingContext(src.Context()) + + // md := metadata.Join(inMD, outMD) + // m := make(map[string][]string) + md, err := src.Header() if err != nil { - ret <- err + ret <- fmt.Errorf("failed reading header from client stream: %w", err) break } + if err := dst.SendHeader(md); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending header to server stream: %w", err) break } } if err := dst.SendMsg(f); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending msg to server stream: %w", err) break } } @@ -286,17 +381,46 @@ func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.Server func (g *GrpcProxy) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error { ret := make(chan error, 1) go func() { - f := &emptypb.Empty{} + method, _ := grpc.Method(dst.Context()) + io := g.ioType[method] + f := dynamicpb.NewMessage(io.In) + for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { - ret <- err // this can be io.EOF which is happy case + ret <- fmt.Errorf("recv msg error:%w from client forward", err) // this can be io.EOF which is happy case break } + if err := dst.SendMsg(f); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending msg to client stream: %w", err) break } } }() return ret } + +func (g *GrpcProxy) buildServiceDesc(pd ProtobufDescription) { + fd := pd.GetFileDescriptorSet() + // sds := make([]*grpc.ServiceDesc, 0) + for _, file := range fd.GetFile() { + sd := file.GetService() + + for _, service := range sd { + + methods := service.GetMethod() + for _, method := range methods { + in := method.GetInputType() + inDesc := pd.GetMessageTypeByFullName(strings.TrimPrefix(in, ".")) + outDesc := pd.GetMessageTypeByFullName(strings.TrimPrefix(method.GetOutputType(), ".")) + srv := fmt.Sprintf("/%s.%s/%s", file.GetPackage(), service.GetName(), method.GetName()) + g.ioType[srv] = &IOType{ + In: inDesc, + Out: outDesc, + } + + } + + } + } +} diff --git a/gateway/grpc_test.go b/gateway/proxy_test.go similarity index 63% rename from gateway/grpc_test.go rename to gateway/proxy_test.go index 2bbacbb..2acf50e 100644 --- a/gateway/grpc_test.go +++ b/gateway/proxy_test.go @@ -14,6 +14,8 @@ import ( "time" "github.com/agiledragon/gomonkey/v2" + cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/internal/pkg/config" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" hello "github.com/begonia-org/go-sdk/api/example/v1" @@ -26,8 +28,11 @@ import ( ) type streamMock struct { + ctx context.Context +} +type clientStreamMock struct { + ctx context.Context } -type clientStreamMock struct{} func (*streamMock) SendHeader(md metadata.MD) error { return nil @@ -37,10 +42,11 @@ func (*streamMock) SetHeader(md metadata.MD) error { } func (*streamMock) SetTrailer(md metadata.MD) { } -func (*streamMock) Context() context.Context { - return context.Background() +func (s *streamMock) Context() context.Context { + return s.ctx } func (*streamMock) SendMsg(m interface{}) error { + time.Sleep(1 * time.Second) return nil } func (*streamMock) RecvMsg(m interface{}) error { @@ -94,13 +100,27 @@ func TestGrpcHandleErr(t *testing.T) { load, _ := loadbalance.New(loadbalance.RRBalanceType, endps) lb := NewGrpcLoadBalancer() lb.Register(load, pd) - mid := func(srv interface{}, serverStream grpc.ServerStream) error { - return nil + fullMethod1 := "" + mid := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + fullMethod1 = method + return streamer(ctx, desc, cc, method, opts...) + } + fullMethod2 := "" + mid2 := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + fullMethod2 = method + return streamer(ctx, desc, cc, method, opts...) + } + proxy := NewGrpcProxy(lb, Log, mid, mid2) + proxy.buildServiceDesc(pd) + proxy.Register(load, pd) + + stream := &streamMock{ + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("uri", "/api/v1/example/server/websocket")), } - proxy := NewGrpcProxy(lb, mid) - stream := &streamMock{} patch := gomonkey.ApplyFuncReturn(grpc.MethodFromServerStream, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket"), true) - patch.ApplyFuncReturn(grpc.NewClientStream, &clientStreamMock{}, nil) + patch.ApplyFuncReturn(grpc.NewClientStream, &clientStreamMock{ + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("uri", "/api/v1/example/server/websocket")), + }, nil) addrs, _ := net.InterfaceAddrs() var localAddr net.Addr for _, addr := range addrs { @@ -121,10 +141,15 @@ func TestGrpcHandleErr(t *testing.T) { err error output []interface{} }{ + // { + // patch: metadata.FromIncomingContext, + // err: fmt.Errorf("metadata not exists in context"), + // output: []interface{}{nil, false}, + // }, { patch: mid, err: fmt.Errorf("mid handle err"), - output: []interface{}{fmt.Errorf("mid handle err")}, + output: []interface{}{nil, fmt.Errorf("mid handle err")}, }, { patch: (*clientStreamMock).CloseSend, @@ -153,14 +178,21 @@ func TestGrpcHandleErr(t *testing.T) { return io.EOF }) defer patch3.Reset() + patch6 := gomonkey.ApplyFuncReturn(grpc.Method, "/helloworld.Greeter/SayHelloWebsocket", true) + defer patch6.Reset() for _, caseV := range cases { patch2 := gomonkey.ApplyFuncReturn(caseV.patch, caseV.output...) defer patch2.Reset() - err = proxy.Handler(&hello.HelloRequest{}, stream) + + err = proxy.Do(&hello.HelloRequest{}, stream) + t.Log(caseV.err.Error()) + // t.Logf("err:%v", err.Error()) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, caseV.err.Error()) patch2.Reset() } + c.So(fullMethod1, c.ShouldEqual, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket")) + c.So(fullMethod2, c.ShouldEqual, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket")) patch3.Reset() errChan2 := make(chan error, 3) @@ -168,12 +200,33 @@ func TestGrpcHandleErr(t *testing.T) { errChan2 <- io.EOF errChan2 <- io.EOF patch4 := gomonkey.ApplyFuncReturn((*GrpcProxy).forwardServerToClient, errChan2) - err = proxy.Handler(&hello.HelloRequest{}, stream) + err = proxy.Do(&hello.HelloRequest{}, stream) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "proxying should never reach") defer patch4.Reset() patch.Reset() + patch4.Reset() + // patch6.Reset() + patch7 := gomonkey.ApplyFuncReturn(grpc.MethodFromServerStream, "/helloworld.Greeter/SayHelloWebsocket", false) + defer patch7.Reset() + + proxy2 := NewGrpcProxy(lb, Log, mid) + err = proxy2.Do(&hello.HelloRequest{}, stream) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "stream not exists in context") + + proxy2.chainStreamClientInterceptors() + f := func() { + proxy.forwardClientToServer(nil, stream) + } + c.So(f, c.ShouldNotPanic) }) } + +func TestProxyDo(t *testing.T) { + pd, _ := readDesc(config.NewConfig(cfg.ReadConfig("test"))) + p := &GrpcProxy{ioType: make(map[string]*IOType), lb: NewGrpcLoadBalancer(), log: Log} + p.buildServiceDesc(pd) +} diff --git a/internal/pkg/routers/routers.go b/gateway/routers.go similarity index 74% rename from internal/pkg/routers/routers.go rename to gateway/routers.go index 3e0d861..b1377e7 100644 --- a/internal/pkg/routers/routers.go +++ b/gateway/routers.go @@ -1,19 +1,10 @@ -package routers +package gateway import ( "fmt" "strings" "sync" - "github.com/begonia-org/begonia/gateway" - _ "github.com/begonia-org/go-sdk/api/app/v1" - _ "github.com/begonia-org/go-sdk/api/endpoint/v1" - _ "github.com/begonia-org/go-sdk/api/example/v1" - _ "github.com/begonia-org/go-sdk/api/iam/v1" - _ "github.com/begonia-org/go-sdk/api/plugin/v1" - _ "github.com/begonia-org/go-sdk/api/sys/v1" - _ "github.com/begonia-org/go-sdk/api/user/v1" - _ "github.com/begonia-org/go-sdk/common/api/v1" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/protobuf/proto" @@ -52,7 +43,7 @@ func NewHttpURIRouteToSrvMethod() *HttpURIRouteToSrvMethod { }) return httpURIRouteToSrvMethod } -func Get() *HttpURIRouteToSrvMethod { +func GetRouter() *HttpURIRouteToSrvMethod { return NewHttpURIRouteToSrvMethod() } @@ -60,7 +51,8 @@ func (r *HttpURIRouteToSrvMethod) AddRoute(uri string, srvMethod *APIMethodDetai r.mux.Lock() defer r.mux.Unlock() r.routers[uri] = srvMethod - r.grpcRouter[srvMethod.GrpcFullRouter] = srvMethod + // log.Printf("add srv method grpc router:%s,pointer:%p", srvMethod.GrpcFullRouter, r) + r.grpcRouter[strings.ToUpper(srvMethod.GrpcFullRouter)] = srvMethod } func (r *HttpURIRouteToSrvMethod) deleteRoute(uri string, grpcFullMethod string) { delete(r.routers, uri) @@ -71,6 +63,7 @@ func (r *HttpURIRouteToSrvMethod) GetRoute(uri string) *APIMethodDetails { return r.routers[uri] } func (r *HttpURIRouteToSrvMethod) GetRouteByGrpcMethod(method string) *APIMethodDetails { + // log.Printf("get grpc method,%v:%s,pointer:%p",r.grpcRouter,strings.ToUpper(method),r) return r.grpcRouter[strings.ToUpper(method)] } func (r *HttpURIRouteToSrvMethod) GetAllRoutes() map[string]*APIMethodDetails { @@ -85,7 +78,14 @@ func (r *HttpURIRouteToSrvMethod) getServiceOptionByExt(service *descriptorpb.Se } return nil } - +func (r *HttpURIRouteToSrvMethod) getMethodOptionByExt(method *descriptorpb.MethodDescriptorProto, ext protoreflect.ExtensionType) interface{} { + if options := method.GetOptions(); options != nil { + if ext := proto.GetExtension(options, ext); ext != nil { + return ext + } + } + return nil +} func (r *HttpURIRouteToSrvMethod) getHttpRule(method *descriptorpb.MethodDescriptorProto) *annotations.HttpRule { if options := method.GetOptions(); options != nil { if ext := proto.GetExtension(options, annotations.E_Http); ext != nil { @@ -97,6 +97,7 @@ func (r *HttpURIRouteToSrvMethod) getHttpRule(method *descriptorpb.MethodDescrip return nil } func (r *HttpURIRouteToSrvMethod) AddLocalSrv(fullMethod string) { + // log.Printf("add local srv:%s", fullMethod) r.localSrv[strings.ToUpper(fullMethod)] = true } func (r *HttpURIRouteToSrvMethod) IsLocalSrv(fullMethod string) bool { @@ -154,7 +155,13 @@ func (r *HttpURIRouteToSrvMethod) addRouterDetails(serviceName string, useJsonRe } } -func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) { + +// LoadAllRouters Load all routers from protobuf description +// for service methods, if the method has a google.api.http annotation, then add the router +// to the router list, and set the authRequired flag to true if the method has a pb.auth_required annotation, +// if the method has a pb.http_response annotation, then set the useJsonResponse flag to true, +// if the method has a pb.dont_use_http_response annotation, then set the useJsonResponse flag to false. +func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd ProtobufDescription) { fds := pd.GetFileDescriptorSet() for _, fd := range fds.File { for _, service := range fd.Service { @@ -165,13 +172,20 @@ func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) if authRequiredExt := r.getServiceOptionByExt(service, common.E_AuthReqiured); authRequiredExt != nil { authRequired, _ = authRequiredExt.(bool) } - if httpResponseExt := r.getServiceOptionByExt(service, common.E_HttpResponse); httpResponseExt != nil { + if httpResponseExt := r.getServiceOptionByExt(service, common.E_HttpResponse); httpResponseExt != nil && httpResponseExt.(string) != "" { httpResponse = true } // 遍历服务中的所有方法 for _, method := range service.GetMethod() { key := fmt.Sprintf("/%s.%s/%s", fd.GetPackage(), service.GetName(), method.GetName()) - r.addRouterDetails(strings.ToUpper(key), httpResponse, authRequired, method) + // log.Printf("add router:%s,%v", key, httpResponse) + // do not use HttpResponse for this method if it is set + dontUseHttpResponse := r.getMethodOptionByExt(method, common.E_DontUseHttpResponse) + useHttpResponse := httpResponse + if dontUseHttpResponse != nil && dontUseHttpResponse.(bool) { + useHttpResponse = false + } + r.addRouterDetails(strings.ToUpper(key), useHttpResponse, authRequired, method) } } @@ -179,7 +193,7 @@ func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) } -func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd gateway.ProtobufDescription) { +func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd ProtobufDescription) { fds := pd.GetFileDescriptorSet() for _, fd := range fds.File { for _, service := range fd.Service { diff --git a/internal/pkg/routers/routers_test.go b/gateway/routers_test.go similarity index 75% rename from internal/pkg/routers/routers_test.go rename to gateway/routers_test.go index c07cd8d..ed3e8b5 100644 --- a/internal/pkg/routers/routers_test.go +++ b/gateway/routers_test.go @@ -1,4 +1,4 @@ -package routers_test +package gateway_test import ( "path/filepath" @@ -6,15 +6,14 @@ import ( "testing" "github.com/begonia-org/begonia/gateway" - "github.com/begonia-org/begonia/internal/pkg/routers" c "github.com/smartystreets/goconvey/convey" ) func TestLoadAllRouters(t *testing.T) { c.Convey("TestLoadAllRouters", t, func() { - R := routers.NewHttpURIRouteToSrvMethod() + R := gateway.NewHttpURIRouteToSrvMethod() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "testdata") pd, err := gateway.NewDescription(pbFile) c.So(err, c.ShouldBeNil) R.LoadAllRouters(pd) @@ -32,14 +31,19 @@ func TestLoadAllRouters(t *testing.T) { d, ok := rs["/test/custom"] c.So(ok, c.ShouldBeTrue) c.So(d.ServiceName, c.ShouldEqual, "/INTEGRATION.TESTSERVICE/CUSTOM") + c.So(d.UseJsonResponse, c.ShouldBeTrue) + + r := rs["/test/body"] + c.So(r, c.ShouldNotBeNil) + c.So(r.UseJsonResponse, c.ShouldBeFalse) }) } func TestDeleteRouters(t *testing.T) { c.Convey("TestDeleteRouters", t, func() { - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + pbFile := filepath.Join((filepath.Dir(filepath.Dir(filename))), "testdata") pd, err := gateway.NewDescription(pbFile) c.So(err, c.ShouldBeNil) diff --git a/gateway/serialization.go b/gateway/serialization.go index d474925..65f22f2 100644 --- a/gateway/serialization.go +++ b/gateway/serialization.go @@ -15,6 +15,7 @@ import ( common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "google.golang.org/genproto/googleapis/api/httpbody" + spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" @@ -36,9 +37,6 @@ type BinaryDecoder struct { marshaler runtime.Marshaler } -// var typeOfBytes = reflect.TypeOf([]byte(nil)) -// var typeOfHttpbody = reflect.TypeOf(&httpbody.HttpBody{}) - func (d *BinaryDecoder) Decode(v interface{}) error { if v == nil { return nil @@ -131,6 +129,11 @@ func (m *RawBinaryUnmarshaler) Marshal(v interface{}) ([]byte, error) { } } } + if resp, ok := v.(map[string]interface{}); ok { + if _, ok := resp["result"]; ok { + v = resp["result"] + } + } return m.Marshaler.Marshal(v) } @@ -140,17 +143,24 @@ func (m *EventSourceMarshaler) ContentType(v interface{}) string { func (m *EventSourceMarshaler) Marshal(v interface{}) ([]byte, error) { if response, ok := v.(map[string]interface{}); ok { - // result:=response if _, ok := response["result"]; ok { v = response["result"] } - } - // 在这里定义你的自定义序列化逻辑 + if response, ok := v.(map[string]proto.Message); ok { + if _, ok := response["error"]; ok { + v = response["error"] + } + } + // build event stream format line by line if stream, ok := v.(*common.EventStream); ok { line := fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", stream.Id, stream.Event, stream.Retry, stream.Data) return []byte(line), nil - + } + // build error message + if stream, ok := v.(*spb.Status); ok { + line := fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 0, "error", 0, stream.GetMessage()) + return []byte(line), nil } return m.JSONPb.Marshal(v) } @@ -183,7 +193,6 @@ func (m *JSONMarshaler) Marshal(v interface{}) ([]byte, error) { } } - if response, ok := v.(*dynamicpb.Message); ok { // log.Println("实际类型,", response.Type().Descriptor().Name()) byteData, err := m.JSONPb.Marshal(response) @@ -198,3 +207,6 @@ func (m *JSONMarshaler) Marshal(v interface{}) ([]byte, error) { func (m *JSONMarshaler) ContentType(v interface{}) string { return "application/json" } +func (m *JSONMarshaler) NewDecoder(r io.Reader) runtime.Decoder { + return NewMaskDecoder(m.JSONPb.NewDecoder(r)) +} diff --git a/gateway/serialization_test.go b/gateway/serialization_test.go index 4eab9e5..0040e03 100644 --- a/gateway/serialization_test.go +++ b/gateway/serialization_test.go @@ -7,10 +7,15 @@ import ( "testing" "github.com/agiledragon/gomonkey/v2" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" c "github.com/smartystreets/goconvey/convey" "google.golang.org/genproto/googleapis/api/httpbody" + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/dynamicpb" + "google.golang.org/protobuf/types/known/anypb" ) func TestRawBinaryUnmarshaler(t *testing.T) { @@ -108,3 +113,75 @@ func TestRawBinaryDecodeErr(t *testing.T) { }) } +func TestJSONMarshaler(t *testing.T) { + c.Convey("TestJSONMarshaler", t, func() { + marshaler := NewJSONMarshaler() + data := map[string]interface{}{ + "test": "test", + } + buf, err := marshaler.Marshal(data) + c.So(err, c.ShouldBeNil) + c.So(string(buf), c.ShouldEqual, `{"test":"test"}`) + + httpBody := &httpbody.HttpBody{ + ContentType: "application/octet-stream-test", + Data: []byte("test"), + } + msg2 := dynamicpb.NewMessage(httpBody.ProtoReflect().Descriptor()).New() + patch := gomonkey.ApplyFuncReturn((*runtime.JSONPb).Marshal, nil, fmt.Errorf("runtime.JSONPb{}.Marshal: nil")) + defer patch.Reset() + _, err = marshaler.Marshal(msg2) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "runtime.JSONPb{}.Marshal: nil") + }) +} + +func TestEventSourceMarshaler(t *testing.T) { + c.Convey("TestEventSourceMarshaler", t, func() { + marshaler := NewEventSourceMarshaler() + cases := []struct { + data interface{} + err error + exception string + }{ + { + data: map[string]interface{}{ + "result": &common.EventStream{ + Event: "test", + Id: 1, + Data: "test", + Retry: 0, + }, + }, + err: nil, + exception: fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 1, "test", 0, "test"), + }, + { + data: &common.EventStream{ + Event: "test-data", + Id: 1, + Data: "test-data", + Retry: 0, + }, + err: nil, + exception: fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 1, "test-data", 0, "test-data"), + }, + { + data: map[string]proto.Message{ + "error": &spb.Status{ + Message: "test error", + Code: int32(codes.Internal), + Details: []*anypb.Any{}, + }, + }, + err: nil, + exception: fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 0, "error", 0, "test error"), + }, + } + for _, caseV := range cases { + buf, err := marshaler.Marshal(caseV.data) + c.So(err, c.ShouldBeNil) + c.So(string(buf), c.ShouldEqual, caseV.exception) + } + }) +} diff --git a/gateway/types.go b/gateway/types.go index ed309ac..924f5ed 100644 --- a/gateway/types.go +++ b/gateway/types.go @@ -47,7 +47,7 @@ func (x *serverSideStreamClient) buildEventStreamResponse(dpm *dynamicpb.Message return nil, err } - + // log.Printf("buildEventStreamResponse data:%s", string(data)) commonEvent := &common.EventStream{ Event: string(dpm.Descriptor().Name()), Id: atomic.LoadInt64(&x.ID), diff --git a/gateway/utils_test.go b/gateway/utils_test.go index b368ed0..ee16dc0 100644 --- a/gateway/utils_test.go +++ b/gateway/utils_test.go @@ -16,7 +16,7 @@ import ( func TestNewEndpoint(t *testing.T) { opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), diff --git a/gateway/websocket_test.go b/gateway/websocket_test.go index d678e9b..0291af5 100644 --- a/gateway/websocket_test.go +++ b/gateway/websocket_test.go @@ -14,9 +14,9 @@ func TestWebsocketForwarder(t *testing.T) { wk := &websocketForwarder{ websocket: &websocket.Conn{}, } - patch:=gomonkey.ApplyFuncReturn((*websocket.Conn).WriteMessage, fmt.Errorf("write error")) + patch := gomonkey.ApplyFuncReturn((*websocket.Conn).WriteMessage, fmt.Errorf("write error")) defer patch.Reset() - _,err := wk.Write([]byte("test")) + _, err := wk.Write([]byte("test")) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "write error") patch.Reset() diff --git a/go.mod b/go.mod index 8517228..c761a1c 100644 --- a/go.mod +++ b/go.mod @@ -10,16 +10,16 @@ require ( github.com/cockroachdb/errors v1.11.1 github.com/google/wire v0.6.0 github.com/smartystreets/goconvey v1.8.1 - github.com/spark-lence/tiga v0.0.0-20240628071333-f5e34bac7593 + github.com/spark-lence/tiga v0.0.0-20240714083240-b1dae0a443cd github.com/spf13/cobra v1.8.0 - google.golang.org/genproto/googleapis/api v0.0.0-20240624140628-dc46fd24d27d - google.golang.org/grpc v1.64.0 + google.golang.org/genproto/googleapis/api v0.0.0-20240730163845-b1a4ccb954bf + google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 ) require ( github.com/allegro/bigcache/v3 v3.1.0 // indirect - golang.org/x/net v0.26.0 + golang.org/x/net v0.27.0 ) require ( @@ -50,7 +50,7 @@ require ( github.com/oapi-codegen/runtime v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect - github.com/redis/go-redis/v9 v9.5.3 + github.com/redis/go-redis/v9 v9.6.1 github.com/sagikazarmark/locafero v0.6.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/segmentio/kafka-go v0.4.47 // indirect @@ -65,30 +65,32 @@ require ( github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect - github.com/youmark/pkcs8 v0.0.0-20240424034433-3c2c7870ae76 // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect go.mongodb.org/mongo-driver v1.16.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.24.0 // indirect - golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 + golang.org/x/crypto v0.25.0 // indirect + golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/sync v0.7.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gorm.io/driver/mysql v1.5.7 // indirect - gorm.io/gorm v1.25.10 + gorm.io/gorm v1.25.11 ) require ( github.com/agiledragon/gomonkey/v2 v2.11.0 github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a - github.com/begonia-org/go-sdk v0.0.0-20240628071225-2864d45934ea + github.com/begonia-org/go-sdk v0.0.0-20240731081246-9990574a3916 github.com/go-git/go-git/v5 v5.11.0 github.com/go-playground/validator/v10 v10.19.0 github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 + github.com/iancoleman/strcase v0.3.0 github.com/minio/minio-go/v7 v7.0.71 github.com/r3labs/sse/v2 v2.10.0 - go.etcd.io/etcd/api/v3 v3.5.14 - go.etcd.io/etcd/client/v3 v3.5.14 + go.etcd.io/etcd/api/v3 v3.5.15 + go.etcd.io/etcd/client/v3 v3.5.15 + google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf gopkg.in/cenkalti/backoff.v1 v1.1.0 ) @@ -135,11 +137,10 @@ require ( github.com/skeema/knownhosts v1.2.2 // indirect github.com/smarty/assertions v1.15.0 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect - go.etcd.io/etcd/client/pkg/v3 v3.5.14 // indirect + go.etcd.io/etcd/client/pkg/v3 v3.5.15 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/sys v0.21.0 // indirect + golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/go.sum b/go.sum index db5dd88..c2a105a 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/begonia-org/go-layered-cache v0.0.0-20240510102605-41bdb7aa07fa h1:DH github.com/begonia-org/go-layered-cache v0.0.0-20240510102605-41bdb7aa07fa/go.mod h1:xEqoca1vNGqH8CV7X9EzhDV5Ihtq9J95p7ZipzUB6pc= github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a h1:Mpw7T+90KC5QW7yCa8Nn/5psnlvsexipAOrQAcc7YE0= github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a/go.mod h1:crPS67sfgmgv47psftwfmTMbmTfdepVm8MPeqApINlI= -github.com/begonia-org/go-sdk v0.0.0-20240628071225-2864d45934ea h1:jdDBLZVsGKfmF/V+U4WUrz4Hzd3/62GqMIly9gnSppw= -github.com/begonia-org/go-sdk v0.0.0-20240628071225-2864d45934ea/go.mod h1:I70a3fiAADGrOoOC3lv408rFcTRhTwLt3pwr6cQwB4Y= +github.com/begonia-org/go-sdk v0.0.0-20240731081246-9990574a3916 h1:nXTX0vRd0SlgkmWg9/UIWCJAdm02yXMX5/2ALzR3MlM= +github.com/begonia-org/go-sdk v0.0.0-20240731081246-9990574a3916/go.mod h1:2mHpFudwolu6RHF18EX+lnFYyTNnwDxBD6JcfRcahz8= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -123,6 +123,8 @@ github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/C github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/influxdata/influxdb-client-go/v2 v2.13.0 h1:ioBbLmR5NMbAjP4UVA5r9b5xGjpABD7j65pI8kFphDM= @@ -200,8 +202,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0= github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I= -github.com/redis/go-redis/v9 v9.5.3 h1:fOAp1/uJG+ZtcITgZOfYFmTKPE7n4Vclj1wZFgRciUU= -github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0y4= +github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= @@ -229,8 +231,8 @@ github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sS github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= -github.com/spark-lence/tiga v0.0.0-20240628071333-f5e34bac7593 h1:wbjjiTqJWiAIQemNAE1Nw71roZQ8sshGJ70XHwBx8b4= -github.com/spark-lence/tiga v0.0.0-20240628071333-f5e34bac7593/go.mod h1:MSL8X9t+qvpQ4Tq3vVPKncq9RJcCzF2XGEWkCuNhm6Q= +github.com/spark-lence/tiga v0.0.0-20240714083240-b1dae0a443cd h1:HCOerDD33LOuD7sA8G5E2CD42AX29hK7pDJID7GSc9c= +github.com/spark-lence/tiga v0.0.0-20240714083240-b1dae0a443cd/go.mod h1:h7BTZeR6xD6+tr3ClEhHC1PeXPOn3jRt7NnThQg1JvE= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= @@ -268,17 +270,17 @@ github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= -github.com/youmark/pkcs8 v0.0.0-20240424034433-3c2c7870ae76 h1:tBiBTKHnIjovYoLX/TPkcf+OjqqKGQrPtGT3Foz+Pgo= -github.com/youmark/pkcs8 v0.0.0-20240424034433-3c2c7870ae76/go.mod h1:SQliXeA7Dhkt//vS29v3zpbEwoa+zb2Cn5xj5uO4K5U= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.etcd.io/etcd/api/v3 v3.5.14 h1:vHObSCxyB9zlF60w7qzAdTcGaglbJOpSj1Xj9+WGxq0= -go.etcd.io/etcd/api/v3 v3.5.14/go.mod h1:BmtWcRlQvwa1h3G2jvKYwIQy4PkHlDej5t7uLMUdJUU= -go.etcd.io/etcd/client/pkg/v3 v3.5.14 h1:SaNH6Y+rVEdxfpA2Jr5wkEvN6Zykme5+YnbCkxvuWxQ= -go.etcd.io/etcd/client/pkg/v3 v3.5.14/go.mod h1:8uMgAokyG1czCtIdsq+AGyYQMvpIKnSvPjFMunkgeZI= -go.etcd.io/etcd/client/v3 v3.5.14 h1:CWfRs4FDaDoSz81giL7zPpZH2Z35tbOrAJkkjMqOupg= -go.etcd.io/etcd/client/v3 v3.5.14/go.mod h1:k3XfdV/VIHy/97rqWjoUzrj9tk7GgJGH9J8L4dNXmAk= +go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk= +go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM= +go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA= +go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU= +go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4= +go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU= go.mongodb.org/mongo-driver v1.16.0 h1:tpRsfBJMROVHKpdGyc1BBEzzjDUWjItxbVSZ8Ls4BQ4= go.mongodb.org/mongo-driver v1.16.0/go.mod h1:oB6AhJQvFQL4LEHyXi6aJzQJtBiTQHiAd83l0GdFaiw= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -297,10 +299,10 @@ golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= -golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -325,8 +327,8 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -355,8 +357,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= @@ -366,7 +368,7 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= +golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -393,12 +395,12 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20240624140628-dc46fd24d27d h1:Aqf0fiIdUQEj0Gn9mKFFXoQfTTEaNopWpfVyYADxiSg= -google.golang.org/genproto/googleapis/api v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Od4k8V1LQSizPRUK4OzZ7TBE/20k+jPczUDAEyvn69Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= -google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= -google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/genproto/googleapis/api v0.0.0-20240730163845-b1a4ccb954bf h1:GillM0Ef0pkZPIB+5iO6SDK+4T9pf6TpaYR6ICD5rVE= +google.golang.org/genproto/googleapis/api v0.0.0-20240730163845-b1a4ccb954bf/go.mod h1:OFMYQFHJ4TM3JRlWDZhJbZfra2uqc3WLBZiaaqP4DtU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf h1:liao9UHurZLtiEwBgT9LMOnKYsHze6eA6w1KQCMVN2Q= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= +google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= @@ -419,5 +421,5 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= -gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg= +gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/internal/biz/aksk.go b/internal/biz/aksk.go index 60afdd7..7d81263 100644 --- a/internal/biz/aksk.go +++ b/internal/biz/aksk.go @@ -5,9 +5,9 @@ import ( "strings" "time" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/app/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -30,7 +30,7 @@ func NewAccessKeyAuth(app AppRepo, config *config.Config, log logger.Logger) *Ac } func IfNeedValidate(ctx context.Context, fullMethod string) bool { - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(strings.ToUpper(fullMethod)) if router == nil { return false @@ -124,4 +124,4 @@ func (a *AccessKeyAuth) GetAppOwner(ctx context.Context, accessKey string) (stri return "", gosdk.NewError(err, int32(api.APPSvrCode_APP_UNKNOWN), codes.Unauthenticated, "app_owner") } return app.Owner, nil -} \ No newline at end of file +} diff --git a/internal/biz/aksk_test.go b/internal/biz/aksk_test.go index ef076b3..06aef13 100644 --- a/internal/biz/aksk_test.go +++ b/internal/biz/aksk_test.go @@ -16,7 +16,6 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/pkg/utils" gosdk "github.com/begonia-org/go-sdk" @@ -30,7 +29,7 @@ import ( var akskAccess = "" var akskSecret = "" var akskAppid = "" -var akskOwner="" +var akskOwner = "" func newGatewayRequest() (*gosdk.GatewayRequest, error) { signer := gosdk.NewAppAuthSigner(akskAccess, akskSecret) @@ -87,8 +86,7 @@ func testGetSecret(t *testing.T) { Description: "test", CreatedAt: timestamppb.New(time.Now()), UpdatedAt: timestamppb.New(time.Now()), - Owner: akskOwner, - + Owner: akskOwner, }) if err != nil { @@ -139,7 +137,7 @@ func testIfNeedValidate(t *testing.T) { ok := biz.IfNeedValidate(context.TODO(), akskAccess) c.So(ok, c.ShouldBeFalse) - patch := gomonkey.ApplyFuncReturn((*routers.HttpURIRouteToSrvMethod).GetRouteByGrpcMethod, &routers.APIMethodDetails{AuthRequired: true}) + patch := gomonkey.ApplyFuncReturn((*gateway.HttpURIRouteToSrvMethod).GetRouteByGrpcMethod, &gateway.APIMethodDetails{AuthRequired: true}) defer patch.Reset() ok = biz.IfNeedValidate(context.TODO(), akskAccess) c.So(ok, c.ShouldBeTrue) diff --git a/internal/biz/app.go b/internal/biz/app.go index 42a9f51..3440dae 100644 --- a/internal/biz/app.go +++ b/internal/biz/app.go @@ -63,12 +63,12 @@ func (a *AppUsecase) CreateApp(ctx context.Context, in *api.AppsRequest, owner s appid := GenerateAppid(a.snowflake) accessKey, err := GenerateAppAccessKey() if err != nil { - return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "generate_app_access_key") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "generate_app_access_key") } secret, err := GenerateAppSecret() if err != nil { - return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "generate_app_secret_key") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "generate_app_secret_key") } app := a.newApp() app.AccessKey = accessKey diff --git a/internal/biz/authz_test.go b/internal/biz/authz_test.go index 6f1f298..4ab34bf 100644 --- a/internal/biz/authz_test.go +++ b/internal/biz/authz_test.go @@ -28,8 +28,8 @@ import ( "github.com/begonia-org/begonia/internal/pkg/crypto" "github.com/begonia-org/begonia/internal/pkg/utils" - v1 "github.com/begonia-org/go-sdk/api/user/v1" app "github.com/begonia-org/go-sdk/api/app/v1" + v1 "github.com/begonia-org/go-sdk/api/user/v1" "github.com/spark-lence/tiga" "google.golang.org/grpc/metadata" @@ -52,7 +52,7 @@ func newAuthzBiz() *biz.AuthzUsecase { crypto := crypto.NewUsersAuth(cnf) app := data.NewAppRepo(config, gateway.Log) - return biz.NewAuthzUsecase(repo, user,app, gateway.Log, crypto, cnf) + return biz.NewAuthzUsecase(repo, user, app, gateway.Log, crypto, cnf) } func testAuthSeed(t *testing.T) { @@ -310,7 +310,7 @@ func testDelToken(t *testing.T) { c.So(err, c.ShouldBeNil) }) } -func testGetIdentity(t *testing.T){ +func testGetIdentity(t *testing.T) { authzBiz := newAuthzBiz() c.Convey("test get identity", t, func() { env := "dev" @@ -319,39 +319,38 @@ func testGetIdentity(t *testing.T){ } config := config.ReadConfig(env) appRepo := data.NewAppRepo(config, gateway.Log) - patch:=gomonkey.ApplyMethodReturn(appRepo,"Get",&app.Apps{Owner: "12345567"},nil) + patch := gomonkey.ApplyMethodReturn(appRepo, "Get", &app.Apps{Owner: "12345567"}, nil) defer patch.Reset() - identity, err := authzBiz.GetIdentity(context.Background(), gosdk.AccessKeyType,"123456") + identity, err := authzBiz.GetIdentity(context.Background(), gosdk.AccessKeyType, "123456") patch.Reset() c.So(err, c.ShouldBeNil) c.So(identity, c.ShouldEqual, "12345567") - identity, err = authzBiz.GetIdentity(context.Background(), gosdk.UidType,"123456") + identity, err = authzBiz.GetIdentity(context.Background(), gosdk.UidType, "123456") c.So(err, c.ShouldBeNil) c.So(identity, c.ShouldEqual, "123456") - cnf:=cfg.NewConfig(config) - apikey:=cnf.GetAdminAPIKey() - identity, err = authzBiz.GetIdentity(context.Background(), gosdk.UidType,"123456") + cnf := cfg.NewConfig(config) + apikey := cnf.GetAdminAPIKey() + identity, err = authzBiz.GetIdentity(context.Background(), gosdk.UidType, "123456") c.So(err, c.ShouldBeNil) c.So(identity, c.ShouldEqual, "123456") - - identity, err = authzBiz.GetIdentity(context.Background(), gosdk.ApiKeyType,apikey) + identity, err = authzBiz.GetIdentity(context.Background(), gosdk.ApiKeyType, apikey) c.So(err, c.ShouldBeNil) c.So(identity, c.ShouldNotBeEmpty) - _, err = authzBiz.GetIdentity(context.Background(), gosdk.ApiKeyType,"123") + _, err = authzBiz.GetIdentity(context.Background(), gosdk.ApiKeyType, "123") c.So(err, c.ShouldNotBeNil) - c.So(err.Error(),c.ShouldContainSubstring,pkg.ErrAPIKeyNotMatch.Error()) - patch2:=gomonkey.ApplyFuncReturn(tiga.MySQLDao.First,fmt.Errorf("error")) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) + patch2 := gomonkey.ApplyFuncReturn(tiga.MySQLDao.First, fmt.Errorf("error")) defer patch2.Reset() - _, err = authzBiz.GetIdentity(context.Background(), gosdk.ApiKeyType,apikey) + _, err = authzBiz.GetIdentity(context.Background(), gosdk.ApiKeyType, apikey) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(),c.ShouldContainSubstring,"error") + c.So(err.Error(), c.ShouldContainSubstring, "error") patch2.Reset() - _, err = authzBiz.GetIdentity(context.Background(), "gosdk.ApiKeyType",apikey) + _, err = authzBiz.GetIdentity(context.Background(), "gosdk.ApiKeyType", apikey) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(),c.ShouldContainSubstring,pkg.ErrIdentityMissing.Error()) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrIdentityMissing.Error()) }) } @@ -369,7 +368,7 @@ func testPutBlackList(t *testing.T) { func TestAuthz(t *testing.T) { t.Run("test auth seed", testAuthSeed) t.Run("test login", testLogin) - t.Run("test get identity",testGetIdentity) + t.Run("test get identity", testGetIdentity) t.Run("test logout", testLogout) t.Run("test del token", testDelToken) t.Run("test put black list", testPutBlackList) diff --git a/internal/biz/biz.go b/internal/biz/biz.go index eb8c2f3..81b4497 100644 --- a/internal/biz/biz.go +++ b/internal/biz/biz.go @@ -13,4 +13,7 @@ var ProviderSet = wire.NewSet(NewAuthzUsecase, endpoint.NewEndpointUsecase, NewAppUsecase, endpoint.NewWatcher, - NewDataOperatorUsecase) + NewDataOperatorUsecase, + NewTenantUsecase, + NewBusinessUsecase, +) diff --git a/internal/biz/bussiness.go b/internal/biz/bussiness.go new file mode 100644 index 0000000..149d7d7 --- /dev/null +++ b/internal/biz/bussiness.go @@ -0,0 +1,88 @@ +package biz + +import ( + "context" + "fmt" + "strings" + + gosdk "github.com/begonia-org/go-sdk" + api "github.com/begonia-org/go-sdk/api/user/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/spark-lence/tiga" + "google.golang.org/grpc/codes" +) + +type BusinessRepo interface { + Add(ctx context.Context, business *api.Business) error + Get(ctx context.Context, key string) (*api.Business, error) + Del(ctx context.Context, key string) error + List(ctx context.Context, tags []string, page, pageSize int32) ([]*api.Business, error) + Patch(ctx context.Context, model *api.Business) error +} + +type BusinessUsecase struct { + repo BusinessRepo + snowflake *tiga.Snowflake +} + +func NewBusinessUsecase(repo BusinessRepo) *BusinessUsecase { + snk, _ := tiga.NewSnowflake(1) + return &BusinessUsecase{repo: repo, snowflake: snk} +} + +func (u *BusinessUsecase) Add(ctx context.Context, in *api.PostBusinessRequest, createdBy string) (business *api.Business, err error) { + defer func() { + if err != nil { + // log.Println(err) + if strings.Contains(err.Error(), "Duplicate entry") { + err = gosdk.NewError(err, int32(common.Code_CONFLICT), codes.AlreadyExists, "commit_app") + } else { + err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "add_business") + + } + } + }() + // business.BusinessId = u.snowflake.GenerateIDString() + business = &api.Business{ + BusinessId: u.snowflake.GenerateIDString(), + BusinessName: in.BusinessName, + Description: in.Description, + Tags: in.Tags, + CreatedBy: createdBy, + } + + err = u.repo.Add(ctx, business) + return +} +func (u *BusinessUsecase) Get(ctx context.Context, key string) (*api.Business, error) { + business, err := u.repo.Get(ctx, key) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("get business fail:%w", err), int32(common.Code_INTERNAL_ERROR), codes.NotFound, "get_business") + } + return business, nil +} +func (u *BusinessUsecase) Del(ctx context.Context, key string) error { + err := u.repo.Del(ctx, key) + if err != nil { + return gosdk.NewError(fmt.Errorf("del business fail:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "del_business") + } + return nil +} +func (u *BusinessUsecase) List(ctx context.Context, tags []string, page, pageSize int32) ([]*api.Business, error) { + list, err := u.repo.List(ctx, tags, page, pageSize) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("list business fail:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "list_business") + + } + return list, nil +} +func (u *BusinessUsecase) Patch(ctx context.Context, model *api.Business) error { + err := u.repo.Patch(ctx, model) + if err != nil { + if strings.Contains(err.Error(), "Duplicate entry") { + return gosdk.NewError(err, int32(common.Code_CONFLICT), codes.AlreadyExists, "patch_business") + } + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "patch_business") + } + return nil +} diff --git a/internal/biz/bussiness_test.go b/internal/biz/bussiness_test.go new file mode 100644 index 0000000..5b385a5 --- /dev/null +++ b/internal/biz/bussiness_test.go @@ -0,0 +1,159 @@ +package biz_test + +import ( + "context" + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/begonia-org/begonia" + cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" + "github.com/begonia-org/begonia/internal/biz" + "github.com/begonia-org/begonia/internal/data" + api "github.com/begonia-org/go-sdk/api/user/v1" + c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" + "google.golang.org/protobuf/types/known/fieldmaskpb" +) + +var bid = "" +var bn = "" + +func testAddBusiness(t *testing.T) { + c.Convey("test add business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(2) + + bs := biz.NewBusinessUsecase(repo) + bn = fmt.Sprintf("test-%s", bid) + in := &api.PostBusinessRequest{ + BusinessName: fmt.Sprintf("test-%s", snk.GenerateIDString()), + Description: "test", + Tags: []string{"test"}, + } + business, err := bs.Add(context.Background(), in, snk.GenerateIDString()) + c.So(err, c.ShouldBeNil) + bid = business.BusinessId + bn = business.BusinessName + _, err = bs.Add(context.Background(), in, snk.GenerateIDString()) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Duplicate entry") + in2 := &api.PostBusinessRequest{ + BusinessName: fmt.Sprintf("test-1-%s", bid), + } + + patch := gomonkey.ApplyMethodReturn(repo, "Add", fmt.Errorf("too long")) + defer patch.Reset() + _, err = bs.Add(context.Background(), in2, snk.GenerateIDString()) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "too long") + }) +} + +func testGetBusiness(t *testing.T) { + c.Convey("test get business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + bs := biz.NewBusinessUsecase(repo) + business, err := bs.Get(context.Background(), bid) + c.So(err, c.ShouldBeNil) + c.So(business, c.ShouldNotBeNil) + c.So(business.BusinessName, c.ShouldEqual, bn) + }) +} +func testUpdateBusiness(t *testing.T) { + c.Convey("test update business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + bs := biz.NewBusinessUsecase(repo) + business, err := bs.Get(context.Background(), bid) + c.So(err, c.ShouldBeNil) + c.So(business, c.ShouldNotBeNil) + business.Description = "test update" + business.UpdateMask = &fieldmaskpb.FieldMask{Paths: []string{"description"}} + err = bs.Patch(context.Background(), business) + c.So(err, c.ShouldBeNil) + business, err = bs.Get(context.Background(), bid) + c.So(err, c.ShouldBeNil) + c.So(business, c.ShouldNotBeNil) + c.So(business.Description, c.ShouldEqual, "test update") + + snk, _ := tiga.NewSnowflake(2) + in := &api.PostBusinessRequest{ + BusinessName: fmt.Sprintf("test-%s", snk.GenerateIDString()), + Description: "test", + Tags: []string{"test"}, + } + _, err = bs.Add(context.Background(), in, snk.GenerateIDString()) + c.So(err, c.ShouldBeNil) + bn2 := in.BusinessName + business.BusinessName = bn2 + business.UpdateMask = &fieldmaskpb.FieldMask{Paths: []string{"business_name"}} + err = bs.Patch(context.Background(), business) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Duplicate entry") + + business.BusinessName = bn + business.UpdateMask = &fieldmaskpb.FieldMask{Paths: []string{"business_name"}} + business.BusinessId = snk.GenerateIDString() + err = bs.Patch(context.Background(), business) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + + }) + +} +func testListBusiness(t *testing.T) { + c.Convey("test list business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + bs := biz.NewBusinessUsecase(repo) + business, err := bs.List(context.Background(), []string{"test"}, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(business, c.ShouldNotBeEmpty) + + patch := gomonkey.ApplyMethodReturn(repo, "List", nil, fmt.Errorf("list error")) + defer patch.Reset() + _, err = bs.List(context.Background(), []string{"test"}, 1, 10) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "list error") + + }) +} +func testDeleteBusiness(t *testing.T) { + c.Convey("test delete business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + bs := biz.NewBusinessUsecase(repo) + err := bs.Del(context.Background(), bid) + c.So(err, c.ShouldBeNil) + _, err = bs.Get(context.Background(), bid) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + }) +} + +func TestBusinessBiz(t *testing.T) { + t.Run("test add business", testAddBusiness) + t.Run("test get business", testGetBusiness) + t.Run("test update business", testUpdateBusiness) + t.Run("test list business", testListBusiness) + t.Run("test delete business", testDeleteBusiness) +} diff --git a/internal/biz/curd.go b/internal/biz/curd.go index 4dc1d9c..384be26 100644 --- a/internal/biz/curd.go +++ b/internal/biz/curd.go @@ -7,16 +7,18 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" ) type CURD interface { - Add(ctx context.Context, model Model, needEncrypt bool) error + Add(ctx context.Context, model Model, needEncrypt bool, tx *gorm.DB) error Get(ctx context.Context, model interface{}, needDecrypt bool, query string, args ...interface{}) error - Update(ctx context.Context, model Model, needEncrypt bool) error - Del(ctx context.Context, model interface{}, needEncrypt bool) error + Update(ctx context.Context, model Model, needEncrypt bool, tx *gorm.DB) error + Del(ctx context.Context, model interface{}, needEncrypt bool, tx *gorm.DB) error List(ctx context.Context, models interface{}, pagination *tiga.Pagination) error + BeginTx(ctx context.Context) *gorm.DB } type SourceType interface { // 获取数据源类型 diff --git a/internal/biz/data.go b/internal/biz/data.go index 04fe9b6..770225c 100644 --- a/internal/biz/data.go +++ b/internal/biz/data.go @@ -129,7 +129,7 @@ func (d *DataOperatorUsecase) loadUsersBlacklist(ctx context.Context) error { // 直接加载远程缓存到本地 // lastUpdate ttl= 3*time.Second { users, err := d.repo.GetAllForbiddenUsers(ctx) d.log.Infof(ctx, "load users:%d", len(users)) diff --git a/internal/biz/data_test.go b/internal/biz/data_test.go index b776cee..91c5958 100644 --- a/internal/biz/data_test.go +++ b/internal/biz/data_test.go @@ -116,7 +116,7 @@ func TestDo(t *testing.T) { _ = cache.Del(context.Background(), "begonia:user:black:lock") _ = cache.Del(context.Background(), "begonia:user:black:last_updated") opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), @@ -165,15 +165,20 @@ func TestDo(t *testing.T) { } err = appBiz.Put(context.TODO(), app, u2.Uid) c.So(err, c.ShouldBeNil) - patch := gomonkey.ApplyFuncReturn((*cfg.Config).GetUserBlackListExpiration, 3) + patch := gomonkey.ApplyFuncReturn((*cfg.Config).GetUserBlackListExpiration, 6) defer patch.Reset() go dataOperator.Do(context.Background()) go dataOperator.Do(context.Background()) time.Sleep(5 * time.Second) prefix := cnf.GetUserBlackListPrefix() + t.Logf("get blacklist: %s", fmt.Sprintf("%s:%s", prefix, u1.Uid)) + // cache2:=data.NewLayered(config, gateway.Log) + // cache2.SetToLocal(context.TODO(), fmt.Sprintf("%s:%s", prefix, u1.Uid), []byte("1"), 3*time.Second) + val, err := cache.GetFromLocal(context.TODO(), fmt.Sprintf("%s:%s", prefix, u1.Uid)) c.So(err, c.ShouldBeNil) + t.Logf("blacklist value:%s", val) c.So(val, c.ShouldNotBeEmpty) appPrefix := cnf.GetAppPrefix() val, err = cache.GetFromLocal(context.TODO(), fmt.Sprintf("%s:access_key:%s", appPrefix, app.AccessKey)) diff --git a/internal/biz/endpoint/endpoint_test.go b/internal/biz/endpoint/endpoint_test.go index b669b22..549050f 100644 --- a/internal/biz/endpoint/endpoint_test.go +++ b/internal/biz/endpoint/endpoint_test.go @@ -21,7 +21,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gwRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" goloadbalancer "github.com/begonia-org/go-loadbalancer" @@ -374,7 +373,7 @@ func testWatcherUpdate(t *testing.T) { } val, _ := json.Marshal(value) opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), @@ -385,12 +384,12 @@ func testWatcherUpdate(t *testing.T) { GrpcProxyAddr: "127.0.0.1:12148", } gateway.New(gwCnf, opts) - routers.NewHttpURIRouteToSrvMethod() + gateway.NewHttpURIRouteToSrvMethod() c.Convey("Test Watcher Update", t, func() { err = watcher.Handle(context.TODO(), mvccpb.PUT, cnf.GetServiceKey(epId), string(val)) c.So(err, c.ShouldBeNil) - r := routers.Get() + r := gateway.GetRouter() detail := r.GetRoute("/api/v1/example/{name}") c.So(detail, c.ShouldNotBeNil) @@ -433,6 +432,27 @@ func testWatcherUpdate(t *testing.T) { c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrUnknownLoadBalancer.Error()) + // SetHttpResponse err + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + conf := config.ReadConfig(env) + cnf := cfg.NewConfig(conf) + outDir := cnf.GetGatewayDescriptionOut() + _, filename, _, _ := runtime.Caller(0) + + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata", "helloworld.pb") + pb, err := os.ReadFile(pbFile) + c.So(err, c.ShouldBeNil) + pd, err := gateway.NewDescriptionFromBinary(pb, filepath.Join(outDir, "tmp-test")) + c.So(err, c.ShouldBeNil) + patch6 := gomonkey.ApplyMethodReturn(pd, "SetHttpResponse", fmt.Errorf("test SetHttpResponse error")) + defer patch6.Reset() + err = watcher.Handle(context.TODO(), mvccpb.PUT, cnf.GetServiceKey(epId), string(val)) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "test SetHttpResponse error") }) } func testWatcherDel(t *testing.T) { @@ -453,7 +473,7 @@ func testWatcherDel(t *testing.T) { c.Convey("Test Watcher Del", t, func() { err := watcher.Handle(context.TODO(), mvccpb.DELETE, cnf.GetServiceKey(epId), string(val)) c.So(err, c.ShouldBeNil) - r := routers.Get() + r := gateway.GetRouter() detail := r.GetRoute("/api/v1/example/{name}") c.So(detail, c.ShouldBeNil) }) diff --git a/internal/biz/endpoint/utils.go b/internal/biz/endpoint/utils.go index bcd5d4e..2de7aca 100644 --- a/internal/biz/endpoint/utils.go +++ b/internal/biz/endpoint/utils.go @@ -7,14 +7,13 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/grpc/codes" ) func deleteAll(ctx context.Context, pd gateway.ProtobufDescription) error { - routersList := routers.Get() + routersList := gateway.GetRouter() routersList.DeleteRouters(pd) gw := gateway.Get() gw.DeleteLoadBalance(pd) diff --git a/internal/biz/endpoint/watcher.go b/internal/biz/endpoint/watcher.go index 8149c29..b8365e5 100644 --- a/internal/biz/endpoint/watcher.go +++ b/internal/biz/endpoint/watcher.go @@ -3,17 +3,17 @@ package endpoint import ( "context" "fmt" + "log" "sync" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" "go.etcd.io/etcd/api/v3/mvccpb" "encoding/json" - "github.com/begonia-org/begonia/gateway" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -38,7 +38,7 @@ func (g *EndpointWatcher) Update(ctx context.Context, key string, value string) return nil } endpoint := &api.Endpoints{} - routersList := routers.NewHttpURIRouteToSrvMethod() + routersList := gateway.GetRouter() err := json.Unmarshal([]byte(value), endpoint) if err != nil { return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_endpoint") @@ -62,15 +62,24 @@ func (g *EndpointWatcher) Update(ctx context.Context, key string, value string) } // register routers // log.Print("register router") + err = pd.SetHttpResponse(common.E_HttpResponse) + if err != nil { + return gosdk.NewError(fmt.Errorf("set http response error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "set_http_response") + + } routersList.LoadAllRouters(pd) + // register service to gateway gw := gateway.Get() err = gw.RegisterService(ctx, pd, lb) + if err != nil { return gosdk.NewError(fmt.Errorf("register service error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "register_service") } + gw.RegisterServiceWithProxy(pd) // err = g.repo.PutTags(ctx, endpoint.Key, endpoint.Tags) + log.Printf("register service success") return nil } func (g *EndpointWatcher) Del(ctx context.Context, key string, value string) error { diff --git a/internal/biz/file/file.go b/internal/biz/file/file.go index 632dc9f..68eaeec 100644 --- a/internal/biz/file/file.go +++ b/internal/biz/file/file.go @@ -325,7 +325,7 @@ func (f *FileUsecaseImpl) Upload(ctx context.Context, in *api.UploadFileRequest, if updated { existsObj, err := f.repo.GetFile(ctx, fileObj.Engine, fileObj.Bucket, fileObj.Key) if err != nil { - return nil, gosdk.NewError(fmt.Errorf("get updated file error:%w",err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_file") + return nil, gosdk.NewError(fmt.Errorf("get updated file error:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_file") } if existsObj != nil { uid = existsObj.Uid @@ -541,13 +541,13 @@ func (f *FileUsecaseImpl) CompleteMultipartUploadFile(ctx context.Context, in *a // log.Printf("insert %s,%s,%s,%s", fileObj.Uid, fileObj.Bucket, fileObj.Key, fileObj.Engine) updated, err := f.repo.UpsertFile(ctx, fileObj) if err != nil { - return nil, gosdk.NewError(fmt.Errorf("insert or update file err:%w",err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "upsert_file") + return nil, gosdk.NewError(fmt.Errorf("insert or update file err:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "upsert_file") } uid := fileObj.Uid if updated { existsObj, err := f.repo.GetFile(ctx, fileObj.Engine, fileObj.Bucket, fileObj.Key) if err != nil { - return nil, gosdk.NewError(fmt.Errorf("get updated file error:%w",err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_file") + return nil, gosdk.NewError(fmt.Errorf("get updated file error:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_file") } if existsObj != nil { uid = existsObj.Uid diff --git a/internal/biz/file/file_test.go b/internal/biz/file/file_test.go index 032bc6e..eb75e0e 100644 --- a/internal/biz/file/file_test.go +++ b/internal/biz/file/file_test.go @@ -806,7 +806,7 @@ func testCompleteMultipartUploadFile(t *testing.T) { }) c.So(err, c.ShouldBeNil) - _=uploadParts(bigTmpFile2.path, rsp.UploadId, "test/upload.parts.test1", t) + _ = uploadParts(bigTmpFile2.path, rsp.UploadId, "test/upload.parts.test1", t) rsp2, err := fileBiz.CompleteMultipartUploadFile(context.TODO(), &api.CompleteMultipartUploadRequest{ Key: "test/upload.parts.test1", UploadId: rsp.UploadId, @@ -814,9 +814,9 @@ func testCompleteMultipartUploadFile(t *testing.T) { Bucket: bucket, Engine: api.FileEngine_FILE_ENGINE_LOCAL.String(), }, fileAuthor) - + c.So(err, c.ShouldBeNil) - c.So(rsp2.Uid, c.ShouldEqual,fid) + c.So(rsp2.Uid, c.ShouldEqual, fid) }) c.Convey("test complete parts file update fail", t, func() { bigTmpFile2, _ := generateRandomFile(1024 * 1024 * 12) @@ -827,8 +827,8 @@ func testCompleteMultipartUploadFile(t *testing.T) { }) c.So(err, c.ShouldBeNil) - _=uploadParts(bigTmpFile2.path, rsp.UploadId, "test/upload.parts.test1", t) - patch:=gomonkey.ApplyFuncReturn(tiga.MySQLDao.First,fmt.Errorf("remove error")) + _ = uploadParts(bigTmpFile2.path, rsp.UploadId, "test/upload.parts.test1", t) + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.First, fmt.Errorf("remove error")) defer patch.Reset() _, err = fileBiz.CompleteMultipartUploadFile(context.TODO(), &api.CompleteMultipartUploadRequest{ Key: "test/upload.parts.test1", diff --git a/internal/biz/file/minio_test.go b/internal/biz/file/minio_test.go index 590b5f7..cadae4f 100644 --- a/internal/biz/file/minio_test.go +++ b/internal/biz/file/minio_test.go @@ -89,7 +89,7 @@ func testMinioUpload(t *testing.T) { Bucket: minioBucket, Key: "test.txt", Content: []byte("hello"), - Engine: api.FileEngine_FILE_ENGINE_MINIO.String(), + Engine: api.FileEngine_FILE_ENGINE_MINIO.String(), }, minioFileAuthor) c.So(err, c.ShouldBeNil) c.So(rsp, c.ShouldNotBeNil) @@ -106,7 +106,7 @@ func testMinioUpload(t *testing.T) { Bucket: minioBucket, Key: "test.txt", Content: []byte("hello"), - Engine: api.FileEngine_FILE_ENGINE_MINIO.String(), + Engine: api.FileEngine_FILE_ENGINE_MINIO.String(), }, minioFileAuthor) patch.Reset() c.So(err, c.ShouldNotBeNil) @@ -260,7 +260,7 @@ func testMinioInitPartsUpload(t *testing.T) { fileBiz := newFileMinioBiz() c.Convey("test init parts upload success", t, func() { rsp, err := fileBiz.InitiateUploadFile(context.TODO(), &api.InitiateMultipartUploadRequest{ - Key: "test-minio.txt", + Key: "test-minio.txt", Engine: api.FileEngine_FILE_ENGINE_MINIO.String(), }) c.So(err, c.ShouldBeNil) @@ -477,14 +477,14 @@ func testMinioDelete(t *testing.T) { func testMinioList(t *testing.T) { fileBiz := newFileMinioBiz() c.Convey("test list success", t, func() { - t.Logf("minio bucket:%s,engine:%s,author:%s", minioBucket,api.FileEngine_FILE_ENGINE_MINIO.String(),minioFileAuthor) + t.Logf("minio bucket:%s,engine:%s,author:%s", minioBucket, api.FileEngine_FILE_ENGINE_MINIO.String(), minioFileAuthor) rsp, err := fileBiz.List(context.Background(), &api.ListFilesRequest{Bucket: minioBucket, Page: 1, PageSize: 20, Engine: api.FileEngine_FILE_ENGINE_MINIO.String()}, minioFileAuthor) c.So(err, c.ShouldBeNil) c.So(rsp, c.ShouldNotBeNil) c.So(len(rsp), c.ShouldBeGreaterThanOrEqualTo, 1) }) c.Convey("test list fail", t, func() { - _, err := fileBiz.List(context.Background(), &api.ListFilesRequest{Bucket: minioBucket, Page: -1, PageSize: -1,Engine: api.FileEngine_FILE_ENGINE_MINIO.String()}, minioFileAuthor) + _, err := fileBiz.List(context.Background(), &api.ListFilesRequest{Bucket: minioBucket, Page: -1, PageSize: -1, Engine: api.FileEngine_FILE_ENGINE_MINIO.String()}, minioFileAuthor) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "SQL syntax") }) diff --git a/internal/biz/tenant.go b/internal/biz/tenant.go new file mode 100644 index 0000000..a71c47f --- /dev/null +++ b/internal/biz/tenant.go @@ -0,0 +1,152 @@ +package biz + +import ( + "context" + "fmt" + "strings" + + "github.com/begonia-org/begonia/internal/pkg/config" + gosdk "github.com/begonia-org/go-sdk" + api "github.com/begonia-org/go-sdk/api/user/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/spark-lence/tiga" + "google.golang.org/grpc/codes" +) + +type TenantRepo interface { + Add(ctx context.Context, tenant *api.Tenants) error + Get(ctx context.Context, key string) (*api.Tenants, error) + Del(ctx context.Context, uidOrName string) error + List(ctx context.Context, tags []string, status []api.TENANTS_STATUS, page, pageSize int32) ([]*api.Tenants, error) + Patch(ctx context.Context, model *api.Tenants) error + AddBusiness(ctx context.Context, tenantBusiness *api.TenantsBusiness) error + DelTenantBusiness(ctx context.Context, tenantId, businessId string) error + GetTenantBusiness(ctx context.Context, tenant, business string) (*api.TenantsBusiness, error) + TenantBusinessList(ctx context.Context, tenantId string, page, pageSize int32) ([]*api.TenantsBusiness, error) +} + +type TenantUsecase struct { + repo TenantRepo + business *BusinessUsecase + cfg *config.Config + snowflake *tiga.Snowflake +} + +func NewTenantUsecase(repo TenantRepo, business *BusinessUsecase, cfg *config.Config) *TenantUsecase { + snk, _ := tiga.NewSnowflake(1) + return &TenantUsecase{repo: repo, business: business, cfg: cfg, snowflake: snk} +} + +func (u *TenantUsecase) Get(ctx context.Context, uidOrName string) (*api.Tenants, error) { + user, err := u.repo.Get(ctx, uidOrName) + if err != nil || user == nil || user.TenantId == "" { + return nil, gosdk.NewError(fmt.Errorf("get tenant fail:%w or not found tenant", err), int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_tenant") + } + return user, nil +} +func (u *TenantUsecase) Add(ctx context.Context, in *api.PostTenantRequest, createdBy string) (*api.Tenants, error) { + tenants := &api.Tenants{ + TenantName: in.TenantName, + Email: in.Email, + Description: in.Description, + Tags: in.Tags, + TenantId: u.snowflake.GenerateIDString(), + CreatedBy: createdBy, + } + err := u.repo.Add(ctx, tenants) + if err != nil { + if strings.Contains(err.Error(), "Duplicate entry") { + return nil, gosdk.NewError(err, int32(api.UserSvrCode_USER_USERNAME_DUPLICATE_ERR), codes.AlreadyExists, "add_tenant") + } + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "add_tenant") + } + return tenants, nil +} + +func (u *TenantUsecase) Update(ctx context.Context, in *api.PatchTenantRequest) (*api.Tenants, error) { + tenant := &api.Tenants{ + TenantId: in.TenantId, + TenantName: in.TenantName, + Email: in.Email, + Description: in.Description, + Tags: in.Tags, + Status: in.Status, + AdminId: in.AdminId, + UpdateMask: in.UpdateMask, + } + + err := u.repo.Patch(ctx, tenant) + if err != nil { + if strings.Contains(err.Error(), "Duplicate entry") { + return nil, gosdk.NewError(fmt.Errorf("Update tenant error:%w", err), int32(common.Code_CONFLICT), codes.AlreadyExists, "patch_app") + } + if strings.Contains(err.Error(), "not found") { + return nil, gosdk.NewError(fmt.Errorf("Update tenant error:%w", err), int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") + } + + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_user") + } + return u.Get(ctx, in.TenantId) +} +func (u *TenantUsecase) Delete(ctx context.Context, uidOrName string) error { + err := u.repo.Del(ctx, uidOrName) + if err != nil { + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_user") + } + return nil +} +func (u *TenantUsecase) List(ctx context.Context, tags []string, status []api.TENANTS_STATUS, page, pageSize int32) ([]*api.Tenants, error) { + tenants, err := u.repo.List(ctx, tags, status, page, pageSize) + if err != nil { + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_tenants") + + } + return tenants, nil +} +func (t *TenantUsecase) AddTenantBusiness(ctx context.Context, tenantId, businessId string, plan, createdBy string) (*api.TenantsBusiness, error) { + tenant, err := t.Get(ctx, tenantId) + if err != nil || tenant == nil || tenant.TenantId == "" { + return nil, gosdk.NewError(fmt.Errorf("get tenant fail:%w or tenant not found", err), int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.InvalidArgument, "get_tenant") + } + business, err := t.business.Get(ctx, businessId) + if err != nil || business == nil || business.BusinessId == "" { + return nil, gosdk.NewError(fmt.Errorf("get business fail:%w or business not found", err), int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.InvalidArgument, "get_business") + } + tb := &api.TenantsBusiness{ + TenantId: tenantId, + BusinessId: businessId, + Plan: plan, + TenantName: tenant.TenantName, + BusinessName: business.BusinessName, + CreatedBy: createdBy, + } + err = t.repo.AddBusiness(ctx, tb) + if err != nil { + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "add_tenants_business") + + } + return tb, nil +} +func (t *TenantUsecase) GetTenantBusiness(ctx context.Context, tenant, business string) (*api.TenantsBusiness, error) { + tb, err := t.repo.GetTenantBusiness(ctx, tenant, business) + if err != nil || tb == nil || tb.TenantId == "" { + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_tenants_business") + } + return tb, nil +} +func (t *TenantUsecase) DelTenantBusiness(ctx context.Context, tenantId, businessId string) error { + err := t.repo.DelTenantBusiness(ctx, tenantId, businessId) + if err != nil { + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "del_tenants_business") + + } + return nil +} + +func (t *TenantUsecase) ListTenantBusiness(ctx context.Context, tenantId string, page, pageSize int32) ([]*api.TenantsBusiness, error) { + tbs, err := t.repo.TenantBusinessList(ctx, tenantId, page, pageSize) + if err != nil { + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_tenants_business") + } + return tbs, nil +} diff --git a/internal/biz/tenant_test.go b/internal/biz/tenant_test.go new file mode 100644 index 0000000..0a28f73 --- /dev/null +++ b/internal/biz/tenant_test.go @@ -0,0 +1,333 @@ +package biz_test + +import ( + "context" + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/begonia-org/begonia" + cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" + "github.com/begonia-org/begonia/internal/biz" + "github.com/begonia-org/begonia/internal/data" + "github.com/begonia-org/begonia/internal/pkg/config" + api "github.com/begonia-org/go-sdk/api/user/v1" + c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" + "google.golang.org/protobuf/types/known/fieldmaskpb" +) + +var tid = "" +var tn = "" +var tid2 = "" +var tn2 = "" +var tenantBusinessId = "" + +func testAddTenant(t *testing.T) { + c.Convey("test tenant add", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(3) + + bs := biz.NewBusinessUsecase(bRepo) + + tbiz := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + in := &api.PostTenantRequest{ + TenantName: fmt.Sprintf("test-%s", snk.GenerateIDString()), + Description: "test tenant", + Tags: []string{"test"}, + Email: fmt.Sprintf("%s@example.com", snk.GenerateIDString()), + } + uid := snk.GenerateIDString() + tenant, err := tbiz.Add(context.Background(), in, uid) + c.So(err, c.ShouldBeNil) + c.So(tenant.TenantId, c.ShouldNotBeEmpty) + tid = tenant.TenantId + tn = tenant.TenantName + _, err = tbiz.Add(context.Background(), in, uid) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Duplicate entry") + patch := gomonkey.ApplyMethodReturn(repo, "Add", fmt.Errorf("too long")) + defer patch.Reset() + _, err = tbiz.Add(context.Background(), in, uid) + patch.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "too long") + + in2 := &api.PostTenantRequest{ + TenantName: fmt.Sprintf("test-%s", snk.GenerateIDString()), + Description: "test tenant2", + Tags: []string{"test"}, + Email: fmt.Sprintf("%s@example.com", snk.GenerateIDString()), + } + tenant, err = tbiz.Add(context.Background(), in2, uid) + c.So(err, c.ShouldBeNil) + tid2 = tenant.TenantId + tn2 = tenant.TenantName + + }) +} +func testGetTenant(t *testing.T) { + c.Convey("test tenant get", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + // snk, _ := tiga.NewSnowflake(3) + + bs := biz.NewBusinessUsecase(bRepo) + + tbiz := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + tenant, err := tbiz.Get(context.Background(), tid) + c.So(err, c.ShouldBeNil) + c.So(tenant.TenantId, c.ShouldEqual, tid) + _, err = tbiz.Get(context.Background(), "not-exist") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + }) +} +func testPatchTenant(t *testing.T) { + c.Convey("test tenant patch", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(3) + + bs := biz.NewBusinessUsecase(bRepo) + + tbiz := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + in := &api.PatchTenantRequest{ + TenantName: fmt.Sprintf("test-%s", snk.GenerateIDString()), + Description: "test tenant patch", + Tags: []string{"test"}, + Email: fmt.Sprintf("%s@example.com", snk.GenerateIDString()), + TenantId: tid, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"description", "tags", "email"}}, + } + _, err := tbiz.Update(context.Background(), in) + c.So(err, c.ShouldBeNil) + tenant, err := tbiz.Get(context.Background(), tid) + c.So(err, c.ShouldBeNil) + c.So(tenant.TenantName, c.ShouldEqual, tn) + + in = &api.PatchTenantRequest{ + TenantName: fmt.Sprintf("test-%s", snk.GenerateIDString()), + Description: "test tenant patch", + Tags: []string{"test"}, + Email: fmt.Sprintf("%s@example.com", snk.GenerateIDString()), + TenantId: snk.GenerateIDString(), + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"description", "tags", "email"}}, + } + _, err = tbiz.Update(context.Background(), in) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + + patch := gomonkey.ApplyMethodReturn(repo, "Patch", fmt.Errorf("too long")) + defer patch.Reset() + in.TenantId = tid + _, err = tbiz.Update(context.Background(), in) + patch.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "too long") + in.TenantName = tn + in.TenantId = tid2 + in.UpdateMask.Paths = []string{"tenant_name"} + t.Logf("tid:%s,tid2:%s,tn:%s,tn2:%s", tid, tid2, tn, tn2) + _, err = tbiz.Update(context.Background(), in) + c.So(err, c.ShouldNotBeNil) + t.Logf("update tenant err:%s", err.Error()) + c.So(err.Error(), c.ShouldContainSubstring, "Duplicate entry") + + }, + ) +} +func testListTenant(t *testing.T) { + c.Convey("test tenant list", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + // snk, _ := tiga.NewSnowflake(3) + + bs := biz.NewBusinessUsecase(bRepo) + + tbiz := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + tenants, err := tbiz.List(context.Background(), []string{"test"}, []api.TENANTS_STATUS{api.TENANTS_STATUS_TENANTS_ACTIVE}, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(tenants, c.ShouldNotBeEmpty) + c.So(len(tenants), c.ShouldBeGreaterThan, 0) + + patch := gomonkey.ApplyMethodReturn(repo, "List", nil, fmt.Errorf("list error")) + defer patch.Reset() + _, err = tbiz.List(context.Background(), []string{"test"}, []api.TENANTS_STATUS{api.TENANTS_STATUS_TENANTS_ACTIVE}, 1, 10) + patch.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "list error") + + }) +} +func testDelTenant(t *testing.T) { + c.Convey("test tenant delete", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + // snk, _ := tiga.NewSnowflake(3) + + bs := biz.NewBusinessUsecase(bRepo) + + tbiz := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + err := tbiz.Delete(context.Background(), tid) + c.So(err, c.ShouldBeNil) + _, err = tbiz.Get(context.Background(), tid) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + patch := gomonkey.ApplyMethodReturn(repo, "Del", fmt.Errorf("del error")) + defer patch.Reset() + err = tbiz.Delete(context.Background(), tid) + // patch.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "del error") + + }) + +} + +func testAddTenantBusiness(t *testing.T) { + c.Convey("test add tenant business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(2) + bs := biz.NewBusinessUsecase(bRepo) + tr := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + in := &api.PostBusinessRequest{ + BusinessName: fmt.Sprintf("test-data-%s", snk.GenerateIDString()), + Description: "test business", + } + business, err := bs.Add(context.Background(), in, snk.GenerateIDString()) + c.So(err, c.ShouldBeNil) + tb, err := tr.AddTenantBusiness(context.Background(), tid, business.BusinessId, "FREE", uid) + c.So(err, c.ShouldBeNil) + tenantBusinessId = business.BusinessId + c.So(tb.TenantId, c.ShouldEqual, tid) + c.So(tb.BusinessId, c.ShouldEqual, business.BusinessId) + + _, err = tr.AddTenantBusiness(context.Background(), tid, snk.GenerateIDString(), "FREE", uid) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found business") + + _, err = tr.AddTenantBusiness(context.Background(), snk.GenerateIDString(), business.BusinessId, "FREE", uid) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found tenant") + + patch := gomonkey.ApplyMethodReturn(repo, "AddBusiness", fmt.Errorf("add error")) + defer patch.Reset() + _, err = tr.AddTenantBusiness(context.Background(), tid, business.BusinessId, "FREE", uid) + patch.Reset() + c.So(err, c.ShouldNotBeNil) + + }) +} +func testGetTenantBusiness(t *testing.T) { + c.Convey("test get tenant business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(2) + bs := biz.NewBusinessUsecase(bRepo) + tr := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + + tbs, err := tr.GetTenantBusiness(context.Background(), tid, tenantBusinessId) + c.So(err, c.ShouldBeNil) + c.So(tbs, c.ShouldNotBeEmpty) + c.So(tbs.BusinessId, c.ShouldEqual, tenantBusinessId) + + _, err = tr.GetTenantBusiness(context.Background(), snk.GenerateIDString(), tenantBusinessId) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + }) +} +func testListTenantBusiness(t *testing.T) { + c.Convey("test list tenant business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(2) + bs := biz.NewBusinessUsecase(bRepo) + tr := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + tbs, err := tr.ListTenantBusiness(context.Background(), tid, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(tbs, c.ShouldNotBeEmpty) + c.So(len(tbs), c.ShouldBeGreaterThan, 0) + + patch := gomonkey.ApplyMethodReturn(repo, "TenantBusinessList", nil, fmt.Errorf("list error")) + defer patch.Reset() + _, err = tr.ListTenantBusiness(context.Background(), snk.GenerateIDString(), 1, 10) + patch.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "list error") + }) +} +func testDelTenantBusiness(t *testing.T) { + c.Convey("test del tenant business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + repo := data.NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bRepo := data.NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + // snk, _ := tiga.NewSnowflake(2) + bs := biz.NewBusinessUsecase(bRepo) + tr := biz.NewTenantUsecase(repo, bs, config.NewConfig(cfg.ReadConfig(env))) + + err := tr.DelTenantBusiness(context.Background(), tid, tenantBusinessId) + c.So(err, c.ShouldBeNil) + _, err = tr.GetTenantBusiness(context.Background(), tid, tenantBusinessId) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + + patch := gomonkey.ApplyMethodReturn(repo, "DelTenantBusiness", fmt.Errorf("del error")) + defer patch.Reset() + err = tr.DelTenantBusiness(context.Background(), tid, tenantBusinessId) + patch.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "del error") + }) +} + +func TestTenant(t *testing.T) { + t.Run("test add tenant", testAddTenant) + t.Run("test get tenant", testGetTenant) + t.Run("test patch tenant", testPatchTenant) + t.Run("test list tenant", testListTenant) + t.Run("test add tenant business", testAddTenantBusiness) + t.Run("test get tenant business", testGetTenantBusiness) + t.Run("test list tenant business", testListTenantBusiness) + t.Run("test del tenant business", testDelTenantBusiness) + t.Run("test del tenant", testDelTenant) + +} diff --git a/internal/biz/user.go b/internal/biz/user.go index ccf06b0..e488e62 100644 --- a/internal/biz/user.go +++ b/internal/biz/user.go @@ -50,6 +50,7 @@ func (u *UserUsecase) Add(ctx context.Context, users *api.Users) (err error) { users.Uid = u.snowflake.GenerateIDString() err = u.repo.Add(ctx, users) + return } func (u *UserUsecase) Get(ctx context.Context, key string) (*api.Users, error) { @@ -62,12 +63,13 @@ func (u *UserUsecase) Get(ctx context.Context, key string) (*api.Users, error) { func (u *UserUsecase) Update(ctx context.Context, model *api.Users) error { err := u.repo.Patch(ctx, model) if err != nil { - if strings.Contains(err.Error(), "not found") { - return gosdk.NewError(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") - } if strings.Contains(err.Error(), "Duplicate entry") { return gosdk.NewError(err, int32(api.UserSvrCode_USER_USERNAME_DUPLICATE_ERR), codes.AlreadyExists, "patch_app") } + if strings.Contains(err.Error(), "not found") { + return gosdk.NewError(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") + } + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_user") } return nil diff --git a/internal/data/app.go b/internal/data/app.go index d1c2ad3..41c3d11 100644 --- a/internal/data/app.go +++ b/internal/data/app.go @@ -23,7 +23,7 @@ func NewAppRepoImpl(curd biz.CURD, local *LayeredCache, cfg *config.Config) biz. func (r *appRepoImpl) Add(ctx context.Context, apps *api.Apps) error { - if err := r.curd.Add(ctx, apps, false); err != nil { + if err := r.curd.Add(ctx, apps, false, nil); err != nil { return fmt.Errorf("add app failed: %w", err) } key := r.cfg.GetAPPAccessKey(apps.AccessKey) @@ -57,11 +57,11 @@ func (r *appRepoImpl) Del(ctx context.Context, key string) error { return err } _ = r.local.Del(ctx, r.cfg.GetAPPAccessKey(app.AccessKey)) - return r.curd.Del(ctx, app, false) + return r.curd.Del(ctx, app, false, nil) } func (r *appRepoImpl) Patch(ctx context.Context, model *api.Apps) error { - return r.curd.Update(ctx, model, false) + return r.curd.Update(ctx, model, false, nil) } func (r *appRepoImpl) List(ctx context.Context, tags []string, status []api.APPStatus, page, pageSize int32) ([]*api.Apps, error) { apps := make([]*api.Apps, 0) @@ -100,7 +100,7 @@ func (a *appRepoImpl) GetSecret(ctx context.Context, accessKey string) (string, cacheKey := a.cfg.GetAPPAccessKey(accessKey) secretBytes, err := a.local.Get(ctx, cacheKey) secret := string(secretBytes) - if err != nil { + if err != nil || secret == "" { apps, err := a.Get(ctx, accessKey) if err != nil || apps.Secret == "" { return "", fmt.Errorf("get app secret failed: %w", err) diff --git a/internal/data/app_test.go b/internal/data/app_test.go index cbc9c9e..9ba296a 100644 --- a/internal/data/app_test.go +++ b/internal/data/app_test.go @@ -188,9 +188,9 @@ func patchTest(t *testing.T) { err = repo.Patch(context.Background(), updated) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "appid can not be updated") + c.So(err.Error(), c.ShouldContainSubstring, "can not be updated") - patch := gomonkey.ApplyFuncReturn(getPrimaryColumnValue, "", nil, fmt.Errorf("getPrimaryColumnValue error")) + patch := gomonkey.ApplyFuncReturn(getPrimaryColumnValue, nil, fmt.Errorf("getPrimaryColumnValue error")) defer patch.Reset() err = repo.Patch(context.Background(), updated) c.So(err, c.ShouldNotBeNil) @@ -316,10 +316,16 @@ func delTest(t *testing.T) { env = begonia.Env } repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) - - patch := gomonkey.ApplyFuncReturn(getPrimaryColumnValue, "", nil, fmt.Errorf("getPrimaryColumnValue,error")) - defer patch.Reset() + // set boolean err + patch4 := gomonkey.ApplyFuncReturn((*curdImpl).SetBoolean, fmt.Errorf("set boolean error")) + defer patch4.Reset() err := repo.Del(context.TODO(), appid) + patch4.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "set boolean error") + patch := gomonkey.ApplyFuncReturn(getPrimaryColumnValue, nil, fmt.Errorf("getPrimaryColumnValue,error")) + defer patch.Reset() + err = repo.Del(context.TODO(), appid) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "getPrimaryColumnValue,error") patch.Reset() diff --git a/internal/data/business.go b/internal/data/business.go new file mode 100644 index 0000000..49634ba --- /dev/null +++ b/internal/data/business.go @@ -0,0 +1,63 @@ +package data + +import ( + "context" + "fmt" + + "github.com/begonia-org/begonia/internal/biz" + "github.com/begonia-org/begonia/internal/pkg/config" + api "github.com/begonia-org/go-sdk/api/user/v1" + "github.com/spark-lence/tiga" +) + +type businessRepoImpl struct { + data *Data + cfg *config.Config + curd biz.CURD +} + +func NewBusinessRepoImpl(data *Data, curd biz.CURD, cfg *config.Config) biz.BusinessRepo { + return &businessRepoImpl{data: data, cfg: cfg, curd: curd} +} + +func (b *businessRepoImpl) Add(ctx context.Context, business *api.Business) error { + return b.curd.Add(ctx, business, false, nil) +} +func (b *businessRepoImpl) Get(ctx context.Context, key string) (*api.Business, error) { + business := &api.Business{} + err := b.curd.Get(ctx, business, false, "business_id=? or business_name=?", key, key) + if err != nil || business.BusinessId == "" { + return nil, fmt.Errorf("get business failed: %w or not found business", err) + } + return business, nil +} +func (b *businessRepoImpl) Del(ctx context.Context, key string) error { + business, err := b.Get(ctx, key) + if err != nil { + return err + } + return b.curd.Del(ctx, business, false, nil) +} +func (b *businessRepoImpl) List(ctx context.Context, tags []string, page, pageSize int32) ([]*api.Business, error) { + businesses := make([]*api.Business, 0) + query := "" + conds := make([]interface{}, 0) + if len(tags) > 0 { + query = "json_contains(json_array(?),tags)" + conds = append(conds, tags) + } + pagination := &tiga.Pagination{ + Page: page, + PageSize: pageSize, + Query: query, + Args: conds, + } + err := b.curd.List(ctx, &businesses, pagination) + if err != nil { + return nil, err + } + return businesses, nil +} +func (b *businessRepoImpl) Patch(ctx context.Context, model *api.Business) error { + return b.curd.Update(ctx, model, false, nil) +} diff --git a/internal/data/business_test.go b/internal/data/business_test.go new file mode 100644 index 0000000..496ebd0 --- /dev/null +++ b/internal/data/business_test.go @@ -0,0 +1,117 @@ +package data + +import ( + "context" + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/begonia-org/begonia" + cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" + api "github.com/begonia-org/go-sdk/api/user/v1" + c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" + "google.golang.org/protobuf/types/known/fieldmaskpb" +) + +var bid = "" +var bn = "" + +func testAddBusiness(t *testing.T) { + c.Convey("test add business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(2) + bid = snk.GenerateIDString() + bn = fmt.Sprintf("test-%s", bid) + business := &api.Business{ + BusinessId: bid, + BusinessName: fmt.Sprintf("test-%s", bid), + Description: "test", + Tags: []string{"test"}, + } + err := bs.Add(context.Background(), business) + c.So(err, c.ShouldBeNil) + }) +} +func testUpdateBusiness(t *testing.T) { + c.Convey("test update business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + business, err := bs.Get(context.Background(), bid) + c.So(err, c.ShouldBeNil) + business.Description = "update description" + business.UpdateMask = &fieldmaskpb.FieldMask{Paths: []string{"description"}} + err = bs.Patch(context.Background(), business) + c.So(err, c.ShouldBeNil) + }) +} +func testGetBusiness(t *testing.T) { + c.Convey("test get business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + business, err := bs.Get(context.Background(), bid) + c.So(err, c.ShouldBeNil) + c.So(business.BusinessName, c.ShouldEqual, bn) + + business, err = bs.Get(context.Background(), bn) + c.So(err, c.ShouldBeNil) + c.So(business.BusinessId, c.ShouldEqual, bid) + + _, err = bs.Get(context.Background(), "not found") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found business") + }) +} + +func testListBusiness(t *testing.T) { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + c.Convey("test list business", t, func() { + + businesses, err := bs.List(context.Background(), []string{"test", "test2"}, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(len(businesses), c.ShouldBeGreaterThan, 0) + }) + c.Convey("test list business fail", t, func() { + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Pagination, fmt.Errorf("pagination error")) + defer patch.Reset() + _, err := bs.List(context.Background(), []string{"not found"}, 1, 10) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "pagination error") + }) +} +func testDelBusiness(t *testing.T) { + c.Convey("test del business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + err := bs.Del(context.Background(), bid) + c.So(err, c.ShouldBeNil) + err = bs.Del(context.Background(), bn) + c.So(err, c.ShouldNotBeNil) + }) +} +func TestBusiness(t *testing.T) { + t.Run("test add business", testAddBusiness) + t.Run("test update business", testUpdateBusiness) + t.Run("test get business", testGetBusiness) + t.Run("test list business", testListBusiness) + t.Run("test del business", testDelBusiness) + +} diff --git a/internal/data/cache.go b/internal/data/cache.go index 599ae93..fce4b3c 100644 --- a/internal/data/cache.go +++ b/internal/data/cache.go @@ -75,11 +75,13 @@ func (l *LayeredCache) Get(ctx context.Context, key string) ([]byte, error) { return l.kv.Get(ctx, key) } func (l *LayeredCache) GetFromLocal(ctx context.Context, key string) ([]byte, error) { + // log.Printf("cache get from local %s ,with %p", key, l.kv) + values, err := l.kv.GetFromLocal(ctx, key) if err != nil { return nil, err } - + // log.Printf("get cache %s from local %v", key, values) for _, val := range values { if val, ok := val.([]byte); ok { return val, nil @@ -91,6 +93,7 @@ func (l *LayeredCache) Del(ctx context.Context, key string) error { return l.kv.Del(ctx, key) } func (l *LayeredCache) SetToLocal(ctx context.Context, key string, value []byte, exp time.Duration) error { + // log.Printf("cache set to local %s,%s,with %p", key,value, l.kv) return l.kv.SetToLocal(ctx, key, value, exp) } diff --git a/internal/data/curd.go b/internal/data/curd.go index 9ee59d6..f4fa34d 100644 --- a/internal/data/curd.go +++ b/internal/data/curd.go @@ -13,6 +13,7 @@ import ( "github.com/spark-lence/tiga" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" ) type curdImpl struct { @@ -48,7 +49,10 @@ func (c *curdImpl) SetBoolean(model biz.DeleteModel, name string) error { model.ProtoReflect().Set(field, protoreflect.ValueOfBool(true)) return nil } -func (c *curdImpl) Add(ctx context.Context, model biz.Model, needEncrypt bool) error { +func (c *curdImpl) BeginTx(ctx context.Context) *gorm.DB { + return c.db.Begin() +} +func (c *curdImpl) Add(ctx context.Context, model biz.Model, needEncrypt bool, tx *gorm.DB) error { if err := c.SetDatetimeAt(model, "created_at"); err != nil { return err } @@ -66,7 +70,7 @@ func (c *curdImpl) Add(ctx context.Context, model biz.Model, needEncrypt bool) e } } - return c.db.Create(ctx, model) + return c.db.Create(ctx, model, tx) } func (c *curdImpl) Get(ctx context.Context, model interface{}, needDecrypt bool, query string, args ...interface{}) error { if _, ok := model.(biz.DeleteModel); ok { @@ -89,20 +93,21 @@ func (c *curdImpl) Get(ctx context.Context, model interface{}, needDecrypt bool, } return nil } -func (c *curdImpl) Update(ctx context.Context, model biz.Model, needEncrypt bool) error { +func (c *curdImpl) Update(ctx context.Context, model biz.Model, needEncrypt bool, tx *gorm.DB) error { paths := make([]string, 0) updateMask := model.GetUpdateMask() if updateMask != nil { paths = updateMask.Paths } - key, val, err := getPrimaryColumnValue(model, "primary") + kv, err := getPrimaryColumnValue(model, "primary") if err != nil { return errors.Wrap(err, "get primary column value failed") } for _, path := range paths { - if path == key { - return fmt.Errorf("primary key %s can not be updated", key) + if k, ok := kv[path]; ok { + + return fmt.Errorf("primary key %s can not be updated", k) } } @@ -118,9 +123,14 @@ func (c *curdImpl) Update(ctx context.Context, model biz.Model, needEncrypt bool } } - err = c.db.UpdateSelectColumns(ctx, fmt.Sprintf("%s=%s", key, val), model, paths...) + query := make([]string, 0) + for k, v := range kv { + query = append(query, fmt.Sprintf("%s=%s", k, v)) + } + + err = c.db.UpdateSelectColumns(ctx, strings.Join(query, " and "), model, tx, paths...) if err != nil { - return fmt.Errorf("update model for %s=%v failed: %w", key, val, err) + return fmt.Errorf("update model for %v failed: %w", query, err) } return nil } @@ -133,9 +143,6 @@ func (c *curdImpl) renameUniqueFields(model biz.Model) ([]string, error) { modelVal = modelVal.Elem() } - if modelType.Kind() != reflect.Struct { - return nil, fmt.Errorf("%s not a struct type", modelType.Kind().String()) - } updated := make([]string, 0) // 遍历结构体的字段 for i := 0; i < modelType.NumField(); i++ { @@ -160,11 +167,16 @@ func (c *curdImpl) renameUniqueFields(model biz.Model) ([]string, error) { } return updated, nil } -func (c *curdImpl) Del(ctx context.Context, model interface{}, needEncrypt bool) error { - key, val, err := getPrimaryColumnValue(model, "primary") +func (c *curdImpl) Del(ctx context.Context, model interface{}, needEncrypt bool, tx *gorm.DB) error { + kv, err := getPrimaryColumnValue(model, "primary") if err != nil { return errors.Wrap(err, "get primary column value failed") } + query := []string{} + for k, v := range kv { + query = append(query, fmt.Sprintf("%s='%s'", k, v)) + } + if delModel, ok := c.assertDeletedModel(model); ok { if err := c.SetBoolean(delModel, "is_deleted"); err != nil { return err @@ -187,9 +199,13 @@ func (c *curdImpl) Del(ctx context.Context, model interface{}, needEncrypt bool) } } } - return c.db.UpdateSelectColumns(ctx, fmt.Sprintf("%s=%s", key, val), model, updated...) + err = c.db.UpdateSelectColumns(ctx, strings.Join(query, " and "), delModel, tx, updated...) + if err != nil && !strings.Contains(err.Error(), "no rows affected") { + return err + } + return nil } else { - return c.db.Delete(model, fmt.Sprintf("%s=?", key), val) + return c.db.Delete(model, tx, strings.Join(query, " and ")) } } func (c *curdImpl) assertDeletedModel(model interface{}) (biz.DeleteModel, bool) { diff --git a/internal/data/curd_test.go b/internal/data/curd_test.go index 85c1432..4ec50c9 100644 --- a/internal/data/curd_test.go +++ b/internal/data/curd_test.go @@ -1,15 +1,18 @@ package data import ( + "context" "testing" + "github.com/agiledragon/gomonkey/v2" api "github.com/begonia-org/go-sdk/api/app/v1" c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" ) func TestAssertDeletedModel(t *testing.T) { c.Convey("test assert deleted model", t, func() { - curd := &curdImpl{} + curd := &curdImpl{db: &tiga.MySQLDao{}} v, ok := curd.assertDeletedModel(&struct{}{}) c.So(ok, c.ShouldBeFalse) c.So(v, c.ShouldBeNil) @@ -24,22 +27,15 @@ func TestAssertDeletedModel(t *testing.T) { c.So(err, c.ShouldNotBeNil) err = curd.SetDatetimeAt(&api.Apps{}, "deleted_at_test") c.So(err, c.ShouldNotBeNil) + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Begin, nil) + defer patch.Reset() + curd.BeginTx(context.Background()) }) } func TestGetPrimaryColumnValueErr(t *testing.T) { c.Convey("test get primary column value err", t, func() { - _, _, err := getPrimaryColumnValue(make(map[string]interface{}), "primary") + _, err := getPrimaryColumnValue(make(map[string]interface{}), "primary") c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "not a struct type") - - _, _, err = getPrimaryColumnValue(&struct { - Primary string - Name string - }{ - Primary: "primary", - Name: "name", - }, "primary") - c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "not found primary column") }) } diff --git a/internal/data/data.go b/internal/data/data.go index 083322c..2ee90ee 100644 --- a/internal/data/data.go +++ b/internal/data/data.go @@ -59,6 +59,8 @@ var ProviderSet = wire.NewSet(NewMySQL, NewLayeredCache, NewDataLock, + NewBusinessRepoImpl, + NewTenantRepoImpl, NewAuthzRepoImpl, NewUserRepoImpl, NewEndpointRepoImpl, @@ -190,7 +192,7 @@ func NewData(mysql *tiga.MySQLDao, rdb *tiga.RedisDao, etcd *tiga.EtcdDao) *Data // } // return db.Commit().Error // } -func getPrimaryColumnValue(model interface{}, tagName string) (string, interface{}, error) { +func getPrimaryColumnValue(model interface{}, tagName string) (map[string]interface{}, error) { // 获取结构体类型 modelType := reflect.TypeOf(model) modelVal := reflect.ValueOf(model) @@ -200,9 +202,9 @@ func getPrimaryColumnValue(model interface{}, tagName string) (string, interface } if modelType.Kind() != reflect.Struct { - return "", "", fmt.Errorf("%s not a struct type", modelType.Kind().String()) + return nil, fmt.Errorf("%s not a struct type", modelType.Kind().String()) } - + fieldValue := make(map[string]interface{}) // 遍历结构体的字段 for i := 0; i < modelType.NumField(); i++ { field := modelType.Field(i) @@ -214,15 +216,20 @@ func getPrimaryColumnValue(model interface{}, tagName string) (string, interface tagParts := strings.Split(tag, ";") for _, part := range tagParts { kv := strings.Split(part, ":") - if len(kv) == 2 && strings.TrimSpace(kv[0]) == "column" { + if len(kv) == 2 && strings.TrimSpace(kv[0]) == "column" && !strings.Contains(tag, "primaryKey") { value := modelVal.Field(i).Interface() - return strings.TrimSpace(kv[1]), value, nil + fieldValue[strings.TrimSpace(kv[1])] = value + // return strings.TrimSpace(kv[1]), value, nil } } } } - return "", nil, fmt.Errorf("not found primary column") + // if len(fieldValue) == 0 { + // return nil, fmt.Errorf("not found primary column") + + // } + return fieldValue, nil } // func (d *Data) Update(ctx context.Context, model SourceType) error { diff --git a/internal/data/data_test.go b/internal/data/data_test.go index eacd7c1..7596201 100644 --- a/internal/data/data_test.go +++ b/internal/data/data_test.go @@ -8,8 +8,8 @@ import ( "github.com/begonia-org/begonia" cfg "github.com/begonia-org/begonia/config" - "github.com/spark-lence/tiga" user "github.com/begonia-org/go-sdk/api/user/v1" + "github.com/spark-lence/tiga" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -65,7 +65,7 @@ func setup() { } mysql := tiga.NewMySQLDao(conf) mysql.RegisterTimeSerializer() - err=mysql.GetModel(&user.Users{}).Where("`group` = ?", "test-user-01").Delete(&user.Users{}).Error + err = mysql.GetModel(&user.Users{}).Where("`group` = ?", "test-user-01").Delete(&user.Users{}).Error if err != nil { log.Fatalf("Failed to delete keys with prefix %s: %v", prefix, err) } diff --git a/internal/data/endpoint.go b/internal/data/endpoint.go index 066935b..e18e693 100644 --- a/internal/data/endpoint.go +++ b/internal/data/endpoint.go @@ -57,7 +57,7 @@ func (e *endpointRepoImpl) ServiceNameExists(ctx context.Context, name, id strin return fmt.Errorf("unmarshal service name error: %w", err) } if ep["uid"] != id { - return fmt.Errorf("%s service name already exists in %s",ep["uid"],id) + return fmt.Errorf("%s service name already exists in %s", ep["uid"], id) } return nil } @@ -247,7 +247,7 @@ func (e *endpointRepoImpl) PutTags(ctx context.Context, id string, tags []string ops = append(ops, clientv3.OpPut(srvKey, string(updated))) ok, err := e.data.PutEtcdWithTxn(ctx, ops) - if err != nil||!ok { + if err != nil || !ok { return fmt.Errorf("put tags fail: %w", err) } return nil diff --git a/internal/data/endpoint_test.go b/internal/data/endpoint_test.go index 490891d..4d99815 100644 --- a/internal/data/endpoint_test.go +++ b/internal/data/endpoint_test.go @@ -506,7 +506,7 @@ func checkServiceNameExistsTest(t *testing.T) { err = repo.ServiceNameExists(context.Background(), serviceName, endpointId) patch2.Reset() c.So(err, c.ShouldNotBeNil) - c.So(err.Error(),c.ShouldContainSubstring,"get endpoint fail") + c.So(err.Error(), c.ShouldContainSubstring, "get endpoint fail") err = repo.ServiceNameExists(context.Background(), serviceName, endpointId) c.So(err, c.ShouldBeNil) diff --git a/internal/data/file.go b/internal/data/file.go index b40be99..728ad48 100644 --- a/internal/data/file.go +++ b/internal/data/file.go @@ -3,7 +3,6 @@ package data import ( "context" "fmt" - "log" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/biz/file" @@ -29,11 +28,10 @@ func (f *fileRepoImpl) UpsertFile(ctx context.Context, in *api.Files) (bool, err mask = in.UpdateMask.Paths } // log.Printf("mask:%v", in.Uid) - return f.data.db.Upsert(ctx, in, mask...) + return f.data.db.Upsert(ctx, in, nil, mask...) } func (f *fileRepoImpl) DelFile(ctx context.Context, engine, bucket, key string) error { - // return f.curd.Del(ctx, &api.Files{Uid: fid},false) - return f.data.db.UpdateSelectColumns(ctx, &api.Files{Engine: engine, Bucket: bucket, Key: key}, &api.Files{IsDeleted: true}, "is_deleted") + return f.curd.Del(ctx, &api.Files{Engine: engine, Bucket: bucket, Key: key}, false, nil) } func (f *fileRepoImpl) UpsertBucket(ctx context.Context, bucket *api.Buckets) (bool, error) { bucket.UpdatedAt = timestamppb.Now() @@ -41,10 +39,10 @@ func (f *fileRepoImpl) UpsertBucket(ctx context.Context, bucket *api.Buckets) (b if bucket.UpdateMask != nil { mask = bucket.UpdateMask.Paths } - return f.data.db.Upsert(ctx, bucket, mask...) + return f.data.db.Upsert(ctx, bucket, nil, mask...) } func (f *fileRepoImpl) DelBucket(ctx context.Context, bucketId string) error { - return f.curd.Del(ctx, &api.Buckets{Uid: bucketId}, false) + return f.curd.Del(ctx, &api.Buckets{Uid: bucketId}, false, nil) } func (f *fileRepoImpl) GetFileById(ctx context.Context, fid string) (*api.Files, error) { file := &api.Files{Uid: fid} @@ -82,7 +80,7 @@ func (f *fileRepoImpl) List(ctx context.Context, page, pageSize int32, bucket, e } pagination.Args = append(pagination.Args, engine) } - log.Printf("query:%s,args:%s", pagination.Query, pagination.Args) + // log.Printf("query:%s,args:%s", pagination.Query, pagination.Args) err := f.curd.List(ctx, &files, pagination) if err != nil { diff --git a/internal/data/file_test.go b/internal/data/file_test.go index 15e8a23..ce95b77 100644 --- a/internal/data/file_test.go +++ b/internal/data/file_test.go @@ -66,7 +66,7 @@ func testGetFileById(t *testing.T) { file, err := f.GetFileById(context.Background(), fileFileId) c.So(err, c.ShouldBeNil) c.So(file, c.ShouldNotBeNil) - fk:=fmt.Sprintf("test-%s",fileFileId) + fk := fmt.Sprintf("test-%s", fileFileId) file, err = f.GetFile(context.Background(), fk, "test", fk) c.So(err, c.ShouldBeNil) c.So(file, c.ShouldNotBeNil) @@ -91,7 +91,7 @@ func testUpsertBucket(t *testing.T) { conf := cfg.ReadConfig(env) f := NewFileRepo(conf, gateway.Log) snk, _ := tiga.NewSnowflake(1) - bk:=fmt.Sprintf("test-%s",snk.GenerateIDString()) + bk := fmt.Sprintf("test-%s", snk.GenerateIDString()) bucket := &api.Buckets{ Engine: "test", diff --git a/internal/data/operator.go b/internal/data/operator.go index f0412c0..760482e 100644 --- a/internal/data/operator.go +++ b/internal/data/operator.go @@ -112,6 +112,7 @@ func (d *dataOperatorRepo) FlashUsersCache(ctx context.Context, prefix string, m kv := make([]interface{}, 0) for _, model := range models { key := fmt.Sprintf("%s:%s", prefix, model.Uid) + // log.Printf("缓存用户:%s", key) val, _ := protojson.Marshal(model) kv = append(kv, key, string(val)) } diff --git a/internal/data/tenant.go b/internal/data/tenant.go new file mode 100644 index 0000000..697f363 --- /dev/null +++ b/internal/data/tenant.go @@ -0,0 +1,115 @@ +package data + +import ( + "context" + "fmt" + + "github.com/begonia-org/begonia/internal/biz" + "github.com/begonia-org/begonia/internal/pkg/config" + api "github.com/begonia-org/go-sdk/api/user/v1" + "github.com/spark-lence/tiga" +) + +// type TenantRepo interface { +// Add(ctx context.Context, tenant *api.Tenants) error +// Get(ctx context.Context, key string) (*api.Tenants, error) +// Del(ctx context.Context, uidOrName string) error +// List(ctx context.Context, tags []string, status []api.USER_STATUS, page, pageSize int32) ([]*api.Tenants, error) +// Patch(ctx context.Context, model *api.Tenants) error +// } + +type tenantRepoImpl struct { + data *Data + cfg *config.Config + curd biz.CURD +} + +func (t *tenantRepoImpl) Add(ctx context.Context, tenant *api.Tenants) error { + err := t.curd.Add(ctx, tenant, true, nil) + return err +} +func (t *tenantRepoImpl) Get(ctx context.Context, key string) (*api.Tenants, error) { + tenant := &api.Tenants{} + err := t.curd.Get(ctx, tenant, false, "tenant_id=? or tenant_name=?", key, key) + if err != nil || tenant.TenantId == "" { + return nil, err + } + return tenant, nil +} +func (t *tenantRepoImpl) Del(ctx context.Context, tenantId string) error { + return t.curd.Del(ctx, &api.Tenants{TenantId: tenantId}, false, nil) +} +func (t *tenantRepoImpl) List(ctx context.Context, tags []string, status []api.TENANTS_STATUS, page, pageSize int32) ([]*api.Tenants, error) { + tenants := make([]*api.Tenants, 0) + query := "" + conds := make([]interface{}, 0) + if len(tags) > 0 { + query = "json_contains(json_array(?),tags)" + conds = append(conds, tags) + } + if len(status) > 0 { + if query == "" { + query = "status in (?)" + + } else { + query += " and status in (?)" + } + conds = append(conds, status) + } + pagination := &tiga.Pagination{ + Page: page, + PageSize: pageSize, + Query: query, + Args: conds, + } + err := t.curd.List(ctx, &tenants, pagination) + if err != nil { + return nil, err + } + return tenants, nil +} +func (t *tenantRepoImpl) Patch(ctx context.Context, model *api.Tenants) error { + return t.curd.Update(ctx, model, false, nil) +} +func (t *tenantRepoImpl) AddBusiness(ctx context.Context, tenantBusiness *api.TenantsBusiness) error { + tenant, err := t.Get(ctx, tenantBusiness.TenantId) + if err != nil || tenant == nil || tenant.TenantId == "" { + return fmt.Errorf("get tenant failed: %w or not found tenant before add business", err) + } + bs := &api.Business{BusinessId: tenantBusiness.BusinessId} + err = t.curd.Get(ctx, bs, false, "business_id=?", tenantBusiness.BusinessId) + if err != nil || bs.BusinessId == "" { + return fmt.Errorf("get business failed: %w or not found business before add business", err) + + } + return t.curd.Add(ctx, tenantBusiness, false, nil) +} + +func (t *tenantRepoImpl) DelTenantBusiness(ctx context.Context, tenantId, businessId string) error { + return t.curd.Del(ctx, &api.TenantsBusiness{TenantId: tenantId, BusinessId: businessId}, false, nil) +} +func (t *tenantRepoImpl) GetTenantBusiness(ctx context.Context, tenant, business string) (*api.TenantsBusiness, error) { + tenantBusiness := &api.TenantsBusiness{} + err := t.curd.Get(ctx, tenantBusiness, false, "(tenant_id=? or tenant_name=?) and (business_id=? or business_name=?)", tenant, tenant, business, business) + if err != nil || tenantBusiness.BusinessId == "" || tenantBusiness.TenantId == "" { + return nil, fmt.Errorf("get tenant business failed: %w or not found", err) + } + return tenantBusiness, nil +} +func (t *tenantRepoImpl) TenantBusinessList(ctx context.Context, tenantId string, page, pageSize int32) ([]*api.TenantsBusiness, error) { + tenantBusinesses := make([]*api.TenantsBusiness, 0) + pagination := &tiga.Pagination{ + Page: page, + PageSize: pageSize, + Query: "tenant_id=?", + Args: []interface{}{tenantId}, + } + err := t.curd.List(ctx, &tenantBusinesses, pagination) + if err != nil { + return nil, err + } + return tenantBusinesses, nil +} +func NewTenantRepoImpl(data *Data, cfg *config.Config, curd biz.CURD) biz.TenantRepo { + return &tenantRepoImpl{data: data, cfg: cfg, curd: curd} +} diff --git a/internal/data/tenant_test.go b/internal/data/tenant_test.go new file mode 100644 index 0000000..2f83087 --- /dev/null +++ b/internal/data/tenant_test.go @@ -0,0 +1,225 @@ +package data + +import ( + "context" + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/begonia-org/begonia" + cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" + api "github.com/begonia-org/go-sdk/api/user/v1" + c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" + "google.golang.org/protobuf/types/known/fieldmaskpb" +) + +var tid = "" +var tn = "" +var tsn = "" + +func testAddTenant(t *testing.T) { + c.Convey("test add tenant", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(2) + tid = snk.GenerateIDString() + tn = fmt.Sprintf("test-%s", tid) + tenant := &api.Tenants{ + TenantId: tid, + TenantName: fmt.Sprintf("test-%s", tid), + Description: "test tenant", + Tags: []string{"test"}, + Email: fmt.Sprintf("%s@example.com", tn), + } + err := bs.Add(context.Background(), tenant) + c.So(err, c.ShouldBeNil) + }) +} +func testUpdateTenant(t *testing.T) { + c.Convey("test update tenant", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + tenant, err := bs.Get(context.Background(), tid) + c.So(err, c.ShouldBeNil) + tenant.Description = "update description" + tenant.UpdateMask = &fieldmaskpb.FieldMask{Paths: []string{"description"}} + err = bs.Patch(context.Background(), tenant) + c.So(err, c.ShouldBeNil) + }) +} +func testGetTenant(t *testing.T) { + c.Convey("test get tenant", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + tenant, err := bs.Get(context.Background(), tid) + c.So(err, c.ShouldBeNil) + c.So(tenant.TenantName, c.ShouldEqual, tn) + + tenant, err = bs.Get(context.Background(), tn) + c.So(err, c.ShouldBeNil) + c.So(tenant.TenantId, c.ShouldEqual, tid) + + _, err = bs.Get(context.Background(), "not found") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + }) +} + +func testListTenant(t *testing.T) { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + c.Convey("test list tenant", t, func() { + + tenantes, err := bs.List(context.Background(), []string{"test", "test2"}, []api.TENANTS_STATUS{api.TENANTS_STATUS_TENANTS_ACTIVE}, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(len(tenantes), c.ShouldBeGreaterThan, 0) + }) + c.Convey("test list tenant fail", t, func() { + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Pagination, fmt.Errorf("pagination error")) + defer patch.Reset() + _, err := bs.List(context.Background(), []string{"not found"}, []api.TENANTS_STATUS{api.TENANTS_STATUS_TENANTS_DELETED}, 1, 10) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "pagination error") + }) +} +func testAddTenantBusiness(t *testing.T) { + c.Convey("test add tenant", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + tr := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + bs := NewBusinessRepo(cfg.ReadConfig(env), gateway.Log) + snk, _ := tiga.NewSnowflake(2) + business := &api.Business{ + BusinessId: snk.GenerateIDString(), + BusinessName: fmt.Sprintf("test-data-%s", snk.GenerateIDString()), + Description: "test business", + } + tsn = business.BusinessName + err := bs.Add(context.Background(), business) + c.So(err, c.ShouldBeNil) + tenantBusiness := &api.TenantsBusiness{ + TenantId: tid, + BusinessId: business.BusinessId, + BusinessName: business.BusinessName, + TenantName: tn, + Plan: "Free", + } + err = tr.AddBusiness(context.Background(), tenantBusiness) + c.So(err, c.ShouldBeNil) + tb := &api.TenantsBusiness{ + TenantId: snk.GenerateIDString(), + BusinessId: business.BusinessId, + BusinessName: business.BusinessName, + TenantName: tn, + Plan: "Free", + } + err = tr.AddBusiness(context.Background(), tb) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + + tb2 := &api.TenantsBusiness{ + TenantId: tid, + BusinessId: snk.GenerateIDString(), + BusinessName: business.BusinessName, + TenantName: tn, + Plan: "Free", + } + err = tr.AddBusiness(context.Background(), tb2) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + }) +} +func testGetTenantBusiness(t *testing.T) { + c.Convey("test get tenant business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + snk, _ := tiga.NewSnowflake(2) + tr := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + tb, err := tr.GetTenantBusiness(context.Background(), tid, tsn) + c.So(err, c.ShouldBeNil) + + c.So(tb.BusinessName, c.ShouldEqual, tsn) + _, err = tr.GetTenantBusiness(context.Background(), snk.GenerateIDString(), snk.GenerateIDString()) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "not found") + + }) +} +func testTenantBusinessList(t *testing.T) { + c.Convey("test tenant business list", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + tr := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + tbs, err := tr.TenantBusinessList(context.Background(), tid, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(len(tbs), c.ShouldBeGreaterThan, 0) + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Pagination, fmt.Errorf("pagination error")) + defer patch.Reset() + _, err = tr.TenantBusinessList(context.Background(), "not found", 1, 10) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "pagination error") + }) + +} +func testDelTenantBusiness(t *testing.T) { + c.Convey("test del tenant business", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + tr := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + err := tr.DelTenantBusiness(context.Background(), tid, tsn) + c.So(err, c.ShouldBeNil) + }) + +} +func testDelTenant(t *testing.T) { + c.Convey("test del tenant", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + bs := NewTenantRepo(cfg.ReadConfig(env), gateway.Log) + err := bs.Del(context.Background(), tid) + c.So(err, c.ShouldBeNil) + err = bs.Del(context.Background(), tn) + c.So(err, c.ShouldBeNil) + + err = bs.Del(context.Background(), "not found") + c.So(err, c.ShouldBeNil) + // c.So(err.Error(), c.ShouldContainSubstring, "not found") + + }) +} +func TestTenant(t *testing.T) { + t.Run("test add tenant", testAddTenant) + t.Run("test update tenant", testUpdateTenant) + t.Run("test get tenant", testGetTenant) + t.Run("test list tenant", testListTenant) + t.Run("test add tenant business", testAddTenantBusiness) + t.Run("test get tenant business", testGetTenantBusiness) + t.Run("test tenant business list", testTenantBusinessList) + t.Run("test del tenant business", testDelTenantBusiness) + t.Run("test del tenant", testDelTenant) + +} diff --git a/internal/data/user.go b/internal/data/user.go index ee68480..32320f6 100644 --- a/internal/data/user.go +++ b/internal/data/user.go @@ -25,9 +25,46 @@ func NewUserRepoImpl(data *Data, local *LayeredCache, curd biz.CURD, cfg *config func (r *userRepoImpl) Add(ctx context.Context, user *api.Users) error { - err := r.curd.Add(ctx, user, true) + err := r.curd.Add(ctx, user, true, nil) return err } + +// func (u *userRepoImpl) AddUserWithTenant(ctx context.Context, user *api.Users) error { +// if user.TenantId != "" && user.TenantId != user.Uid { +// ten, err := u.GetTenant(ctx, user.TenantId) +// if err != nil || ten == nil || ten.TenantId == "" { +// return fmt.Errorf("get tenant before add new user failed: %w or tenant not found", err) +// } +// } +// tx := u.curd.BeginTx(ctx) +// defer func() { +// if err := recover(); err != nil { +// tx.Rollback() +// } +// }() +// if user.TenantId==""{ +// user.TenantId = user.Uid +// tenant := &api.Tenants{ +// TenantId: user.TenantId, +// TenantName: user.Name, + +// } +// err := u.curd.Add(ctx, tenant, false, tx) +// if err != nil { +// tx.Rollback() +// return fmt.Errorf("add tenant failed: %w", err) + +// } +// } + +// err := u.curd.Add(ctx, user, false, tx) +// if err != nil { +// tx.Rollback() +// return fmt.Errorf("add user failed: %w", err) +// } + +// return err +// } func (r *userRepoImpl) Get(ctx context.Context, key string) (*api.Users, error) { app := &api.Users{} @@ -43,12 +80,12 @@ func (r *userRepoImpl) Del(ctx context.Context, key string) error { if err != nil { return err } - err = r.curd.Del(ctx, user, true) + err = r.curd.Del(ctx, user, true, nil) return err } func (r *userRepoImpl) Patch(ctx context.Context, model *api.Users) error { - return r.curd.Update(ctx, model, true) + return r.curd.Update(ctx, model, true, nil) } func (r *userRepoImpl) List(ctx context.Context, dept []string, status []api.USER_STATUS, page, pageSize int32) ([]*api.Users, error) { apps := make([]*api.Users, 0) diff --git a/internal/data/wire.go b/internal/data/wire.go index 7ae4055..cb93d0e 100644 --- a/internal/data/wire.go +++ b/internal/data/wire.go @@ -47,3 +47,10 @@ func NewLocker(cfg *tiga.Configuration, log logger.Logger, key string, ttl time. func NewFileRepo(cfg *tiga.Configuration, log logger.Logger) file.FileRepo { panic(wire.Build(ProviderSet, config.NewConfig)) } +func NewBusinessRepo(cfg *tiga.Configuration, log logger.Logger) biz.BusinessRepo { + panic(wire.Build(ProviderSet, config.NewConfig)) +} + +func NewTenantRepo(cfg *tiga.Configuration, log logger.Logger) biz.TenantRepo { + panic(wire.Build(ProviderSet, config.NewConfig)) +} diff --git a/internal/data/wire_gen.go b/internal/data/wire_gen.go index 63a45f2..866d5d9 100644 --- a/internal/data/wire_gen.go +++ b/internal/data/wire_gen.go @@ -105,3 +105,25 @@ func NewFileRepo(cfg *tiga.Configuration, log logger.Logger) file.FileRepo { fileRepo := NewFileRepoImpl(data, curd) return fileRepo } + +func NewBusinessRepo(cfg *tiga.Configuration, log logger.Logger) biz.BusinessRepo { + mySQLDao := NewMySQL(cfg) + redisDao := NewRDB(cfg) + etcdDao := NewEtcd(cfg) + data := NewData(mySQLDao, redisDao, etcdDao) + configConfig := config.NewConfig(cfg) + curd := NewCurdImpl(mySQLDao, configConfig) + businessRepo := NewBusinessRepoImpl(data, curd, configConfig) + return businessRepo +} + +func NewTenantRepo(cfg *tiga.Configuration, log logger.Logger) biz.TenantRepo { + mySQLDao := NewMySQL(cfg) + redisDao := NewRDB(cfg) + etcdDao := NewEtcd(cfg) + data := NewData(mySQLDao, redisDao, etcdDao) + configConfig := config.NewConfig(cfg) + curd := NewCurdImpl(mySQLDao, configConfig) + tenantRepo := NewTenantRepoImpl(data, configConfig, curd) + return tenantRepo +} diff --git a/internal/middleware/auth/ak_test.go b/internal/middleware/auth/ak_test.go index 9517232..62168ae 100644 --- a/internal/middleware/auth/ak_test.go +++ b/internal/middleware/auth/ak_test.go @@ -15,7 +15,6 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/middleware/auth" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" c "github.com/smartystreets/goconvey/convey" @@ -38,7 +37,7 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { ak.SetPriority(1) c.So(ak.Name(), c.ShouldEqual, "ak_auth") c.So(ak.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -58,11 +57,18 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { return fmt.Errorf("metadata not exists in context") }) c.So(err.Error(), c.ShouldContainSubstring, fmt.Errorf("metadata not exists in context").Error()) - - patch := gomonkey.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamRequestBefore, nil, nil) + patch := gomonkey.ApplyMethodReturn(akBiz, "AppValidator", "test", nil) + patch = patch.ApplyMethodReturn(akBiz, "GetAppOwner", "test", nil) + // patch := gomonkey.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamRequestBefore, nil, nil) patch = patch.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamResponseAfter, fmt.Errorf("StreamResponseAfter err")) defer patch.Reset() - err = ak.StreamInterceptor(context.Background(), &testStream{ctx: context.Background()}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, func(srv any, stream grpc.ServerStream) error { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + err = ak.StreamInterceptor(ctx, &testStream{ctx: ctx}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, func(srv any, ss grpc.ServerStream) error { + md, _ := metadata.FromIncomingContext(ss.Context()) + if len(md.Get(gosdk.HeaderXIdentity)) == 0 || md.Get(gosdk.HeaderXIdentity)[0] == "" { + t.Error("identity not exists in context") + return fmt.Errorf("identity not exists in context") + } return nil }) @@ -76,6 +82,29 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { }) patch2.Reset() c.So(err.Error(), c.ShouldContainSubstring, fmt.Errorf("StreamRequestBefore err").Error()) + // do not need validate + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + _, err = ak.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + + // NO CONTEXT + _, err = ak.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + // get owner error + patch3 := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, "", fmt.Errorf("get owner error")) + defer patch3.Reset() + // get owner error + in := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + _, err = ak.StreamRequestBefore(in, &testStream{ctx: in}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, nil) + patch3.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get app owner error") + }) } func TestRequestBeforeErr(t *testing.T) { @@ -133,7 +162,7 @@ func TestValidateStream(t *testing.T) { ak := auth.NewAccessKeyAuth(akBiz, cnf, gateway.Log) patch := gomonkey.ApplyFuncReturn(gosdk.NewGatewayRequestFromGrpc, nil, fmt.Errorf("NewGatewayRequestFromGrpc err")) defer patch.Reset() - _, err := ak.ValidateStream(context.TODO(), nil, "", nil) + _, err := ak.ValidateStream(context.TODO(), nil, "") patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "NewGatewayRequestFromGrpc err") diff --git a/internal/middleware/auth/aksk.go b/internal/middleware/auth/aksk.go index 006ea6b..905b57b 100644 --- a/internal/middleware/auth/aksk.go +++ b/internal/middleware/auth/aksk.go @@ -4,9 +4,9 @@ import ( "context" "strings" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" "github.com/begonia-org/go-sdk/logger" "google.golang.org/grpc" @@ -34,7 +34,7 @@ func NewAccessKeyAuth(app *biz.AccessKeyAuth, config *config.Config, log logger. } func IfNeedValidate(ctx context.Context, fullMethod string) bool { - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(strings.ToUpper(fullMethod)) if router == nil { return false @@ -63,32 +63,39 @@ func (a *AccessKeyAuthMiddleware) RequestBefore(ctx context.Context, info *grpc. if !ok { md = metadata.MD{} } - // md.Set(gosdk.HeaderXIdentity, owner) - md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, owner)) + md.Set(gosdk.HeaderXIdentity, owner) + // md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, owner)) ctx = metadata.NewIncomingContext(ctx, md) // md2, _ := metadata.FromIncomingContext(ctx) - return ctx, nil } -func (a *AccessKeyAuthMiddleware) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { - ctx,err:= a.RequestBefore(ctx, &grpc.UnaryServerInfo{FullMethod: fullName}, req) - if err!=nil{ - return ctx,err - } - md, _ := metadata.FromIncomingContext(ctx) - if identity := md.Get(gosdk.HeaderXIdentity);len(identity)>0{ - headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity[0]) +func (a *AccessKeyAuthMiddleware) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { + ctx, err := a.RequestBefore(ctx, &grpc.UnaryServerInfo{FullMethod: fullName}, req) + if err != nil { + return ctx, err } - return ctx,nil - + + return ctx, nil + } func (a *AccessKeyAuthMiddleware) StreamRequestBefore(ctx context.Context, ss grpc.ServerStream, info *grpc.StreamServerInfo, req interface{}) (grpc.ServerStream, error) { - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) - // defer grpcStream.Release() + if in, ok := metadata.FromIncomingContext(ctx); ok { + if ak := in.Get(gosdk.HeaderXAccessKey); len(ak) > 0 { + identity, err := a.app.GetAppOwner(ctx, ak[0]) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "get app owner error,%v", err) + } + md, _ := metadata.FromIncomingContext(ctx) + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, md) + ctx = metadata.NewOutgoingContext(ctx, md) + } + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) return grpcStream, nil } @@ -104,6 +111,7 @@ func (a *AccessKeyAuthMiddleware) UnaryInterceptor(ctx context.Context, req any, defer func() { _ = a.ResponseAfter(ctx, info, req, resp) }() + resp, err = handler(ctx, req) return resp, err @@ -112,10 +120,12 @@ func (a *AccessKeyAuthMiddleware) StreamInterceptor(srv interface{}, ss grpc.Ser if !IfNeedValidate(ss.Context(), info.FullMethod) { return handler(srv, ss) } + grpcStream, err := a.StreamRequestBefore(ss.Context(), ss, info, srv) if err != nil { return err } + // log.Printf("AccessKeyAuthMiddleware StreamInterceptor") defer func() { err := a.StreamResponseAfter(ss.Context(), ss, info) if err != nil { @@ -146,3 +156,34 @@ func (a *AccessKeyAuthMiddleware) Priority() int { func (a *AccessKeyAuthMiddleware) Name() string { return a.name } +func (a *AccessKeyAuthMiddleware) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") + } + if AK := md.Get(gosdk.HeaderXAccessKey); len(AK) > 0 { + accessKey := AK[0] + identity, err := a.app.GetAppOwner(ctx, accessKey) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "get app owner error,%v", err) + } + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewOutgoingContext(ctx, md) + in, ok := metadata.FromIncomingContext(ctx) + if !ok { + in = metadata.MD{} + } + in.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, in) + // ctx = metadata.NewIncomingContext() + } + st, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, status.Errorf(codes.Internal, "streamer error,%v", err) + } + return st, nil +} diff --git a/internal/middleware/auth/apikey.go b/internal/middleware/auth/apikey.go index 6af3f66..0039a6e 100644 --- a/internal/middleware/auth/apikey.go +++ b/internal/middleware/auth/apikey.go @@ -3,7 +3,6 @@ package auth import ( "context" "fmt" - "strings" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/pkg" @@ -48,7 +47,7 @@ func (a *ApiKeyAuthImpl) UnaryInterceptor(ctx context.Context, req any, info *gr if err != nil { return nil, gosdk.NewError(fmt.Errorf("query uid base on apikey get error:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } - md,_:=metadata.FromIncomingContext(ctx) + md, _ := metadata.FromIncomingContext(ctx) md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, identity)) ctx = metadata.NewIncomingContext(ctx, md) @@ -66,22 +65,31 @@ func NewApiKeyAuth(config *config.Config, authz *biz.AuthzUsecase) ApiKeyAuth { } func (a *ApiKeyAuthImpl) check(ctx context.Context) (string, error) { md, ok := metadata.FromIncomingContext(ctx) - if !ok { + out, outOK := metadata.FromOutgoingContext(ctx) + if !ok && !outOK { return "", gosdk.NewError(status.Errorf(codes.Unauthenticated, "metadata not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } // authorization := a.GetAuthorizationFromMetadata(md) apikeys := md.Get(gosdk.HeaderXApiKey) - if len(apikeys) == 0 { + outAPIKeys := out.Get(gosdk.HeaderXApiKey) + if len(apikeys) == 0 && len(outAPIKeys) == 0 { return "", gosdk.NewError(status.Errorf(codes.Unauthenticated, "apikey not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } - apikey := apikeys[0] + + apikey := "" + if len(apikeys) != 0 { + apikey = apikeys[0] + } else if len(outAPIKeys) != 0 { + apikey = outAPIKeys[0] + } + if apikey != a.config.GetAdminAPIKey() { return "", gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } return apikey, nil } -func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { apikey := "" var err error if apikey, err = a.check(ctx); err == nil && apikey != "" { @@ -89,11 +97,11 @@ func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fu if err != nil { return ctx, gosdk.NewError(fmt.Errorf("query user id base on apikey err:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } - md,_:=metadata.FromIncomingContext(ctx) + md, _ := metadata.FromIncomingContext(ctx) md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, identity)) - headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity) - return metadata.NewIncomingContext(ctx,md), err + // headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity) + return metadata.NewIncomingContext(ctx, md), err } return ctx, err } @@ -102,7 +110,46 @@ func (a *ApiKeyAuthImpl) StreamInterceptor(srv interface{}, ss grpc.ServerStream if !IfNeedValidate(ss.Context(), info.FullMethod) { return handler(srv, ss) } - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) + ctx := ss.Context() + if apikey, err := a.check(ctx); err == nil && apikey != "" { + identity, err := a.authz.GetIdentity(ctx, gosdk.ApiKeyType, apikey) + if err != nil { + return gosdk.NewError(fmt.Errorf("query user id base on apikey err:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + } + md, _ := metadata.FromIncomingContext(ctx) + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, md) + ctx = metadata.AppendToOutgoingContext(ctx, gosdk.HeaderXIdentity, identity) + } else { + return gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) defer grpcStream.Release() - return handler(srv, grpcStream) + err := handler(srv, grpcStream) + + return err +} +func (a *ApiKeyAuthImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + if apikey, err := a.check(ctx); err == nil && apikey != "" { + identity, err := a.authz.GetIdentity(ctx, gosdk.ApiKeyType, apikey) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("query user id base on apikey err:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + } + md, _ := metadata.FromOutgoingContext(ctx) + md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, identity)) + ctx = metadata.NewOutgoingContext(ctx, md) + in, ok := metadata.FromIncomingContext(ctx) + if !ok { + in = metadata.New(make(map[string]string)) + } + in.Set(gosdk.HeaderXIdentity, identity) + // log.Printf("incoming identity:%s", identity) + ctx = metadata.NewIncomingContext(ctx, in) + return streamer(ctx, desc, cc, method, opts...) + } + return nil, gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } diff --git a/internal/middleware/auth/apikey_test.go b/internal/middleware/auth/apikey_test.go index 46312ff..8c23a44 100644 --- a/internal/middleware/auth/apikey_test.go +++ b/internal/middleware/auth/apikey_test.go @@ -17,7 +17,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" c "github.com/smartystreets/goconvey/convey" @@ -39,7 +38,7 @@ func TestAPIKeyUnaryInterceptor(t *testing.T) { c.So(apikey.Name(), c.ShouldEqual, "api_key_auth") c.So(apikey.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -101,6 +100,18 @@ func TestAPIKeyUnaryInterceptor(t *testing.T) { patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "get user id base on apikey error") + // ValidateStream error + patch2 := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get user id base on apikey error")) + defer patch2.Reset() + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = apikey.(*auth.ApiKeyAuthImpl).ValidateStream(ctx, nil, "/integration.TestService/Get") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get user id base on apikey error") + patch2.Reset() + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) + _, err = apikey.(*auth.ApiKeyAuthImpl).ValidateStream(ctx, nil, "/integration.TestService/Get") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) }) } func newAuthzBiz() *biz.AuthzUsecase { @@ -131,7 +142,7 @@ func TestApiKeyStreamInterceptor(t *testing.T) { c.So(apikey.Name(), c.ShouldEqual, "api_key_auth") c.So(apikey.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -155,12 +166,13 @@ func TestApiKeyStreamInterceptor(t *testing.T) { ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())), }}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { - err:=ss.RecvMsg(srv) md, _ := metadata.FromIncomingContext(ss.Context()) if len(md.Get(gosdk.HeaderXIdentity)) == 0 || md.Get(gosdk.HeaderXIdentity)[0] == "" { t.Error("identity not exists in context") return fmt.Errorf("identity not exists in context") } + err := ss.RecvMsg(srv) + return err }) diff --git a/internal/middleware/auth/auth.go b/internal/middleware/auth/auth.go index acc21cd..9a0f8e3 100644 --- a/internal/middleware/auth/auth.go +++ b/internal/middleware/auth/auth.go @@ -31,6 +31,7 @@ func NewAuth(ak *AccessKeyAuthMiddleware, jwt *JWTAuth, apikey ApiKeyAuth) gosdk } func (a *Auth) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + // fmt.Print("auth unary interceptor \n") if !IfNeedValidate(ctx, info.FullMethod) { return handler(ctx, req) } @@ -57,13 +58,14 @@ func (a *Auth) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnarySe func (a *Auth) StreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if !IfNeedValidate(ss.Context(), info.FullMethod) { + // log.Printf("no need stream validate %s", info.FullMethod) return handler(srv, ss) } md, ok := metadata.FromIncomingContext(ss.Context()) if !ok { return status.Errorf(codes.Unauthenticated, "metadata not exists in context") } - xApiKey := md.Get("x-api-key") + xApiKey := md.Get(gosdk.HeaderXApiKey) if len(xApiKey) != 0 { return a.apikey.StreamInterceptor(srv, ss, info, handler) } @@ -89,3 +91,28 @@ func (a *Auth) Priority() int { func (a *Auth) Name() string { return a.name } + +func (a *Auth) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + // log.Printf("no need stream validate %s", info.FullMethod) + return streamer(ctx, desc, cc, method, opts...) + } + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") + } + xApiKey := md.Get("x-api-key") + if len(xApiKey) != 0 { + return a.apikey.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) + } + authorization := a.jwt.GetAuthorizationFromMetadata(md) + + if authorization == "" { + return nil, gosdk.NewError(pkg.ErrTokenMissing, int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") + } + if strings.Contains(authorization, "Bearer") { + return a.jwt.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) + + } + return a.ak.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) +} diff --git a/internal/middleware/auth/auth_test.go b/internal/middleware/auth/auth_test.go index 1abb2ae..69eb004 100644 --- a/internal/middleware/auth/auth_test.go +++ b/internal/middleware/auth/auth_test.go @@ -7,10 +7,12 @@ import ( "crypto/cipher" "crypto/rand" "crypto/rsa" + "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" + "io" "log" "net/http" "os" @@ -29,7 +31,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/pkg/utils" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/app/v1" @@ -44,6 +45,21 @@ import ( type testStream struct { ctx context.Context } +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func (t *testStream) SetHeader(metadata.MD) error { return nil @@ -124,8 +140,8 @@ func getJWT() string { user := data.NewUserRepo(config, gateway.Log) userAuth := crypto.NewUsersAuth(cnf) authzRepo := data.NewAuthzRepo(config, gateway.Log) - appRepo:=data.NewAppRepo(config,gateway.Log) - authz := biz.NewAuthzUsecase(authzRepo, user,appRepo, gateway.Log, userAuth, cnf) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) adminUser := cnf.GetDefaultAdminName() adminPasswd := cnf.GetDefaultAdminPasswd() _, filename, _, _ := runtime.Caller(0) @@ -217,12 +233,12 @@ func getMid() gosdk.LocalPlugin { user := data.NewUserRepo(config, gateway.Log) userAuth := crypto.NewUsersAuth(cnf) authzRepo := data.NewAuthzRepo(config, gateway.Log) - appRepo:=data.NewAppRepo(config,gateway.Log) - authz := biz.NewAuthzUsecase(authzRepo, user,appRepo, gateway.Log, userAuth, cnf) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) jwt := auth.NewJWTAuth(cnf, tiga.NewRedisDao(config), authz, gateway.Log) ak := auth.NewAccessKeyAuth(akBiz, cnf, gateway.Log) - apiKey := auth.NewApiKeyAuth(cnf,authz) + apiKey := auth.NewApiKeyAuth(cnf, authz) mid := auth.NewAuth(ak, jwt, apiKey) return mid @@ -238,22 +254,22 @@ func TestUnaryInterceptor(t *testing.T) { mid := getMid() ctx := context.Background() handler := func(ctx context.Context, req interface{}) (interface{}, error) { - md,ok:=metadata.FromIncomingContext(ctx) - if !ok{ - return nil,fmt.Errorf("no metadata") + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, fmt.Errorf("no metadata") } - if identify:=md.Get(gosdk.HeaderXIdentity);len(identify)==0||identify[0]==""{ - return nil,fmt.Errorf("no app identity") + if identify := md.Get(gosdk.HeaderXIdentity); len(identify) == 0 || identify[0] == "" { + return nil, fmt.Errorf("no app identity") } - XAccessKey:=md.Get(gosdk.HeaderXAccessKey) - XApiKey:=md.Get(gosdk.HeaderXApiKey) - XAuthz:=md.Get("authorization") - if len(XAccessKey)==0 && len(XApiKey)==0 && len(XAuthz)==0{ - return nil,fmt.Errorf("no app auth key") + XAccessKey := md.Get(gosdk.HeaderXAccessKey) + XApiKey := md.Get(gosdk.HeaderXApiKey) + XAuthz := md.Get("authorization") + if len(XAccessKey) == 0 && len(XApiKey) == 0 && len(XAuthz) == 0 { + return nil, fmt.Errorf("no app auth key") } return nil, nil } - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -288,8 +304,11 @@ func TestUnaryInterceptor(t *testing.T) { c.Convey("TestUnaryInterceptor aksk", t, func() { u := &v1.Users{} bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) access, secret, appid := readInitAPP() sgin := gosdk.NewAppAuthSigner(access, secret) gwReq, err := gosdk.NewGatewayRequestFromHttp(req) @@ -342,7 +361,7 @@ func TestStreamInterceptor(t *testing.T) { mid := getMid() ctx := context.Background() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -352,15 +371,18 @@ func TestStreamInterceptor(t *testing.T) { ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { - err:= ss.RecvMsg(srv) - md,_:=metadata.FromIncomingContext(ss.Context()) - if identify:=md.Get(gosdk.HeaderXIdentity);len(identify)==0||identify[0]==""{ + err := ss.RecvMsg(srv) + if err != nil { + return err + } + md, _ := metadata.FromIncomingContext(ss.Context()) + if identify := md.Get(gosdk.HeaderXIdentity); len(identify) == 0 || identify[0] == "" { return fmt.Errorf("no app identity") } - if xAppKey:=md.Get(gosdk.HeaderXApiKey);len(xAppKey)==0||xAppKey[0]==""{ + if xAppKey := md.Get(gosdk.HeaderXApiKey); len(xAppKey) == 0 || xAppKey[0] == "" { return fmt.Errorf("no app key") } - return err + return ss.RecvMsg(srv) }) c.So(err, c.ShouldBeNil) ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) @@ -368,18 +390,39 @@ func TestStreamInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldNotBeNil) + // get identity err + patch := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get identity err")) + defer patch.Reset() + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "query user id base on apikey") + patch.Reset() + + // check apikey not match + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) }) c.Convey("TestStreamInterceptor jwt", t, func() { jwt := getJWT() ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorization", "Bearer "+jwt)) err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { - err:= ss.RecvMsg(srv) - md,_:=metadata.FromIncomingContext(ss.Context()) - if identify:=md.Get(gosdk.HeaderXIdentity);len(identify)==0||identify[0]==""{ + err := ss.RecvMsg(srv) + md, _ := metadata.FromIncomingContext(ss.Context()) + if identify := md.Get(gosdk.HeaderXIdentity); len(identify) == 0 || identify[0] == "" { return fmt.Errorf("no app identity") } - if xAuthorization:=md.Get("authorization");len(xAuthorization)==0||xAuthorization[0]==""{ + if xAuthorization := md.Get("authorization"); len(xAuthorization) == 0 || xAuthorization[0] == "" { return fmt.Errorf("no jwt key") } return err @@ -390,13 +433,26 @@ func TestStreamInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldNotBeNil) + + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorization", "")) + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrTokenMissing.Error()) }) - c.Convey("TestUnaryInterceptor aksk", t, func() { + c.Convey("TestStreamInterceptor aksk", t, func() { u := &v1.Users{} + bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) + + t.Logf("data:%s", string(bData)) req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) access, secret, appid := readInitAPP() sgin := gosdk.NewAppAuthSigner(access, secret) gwReq, err := gosdk.NewGatewayRequestFromHttp(req) @@ -417,22 +473,20 @@ func TestStreamInterceptor(t *testing.T) { defer patch.Reset() ctx := metadata.NewIncomingContext(context.Background(), md) err = mid.StreamInterceptor(u, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/POST"}, func(srv interface{}, ss grpc.ServerStream) error { - err:= ss.RecvMsg(srv) - md,ok:=metadata.FromIncomingContext(ss.Context()) - if !ok{ + err := ss.RecvMsg(srv) + md, ok := metadata.FromIncomingContext(ss.Context()) + if !ok { return fmt.Errorf("no metadata") } - if identify:=md.Get(gosdk.HeaderXIdentity);len(identify)==0||identify[0]==""{ + if identify := md.Get(gosdk.HeaderXIdentity); len(identify) == 0 || identify[0] == "" { return fmt.Errorf("no app identity") } - if xAccessKey:=md.Get(gosdk.HeaderXAccessKey);len(xAccessKey)==0||xAccessKey[0]==""{ - t.Logf("error metadata:%v",md) + if xAccessKey := md.Get(gosdk.HeaderXAccessKey); len(xAccessKey) == 0 || xAccessKey[0] == "" { return fmt.Errorf("no app access key") } return err }) c.So(err, c.ShouldBeNil) - sign1 := gosdk.NewAppAuthSigner("ASDASDCASDFQ", "ASDASDCASDFQ") gwReq1, err := gosdk.NewGatewayRequestFromHttp(req) c.So(err, c.ShouldBeNil) @@ -452,13 +506,187 @@ func TestStreamInterceptor(t *testing.T) { }) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAppSignatureInvalid.Error()) + + // recv msg err + ctx = metadata.NewIncomingContext(context.Background(), md) + patch4 := gomonkey.ApplyFuncReturn((*testStream).RecvMsg, fmt.Errorf("recv msg err")) + + defer patch4.Reset() + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/POST"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + patch4.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "recv msg err") + // recv io.EOF + patch5 := gomonkey.ApplyFuncReturn((*testStream).RecvMsg, io.EOF) + defer patch5.Reset() + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/POST"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + patch5.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, io.EOF.Error()) + + }) +} +func TestStreamClientInterceptor(t *testing.T) { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + config := config.ReadConfig(env) + cnf := cfg.NewConfig(config) + mid := getMid() + ctx := context.Background() + + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + + pd, _ := gateway.NewDescription(pbFile) + R.LoadAllRouters(pd) + c.Convey("TestStreamClientInterceptor apikey", t, func() { + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + st, err := mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + err = st.SendMsg(&v1.Users{}) + c.So(err, c.ShouldBeNil) + + // no need validate + st, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldBeEmpty) + // no outgoing context + _, err = mid.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + // do not need validate + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + // not match apikey + patch := gomonkey.ApplyFuncReturn((*cfg.Config).GetAdminAPIKey, "") + defer patch.Reset() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) + patch.Reset() + // get owner error + patch2 := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get owner err")) + defer patch2.Reset() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + patch2.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "query user id base on apikey") + + }) + c.Convey("TestStreamClientInterceptor jwt", t, func() { + jwt := getJWT() + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", "Bearer "+jwt)) + st, err := mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldNotBeEmpty) + }) + c.Convey("TestStreamClientInterceptor aksk", t, func() { + u := &v1.Users{} + + bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) + + t.Logf("data:%s", string(bData)) + req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) + access, secret, appid := readInitAPP() + sgin := gosdk.NewAppAuthSigner(access, secret) + gwReq, err := gosdk.NewGatewayRequestFromHttp(req) + c.So(err, c.ShouldBeNil) + err = sgin.SignRequest(gwReq) + c.So(err, c.ShouldBeNil) + md := metadata.New(make(map[string]string)) + + headers := gwReq.Headers + for _, k := range headers.Keys() { + // t.Logf("header:%s,value:%s", k, headers.Get(k)) + md.Append(k, headers.Get(k)) + } + md.Append("uri", "/test/post") + md.Append("x-http-method", http.MethodPost) + patch := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetSecret, secret, nil) + patch = patch.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, appid, nil) + defer patch.Reset() + outCTX := metadata.NewOutgoingContext(context.Background(), md) + st, err := mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldNotBeEmpty) + // no need validate + st, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok = metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldBeEmpty) + // no context + _, err = mid.StreamClientInterceptor(req.Context(), nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + + // streamer err + _, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, fmt.Errorf("streamer err") + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "streamer err") + + // get owner err + + patch2 := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, "", fmt.Errorf("get owner err")) + defer patch2.Reset() + _, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get owner err") + }) } func TestTestUnaryInterceptorErr(t *testing.T) { mid := getMid() ctx := context.Background() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -485,7 +713,7 @@ func TestTestUnaryInterceptorErr(t *testing.T) { func TestStreamInterceptorErr(t *testing.T) { mid := getMid() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") diff --git a/internal/middleware/auth/headers.go b/internal/middleware/auth/headers.go index b47c0f1..7a44932 100644 --- a/internal/middleware/auth/headers.go +++ b/internal/middleware/auth/headers.go @@ -1,119 +1,134 @@ package auth -import ( - "context" - "net/http" - "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" -) - type Header interface { Set(key, value string) SendHeader(key, value string) } -type GrpcHeader struct { - in metadata.MD - ctx context.Context - out metadata.MD -} -type httpHeader struct { - w http.ResponseWriter - r *http.Request -} -type GrpcStreamHeader struct { - *GrpcHeader - ss grpc.ServerStream -} -var headerPool = &sync.Pool{ - New: func() interface{} { - return &GrpcHeader{} - }, -} -var httpHeaderPool = &sync.Pool{ - New: func() interface{} { - return &httpHeader{} - }, -} +// type GrpcHeader struct { +// in metadata.MD +// ctx context.Context +// out metadata.MD +// } +// type httpHeader struct { +// w http.ResponseWriter +// r *http.Request +// } +// type GrpcStreamHeader struct { +// *GrpcHeader +// ss grpc.ServerStream +// } -var grpcStreamHeaderPool = &sync.Pool{ - New: func() interface{} { - return &GrpcStreamHeader{} - }, -} +// var headerPool = &sync.Pool{ +// New: func() interface{} { +// return &GrpcHeader{} +// }, +// } +// var httpHeaderPool = &sync.Pool{ +// New: func() interface{} { +// return &httpHeader{} +// }, +// } -func (g *GrpcHeader) Release() { - g.ctx = nil - g.in = nil - g.out = nil - headerPool.Put(g) -} -func (g *GrpcHeader) Set(key, value string) { - g.in.Set(key, value) - md, _ := metadata.FromIncomingContext(g.ctx) - newMd := metadata.Join(md, g.in) - g.ctx = metadata.NewIncomingContext(g.ctx, newMd) +// var grpcStreamHeaderPool = &sync.Pool{ +// New: func() interface{} { +// return &GrpcStreamHeader{} +// }, +// } -} -func (g *GrpcHeader) SendHeader(key, value string) { - g.out.Append(key, value) - _ = grpc.SendHeader(g.ctx, g.out) - g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) -} -func (g *httpHeader) Release() { - g.w = nil - g.r = nil - httpHeaderPool.Put(g) -} -func (g *httpHeader) Set(key, value string) { - g.r.Header.Add(key, value) +// func (g *GrpcHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// headerPool.Put(g) +// } +// func (g *GrpcHeader) Set(key, value string) { +// g.in.Set(key, value) +// md, _ := metadata.FromIncomingContext(g.ctx) +// newMd := metadata.Join(md, g.in) +// g.ctx = metadata.NewIncomingContext(g.ctx, newMd) +// log.Printf("grpc header set key:%v value:%v", key, value) -} -func (g *httpHeader) SendHeader(key, value string) { - g.w.Header().Add(key, value) -} -func (g *GrpcStreamHeader) Release() { - g.ctx = nil - g.in = nil - g.out = nil - g.ss = nil - grpcStreamHeaderPool.Put(g) +// } +// func (g *GrpcHeader) SendHeader(key, value string) { +// g.out.Append(key, value) +// _ = grpc.SendHeader(g.ctx, g.out) +// g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) +// } +// func (g *httpHeader) Release() { +// g.w = nil +// g.r = nil +// httpHeaderPool.Put(g) +// } +// func (g *httpHeader) Set(key, value string) { +// g.r.Header.Add(key, value) -} -func (g *GrpcStreamHeader) Set(key, value string) { - g.in.Append(key, value) - newCtx := metadata.NewIncomingContext(g.ctx, g.in) - g.ctx = newCtx - _ = g.ss.SetHeader(g.in) -} -func (g *GrpcStreamHeader) SendHeader(key, value string) { - g.out.Append(key, value) - _ = g.ss.SendHeader(g.out) - g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) -} +// } +// func (g *httpHeader) SendHeader(key, value string) { +// g.w.Header().Add(key, value) +// } -func NewGrpcHeader(in metadata.MD, ctx context.Context, out metadata.MD) *GrpcHeader { - // return &GrpcHeader{in: in, ctx: ctx, out: out} - header := headerPool.Get().(*GrpcHeader) - header.in = in - header.ctx = ctx - header.out = out - return header -} -func NewHttpHeader(w http.ResponseWriter, r *http.Request) *httpHeader { - // return &httpHeader{w: w, r: r} - header := httpHeaderPool.Get().(*httpHeader) - header.w = w - header.r = r - return header -} +// func (g *GrpcStreamHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// g.ss = nil +// grpcStreamHeaderPool.Put(g) -func NewGrpcStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, ss grpc.ServerStream) *GrpcStreamHeader { - // return &GrpcStreamHeader{&GrpcHeader{in: in, ctx: ctx, out: out}, ss} - header := grpcStreamHeaderPool.Get().(*GrpcStreamHeader) - header.GrpcHeader = NewGrpcHeader(in, ctx, out) - header.ss = ss - return header -} +// } +// func (g *GrpcStreamHeader) Set(key, value string) { +// g.in.Set(key, value) +// newCtx := metadata.NewIncomingContext(g.ctx, g.in) +// g.ctx = newCtx +// _ = g.ss.SetHeader(g.in) +// } +// func (g *GrpcStreamHeader) SendHeader(key, value string) { +// g.out.Append(key, value) +// _ = g.ss.SendHeader(g.out) +// g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) +// } + +// func (g *GrpcClientStreamHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// g.cs = nil +// grpcClientHeaderPool.Put(g) + +// } +// func (g *GrpcClientStreamHeader) Set(key, value string) { +// g.out.Set(key, value) +// newCtx := metadata.NewOutgoingContext(g.ctx, g.in) +// g.ctx = newCtx +// // _ = g.cs.(g.in) + +// } +// func NewGrpcHeader(in metadata.MD, ctx context.Context, out metadata.MD) *GrpcHeader { +// // return &GrpcHeader{in: in, ctx: ctx, out: out} +// header := headerPool.Get().(*GrpcHeader) +// header.in = in +// header.ctx = ctx +// header.out = out +// return header +// } +// func NewHttpHeader(w http.ResponseWriter, r *http.Request) *httpHeader { +// // return &httpHeader{w: w, r: r} +// header := httpHeaderPool.Get().(*httpHeader) +// header.w = w +// header.r = r +// return header +// } + +// func NewGrpcStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, ss grpc.ServerStream) *GrpcStreamHeader { +// // return &GrpcStreamHeader{&GrpcHeader{in: in, ctx: ctx, out: out}, ss} +// header := grpcStreamHeaderPool.Get().(*GrpcStreamHeader) +// header.GrpcHeader = NewGrpcHeader(in, ctx, out) +// header.ss = ss +// return header +// } +// func NewGrpcClientStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, cs grpc.ClientStream) *GrpcClientStreamHeader { +// header := grpcClientHeaderPool.Get().(*GrpcClientStreamHeader) +// header.GrpcHeader = NewGrpcHeader(in, ctx, out) +// header.cs = cs +// return header +// } diff --git a/internal/middleware/auth/headers_test.go b/internal/middleware/auth/headers_test.go index 8344ba2..27cdf96 100644 --- a/internal/middleware/auth/headers_test.go +++ b/internal/middleware/auth/headers_test.go @@ -1,32 +1,32 @@ package auth_test -import ( - "net/http" - "testing" +// import ( +// "net/http" +// "testing" - "github.com/begonia-org/begonia/internal/middleware/auth" - c "github.com/smartystreets/goconvey/convey" -) +// "github.com/begonia-org/begonia/internal/middleware/auth" +// c "github.com/smartystreets/goconvey/convey" +// ) -type responseWriter struct { -} +// type responseWriter struct { +// } -func (r *responseWriter) Header() http.Header { - return make(http.Header) -} -func (r *responseWriter) Write([]byte) (int, error) { - return 0, nil -} -func (r *responseWriter) WriteHeader(int) { +// func (r *responseWriter) Header() http.Header { +// return make(http.Header) +// } +// func (r *responseWriter) Write([]byte) (int, error) { +// return 0, nil +// } +// func (r *responseWriter) WriteHeader(int) { -} -func TestHeaders(t *testing.T) { - c.Convey("TestHeaders", t, func() { - req, _ := http.NewRequest("GET", "http://localhost", nil) - h := auth.NewHttpHeader(&responseWriter{}, req) - c.So(h, c.ShouldNotBeNil) - h.Set("key", "value") - h.SendHeader("key", "value") - h.Release() - }) -} +// } +// func TestHeaders(t *testing.T) { +// c.Convey("TestHeaders", t, func() { +// req, _ := http.NewRequest("GET", "http://localhost", nil) +// h := auth.NewHttpHeader(&responseWriter{}, req) +// c.So(h, c.ShouldNotBeNil) +// h.Set("key", "value") +// h.SendHeader("key", "value") +// h.Release() +// }) +// } diff --git a/internal/middleware/auth/jwt.go b/internal/middleware/auth/jwt.go index 00a6b54..c5c35b2 100644 --- a/internal/middleware/auth/jwt.go +++ b/internal/middleware/auth/jwt.go @@ -104,18 +104,21 @@ func (a *JWTAuth) checkJWTItem(ctx context.Context, payload *api.BasicAuth, toke } return true, nil } -func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, rspHeader Header, reqHeader Header) (ok bool, err error) { +func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, io *metadata.MD) (ok bool, err error) { payload, errAuth := a.jwt2BasicAuth(authorization) err = errAuth if err != nil { return false, err } + io.Set("x-uid", payload.Uid) + strArr := strings.Split(authorization, " ") token := strArr[1] ok, err = a.checkJWTItem(ctx, payload, token) if err != nil || !ok { return false, err } + io.Set("x-token", token) left := payload.Expiration - time.Now().Unix() // expiration := a.config.GetJWTExpiration() @@ -156,64 +159,73 @@ func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, rspHeader } // 旧token加入黑名单 go a.biz.PutBlackList(ctx, a.config.GetUserBlackListKey(tiga.GetMd5(token))) - rspHeader.SendHeader("Authorization", fmt.Sprintf("Bearer %s", newToken)) + // rspHeader.Set("Authorization", fmt.Sprintf("Bearer %s", newToken)) + _ = grpc.SetHeader(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", newToken))) token = newToken } - // 设置uid - reqHeader.Set("x-token", token) - reqHeader.Set("x-uid", payload.Uid) - reqHeader.Set(gosdk.HeaderXIdentity, payload.Uid) + io.Set("x-token", token) return true, nil } -func (a *JWTAuth) jwtValidator(ctx context.Context, headers Header) (context.Context, error) { - - - md, _ := metadata.FromIncomingContext(ctx) +func (a *JWTAuth) jwtValidator(ctx context.Context) (context.Context, error) { + in, inOK := metadata.FromIncomingContext(ctx) + if !inOK { + in = metadata.MD{} + } + out, ok := metadata.FromOutgoingContext(ctx) + if !ok { + out = metadata.MD{} + } + md := metadata.Join(in, out) token := a.GetAuthorizationFromMetadata(md) if token == "" { - return nil, status.Errorf(codes.Unauthenticated, "token not exists in context") - } + if token = a.GetAuthorizationFromMetadata(out); token == "" { + return nil, status.Errorf(codes.Unauthenticated, "token not exists in context") + } - ok, err := a.checkJWT(ctx, token, headers, headers) + } + ioMD := metadata.New(make(map[string]string)) + ok, err := a.checkJWT(ctx, token, &ioMD) if err != nil || !ok { return nil, status.Errorf(codes.Unauthenticated, "check token error,%v", err) } + in.Set("x-token", ioMD.Get("x-token")...) + in.Set("x-uid", ioMD.Get("x-uid")...) + in.Set(gosdk.HeaderXIdentity, ioMD.Get("x-uid")...) + newCtx := metadata.NewIncomingContext(ctx, in) + + out.Set("x-token", ioMD.Get("x-token")...) + out.Set("x-uid", ioMD.Get("x-uid")...) + out.Set(gosdk.HeaderXIdentity, ioMD.Get("x-uid")...) - newCtx := metadata.NewIncomingContext(ctx, md) + newCtx = metadata.NewOutgoingContext(newCtx, out) return newCtx, nil // return handler(newCtx, req) } func (a *JWTAuth) RequestBefore(ctx context.Context, info *grpc.UnaryServerInfo, req interface{}) (context.Context, error) { - in, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") - } - out, ok := metadata.FromOutgoingContext(ctx) - if !ok { - out = metadata.MD{} - } - headers := NewGrpcHeader(in, ctx, out) - defer headers.Release() - _, err := a.jwtValidator(ctx, headers) + ctx, err := a.jwtValidator(ctx) if err != nil { return nil, err } - return headers.ctx, nil + return ctx, nil } -func (a *JWTAuth) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *JWTAuth) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { // headers := NewGrpcStreamHeader(in, ctx, out,ss) - ctx, err := a.jwtValidator(ctx, headers) + ctx, err := a.jwtValidator(ctx) return ctx, err } func (a *JWTAuth) StreamRequestBefore(ctx context.Context, ss grpc.ServerStream, info *grpc.StreamServerInfo, req interface{}) (grpc.ServerStream, error) { - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) + ctx, err := a.jwtValidator(ctx) + if err != nil { + return nil, err + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) return grpcStream, nil } @@ -273,3 +285,15 @@ func (jwt *JWTAuth) Priority() int { func (jwt *JWTAuth) Name() string { return jwt.name } + +func (jwt *JWTAuth) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + + ctx, err := jwt.jwtValidator(ctx) + if err != nil { + return nil, err + } + return streamer(ctx, desc, cc, method, opts...) +} diff --git a/internal/middleware/auth/jwt_test.go b/internal/middleware/auth/jwt_test.go index 295030d..a6d5bef 100644 --- a/internal/middleware/auth/jwt_test.go +++ b/internal/middleware/auth/jwt_test.go @@ -19,7 +19,7 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" + gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" api "github.com/begonia-org/go-sdk/api/user/v1" "github.com/bsm/redislock" @@ -62,7 +62,7 @@ func TestJWTUnaryInterceptor(t *testing.T) { config := config.ReadConfig(env) cnf := cfg.NewConfig(config) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -71,8 +71,8 @@ func TestJWTUnaryInterceptor(t *testing.T) { user := data.NewUserRepo(config, gateway.Log) userAuth := crypto.NewUsersAuth(cnf) authzRepo := data.NewAuthzRepo(config, gateway.Log) - appRepo:=data.NewAppRepo(config,gateway.Log) - authz := biz.NewAuthzUsecase(authzRepo, user,appRepo, gateway.Log, userAuth, cnf) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) jwt := auth.NewJWTAuth(cnf, tiga.NewRedisDao(config), authz, gateway.Log) jwt.SetPriority(1) c.So(jwt.Priority(), c.ShouldEqual, 1) @@ -85,15 +85,6 @@ func TestJWTUnaryInterceptor(t *testing.T) { }) c.So(err, c.ShouldBeNil) - _, err = jwt.UnaryInterceptor(context.Background(), &hello.HelloRequest{}, &grpc.UnaryServerInfo{ - FullMethod: "/integration.TestService/Get", - }, func(ctx context.Context, req interface{}) (interface{}, error) { - return nil, nil - - }) - c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") - _, err = jwt.UnaryInterceptor(metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test")), &hello.HelloRequest{}, &grpc.UnaryServerInfo{ FullMethod: "/integration.TestService/Get", }, func(ctx context.Context, req interface{}) (interface{}, error) { @@ -289,7 +280,7 @@ func TestJWTStreamInterceptor(t *testing.T) { config := config.ReadConfig(env) cnf := cfg.NewConfig(config) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -298,8 +289,8 @@ func TestJWTStreamInterceptor(t *testing.T) { user := data.NewUserRepo(config, gateway.Log) userAuth := crypto.NewUsersAuth(cnf) authzRepo := data.NewAuthzRepo(config, gateway.Log) - appRepo:=data.NewAppRepo(config,gateway.Log) - authz := biz.NewAuthzUsecase(authzRepo, user,appRepo, gateway.Log, userAuth, cnf) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) jwt := auth.NewJWTAuth(cnf, tiga.NewRedisDao(config), authz, gateway.Log) err := jwt.StreamInterceptor(&hello.HelloRequest{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-api-key", cnf.GetAdminAPIKey())), @@ -320,3 +311,40 @@ func TestJWTStreamInterceptor(t *testing.T) { c.So(err, c.ShouldBeNil) }) } + +func TestJWTClientStream(t *testing.T) { + c.Convey("TestJWTClientStream", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + config := config.ReadConfig(env) + cnf := cfg.NewConfig(config) + + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + + pd, _ := gateway.NewDescription(pbFile) + R.LoadAllRouters(pd) + user := data.NewUserRepo(config, gateway.Log) + userAuth := crypto.NewUsersAuth(cnf) + authzRepo := data.NewAuthzRepo(config, gateway.Log) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) + jwt := auth.NewJWTAuth(cnf, tiga.NewRedisDao(config), authz, gateway.Log) + jwt.SetPriority(1) + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAuthorization, cnf.GetAdminAPIKey())) + _, err := jwt.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + + outCTX = metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAuthorization, cnf.GetAdminAPIKey())) + _, err = jwt.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + + }) +} diff --git a/internal/middleware/auth/stream.go b/internal/middleware/auth/stream.go index 2fde49c..e4a89d8 100644 --- a/internal/middleware/auth/stream.go +++ b/internal/middleware/auth/stream.go @@ -2,24 +2,34 @@ package auth import ( "context" + "errors" + "io" "sync" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) type StreamValidator interface { - ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) + ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) } type grpcServerStream struct { grpc.ServerStream - fullName string - validate StreamValidator - ctx context.Context + fullName string + validate StreamValidator + ctx context.Context + firstFrame bool } +// type grpcClientStream struct { +// grpc.ClientStream +// fullName string +// ctx context.Context +// validate StreamValidator +// firstFrame bool +// } + var streamPool = &sync.Pool{ New: func() interface{} { return &grpcServerStream{ @@ -28,11 +38,17 @@ var streamPool = &sync.Pool{ }, } +// var clientStreamPool = &sync.Pool{ +// New: func() interface{} { +// return &grpcClientStream{} +// }, +// } + func NewGrpcStream(s grpc.ServerStream, fullName string, ctx context.Context, validator StreamValidator) *grpcServerStream { stream := streamPool.Get().(*grpcServerStream) stream.ServerStream = s stream.fullName = fullName - stream.ctx = s.Context() + stream.ctx = ctx stream.validate = validator return stream } @@ -41,29 +57,28 @@ func (g *grpcServerStream) Release() { g.fullName = "" g.ServerStream = nil g.validate = nil + g.firstFrame = false streamPool.Put(g) } func (g *grpcServerStream) Context() context.Context { return g.ctx } func (s *grpcServerStream) RecvMsg(m interface{}) error { - if err := s.ServerStream.RecvMsg(m); err != nil { - return err + var err error + if err = s.ServerStream.RecvMsg(m); err != nil && !errors.Is(err, io.EOF) { + return status.Errorf(codes.Internal, "recv msg err:%s", err.Error()) } - in, ok := metadata.FromIncomingContext(s.Context()) - if !ok { - return status.Errorf(codes.Unauthenticated, "metadata not exists in context") + if err != nil { + return err } - out, ok := metadata.FromOutgoingContext(s.Context()) - if !ok { - out = metadata.MD{} + if !s.firstFrame { + ctx, err := s.validate.ValidateStream(s.Context(), m, s.fullName) + s.ctx = ctx + s.firstFrame = true + return err } - header := NewGrpcStreamHeader(in, s.Context(), out, s.ServerStream) - _, err := s.validate.ValidateStream(s.Context(), m, s.fullName, header) - s.ctx = header.ctx - header.Release() - return err + return nil } diff --git a/internal/middleware/http.go b/internal/middleware/http.go index 8bf0804..a1b4bf6 100644 --- a/internal/middleware/http.go +++ b/internal/middleware/http.go @@ -3,9 +3,8 @@ package middleware import ( "context" "fmt" - "strings" - "github.com/begonia-org/begonia/internal/pkg/routers" + "github.com/begonia-org/begonia/gateway" gosdk "github.com/begonia-org/go-sdk" _ "github.com/begonia-org/go-sdk/api/app/v1" _ "github.com/begonia-org/go-sdk/api/endpoint/v1" @@ -19,7 +18,6 @@ import ( "google.golang.org/genproto/googleapis/api/httpbody" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -38,28 +36,24 @@ type HttpStream struct { grpc.ServerStream FullMethod string } + type Http struct { priority int name string } func (s *HttpStream) SendMsg(m interface{}) error { - ctx := s.ServerStream.Context() - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return s.ServerStream.SendMsg(m) - } - if protocol, ok := md["grpcgateway-accept"]; ok { - if !strings.EqualFold(protocol[0], "application/json") { + + routersList := gateway.GetRouter() + router := routersList.GetRouteByGrpcMethod(s.FullMethod) + // 对内置服务的http响应进行格式化 + if routersList.IsLocalSrv(s.FullMethod) || (router != nil && router.UseJsonResponse) { + if _, ok := m.(*httpbody.HttpBody); ok { return s.ServerStream.SendMsg(m) } - routersList := routers.Get() - router := routersList.GetRouteByGrpcMethod(s.FullMethod) - // 对内置服务的http响应进行格式化 - if routersList.IsLocalSrv(s.FullMethod) || router.UseJsonResponse { - rsp, _ := grpcToHttpResponse(m, nil) - return s.ServerStream.SendMsg(rsp) - } + rsp, _ := grpcToHttpResponse(m, s.Context().Err()) + err := s.ServerStream.SendMsg(rsp) + return err } return s.ServerStream.SendMsg(m) } @@ -94,7 +88,6 @@ func toStructMessage(msg protoreflect.ProtoMessage) (*structpb.Struct, error) { return structMsg, nil } func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error) { - if err != nil { if st, ok := status.FromError(err); ok { details := st.Details() @@ -102,7 +95,6 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error if anyType, ok := detail.(*anypb.Any); ok { var errDetail common.Errors var stErr = anyType.UnmarshalTo(&errDetail) - if stErr == nil { rspCode := int32(errDetail.Code) codesMap := getClientMessageMap() @@ -125,7 +117,13 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error if st.Code() == codes.Unimplemented { code = int32(common.Code_NOT_FOUND) } + if st.Code() == codes.InvalidArgument { + code = int32(common.Code_PARAMS_ERROR) + } + if st.Code() == codes.AlreadyExists { + code = int32(common.Code_CONFLICT) + } return &common.HttpResponse{ Code: code, Message: st.Message(), @@ -155,25 +153,19 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error }, err } func (h *Http) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return handler(ctx, req) - } - if protocol, ok := md["grpcgateway-accept"]; ok { - if !strings.EqualFold(protocol[0], "application/json") { - return handler(ctx, req) - } - routersList := routers.Get() - router := routersList.GetRouteByGrpcMethod(info.FullMethod) - // 对内置服务的http响应进行格式化 - if routersList.IsLocalSrv(info.FullMethod) || router.UseJsonResponse { - rsp, err := handler(ctx, req) - if _, ok := rsp.(*httpbody.HttpBody); ok { - return rsp, err - } - return grpcToHttpResponse(rsp, err) + + routersList := gateway.GetRouter() + router := routersList.GetRouteByGrpcMethod(info.FullMethod) + // 对内置服务的http响应进行格式化 + if routersList.IsLocalSrv(info.FullMethod) || router.UseJsonResponse { + // ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("content-type", "application/json")) + rsp, err := handler(ctx, req) + if _, ok := rsp.(*httpbody.HttpBody); ok { + return rsp, err } + return grpcToHttpResponse(rsp, err) } + // } return handler(ctx, req) } @@ -183,7 +175,13 @@ func (h *Http) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *gr return handler(srv, stream) } - +func (h *Http) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ss, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, status.Errorf(codes.Internal, "create client stream error:%s", err.Error()) + } + return ss, nil +} func NewHttp() *Http { return &Http{name: "http"} } diff --git a/internal/middleware/http_test.go b/internal/middleware/http_test.go index 805d4c0..be75896 100644 --- a/internal/middleware/http_test.go +++ b/internal/middleware/http_test.go @@ -11,7 +11,6 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/middleware" "github.com/begonia-org/begonia/internal/pkg" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" user "github.com/begonia-org/go-sdk/api/user/v1" @@ -70,7 +69,7 @@ func (x *greeterSayHelloWebsocketServer) Context() context.Context { func TestStreamInterceptor(t *testing.T) { c.Convey("test stream interceptor", t, func() { mid := middleware.NewHttp() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") @@ -84,10 +83,37 @@ func TestStreamInterceptor(t *testing.T) { c.So(err, c.ShouldBeNil) }) } +func TestHttpStreamClientInterceptor(t *testing.T) { + c.Convey("test http stream client interceptor", t, func() { + mid := middleware.NewHttp() + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") + + pd, err := gateway.NewDescription(pbFile) + c.So(err, c.ShouldBeNil) + R.LoadAllRouters(pd) + stream, err := mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("grpcgateway-accept", "application/json")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + c.So(stream, c.ShouldNotBeNil) + + stream, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("grpcgateway-accept", "application/json")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, fmt.Errorf("new stream err") + }, + ) + c.So(err, c.ShouldNotBeNil) + c.So(stream, c.ShouldBeNil) + + }) + +} func TestUnaryInterceptor(t *testing.T) { c.Convey("test unary interceptor", t, func() { mid := middleware.NewHttp() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") @@ -120,6 +146,16 @@ func TestUnaryInterceptor(t *testing.T) { }) c.So(err, c.ShouldNotBeNil) c.So(req, c.ShouldNotBeNil) + req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, status.Error(codes.InvalidArgument, "test") + }) + c.So(err, c.ShouldNotBeNil) + c.So(req, c.ShouldNotBeNil) + req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, status.Error(codes.AlreadyExists, "test") + }) + c.So(err, c.ShouldNotBeNil) + c.So(req, c.ShouldNotBeNil) req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, fmt.Errorf("test") diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 957c6dc..e0d69f1 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -31,7 +31,7 @@ func New(config *config.Config, ) *PluginsApply { jwt := auth.NewJWTAuth(config, rdb, user, log) ak := auth.NewAccessKeyAuth(authz, config, log) - apiKey := auth.NewApiKeyAuth(config,user) + apiKey := auth.NewApiKeyAuth(config, user) plugins := map[string]gosdk.LocalPlugin{ "onlyJWT": jwt, "onlyAK": ak, @@ -102,3 +102,11 @@ func (p *PluginsApply) StreamInterceptorChains() []grpc.StreamServerInterceptor } return chains } + +func (p *PluginsApply) StreamClientInterceptorChains() []grpc.StreamClientInterceptor { + chains := make([]grpc.StreamClientInterceptor, 0) + for _, plugin := range p.Plugins { + chains = append(chains, plugin.(gosdk.LocalPlugin).StreamClientInterceptor) + } + return chains +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index e6ea5a1..5d631a9 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -28,8 +28,8 @@ func TestMiddlewareUnaryInterceptorChains(t *testing.T) { user := data.NewUserRepo(config, gateway.Log) userAuth := crypto.NewUsersAuth(cnf) authzRepo := data.NewAuthzRepo(config, gateway.Log) - appRepo:=data.NewAppRepo(config,gateway.Log) - authz := biz.NewAuthzUsecase(authzRepo, user,appRepo, gateway.Log, userAuth, cnf) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) repo := data.NewAppRepo(config, gateway.Log) akBiz := biz.NewAccessKeyAuth(repo, cnf, gateway.Log) @@ -37,6 +37,7 @@ func TestMiddlewareUnaryInterceptorChains(t *testing.T) { // mid.SetPriority(1) c.So(len(mid.StreamInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) c.So(len(mid.UnaryInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) + c.So(len(mid.StreamClientInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) plugins := cnf.GetPlugins() plugins["test"] = 1 diff --git a/internal/middleware/protobuf_validate.go b/internal/middleware/protobuf_validate.go new file mode 100644 index 0000000..ae025d8 --- /dev/null +++ b/internal/middleware/protobuf_validate.go @@ -0,0 +1,187 @@ +package middleware + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-playground/validator/v10" + "github.com/iancoleman/strcase" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type protobufValidator struct { + validate *validator.Validate +} + +// NewProtobufValidate Create a new protobuf validator +// +// The validate parameter is a validator instance that can be used to validate the structure of the protobuf message +func NewProtobufValidate(validate *validator.Validate) *protobufValidator { + return &protobufValidator{validate: validate} +} + +// getValue Get the value of the field +func (p *protobufValidator) getValue(v protoreflect.Value, k protoreflect.Kind, f protoreflect.FieldDescriptor) interface{} { + switch k { + case protoreflect.BoolKind: + return v.Bool() + case protoreflect.StringKind: + return v.String() + case protoreflect.Int32Kind, protoreflect.Int64Kind: + return v.Int() + case protoreflect.Uint32Kind, protoreflect.Uint64Kind: + return v.Uint() + case protoreflect.FloatKind, protoreflect.DoubleKind: + return v.Float() + case protoreflect.MessageKind: + return v.Message().Interface() + case protoreflect.EnumKind: + return f.Enum().Values().ByNumber(v.Enum()).Name() + case protoreflect.BytesKind: + return v.Bytes() + + } + return nil +} + +// getFieldTag Get the tag of the field +// +// The validateTag parameter is the validate tag of the field +// For Enum fields, the oneof tag is added to the validate tag +// Json tag is added to the field tag +func (p *protobufValidator) getFieldTag(field protoreflect.FieldDescriptor, validateTag interface{}) string { + tag := fmt.Sprintf(`json:"%s"`, field.JSONName()) + if validateTag != nil { + validate := validateTag.(string) + tag = fmt.Sprintf(`json:"%s" validate:"%s"`, field.JSONName(), validate) + } + if field.Enum() != nil && !strings.Contains(tag, "oneof") { + oneOfEnum := make([]string, 0) + for i := 0; i < field.Enum().Values().Len(); i++ { + oneOfEnum = append(oneOfEnum, string(field.Enum().Values().Get(i).Name())) + } + if !field.IsList() && !field.IsMap() { + tag = fmt.Sprintf(`json:"%s" validate:"oneof=%s"`, field.JSONName(), strings.Join(oneOfEnum, " ")) + } + } + return tag +} + +// handleFieldValue Handle the field value, convert the protobuf message field to a struct field by recursion +// +// see: `reflect.StructField.Type` +func (p *protobufValidator) handleFieldValue(field protoreflect.FieldDescriptor, fieldValue protoreflect.Value, ext protoreflect.ExtensionType) interface{} { + if field.IsMap() { + // convert map field to struct map field + mapTyp := make(map[string]interface{}) + mapValue := fieldValue.Map() + mapValue.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool { + if field.MapValue().Kind() == protoreflect.MessageKind { + mapTyp[key.String()] = p.protobufToStructType(value.Message().Interface(), ext).Interface() + } else { + mapTyp[key.String()] = p.getValue(value, field.MapValue().Kind(), field.MapValue()) + } + return true + }) + return mapTyp + } + + if field.Kind() == protoreflect.MessageKind { + if field.IsList() { + + list := make([]interface{}, 0) + for j := 0; j < fieldValue.List().Len(); j++ { + list = append(list, p.protobufToStructType(fieldValue.List().Get(j).Message().Interface(), ext).Interface()) + } + return list + } + return p.protobufToStructType(fieldValue.Message().Interface(), ext).Interface() + } + + if field.IsList() { + + list := make([]interface{}, 0) + for j := 0; j < fieldValue.List().Len(); j++ { + list = append(list, p.getValue(fieldValue.List().Get(j), field.Kind(), field)) + } + return list + } + + return p.getValue(fieldValue, field.Kind(), field) +} + +// isValid Determine whether the field value is valid, +// +// If the field is a list or map, the function will return true if the field is valid +func (p *protobufValidator) isValid(value protoreflect.Value, f protoreflect.FieldDescriptor) bool { + if f.IsList() { + return value.List().IsValid() + } + if f.IsMap() { + return value.Map().IsValid() + } + switch f.Kind() { + case protoreflect.MessageKind: + return value.Message().IsValid() + case protoreflect.StringKind: + return value.String() != "" + default: + return value.IsValid() + } +} + +// protobufToStructType Convert the protobuf message to a struct type +// +// see: `reflect.StructField` +func (p *protobufValidator) protobufToStructType(message proto.Message, ext protoreflect.ExtensionType) reflect.Value { + md := message.ProtoReflect().Descriptor() + fieldsValues := make(map[string]reflect.Value) + structFields := make([]reflect.StructField, 0) + + for i := 0; i < md.Fields().Len(); i++ { + field := md.Fields().Get(i) + fieldName := strcase.ToCamel(string(field.Name())) + fieldValue := message.ProtoReflect().Get(field) + value := p.handleFieldValue(field, fieldValue, ext) + validateTag := proto.GetExtension(field.Options(), ext) + tag := p.getFieldTag(field, validateTag) + + structFields = append(structFields, reflect.StructField{ + Name: fieldName, + Type: reflect.TypeOf(value), + Tag: reflect.StructTag(tag), + }) + fieldsValues[fieldName] = reflect.ValueOf(value) + if !p.isValid(fieldValue, field) { + fieldsValues[fieldName] = reflect.Zero(reflect.TypeOf(value)) + } + } + + structType := reflect.StructOf(structFields) + newTypVal := reflect.New(structType) + for k, v := range fieldsValues { + newTypVal.Elem().FieldByName(k).Set(v) + } + return newTypVal +} + +func (p *protobufValidator) Protobuf(message proto.Message, ext protoreflect.ExtensionType) error { + v := p.protobufToStructType(message, ext).Interface() + return p.validate.Struct(v) +} + +func (p *protobufValidator) NewStructFromProtobuf(message proto.Message, ext protoreflect.ExtensionType) interface{} { + return p.protobufToStructType(message, ext).Interface() +} + +func (p *protobufValidator) ProtobufPartial(message proto.Message, ext protoreflect.ExtensionType, fields ...string) error { + v := p.protobufToStructType(message, ext).Interface() + return p.validate.StructPartial(v, fields...) +} +func (p *protobufValidator) ProtobufPartialCtx(ctx context.Context, message proto.Message, ext protoreflect.ExtensionType, fields ...string) error { + v := p.protobufToStructType(message, ext).Interface() + return p.validate.StructPartialCtx(ctx, v, fields...) +} diff --git a/internal/middleware/rpc.go b/internal/middleware/rpc.go index 010c377..13e7df2 100644 --- a/internal/middleware/rpc.go +++ b/internal/middleware/rpc.go @@ -22,22 +22,11 @@ import ( type RPCPluginCaller interface{} -// type rpcPluginCallerImpl struct { -// plugins gosdk.Plugins -// } - -// func NewRPCPluginCaller() RPCPluginCaller { -// return &rpcPluginCallerImpl{ -// plugins: make(gosdk.Plugins, 0), -// } -// } - type pluginImpl struct { priority int name string timeout time.Duration lb lb.LoadBalance - // api.PluginServiceClient } func (p *pluginImpl) SetPriority(priority int) { @@ -54,17 +43,16 @@ func (p *pluginImpl) Name() string { func (p *pluginImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - md = metadata.New(nil) + md = metadata.New(make(map[string]string)) } - rsp, err := p.Apply(ctx, req, info.FullMethod) + rsp, header, err := p.Apply(ctx, req, info.FullMethod) if err != nil { return nil, gosdk.NewError(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } + for k, v := range header { + md[k] = append(md[k], v...) - for k, v := range rsp.Metadata { - md.Append(k, v) } - newRequest := rsp.NewRequest if newRequest != nil { err = newRequest.UnmarshalTo(req.(proto.Message)) @@ -97,15 +85,15 @@ func (p *pluginImpl) getEndpoint(ctx context.Context) (lb.Endpoint, error) { return endpoint, nil } -func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName string) (*api.PluginResponse, error) { +func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName string) (*api.PluginResponse, metadata.MD, error) { endpoint, err := p.getEndpoint(ctx) if err != nil { - return nil, err + return nil, nil, err } cn, err := endpoint.Get(ctx) if err != nil { - return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + return nil, nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") } defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) @@ -113,14 +101,38 @@ func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName s plugin := api.NewPluginServiceClient(conn) anyReq, err := anypb.New(in.(proto.Message)) if err != nil { - return nil, gosdk.NewError(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") + return nil, nil, gosdk.NewError(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") } - return plugin.Apply(ctx, &api.PluginRequest{ + var header, trailer metadata.MD + rsp, err := plugin.Apply(ctx, &api.PluginRequest{ Request: anyReq, FullMethodName: fullMethodName, - }) - // return plugin.Call(ctx, anyReq, opts...) + }, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + return nil, nil, gosdk.NewError(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") + } + return rsp, header, nil +} +func (p *pluginImpl) Metadata(ctx context.Context, in *emptypb.Empty) (metadata.MD, error) { + endpoint, err := p.getEndpoint(ctx) + if err != nil { + return nil, err + } + cn, err := endpoint.Get(ctx) + if err != nil { + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + } + defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) + conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) + plugin := api.NewPluginServiceClient(conn) + var header, trailer metadata.MD + _, err = plugin.Metadata(ctx, in, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("call plugin metadata error:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "metadata") + } + return header, nil + } func (p *pluginImpl) Info(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*api.PluginInfo, error) { endpoint, err := p.getEndpoint(ctx) @@ -137,14 +149,26 @@ func (p *pluginImpl) Info(ctx context.Context, in *emptypb.Empty, opts ...grpc.C return plugin.Info(ctx, in, opts...) } func (p *pluginImpl) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + md, err := p.Metadata(ss.Context(), &emptypb.Empty{}) + if err != nil { + return err + + } + in, _ := metadata.FromIncomingContext(ss.Context()) - grpcStream := NewGrpcPluginStream(ss, info.FullMethod, ss.Context(), p) + ctx := metadata.NewIncomingContext(ss.Context(), metadata.Join(in, md)) + grpcStream := NewGrpcPluginStream(ss, info.FullMethod, ctx, p) if grpcStream != nil { defer grpcStream.Release() } return handler(srv, grpcStream) +} +func (p *pluginImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + + return streamer(ctx, desc, cc, method, opts...) + } func NewPluginImpl(lb lb.LoadBalance, name string, timeout time.Duration) *pluginImpl { return &pluginImpl{ diff --git a/internal/middleware/rpc_test.go b/internal/middleware/rpc_test.go index dd09e60..ed1505d 100644 --- a/internal/middleware/rpc_test.go +++ b/internal/middleware/rpc_test.go @@ -11,6 +11,7 @@ import ( "github.com/begonia-org/begonia/internal/middleware" goloadbalancer "github.com/begonia-org/go-loadbalancer" hello "github.com/begonia-org/go-sdk/api/example/v1" + api "github.com/begonia-org/go-sdk/api/plugin/v1" "github.com/begonia-org/go-sdk/example" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" @@ -20,6 +21,21 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func TestPluginUnaryInterceptor(t *testing.T) { c.Convey("test plugin unary interceptor", t, func() { go example.RunPlugins(":9850") @@ -57,19 +73,23 @@ func TestPluginUnaryInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldBeNil) - patch2 := gomonkey.ApplyFuncSeq(metadata.FromIncomingContext, []gomonkey.OutputCell{{ - Values: gomonkey.Params{metadata.New(map[string]string{"test": "test"}), true}, - Times: 2, - }, - { - Values: gomonkey.Params{nil, false}, - }, - }) - defer patch2.Reset() + // patch2 := gomonkey.ApplyFuncSeq(metadata.FromIncomingContext, []gomonkey.OutputCell{{ + // Values: gomonkey.Params{metadata.New(map[string]string{"test": "test"}), true}, + // Times: 4, + // }, + // { + // Values: gomonkey.Params{nil, false}, + // }, + // }) + // defer patch2.Reset() err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { - return ss.RecvMsg(srv) + err := ss.RecvMsg(srv) + if err != nil { + t.Logf("recv msg error: %v", err) + } + return err }) - patch2.Reset() + // patch2.Reset() c.So(err, c.ShouldBeNil) }) @@ -142,20 +162,61 @@ func TestPluginUnaryInterceptorErr(t *testing.T) { patch4.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "unmarshal to request error") + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: context.Background()}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get metadata from context error") + // select err + patch5 := gomonkey.ApplyMethodReturn(lb, "Select", nil, fmt.Errorf("select endpoint error")) + defer patch5.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "select endpoint error") + patch5.Reset() + enp, _ := lb.Select("127.0.0.1") + patch6 := gomonkey.ApplyMethodReturn(enp, "Get", nil, fmt.Errorf("get connection error")) + defer patch6.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + patch6.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get connection error") + // call plugin metadata err + cli := api.NewPluginServiceClient(nil) + patch7 := gomonkey.ApplyMethodReturn(cli, "Metadata", nil, fmt.Errorf("call test plugin error")) + defer patch7.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + patch7.Reset() + c.So(err.Error(), c.ShouldContainSubstring, "call test plugin error") + // call apply err + patch8 := gomonkey.ApplyMethodReturn(cli, "Apply", nil, fmt.Errorf("call plugin error")) + defer patch8.Reset() + _, err = mid.UnaryInterceptor(metadata.NewIncomingContext(context.Background(), metadata.Pairs("X-Forwarded-For", "127.0.0.1:9090")), &hello.HelloRequest{}, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + patch8.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "call plugin error") }) } func TestPluginStreamInterceptorErr(t *testing.T) { c.Convey("test plugin unary interceptor", t, func() { - go example.RunPlugins(":9850") + go example.RunPlugins(":9851") time.Sleep(2 * time.Second) lb := goloadbalancer.NewGrpcLoadBalance(&goloadbalancer.Server{ Name: "test", Endpoints: []goloadbalancer.EndpointServer{ { - Addr: "127.0.0.1:9850", + Addr: "127.0.0.1:9851", }, }, Pool: &goloadbalancer.PoolConfig{ @@ -176,7 +237,7 @@ func TestPluginStreamInterceptorErr(t *testing.T) { patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "recv msg error") - patch1 := gomonkey.ApplyMethodReturn(mid, "Apply", nil, fmt.Errorf("call test plugin error")) + patch1 := gomonkey.ApplyMethodReturn(mid, "Apply", nil, nil, fmt.Errorf("call test plugin error")) defer patch1.Reset() err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { return ss.RecvMsg(srv) @@ -197,3 +258,34 @@ func TestPluginStreamInterceptorErr(t *testing.T) { }) } + +func TestRPCStreamClientInterceptor(t *testing.T) { + c.Convey("test rpc stream client interceptor", t, func() { + go example.RunPlugins(":9852") + time.Sleep(2 * time.Second) + lb := goloadbalancer.NewGrpcLoadBalance(&goloadbalancer.Server{ + Name: "test", + Endpoints: []goloadbalancer.EndpointServer{ + { + Addr: "127.0.0.1:9852", + }, + }, + Pool: &goloadbalancer.PoolConfig{ + MaxOpenConns: 10, + MaxIdleConns: 5, + MaxActiveConns: 5, + }, + }) + mid := middleware.NewPluginImpl(lb, "test", 3*time.Second) + c.So(mid.Name(), c.ShouldEqual, "test") + mid.SetPriority(3) + st, err := mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("key", "value")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + }, + ) +} diff --git a/internal/middleware/stream.go b/internal/middleware/stream.go index ec3ecda..fab289a 100644 --- a/internal/middleware/stream.go +++ b/internal/middleware/stream.go @@ -3,6 +3,7 @@ package middleware import ( "context" "fmt" + "log" "sync" gosdk "github.com/begonia-org/go-sdk" @@ -20,6 +21,13 @@ type grpcPluginStream struct { ctx context.Context } +// type grpcPluginClientStream struct { +// grpc.ClientStream +// fullName string +// plugin *pluginImpl +// ctx context.Context +// } + var streamPool = &sync.Pool{ New: func() interface{} { return &grpcPluginStream{ @@ -28,11 +36,19 @@ var streamPool = &sync.Pool{ }, } +// func NewGrpcPluginClientStream(s grpc.ClientStream, fullName string, ctx context.Context, plugin *pluginImpl) *grpcPluginClientStream { +// return &grpcPluginClientStream{ +// ClientStream: s, +// fullName: fullName, +// ctx: ctx, +// plugin: plugin, +// } +// } func NewGrpcPluginStream(s grpc.ServerStream, fullName string, ctx context.Context, plugin *pluginImpl) *grpcPluginStream { stream := streamPool.Get().(*grpcPluginStream) stream.ServerStream = s stream.fullName = fullName - stream.ctx = s.Context() + stream.ctx = ctx stream.plugin = plugin return stream } @@ -46,22 +62,20 @@ func (g *grpcPluginStream) Release() { func (g *grpcPluginStream) Context() context.Context { return g.ctx } + func (s *grpcPluginStream) RecvMsg(m interface{}) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - rsp, err := s.plugin.Apply(s.Context(), m, s.fullName) + rsp, header, err := s.plugin.Apply(s.Context(), m, s.fullName) if err != nil { return gosdk.NewError(fmt.Errorf("call %s plugin error: %w", s.plugin.Name(), err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } md, ok := metadata.FromIncomingContext(s.ctx) if !ok { - md = metadata.New(nil) - } - for k, v := range rsp.Metadata { - md.Append(k, v) + md = metadata.New(make(map[string]string)) } newRequest := rsp.NewRequest if newRequest != nil { @@ -70,6 +84,10 @@ func (s *grpcPluginStream) RecvMsg(m interface{}) error { return gosdk.NewError(fmt.Errorf("unmarshal to request error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_to_request") } } + log.Printf("grpcPluginStream server stream pointer:%p", s) + for k, v := range header { + md[k] = append(md[k], v...) + } s.ctx = metadata.NewIncomingContext(s.ctx, md) // s.ctx = metadata.NewIncomingContext(s.ctx, metadata.New(rsp.Metadata)) diff --git a/internal/middleware/vaildator.go b/internal/middleware/vaildator.go index 0b78362..91536fe 100644 --- a/internal/middleware/vaildator.go +++ b/internal/middleware/vaildator.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "reflect" "strings" @@ -10,8 +11,10 @@ import ( gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/go-playground/validator/v10" + "github.com/iancoleman/strcase" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/fieldmaskpb" ) @@ -23,6 +26,16 @@ type validatePluginStream struct { ctx context.Context validator ParamsValidator } +type ValidateError struct { + error + Field string +} + +// type validatePluginClientStream struct { +// grpc.ClientStream +// ctx context.Context +// validator ParamsValidator +// } var validatePluginStreamPool = &sync.Pool{ New: func() interface{} { @@ -39,6 +52,7 @@ type ParamsValidator interface { type ParamsValidatorImpl struct { priority int + validate *validator.Validate } func (p *validatePluginStream) Context() context.Context { @@ -54,6 +68,7 @@ func (p *validatePluginStream) RecvMsg(m interface{}) error { } +// func getFieldNamesFromProto(input interface{}) map[string]string {} func getFieldNamesFromJSONTags(input interface{}) map[string]string { fieldMap := make(map[string]string) @@ -61,7 +76,10 @@ func getFieldNamesFromJSONTags(input interface{}) map[string]string { if val.Kind() == reflect.Ptr { val = val.Elem() } + if !val.IsValid() || val.IsZero() { + return nil + } typ := val.Type() for i := 0; i < val.NumField(); i++ { @@ -78,13 +96,149 @@ func getFieldNamesFromJSONTags(input interface{}) map[string]string { return fieldMap } +// isRequiredField 检查字段是否是必填字段 +// 通过proto文件中的validate标签或者struct tag中的validate标签判断 +func (p *ParamsValidatorImpl) isRequiredField(field interface{}) bool { + if fd, ok := field.(protoreflect.FieldDescriptor); ok && fd != nil { + + if v, ok := proto.GetExtension(fd.Options(), common.E_Validate).(string); ok && strings.Contains(v, "required") { + return true + + } + + } + + if fd, ok := field.(reflect.StructField); ok && fd.Tag.Get("validate") != "" { + if v, ok := fd.Tag.Lookup("validate"); ok && strings.Contains(v, "required") { + return true + + } + + } + return false +} + +// getValidatePath 获取待验证字段的路径 +// 路径格式参考validate.StructPartial +func (p *ParamsValidatorImpl) getValidatePath(message protoreflect.ProtoMessage, field string, parent string) []string { + fieldsName := make([]string, 0) + md := message.ProtoReflect().Descriptor() + // log.Printf("get validate path,field:%s,parent:%s", field, parent) + if fd := md.Fields().ByJSONName(field); fd != nil { + fieldName := strcase.ToCamel(string(fd.Name())) + // fieldName := fd.JSONName() + if parent != "" { + fieldName = parent + "." + fieldName + } + fieldsName = append(fieldsName, fieldName) + + if fd.Kind() == protoreflect.MessageKind { + if fd.IsList() { + list := message.ProtoReflect().Get(fd).List() + for j := 0; j < list.Len(); j++ { + item := list.Get(j).Message().Interface() + fieldsName = append(fieldsName, p.FiltersFields(item, fmt.Sprintf("%s[%d]", fieldName, j))...) + } + } else if fd.IsMap() { + mapValue := message.ProtoReflect().Get(fd).Map() + + mapValue.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool { + + if fd.MapValue().Kind() == protoreflect.MessageKind { + + item := value.Message().Interface() + fieldsName = append(fieldsName, p.FiltersFields(item, fmt.Sprintf("%s[%v]", fieldName, key.Interface()))...) + + } else { + // log.Printf("map key path:%v", fmt.Sprintf("%s[%v]", fieldName, key.Interface())) + fieldsName = append(fieldsName, fmt.Sprintf("%s[%v]", fieldName, key.Interface())) + } + return true + }) + } else { + nestedMessage := message.ProtoReflect().Get(fd).Message().Interface() + fieldsName = append(fieldsName, p.FiltersFields(nestedMessage, fieldName)...) + } + } + } + + return fieldsName +} + +// FiltersFields 从FieldMask中获取过滤字段,获取待验证字段 +// required 字段优先级高于FieldMask +func (p *ParamsValidatorImpl) FiltersMessageFields(v interface{}) []string { + // fieldsMap := getFieldNamesFromJSONTags(v) + requiredFields := make([]string, 0) + maskFields := make([]string, 0) + + if message, ok := v.(protoreflect.ProtoMessage); ok { + md := message.ProtoReflect().Descriptor() + + // 遍历所有字段 + for i := 0; i < md.Fields().Len(); i++ { + field := md.Fields().Get(i) + // require 字段必须校验 + if p.isRequiredField(field) { + // log.Printf("required field:%s", field.JSONName()) + requiredFields = append(requiredFields, p.getValidatePath(message, field.JSONName(), "")...) + } + + // 检查字段是否是FieldMask类型 + if field.Kind() == protoreflect.MessageKind && !field.IsList() && !field.IsMap() { + + // 获取字段的值(确保它是FieldMask类型) + fieldValue := message.ProtoReflect().Get(field).Message() + mask, ok := fieldValue.Interface().(*fieldmaskpb.FieldMask) + if mask == nil || !ok { + continue + } + paths := make([]string, 0) + paths = append(paths, mask.Paths...) + for _, path := range paths { + maskField := strcase.ToCamel(path) + // if parent != "" { + // maskField = fmt.Sprintf("%s.%s", parent, strcase.ToCamel(path)) + // } + maskFields = append(maskFields, maskField) + maskFields = append(maskFields, p.getValidatePath(message, path, "")...) + } + } + } + return append(requiredFields, maskFields...) + } + return nil +} + // FiltersFields 从FieldMask中获取过滤字段,获取待验证字段 +// required 字段优先级高于FieldMask func (p *ParamsValidatorImpl) FiltersFields(v interface{}, parent string) []string { fieldsMap := getFieldNamesFromJSONTags(v) - fieldsName := make([]string, 0) + requiredFields := make([]string, 0) + maskFields := make([]string, 0) + if message, ok := v.(protoreflect.ProtoMessage); ok { md := message.ProtoReflect().Descriptor() + val := reflect.ValueOf(v) + typ := reflect.TypeOf(v) + if val.Kind() == reflect.Ptr { + val = val.Elem() + typ = typ.Elem() + } + isRequired := false + for k := range fieldsMap { + field := md.Fields().ByJSONName(k) + st, ok := typ.FieldByName(fieldsMap[k]) + if val.Kind() == reflect.Struct { + isRequired = ok && p.isRequiredField(st) + } + // 检查字段是否是必填字段 + // 如果是必填字段,将其加入requiredFields,用于检查 + if p.isRequiredField(field) || isRequired { + requiredFields = append(requiredFields, p.getValidatePath(message, k, parent)...) + } + } // 遍历所有字段 for i := 0; i < md.Fields().Len(); i++ { field := md.Fields().Get(i) @@ -97,71 +251,57 @@ func (p *ParamsValidatorImpl) FiltersFields(v interface{}, parent string) []stri if mask == nil || !ok { continue } - for _, path := range mask.Paths { - if fd := message.ProtoReflect().Descriptor().Fields().ByJSONName(path); fd != nil { - fieldName := "" - if parent != "" { - fieldName = parent + "." + fieldsMap[fd.JSONName()] - } else { - fieldName = fieldsMap[fd.JSONName()] - } - if fd.Kind() == protoreflect.MessageKind { - if fd.IsList() { - for j := 0; j < message.ProtoReflect().Get(fd).List().Len(); j++ { - if fd.Kind() == protoreflect.MessageKind { - fieldsName = append(fieldsName, p.FiltersFields(message.ProtoReflect().Get(fd).List().Get(j).Message().Interface(), fmt.Sprintf("%s[%d]", fieldName, j))...) - } else { - fieldsName = append(fieldsName, fmt.Sprintf("%s[%d]", fieldName, j)) - } - } - } else { - fieldsName = append(fieldsName, p.FiltersFields(message.ProtoReflect().Get(fd).Message().Interface(), fieldName)...) - } - // fieldsName = append(fieldsName, p.FiltersFields(message.ProtoReflect().Get(fd).Interface(), fieldName)...) - } else { - fieldsName = append(fieldsName, fieldName) - } + paths := make([]string, 0) + paths = append(paths, mask.Paths...) + for _, path := range paths { + maskField := strcase.ToCamel(path) + if parent != "" { + maskField = fmt.Sprintf("%s.%s", parent, strcase.ToCamel(path)) } - + maskFields = append(maskFields, maskField) + maskFields = append(maskFields, p.getValidatePath(message, path, parent)...) } } } - return fieldsName + return append(requiredFields, maskFields...) } return nil } -func RegisterCustomValidators(v *validator.Validate) { - _=v.RegisterValidation("required_if", requiredIf) -} -// requiredIf 自定义验证器逻辑 -func requiredIf(fl validator.FieldLevel) bool { - param := fl.Param() - field := fl.Field() +func (p *ParamsValidatorImpl) ValidateParams(v interface{}) error { + // p.validate.Struct() + var err error + if message, ok := v.(proto.Message); ok { + filters := p.FiltersMessageFields(v) + duplicateFilters := make([]string, 0) + fieldsSet := make(map[string]struct{}) + for _, f := range filters { + if _, ok := fieldsSet[f]; !ok { + fieldsSet[f] = struct{}{} + duplicateFilters = append(duplicateFilters, f) + } + } - // 获取参数字段值 - paramField := fl.Parent().FieldByName(param) + pv := NewProtobufValidate(p.validate) + // log.Printf("validate fields:%v", duplicateFilters) + err = pv.ProtobufPartial(message, common.E_Validate, duplicateFilters...) + // err = p.ValidateProtoMessage(message, common.E_Validate, fieldsSet, strcase.ToCamel(string(message.ProtoReflect().Descriptor().Name()))+".") - // 如果参数字段为空,当前字段必须非空 - if paramField.String() == "" { - return field.String() != "" - } + } else { - return true -} -func (p *ParamsValidatorImpl) ValidateParams(v interface{}) error { - validate := validator.New() - RegisterCustomValidators(validate) - err := validate.Struct(v) - filters := p.FiltersFields(v, "") - if len(filters) > 0 { - err = validate.StructPartial(v, filters...) + err = gosdk.NewError(fmt.Errorf("params validation failed: params is not a proto.Message"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", gosdk.WithClientMessage("params validation failed: unsupported type")) } - if errs, ok := err.(validator.ValidationErrors); ok { - clientMsg := fmt.Sprintf("params %s validation failed with %v,except %s,%v", errs[0].Field(), errs[0].Value(), errs[0].ActualTag(), filters) - return gosdk.NewError(fmt.Errorf("params %s validation failed: %v", errs[0].Field(), errs[0].Value()), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", gosdk.WithClientMessage(clientMsg)) + fieldName := "" + + validateErr := validator.ValidationErrors{} + if errors.As(err, &validateErr) { + if validateErr[0].Namespace() != "" { + fieldName = validateErr[0].Namespace() + } + clientMsg := fmt.Sprintf("params %s validation failed with %v,except %s", fieldName, validateErr[0].Value(), validateErr[0].ActualTag()) + return gosdk.NewError(fmt.Errorf("params %s validation failed: %v due to %v", fieldName, validateErr[0].Value(), validateErr[0].ActualTag()), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", gosdk.WithClientMessage(clientMsg)) } - return nil + return err } func (p *ParamsValidatorImpl) SetPriority(priority int) { @@ -176,6 +316,7 @@ func (p *ParamsValidatorImpl) Name() string { } func (p *ParamsValidatorImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + // fmt.Print("params validator unary interceptor\n") err = p.ValidateParams(req) if err != nil { return nil, err @@ -193,7 +334,16 @@ func (p *ParamsValidatorImpl) StreamInterceptor(srv interface{}, ss grpc.ServerS err := handler(srv, validateStream) return err } +func (p *ParamsValidatorImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return streamer(ctx, desc, cc, method, opts...) +} func NewParamsValidator() ParamsValidator { - return &ParamsValidatorImpl{} + + v := &ParamsValidatorImpl{ + validate: validator.New(), + } + // RegisterCustomValidators(v.validate) + return v + } diff --git a/internal/middleware/vaildator_test.go b/internal/middleware/vaildator_test.go index 7279f76..93beacb 100644 --- a/internal/middleware/vaildator_test.go +++ b/internal/middleware/vaildator_test.go @@ -6,108 +6,314 @@ import ( "github.com/begonia-org/begonia/internal/middleware" hello "github.com/begonia-org/go-sdk/api/example/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/go-playground/validator/v10" c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" "google.golang.org/grpc" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/dynamicpb" "google.golang.org/protobuf/types/known/fieldmaskpb" ) -func TestValidatorUnaryInterceptor(t *testing.T) { - c.Convey("test validator unary interceptor", t, func() { - validator := middleware.NewParamsValidator() +type HelloSubRequest struct { + SubMsg string `protobuf:"bytes,1,opt,name=sub_msg,proto3" json:"sub_msg,omitempty"` + // @gotags: validate:"required" + SubName string `protobuf:"bytes,2,opt,name=sub_name,proto3" json:"sub_name,omitempty" validate:"required"` + // @gotags: validate:"required,gte=18,lte=35" + SubAge int32 `protobuf:"varint,4,opt,name=sub_age,proto3" json:"sub_age,omitempty" validate:"required,gte=18,lte=35"` + UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,3,opt,name=update_mask,proto3" json:"update_mask,omitempty"` +} - _, err := validator.UnaryInterceptor(context.Background(), &hello.HelloRequestWithValidator{ - Name: "test", - Msg: "test", - Sub: &hello.HelloSubRequest{ - SubMsg: "test", - UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, - }, - UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"name", "msg", "sub"}}, - }, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { - return nil, nil - }) - c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "validation failed") +type HelloRequestWithValidator struct { - validator.SetPriority(1) - c.So(validator.Priority(), c.ShouldEqual, 1) - c.So(validator.Name(), c.ShouldEqual, "ParamsValidator") + // @gotags: validate:"required" + Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty" validate:"required"` + // @gotags: validate:"required" + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty" validate:"required"` + Age int32 `protobuf:"varint,3,opt,name=age,proto3" json:"age,omitempty" validate:"required,gte=18,lte=35"` + Sub *HelloSubRequest `protobuf:"bytes,4,opt,name=sub,proto3" json:"sub,omitempty"` + // @gotags: validate:"required,dive" + Subs []*HelloSubRequest `protobuf:"bytes,5,rep,name=subs,proto3" json:"subs,omitempty" validate:"required,dive"` + UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,6,opt,name=update_mask,proto3" json:"update_mask,omitempty"` + // @gotags: validate:"required,dive" + SubMap map[string]*HelloSubRequest `protobuf:"bytes,7,rep,name=sub_map,proto3" json:"sub_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3" validate:"required,dive"` + // @gotags: validate:"required" + SubMap2 map[string]string `protobuf:"bytes,8,rep,name=sub_map2,proto3" json:"sub_map2,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3" validate:"required"` +} - _, err = validator.UnaryInterceptor(context.Background(), &hello.HelloRequestWithValidator{ - Name: "test", - Msg: "test", +func TestValidateDynamicProtoMessage(t *testing.T) { + c.Convey("test dynamic proto message", t, func() { + req := &hello.HelloRequestWithValidator{ + Name: "test", + Msg: "test", + Age: 16, + FloatNum: 0.0, + BoolData: true, + ExEnum: hello.ExampleEnum_EX_RUNNING, + ExEnums: []hello.ExampleEnum{ + hello.ExampleEnum_EX_RUNNING, + }, + EnumMap: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_RUNNING, + }, + EnumMap2: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_UNKNOWN, + }, + Strs: []string{"test"}, Sub: &hello.HelloSubRequest{ SubMsg: "test", + SubAge: 19, SubName: "test", UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, }, Subs: []*hello.HelloSubRequest{ { + SubAge: 19, + SubName: "test", SubMsg: "test", UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, }, { + SubName: "test", + SubAge: 19, SubMsg: "test", - UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_msg"}}, }, }, - UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"name", "msg", "sub", "subs"}}, - }, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { - return nil, nil - }) + SubMap: map[string]*hello.HelloSubRequest{ + "TEST1": { + SubName: "test", + SubAge: 19, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_age"}}, + }, + }, + SubMap2: map[string]string{ + "TEST1": "test", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"name", "msg", "sub", "subs", "sub_map", "sub_map2"}}, + } + + dpb := dynamicpb.NewMessage(req.ProtoReflect().Descriptor()) + b, _ := protojson.Marshal(req) + _ = protojson.Unmarshal(b, dpb) + + pv := middleware.NewProtobufValidate(validator.New()) + err := pv.Protobuf(dpb, common.E_Validate) c.So(err, c.ShouldNotBeNil) - t.Log(err.Error()) - _, err = validator.UnaryInterceptor(context.Background(), &hello.HelloRequestWithValidator{ - Name: "test", - Msg: "test", - Sub: &hello.HelloSubRequest{ - SubMsg: "test", + c.So(err.Error(), c.ShouldContainSubstring, "Age") + }) + +} +func TestValidateProtoMessage(t *testing.T) { + req := &hello.HelloRequestWithValidator{ + Name: "test", + Msg: "test", + Age: 19, + Age2: 19, + FloatNum: 1.1, + BoolData: true, + BytesData: []byte("test"), + ExEnum: hello.ExampleEnum_EX_RUNNING, + ExEnums: []hello.ExampleEnum{ + hello.ExampleEnum_EX_RUNNING, + }, + EnumMap: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_RUNNING, + }, + EnumMap2: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_UNKNOWN, + }, + Strs: []string{"test"}, + Sub: &hello.HelloSubRequest{ + SubMsg: "test", + SubAge: 19, + SubName: "test", + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, + }, + Subs: []*hello.HelloSubRequest{ + { + SubAge: 19, SubName: "test", + SubMsg: "test", UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, }, - Subs: []*hello.HelloSubRequest{ - { - SubMsg: "test", - SubName: "test", - UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, - }, + { + SubName: "test", + SubAge: 19, + SubMsg: "test", + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_msg"}}, }, - UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"name", "msg", "sub", "subs"}}, - }, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + }, + SubMap: map[string]*hello.HelloSubRequest{ + "TEST1": { + SubName: "test", + SubAge: 19, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_age"}}, + }, + }, + SubMap2: map[string]string{ + "TEST1": "test", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"name", "msg", "sub", "subs", "sub_map", "sub_map2"}}, + } + c.Convey("test validator unary interceptor", t, func() { + validator := middleware.NewParamsValidator() + + validator.SetPriority(1) + c.So(validator.Priority(), c.ShouldEqual, 1) + c.So(validator.Name(), c.ShouldEqual, "ParamsValidator") + + _, err := validator.UnaryInterceptor(context.Background(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) c.So(err, c.ShouldBeNil) - _, err = validator.UnaryInterceptor(context.Background(), &struct{}{}, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + req2 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req2.Age = 16 + _, err = validator.UnaryInterceptor(context.Background(), req2, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) - c.So(err, c.ShouldBeNil) - }) -} -func TestRequireIf(t *testing.T) { - c.Convey("test require if", t, func() { - v := []struct { - Field string `validate:"required_if=Field2"` - Field2 string - Field3 string `validate:"required_if=Field2"` - }{{ - Field: "", - Field2: "test", - Field3: "", - }, - { - Field: "", - Field2: "", - Field3: "", - }, - } - validator := middleware.NewParamsValidator() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Age") + + req3 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req3.Subs[1].SubAge = 16 + _, err = validator.UnaryInterceptor(context.Background(), req3, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Subs[1].SubAge") + req4 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req4.SubMap["TEST1"].SubAge = 16 + _, err = validator.UnaryInterceptor(context.Background(), req4, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "SubMap[TEST1].SubAge") + req5 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req5.Subs[0] = &hello.HelloSubRequest{ + SubName: "test2", + } + _, err = validator.UnaryInterceptor(context.Background(), req5, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Subs[0].SubAge") + + req6 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req6.Sub.SubAge = 16 + _, err = validator.UnaryInterceptor(context.Background(), req6, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub.SubAge") + + req7 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req7.Subs[1] = &hello.HelloSubRequest{ + SubAge: 19, + SubMsg: "test", + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_age"}}, + } + _, err = validator.UnaryInterceptor(context.Background(), req7, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Subs[1].SubName") + + req8 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req8.SubMap2 = nil + _, err = validator.UnaryInterceptor(context.Background(), req8, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "SubMap2") + + req9 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req9.ExEnum = hello.ExampleEnum_EX_UNKNOWN + _, err = validator.UnaryInterceptor(context.Background(), req9, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "ExEnum") + req10 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req10.Sub = nil + _, err = validator.UnaryInterceptor(context.Background(), req10, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub") - err:=validator.ValidateParams(v[0]) + req11 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req11.EnumMap = nil + _, err = validator.UnaryInterceptor(context.Background(), req11, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "EnumMap") + + req12 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req12.EnumMap2 = nil + req12.UpdateMask.Paths = append(req12.UpdateMask.Paths, "enum_map2") + _, err = validator.UnaryInterceptor(context.Background(), req12, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) c.So(err, c.ShouldBeNil) - err=validator.ValidateParams(v[1]) - c.So(err,c.ShouldNotBeNil) + + req14 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req14.Subs = nil + _, err = validator.UnaryInterceptor(context.Background(), req14, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Subs") + req15 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req15.Name = "hello" + _, err = validator.UnaryInterceptor(context.Background(), req15, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub2") + req16 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req16.Sub = nil + _, err = validator.UnaryInterceptor(context.Background(), req16, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub") + + st := struct { + Name string `validate:"required"` + }{} + _, err = validator.UnaryInterceptor(context.Background(), st, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "params is not a proto.Message") + }) + c.Convey("test NewStructFromProtobuf", t, func() { + validate := validator.New() + vd := middleware.NewProtobufValidate(validate) + v := vd.NewStructFromProtobuf(tiga.DeepCopy(req).(*hello.HelloRequestWithValidator), common.E_Validate) + err := validate.Struct(v) + c.So(err, c.ShouldBeNil) + err = vd.ProtobufPartialCtx(context.Background(), tiga.DeepCopy(req).(*hello.HelloRequestWithValidator), common.E_Validate) + c.So(err, c.ShouldBeNil) + }) +} +func TestValidator(t *testing.T) { + st := struct { + IntNum int `validate:"required"` + BoolField bool `validate:"required"` + }{ + IntNum: 1, + BoolField: false, + } + v := validator.New() + err := v.Struct(st) + t.Log(err) } func TestValidatorStreamInterceptor(t *testing.T) { c.Convey("test stream interceptor", t, func() { @@ -129,3 +335,15 @@ func TestValidatorStreamInterceptor(t *testing.T) { c.So(err, c.ShouldNotBeNil) }) } +func TestValidatorStreamClientInterceptor(t *testing.T) { + c.Convey("test stream client interceptor", t, func() { + validator := middleware.NewParamsValidator() + + _, err := validator.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // return middleware.NewGrpcPluginClientStream(ctx, desc, cc, method, opts...),nil + return nil, nil + }) + c.So(err, c.ShouldBeNil) + + }) +} diff --git a/internal/migrate/admin.go b/internal/migrate/admin.go index 9c97b36..5184cc5 100644 --- a/internal/migrate/admin.go +++ b/internal/migrate/admin.go @@ -44,13 +44,27 @@ func (m *UsersOperator) InitAdminUser(passwd string, aseKey, ivKey string, name, CreatedAt: timestamppb.New(time.Now()), UpdatedAt: timestamppb.New(time.Now()), IsDeleted: false, + TenantId: fmt.Sprintf("%d", uid), } err = tiga.EncryptStructAES([]byte(aseKey), user, ivKey) if err != nil { return "", err } - err = m.mysql.Create(context.Background(), user) + err = m.mysql.Create(context.Background(), user, nil) + tenant := &api.Tenants{ + TenantId: fmt.Sprintf("%d", uid), + TenantName: name, + Description: "Super Admin", + CreatedAt: timestamppb.New(time.Now()), + UpdatedAt: timestamppb.New(time.Now()), + Tags: []string{"admin"}, + AdminId: fmt.Sprintf("%d", uid), + } + if err != nil { + return "", err + } + _, err = m.mysql.Upsert(context.Background(), tenant, nil) return user.Uid, err } return userExist.Uid, nil diff --git a/internal/migrate/app.go b/internal/migrate/app.go index 48d78d2..5140304 100644 --- a/internal/migrate/app.go +++ b/internal/migrate/app.go @@ -23,7 +23,7 @@ type APPOperator struct { func NewAPPOperator(mysql *tiga.MySQLDao) *APPOperator { return &APPOperator{mysql: mysql} } -func dumpInitApp(app *api.Apps,env string) error { +func dumpInitApp(app *api.Apps, env string) error { log.Print("########################################admin-app###############################") log.Printf("Init appid:%s", app.Appid) log.Printf("Init accessKey:%s", app.AccessKey) @@ -36,7 +36,7 @@ func dumpInitApp(app *api.Apps,env string) error { if err := os.MkdirAll(path, os.ModePerm); err != nil { return err } - file, err := os.Create(filepath.Join(path, fmt.Sprintf("admin-app.%s.json",env))) + file, err := os.Create(filepath.Join(path, fmt.Sprintf("admin-app.%s.json", env))) if err != nil { return err } @@ -49,11 +49,11 @@ func dumpInitApp(app *api.Apps,env string) error { log.Print("#################################################################################") return nil } -func (m *APPOperator) InitAdminAPP(owner,env string) (err error) { +func (m *APPOperator) InitAdminAPP(owner, env string) (err error) { app := &api.Apps{} defer func() { if app.Appid != "" { - if errInit := dumpInitApp(app,env); errInit != nil { + if errInit := dumpInitApp(app, env); errInit != nil { err = errInit } } @@ -91,7 +91,7 @@ func (m *APPOperator) InitAdminAPP(owner,env string) (err error) { UpdatedAt: timestamppb.New(time.Now()), Tags: []string{"admin"}, } - err = m.mysql.Create(context.Background(), app) + err = m.mysql.Create(context.Background(), app, nil) return err } return nil diff --git a/internal/migrate/migrate.go b/internal/migrate/migrate.go index 84b5bca..3ada09d 100644 --- a/internal/migrate/migrate.go +++ b/internal/migrate/migrate.go @@ -26,7 +26,7 @@ type MySQLMigrate struct { func NewTableModels() []TableModel { tables := make([]TableModel, 0) - tables = append(tables, api.Users{}, endpoint.Endpoints{}, app.Apps{}, file.Files{}, file.Buckets{}) + tables = append(tables, api.Users{}, endpoint.Endpoints{}, app.Apps{}, file.Files{}, file.Buckets{}, api.Tenants{}, api.Business{}, api.TenantsBusiness{}) return tables } func NewMySQLMigrate(mysql *tiga.MySQLDao, models ...TableModel) *MySQLMigrate { diff --git a/internal/migrate/operator.go b/internal/migrate/operator.go index b453808..c3246a7 100644 --- a/internal/migrate/operator.go +++ b/internal/migrate/operator.go @@ -29,12 +29,12 @@ func (m *InitOperator) Init() error { phone := m.config.GetDefaultAdminPhone() aseKey := m.config.GetAesKey() ivKey := m.config.GetAesIv() - env:=m.config.GetEnv() + env := m.config.GetEnv() uid, err := m.user.InitAdminUser(adminPasswd, aseKey, ivKey, name, email, phone) if err != nil { log.Printf("failed to init admin user: %v", err) return err } - return m.app.InitAdminAPP(uid,env) + return m.app.InitAdminAPP(uid, env) } diff --git a/internal/migrate/operator_test.go b/internal/migrate/operator_test.go index b41cbc2..0a580da 100644 --- a/internal/migrate/operator_test.go +++ b/internal/migrate/operator_test.go @@ -126,7 +126,7 @@ func TestAppOperatorFail(t *testing.T) { } patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.First, fmt.Errorf("first failed")) defer patch.Reset() - err := operator.InitAdminAPP("test",env) + err := operator.InitAdminAPP("test", env) patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "first failed") @@ -137,7 +137,7 @@ func TestAppOperatorFail(t *testing.T) { for _, caseV := range cases { patch2 := gomonkey.ApplyFuncReturn(caseV.patch, caseV.output...) defer patch2.Reset() - err := operator.InitAdminAPP("test",env) + err := operator.InitAdminAPP("test", env) patch2.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, caseV.err.Error()) diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index bb2790a..d6873ea 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -227,7 +227,7 @@ func (c *Config) GetServiceTagsPrefix() string { return fmt.Sprintf("%s/tags", prefix) } func (c *Config) GetServiceKey(key string) string { - if tiga.IsSnowflakeID(key){ + if tiga.IsSnowflakeID(key) { prefix := c.GetServicePrefix() return filepath.Join(prefix, key) } diff --git a/internal/pkg/config/config_test.go b/internal/pkg/config/config_test.go index 641d952..90ad8e1 100644 --- a/internal/pkg/config/config_test.go +++ b/internal/pkg/config/config_test.go @@ -75,7 +75,7 @@ func TestConfig(t *testing.T) { c.So(config.GetServicePrefix(), c.ShouldEndWith, "/service") c.So(config.GetServiceNamePrefix(), c.ShouldEndWith, "/service_name") c.So(config.GetServiceTagsPrefix(), c.ShouldEndWith, "/tags") - snk,_:=tiga.NewSnowflake(1) + snk, _ := tiga.NewSnowflake(1) c.So(config.GetServiceKey("test"), c.ShouldStartWith, config.GetServiceNamePrefix()) c.So(config.GetServiceKey(snk.GenerateIDString()), c.ShouldStartWith, config.GetServicePrefix()) diff --git a/internal/server/server.go b/internal/server/server.go index 4dc9558..8b498fe 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,7 +13,6 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/middleware" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/service" loadbalance "github.com/begonia-org/go-loadbalancer" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -52,7 +51,7 @@ func readDesc(conf *config.Config) (gateway.ProtobufDescription, error) { func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []service.Service, pluginApply *middleware.PluginsApply) *gateway.GatewayServer { // 参数选项 opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]runtime.ServeMuxOption, 0), @@ -63,10 +62,12 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption("application/x-www-form-urlencoded", gateway.NewFormUrlEncodedMarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption(runtime.MIMEWildcard, gateway.NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption("application/octet-stream", gateway.NewRawBinaryUnmarshaler())) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption("text/event-stream", gateway.NewEventSourceMarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMetadata(gateway.IncomingHeadersToMetadata)) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithErrorHandler(gateway.HandleErrorWithLogger(gateway.Log))) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithForwardResponseOption(gateway.HttpResponseBodyModify)) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithStreamErrorHandler(gateway.HandleServerStreamError(gateway.Log))) // opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithRoutingErrorHandler(middleware.HandleRoutingError)) // 连接池配置 opts.PoolOptions = append(opts.PoolOptions, loadbalance.WithMaxActiveConns(100)) @@ -74,6 +75,7 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv // 中间件配置 opts.Options = append(opts.Options, grpc.ChainUnaryInterceptor(pluginApply.UnaryInterceptorChains()...)) opts.Options = append(opts.Options, grpc.ChainStreamInterceptor(pluginApply.StreamInterceptorChains()...)) + opts.Middlewares = append(opts.Middlewares, pluginApply.StreamClientInterceptorChains()...) pd, err := readDesc(conf) if err != nil { panic(err) @@ -84,7 +86,7 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv opts.HttpHandlers = append(opts.HttpHandlers, cors.Handle) gw := gateway.New(cfg, opts) - routersList := routers.Get() + routersList := gateway.GetRouter() for _, srv := range services { err := gw.RegisterLocalService(context.Background(), pd, srv.Desc(), srv) if err != nil { diff --git a/internal/service/app_test.go b/internal/service/app_test.go index 60a5934..15d880b 100644 --- a/internal/service/app_test.go +++ b/internal/service/app_test.go @@ -38,9 +38,10 @@ func addApp(t *testing.T) { c.So(err, c.ShouldBeNil) c.So(rsp2.StatusCode, c.ShouldEqual, common.Code_OK) c.So(rsp2.Name, c.ShouldNotBeEmpty) - _, err = apiClient.PostAppConfig(context.Background(), &api.AppsRequest{Name: name, Description: "test"}) + rsp, err = apiClient.PostAppConfig(context.Background(), &api.AppsRequest{Name: name, Description: "test"}) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldEqual, "duplicate app name") + c.So(rsp.StatusCode, c.ShouldEqual, int(api.APPSvrCode_APP_DUPLICATE_ERR)) + // c.So(err.Error(), c.ShouldEqual, "duplicate app name") // c.So(rsp3.StatusCode, c.ShouldEqual, common.Code_ERR) name2 = fmt.Sprintf("app-service-2-%s", time.Now().Format("20060102150405")) @@ -72,7 +73,7 @@ func testPatchApp(t *testing.T) { func() { apiClient := client.NewAppAPI(apiAddr, accessKey, secret) name := fmt.Sprintf("app-%s", time.Now().Format("20060102150405")) - rsp2, err := apiClient.UpdateAPP(context.Background(), appid, name, "test patch", nil) + rsp2, err := apiClient.UpdateAPP(context.Background(), appid, client.WithPatchParams("name", name), client.WithPatchParams("description", "test patch")) c.So(err, c.ShouldBeNil) c.So(rsp2.StatusCode, c.ShouldEqual, common.Code_OK) rsp2, err = apiClient.GetAPP(context.Background(), appid) @@ -80,9 +81,10 @@ func testPatchApp(t *testing.T) { c.So(rsp2.StatusCode, c.ShouldEqual, common.Code_OK) c.So(rsp2.Name, c.ShouldEqual, name) - _, err = apiClient.UpdateAPP(context.Background(), appid, name2, "test patch", nil) + rsp, err := apiClient.UpdateAPP(context.Background(), appid, client.WithPatchParams("name", name2), client.WithPatchParams("description", "test patch")) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldEqual, "duplicate app name") + // c.So(err.Error(), c.ShouldEqual, "duplicate app name") + c.So(rsp.StatusCode, c.ShouldEqual, int(api.APPSvrCode_APP_DUPLICATE_ERR)) }, ) @@ -102,9 +104,9 @@ func delApp(t *testing.T) { _, err = apiClient.GetAPP(context.Background(), appid) c.So(err, c.ShouldNotBeNil) - _, err = apiClient.DeleteAPP(context.TODO(), appid) + rsp, err := apiClient.DeleteAPP(context.TODO(), appid) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldEqual, "app not found") + c.So(rsp.StatusCode, c.ShouldEqual, int(api.APPSvrCode_APP_NOT_FOUND_ERR)) // c.So(rsp3.StatusCode, c.ShouldEqual, common.Code_OK) }) } @@ -149,7 +151,7 @@ func TestApp(t *testing.T) { t.Run("list app", listAPP) t.Run("list app err", testListErr) t.Run("patch app", testPatchApp) - // appid = "442568851213783040" + // // appid = "442568851213783040" t.Run("del app", delApp) } diff --git a/internal/service/base_test.go b/internal/service/base_test.go index f1c0933..c40a376 100644 --- a/internal/service/base_test.go +++ b/internal/service/base_test.go @@ -58,7 +58,7 @@ func readInitAPP() { op := internal.InitOperatorApp(config.ReadConfig(env)) _ = op.Init() path := filepath.Join(homeDir, ".begonia") - path = filepath.Join(path, fmt.Sprintf("admin-app.%s.json",env)) + path = filepath.Join(path, fmt.Sprintf("admin-app.%s.json", env)) file, err := os.Open(path) if err != nil { @@ -159,7 +159,7 @@ func clean() { mysql := tiga.NewMySQLDao(conf) mysql.RegisterTimeSerializer() - err=mysql.GetModel(&user.Users{}).Where("`group` = ?", "test-user-01").Delete(&user.Users{}).Error + err = mysql.GetModel(&user.Users{}).Where("`group` = ?", "test-user-01").Delete(&user.Users{}).Error if err != nil { log.Fatalf("Failed to delete keys with prefix %s: %v", prefix, err) } diff --git a/internal/service/business.go b/internal/service/business.go new file mode 100644 index 0000000..c67f4c2 --- /dev/null +++ b/internal/service/business.go @@ -0,0 +1,75 @@ +package service + +import ( + "context" + "fmt" + + "github.com/begonia-org/begonia/internal/biz" + "github.com/begonia-org/begonia/internal/pkg" + "github.com/begonia-org/begonia/internal/pkg/config" + gosdk "github.com/begonia-org/go-sdk" + api "github.com/begonia-org/go-sdk/api/user/v1" + user "github.com/begonia-org/go-sdk/api/user/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/begonia-org/go-sdk/logger" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +type BusinessService struct { + api.UnimplementedBusinessServiceServer + business *biz.BusinessUsecase + cfg *config.Config + log logger.Logger +} + +func NewBusinessService(business *biz.BusinessUsecase, log logger.Logger, cfg *config.Config) api.BusinessServiceServer { + return &BusinessService{business: business, log: log, cfg: cfg} +} + +func (b *BusinessService) Add(ctx context.Context, in *api.PostBusinessRequest) (*api.Business, error) { + identity := GetIdentity(ctx) + if identity == "" { + return nil, gosdk.NewError(pkg.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + + } + return b.business.Add(ctx, in, identity) +} +func (b *BusinessService) Get(ctx context.Context, in *api.GetBusinessRequest) (*api.Business, error) { + bs, err := b.business.Get(ctx, in.Business) + + return bs, err +} +func (b *BusinessService) Update(ctx context.Context, in *api.PatchBusinessRequest) (*api.Business, error) { + bs := &api.Business{ + BusinessName: in.BusinessName, + Description: in.Description, + Tags: in.Tags, + BusinessId: in.BusinessId, + UpdateMask: in.UpdateMask, + } + err := b.business.Patch(ctx, bs) + if err != nil { + return nil, err + } + + return b.Get(ctx, &api.GetBusinessRequest{Business: in.BusinessId}) +} +func (b *BusinessService) Delete(ctx context.Context, in *api.DeleteBusinessRequest) (*api.DeleteBusinessResponse, error) { + err := b.business.Del(ctx, in.Business) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("Delete business %s error:%w", in.Business, err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "delete business failed") + } + return &api.DeleteBusinessResponse{}, nil +} + +func (b *BusinessService) List(ctx context.Context, in *api.ListBusinessRequest) (*api.ListBusinessResponse, error) { + bs, err := b.business.List(ctx, in.Tags, in.Page, in.PageSize) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("List business error:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "list business failed") + } + return &api.ListBusinessResponse{Business: bs}, nil +} +func (b *BusinessService) Desc() *grpc.ServiceDesc { + return &api.BusinessService_ServiceDesc +} diff --git a/internal/service/business_test.go b/internal/service/business_test.go new file mode 100644 index 0000000..ba8c1e9 --- /dev/null +++ b/internal/service/business_test.go @@ -0,0 +1,141 @@ +package service_test + +import ( + "context" + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/begonia-org/begonia/internal/service" + "github.com/begonia-org/go-sdk/client" + common "github.com/begonia-org/go-sdk/common/api/v1" + c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" +) + +var bn = "" +var bid = "" + +func testAddBusinessService(t *testing.T) { + c.Convey("test add business", t, func() { + apiClient := client.NewBusinessAPI(apiAddr, accessKey, secret) + snk, _ := tiga.NewSnowflake(2) + bn = fmt.Sprintf("test-service-%s", snk.GenerateIDString()) + rsp, err := apiClient.PostBusiness(context.Background(), bn, "test", []string{"test"}) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + c.So(rsp.BusinessId, c.ShouldNotBeEmpty) + bid = rsp.BusinessId + }) + c.Convey("test add business duplicate", t, func() { + apiClient := client.NewBusinessAPI(apiAddr, accessKey, secret) + _, err := apiClient.PostBusiness(context.Background(), bn, "test", []string{"test"}) + c.So(err, c.ShouldNotBeNil) + t.Log(err.Error()) + // c.So(rsp.StatusCode, c.ShouldEqual, common.Code_CONFLICT) + // c.So(err.Error(), c.ShouldContainSubstring, "Duplicate entry") + }) + c.Convey("test no id", t, func() { + patch := gomonkey.ApplyFuncReturn(service.GetIdentity, "") + defer patch.Reset() + apiClient := client.NewBusinessAPI(apiAddr, accessKey, secret) + _, err := apiClient.PostBusiness(context.Background(), bn, "test", []string{"test"}) + c.So(err, c.ShouldNotBeNil) + }) +} + +func testGetBusinessService(t *testing.T) { + apiClient := client.NewBusinessAPI(apiAddr, accessKey, secret) + + c.Convey("test get business by id", t, func() { + rsp, err := apiClient.GetBusiness(context.Background(), bid) + c.So(err, c.ShouldBeNil) + c.So(rsp.BusinessName, c.ShouldEqual, bn) + }) + c.Convey("test get business by name", t, func() { + rsp, err := apiClient.GetBusiness(context.Background(), bn) + c.So(err, c.ShouldBeNil) + c.So(rsp.BusinessName, c.ShouldEqual, bn) + }) + c.Convey("test get business id not found", t, func() { + rsp, err := apiClient.GetBusiness(context.Background(), "not found") + c.So(err, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_NOT_FOUND)) + }) +} + +func testUpdateBusinessService(t *testing.T) { + apiClient := client.NewBusinessAPI(apiAddr, accessKey, secret) + + c.Convey("test update business", t, func() { + rsp, err := apiClient.PatchBusiness(context.Background(), bid, client.WithPatchParams("description", "update desc")) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + c.So(rsp.Description, c.ShouldEqual, "update desc") + c.So(rsp.BusinessName, c.ShouldEqual, bn) + c.So(rsp.Tags[0], c.ShouldEqual, "test") + }) + c.Convey("test update with duplicate name", t, func() { + snk, _ := tiga.NewSnowflake(2) + bn2 := fmt.Sprintf("test-service2-%s", snk.GenerateIDString()) + rsp, err := apiClient.PostBusiness(context.Background(), bn2, "test", []string{"test"}) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + // bu := &api.PatchBusinessRequest{} + rsp2, err := apiClient.PatchBusiness(context.Background(), bid, client.WithPatchParams("business_name", bn2), client.WithPatchParams("name", bn2), client.WithPatchParams("description", "update desc")) + c.So(err, c.ShouldNotBeNil) + c.So(rsp2.StatusCode, c.ShouldEqual, int(common.Code_CONFLICT)) + }) + c.Convey("test update with not found id", t, func() { + snk, _ := tiga.NewSnowflake(2) + + rsp, err := apiClient.PatchBusiness(context.Background(), snk.GenerateIDString(), client.WithPatchParams("description", "update desc")) + c.So(err, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_NOT_FOUND)) + }) +} + +func testListBusinessService(t *testing.T) { + apiClient := client.NewBusinessAPI(apiAddr, accessKey, secret) + + c.Convey("test list business", t, func() { + rsp, err := apiClient.ListBusiness(context.Background(), []string{"test"}, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(len(rsp.Business), c.ShouldBeGreaterThanOrEqualTo, 1) + + rsp, err = apiClient.ListBusiness(context.Background(), []string{"test2"}, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(len(rsp.Business), c.ShouldEqual, 0) + + }) + c.Convey("test list business with error", t, func() { + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Pagination, fmt.Errorf("pagination error")) + defer patch.Reset() + rsp, err := apiClient.ListBusiness(context.Background(), []string{"test"}, 1, 20) + c.So(err, c.ShouldNotBeNil) + c.So(rsp, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_INTERNAL_ERROR)) + }) +} +func testDeleteBusinessService(t *testing.T) { + apiClient := client.NewBusinessAPI(apiAddr, accessKey, secret) + + c.Convey("test delete business", t, func() { + rsp, err := apiClient.DeleteBusiness(context.Background(), bid) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + }) + c.Convey("test delete business with not found id", t, func() { + rsp, err := apiClient.DeleteBusiness(context.Background(), bid) + c.So(err, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_NOT_FOUND)) + }) +} + +func TestBusinessService(t *testing.T) { + t.Run("test add business", testAddBusinessService) + t.Run("test get business", testGetBusinessService) + t.Run("test update business", testUpdateBusinessService) + t.Run("test list business", testListBusinessService) + t.Run("test delete business", testDeleteBusinessService) +} diff --git a/internal/service/file_test.go b/internal/service/file_test.go index 3ad7ae9..d0e6c36 100644 --- a/internal/service/file_test.go +++ b/internal/service/file_test.go @@ -34,7 +34,8 @@ import ( ) var fileBucket = "" -var localFileId="" +var localFileId = "" + func sumFileSha256(src string) (string, error) { file, err := os.Open(src) if err != nil { @@ -63,25 +64,24 @@ func makeBucket(t *testing.T) { c.Convey("test make bucket", t, func() { fileBucket = fmt.Sprintf("test-service-bucket-%s", time.Now().Format("20060102150405")) apiClient := client.NewFilesAPI(apiAddr, accessKey, secret, api.FileEngine_FILE_ENGINE_LOCAL) - rsp, err := apiClient.CreateBucket(context.Background(), fileBucket, "test", false,true) + rsp, err := apiClient.CreateBucket(context.Background(), fileBucket, "test", false, true) c.So(err, c.ShouldBeNil) c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + t.Logf("access key:%s", accessKey) minioFile := client.NewFilesAPI(apiAddr, accessKey, secret, api.FileEngine_FILE_ENGINE_MINIO) - rsp, err = minioFile.CreateBucket(context.Background(), fileBucket, "test", false,true) + rsp, err = minioFile.CreateBucket(context.Background(), fileBucket, "test", false, true) c.So(err, c.ShouldBeNil) c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) - // no idend patch := gomonkey.ApplyFuncReturn(service.GetIdentity, "") defer patch.Reset() // apiClient := client.NewFilesAPI(apiAddr, accessKey, secret, api.FileEngine_FILE_ENGINE_LOCAL) - _, err = apiClient.CreateBucket(context.Background(), fileBucket, "test", false,true) + _, err = apiClient.CreateBucket(context.Background(), fileBucket, "test", false, true) c.So(err, c.ShouldNotBeNil) patch.Reset() - }) } func upload(t *testing.T) { @@ -226,7 +226,6 @@ func download(t *testing.T) { t.Log(sha256Str) c.So(sha256Str, c.ShouldEqual, downloadedSha256) - patch2 := gomonkey.ApplyFuncReturn((*service.FileService).Metadata, nil, fmt.Errorf("test metadata error")) defer patch2.Reset() _, err = apiClient.DownloadFile(context.Background(), "test/helloworld.pb", tmp.Name(), "", fileBucket) @@ -236,14 +235,14 @@ func download(t *testing.T) { // query by fid sha256Str, err = apiClient.DownloadFile(context.Background(), localFileId, tmp.Name(), "", fileBucket) - c.So(err,c.ShouldBeNil) - c.So(sha256Str,c.ShouldNotBeEmpty) + c.So(err, c.ShouldBeNil) + c.So(sha256Str, c.ShouldNotBeEmpty) - patch3:=gomonkey.ApplyFuncReturn((*service.FileService).GetFileById, nil, fmt.Errorf("test get file by id error")) + patch3 := gomonkey.ApplyFuncReturn((*service.FileService).GetFileById, nil, fmt.Errorf("test get file by id error")) defer patch3.Reset() _, err = apiClient.DownloadFile(context.Background(), localFileId, tmp.Name(), "", fileBucket) patch3.Reset() - c.So(err,c.ShouldNotBeNil) + c.So(err, c.ShouldNotBeNil) }) } diff --git a/internal/service/service.go b/internal/service/service.go index b3effcc..26746f5 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -24,7 +24,10 @@ var ProviderSet = wire.NewSet(NewAuthzService, NewUserService, NewServices, NewEndpointsService, NewAppService, - NewSysService) + NewSysService, + NewTenantService, + NewBusinessService, +) type ServiceOptions func(*grpc.Server, *runtime.ServeMux, string) error @@ -34,10 +37,20 @@ func NewServices(file file.FileServiceServer, app app.AppsServiceServer, sys sys.SystemServiceServer, users user.UserServiceServer, + business user.BusinessServiceServer, + tenant user.TenantsServiceServer, ) []Service { services := make([]Service, 0) - services = append(services, file.(Service), authz.(Service), ep.(Service), app.(Service), sys.(Service), users.(Service)) + services = append(services, file.(Service), + authz.(Service), + ep.(Service), + app.(Service), + sys.(Service), + users.(Service), + business.(Service), + tenant.(Service), + ) return services } diff --git a/internal/service/tenant.go b/internal/service/tenant.go new file mode 100644 index 0000000..e4dbe93 --- /dev/null +++ b/internal/service/tenant.go @@ -0,0 +1,89 @@ +package service + +import ( + "context" + + "github.com/begonia-org/begonia/internal/biz" + "github.com/begonia-org/begonia/internal/pkg" + "github.com/begonia-org/begonia/internal/pkg/config" + gosdk "github.com/begonia-org/go-sdk" + api "github.com/begonia-org/go-sdk/api/user/v1" + user "github.com/begonia-org/go-sdk/api/user/v1" + "github.com/begonia-org/go-sdk/logger" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +type TenantService struct { + api.UnimplementedTenantsServiceServer + tenant *biz.TenantUsecase + cfg *config.Config + log logger.Logger +} + +func NewTenantService(tenant *biz.TenantUsecase, cfg *config.Config, log logger.Logger) api.TenantsServiceServer { + return &TenantService{tenant: tenant, cfg: cfg, log: log} +} +func (t *TenantService) Register(ctx context.Context, in *api.PostTenantRequest) (*api.Tenants, error) { + // st:=debug.Stack() + // fmt.Printf("TenantService stack:%s\n",st) + identity := GetIdentity(ctx) + if identity == "" { + return nil, gosdk.NewError(pkg.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + + } + return t.tenant.Add(ctx, in, identity) + +} +func (t *TenantService) Get(ctx context.Context, in *api.GetTenantRequest) (*api.Tenants, error) { + tenant, err := t.tenant.Get(ctx, in.Tenant) + return tenant, err +} +func (t *TenantService) Update(ctx context.Context, in *api.PatchTenantRequest) (*api.Tenants, error) { + + return t.tenant.Update(ctx, in) + +} +func (t *TenantService) List(ctx context.Context, in *api.ListTenantsRequest) (*api.ListTenantsResponse, error) { + tenants, err := t.tenant.List(ctx, in.Tags, in.Status, in.Page, in.PageSize) + if err != nil { + return nil, err + + } + return &api.ListTenantsResponse{Tenants: tenants}, nil + +} +func (t *TenantService) Delete(ctx context.Context, in *api.DeleteTenantRequest) (*api.DeleteTenantResponse, error) { + err := t.tenant.Delete(ctx, in.Tenant) + if err != nil { + return nil, err + } + return &api.DeleteTenantResponse{}, nil +} +func (t *TenantService) AddTenantBusiness(ctx context.Context, in *api.AddTenantBusinessRequest) (*api.TenantsBusiness, error) { + identity := GetIdentity(ctx) + if identity == "" { + return nil, gosdk.NewError(pkg.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + + } + tb, err := t.tenant.AddTenantBusiness(ctx, in.TenantId, in.BusinessId, in.Plan, identity) + return tb, err + +} +func (t *TenantService) DeleteTenantBusiness(ctx context.Context, in *api.DeleteTenantBusinessRequest) (*api.DeleteTenantBusinessResponse, error) { + err := t.tenant.DelTenantBusiness(ctx, in.TenantId, in.BusinessId) + if err != nil { + return nil, err + } + return &api.DeleteTenantBusinessResponse{}, nil +} +func (t *TenantService) ListTenantBusiness(ctx context.Context, in *api.ListTenantBusinessRequest) (*api.ListTenantBusinessResponse, error) { + tb, err := t.tenant.ListTenantBusiness(ctx, in.Tenant, in.Page, in.PageSize) + if err != nil { + return nil, err + } + return &api.ListTenantBusinessResponse{TenantsBusiness: tb}, nil +} +func (b *TenantService) Desc() *grpc.ServiceDesc { + return &api.TenantsService_ServiceDesc +} diff --git a/internal/service/tenant_test.go b/internal/service/tenant_test.go new file mode 100644 index 0000000..0e3f420 --- /dev/null +++ b/internal/service/tenant_test.go @@ -0,0 +1,170 @@ +package service_test + +import ( + "context" + "fmt" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/begonia-org/begonia/internal/service" + api "github.com/begonia-org/go-sdk/api/user/v1" + "github.com/begonia-org/go-sdk/client" + common "github.com/begonia-org/go-sdk/common/api/v1" + c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" +) + +var tid = "" +var tn = "" +var tbn = "" +var tbid = "" + +func testAddTenant(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + snk, _ := tiga.NewSnowflake(2) + + c.Convey("test add tenant", t, func() { + tn = fmt.Sprintf("test-add-tenant-%s", snk.GenerateIDString()) + rsp, err := apiClient.RegisterTenant(context.Background(), tn, "test tenant", fmt.Sprintf("%s@example.com", tn), []string{"test"}) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_OK)) + c.So(rsp.TenantId, c.ShouldNotBeEmpty) + tid = rsp.TenantId + + }) + c.Convey("test add tenant no creator", t, func() { + patch := gomonkey.ApplyFuncReturn(service.GetIdentity, "") + defer patch.Reset() + rsp, err := apiClient.RegisterTenant(context.Background(), tn, "test tenant", fmt.Sprintf("%s@example.com", tn), []string{"test"}) + c.So(err, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(api.UserSvrCode_USER_IDENTITY_MISSING_ERR)) + + }) +} +func testUpdateTenant(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + c.Convey("test update tenant", t, func() { + rsp, err := apiClient.PatchTenant(context.Background(), tid, client.WithPatchParams("description", "update tenant")) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_OK)) + c.So(rsp.TenantName, c.ShouldEqual, tn) + c.So(rsp.Description, c.ShouldEqual, "update tenant") + + }) + +} + +func testGetTenant(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + c.Convey("test get tenant", t, func() { + rsp, err := apiClient.GetTenant(context.Background(), tid) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_OK)) + c.So(rsp.TenantName, c.ShouldEqual, tn) + c.So(rsp.Description, c.ShouldEqual, "update tenant") + }) + +} +func testListTenant(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + c.Convey("test list tenant", t, func() { + rsp, err := apiClient.ListTenants(context.Background(), 1, 10, []string{"test"}, []string{api.TENANTS_STATUS_TENANTS_ACTIVE.String()}) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + c.So(rsp.Tenants, c.ShouldNotBeEmpty) + c.So(len(rsp.Tenants), c.ShouldBeGreaterThanOrEqualTo, 1) + snk, _ := tiga.NewSnowflake(1) + rsp2, err2 := apiClient.ListTenants(context.Background(), 1, 10, []string{snk.GenerateIDString()}, []string{api.TENANTS_STATUS_TENANTS_ACTIVE.String()}) + c.So(err2, c.ShouldBeNil) + c.So(rsp2.StatusCode, c.ShouldEqual, common.Code_OK) + c.So(rsp2.Tenants, c.ShouldBeEmpty) + c.So(len(rsp2.Tenants), c.ShouldEqual, 0) + + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Pagination, fmt.Errorf("pagination error")) + defer patch.Reset() + rsp3, err3 := apiClient.ListTenants(context.Background(), 1, 10, []string{"test"}, []string{api.TENANTS_STATUS_TENANTS_ACTIVE.String()}) + c.So(err3, c.ShouldNotBeNil) + c.So(rsp3.StatusCode, c.ShouldEqual, int(common.Code_INTERNAL_ERROR)) + }) + +} + +func testAddTenantBusiness(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + businessApi := client.NewBusinessAPI(apiAddr, accessKey, secret) + snk, _ := tiga.NewSnowflake(1) + c.Convey("test add tenant business", t, func() { + bRsp, err := businessApi.PostBusiness(context.Background(), fmt.Sprintf("test-business-%s", snk.GenerateIDString()), "test-business", []string{"test-plan"}) + c.So(err, c.ShouldBeNil) + tbid = bRsp.BusinessId + rsp, err := apiClient.AddTenantBusiness(context.Background(), tid, tbid, "test-plan") + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_OK)) + c.So(rsp.TenantId, c.ShouldEqual, tid) + c.So(rsp.BusinessId, c.ShouldNotBeEmpty) + }) + c.Convey("test add tenant business no creator", t, func() { + patch := gomonkey.ApplyFuncReturn(service.GetIdentity, "") + defer patch.Reset() + rsp, err := apiClient.AddTenantBusiness(context.Background(), tid, tbid, "test-plan") + c.So(err, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(api.UserSvrCode_USER_IDENTITY_MISSING_ERR)) + }) +} +func testListTenantBusiness(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + c.Convey("test list tenant business", t, func() { + rsp, err := apiClient.ListTenantBusiness(context.Background(), tid, 1, 10) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + c.So(rsp.TenantsBusiness, c.ShouldNotBeEmpty) + c.So(len(rsp.TenantsBusiness), c.ShouldBeGreaterThanOrEqualTo, 1) + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Pagination, fmt.Errorf("pagination error")) + defer patch.Reset() + rsp2, err2 := apiClient.ListTenantBusiness(context.Background(), tid, 1, 10) + c.So(err2, c.ShouldNotBeNil) + c.So(rsp2.StatusCode, c.ShouldEqual, int(common.Code_INTERNAL_ERROR)) + }) +} + +func testDeleteTenantBusiness(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + c.Convey("test delete tenant business", t, func() { + rsp, err := apiClient.DeleteTenantBusiness(context.Background(), tid, tbid) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + }) + c.Convey("test delete tenant error", t, func() { + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.UpdateSelectColumns, fmt.Errorf("update delete error")) + defer patch.Reset() + rsp, err := apiClient.DeleteTenantBusiness(context.Background(), tid, tbid) + c.So(err, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_INTERNAL_ERROR)) + }) +} +func testDeleteTenant(t *testing.T) { + apiClient := client.NewTenantAPI(apiAddr, accessKey, secret) + c.Convey("test delete tenant", t, func() { + rsp, err := apiClient.DeleteTenant(context.Background(), tid) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + }) + c.Convey("test delete tenant error", t, func() { + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.UpdateSelectColumns, fmt.Errorf("update delete error")) + defer patch.Reset() + rsp, err := apiClient.DeleteTenant(context.Background(), tid) + c.So(err, c.ShouldNotBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, int(common.Code_INTERNAL_ERROR)) + }) +} + +func TestTenant(t *testing.T) { + t.Run("test add tenant", testAddTenant) + t.Run("test update tenant", testUpdateTenant) + t.Run("test get tenant", testGetTenant) + t.Run("test list tenant", testListTenant) + t.Run("test add tenant business", testAddTenantBusiness) + t.Run("test list tenant business", testListTenantBusiness) + t.Run("test delete tenant business", testDeleteTenantBusiness) + t.Run("test delete tenant", testDeleteTenant) +} diff --git a/internal/service/user.go b/internal/service/user.go index 9b8ccc7..467d0a1 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -36,6 +36,7 @@ func (u *UserService) Register(ctx context.Context, in *api.PostUserRequest) (*a Dept: in.Dept, Owner: owner, Avatar: in.Avatar, + TenantId: in.TenantId, } err := u.biz.Add(ctx, user) if err != nil { diff --git a/internal/service/user_test.go b/internal/service/user_test.go index c53dabc..f798c6a 100644 --- a/internal/service/user_test.go +++ b/internal/service/user_test.go @@ -16,6 +16,7 @@ import ( "github.com/begonia-org/go-sdk/client" common "github.com/begonia-org/go-sdk/common/api/v1" c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" ) var uid = "" @@ -26,6 +27,10 @@ func addUser(t *testing.T) { t, func() { apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) + tenantAPI := client.NewTenantAPI(apiAddr, accessKey, secret) + snk, _ := tiga.NewSnowflake(2) + resp, err := tenantAPI.RegisterTenant(context.Background(), snk.GenerateIDString(), "test-tenant", fmt.Sprintf("%s@example.com", snk.GenerateIDString()), []string{"test"}) + c.So(err, c.ShouldBeNil) name := fmt.Sprintf("user-service-test-%s", time.Now().Format("20060102150405")) rsp, err := apiClient.PostUser(context.Background(), &api.PostUserRequest{ Name: name, @@ -36,6 +41,7 @@ func addUser(t *testing.T) { Avatar: "https://www.example.com/avatar.jpg", Owner: "test-user-01", Phone: time.Now().Format("20060102150405"), + TenantId: resp.TenantId, }) c.So(err, c.ShouldBeNil) c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) diff --git a/internal/service/wire.go b/internal/service/wire.go index 0d8fddd..a296652 100644 --- a/internal/service/wire.go +++ b/internal/service/wire.go @@ -11,7 +11,6 @@ import ( app "github.com/begonia-org/go-sdk/api/app/v1" ep "github.com/begonia-org/go-sdk/api/endpoint/v1" file "github.com/begonia-org/go-sdk/api/file/v1" - sys "github.com/begonia-org/go-sdk/api/sys/v1" user "github.com/begonia-org/go-sdk/api/user/v1" "github.com/begonia-org/go-sdk/logger" @@ -32,9 +31,17 @@ func NewEndpointSvrForTest(config *tiga.Configuration, log logger.Logger) ep.End func NewFileSvrForTest(config *tiga.Configuration, log logger.Logger) file.FileServiceServer { panic(wire.Build(biz.ProviderSet, data.ProviderSet, pkg.ProviderSet, NewFileService)) } -func NewSysSvrForTest(config *tiga.Configuration, log logger.Logger) sys.SystemServiceServer { - panic(wire.Build(NewSysService)) -} + +// func NewSysSvrForTest(config *tiga.Configuration, log logger.Logger) sys.SystemServiceServer { +// panic(wire.Build(NewSysService)) +// } func NewUserSvrForTest(config *tiga.Configuration, log logger.Logger) user.UserServiceServer { panic(wire.Build(biz.ProviderSet, pkg.ProviderSet, data.ProviderSet, NewUserService)) } + +// func NewBusinessSvrForTest(config *tiga.Configuration, log logger.Logger) user.BusinessServiceServer { +// panic(wire.Build(biz.ProviderSet, pkg.ProviderSet, data.ProviderSet, NewBusinessService)) +// } +// func NewTenantSvrForTest(config *tiga.Configuration, log logger.Logger) user.TenantsServiceServer { +// panic(wire.Build(biz.ProviderSet, pkg.ProviderSet, data.ProviderSet, NewTenantService)) +// } diff --git a/internal/service/wire_gen.go b/internal/service/wire_gen.go index c703403..df52db5 100644 --- a/internal/service/wire_gen.go +++ b/internal/service/wire_gen.go @@ -16,7 +16,6 @@ import ( v1_2 "github.com/begonia-org/go-sdk/api/app/v1" v1_3 "github.com/begonia-org/go-sdk/api/endpoint/v1" v1_4 "github.com/begonia-org/go-sdk/api/file/v1" - v1_5 "github.com/begonia-org/go-sdk/api/sys/v1" "github.com/begonia-org/go-sdk/api/user/v1" "github.com/begonia-org/go-sdk/logger" "github.com/spark-lence/tiga" @@ -78,11 +77,9 @@ func NewFileSvrForTest(config2 *tiga.Configuration, log logger.Logger) v1_4.File return fileServiceServer } -func NewSysSvrForTest(config2 *tiga.Configuration, log logger.Logger) v1_5.SystemServiceServer { - systemServiceServer := NewSysService() - return systemServiceServer -} - +// func NewSysSvrForTest(config *tiga.Configuration, log logger.Logger) sys.SystemServiceServer { +// panic(wire.Build(NewSysService)) +// } func NewUserSvrForTest(config2 *tiga.Configuration, log logger.Logger) v1.UserServiceServer { mySQLDao := data.NewMySQL(config2) redisDao := data.NewRDB(config2) diff --git a/internal/wire_gen.go b/internal/wire_gen.go index 21fe672..103238e 100644 --- a/internal/wire_gen.go +++ b/internal/wire_gen.go @@ -70,7 +70,13 @@ func New(config2 *tiga.Configuration, log logger.Logger, endpoint2 string) Gatew systemServiceServer := service.NewSysService() userUsecase := biz.NewUserUsecase(userRepo, configConfig) userServiceServer := service.NewUserService(userUsecase, log, configConfig) - v2 := service.NewServices(fileServiceServer, authServiceServer, endpointServiceServer, appsServiceServer, systemServiceServer, userServiceServer) + businessRepo := data.NewBusinessRepoImpl(dataData, curd, configConfig) + businessUsecase := biz.NewBusinessUsecase(businessRepo) + businessServiceServer := service.NewBusinessService(businessUsecase, log, configConfig) + tenantRepo := data.NewTenantRepoImpl(dataData, configConfig, curd) + tenantUsecase := biz.NewTenantUsecase(tenantRepo, businessUsecase, configConfig) + tenantsServiceServer := service.NewTenantService(tenantUsecase, configConfig, log) + v2 := service.NewServices(fileServiceServer, authServiceServer, endpointServiceServer, appsServiceServer, systemServiceServer, userServiceServer, businessServiceServer, tenantsServiceServer) accessKeyAuth := biz.NewAccessKeyAuth(appRepo, configConfig, log) pluginsApply := middleware.New(configConfig, redisDao, authzUsecase, log, accessKeyAuth) gatewayServer := server.NewGateway(gatewayConfig, configConfig, v2, pluginsApply) @@ -167,7 +173,13 @@ func NewWorker(config2 *tiga.Configuration, log logger.Logger, gw string) Gatewa systemServiceServer := service.NewSysService() userUsecase := biz.NewUserUsecase(userRepo, configConfig) userServiceServer := service.NewUserService(userUsecase, log, configConfig) - v2 := service.NewServices(fileServiceServer, authServiceServer, endpointServiceServer, appsServiceServer, systemServiceServer, userServiceServer) + businessRepo := data.NewBusinessRepoImpl(dataData, curd, configConfig) + businessUsecase := biz.NewBusinessUsecase(businessRepo) + businessServiceServer := service.NewBusinessService(businessUsecase, log, configConfig) + tenantRepo := data.NewTenantRepoImpl(dataData, configConfig, curd) + tenantUsecase := biz.NewTenantUsecase(tenantRepo, businessUsecase, configConfig) + tenantsServiceServer := service.NewTenantService(tenantUsecase, configConfig, log) + v2 := service.NewServices(fileServiceServer, authServiceServer, endpointServiceServer, appsServiceServer, systemServiceServer, userServiceServer, businessServiceServer, tenantsServiceServer) accessKeyAuth := biz.NewAccessKeyAuth(appRepo, configConfig, log) pluginsApply := middleware.New(configConfig, redisDao, authzUsecase, log, accessKeyAuth) gatewayServer := server.NewGateway(gatewayConfig, configConfig, v2, pluginsApply) diff --git a/internal/worker.go b/internal/worker.go index 3b412c3..4b87fec 100644 --- a/internal/worker.go +++ b/internal/worker.go @@ -32,5 +32,5 @@ func (g *GatewayWorkerImpl) Start() { g.daemon.Start(context.Background()) time.Sleep(time.Second * 2) g.server.Start() - + } diff --git a/testdata/desc.pb b/testdata/desc.pb index fe34b33..49a730b 100644 Binary files a/testdata/desc.pb and b/testdata/desc.pb differ diff --git a/testdata/gateway.json b/testdata/gateway.json index 14807a4..b829aeb 100644 --- a/testdata/gateway.json +++ b/testdata/gateway.json @@ -1,4 +1,302 @@ { + "/helloworld.Greeter/SayHello": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3 + ], + "Pool": [ + "api", + "v1", + "example", + "post" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/post" + }, + "HttpMethod": "POST", + "FullMethodName": "/helloworld.Greeter/SayHello", + "HttpUri": "/api/v1/example/post", + "PathParams": [], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloBody": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3 + ], + "Pool": [ + "api", + "v1", + "example", + "body" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/body" + }, + "HttpMethod": "POST", + "FullMethodName": "/helloworld.Greeter/SayHelloBody", + "HttpUri": "/api/v1/example/body", + "PathParams": [], + "InName": "HttpBody", + "OutName": "HttpBody", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "google.api", + "OutPkg": "google.api" + } + ], + "/helloworld.Greeter/SayHelloClientStream": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4 + ], + "Pool": [ + "api", + "v1", + "example", + "client", + "stream" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/client/stream" + }, + "HttpMethod": "POST", + "FullMethodName": "/helloworld.Greeter/SayHelloClientStream", + "HttpUri": "/api/v1/example/client/stream", + "PathParams": [], + "InName": "HelloRequest", + "OutName": "RepeatedReply", + "IsClientStream": true, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloError": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4 + ], + "Pool": [ + "api", + "v1", + "example", + "error", + "test" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/error/test" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloError", + "HttpUri": "/api/v1/example/error/test", + "PathParams": [], + "InName": "ErrorRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloGet": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 1, + 0, + 4, + 1, + 5, + 3 + ], + "Pool": [ + "api", + "v1", + "example", + "name" + ], + "Verb": "", + "Fields": [ + "name" + ], + "Template": "/api/v1/example/{name}" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloGet", + "HttpUri": "/api/v1/example/{name}", + "PathParams": [ + "name" + ], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloRPC": [], + "/helloworld.Greeter/SayHelloServerSideEvent": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4, + 1, + 0, + 4, + 1, + 5, + 5 + ], + "Pool": [ + "api", + "v1", + "example", + "server", + "sse", + "name" + ], + "Verb": "", + "Fields": [ + "name" + ], + "Template": "/api/v1/example/server/sse/{name}" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloServerSideEvent", + "HttpUri": "/api/v1/example/server/sse/{name}", + "PathParams": [ + "name" + ], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": true, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloWebsocket": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4 + ], + "Pool": [ + "api", + "v1", + "example", + "server", + "websocket" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/server/websocket" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloWebsocket", + "HttpUri": "/api/v1/example/server/websocket", + "PathParams": [], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": true, + "IsServerStream": true, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], "/integration.TestService/Body": [ { "Pattern": {}, diff --git a/testdata/helloworld.pb b/testdata/helloworld.pb index 82797fb..54336ca 100644 Binary files a/testdata/helloworld.pb and b/testdata/helloworld.pb differ diff --git a/testdata/helloworld.proto b/testdata/helloworld.proto new file mode 100644 index 0000000..58f353c --- /dev/null +++ b/testdata/helloworld.proto @@ -0,0 +1,102 @@ +syntax = "proto3"; + +option go_package = "github.com/begonia-org/begonia/gateway/helloworld"; +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; +option objc_class_prefix = "HLW"; + +package helloworld; + +import "google/api/annotations.proto"; +import "google/protobuf/field_mask.proto"; +import "google/protobuf/descriptor.proto"; +import "google/api/httpbody.proto"; +// The greeting service definition. +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) { + option (google.api.http)={ + post:"/api/v1/example/post" + body:"*" + }; + } + rpc SayHelloGet (HelloRequest) returns (HelloReply) { + option (google.api.http)={ + get:"/api/v1/example/{name}" + }; + } + rpc SayHelloRPC (HelloRequest) returns (HelloReply) { +} + + rpc SayHelloServerSideEvent (HelloRequest) returns (stream HelloReply) { + option (google.api.http) = { + get: "/api/v1/example/server/sse/{name}" + }; + } + + rpc SayHelloClientStream (stream HelloRequest) returns (RepeatedReply) { + option (google.api.http) = { + post: "/api/v1/example/client/stream" + body: "*" + }; + + } + rpc SayHelloWebsocket (stream HelloRequest) returns (stream HelloReply) { + option (google.api.http) = { + get: "/api/v1/example/server/websocket" + }; + } + rpc SayHelloBody (google.api.HttpBody) returns (google.api.HttpBody) { + option (google.api.http) = { + post: "/api/v1/example/body" + }; + } + rpc SayHelloError (ErrorRequest) returns (HelloReply) { + option (google.api.http) = { + get: "/api/v1/example/error/test" + }; + } + +} + +message ErrorRequest { + string msg = 1; + int32 code = 2; +} +// The request message containing the user's name. +message HelloRequest { + string msg = 1; + string name = 2; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; + string name = 2; +} +message RepeatedReply{ + repeated HelloReply replies = 1; +} +enum EnumAllow { + ALLOW = 0; + DENY = 1; +} +message ExampleMessage { + string message = 1; + int32 code = 2; + double float_num = 3[json_name="float_num"]; + float float_data = 14[json_name="float_data"]; + bytes byte_data = 4[json_name="byte_data"]; + bool bool_data = 5[json_name="bool_data"]; + EnumAllow allow = 6; + repeated int64 repeated_data = 7[json_name="repeated_data"]; + fixed64 fixed_data = 8[json_name="fixed_data"]; + sfixed64 sfixed_data = 9[json_name="sfixed_data"]; + sfixed32 sfixed32_data = 10[json_name="sfixed32_data"]; + fixed32 fixed32_data = 11[json_name="fixed32_data"]; + repeated HelloRequest repeated_msg = 12[json_name="repeated_msg"]; + HelloRequest msg = 13; + google.protobuf.FieldMask mask = 18; +} \ No newline at end of file diff --git a/testdata/options.proto b/testdata/options.proto index d215d6c..8c442bb 100644 --- a/testdata/options.proto +++ b/testdata/options.proto @@ -12,7 +12,7 @@ import "google/protobuf/descriptor.proto"; } extend google.protobuf.FieldOptions { - optional bool jsontag = 50035; + optional bool un_updatable = 50039; } @@ -26,4 +26,8 @@ extend google.protobuf.FileOptions { } extend google.protobuf.ServiceOptions { optional string http_response = 50038; -} \ No newline at end of file +} + +extend google.protobuf.MethodOptions { + bool dont_use_http_response = 50041; + } \ No newline at end of file diff --git a/testdata/test.proto b/testdata/test.proto index 4c5140f..b19af55 100644 --- a/testdata/test.proto +++ b/testdata/test.proto @@ -56,10 +56,13 @@ service TestService{ body: "*" }; } - rpc Body(TestRequest) returns (google.api.HttpBody){ + + rpc Body(TestRequest) returns (google.api.HttpBody){ option (google.api.http) = { get: "/test/body" }; + option (begonia.org.sdk.common.dont_use_http_response) = true; + } rpc Custom(TestRequest) returns (TestRequest) { // 使用 HttpRule_Custom 来定义 HTTP 映射规则