Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion taos-ws/src/query/asyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl WsTaos {

tokio::spawn(
super::conn::run(
ws_taos.clone(),
Arc::downgrade(&ws_taos),
builder,
ws_stream,
query_sender,
Expand Down Expand Up @@ -640,6 +640,7 @@ impl AsyncQueryable for WsTaos {
impl Drop for WsTaos {
fn drop(&mut self) {
tracing::trace!("dropping ws connection, conn_id: {}", self.conn_id);
self.set_state(ConnState::Disconnected);
// Send close signal to reader/writer spawned tasks.
let _ = self.close_signal.send(true);
}
Expand Down
110 changes: 99 additions & 11 deletions taos-ws/src/query/conn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::Weak;
use std::time::Duration;
use std::{
collections::HashSet,
Expand Down Expand Up @@ -137,7 +138,7 @@ impl MessageCache {
}

pub(super) async fn run(
ws_taos: Arc<WsTaos>,
ws_taos: Weak<WsTaos>,
builder: Arc<TaosBuilder>,
mut ws_stream: WsStream,
query_sender: WsQuerySender,
Expand Down Expand Up @@ -180,7 +181,9 @@ pub(super) async fn run(
let _ = close_tx.send(true);
if !is_disconnect_error(&err) {
tracing::error!("non-disconnect error detected, cleaning up all pending queries");
ws_taos.set_state(ConnState::Disconnected);
if let Some(taos) = ws_taos.upgrade() {
taos.set_state(ConnState::Disconnected);
}
cleanup_after_disconnect(query_sender.clone());
return;
}
Expand All @@ -189,13 +192,14 @@ pub(super) async fn run(
}
_ = close_reader.changed() => {
tracing::info!("WebSocket received close signal");
ws_taos.set_state(ConnState::Disconnected);
let _ = close_tx.send(true);
return;
}
}

ws_taos.set_state(ConnState::Reconnecting);
if let Some(taos) = ws_taos.upgrade() {
taos.set_state(ConnState::Reconnecting);
}

if let Err(err) = send_handle.await {
tracing::error!("send messages task failed: {err:?}");
Expand All @@ -213,19 +217,24 @@ pub(super) async fn run(
}
Err(err) => {
tracing::error!("WebSocket reconnection failed: {err}");
ws_taos.set_state(ConnState::Disconnected);
if let Some(taos) = ws_taos.upgrade() {
taos.set_state(ConnState::Disconnected);
}
cleanup_after_disconnect(query_sender.clone());
return;
}
};

tracing::info!("WebSocket reconnected successfully");

ws_taos.wait_for_previous_recover_stmt2().await;
let mut stmt2_req_ids = cleanup_stmt2(query_sender.sender.clone(), message_reader.clone());
ws_taos.clone().recover_stmt2().await;
stmt2_req_ids.extend(ws_taos.stmt2_req_ids().await);
cleanup_after_reconnect(query_sender.clone(), cache.clone(), stmt2_req_ids);
if let Some(taos) = ws_taos.upgrade() {
taos.wait_for_previous_recover_stmt2().await;
let mut stmt2_req_ids =
cleanup_stmt2(query_sender.sender.clone(), message_reader.clone());
taos.clone().recover_stmt2().await;
stmt2_req_ids.extend(taos.stmt2_req_ids().await);
cleanup_after_reconnect(query_sender.clone(), cache.clone(), stmt2_req_ids);
}
}
}

Expand Down Expand Up @@ -617,7 +626,7 @@ fn cleanup_stmt2(

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -944,4 +953,83 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_wstaos_drop_closes_connection() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt()
.with_file(true)
.with_line_number(true)
.with_max_level(tracing::Level::INFO)
.try_init();

let conn_dropped = Arc::new(AtomicBool::new(false));
let conn_dropped_clone = conn_dropped.clone();

let routes = warp::path("ws")
.and(warp::ws())
.map(move |ws: warp::ws::Ws| {
let conn_dropped = conn_dropped_clone.clone();
ws.on_upgrade(move |ws| async move {
let (mut ws_tx, mut ws_rx) = ws.split();

while let Some(res) = ws_rx.next().await {
let message = res.unwrap();
tracing::debug!("ws recv message: {message:?}");

if message.is_text() {
let text = message.to_str().unwrap();
let req: Value = serde_json::from_str(text).unwrap();
let req_id = req
.get("args")
.and_then(|v| v.get("req_id"))
.and_then(Value::as_u64)
.unwrap_or(0);

if text.contains("version") {
let data = json!({
"code": 0,
"message": "message",
"action": "version",
"req_id": req_id,
"version": "3.0"
});
let _ = ws_tx.send(Message::text(data.to_string())).await;
} else if text.contains("conn") {
let data = json!({
"code": 0,
"message": "message",
"action": "conn",
"req_id": req_id
});
let _ = ws_tx.send(Message::text(data.to_string())).await;
}
} else if message.is_close() {
tracing::info!("received close message from client");
conn_dropped.store(true, Ordering::Relaxed);
break;
}
}
})
});

let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let (_, server) =
warp::serve(routes).bind_with_graceful_shutdown(([127, 0, 0, 1], 9984), async move {
let _ = shutdown_rx.await;
});

tokio::spawn(server);
tokio::time::sleep(Duration::from_millis(100)).await;

{
let _taos = TaosBuilder::from_dsn("ws://127.0.0.1:9984")?
.build()
.await?;
}

tokio::time::sleep(Duration::from_millis(500)).await;
assert!(conn_dropped.load(Ordering::Relaxed));
let _ = shutdown_tx.send(());
Ok(())
}
}