@@ -27,8 +27,11 @@ import (
2727 "net"
2828 "reflect"
2929 "strings"
30+ "syscall"
3031 "testing"
32+ "time"
3133
34+ "golang.org/x/sys/unix"
3235 core "google.golang.org/grpc/credentials/alts/internal"
3336 "google.golang.org/grpc/internal/grpctest"
3437)
@@ -105,6 +108,94 @@ func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (cli
105108 return clientConn , serverConn
106109}
107110
111+ // newTCPConnPair returns a pair of conns backed by TCP over loopback.
112+ func newTCPConnPair (rp string , clientProtected []byte , serverProtected []byte ) (* conn , * conn , error ) {
113+ const address = "localhost:50935"
114+
115+ // Start the server.
116+ serverChan := make (chan net.Conn )
117+ listenChan := make (chan struct {})
118+ go func () {
119+ listener , err := net .Listen ("tcp4" , address )
120+ if err != nil {
121+ panic (fmt .Sprintf ("failed to listen: %v" , err ))
122+ }
123+ defer listener .Close ()
124+ listenChan <- struct {}{}
125+ conn , err := listener .Accept ()
126+ if err != nil {
127+ panic (fmt .Sprintf ("failed to aceept: %v" , err ))
128+ }
129+ serverChan <- conn
130+ }()
131+
132+ // Ensure the server is listening before trying to connect.
133+ <- listenChan
134+ clientTCP , err := net .DialTimeout ("tcp4" , address , 5 * time .Second )
135+ if err != nil {
136+ return nil , nil , fmt .Errorf ("failed to Dial: %w" , err )
137+ }
138+
139+ // Get the server-side connection returned by Accept().
140+ var serverTCP net.Conn
141+ select {
142+ case serverTCP = <- serverChan :
143+ case <- time .After (5 * time .Second ):
144+ return nil , nil , fmt .Errorf ("timed out waiting for server conn" )
145+ }
146+
147+ // Make the connection behave a little bit like a real one by imposing
148+ // an MTU.
149+ clientTCP = & mtuConn {clientTCP , 1500 }
150+
151+ // 16 arbitrary bytes.
152+ key := []byte {
153+ 0x1f , 0x8b , 0x08 , 0x00 , 0x00 , 0x09 , 0x6e , 0x88 ,
154+ 0x02 , 0xff , 0xe2 , 0xd2 , 0x4c , 0xce , 0x4f , 0x49 ,
155+ }
156+
157+ client , err := NewConn (clientTCP , core .ClientSide , rp , key , clientProtected )
158+ if err != nil {
159+ panic (fmt .Sprintf ("Unexpected error creating test ALTS record connection: %v" , err ))
160+ }
161+ server , err := NewConn (serverTCP , core .ServerSide , rp , key , serverProtected )
162+ if err != nil {
163+ panic (fmt .Sprintf ("Unexpected error creating test ALTS record connection: %v" , err ))
164+ }
165+
166+ return client .(* conn ), server .(* conn ), nil
167+ }
168+
169+ // mtuConn imposes an MTU on writes. It simulates an important quality of real
170+ // network traffic that is lost when using loopback devices. On loopback, even
171+ // large messages (e.g. 512 KiB) when written often arrive at the receiver
172+ // instantaneously as a single payload. By explicitly splitting such writes into
173+ // smaller, MTU-sized paylaods we give the receiver a chance to respond to
174+ // smaller message sizes.
175+ type mtuConn struct {
176+ net.Conn
177+ mtu int
178+ }
179+
180+ // Write implements net.Conn.
181+ func (rc * mtuConn ) Write (buf []byte ) (int , error ) {
182+ var written int
183+ for len (buf ) > 0 {
184+ n , err := rc .Conn .Write (buf [:min (rc .mtu , len (buf ))])
185+ written += n
186+ if err != nil {
187+ return written , err
188+ }
189+ buf = buf [n :]
190+ }
191+ return written , nil
192+ }
193+
194+ // SyscallConn implements syscall.Conn.
195+ func (rc * mtuConn ) SycallConn () (syscall.RawConn , error ) {
196+ return rc .Conn .(syscall.Conn ).SyscallConn ()
197+ }
198+
108199func testPingPong (t * testing.T , rp string ) {
109200 clientConn , serverConn := newConnPair (rp , nil , nil )
110201 clientMsg := []byte ("Client Message" )
@@ -231,6 +322,117 @@ func BenchmarkLargeMessage(b *testing.B) {
231322 }
232323}
233324
325+ // BenchmarkTCP is a simple throughput test that sends payloads over a local TCP
326+ // connection. Run via:
327+ //
328+ // go test -run="^$" -bench="BenchmarkTCP" ./credentials/alts/internal/conn
329+ func BenchmarkTCP (b * testing.B ) {
330+ tcs := []struct {
331+ name string
332+ size int
333+ }{
334+ {"1 KiB" , 1024 },
335+ {"4 KiB" , 4 * 1024 },
336+ {"64 KiB" , 64 * 1024 },
337+ {"512 KiB" , 512 * 1024 },
338+ {"1 MiB" , 1024 * 1024 },
339+ {"4 MiB" , 4 * 1024 * 1024 },
340+ }
341+ for _ , tc := range tcs {
342+ b .Run ("size=" + tc .name , func (b * testing.B ) {
343+ benchmarkTCP (b , tc .size )
344+ })
345+ }
346+ }
347+
348+ // sum makes unwanted compiler optimizations in benchmarkTCP's loop less likely.
349+ var sum int
350+
351+ func benchmarkTCP (b * testing.B , size int ) {
352+ // Initialize the connection.
353+ client , server , err := newTCPConnPair (rekeyRecordProtocol , nil , nil )
354+ if err != nil {
355+ b .Fatalf ("failed to create TCP conn pair: %v" , err )
356+ }
357+ defer client .Close ()
358+ defer server .Close ()
359+
360+ rcvBuf := make ([]byte , size )
361+ sndBuf := make ([]byte , size )
362+ done := make (chan struct {})
363+ errChan := make (chan error )
364+
365+ // Launch a writer goroutine.
366+ go func () {
367+ for {
368+ select {
369+ case <- done :
370+ return
371+ default :
372+ }
373+ n , err := client .Write (sndBuf )
374+ if n != size || err != nil {
375+ errChan <- fmt .Errorf ("Write() = %v, %v; want %v, <nil>" , n , err , size )
376+ return
377+ }
378+ // Act a bit like a real workload that can't just fill
379+ // every buffer immediately.
380+ time .Sleep (10 * time .Millisecond )
381+ }
382+ }()
383+
384+ // Get the initial rusage so we can measure CPU time.
385+ var startUsage unix.Rusage
386+ if err := unix .Getrusage (unix .RUSAGE_SELF , & startUsage ); err != nil {
387+ b .Fatalf ("failed to get initial rusage: %v" , err )
388+ }
389+
390+ // Read as much as possible.
391+ var rcvd uint64
392+ for b .Loop () {
393+ n , err := io .ReadFull (server , rcvBuf )
394+ rcvd += uint64 (n )
395+ if n != size || err != nil {
396+ b .Fatalf ("Read() = %v, %v; want %v, <nil>" , n , err , size )
397+ }
398+ // Act a bit like a real workload and utilize received bytes.
399+ for _ , b := range rcvBuf [:n ] {
400+ sum += int (b )
401+ }
402+ }
403+
404+ // Turn off the writer.
405+ done <- struct {}{}
406+
407+ // Get the ending rusage.
408+ var endUsage unix.Rusage
409+ if err := unix .Getrusage (unix .RUSAGE_SELF , & endUsage ); err != nil {
410+ b .Fatalf ("failed to get final rusage: %v" , err )
411+ }
412+
413+ // Error check the writer goroutine.
414+ select {
415+ case err := <- errChan :
416+ b .Fatal (err )
417+ default :
418+ }
419+
420+ // Emit extra metrics.
421+ utime := timevalDiffUsec (& startUsage .Utime , & endUsage .Utime )
422+ stime := timevalDiffUsec (& startUsage .Stime , & endUsage .Stime )
423+ b .ReportMetric (float64 (utime )/ float64 (b .N ), "usr-usec/op" )
424+ b .ReportMetric (float64 (stime )/ float64 (b .N ), "sys-usec/op" )
425+ b .ReportMetric (float64 (stime + utime )/ float64 (b .N ), "cpu-usec/op" )
426+ b .ReportMetric (float64 (rcvd * 8 / (1024 * 1024 ))/ float64 (b .Elapsed ().Seconds ()), "Mbps" )
427+ }
428+
429+ // timevalDiffUsec returns the difference in microseconds between start and end.
430+ func timevalDiffUsec (start , end * unix.Timeval ) int64 {
431+ // Note: the int64 type conversion is needed because unix.Timeval uses
432+ // 32 bit values on some architectures.
433+ return int64 (1_000_000 * (end .Sec - start .Sec ) + end .Usec - start .Usec )
434+ }
435+
234436func testIncorrectMsgType (t * testing.T , rp string ) {
235437 // framedMsg is an empty ciphertext with correct framing but wrong
236438 // message type.
0 commit comments