Skip to content

Commit 7d8905e

Browse files
committed
fix: fix closing ws connection
1 parent 0c9b05a commit 7d8905e

File tree

2 files changed

+101
-12
lines changed

2 files changed

+101
-12
lines changed

taos-ws/src/query/asyn.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ impl WsTaos {
106106

107107
tokio::spawn(
108108
super::conn::run(
109-
ws_taos.clone(),
109+
Arc::downgrade(&ws_taos),
110110
builder,
111111
ws_stream,
112112
query_sender,
@@ -640,6 +640,7 @@ impl AsyncQueryable for WsTaos {
640640
impl Drop for WsTaos {
641641
fn drop(&mut self) {
642642
tracing::trace!("dropping ws connection, conn_id: {}", self.conn_id);
643+
self.set_state(ConnState::Disconnected);
643644
// Send close signal to reader/writer spawned tasks.
644645
let _ = self.close_signal.send(true);
645646
}

taos-ws/src/query/conn.rs

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::future::Future;
22
use std::pin::Pin;
3+
use std::sync::Weak;
34
use std::time::Duration;
45
use std::{
56
collections::HashSet,
@@ -137,7 +138,7 @@ impl MessageCache {
137138
}
138139

139140
pub(super) async fn run(
140-
ws_taos: Arc<WsTaos>,
141+
ws_taos: Weak<WsTaos>,
141142
builder: Arc<TaosBuilder>,
142143
mut ws_stream: WsStream,
143144
query_sender: WsQuerySender,
@@ -180,7 +181,9 @@ pub(super) async fn run(
180181
let _ = close_tx.send(true);
181182
if !is_disconnect_error(&err) {
182183
tracing::error!("non-disconnect error detected, cleaning up all pending queries");
183-
ws_taos.set_state(ConnState::Disconnected);
184+
if let Some(taos) = ws_taos.upgrade() {
185+
taos.set_state(ConnState::Disconnected);
186+
}
184187
cleanup_after_disconnect(query_sender.clone());
185188
return;
186189
}
@@ -189,13 +192,14 @@ pub(super) async fn run(
189192
}
190193
_ = close_reader.changed() => {
191194
tracing::info!("WebSocket received close signal");
192-
ws_taos.set_state(ConnState::Disconnected);
193195
let _ = close_tx.send(true);
194196
return;
195197
}
196198
}
197199

198-
ws_taos.set_state(ConnState::Reconnecting);
200+
if let Some(taos) = ws_taos.upgrade() {
201+
taos.set_state(ConnState::Reconnecting);
202+
}
199203

200204
if let Err(err) = send_handle.await {
201205
tracing::error!("send messages task failed: {err:?}");
@@ -213,19 +217,24 @@ pub(super) async fn run(
213217
}
214218
Err(err) => {
215219
tracing::error!("WebSocket reconnection failed: {err}");
216-
ws_taos.set_state(ConnState::Disconnected);
220+
if let Some(taos) = ws_taos.upgrade() {
221+
taos.set_state(ConnState::Disconnected);
222+
}
217223
cleanup_after_disconnect(query_sender.clone());
218224
return;
219225
}
220226
};
221227

222228
tracing::info!("WebSocket reconnected successfully");
223229

224-
ws_taos.wait_for_previous_recover_stmt2().await;
225-
let mut stmt2_req_ids = cleanup_stmt2(query_sender.sender.clone(), message_reader.clone());
226-
ws_taos.clone().recover_stmt2().await;
227-
stmt2_req_ids.extend(ws_taos.stmt2_req_ids().await);
228-
cleanup_after_reconnect(query_sender.clone(), cache.clone(), stmt2_req_ids);
230+
if let Some(taos) = ws_taos.upgrade() {
231+
taos.wait_for_previous_recover_stmt2().await;
232+
let mut stmt2_req_ids =
233+
cleanup_stmt2(query_sender.sender.clone(), message_reader.clone());
234+
taos.clone().recover_stmt2().await;
235+
stmt2_req_ids.extend(taos.stmt2_req_ids().await);
236+
cleanup_after_reconnect(query_sender.clone(), cache.clone(), stmt2_req_ids);
237+
}
229238
}
230239
}
231240

@@ -617,7 +626,7 @@ fn cleanup_stmt2(
617626

618627
#[cfg(test)]
619628
mod tests {
620-
use std::sync::atomic::{AtomicUsize, Ordering};
629+
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
621630
use std::sync::Arc;
622631
use std::time::Duration;
623632

@@ -944,4 +953,83 @@ mod tests {
944953

945954
Ok(())
946955
}
956+
957+
#[tokio::test]
958+
async fn test_wstaos_drop_closes_connection() -> anyhow::Result<()> {
959+
let _ = tracing_subscriber::fmt()
960+
.with_file(true)
961+
.with_line_number(true)
962+
.with_max_level(tracing::Level::INFO)
963+
.try_init();
964+
965+
let conn_dropped = Arc::new(AtomicBool::new(false));
966+
let conn_dropped_clone = conn_dropped.clone();
967+
968+
let routes = warp::path("ws")
969+
.and(warp::ws())
970+
.map(move |ws: warp::ws::Ws| {
971+
let conn_dropped = conn_dropped_clone.clone();
972+
ws.on_upgrade(move |ws| async move {
973+
let (mut ws_tx, mut ws_rx) = ws.split();
974+
975+
while let Some(res) = ws_rx.next().await {
976+
let message = res.unwrap();
977+
tracing::debug!("ws recv message: {message:?}");
978+
979+
if message.is_text() {
980+
let text = message.to_str().unwrap();
981+
let req: Value = serde_json::from_str(text).unwrap();
982+
let req_id = req
983+
.get("args")
984+
.and_then(|v| v.get("req_id"))
985+
.and_then(Value::as_u64)
986+
.unwrap_or(0);
987+
988+
if text.contains("version") {
989+
let data = json!({
990+
"code": 0,
991+
"message": "message",
992+
"action": "version",
993+
"req_id": req_id,
994+
"version": "3.0"
995+
});
996+
let _ = ws_tx.send(Message::text(data.to_string())).await;
997+
} else if text.contains("conn") {
998+
let data = json!({
999+
"code": 0,
1000+
"message": "message",
1001+
"action": "conn",
1002+
"req_id": req_id
1003+
});
1004+
let _ = ws_tx.send(Message::text(data.to_string())).await;
1005+
}
1006+
} else if message.is_close() {
1007+
tracing::info!("received close message from client");
1008+
conn_dropped.store(true, Ordering::Relaxed);
1009+
break;
1010+
}
1011+
}
1012+
})
1013+
});
1014+
1015+
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1016+
let (_, server) =
1017+
warp::serve(routes).bind_with_graceful_shutdown(([127, 0, 0, 1], 9984), async move {
1018+
let _ = shutdown_rx.await;
1019+
});
1020+
1021+
tokio::spawn(server);
1022+
tokio::time::sleep(Duration::from_millis(100)).await;
1023+
1024+
{
1025+
let _taos = TaosBuilder::from_dsn("ws://127.0.0.1:9984")?
1026+
.build()
1027+
.await?;
1028+
}
1029+
1030+
tokio::time::sleep(Duration::from_millis(500)).await;
1031+
assert!(conn_dropped.load(Ordering::Relaxed));
1032+
let _ = shutdown_tx.send(());
1033+
Ok(())
1034+
}
9471035
}

0 commit comments

Comments
 (0)