Skip to content

Commit a862c55

Browse files
committed
tokio: allow binding to a runtime directly
1 parent e3dbdd3 commit a862c55

File tree

5 files changed

+83
-56
lines changed

5 files changed

+83
-56
lines changed

README.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,26 @@
2525

2626
```rust
2727
use async_rs::{Runtime, TokioRuntime, traits::*};
28-
use std::{io, sync::Arc, time::Duration};
28+
use std::{io, time::Duration};
2929

30-
async fn get_a(rt: Arc<TokioRuntime>) -> io::Result<u32> {
31-
rt.clone()
32-
.spawn_blocking(move || rt.block_on(async { Ok(12) }))
33-
.await
30+
async fn get_a(rt: &TokioRuntime) -> io::Result<u32> {
31+
rt.spawn_blocking(|| Ok(12)).await
3432
}
3533

36-
async fn get_b(rt: Arc<TokioRuntime>) -> io::Result<u32> {
34+
async fn get_b(rt: &TokioRuntime) -> io::Result<u32> {
3735
rt.spawn(async { Ok(30) }).await
3836
}
3937

40-
async fn tokio_main() -> io::Result<()> {
41-
let rt = Arc::new(Runtime::tokio());
42-
let a = get_a(rt.clone()).await?;
43-
let b = get_b(rt.clone()).await?;
38+
async fn tokio_main(rt: &TokioRuntime) -> io::Result<()> {
39+
let a = get_a(&rt).await?;
40+
let b = get_b(&rt).await?;
4441
rt.sleep(Duration::from_millis(500)).await;
4542
assert_eq!(a + b, 42);
4643
Ok(())
4744
}
4845

49-
#[tokio::main]
50-
async fn main() -> io::Result<()> {
51-
tokio_main().await
46+
fn main() -> io::Result<()> {
47+
let rt = Runtime::tokio()?;
48+
rt.block_on(tokio_main(&rt))
5249
}
5350
```

examples/send-recv.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ fn send(mut stream: impl AsyncRead + AsyncWrite + Unpin) -> io::Result<()> {
2828
}
2929
}
3030

31-
async fn tokio_main() -> io::Result<()> {
32-
let rt = Runtime::tokio();
31+
async fn tokio_main(rt: &TokioRuntime) -> io::Result<()> {
3332
let listener = listener(&rt).await?;
3433
let sender = sender(&rt).await?;
3534
let stream = rt
@@ -47,7 +46,7 @@ async fn tokio_main() -> io::Result<()> {
4746
Ok(())
4847
}
4948

50-
#[tokio::main]
51-
async fn main() -> io::Result<()> {
52-
tokio_main().await
49+
fn main() -> io::Result<()> {
50+
let rt = Runtime::tokio()?;
51+
rt.block_on(tokio_main(&rt))
5352
}

examples/tokio.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
use async_rs::{Runtime, TokioRuntime, traits::*};
2-
use std::{io, sync::Arc, time::Duration};
2+
use std::{io, time::Duration};
33

4-
async fn get_a(rt: Arc<TokioRuntime>) -> io::Result<u32> {
5-
rt.clone()
6-
.spawn_blocking(move || rt.block_on(async { Ok(12) }))
7-
.await
4+
async fn get_a(rt: &TokioRuntime) -> io::Result<u32> {
5+
rt.spawn_blocking(|| Ok(12)).await
86
}
97

10-
async fn get_b(rt: Arc<TokioRuntime>) -> io::Result<u32> {
8+
async fn get_b(rt: &TokioRuntime) -> io::Result<u32> {
119
rt.spawn(async { Ok(30) }).await
1210
}
1311

14-
async fn tokio_main() -> io::Result<()> {
15-
let rt = Arc::new(Runtime::tokio());
16-
let a = get_a(rt.clone()).await?;
17-
let b = get_b(rt.clone()).await?;
12+
async fn tokio_main(rt: &TokioRuntime) -> io::Result<()> {
13+
let a = get_a(&rt).await?;
14+
let b = get_b(&rt).await?;
1815
rt.sleep(Duration::from_millis(500)).await;
1916
assert_eq!(a + b, 42);
2017
Ok(())
2118
}
2219

23-
#[tokio::main]
24-
async fn main() -> io::Result<()> {
25-
tokio_main().await
20+
fn main() -> io::Result<()> {
21+
let rt = Runtime::tokio()?;
22+
rt.block_on(tokio_main(&rt))
2623
}
2724

28-
#[tokio::test]
29-
async fn tokio() -> io::Result<()> {
30-
tokio_main().await
25+
#[test]
26+
fn tokio() -> io::Result<()> {
27+
let rt = Runtime::tokio()?;
28+
rt.block_on(tokio_main(&rt))
3129
}

src/implementors/tokio.rs

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,43 @@ use std::{
1717
task::{Context, Poll},
1818
time::{Duration, Instant},
1919
};
20-
use tokio::{net::TcpStream, runtime::Handle};
20+
use tokio::{
21+
net::TcpStream,
22+
runtime::{EnterGuard, Handle, Runtime as TokioRT},
23+
};
2124
use tokio_stream::{StreamExt, wrappers::IntervalStream};
2225
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
2326

2427
/// Type alias for the tokio runtime
2528
pub type TokioRuntime = Runtime<Tokio>;
2629

2730
impl TokioRuntime {
31+
/// Create a new TokioRuntime and bind it to this tokio runtime.
32+
pub fn tokio() -> io::Result<Self> {
33+
Ok(Self::tokio_with_runtime(TokioRT::new()?))
34+
}
35+
2836
/// Create a new TokioRuntime and bind it to the current tokio runtime by default.
29-
pub fn tokio() -> Self {
37+
pub fn tokio_current() -> Self {
3038
Self::new(Tokio::current())
3139
}
3240

3341
/// Create a new TokioRuntime and bind it to the tokio runtime associated to this handle by default.
3442
pub fn tokio_with_handle(handle: Handle) -> Self {
3543
Self::new(Tokio::default().with_handle(handle))
3644
}
45+
46+
/// Create a new TokioRuntime and bind it to this tokio runtime.
47+
pub fn tokio_with_runtime(runtime: TokioRT) -> Self {
48+
Self::new(Tokio::default().with_runtime(runtime))
49+
}
3750
}
3851

3952
/// Dummy object implementing async common interfaces on top of tokio
40-
#[derive(Default, Debug, Clone)]
53+
#[derive(Default, Debug)]
4154
pub struct Tokio {
4255
handle: Option<Handle>,
56+
runtime: Option<TokioRT>,
4357
}
4458

4559
impl Tokio {
@@ -49,13 +63,32 @@ impl Tokio {
4963
self
5064
}
5165

66+
/// Bind to the tokio Runtime associated to this handle by default.
67+
pub fn with_runtime(mut self, runtime: TokioRT) -> Self {
68+
let handle = runtime.handle().clone();
69+
self.runtime = Some(runtime);
70+
self.with_handle(handle)
71+
}
72+
5273
/// Bind to the current tokio Runtime by default.
5374
pub fn current() -> Self {
5475
Self::default().with_handle(Handle::current())
5576
}
5677

57-
pub(crate) fn handle(&self) -> Option<Handle> {
58-
Handle::try_current().ok().or_else(|| self.handle.clone())
78+
fn handle(&self) -> Option<Handle> {
79+
self.runtime
80+
.as_ref()
81+
.map(|r| r.handle().clone())
82+
.or_else(|| Handle::try_current().ok())
83+
.or_else(|| self.handle.clone())
84+
}
85+
86+
fn enter(&self) -> Option<EnterGuard<'_>> {
87+
self.runtime
88+
.as_ref()
89+
.map(TokioRT::handle)
90+
.or(self.handle.as_ref())
91+
.map(Handle::enter)
5992
}
6093
}
6194

@@ -65,7 +98,9 @@ impl RuntimeKit for Tokio {}
6598

6699
impl Executor for Tokio {
67100
fn block_on<T, F: Future<Output = T>>(&self, f: F) -> T {
68-
if let Some(handle) = self.handle() {
101+
if let Some(runtime) = self.runtime.as_ref() {
102+
runtime.block_on(f)
103+
} else if let Some(handle) = self.handle() {
69104
handle.block_on(f)
70105
} else {
71106
Handle::current().block_on(f)
@@ -123,7 +158,7 @@ impl Reactor for Tokio {
123158
&self,
124159
socket: H,
125160
) -> io::Result<impl AsyncRead + AsyncWrite + Send + Unpin + 'static> {
126-
let _enter = self.handle().as_ref().map(|handle| handle.enter());
161+
let _enter = self.enter();
127162
cfg_if! {
128163
if #[cfg(unix)] {
129164
Ok(unix::AsyncFdWrapper(
@@ -138,19 +173,20 @@ impl Reactor for Tokio {
138173
}
139174

140175
fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send + 'static {
176+
let _enter = self.enter();
141177
tokio::time::sleep(dur)
142178
}
143179

144180
fn interval(&self, dur: Duration) -> impl Stream<Item = Instant> + Send + 'static {
145-
let _enter = self.handle().as_ref().map(|handle| handle.enter());
181+
let _enter = self.enter();
146182
IntervalStream::new(tokio::time::interval(dur)).map(tokio::time::Instant::into_std)
147183
}
148184

149185
fn tcp_connect_addr(
150186
&self,
151187
addr: SocketAddr,
152188
) -> impl Future<Output = io::Result<Self::TcpStream>> + Send + 'static {
153-
let _enter = self.handle().as_ref().map(|handle| handle.enter());
189+
let _enter = self.enter();
154190
async move { Ok(TcpStream::connect(addr).await?.compat()) }
155191
}
156192
}

src/lib.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,27 @@
1414
//!
1515
//! ```rust
1616
//! use async_rs::{Runtime, TokioRuntime, traits::*};
17-
//! use std::{io, sync::Arc, time::Duration};
17+
//! use std::{io, time::Duration};
1818
//!
19-
//! async fn get_a(rt: Arc<TokioRuntime>) -> io::Result<u32> {
20-
//! rt.clone()
21-
//! .spawn_blocking(move || rt.block_on(async { Ok(12) }))
22-
//! .await
19+
//! async fn get_a(rt: &TokioRuntime) -> io::Result<u32> {
20+
//! rt.spawn_blocking(|| Ok(12)).await
2321
//! }
2422
//!
25-
//! async fn get_b(rt: Arc<TokioRuntime>) -> io::Result<u32> {
23+
//! async fn get_b(rt: &TokioRuntime) -> io::Result<u32> {
2624
//! rt.spawn(async { Ok(30) }).await
2725
//! }
2826
//!
29-
//! async fn tokio_main() -> io::Result<()> {
30-
//! let rt = Arc::new(Runtime::tokio());
31-
//! let a = get_a(rt.clone()).await?;
32-
//! let b = get_b(rt.clone()).await?;
27+
//! async fn tokio_main(rt: &TokioRuntime) -> io::Result<()> {
28+
//! let a = get_a(&rt).await?;
29+
//! let b = get_b(&rt).await?;
3330
//! rt.sleep(Duration::from_millis(500)).await;
3431
//! assert_eq!(a + b, 42);
3532
//! Ok(())
3633
//! }
3734
//!
38-
//! #[tokio::main]
39-
//! async fn main() -> io::Result<()> {
40-
//! tokio_main().await
35+
//! fn main() -> io::Result<()> {
36+
//! let rt = Runtime::tokio()?;
37+
//! rt.block_on(tokio_main(&rt))
4138
//! }
4239
//! ```
4340

0 commit comments

Comments
 (0)