diff --git a/imapserver/conn.go b/imapserver/conn.go index 5ea9ee18..3e47eab3 100644 --- a/imapserver/conn.go +++ b/imapserver/conn.go @@ -101,6 +101,9 @@ func (c *Conn) serve() { c.server.mutex.Unlock() }() + c.server.connWaitGroup.Add(1) + defer c.server.connWaitGroup.Done() + var ( greetingData *GreetingData err error @@ -169,10 +172,27 @@ func (c *Conn) serve() { dec := imapwire.NewDecoder(c.br, imapwire.ConnSideServer) dec.CheckBufferedLiteralFunc = c.checkBufferedLiteral - if c.state == imap.ConnStateLogout || dec.EOF() { + if c.state == imap.ConnStateLogout { break } + if c.br.Buffered() == 0 { + eofCh := make(chan bool, 1) + go func() { + eofCh <- dec.EOF() + }() + + var eof bool + select { + case <-c.server.shutdownCh: + eof = true + case eof = <-eofCh: + } + if eof { + break + } + } + c.setReadTimeout(cmdReadTimeout) if err := c.readCommand(dec); err != nil { if !errors.Is(err, net.ErrClosed) { diff --git a/imapserver/server.go b/imapserver/server.go index fd6eff1f..a3800ebc 100644 --- a/imapserver/server.go +++ b/imapserver/server.go @@ -80,11 +80,13 @@ type Server struct { options Options listenerWaitGroup sync.WaitGroup + connWaitGroup sync.WaitGroup - mutex sync.Mutex - listeners map[net.Listener]struct{} - conns map[*Conn]struct{} - closed bool + mutex sync.Mutex + listeners map[net.Listener]struct{} + conns map[*Conn]struct{} + closed bool + shutdownCh chan struct{} } // New creates a new server. @@ -93,9 +95,10 @@ func New(options *Options) *Server { panic("imapserver: at least IMAP4rev1 must be supported") } return &Server{ - options: *options, - listeners: make(map[net.Listener]struct{}), - conns: make(map[*Conn]struct{}), + options: *options, + listeners: make(map[net.Listener]struct{}), + conns: make(map[*Conn]struct{}), + shutdownCh: make(chan struct{}), } } @@ -220,3 +223,30 @@ func (s *Server) Close() error { return err } + +func (s *Server) Shutdown() error { + var err error + + s.mutex.Lock() + ok := true + select { + case <-s.shutdownCh: + ok = false + default: + close(s.shutdownCh) + for l := range s.listeners { + if closeErr := l.Close(); closeErr != nil && err == nil { + err = closeErr + } + } + } + s.mutex.Unlock() + if !ok { + return errClosed + } + + s.listenerWaitGroup.Wait() + s.connWaitGroup.Wait() + + return err +}