Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"compress/flate"
"context"
"crypto/tls"
"encoding/json"
"errors"
Expand Down Expand Up @@ -119,8 +120,7 @@ func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
// Output(calldepth int, s string)
func (c *Conn) SetLogger(l logger, lvl LogLevel, format string) {
c.logGuard.Lock()
defer c.logGuard.Unlock()
Expand Down Expand Up @@ -171,12 +171,20 @@ func (c *Conn) getLogLevel() LogLevel {
// Connect dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) Connect() (*IdentifyResponse, error) {
ctx := context.Background()
return c.ConnectWithContext(ctx)
}

// ConnectWithContext dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) ConnectWithContext(ctx context.Context) (*IdentifyResponse, error) {
dialer := &net.Dialer{
LocalAddr: c.config.LocalAddr,
Timeout: c.config.DialTimeout,
}

conn, err := dialer.Dial("tcp", c.addr)
// the timeout used is smallest of dialer.Timeout (config.DialTimeout) or context timeout
conn, err := dialer.DialContext(ctx, "tcp", c.addr)
if err != nil {
return nil, err
}
Expand All @@ -190,7 +198,7 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err)
}

resp, err := c.identify()
resp, err := c.identify(ctx)
if err != nil {
return nil, err
}
Expand All @@ -200,7 +208,7 @@ func (c *Conn) Connect() (*IdentifyResponse, error) {
c.log(LogLevelError, "Auth Required")
return nil, errors.New("Auth Required")
}
err := c.auth(c.config.AuthSecret)
err := c.auth(ctx, c.config.AuthSecret)
if err != nil {
c.log(LogLevelError, "Auth Failed %s", err)
return nil, err
Expand Down Expand Up @@ -291,13 +299,28 @@ func (c *Conn) Write(p []byte) (int, error) {
// WriteCommand is a goroutine safe method to write a Command
// to this connection, and flush.
func (c *Conn) WriteCommand(cmd *Command) error {
ctx := context.Background()
return c.WriteCommandWithContext(ctx, cmd)
}

// WriteCommandWithContext is a goroutine safe method to write a Command
// to this connection, and flush.
func (c *Conn) WriteCommandWithContext(ctx context.Context, cmd *Command) error {
c.mtx.Lock()

_, err := cmd.WriteTo(c)
if err != nil {
var err error
select {
case <-ctx.Done():
c.mtx.Unlock()
return ctx.Err()
default:
_, err = cmd.WriteTo(c)
if err != nil {
goto exit
}
err = c.Flush()
goto exit
}
err = c.Flush()

exit:
c.mtx.Unlock()
Expand All @@ -320,7 +343,7 @@ func (c *Conn) Flush() error {
return nil
}

func (c *Conn) identify() (*IdentifyResponse, error) {
func (c *Conn) identify(ctx context.Context) (*IdentifyResponse, error) {
ci := make(map[string]interface{})
ci["client_id"] = c.config.ClientID
ci["hostname"] = c.config.Hostname
Expand Down Expand Up @@ -350,7 +373,7 @@ func (c *Conn) identify() (*IdentifyResponse, error) {
return nil, ErrIdentify{err.Error()}
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
Expand Down Expand Up @@ -479,13 +502,13 @@ func (c *Conn) upgradeSnappy() error {
return nil
}

func (c *Conn) auth(secret string) error {
func (c *Conn) auth(ctx context.Context, secret string) error {
cmd, err := Auth(secret)
if err != nil {
return err
}

err = c.WriteCommand(cmd)
err = c.WriteCommandWithContext(ctx, cmd)
if err != nil {
return err
}
Expand Down
Loading