Skip to content

Commit a0f628f

Browse files
committed
avoid pointers arythmetics
We only need to carry on the Waker. Start with the noop one so we can consider we always have a Waker. Don't clear the Waker on exit, we overide it before next fn anyways.
1 parent 1763a26 commit a0f628f

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

src/lib.rs

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ use std::{
1414
fmt, future,
1515
io::{self, Read, Write},
1616
pin::Pin,
17-
task::{Context, Poll},
17+
task::{Context, Poll, Waker},
1818
};
1919

2020
#[cfg(test)]
2121
mod test;
2222

2323
struct StreamWrapper<S> {
2424
stream: S,
25-
context: usize,
25+
waker: Waker,
2626
}
2727

2828
impl<S> fmt::Debug for StreamWrapper<S>
@@ -37,12 +37,10 @@ where
3737
impl<S> StreamWrapper<S> {
3838
/// # Safety
3939
///
40-
/// Must be called with `context` set to a valid pointer to a live `Context` object, and the
41-
/// wrapper must be pinned in memory.
42-
unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
43-
debug_assert_ne!(self.context, 0);
44-
let stream = Pin::new_unchecked(&mut self.stream);
45-
let context = &mut *(self.context as *mut _);
40+
/// The wrapper must be pinned in memory.
41+
unsafe fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
42+
let stream = unsafe { Pin::new_unchecked(&mut self.stream) };
43+
let context = Context::from_waker(&self.waker);
4644
(stream, context)
4745
}
4846
}
@@ -52,8 +50,8 @@ where
5250
S: AsyncRead,
5351
{
5452
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
55-
let (stream, cx) = unsafe { self.parts() };
56-
match stream.poll_read(cx, buf)? {
53+
let (stream, mut cx) = unsafe { self.parts() };
54+
match stream.poll_read(&mut cx, buf)? {
5755
Poll::Ready(nread) => Ok(nread),
5856
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
5957
}
@@ -65,16 +63,16 @@ where
6563
S: AsyncWrite,
6664
{
6765
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
68-
let (stream, cx) = unsafe { self.parts() };
69-
match stream.poll_write(cx, buf) {
66+
let (stream, mut cx) = unsafe { self.parts() };
67+
match stream.poll_write(&mut cx, buf) {
7068
Poll::Ready(r) => r,
7169
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
7270
}
7371
}
7472

7573
fn flush(&mut self) -> io::Result<()> {
76-
let (stream, cx) = unsafe { self.parts() };
77-
match stream.poll_flush(cx) {
74+
let (stream, mut cx) = unsafe { self.parts() };
75+
match stream.poll_flush(&mut cx) {
7876
Poll::Ready(r) => r,
7977
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
8078
}
@@ -109,7 +107,14 @@ where
109107
{
110108
/// Like [`SslStream::new`](ssl::SslStream::new).
111109
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
112-
ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
110+
ssl::SslStream::new(
111+
ssl,
112+
StreamWrapper {
113+
stream,
114+
waker: Waker::noop().clone(),
115+
},
116+
)
117+
.map(SslStream)
113118
}
114119

115120
/// Like [`SslStream::connect`](ssl::SslStream::connect).
@@ -227,10 +232,8 @@ impl<S> SslStream<S> {
227232
F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
228233
{
229234
let this = unsafe { self.get_unchecked_mut() };
230-
this.0.get_mut().context = ctx as *mut _ as usize;
231-
let r = f(&mut this.0);
232-
this.0.get_mut().context = 0;
233-
r
235+
this.0.get_mut().waker = ctx.waker().clone();
236+
f(&mut this.0)
234237
}
235238
}
236239

0 commit comments

Comments
 (0)