Skip to content

Commit 6fe731c

Browse files
committed
explicitely require Unpin stream
instead of doing unsafe operations
1 parent 5164e39 commit 6fe731c

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

src/lib.rs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,34 @@ use std::{
2020
#[cfg(test)]
2121
mod test;
2222

23-
struct StreamWrapper<S> {
23+
struct StreamWrapper<S: Unpin> {
2424
stream: S,
2525
waker: Waker,
2626
}
2727

2828
impl<S> fmt::Debug for StreamWrapper<S>
2929
where
30-
S: fmt::Debug,
30+
S: fmt::Debug + Unpin,
3131
{
3232
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
3333
self.stream.fmt(fmt)
3434
}
3535
}
3636

37-
impl<S> StreamWrapper<S> {
38-
/// # Safety
39-
///
40-
/// The wrapper must be pinned in memory.
41-
unsafe fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
42-
let stream = unsafe { Pin::new_unchecked(&mut self.stream) };
37+
impl<S: Unpin> StreamWrapper<S> {
38+
fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
39+
let stream = Pin::new(&mut self.stream);
4340
let context = Context::from_waker(&self.waker);
4441
(stream, context)
4542
}
4643
}
4744

4845
impl<S> Read for StreamWrapper<S>
4946
where
50-
S: AsyncRead,
47+
S: AsyncRead + Unpin,
5148
{
5249
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
53-
let (stream, mut cx) = unsafe { self.parts() };
50+
let (stream, mut cx) = self.parts();
5451
match stream.poll_read(&mut cx, buf)? {
5552
Poll::Ready(nread) => Ok(nread),
5653
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
@@ -60,18 +57,18 @@ where
6057

6158
impl<S> Write for StreamWrapper<S>
6259
where
63-
S: AsyncWrite,
60+
S: AsyncWrite + Unpin,
6461
{
6562
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
66-
let (stream, mut cx) = unsafe { self.parts() };
63+
let (stream, mut cx) = self.parts();
6764
match stream.poll_write(&mut cx, buf) {
6865
Poll::Ready(r) => r,
6966
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
7067
}
7168
}
7269

7370
fn flush(&mut self) -> io::Result<()> {
74-
let (stream, mut cx) = unsafe { self.parts() };
71+
let (stream, mut cx) = self.parts();
7572
match stream.poll_flush(&mut cx) {
7673
Poll::Ready(r) => r,
7774
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
@@ -99,11 +96,11 @@ fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
9996

10097
/// An asynchronous version of [`openssl::ssl::SslStream`].
10198
#[derive(Debug)]
102-
pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
99+
pub struct SslStream<S: Unpin>(ssl::SslStream<StreamWrapper<S>>);
103100

104101
impl<S> SslStream<S>
105102
where
106-
S: AsyncRead + AsyncWrite,
103+
S: AsyncRead + AsyncWrite + Unpin,
107104
{
108105
/// Like [`SslStream::new`](ssl::SslStream::new).
109106
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
@@ -206,7 +203,7 @@ where
206203
}
207204
}
208205

209-
impl<S> SslStream<S> {
206+
impl<S: Unpin> SslStream<S> {
210207
/// Returns a shared reference to the `Ssl` object associated with this stream.
211208
pub fn ssl(&self) -> &SslRef {
212209
self.0.ssl()
@@ -224,7 +221,7 @@ impl<S> SslStream<S> {
224221

225222
/// Returns a pinned mutable reference to the underlying stream.
226223
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
227-
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
224+
Pin::new(&mut self.get_mut().0.get_mut().stream)
228225
}
229226

230227
fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
@@ -239,7 +236,7 @@ impl<S> SslStream<S> {
239236

240237
impl<S> AsyncRead for SslStream<S>
241238
where
242-
S: AsyncRead + AsyncWrite,
239+
S: AsyncRead + AsyncWrite + Unpin,
243240
{
244241
fn poll_read(
245242
self: Pin<&mut Self>,
@@ -252,7 +249,7 @@ where
252249

253250
impl<S> AsyncWrite for SslStream<S>
254251
where
255-
S: AsyncRead + AsyncWrite,
252+
S: AsyncRead + AsyncWrite + Unpin,
256253
{
257254
fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
258255
self.with_context(ctx, |s| cvt(s.write(buf)))

0 commit comments

Comments
 (0)