Skip to content

Commit 8faa921

Browse files
authored
Sec web socket protocol (#11)
* 回显子协议的值 * 更新 * 更新
1 parent 608a7f8 commit 8faa921

File tree

4 files changed

+96
-24
lines changed

4 files changed

+96
-24
lines changed

client.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,35 +36,35 @@ var (
3636
)
3737

3838
type DialOption struct {
39-
Header http.Header
40-
u *url.URL
41-
tlsConfig *tls.Config
42-
dialTimeout time.Duration
39+
Header http.Header
40+
u *url.URL
41+
tlsConfig *tls.Config
42+
dialTimeout time.Duration
43+
bindClientHttpHeader *http.Header // 握手成功之后, 客户端获取http.Header,
4344
Config
4445
}
4546

46-
func ClientOptionToConf(opts ...ClientOption) *Config {
47+
func ClientOptionToConf(opts ...ClientOption) *DialOption {
4748
var dial DialOption
4849
dial.defaultSetting()
4950
for _, o := range opts {
5051
o(&dial)
5152
}
52-
return &dial.Config
53+
return &dial
5354
}
5455

55-
func DialConf(rawUrl string, conf *Config) (*Conn, error) {
56-
var dial DialOption
56+
func DialConf(rawUrl string, conf *DialOption) (*Conn, error) {
5757
u, err := url.Parse(rawUrl)
5858
if err != nil {
5959
return nil, err
6060
}
6161

62-
dial.u = u
63-
dial.dialTimeout = defaultTimeout
64-
if dial.Header == nil {
65-
dial.Header = make(http.Header)
62+
conf.u = u
63+
conf.dialTimeout = defaultTimeout
64+
if conf.Header == nil {
65+
conf.Header = make(http.Header)
6666
}
67-
return dial.Dial()
67+
return conf.Dial()
6868
}
6969

7070
// https://datatracker.ietf.org/doc/html/rfc6455#section-4.1
@@ -222,6 +222,10 @@ func (d *DialOption) Dial() (c *Conn, err error) {
222222
return nil, err
223223
}
224224

225+
if d.bindClientHttpHeader != nil {
226+
*d.bindClientHttpHeader = rsp.Header.Clone()
227+
}
228+
225229
cd := maybeCompressionDecompression(rsp.Header)
226230
if d.decompression {
227231
d.decompression = cd

client_option_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,54 @@ func Test_ClientOption(t *testing.T) {
106106
t.Error("not run server:method fail")
107107
}
108108
})
109+
110+
t.Run("6.1 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
111+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
112+
_, err := Upgrade(w, r)
113+
if err != nil {
114+
t.Error(err)
115+
}
116+
}))
117+
118+
defer ts.Close()
119+
120+
url := strings.ReplaceAll(ts.URL, "http", "ws")
121+
h := make(http.Header)
122+
con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
123+
"Sec-WebSocket-Protocol": []string{"token"},
124+
}))
125+
if err != nil {
126+
t.Error(err)
127+
}
128+
defer con.Close()
129+
130+
if h["Sec-Websocket-Protocol"][0] != "token" {
131+
t.Error("header fail")
132+
}
133+
})
134+
135+
t.Run("6.2 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
136+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
137+
_, err := Upgrade(w, r)
138+
if err != nil {
139+
t.Error(err)
140+
}
141+
}))
142+
143+
defer ts.Close()
144+
145+
url := strings.ReplaceAll(ts.URL, "http", "ws")
146+
h := make(http.Header)
147+
con, err := DialConf(url, ClientOptionToConf(WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
148+
"Sec-WebSocket-Protocol": []string{"token"},
149+
})))
150+
if err != nil {
151+
t.Error(err)
152+
}
153+
defer con.Close()
154+
155+
if h["Sec-Websocket-Protocol"][0] != "token" {
156+
t.Error("header fail")
157+
}
158+
})
109159
}

client_options.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,45 @@ import (
2222

2323
type ClientOption func(*DialOption)
2424

25-
// 配置tls.config
25+
// 1.配置tls.config
2626
func WithClientTLSConfig(tls *tls.Config) ClientOption {
2727
return func(o *DialOption) {
2828
o.tlsConfig = tls
2929
}
3030
}
3131

32-
// 配置http.Header
32+
// 2.配置http.Header
3333
func WithClientHTTPHeader(h http.Header) ClientOption {
3434
return func(o *DialOption) {
3535
o.Header = h
3636
}
3737
}
3838

39-
// 配置握手时的timeout
39+
// 3.配置握手时的timeout
4040
func WithClientDialTimeout(t time.Duration) ClientOption {
4141
return func(o *DialOption) {
4242
o.dialTimeout = t
4343
}
4444
}
4545

46-
// 配置压缩
46+
// 4.配置压缩
4747
func WithClientCompression() ClientOption {
4848
return func(o *DialOption) {
4949
o.compression = true
5050
}
5151
}
5252

53-
// 配置压缩和解压缩
53+
// 5.配置压缩和解压缩
5454
func WithClientDecompressAndCompress() ClientOption {
5555
return func(o *DialOption) {
5656
o.compression = true
5757
o.decompression = true
5858
}
5959
}
60+
61+
// 6.获取http header
62+
func WithClientBindHTTPHeader(h *http.Header) ClientOption {
63+
return func(o *DialOption) {
64+
o.bindClientHttpHeader = h
65+
}
66+
}

server_handshake.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ import (
2323
)
2424

2525
var (
26-
ErrNotFoundHijacker = errors.New("not found Hijacker")
27-
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept:")
28-
bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
29-
bytesCRLF = []byte("\r\n")
30-
bytesColon = []byte(": ")
26+
ErrNotFoundHijacker = errors.New("not found Hijacker")
27+
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept:")
28+
bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
29+
bytesCRLF = []byte("\r\n")
30+
strGetSecWebSocketProtocolKey = "Sec-WebSocket-Protocol"
31+
bytesPutSecWebSocketProtocolKey = []byte("Sec-WebSocket-Protocol: ")
3132
)
3233

3334
type ConnOption struct {
@@ -67,6 +68,17 @@ func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config) (err error)
6768
}
6869
}
6970

71+
v = r.Header.Get(strGetSecWebSocketProtocolKey)
72+
if len(v) > 0 {
73+
if _, err = w.Write(bytesPutSecWebSocketProtocolKey); err != nil {
74+
return
75+
}
76+
77+
if err = writeHeaderVal(w, StringToBytes(v)); err != nil {
78+
return err
79+
}
80+
}
81+
7082
_, err = w.Write(bytesCRLF)
7183
return err
7284
}
@@ -111,7 +123,6 @@ func checkRequest(r *http.Request) (ecode int, err error) {
111123
return http.StatusUpgradeRequired, ErrSecWebSocketVersion
112124
}
113125

114-
// TODO Sec-WebSocket-Protocol
115126
// TODO Sec-WebSocket-Extensions
116127
return 0, nil
117128
}

0 commit comments

Comments
 (0)