Skip to content

Commit e328446

Browse files
committed
sql with module thread again
1 parent d533f9c commit e328446

File tree

2 files changed

+48
-50
lines changed

2 files changed

+48
-50
lines changed

crates/client-api/src/lib.rs

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -72,54 +72,52 @@ impl Host {
7272
body: String,
7373
) -> axum::response::Result<Vec<SqlStmtResult<ProductValue>>> {
7474
let module_host = self
75-
.host_controller
76-
.get_or_launch_module_host(database, self.replica_id)
75+
.module()
7776
.await
7877
.map_err(|_| (StatusCode::NOT_FOUND, "module not found".to_string()))?;
7978

80-
tracing::info!(sql = body);
81-
// We need a header for query results
82-
let mut header = vec![];
83-
let sql_start = std::time::Instant::now();
84-
let sql_span = tracing::trace_span!("execute_sql", total_duration = tracing::field::Empty).entered();
85-
let db = &module_host.module.replica_ctx().relational_db;
86-
87-
let result = sql::execute::run(
88-
// Returns an empty result set for mutations
89-
&db,
90-
&body,
91-
auth,
92-
Some(&module_host),
93-
auth.caller,
94-
&mut header,
95-
)
96-
.await
97-
.map_err(|e| {
98-
log::warn!("{e}");
99-
if let Some(auth_err) = e.get_auth_error() {
100-
(StatusCode::UNAUTHORIZED, auth_err.to_string())
101-
} else {
102-
(StatusCode::BAD_REQUEST, e.to_string())
103-
}
104-
})?;
105-
106-
let total_duration = sql_start.elapsed();
107-
sql_span.record("total_duration", tracing::field::debug(total_duration));
108-
109-
// Turn the header into a `ProductType`
110-
let schema = header
111-
.into_iter()
112-
.map(|(col_name, col_type)| ProductTypeElement::new(col_type, Some(col_name)))
113-
.collect();
114-
115-
let tx_offset = result.tx_offset;
116-
let durable_offset = db.durable_tx_offset();
117-
let json = vec![SqlStmtResult {
118-
schema,
119-
rows: result.rows,
120-
total_duration_micros: total_duration.as_micros() as u64,
121-
stats: SqlStmtStats::from_metrics(&result.metrics),
122-
}];
79+
let (tx_offset, durable_offset, json) = self
80+
.host_controller
81+
.using_database(database, self.replica_id, move |db| async move {
82+
tracing::info!(sql = body);
83+
let mut header = vec![];
84+
let sql_start = std::time::Instant::now();
85+
let sql_span = tracing::trace_span!("execute_sql", total_duration = tracing::field::Empty,);
86+
let _guard = sql_span.enter();
87+
88+
let result = sql::execute::run(&db, &body, auth.clone(), Some(&module_host), auth.caller, &mut header)
89+
.await
90+
.map_err(|e| {
91+
log::warn!("{e}");
92+
if let Some(auth_err) = e.get_auth_error() {
93+
(StatusCode::UNAUTHORIZED, auth_err.to_string())
94+
} else {
95+
(StatusCode::BAD_REQUEST, e.to_string())
96+
}
97+
})?;
98+
99+
let total_duration = sql_start.elapsed();
100+
drop(_guard);
101+
sql_span.record("total_duration", tracing::field::debug(total_duration));
102+
103+
let schema = header
104+
.into_iter()
105+
.map(|(col_name, col_type)| ProductTypeElement::new(col_type, Some(col_name)))
106+
.collect();
107+
108+
Ok::<_, (StatusCode, String)>((
109+
result.tx_offset,
110+
db.durable_tx_offset(),
111+
vec![SqlStmtResult {
112+
schema,
113+
rows: result.rows,
114+
total_duration_micros: total_duration.as_micros() as u64,
115+
stats: SqlStmtStats::from_metrics(&result.metrics),
116+
}],
117+
))
118+
})
119+
.await
120+
.map_err(log_and_500)??;
123121

124122
if confirmed_read {
125123
if let Some(mut durable_offset) = durable_offset {
@@ -130,6 +128,7 @@ impl Host {
130128

131129
Ok(json)
132130
}
131+
133132
pub async fn update(
134133
&self,
135134
database: Database,
@@ -142,7 +141,6 @@ impl Host {
142141
.await
143142
}
144143
}
145-
146144
/// Parameters for publishing a database.
147145
///
148146
/// See [`ControlStateDelegate::publish_database`].

crates/core/src/host/host_controller.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,10 @@ impl HostController {
342342
/// If the computation `F` panics, the host is removed from this controller,
343343
/// releasing its resources.
344344
#[tracing::instrument(level = "trace", skip_all)]
345-
pub async fn using_database<F, T>(&self, database: Database, replica_id: u64, f: F) -> anyhow::Result<T>
345+
pub async fn using_database<F, Fut, T>(&self, database: Database, replica_id: u64, f: F) -> anyhow::Result<T>
346346
where
347-
F: FnOnce(&RelationalDB) -> T + Send + 'static,
347+
F: FnOnce(Arc<RelationalDB>) -> Fut + Send + 'static,
348+
Fut: std::future::Future<Output = T> + Send + 'static,
348349
T: Send + 'static,
349350
{
350351
trace!("using database {}/{}", database.database_identity, replica_id);
@@ -356,10 +357,9 @@ impl HostController {
356357
});
357358

358359
let db = module.replica_ctx().relational_db.clone();
359-
let result = module.on_module_thread("using_database", move || f(&db)).await?;
360+
let result = module.on_module_thread("using_database", move || f(db)).await?.await;
360361
Ok(result)
361362
}
362-
363363
/// Update the [`ModuleHost`] identified by `replica_id` to the given
364364
/// program.
365365
///

0 commit comments

Comments
 (0)