@@ -32,6 +32,7 @@ import (
3232 "google.golang.org/grpc"
3333 "google.golang.org/grpc/codes"
3434 "google.golang.org/grpc/credentials"
35+ "google.golang.org/grpc/internal/envconfig"
3536 "google.golang.org/grpc/internal/grpctest"
3637 "google.golang.org/grpc/internal/stubserver"
3738 "google.golang.org/grpc/status"
@@ -410,6 +411,12 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
410411// correctly for a server that doesn't specify the NextProtos field and uses
411412// GetConfigForClient to provide the TLS config during the handshake.
412413func (s ) TestTLS_ServerConfiguresALPNByDefault (t * testing.T ) {
414+ initialVal := envconfig .EnforceALPNEnabled
415+ defer func () {
416+ envconfig .EnforceALPNEnabled = initialVal
417+ }()
418+ envconfig .EnforceALPNEnabled = true
419+
413420 ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
414421 defer cancel ()
415422
@@ -446,104 +453,156 @@ func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
446453// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
447454// connecting to a server that doesn't support ALPN.
448455func (s ) TestTLS_DisabledALPNClient (t * testing.T ) {
449- listener , err := tls .Listen ("tcp" , "localhost:0" , & tls.Config {
450- Certificates : []tls.Certificate {serverCert },
451- NextProtos : []string {}, // Empty list indicates ALPN is disabled.
452- })
453- if err != nil {
454- t .Fatalf ("Error starting TLS server: %v" , err )
455- }
456-
457- errCh := make (chan error , 1 )
458- go func () {
459- conn , err := listener .Accept ()
460- if err != nil {
461- errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
462- } else {
463- // The first write to the TLS listener initiates the TLS handshake.
464- conn .Write ([]byte ("Hello, World!" ))
465- conn .Close ()
466- }
467- close (errCh )
456+ initialVal := envconfig .EnforceALPNEnabled
457+ defer func () {
458+ envconfig .EnforceALPNEnabled = initialVal
468459 }()
469460
470- serverAddr := listener .Addr ().String ()
471- conn , err := net .Dial ("tcp" , serverAddr )
472- if err != nil {
473- t .Fatalf ("net.Dial(%s) failed: %v" , serverAddr , err )
461+ tests := []struct {
462+ name string
463+ alpnEnforced bool
464+ wantErr bool
465+ }{
466+ {
467+ name : "enforced" ,
468+ alpnEnforced : true ,
469+ wantErr : true ,
470+ },
471+ {
472+ name : "not_enforced" ,
473+ },
474474 }
475- defer conn .Close ()
476475
477- ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
478- defer cancel ()
476+ for _ , tc := range tests {
477+ t .Run (tc .name , func (t * testing.T ) {
478+ envconfig .EnforceALPNEnabled = tc .alpnEnforced
479479
480- clientCfg := tls.Config {
481- ServerName : serverName ,
482- RootCAs : certPool ,
483- NextProtos : []string {"h2" },
484- }
485- _ , _ , err = credentials .NewTLS (& clientCfg ).ClientHandshake (ctx , serverName , conn )
480+ listener , err := tls .Listen ("tcp" , "localhost:0" , & tls.Config {
481+ Certificates : []tls.Certificate {serverCert },
482+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
483+ })
484+ if err != nil {
485+ t .Fatalf ("Error starting TLS server: %v" , err )
486+ }
486487
487- if gotErr , wantErr := (err != nil ), true ; gotErr != wantErr {
488- t .Errorf ("ClientHandshake returned unexpected error: got=%v, want=%t" , err , wantErr )
489- }
488+ errCh := make (chan error , 1 )
489+ go func () {
490+ conn , err := listener .Accept ()
491+ if err != nil {
492+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
493+ } else {
494+ // The first write to the TLS listener initiates the TLS handshake.
495+ conn .Write ([]byte ("Hello, World!" ))
496+ conn .Close ()
497+ }
498+ close (errCh )
499+ }()
500+
501+ serverAddr := listener .Addr ().String ()
502+ conn , err := net .Dial ("tcp" , serverAddr )
503+ if err != nil {
504+ t .Fatalf ("net.Dial(%s) failed: %v" , serverAddr , err )
505+ }
506+ defer conn .Close ()
507+
508+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
509+ defer cancel ()
510+
511+ clientCfg := tls.Config {
512+ ServerName : serverName ,
513+ RootCAs : certPool ,
514+ NextProtos : []string {"h2" },
515+ }
516+ _ , _ , err = credentials .NewTLS (& clientCfg ).ClientHandshake (ctx , serverName , conn )
517+
518+ if gotErr := (err != nil ); gotErr != tc .wantErr {
519+ t .Errorf ("ClientHandshake returned unexpected error: got=%v, want=%t" , err , tc .wantErr )
520+ }
490521
491- select {
492- case err := <- errCh :
493- if err != nil {
494- t .Fatalf ("Unexpected error received from server: %v" , err )
495- }
496- case <- ctx .Done ():
497- t .Fatalf ("Timeout waiting for error from server" )
522+ select {
523+ case err := <- errCh :
524+ if err != nil {
525+ t .Fatalf ("Unexpected error received from server: %v" , err )
526+ }
527+ case <- ctx .Done ():
528+ t .Fatalf ("Timeout waiting for error from server" )
529+ }
530+ })
498531 }
499532}
500533
501534// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
502535// accepting a request from a client that doesn't support ALPN.
503536func (s ) TestTLS_DisabledALPNServer (t * testing.T ) {
504- listener , err := net .Listen ("tcp" , "localhost:0" )
505- if err != nil {
506- t .Fatalf ("Error starting server: %v" , err )
507- }
508-
509- errCh := make (chan error , 1 )
510- go func () {
511- conn , err := listener .Accept ()
512- if err != nil {
513- errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
514- return
515- }
516- defer conn .Close ()
517- serverCfg := tls.Config {
518- Certificates : []tls.Certificate {serverCert },
519- NextProtos : []string {"h2" },
520- }
521- _ , _ , err = credentials .NewTLS (& serverCfg ).ServerHandshake (conn )
522- if gotErr , wantErr := (err != nil ), true ; gotErr != wantErr {
523- t .Errorf ("ServerHandshake returned unexpected error: got=%v, want=%t" , err , wantErr )
524- }
525- close (errCh )
537+ initialVal := envconfig .EnforceALPNEnabled
538+ defer func () {
539+ envconfig .EnforceALPNEnabled = initialVal
526540 }()
527541
528- serverAddr := listener .Addr ().String ()
529- clientCfg := & tls.Config {
530- Certificates : []tls.Certificate {serverCert },
531- NextProtos : []string {}, // Empty list indicates ALPN is disabled.
532- RootCAs : certPool ,
533- ServerName : serverName ,
534- }
535- conn , err := tls .Dial ("tcp" , serverAddr , clientCfg )
536- if err != nil {
537- t .Fatalf ("tls.Dial(%s) failed: %v" , serverAddr , err )
542+ tests := []struct {
543+ name string
544+ alpnEnforced bool
545+ wantErr bool
546+ }{
547+ {
548+ name : "enforced" ,
549+ alpnEnforced : true ,
550+ wantErr : true ,
551+ },
552+ {
553+ name : "not_enforced" ,
554+ },
538555 }
539- defer conn .Close ()
540-
541- select {
542- case <- time .After (defaultTestTimeout ):
543- t .Fatal ("Timed out waiting for completion" )
544- case err := <- errCh :
545- if err != nil {
546- t .Fatalf ("Unexpected server error: %v" , err )
547- }
556+
557+ for _ , tc := range tests {
558+ t .Run (tc .name , func (t * testing.T ) {
559+ envconfig .EnforceALPNEnabled = tc .alpnEnforced
560+
561+ listener , err := net .Listen ("tcp" , "localhost:0" )
562+ if err != nil {
563+ t .Fatalf ("Error starting server: %v" , err )
564+ }
565+
566+ errCh := make (chan error , 1 )
567+ go func () {
568+ conn , err := listener .Accept ()
569+ if err != nil {
570+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
571+ return
572+ }
573+ defer conn .Close ()
574+ serverCfg := tls.Config {
575+ Certificates : []tls.Certificate {serverCert },
576+ NextProtos : []string {"h2" },
577+ }
578+ _ , _ , err = credentials .NewTLS (& serverCfg ).ServerHandshake (conn )
579+ if gotErr := (err != nil ); gotErr != tc .wantErr {
580+ t .Errorf ("ServerHandshake returned unexpected error: got=%v, want=%t" , err , tc .wantErr )
581+ }
582+ close (errCh )
583+ }()
584+
585+ serverAddr := listener .Addr ().String ()
586+ clientCfg := & tls.Config {
587+ Certificates : []tls.Certificate {serverCert },
588+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
589+ RootCAs : certPool ,
590+ ServerName : serverName ,
591+ }
592+ conn , err := tls .Dial ("tcp" , serverAddr , clientCfg )
593+ if err != nil {
594+ t .Fatalf ("tls.Dial(%s) failed: %v" , serverAddr , err )
595+ }
596+ defer conn .Close ()
597+
598+ select {
599+ case <- time .After (defaultTestTimeout ):
600+ t .Fatal ("Timed out waiting for completion" )
601+ case err := <- errCh :
602+ if err != nil {
603+ t .Fatalf ("Unexpected server error: %v" , err )
604+ }
605+ }
606+ })
548607 }
549608}
0 commit comments