@@ -20,37 +20,34 @@ use std::{
20
20
#[ cfg( test) ]
21
21
mod test;
22
22
23
- struct StreamWrapper < S > {
23
+ struct StreamWrapper < S : Unpin > {
24
24
stream : S ,
25
25
waker : Waker ,
26
26
}
27
27
28
28
impl < S > fmt:: Debug for StreamWrapper < S >
29
29
where
30
- S : fmt:: Debug ,
30
+ S : fmt:: Debug + Unpin ,
31
31
{
32
32
fn fmt ( & self , fmt : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
33
33
self . stream . fmt ( fmt)
34
34
}
35
35
}
36
36
37
- impl < S > StreamWrapper < S > {
38
- /// # Safety
39
- ///
40
- /// The wrapper must be pinned in memory.
41
- unsafe fn parts ( & mut self ) -> ( Pin < & mut S > , Context < ' _ > ) {
42
- let stream = unsafe { Pin :: new_unchecked ( & mut self . stream ) } ;
37
+ impl < S : Unpin > StreamWrapper < S > {
38
+ fn parts ( & mut self ) -> ( Pin < & mut S > , Context < ' _ > ) {
39
+ let stream = Pin :: new ( & mut self . stream ) ;
43
40
let context = Context :: from_waker ( & self . waker ) ;
44
41
( stream, context)
45
42
}
46
43
}
47
44
48
45
impl < S > Read for StreamWrapper < S >
49
46
where
50
- S : AsyncRead ,
47
+ S : AsyncRead + Unpin ,
51
48
{
52
49
fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
53
- let ( stream, mut cx) = unsafe { self . parts ( ) } ;
50
+ let ( stream, mut cx) = self . parts ( ) ;
54
51
match stream. poll_read ( & mut cx, buf) ? {
55
52
Poll :: Ready ( nread) => Ok ( nread) ,
56
53
Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
@@ -60,18 +57,18 @@ where
60
57
61
58
impl < S > Write for StreamWrapper < S >
62
59
where
63
- S : AsyncWrite ,
60
+ S : AsyncWrite + Unpin ,
64
61
{
65
62
fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
66
- let ( stream, mut cx) = unsafe { self . parts ( ) } ;
63
+ let ( stream, mut cx) = self . parts ( ) ;
67
64
match stream. poll_write ( & mut cx, buf) {
68
65
Poll :: Ready ( r) => r,
69
66
Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
70
67
}
71
68
}
72
69
73
70
fn flush ( & mut self ) -> io:: Result < ( ) > {
74
- let ( stream, mut cx) = unsafe { self . parts ( ) } ;
71
+ let ( stream, mut cx) = self . parts ( ) ;
75
72
match stream. poll_flush ( & mut cx) {
76
73
Poll :: Ready ( r) => r,
77
74
Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
@@ -99,11 +96,11 @@ fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
99
96
100
97
/// An asynchronous version of [`openssl::ssl::SslStream`].
101
98
#[ derive( Debug ) ]
102
- pub struct SslStream < S > ( ssl:: SslStream < StreamWrapper < S > > ) ;
99
+ pub struct SslStream < S : Unpin > ( ssl:: SslStream < StreamWrapper < S > > ) ;
103
100
104
101
impl < S > SslStream < S >
105
102
where
106
- S : AsyncRead + AsyncWrite ,
103
+ S : AsyncRead + AsyncWrite + Unpin ,
107
104
{
108
105
/// Like [`SslStream::new`](ssl::SslStream::new).
109
106
pub fn new ( ssl : Ssl , stream : S ) -> Result < Self , ErrorStack > {
@@ -206,7 +203,7 @@ where
206
203
}
207
204
}
208
205
209
- impl < S > SslStream < S > {
206
+ impl < S : Unpin > SslStream < S > {
210
207
/// Returns a shared reference to the `Ssl` object associated with this stream.
211
208
pub fn ssl ( & self ) -> & SslRef {
212
209
self . 0 . ssl ( )
@@ -224,7 +221,7 @@ impl<S> SslStream<S> {
224
221
225
222
/// Returns a pinned mutable reference to the underlying stream.
226
223
pub fn get_pin_mut ( self : Pin < & mut Self > ) -> Pin < & mut S > {
227
- unsafe { Pin :: new_unchecked ( & mut self . get_unchecked_mut ( ) . 0 . get_mut ( ) . stream ) }
224
+ Pin :: new ( & mut self . get_mut ( ) . 0 . get_mut ( ) . stream )
228
225
}
229
226
230
227
fn with_context < F , R > ( self : Pin < & mut Self > , ctx : & mut Context < ' _ > , f : F ) -> R
@@ -239,7 +236,7 @@ impl<S> SslStream<S> {
239
236
240
237
impl < S > AsyncRead for SslStream < S >
241
238
where
242
- S : AsyncRead + AsyncWrite ,
239
+ S : AsyncRead + AsyncWrite + Unpin ,
243
240
{
244
241
fn poll_read (
245
242
self : Pin < & mut Self > ,
@@ -252,7 +249,7 @@ where
252
249
253
250
impl < S > AsyncWrite for SslStream < S >
254
251
where
255
- S : AsyncRead + AsyncWrite ,
252
+ S : AsyncRead + AsyncWrite + Unpin ,
256
253
{
257
254
fn poll_write ( self : Pin < & mut Self > , ctx : & mut Context , buf : & [ u8 ] ) -> Poll < io:: Result < usize > > {
258
255
self . with_context ( ctx, |s| cvt ( s. write ( buf) ) )
0 commit comments