Skip to content

Commit fc60bda

Browse files
authored
add UpgradeV2 (#29)
1 parent 6eaa970 commit fc60bda

File tree

8 files changed

+63
-21
lines changed

8 files changed

+63
-21
lines changed

autobahn/config/fuzzingclient.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
{
22
"outdir": "./report/",
33
"servers": [
4+
{
5+
"agent": "global",
6+
"url": "ws://localhost:9001/global",
7+
"options": {
8+
"version": 18
9+
}
10+
},
411
{
512
"agent": "no-context-takeover-decompression-and-compression-no-tls",
613
"url": "ws://localhost:9001/no-context-takeover-decompression-and-compression",
@@ -37,4 +44,4 @@
3744
""
3845
],
3946
"exclude-agent-cases": {}
40-
}
47+
}

autobahn/server/autobahn-server.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,24 @@ func echoReadTime(w http.ResponseWriter, r *http.Request) {
136136
_ = c.ReadLoop()
137137
}
138138

139+
var upgrade = quickws.NewUpgrade(
140+
quickws.WithServerReplyPing(),
141+
quickws.WithServerDecompression(),
142+
quickws.WithServerIgnorePong(),
143+
quickws.WithServerEnableUTF8Check(),
144+
quickws.WithServerReadTimeout(5*time.Second),
145+
)
146+
147+
func global(w http.ResponseWriter, r *http.Request) {
148+
c, err := upgrade.UpgradeV2(w, r, &echoHandler{openWriteTimeout: true})
149+
if err != nil {
150+
fmt.Println("Upgrade fail:", err)
151+
return
152+
}
153+
154+
_ = c.ReadLoop()
155+
}
156+
139157
func startTLSServer(mux *http.ServeMux) {
140158

141159
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
@@ -167,6 +185,7 @@ func startServer(mux *http.ServeMux) {
167185
func main() {
168186
mux := &http.ServeMux{}
169187
mux.HandleFunc("/timeout", echoReadTime)
188+
mux.HandleFunc("/global", global)
170189
mux.HandleFunc("/no-context-takeover-decompression", echoNoContextDecompression)
171190
mux.HandleFunc("/no-context-takeover-decompression-and-compression", echoNoContextDecompressionAndCompression)
172191
mux.HandleFunc("/context-takeover-decompression", echoContextTakeoverDecompression)

client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,10 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) {
287287
if err := conn.SetDeadline(time.Time{}); err != nil {
288288
return nil, err
289289
}
290-
wsCon = newConn(conn, true /* client is true*/, &d.Config, fr, br)
290+
if wsCon, err = newConn(conn, true /* client is true*/, &d.Config, fr, br); err != nil {
291+
return nil, err
292+
}
291293
wsCon.pd = pd
294+
wsCon.Callback = d.cb
292295
return wsCon, nil
293296
}

common_options.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import (
2323
// 0. CallbackFunc
2424
func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ClientOption {
2525
return func(o *DialOption) {
26-
o.Callback = &funcToCallback{
26+
o.cb = &funcToCallback{
2727
onOpen: open,
2828
onMessage: m,
2929
onClose: c,
@@ -34,7 +34,7 @@ func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Cli
3434
// 配置服务端回调函数
3535
func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ServerOption {
3636
return func(o *ConnOption) {
37-
o.Callback = &funcToCallback{
37+
o.cb = &funcToCallback{
3838
onOpen: open,
3939
onMessage: m,
4040
onClose: c,
@@ -46,14 +46,14 @@ func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Ser
4646
// 配置客户端callback
4747
func WithClientCallback(cb Callback) ClientOption {
4848
return func(o *DialOption) {
49-
o.Callback = cb
49+
o.cb = cb
5050
}
5151
}
5252

5353
// 配置服务端回调函数
5454
func WithServerCallback(cb Callback) ServerOption {
5555
return func(o *ConnOption) {
56-
o.Callback = cb
56+
o.cb = cb
5757
}
5858
}
5959

@@ -90,14 +90,14 @@ func WithClientEnableUTF8Check() ClientOption {
9090
// 仅仅配置OnMessae函数
9191
func WithServerOnMessageFunc(cb OnMessageFunc) ServerOption {
9292
return func(o *ConnOption) {
93-
o.Callback = OnMessageFunc(cb)
93+
o.cb = OnMessageFunc(cb)
9494
}
9595
}
9696

9797
// 仅仅配置OnMessae函数
9898
func WithClientOnMessageFunc(cb OnMessageFunc) ClientOption {
9999
return func(o *DialOption) {
100-
o.Callback = OnMessageFunc(cb)
100+
o.cb = OnMessageFunc(cb)
101101
}
102102
}
103103

@@ -292,14 +292,14 @@ func WithClientReadTimeout(t time.Duration) ClientOption {
292292
// 17.1 配置服务端OnClose
293293
func WithServerOnCloseFunc(onClose func(c *Conn, err error)) ServerOption {
294294
return func(o *ConnOption) {
295-
o.Callback = OnCloseFunc(onClose)
295+
o.cb = OnCloseFunc(onClose)
296296
}
297297
}
298298

299299
// 17.2 配置客户端OnClose
300300
func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption {
301301
return func(o *DialOption) {
302-
o.Callback = OnCloseFunc(onClose)
302+
o.cb = OnCloseFunc(onClose)
303303
}
304304
}
305305

config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type DialerTimeout interface {
4141
// 一种是声明一个全局的配置,后面不停使用。
4242
// 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置
4343
type Config struct {
44-
Callback
44+
cb Callback
4545
deflate.PermessageDeflateConf // 静态配置, 从WithXXX函数中获取
4646
tcpNoDelay bool
4747
replyPing bool // 开启自动回复
@@ -67,7 +67,7 @@ func (c *Config) initPayloadSize() int {
6767

6868
// 默认设置
6969
func (c *Config) defaultSetting() error {
70-
c.Callback = &DefCallback{}
70+
c.cb = &DefCallback{}
7171
c.maxDelayWriteNum = 10
7272
c.windowsMultipleTimesPayloadSize = 1.0
7373
c.delayWriteInitBufferSize = 8 * 1024

config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func TestConfig_defaultSetting(t *testing.T) {
6767
for _, tt := range tests {
6868
t.Run(tt.name, func(t *testing.T) {
6969
c := &Config{
70-
Callback: tt.fields.Callback,
70+
cb: tt.fields.Callback,
7171
tcpNoDelay: tt.fields.tcpNoDelay,
7272
replyPing: tt.fields.replyPing,
7373
ignorePong: tt.fields.ignorePong,

conn.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ type delayWrite struct {
5858
type Conn struct {
5959
fr fixedreader.FixedReader // 默认使用windows
6060
c net.Conn // net.Conn
61+
Callback // callback移至conn中
6162
br *bufio.Reader // read和fr同时只能使用一个
6263
*Config // config 可能是全局,也可能是局部初始化得来的
6364
pd deflate.PermessageDeflateConf // permessageDeflate局部配置
@@ -87,18 +88,20 @@ func setNoDelay(c net.Conn, noDelay bool) error {
8788
return nil
8889
}
8990

90-
func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) *Conn {
91-
_ = setNoDelay(c, conf.tcpNoDelay)
91+
func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) (wsCon *Conn, err error) {
92+
if err = setNoDelay(c, conf.tcpNoDelay); err != nil {
93+
return nil, err
94+
}
9295

93-
con := &Conn{
96+
wsCon = &Conn{
9497
c: c,
9598
client: client,
9699
Config: conf,
97100
fr: fr,
98101
br: br,
99102
}
100103

101-
return con
104+
return wsCon, err
102105
}
103106

104107
// 返回标准库的net.Conn

upgrade.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ func NewUpgrade(opts ...ServerOption) *UpgradeServer {
4343
}
4444

4545
func (u *UpgradeServer) Upgrade(w http.ResponseWriter, r *http.Request) (c *Conn, err error) {
46-
return upgradeInner(w, r, &u.config)
46+
return upgradeInner(w, r, &u.config, nil)
47+
}
48+
49+
func (u *UpgradeServer) UpgradeV2(w http.ResponseWriter, r *http.Request, cb Callback) (c *Conn, err error) {
50+
return upgradeInner(w, r, &u.config, cb)
4751
}
4852

4953
func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) {
@@ -54,10 +58,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *C
5458
for _, o := range opts {
5559
o(&conf)
5660
}
57-
return upgradeInner(w, r, &conf.Config)
61+
return upgradeInner(w, r, &conf.Config, nil)
5862
}
5963

60-
func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn, err error) {
64+
func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config, cb Callback) (wsCon *Conn, err error) {
6165
if ecode, err := checkRequest(r); err != nil {
6266
http.Error(w, err.Error(), ecode)
6367
return nil, err
@@ -125,9 +129,15 @@ func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn
125129
if err := conn.SetDeadline(time.Time{}); err != nil {
126130
return nil, err
127131
}
128-
wsCon := newConn(conn, false, conf, fr, br)
132+
if wsCon, err = newConn(conn, false, conf, fr, br); err != nil {
133+
return nil, err
134+
}
129135

130136
wsCon.pd = pd
137+
wsCon.Callback = cb
138+
if cb == nil {
139+
wsCon.Callback = conf.cb
140+
}
131141
return wsCon, nil
132142
}
133143

0 commit comments

Comments
 (0)