Skip to content

Commit 2b35fa5

Browse files
authored
credentials/tls: Revert removal of ALPN flag from #8660 (#8664)
Original PR: #8660 This reverts commit 0037c61. ## Why There are internal users of this flag that need to be updated. Internal issue to track removal: b/454048967. RELEASE NOTES: N/A
1 parent f448a97 commit 2b35fa5

File tree

4 files changed

+302
-170
lines changed

4 files changed

+302
-170
lines changed

credentials/tls.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ import (
2828
"net/url"
2929
"os"
3030

31+
"google.golang.org/grpc/grpclog"
3132
credinternal "google.golang.org/grpc/internal/credentials"
33+
"google.golang.org/grpc/internal/envconfig"
3234
)
3335

3436
const alpnFailureHelpMessage = "If you upgraded from a grpc-go version earlier than 1.67, your TLS connections may have stopped working due to ALPN enforcement. For more details, see: https://github.com/grpc/grpc-go/issues/434"
3537

38+
var logger = grpclog.Component("credentials")
39+
3640
// TLSInfo contains the auth information for a TLS authenticated connection.
3741
// It implements the AuthInfo interface.
3842
type TLSInfo struct {
@@ -140,8 +144,11 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
140144
// for using HTTP/2 over TLS. We can terminate the connection immediately.
141145
np := conn.ConnectionState().NegotiatedProtocol
142146
if np == "" {
143-
conn.Close()
144-
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
147+
if envconfig.EnforceALPNEnabled {
148+
conn.Close()
149+
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
150+
}
151+
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
145152
}
146153
tlsInfo := TLSInfo{
147154
State: conn.ConnectionState(),
@@ -167,8 +174,12 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
167174
// support ALPN. In such cases, we can close the connection since ALPN is required
168175
// for using HTTP/2 over TLS.
169176
if cs.NegotiatedProtocol == "" {
170-
conn.Close()
171-
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
177+
if envconfig.EnforceALPNEnabled {
178+
conn.Close()
179+
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
180+
} else if logger.V(2) {
181+
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
182+
}
172183
}
173184
tlsInfo := TLSInfo{
174185
State: cs,

credentials/tls_ext_test.go

Lines changed: 142 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"google.golang.org/grpc"
3333
"google.golang.org/grpc/codes"
3434
"google.golang.org/grpc/credentials"
35+
"google.golang.org/grpc/internal/envconfig"
3536
"google.golang.org/grpc/internal/grpctest"
3637
"google.golang.org/grpc/internal/stubserver"
3738
"google.golang.org/grpc/status"
@@ -410,6 +411,12 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
410411
// correctly for a server that doesn't specify the NextProtos field and uses
411412
// GetConfigForClient to provide the TLS config during the handshake.
412413
func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
414+
initialVal := envconfig.EnforceALPNEnabled
415+
defer func() {
416+
envconfig.EnforceALPNEnabled = initialVal
417+
}()
418+
envconfig.EnforceALPNEnabled = true
419+
413420
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
414421
defer cancel()
415422

@@ -446,104 +453,156 @@ func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
446453
// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
447454
// connecting to a server that doesn't support ALPN.
448455
func (s) TestTLS_DisabledALPNClient(t *testing.T) {
449-
listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
450-
Certificates: []tls.Certificate{serverCert},
451-
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
452-
})
453-
if err != nil {
454-
t.Fatalf("Error starting TLS server: %v", err)
455-
}
456-
457-
errCh := make(chan error, 1)
458-
go func() {
459-
conn, err := listener.Accept()
460-
if err != nil {
461-
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
462-
} else {
463-
// The first write to the TLS listener initiates the TLS handshake.
464-
conn.Write([]byte("Hello, World!"))
465-
conn.Close()
466-
}
467-
close(errCh)
456+
initialVal := envconfig.EnforceALPNEnabled
457+
defer func() {
458+
envconfig.EnforceALPNEnabled = initialVal
468459
}()
469460

470-
serverAddr := listener.Addr().String()
471-
conn, err := net.Dial("tcp", serverAddr)
472-
if err != nil {
473-
t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
461+
tests := []struct {
462+
name string
463+
alpnEnforced bool
464+
wantErr bool
465+
}{
466+
{
467+
name: "enforced",
468+
alpnEnforced: true,
469+
wantErr: true,
470+
},
471+
{
472+
name: "not_enforced",
473+
},
474474
}
475-
defer conn.Close()
476475

477-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
478-
defer cancel()
476+
for _, tc := range tests {
477+
t.Run(tc.name, func(t *testing.T) {
478+
envconfig.EnforceALPNEnabled = tc.alpnEnforced
479479

480-
clientCfg := tls.Config{
481-
ServerName: serverName,
482-
RootCAs: certPool,
483-
NextProtos: []string{"h2"},
484-
}
485-
_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
480+
listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
481+
Certificates: []tls.Certificate{serverCert},
482+
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
483+
})
484+
if err != nil {
485+
t.Fatalf("Error starting TLS server: %v", err)
486+
}
486487

487-
if gotErr, wantErr := (err != nil), true; gotErr != wantErr {
488-
t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
489-
}
488+
errCh := make(chan error, 1)
489+
go func() {
490+
conn, err := listener.Accept()
491+
if err != nil {
492+
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
493+
} else {
494+
// The first write to the TLS listener initiates the TLS handshake.
495+
conn.Write([]byte("Hello, World!"))
496+
conn.Close()
497+
}
498+
close(errCh)
499+
}()
500+
501+
serverAddr := listener.Addr().String()
502+
conn, err := net.Dial("tcp", serverAddr)
503+
if err != nil {
504+
t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
505+
}
506+
defer conn.Close()
507+
508+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
509+
defer cancel()
510+
511+
clientCfg := tls.Config{
512+
ServerName: serverName,
513+
RootCAs: certPool,
514+
NextProtos: []string{"h2"},
515+
}
516+
_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
517+
518+
if gotErr := (err != nil); gotErr != tc.wantErr {
519+
t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
520+
}
490521

491-
select {
492-
case err := <-errCh:
493-
if err != nil {
494-
t.Fatalf("Unexpected error received from server: %v", err)
495-
}
496-
case <-ctx.Done():
497-
t.Fatalf("Timeout waiting for error from server")
522+
select {
523+
case err := <-errCh:
524+
if err != nil {
525+
t.Fatalf("Unexpected error received from server: %v", err)
526+
}
527+
case <-ctx.Done():
528+
t.Fatalf("Timeout waiting for error from server")
529+
}
530+
})
498531
}
499532
}
500533

501534
// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
502535
// accepting a request from a client that doesn't support ALPN.
503536
func (s) TestTLS_DisabledALPNServer(t *testing.T) {
504-
listener, err := net.Listen("tcp", "localhost:0")
505-
if err != nil {
506-
t.Fatalf("Error starting server: %v", err)
507-
}
508-
509-
errCh := make(chan error, 1)
510-
go func() {
511-
conn, err := listener.Accept()
512-
if err != nil {
513-
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
514-
return
515-
}
516-
defer conn.Close()
517-
serverCfg := tls.Config{
518-
Certificates: []tls.Certificate{serverCert},
519-
NextProtos: []string{"h2"},
520-
}
521-
_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
522-
if gotErr, wantErr := (err != nil), true; gotErr != wantErr {
523-
t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, wantErr)
524-
}
525-
close(errCh)
537+
initialVal := envconfig.EnforceALPNEnabled
538+
defer func() {
539+
envconfig.EnforceALPNEnabled = initialVal
526540
}()
527541

528-
serverAddr := listener.Addr().String()
529-
clientCfg := &tls.Config{
530-
Certificates: []tls.Certificate{serverCert},
531-
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
532-
RootCAs: certPool,
533-
ServerName: serverName,
534-
}
535-
conn, err := tls.Dial("tcp", serverAddr, clientCfg)
536-
if err != nil {
537-
t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
542+
tests := []struct {
543+
name string
544+
alpnEnforced bool
545+
wantErr bool
546+
}{
547+
{
548+
name: "enforced",
549+
alpnEnforced: true,
550+
wantErr: true,
551+
},
552+
{
553+
name: "not_enforced",
554+
},
538555
}
539-
defer conn.Close()
540-
541-
select {
542-
case <-time.After(defaultTestTimeout):
543-
t.Fatal("Timed out waiting for completion")
544-
case err := <-errCh:
545-
if err != nil {
546-
t.Fatalf("Unexpected server error: %v", err)
547-
}
556+
557+
for _, tc := range tests {
558+
t.Run(tc.name, func(t *testing.T) {
559+
envconfig.EnforceALPNEnabled = tc.alpnEnforced
560+
561+
listener, err := net.Listen("tcp", "localhost:0")
562+
if err != nil {
563+
t.Fatalf("Error starting server: %v", err)
564+
}
565+
566+
errCh := make(chan error, 1)
567+
go func() {
568+
conn, err := listener.Accept()
569+
if err != nil {
570+
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
571+
return
572+
}
573+
defer conn.Close()
574+
serverCfg := tls.Config{
575+
Certificates: []tls.Certificate{serverCert},
576+
NextProtos: []string{"h2"},
577+
}
578+
_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
579+
if gotErr := (err != nil); gotErr != tc.wantErr {
580+
t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
581+
}
582+
close(errCh)
583+
}()
584+
585+
serverAddr := listener.Addr().String()
586+
clientCfg := &tls.Config{
587+
Certificates: []tls.Certificate{serverCert},
588+
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
589+
RootCAs: certPool,
590+
ServerName: serverName,
591+
}
592+
conn, err := tls.Dial("tcp", serverAddr, clientCfg)
593+
if err != nil {
594+
t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
595+
}
596+
defer conn.Close()
597+
598+
select {
599+
case <-time.After(defaultTestTimeout):
600+
t.Fatal("Timed out waiting for completion")
601+
case err := <-errCh:
602+
if err != nil {
603+
t.Fatalf("Unexpected server error: %v", err)
604+
}
605+
}
606+
})
548607
}
549608
}

0 commit comments

Comments
 (0)