1
1
use std:: future:: Future ;
2
2
use std:: pin:: Pin ;
3
+ use std:: sync:: Weak ;
3
4
use std:: time:: Duration ;
4
5
use std:: {
5
6
collections:: HashSet ,
@@ -137,7 +138,7 @@ impl MessageCache {
137
138
}
138
139
139
140
pub ( super ) async fn run (
140
- ws_taos : Arc < WsTaos > ,
141
+ ws_taos : Weak < WsTaos > ,
141
142
builder : Arc < TaosBuilder > ,
142
143
mut ws_stream : WsStream ,
143
144
query_sender : WsQuerySender ,
@@ -180,7 +181,9 @@ pub(super) async fn run(
180
181
let _ = close_tx. send( true ) ;
181
182
if !is_disconnect_error( & err) {
182
183
tracing:: error!( "non-disconnect error detected, cleaning up all pending queries" ) ;
183
- ws_taos. set_state( ConnState :: Disconnected ) ;
184
+ if let Some ( taos) = ws_taos. upgrade( ) {
185
+ taos. set_state( ConnState :: Disconnected ) ;
186
+ }
184
187
cleanup_after_disconnect( query_sender. clone( ) ) ;
185
188
return ;
186
189
}
@@ -189,13 +192,14 @@ pub(super) async fn run(
189
192
}
190
193
_ = close_reader. changed( ) => {
191
194
tracing:: info!( "WebSocket received close signal" ) ;
192
- ws_taos. set_state( ConnState :: Disconnected ) ;
193
195
let _ = close_tx. send( true ) ;
194
196
return ;
195
197
}
196
198
}
197
199
198
- ws_taos. set_state ( ConnState :: Reconnecting ) ;
200
+ if let Some ( taos) = ws_taos. upgrade ( ) {
201
+ taos. set_state ( ConnState :: Reconnecting ) ;
202
+ }
199
203
200
204
if let Err ( err) = send_handle. await {
201
205
tracing:: error!( "send messages task failed: {err:?}" ) ;
@@ -213,19 +217,24 @@ pub(super) async fn run(
213
217
}
214
218
Err ( err) => {
215
219
tracing:: error!( "WebSocket reconnection failed: {err}" ) ;
216
- ws_taos. set_state ( ConnState :: Disconnected ) ;
220
+ if let Some ( taos) = ws_taos. upgrade ( ) {
221
+ taos. set_state ( ConnState :: Disconnected ) ;
222
+ }
217
223
cleanup_after_disconnect ( query_sender. clone ( ) ) ;
218
224
return ;
219
225
}
220
226
} ;
221
227
222
228
tracing:: info!( "WebSocket reconnected successfully" ) ;
223
229
224
- ws_taos. wait_for_previous_recover_stmt2 ( ) . await ;
225
- let mut stmt2_req_ids = cleanup_stmt2 ( query_sender. sender . clone ( ) , message_reader. clone ( ) ) ;
226
- ws_taos. clone ( ) . recover_stmt2 ( ) . await ;
227
- stmt2_req_ids. extend ( ws_taos. stmt2_req_ids ( ) . await ) ;
228
- cleanup_after_reconnect ( query_sender. clone ( ) , cache. clone ( ) , stmt2_req_ids) ;
230
+ if let Some ( taos) = ws_taos. upgrade ( ) {
231
+ taos. wait_for_previous_recover_stmt2 ( ) . await ;
232
+ let mut stmt2_req_ids =
233
+ cleanup_stmt2 ( query_sender. sender . clone ( ) , message_reader. clone ( ) ) ;
234
+ taos. clone ( ) . recover_stmt2 ( ) . await ;
235
+ stmt2_req_ids. extend ( taos. stmt2_req_ids ( ) . await ) ;
236
+ cleanup_after_reconnect ( query_sender. clone ( ) , cache. clone ( ) , stmt2_req_ids) ;
237
+ }
229
238
}
230
239
}
231
240
@@ -617,7 +626,7 @@ fn cleanup_stmt2(
617
626
618
627
#[ cfg( test) ]
619
628
mod tests {
620
- use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
629
+ use std:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
621
630
use std:: sync:: Arc ;
622
631
use std:: time:: Duration ;
623
632
@@ -944,4 +953,83 @@ mod tests {
944
953
945
954
Ok ( ( ) )
946
955
}
956
+
957
+ #[ tokio:: test]
958
+ async fn test_wstaos_drop_closes_connection ( ) -> anyhow:: Result < ( ) > {
959
+ let _ = tracing_subscriber:: fmt ( )
960
+ . with_file ( true )
961
+ . with_line_number ( true )
962
+ . with_max_level ( tracing:: Level :: INFO )
963
+ . try_init ( ) ;
964
+
965
+ let conn_dropped = Arc :: new ( AtomicBool :: new ( false ) ) ;
966
+ let conn_dropped_clone = conn_dropped. clone ( ) ;
967
+
968
+ let routes = warp:: path ( "ws" )
969
+ . and ( warp:: ws ( ) )
970
+ . map ( move |ws : warp:: ws:: Ws | {
971
+ let conn_dropped = conn_dropped_clone. clone ( ) ;
972
+ ws. on_upgrade ( move |ws| async move {
973
+ let ( mut ws_tx, mut ws_rx) = ws. split ( ) ;
974
+
975
+ while let Some ( res) = ws_rx. next ( ) . await {
976
+ let message = res. unwrap ( ) ;
977
+ tracing:: debug!( "ws recv message: {message:?}" ) ;
978
+
979
+ if message. is_text ( ) {
980
+ let text = message. to_str ( ) . unwrap ( ) ;
981
+ let req: Value = serde_json:: from_str ( text) . unwrap ( ) ;
982
+ let req_id = req
983
+ . get ( "args" )
984
+ . and_then ( |v| v. get ( "req_id" ) )
985
+ . and_then ( Value :: as_u64)
986
+ . unwrap_or ( 0 ) ;
987
+
988
+ if text. contains ( "version" ) {
989
+ let data = json ! ( {
990
+ "code" : 0 ,
991
+ "message" : "message" ,
992
+ "action" : "version" ,
993
+ "req_id" : req_id,
994
+ "version" : "3.0"
995
+ } ) ;
996
+ let _ = ws_tx. send ( Message :: text ( data. to_string ( ) ) ) . await ;
997
+ } else if text. contains ( "conn" ) {
998
+ let data = json ! ( {
999
+ "code" : 0 ,
1000
+ "message" : "message" ,
1001
+ "action" : "conn" ,
1002
+ "req_id" : req_id
1003
+ } ) ;
1004
+ let _ = ws_tx. send ( Message :: text ( data. to_string ( ) ) ) . await ;
1005
+ }
1006
+ } else if message. is_close ( ) {
1007
+ tracing:: info!( "received close message from client" ) ;
1008
+ conn_dropped. store ( true , Ordering :: Relaxed ) ;
1009
+ break ;
1010
+ }
1011
+ }
1012
+ } )
1013
+ } ) ;
1014
+
1015
+ let ( shutdown_tx, shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
1016
+ let ( _, server) =
1017
+ warp:: serve ( routes) . bind_with_graceful_shutdown ( ( [ 127 , 0 , 0 , 1 ] , 9984 ) , async move {
1018
+ let _ = shutdown_rx. await ;
1019
+ } ) ;
1020
+
1021
+ tokio:: spawn ( server) ;
1022
+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
1023
+
1024
+ {
1025
+ let _taos = TaosBuilder :: from_dsn ( "ws://127.0.0.1:9984" ) ?
1026
+ . build ( )
1027
+ . await ?;
1028
+ }
1029
+
1030
+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
1031
+ assert ! ( conn_dropped. load( Ordering :: Relaxed ) ) ;
1032
+ let _ = shutdown_tx. send ( ( ) ) ;
1033
+ Ok ( ( ) )
1034
+ }
947
1035
}
0 commit comments