Skip to content

Commit c1255ad

Browse files
committed
test grpc basic auth propagation
1 parent 0c96286 commit c1255ad

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

src/transports/grpc/grpc_transport.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package grpc
22

33
import (
44
"context"
5+
"encoding/base64"
56
"errors"
67
"fmt"
78
"io"
@@ -21,9 +22,22 @@ import (
2122
. "github.com/universal-tool-calling-protocol/go-utcp/src/providers/grpc"
2223
"github.com/universal-tool-calling-protocol/go-utcp/src/transports"
2324

25+
. "github.com/universal-tool-calling-protocol/go-utcp/src/auth"
2426
. "github.com/universal-tool-calling-protocol/go-utcp/src/tools"
2527
)
2628

29+
type basicAuthCreds struct {
30+
username string
31+
password string
32+
}
33+
34+
func (b *basicAuthCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
35+
token := base64.StdEncoding.EncodeToString([]byte(b.username + ":" + b.password))
36+
return map[string]string{"authorization": "Basic " + token}, nil
37+
}
38+
39+
func (b *basicAuthCreds) RequireTransportSecurity() bool { return false }
40+
2741
// GRPCClientTransport implements ClientTransport over gRPC using the UTCPService.
2842
// It expects the remote server to implement the grpcpb.UTCPService service.
2943
type GRPCClientTransport struct {
@@ -60,6 +74,14 @@ func (t *GRPCClientTransport) dial(ctx context.Context, prov *GRPCProvider) (*gr
6074
t.logger("Using target '%s' as gRPC authority", prov.Target)
6175
}
6276

77+
if prov.Auth != nil {
78+
authIfc := *prov.Auth
79+
switch a := authIfc.(type) {
80+
case *BasicAuth:
81+
opts = append(opts, grpc.WithPerRPCCredentials(&basicAuthCreds{username: a.Username, password: a.Password}))
82+
}
83+
}
84+
6385
if prov.UseSSL {
6486
// In this example we just use insecure when UseSSL is false.
6587
// Real implementation would configure TLS credentials.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package grpc
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"net"
7+
"testing"
8+
9+
. "github.com/universal-tool-calling-protocol/go-utcp/src/auth"
10+
"github.com/universal-tool-calling-protocol/go-utcp/src/grpcpb"
11+
. "github.com/universal-tool-calling-protocol/go-utcp/src/providers/base"
12+
. "github.com/universal-tool-calling-protocol/go-utcp/src/providers/grpc"
13+
14+
"google.golang.org/grpc"
15+
"google.golang.org/grpc/codes"
16+
"google.golang.org/grpc/metadata"
17+
"google.golang.org/grpc/status"
18+
)
19+
20+
type authServer struct {
21+
grpcpb.UnimplementedUTCPServiceServer
22+
}
23+
24+
func (s *authServer) GetManual(ctx context.Context, _ *grpcpb.Empty) (*grpcpb.Manual, error) {
25+
return &grpcpb.Manual{}, nil
26+
}
27+
28+
func TestGRPCClientTransport_BasicAuth(t *testing.T) {
29+
const user = "u"
30+
const pass = "p"
31+
expected := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
32+
33+
interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
34+
md, _ := metadata.FromIncomingContext(ctx)
35+
if md == nil {
36+
return nil, status.Error(codes.Unauthenticated, "missing metadata")
37+
}
38+
auths := md.Get("authorization")
39+
if len(auths) == 0 || auths[0] != expected {
40+
return nil, status.Error(codes.Unauthenticated, "bad creds")
41+
}
42+
return handler(ctx, req)
43+
}
44+
45+
lis, err := net.Listen("tcp", "127.0.0.1:0")
46+
if err != nil {
47+
t.Fatalf("listen err: %v", err)
48+
}
49+
srv := grpc.NewServer(grpc.UnaryInterceptor(interceptor))
50+
grpcpb.RegisterUTCPServiceServer(srv, &authServer{})
51+
go srv.Serve(lis)
52+
defer srv.Stop()
53+
54+
port := lis.Addr().(*net.TCPAddr).Port
55+
ba := &BasicAuth{AuthType: BasicType, Username: user, Password: pass}
56+
var a Auth = ba
57+
prov := &GRPCProvider{BaseProvider: BaseProvider{Name: "g", ProviderType: ProviderGRPC}, Host: "127.0.0.1", Port: port, Auth: &a}
58+
59+
tr := NewGRPCClientTransport(nil)
60+
if _, err := tr.RegisterToolProvider(context.Background(), prov); err != nil {
61+
t.Fatalf("RegisterToolProvider err: %v", err)
62+
}
63+
}

0 commit comments

Comments
 (0)