@@ -14,9 +14,11 @@ use bytes::BytesMut;
14
14
use futures_util:: { sink:: SinkExt , stream:: StreamExt , TryStreamExt } ;
15
15
use http:: { header:: HOST , HeaderMap } ;
16
16
use std:: sync:: Arc ;
17
+ use tokio:: net:: TcpStream ;
17
18
use tokio_tungstenite:: {
18
19
connect_async,
19
- tungstenite:: { protocol:: CloseFrame , Message as MsgTng } ,
20
+ tungstenite:: { self , protocol:: CloseFrame , Message as MsgTng } ,
21
+ MaybeTlsStream , WebSocketStream ,
20
22
} ;
21
23
use tower_http:: trace:: TraceLayer ;
22
24
@@ -144,6 +146,121 @@ fn make_outbound_request(
144
146
Ok ( request)
145
147
}
146
148
149
+ /// Attempts to create a websocket connection.
150
+ /// If `insecure` is true, it will allow websocket connections over insecure TLS.
151
+ /// Will fail if `insecure` is true and none of the `native-tls` or `rustls` features are enabled.
152
+ async fn connect_webscoket (
153
+ outbound_request : http:: Request < ( ) > ,
154
+ insecure : bool ,
155
+ ) -> anyhow:: Result < (
156
+ WebSocketStream < MaybeTlsStream < TcpStream > > ,
157
+ tungstenite:: handshake:: client:: Response ,
158
+ ) > {
159
+ if insecure {
160
+ #[ cfg( any( feature = "native-tls" , feature = "rustls" ) ) ]
161
+ {
162
+ use tokio_tungstenite:: connect_async_tls_with_config;
163
+ connect_async_tls_with_config ( outbound_request, None , false , make_insecure_connector ( ) )
164
+ . await
165
+ . map_err ( |err| {
166
+ anyhow:: anyhow!( "error establishing insecure WebSocket connection: {err}" )
167
+ } )
168
+ }
169
+ #[ cfg( not( any( feature = "native-tls" , feature = "rustls" ) ) ) ]
170
+ {
171
+ Err ( anyhow:: anyhow!(
172
+ "Insecure WebSockets requires the `native-tls` or `rustls` to be feature enabled."
173
+ ) )
174
+ }
175
+ } else {
176
+ connect_async ( outbound_request)
177
+ . await
178
+ . map_err ( |err| anyhow:: anyhow!( "error establishing secure WebSocket connection: {err}" ) )
179
+ }
180
+ }
181
+
182
+ /// Create a connector which does not verify TLS certificates.
183
+ /// Defaults to a `rustls` connector if both the `rustls` and `native-tls` features are enabled.
184
+ #[ cfg( any( feature = "native-tls" , feature = "rustls" ) ) ]
185
+ fn make_insecure_connector ( ) -> Option < tokio_tungstenite:: Connector > {
186
+ #[ cfg( feature = "rustls" ) ]
187
+ {
188
+ use rustls:: {
189
+ client:: danger:: { HandshakeSignatureValid , ServerCertVerified , ServerCertVerifier } ,
190
+ pki_types:: { CertificateDer , ServerName , UnixTime } ,
191
+ ClientConfig , DigitallySignedStruct , SignatureScheme ,
192
+ } ;
193
+
194
+ /// A `rustls` certificate verifier that allows insecure certificates.
195
+ #[ derive( Debug ) ]
196
+ struct NoCertVerification ;
197
+
198
+ impl ServerCertVerifier for NoCertVerification {
199
+ fn verify_server_cert (
200
+ & self ,
201
+ _: & CertificateDer < ' _ > ,
202
+ _: & [ CertificateDer < ' _ > ] ,
203
+ _: & ServerName < ' _ > ,
204
+ _: & [ u8 ] ,
205
+ _: UnixTime ,
206
+ ) -> Result < ServerCertVerified , rustls:: Error > {
207
+ Ok ( ServerCertVerified :: assertion ( ) )
208
+ }
209
+
210
+ fn verify_tls12_signature (
211
+ & self ,
212
+ _: & [ u8 ] ,
213
+ _: & CertificateDer < ' _ > ,
214
+ _: & DigitallySignedStruct ,
215
+ ) -> Result < HandshakeSignatureValid , rustls:: Error > {
216
+ Ok ( HandshakeSignatureValid :: assertion ( ) )
217
+ }
218
+
219
+ fn verify_tls13_signature (
220
+ & self ,
221
+ _: & [ u8 ] ,
222
+ _: & CertificateDer < ' _ > ,
223
+ _: & DigitallySignedStruct ,
224
+ ) -> Result < HandshakeSignatureValid , rustls:: Error > {
225
+ Ok ( HandshakeSignatureValid :: assertion ( ) )
226
+ }
227
+
228
+ fn supported_verify_schemes ( & self ) -> Vec < SignatureScheme > {
229
+ vec ! [
230
+ SignatureScheme :: ED25519 ,
231
+ SignatureScheme :: ECDSA_NISTP256_SHA256 ,
232
+ SignatureScheme :: ECDSA_NISTP384_SHA384 ,
233
+ SignatureScheme :: ECDSA_NISTP521_SHA512 ,
234
+ SignatureScheme :: RSA_PSS_SHA256 ,
235
+ SignatureScheme :: RSA_PSS_SHA384 ,
236
+ SignatureScheme :: RSA_PSS_SHA512 ,
237
+ SignatureScheme :: ED448 ,
238
+ ]
239
+ }
240
+ }
241
+
242
+ Some ( tokio_tungstenite:: Connector :: Rustls ( std:: sync:: Arc :: new (
243
+ ClientConfig :: builder ( )
244
+ . dangerous ( )
245
+ . with_custom_certificate_verifier ( Arc :: new ( NoCertVerification { } ) )
246
+ . with_no_client_auth ( ) ,
247
+ ) ) )
248
+ }
249
+ #[ cfg( all( feature = "native-tls" , not( feature = "rustls" ) ) ) ]
250
+ {
251
+ match native_tls:: TlsConnector :: builder ( )
252
+ . danger_accept_invalid_certs ( true )
253
+ . build ( )
254
+ {
255
+ Ok ( connector) => Some ( tokio_tungstenite:: Connector :: NativeTls ( connector) ) ,
256
+ Err ( err) => {
257
+ tracing:: error!( error = ?err, "error building native TLS connector" ) ;
258
+ None
259
+ }
260
+ }
261
+ }
262
+ }
263
+
147
264
impl ProxyHandlerHttp {
148
265
/// Construct a new instance.
149
266
pub fn new (
@@ -243,6 +360,8 @@ pub struct ProxyHandlerWebSocket {
243
360
rewrite : Option < String > ,
244
361
/// The headers to inject with the request
245
362
request_headers : HeaderMap ,
363
+ /// Allow insecure TLS websocket connections.
364
+ insecure : bool ,
246
365
}
247
366
248
367
impl ProxyHandlerWebSocket {
@@ -252,12 +371,14 @@ impl ProxyHandlerWebSocket {
252
371
backend : Uri ,
253
372
headers : HeaderMap ,
254
373
rewrite : Option < String > ,
374
+ insecure : bool ,
255
375
) -> Arc < Self > {
256
376
Arc :: new ( Self {
257
377
proto,
258
378
backend,
259
379
rewrite,
260
380
request_headers : headers,
381
+ insecure,
261
382
} )
262
383
}
263
384
@@ -337,14 +458,15 @@ impl ProxyHandlerWebSocket {
337
458
}
338
459
} ;
339
460
340
- // Establish WS connection to backend.
341
- let ( backend, _res) = match connect_async ( outbound_request) . await {
461
+ // Try to astablish a websocket connection to the backend and handle potential errors
462
+ let ( backend, _res) = match connect_webscoket ( outbound_request, self . insecure ) . await {
342
463
Ok ( backend) => backend,
343
464
Err ( err) => {
344
465
tracing:: error!( error = ?err, "error establishing WebSocket connection to backend {:?} for proxy" , & outbound_uri) ;
345
466
return ;
346
467
}
347
468
} ;
469
+
348
470
let ( mut backend_sink, mut backend_stream) = backend. split ( ) ;
349
471
let ( mut frontend_sink, mut frontend_stream) = ws. split ( ) ;
350
472
0 commit comments