1
- // Copyright 2021-2023 antlabs. All rights reserved.
1
+ // Copyright 2021-2024 antlabs. All rights reserved.
2
2
//
3
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
4
// you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ import (
28
28
"github.com/antlabs/wsutil/bytespool"
29
29
"github.com/antlabs/wsutil/enum"
30
30
"github.com/antlabs/wsutil/fixedreader"
31
+ "github.com/antlabs/wsutil/hostname"
31
32
)
32
33
33
34
var (
@@ -46,7 +47,9 @@ type DialOption struct {
46
47
47
48
func ClientOptionToConf (opts ... ClientOption ) * DialOption {
48
49
var dial DialOption
49
- dial .defaultSetting ()
50
+ if err := dial .defaultSetting (); err != nil {
51
+ panic (err .Error ())
52
+ }
50
53
for _ , o := range opts {
51
54
o (& dial )
52
55
}
@@ -82,7 +85,10 @@ func Dial(rawUrl string, opts ...ClientOption) (*Conn, error) {
82
85
dial .Header = make (http.Header )
83
86
}
84
87
85
- dial .defaultSetting ()
88
+ if err := dial .defaultSetting (); err != nil {
89
+ return nil , err
90
+ }
91
+
86
92
for _ , o := range opts {
87
93
o (& dial )
88
94
}
@@ -122,6 +128,10 @@ func (d *DialOption) handshake() (*http.Request, string, error) {
122
128
d .Header .Add ("Sec-WebSocket-Extensions" , strExtensions )
123
129
}
124
130
131
+ if len (d .subProtocols ) > 0 {
132
+ d .Header ["Sec-WebSocket-Protocol" ] = []string {strings .Join (d .subProtocols , ", " )}
133
+ }
134
+
125
135
req .Header = d .Header
126
136
return req , secWebSocket , nil
127
137
}
@@ -178,23 +188,36 @@ func (d *DialOption) tlsConn(c net.Conn) net.Conn {
178
188
}
179
189
180
190
func (d * DialOption ) Dial () (c * Conn , err error ) {
191
+ // scheme ws -> http
192
+ // scheme wss -> https
181
193
req , secWebSocket , err := d .handshake ()
182
194
if err != nil {
183
195
return nil , err
184
196
}
185
197
186
198
var conn net.Conn
187
199
begin := time .Now ()
200
+
201
+ hostName := hostname .GetHostName (d .u )
188
202
// conn, err := net.DialTimeout("tcp", d.u.Host /* TODO 加端号*/, d.dialTimeout)
189
- if d .dialFunc == nil {
190
- conn , err = net .Dial ("tcp" , d .u .Host /* TODO 加端号*/ )
191
- } else {
203
+ dialFunc := net .Dial
204
+ if d .dialFunc != nil {
192
205
dialInterface , err := d .dialFunc ()
193
206
if err != nil {
194
207
return nil , err
195
208
}
196
- conn , err = dialInterface .Dial ( "tcp" , d . u . Host )
209
+ dialFunc = dialInterface .Dial
197
210
}
211
+
212
+ if d .proxyFunc != nil {
213
+ proxyURL , err := d .proxyFunc (req )
214
+ if err != nil {
215
+ return nil , err
216
+ }
217
+ dialFunc = newhttpProxy (proxyURL , dialFunc ).Dial
218
+ }
219
+
220
+ conn , err = dialFunc ("tcp" , hostName )
198
221
if err != nil {
199
222
return nil , err
200
223
}
0 commit comments