@@ -14,15 +14,15 @@ use std::{
14
14
fmt, future,
15
15
io:: { self , Read , Write } ,
16
16
pin:: Pin ,
17
- task:: { Context , Poll } ,
17
+ task:: { Context , Poll , Waker } ,
18
18
} ;
19
19
20
20
#[ cfg( test) ]
21
21
mod test;
22
22
23
23
struct StreamWrapper < S > {
24
24
stream : S ,
25
- context : usize ,
25
+ waker : Waker ,
26
26
}
27
27
28
28
impl < S > fmt:: Debug for StreamWrapper < S >
@@ -37,12 +37,10 @@ where
37
37
impl < S > StreamWrapper < S > {
38
38
/// # Safety
39
39
///
40
- /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
41
- /// wrapper must be pinned in memory.
42
- unsafe fn parts ( & mut self ) -> ( Pin < & mut S > , & mut Context < ' _ > ) {
43
- debug_assert_ne ! ( self . context, 0 ) ;
44
- let stream = Pin :: new_unchecked ( & mut self . stream ) ;
45
- let context = & mut * ( self . context as * mut _ ) ;
40
+ /// The wrapper must be pinned in memory.
41
+ unsafe fn parts ( & mut self ) -> ( Pin < & mut S > , Context < ' _ > ) {
42
+ let stream = unsafe { Pin :: new_unchecked ( & mut self . stream ) } ;
43
+ let context = Context :: from_waker ( & self . waker ) ;
46
44
( stream, context)
47
45
}
48
46
}
52
50
S : AsyncRead ,
53
51
{
54
52
fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
55
- let ( stream, cx) = unsafe { self . parts ( ) } ;
56
- match stream. poll_read ( cx, buf) ? {
53
+ let ( stream, mut cx) = unsafe { self . parts ( ) } ;
54
+ match stream. poll_read ( & mut cx, buf) ? {
57
55
Poll :: Ready ( nread) => Ok ( nread) ,
58
56
Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
59
57
}
@@ -65,16 +63,16 @@ where
65
63
S : AsyncWrite ,
66
64
{
67
65
fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
68
- let ( stream, cx) = unsafe { self . parts ( ) } ;
69
- match stream. poll_write ( cx, buf) {
66
+ let ( stream, mut cx) = unsafe { self . parts ( ) } ;
67
+ match stream. poll_write ( & mut cx, buf) {
70
68
Poll :: Ready ( r) => r,
71
69
Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
72
70
}
73
71
}
74
72
75
73
fn flush ( & mut self ) -> io:: Result < ( ) > {
76
- let ( stream, cx) = unsafe { self . parts ( ) } ;
77
- match stream. poll_flush ( cx) {
74
+ let ( stream, mut cx) = unsafe { self . parts ( ) } ;
75
+ match stream. poll_flush ( & mut cx) {
78
76
Poll :: Ready ( r) => r,
79
77
Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
80
78
}
@@ -109,7 +107,14 @@ where
109
107
{
110
108
/// Like [`SslStream::new`](ssl::SslStream::new).
111
109
pub fn new ( ssl : Ssl , stream : S ) -> Result < Self , ErrorStack > {
112
- ssl:: SslStream :: new ( ssl, StreamWrapper { stream, context : 0 } ) . map ( SslStream )
110
+ ssl:: SslStream :: new (
111
+ ssl,
112
+ StreamWrapper {
113
+ stream,
114
+ waker : Waker :: noop ( ) . clone ( ) ,
115
+ } ,
116
+ )
117
+ . map ( SslStream )
113
118
}
114
119
115
120
/// Like [`SslStream::connect`](ssl::SslStream::connect).
@@ -227,10 +232,8 @@ impl<S> SslStream<S> {
227
232
F : FnOnce ( & mut ssl:: SslStream < StreamWrapper < S > > ) -> R ,
228
233
{
229
234
let this = unsafe { self . get_unchecked_mut ( ) } ;
230
- this. 0 . get_mut ( ) . context = ctx as * mut _ as usize ;
231
- let r = f ( & mut this. 0 ) ;
232
- this. 0 . get_mut ( ) . context = 0 ;
233
- r
235
+ this. 0 . get_mut ( ) . waker = ctx. waker ( ) . clone ( ) ;
236
+ f ( & mut this. 0 )
234
237
}
235
238
}
236
239
0 commit comments