1
- use blitz_traits:: net:: { BoxedHandler , Bytes , NetCallback , NetProvider , Request , SharedCallback } ;
1
+ use blitz_traits:: net:: {
2
+ AbortSignal , BoxedHandler , Bytes , NetCallback , NetProvider , Request , SharedCallback ,
3
+ } ;
2
4
use data_url:: DataUrl ;
3
5
use reqwest:: Client ;
4
- use std:: sync:: Arc ;
6
+ use std:: { marker :: PhantomData , pin :: Pin , sync:: Arc , task :: Poll } ;
5
7
use tokio:: {
6
8
runtime:: Handle ,
7
9
sync:: mpsc:: { UnboundedReceiver , UnboundedSender , unbounded_channel} ,
@@ -75,18 +77,6 @@ impl<D: 'static> Provider<D> {
75
77
} )
76
78
}
77
79
78
- async fn fetch_with_handler (
79
- client : Client ,
80
- doc_id : usize ,
81
- request : Request ,
82
- handler : BoxedHandler < D > ,
83
- res_callback : SharedCallback < D > ,
84
- ) -> Result < ( ) , ProviderError > {
85
- let ( _response_url, bytes) = Self :: fetch_inner ( client, request) . await ?;
86
- handler. bytes ( doc_id, bytes, res_callback) ;
87
- Ok ( ( ) )
88
- }
89
-
90
80
#[ allow( clippy:: type_complexity) ]
91
81
pub fn fetch_with_callback (
92
82
& self ,
@@ -108,24 +98,78 @@ impl<D: 'static> Provider<D> {
108
98
}
109
99
110
100
impl < D : ' static > NetProvider < D > for Provider < D > {
111
- fn fetch ( & self , doc_id : usize , request : Request , handler : BoxedHandler < D > ) {
101
+ fn fetch ( & self , doc_id : usize , mut request : Request , handler : BoxedHandler < D > ) {
112
102
let client = self . client . clone ( ) ;
113
103
let callback = Arc :: clone ( & self . resource_callback ) ;
114
104
println ! ( "Fetching {}" , & request. url) ;
115
105
self . rt . spawn ( async move {
116
106
let url = request. url . to_string ( ) ;
117
- let res = Self :: fetch_with_handler ( client, doc_id, request, handler, callback) . await ;
118
- if let Err ( e) = res {
119
- eprintln ! ( "Error fetching {url}: {e:?}" ) ;
107
+ let signal = request. signal . take ( ) ;
108
+ let result = if let Some ( signal) = signal {
109
+ AbortFetch :: new (
110
+ signal,
111
+ Box :: pin ( async move { Self :: fetch_inner ( client, request) . await } ) ,
112
+ )
113
+ . await
120
114
} else {
121
- println ! ( "Success {url}" ) ;
115
+ Self :: fetch_inner ( client, request) . await
116
+ } ;
117
+
118
+ match result {
119
+ Ok ( ( _response_url, bytes) ) => {
120
+ handler. bytes ( doc_id, bytes, callback) ;
121
+ println ! ( "Success {url}" ) ;
122
+ }
123
+ Err ( e) => {
124
+ eprintln ! ( "Error fetching {url}: {e:?}" ) ;
125
+ }
122
126
}
123
127
} ) ;
124
128
}
125
129
}
126
130
131
+ struct AbortFetch < F , T > {
132
+ signal : AbortSignal ,
133
+ future : F ,
134
+ _rt : PhantomData < T > ,
135
+ }
136
+
137
+ impl < F , T > AbortFetch < F , T > {
138
+ fn new ( signal : AbortSignal , future : F ) -> Self {
139
+ Self {
140
+ signal,
141
+ future,
142
+ _rt : PhantomData ,
143
+ }
144
+ }
145
+ }
146
+
147
+ impl < F , T > Future for AbortFetch < F , T >
148
+ where
149
+ F : Future + Unpin + Send + ' static ,
150
+ F :: Output : Send + Into < Result < T , ProviderError > > + ' static ,
151
+ T : Unpin ,
152
+ {
153
+ type Output = Result < T , ProviderError > ;
154
+
155
+ fn poll (
156
+ mut self : std:: pin:: Pin < & mut Self > ,
157
+ cx : & mut std:: task:: Context < ' _ > ,
158
+ ) -> std:: task:: Poll < Self :: Output > {
159
+ if self . signal . aborted ( ) {
160
+ return Poll :: Ready ( Err ( ProviderError :: Abort ) ) ;
161
+ }
162
+
163
+ match Pin :: new ( & mut self . future ) . poll ( cx) {
164
+ Poll :: Ready ( output) => Poll :: Ready ( output. into ( ) ) ,
165
+ Poll :: Pending => Poll :: Pending ,
166
+ }
167
+ }
168
+ }
169
+
127
170
#[ derive( Debug ) ]
128
171
pub enum ProviderError {
172
+ Abort ,
129
173
Io ( std:: io:: Error ) ,
130
174
DataUrl ( data_url:: DataUrlError ) ,
131
175
DataUrlBase64 ( data_url:: forgiving_base64:: InvalidBase64 ) ,
0 commit comments