Skip to content

Commit ea3805e

Browse files
committed
wip: migrate to thiserror
1 parent b33297c commit ea3805e

File tree

17 files changed

+827
-808
lines changed

17 files changed

+827
-808
lines changed

Cargo.lock

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ serde-toml-merge = { version = "0.3.8"}
5050
jwt = { version = "0.16.0", features = ["openssl"] }
5151
openssl = { version = "0.10.71"}
5252
iota = { version = "0.2.3" }
53-
53+
thiserror = "2.0"
5454

5555
[replace]
5656
'deadpool:0.10.0' = { path = 'patches/deadpool' }

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.DEFAULT_GOAL := build
2+
.PHONY: build install test
23

34
build:
45
cargo build --release
@@ -9,4 +10,3 @@ install: build
910

1011
test:
1112
cargo test
12-
./tests/tests.sh

src/admin.rs

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ use log::{debug, error, info};
55
use nix::sys::signal::{self, Signal};
66
use nix::unistd::Pid;
77
use std::collections::HashMap;
8+
use std::marker::Unpin;
89
/// Admin database.
910
use std::sync::atomic::Ordering;
11+
use tokio::io::AsyncWrite;
1012
use tokio::time::Instant;
1113

1214
use crate::config::{get_config, reload_config, VERSION};
13-
use crate::errors::Error;
15+
use crate::errors::{Error, ProtocolSyncError, ServerError};
1416
use crate::messages::*;
1517
use crate::pool::get_all_pools;
1618
use crate::pool::ClientServerMap;
1719
use crate::stats::client::{CLIENT_STATE_ACTIVE, CLIENT_STATE_IDLE};
18-
#[cfg(target_os = "linux")]
19-
use crate::stats::get_socket_states_count;
2020
use crate::stats::server::{SERVER_STATE_ACTIVE, SERVER_STATE_IDLE};
2121
use crate::stats::{
2222
get_client_stats, get_server_stats, CANCEL_CONNECTION_COUNTER, PLAIN_CONNECTION_COUNTER,
@@ -42,15 +42,15 @@ pub async fn handle_admin<T>(
4242
client_server_map: ClientServerMap,
4343
) -> Result<(), Error>
4444
where
45-
T: tokio::io::AsyncWrite + std::marker::Unpin,
45+
T: AsyncWrite + Unpin,
4646
{
47-
let code = query.get_u8() as char;
48-
49-
if code != 'Q' {
50-
return Err(Error::ProtocolSyncError(format!(
51-
"Invalid code, expected 'Q' but got '{}'",
52-
code
53-
)));
47+
let code = query.get_u8();
48+
if code != b'Q' {
49+
return Err(ProtocolSyncError::InvalidCode {
50+
expected: b'Q',
51+
actual: code,
52+
}
53+
.into());
5454
}
5555

5656
let len = query.get_i32() as usize;
@@ -110,7 +110,7 @@ where
110110
/// Column-oriented statistics.
111111
async fn show_lists<T>(stream: &mut T) -> Result<(), Error>
112112
where
113-
T: tokio::io::AsyncWrite + std::marker::Unpin,
113+
T: AsyncWrite + Unpin,
114114
{
115115
let client_stats = get_client_stats();
116116
let server_stats = get_server_stats();
@@ -206,13 +206,13 @@ where
206206
res.put_i32(5);
207207
res.put_u8(b'I');
208208

209-
write_all_half(stream, &res).await
209+
Ok(write_all_half(stream, &res).await?)
210210
}
211211

212212
/// Show PgDoorman version.
213213
async fn show_version<T>(stream: &mut T) -> Result<(), Error>
214214
where
215-
T: tokio::io::AsyncWrite + std::marker::Unpin,
215+
T: AsyncWrite + Unpin,
216216
{
217217
let mut res = BytesMut::new();
218218

@@ -224,13 +224,13 @@ where
224224
res.put_i32(5);
225225
res.put_u8(b'I');
226226

227-
write_all_half(stream, &res).await
227+
Ok(write_all_half(stream, &res).await?)
228228
}
229229

230230
/// Show utilization of connection pools for each pool.
231231
async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
232232
where
233-
T: tokio::io::AsyncWrite + std::marker::Unpin,
233+
T: AsyncWrite + Unpin,
234234
{
235235
let pool_lookup = PoolStats::construct_pool_lookup();
236236
let mut res = BytesMut::new();
@@ -245,13 +245,13 @@ where
245245
res.put_i32(5);
246246
res.put_u8(b'I');
247247

248-
write_all_half(stream, &res).await
248+
Ok(write_all_half(stream, &res).await?)
249249
}
250250

251251
/// Show extended utilization of connection pools for each pool.
252252
async fn show_pools_extended<T>(stream: &mut T) -> Result<(), Error>
253253
where
254-
T: tokio::io::AsyncWrite + std::marker::Unpin,
254+
T: AsyncWrite + Unpin,
255255
{
256256
let pool_lookup = PoolStats::construct_pool_lookup();
257257
let mut res = BytesMut::new();
@@ -268,13 +268,13 @@ where
268268
res.put_i32(5);
269269
res.put_u8(b'I');
270270

271-
write_all_half(stream, &res).await
271+
Ok(write_all_half(stream, &res).await?)
272272
}
273273

274274
/// Show all available options.
275275
async fn show_help<T>(stream: &mut T) -> Result<(), Error>
276276
where
277-
T: tokio::io::AsyncWrite + std::marker::Unpin,
277+
T: AsyncWrite + Unpin,
278278
{
279279
let mut res = BytesMut::new();
280280

@@ -307,13 +307,13 @@ where
307307
res.put_i32(5);
308308
res.put_u8(b'I');
309309

310-
write_all_half(stream, &res).await
310+
Ok(write_all_half(stream, &res).await?)
311311
}
312312

313313
/// Show databases.
314314
async fn show_databases<T>(stream: &mut T) -> Result<(), Error>
315315
where
316-
T: tokio::io::AsyncWrite + std::marker::Unpin,
316+
T: AsyncWrite + Unpin,
317317
{
318318
// Columns
319319
let columns = vec![
@@ -361,22 +361,22 @@ where
361361
res.put_i32(5);
362362
res.put_u8(b'I');
363363

364-
write_all_half(stream, &res).await
364+
Ok(write_all_half(stream, &res).await?)
365365
}
366366

367367
/// Ignore any SET commands the client sends.
368368
/// This is common initialization done by ORMs.
369369
async fn ignore_set<T>(stream: &mut T) -> Result<(), Error>
370370
where
371-
T: tokio::io::AsyncWrite + std::marker::Unpin,
371+
T: AsyncWrite + Unpin,
372372
{
373373
custom_protocol_response_ok(stream, "SET").await
374374
}
375375

376376
/// Reload the configuration file without restarting the process.
377377
async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
378378
where
379-
T: tokio::io::AsyncWrite + std::marker::Unpin,
379+
T: AsyncWrite + Unpin,
380380
{
381381
info!("Reloading config");
382382

@@ -393,13 +393,13 @@ where
393393
res.put_i32(5);
394394
res.put_u8(b'I');
395395

396-
write_all_half(stream, &res).await
396+
Ok(write_all_half(stream, &res).await?)
397397
}
398398

399399
/// Shows current configuration.
400400
async fn show_config<T>(stream: &mut T) -> Result<(), Error>
401401
where
402-
T: tokio::io::AsyncWrite + std::marker::Unpin,
402+
T: AsyncWrite + Unpin,
403403
{
404404
let config = &get_config();
405405
let config: HashMap<String, String> = config.into();
@@ -439,13 +439,13 @@ where
439439
res.put_i32(5);
440440
res.put_u8(b'I');
441441

442-
write_all_half(stream, &res).await
442+
Ok(write_all_half(stream, &res).await?)
443443
}
444444

445445
/// Show stats.
446446
async fn show_stats<T>(stream: &mut T) -> Result<(), Error>
447447
where
448-
T: tokio::io::AsyncWrite + std::marker::Unpin,
448+
T: AsyncWrite + Unpin,
449449
{
450450
let pool_lookup = PoolStats::construct_pool_lookup();
451451
let mut res = BytesMut::new();
@@ -461,13 +461,13 @@ where
461461
res.put_i32(5);
462462
res.put_u8(b'I');
463463

464-
write_all_half(stream, &res).await
464+
Ok(write_all_half(stream, &res).await?)
465465
}
466466

467467
/// Show currently connected clients
468468
async fn show_clients<T>(stream: &mut T) -> Result<(), Error>
469469
where
470-
T: tokio::io::AsyncWrite + std::marker::Unpin,
470+
T: AsyncWrite + Unpin,
471471
{
472472
let columns = vec![
473473
("client_id", DataType::Text),
@@ -517,12 +517,12 @@ where
517517
res.put_i32(5);
518518
res.put_u8(b'I');
519519

520-
write_all_half(stream, &res).await
520+
Ok(write_all_half(stream, &res).await?)
521521
}
522522

523523
async fn show_connections<T>(stream: &mut T) -> Result<(), Error>
524524
where
525-
T: tokio::io::AsyncWrite + std::marker::Unpin,
525+
T: AsyncWrite + Unpin,
526526
{
527527
let columns = vec![
528528
("total", DataType::Numeric),
@@ -556,12 +556,13 @@ where
556556
res.put_i32(5);
557557
res.put_u8(b'I');
558558

559-
write_all_half(stream, &res).await
559+
Ok(write_all_half(stream, &res).await?)
560560
}
561+
561562
/// Show currently connected servers
562563
async fn show_servers<T>(stream: &mut T) -> Result<(), Error>
563564
where
564-
T: tokio::io::AsyncWrite + std::marker::Unpin,
565+
T: AsyncWrite + Unpin,
565566
{
566567
let columns = vec![
567568
("server_id", DataType::Text),
@@ -627,13 +628,13 @@ where
627628
res.put_i32(5);
628629
res.put_u8(b'I');
629630

630-
write_all_half(stream, &res).await
631+
Ok(write_all_half(stream, &res).await?)
631632
}
632633

633634
/// Send response packets for shutdown.
634635
async fn shutdown<T>(stream: &mut T) -> Result<(), Error>
635636
where
636-
T: tokio::io::AsyncWrite + std::marker::Unpin,
637+
T: AsyncWrite + Unpin,
637638
{
638639
let mut res = BytesMut::new();
639640

@@ -655,13 +656,13 @@ where
655656
res.put_i32(5);
656657
res.put_u8(b'I');
657658

658-
write_all_half(stream, &res).await
659+
Ok(write_all_half(stream, &res).await?)
659660
}
660661

661662
/// Show Users.
662663
async fn show_users<T>(stream: &mut T) -> Result<(), Error>
663664
where
664-
T: tokio::io::AsyncWrite + std::marker::Unpin,
665+
T: AsyncWrite + Unpin,
665666
{
666667
let mut res = BytesMut::new();
667668

@@ -684,20 +685,19 @@ where
684685
res.put_i32(5);
685686
res.put_u8(b'I');
686687

687-
write_all_half(stream, &res).await
688+
Ok(write_all_half(stream, &res).await?)
688689
}
689690

690691
#[cfg(target_os = "linux")]
691692
async fn show_sockets<T>(stream: &mut T) -> Result<(), Error>
692693
where
693-
T: tokio::io::AsyncWrite + std::marker::Unpin,
694+
T: AsyncWrite + Unpin,
694695
{
696+
use crate::stats::get_socket_states_count;
697+
695698
let mut res = BytesMut::new();
696699

697-
let sockets_info = match get_socket_states_count(std::process::id()) {
698-
Ok(info) => info,
699-
Err(_) => return Err(Error::ServerError),
700-
};
700+
let sockets_info = get_socket_states_count(std::process::id()).map_err(ServerError::from)?;
701701

702702
res.put(row_description(&vec![
703703
// tcp
@@ -747,5 +747,5 @@ where
747747
res.put_i32(5);
748748
res.put_u8(b'I');
749749

750-
write_all_half(stream, &res).await
750+
Ok(write_all_half(stream, &res).await?)
751751
}

src/auth.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use std::fmt::{self, Display};
2+
3+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4+
pub enum AuthMethod {
5+
Sasl,
6+
ClearPassword,
7+
Jwt,
8+
Md5,
9+
}
10+
11+
impl Display for AuthMethod {
12+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13+
f.write_str(match self {
14+
Self::Sasl => "SASL",
15+
Self::ClearPassword => "clear password",
16+
Self::Jwt => "JWT",
17+
Self::Md5 => "MD5-encrypted password",
18+
})
19+
}
20+
}

0 commit comments

Comments
 (0)