Skip to content

Commit bc4565b

Browse files
committed
Fix potential nil connection closes #132
1 parent 293b2fc commit bc4565b

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

pkg/mux/multiplexer.go

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,11 @@ func ListenWithConfig(network, address string, _c MultiplexerConfig) (*Multiplex
215215
conn.SetDeadline(time.Now().Add(2 * time.Second))
216216

217217
var proto string
218-
conn, proto = m.determineProtocol(conn)
218+
conn, proto, err = m.determineProtocol(conn)
219+
if err != nil {
220+
log.Println("Multiplexing failed: ", err)
221+
return
222+
}
219223

220224
if m.config.TLS && proto == "tls" {
221225

@@ -269,10 +273,10 @@ func ListenWithConfig(network, address string, _c MultiplexerConfig) (*Multiplex
269273

270274
conn.SetDeadline(time.Time{})
271275

272-
functionalConn, proto := m.determineProtocol(conn)
273-
if functionalConn == nil {
276+
functionalConn, proto, err := m.determineProtocol(conn)
277+
if err != nil {
274278
conn.Close()
275-
log.Println("determining functional protocol: ", proto)
279+
log.Println("Error determining functional protocol: ", err)
276280
return
277281
}
278282

@@ -307,10 +311,10 @@ func ListenWithConfig(network, address string, _c MultiplexerConfig) (*Multiplex
307311

308312
select {
309313
case wsConn := <-wsConnChan:
310-
functionalConn, proto = m.determineProtocol(wsConn)
311-
if functionalConn == nil {
314+
functionalConn, proto, err = m.determineProtocol(wsConn)
315+
if err != nil {
312316
wsConn.Close()
313-
log.Println("failed to determine protocol via ws: ", proto)
317+
log.Println("failed to determine protocol via ws: ", err)
314318
return
315319
}
316320

@@ -387,33 +391,33 @@ func isHttp(b []byte) bool {
387391
return false
388392
}
389393

390-
func (m *Multiplexer) determineProtocol(conn net.Conn) (net.Conn, string) {
394+
func (m *Multiplexer) determineProtocol(conn net.Conn) (net.Conn, string, error) {
391395

392396
header := make([]byte, 7)
393397
n, err := conn.Read(header)
394398
if err != nil {
395-
return nil, "failed: " + err.Error()
399+
return nil, "", err
396400
}
397401

398402
c := &bufferedConn{prefix: header[:n], conn: conn}
399403

400404
if bytes.HasPrefix(header, []byte{0x16}) {
401-
return c, "tls"
405+
return c, "tls", nil
402406
}
403407

404408
if bytes.HasPrefix(header, []byte{'S', 'S', 'H'}) {
405-
return c, "ssh"
409+
return c, "ssh", nil
406410
}
407411

408412
if isHttp(header) {
409413
if bytes.HasPrefix(header, []byte("GET /ws")) {
410-
return c, "ws"
414+
return c, "ws", nil
411415
}
412416

413-
return c, "http"
417+
return c, "http", nil
414418
}
415419

416-
return c, "unknown"
420+
return nil, "", errors.New("unknown protocol")
417421
}
418422

419423
func (m *Multiplexer) getProtoListener(proto string) net.Listener {

0 commit comments

Comments
 (0)