@@ -446,10 +446,10 @@ impl<C: Config> Client<C> {
446
446
path : & str ,
447
447
request : I ,
448
448
event_mapper : impl Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static ,
449
- ) -> Pin < Box < dyn Stream < Item = Result < O , OpenAIError > > + Send > >
449
+ ) -> OpenAIEventMappedStream < O >
450
450
where
451
451
I : Serialize ,
452
- O : DeserializeOwned + Send + ' static ,
452
+ O : DeserializeOwned + Send + ' static
453
453
{
454
454
let event_source = self
455
455
. http_client
@@ -460,8 +460,7 @@ impl<C: Config> Client<C> {
460
460
. eventsource ( )
461
461
. unwrap ( ) ;
462
462
463
- // stream_mapped_raw_events(event_source, event_mapper).await
464
- todo ! ( )
463
+ OpenAIEventMappedStream :: new ( event_source, event_mapper)
465
464
}
466
465
467
466
/// Make HTTP GET request to receive SSE
@@ -491,19 +490,21 @@ impl<C: Config> Client<C> {
491
490
/// Request which responds with SSE.
492
491
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
493
492
#[ pin_project]
494
- pub struct OpenAIEventStream < O > {
493
+ pub struct OpenAIEventStream < O : DeserializeOwned + Send + ' static > {
495
494
#[ pin]
496
495
stream : Filter < EventSource , future:: Ready < bool > , fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > > ,
496
+ done : bool ,
497
497
_phantom_data : PhantomData < O > ,
498
498
}
499
499
500
- impl < O > OpenAIEventStream < O > {
500
+ impl < O : DeserializeOwned + Send + ' static > OpenAIEventStream < O > {
501
501
pub ( crate ) fn new ( event_source : EventSource ) -> Self {
502
502
Self {
503
503
stream : event_source. filter ( |result|
504
504
// filter out the first event which is always Event::Open
505
505
future:: ready ( !( result. is_ok ( ) && result. as_ref ( ) . unwrap ( ) . eq ( & Event :: Open ) ) )
506
506
) ,
507
+ done : false ,
507
508
_phantom_data : PhantomData ,
508
509
}
509
510
}
@@ -514,6 +515,9 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
514
515
515
516
fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
516
517
let this = self . project ( ) ;
518
+ if * this. done {
519
+ return Poll :: Ready ( None ) ;
520
+ }
517
521
let stream: Pin < & mut _ > = this. stream ;
518
522
match stream. poll_next ( cx) {
519
523
Poll :: Ready ( response) => {
@@ -524,17 +528,24 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
524
528
Event :: Open => unreachable ! ( ) , // it has been filtered out
525
529
Event :: Message ( message) => {
526
530
if message. data == "[DONE]" {
531
+ * this. done = true ;
527
532
Poll :: Ready ( None ) // end of the stream, defined by OpenAI
528
533
} else {
529
534
// deserialize the data
530
535
match serde_json:: from_str :: < O > ( & message. data ) {
531
- Err ( e) => Poll :: Ready ( Some ( Err ( map_deserialization_error ( e, & message. data . as_bytes ( ) ) ) ) ) ,
536
+ Err ( e) => {
537
+ * this. done = true ;
538
+ Poll :: Ready ( Some ( Err ( map_deserialization_error ( e, & message. data . as_bytes ( ) ) ) ) )
539
+ }
532
540
Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
533
541
}
534
542
}
535
543
}
536
544
}
537
- Err ( e) => Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
545
+ Err ( e) => {
546
+ * this. done = true ;
547
+ Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
548
+ }
538
549
}
539
550
}
540
551
}
@@ -543,6 +554,77 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
543
554
}
544
555
}
545
556
557
+ #[ pin_project]
558
+ pub struct OpenAIEventMappedStream < O >
559
+ where O : Send + ' static
560
+ {
561
+ #[ pin]
562
+ stream : Filter < EventSource , future:: Ready < bool > , fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > > ,
563
+ event_mapper : Box < dyn Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static > ,
564
+ done : bool ,
565
+ _phantom_data : PhantomData < O > ,
566
+ }
567
+
568
+ impl < O > OpenAIEventMappedStream < O >
569
+ where O : Send + ' static
570
+ {
571
+ pub ( crate ) fn new < M > ( event_source : EventSource , event_mapper : M ) -> Self
572
+ where M : Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static {
573
+ Self {
574
+ stream : event_source. filter ( |result|
575
+ // filter out the first event which is always Event::Open
576
+ future:: ready ( !( result. is_ok ( ) && result. as_ref ( ) . unwrap ( ) . eq ( & Event :: Open ) ) )
577
+ ) ,
578
+ done : false ,
579
+ event_mapper : Box :: new ( event_mapper) ,
580
+ _phantom_data : PhantomData ,
581
+ }
582
+ }
583
+ }
584
+
585
+
586
+ impl < O > Stream for OpenAIEventMappedStream < O >
587
+ where O : Send + ' static
588
+ {
589
+ type Item = Result < O , OpenAIError > ;
590
+
591
+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
592
+ let this = self . project ( ) ;
593
+ if * this. done {
594
+ return Poll :: Ready ( None ) ;
595
+ }
596
+ let stream: Pin < & mut _ > = this. stream ;
597
+ match stream. poll_next ( cx) {
598
+ Poll :: Ready ( response) => {
599
+ match response {
600
+ None => Poll :: Ready ( None ) , // end of the stream
601
+ Some ( result) => match result {
602
+ Ok ( event) => match event {
603
+ Event :: Open => unreachable ! ( ) , // it has been filtered out
604
+ Event :: Message ( message) => {
605
+ if message. data == "[DONE]" {
606
+ * this. done = true ;
607
+ }
608
+ let response = ( this. event_mapper ) ( message) ;
609
+ match response {
610
+ Ok ( output) => Poll :: Ready ( Some ( Ok ( output) ) ) ,
611
+ Err ( _) => Poll :: Ready ( None )
612
+ }
613
+ }
614
+ }
615
+ Err ( e) => {
616
+ * this. done = true ;
617
+ Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
618
+ }
619
+ }
620
+ }
621
+ }
622
+ Poll :: Pending => Poll :: Pending
623
+ }
624
+ }
625
+ }
626
+
627
+
546
628
// pub(crate) async fn stream_mapped_raw_events<O>(
547
629
// mut event_source: EventSource,
548
630
// event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
0 commit comments