@@ -41,8 +41,11 @@ type SSEEvent struct {
4141 Data string
4242}
4343
44- func (e * SSEEvent ) decode (b []byte ) error {
44+ // decodeSSEEvent parses the raw SSE event data and returns an SSEEvent pointer and an error.
45+ func decodeSSEEvent (b []byte ) (* SSEEvent , error ) {
4546 chunks := [][]byte {}
47+ e := & SSEEvent {Type : SSETypeDefault }
48+
4649 for _ , line := range bytes .Split (b , []byte ("\n " )) {
4750 // Parse field and value from line
4851 parts := bytes .SplitN (line , []byte {':' }, 2 )
@@ -56,7 +59,7 @@ func (e *SSEEvent) decode(b []byte) error {
5659 if len (parts ) == 2 {
5760 value = parts [1 ]
5861 // Trim leading space if present
59- value , _ = bytes .CutPrefix (value , []byte (" " ))
62+ value = bytes .TrimPrefix (value , []byte (" " ))
6063 }
6164
6265 switch field {
@@ -73,16 +76,21 @@ func (e *SSEEvent) decode(b []byte) error {
7376
7477 data := bytes .Join (chunks , []byte ("\n " ))
7578 if ! utf8 .Valid (data ) {
76- return ErrInvalidUTF8Data
79+ return nil , ErrInvalidUTF8Data
7780 }
7881 e .Data = string (data )
7982
80- return nil
83+ // Return nil if event data is empty and event type is not "done"
84+ if e .Data == "" && e .Type != SSETypeDone {
85+ return nil , nil
86+ }
87+
88+ return e , nil
8189}
8290
8391func (e * SSEEvent ) String () string {
8492 switch e .Type {
85- case "output" :
93+ case SSETypeOutput :
8694 return e .Data
8795 default :
8896 return ""
@@ -126,9 +134,6 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (
126134}
127135
128136func (r * Client ) streamPrediction (ctx context.Context , prediction * Prediction , lastEvent * SSEEvent , sseChan chan SSEEvent , errChan chan error ) {
129- g , ctx := errgroup .WithContext (ctx )
130- done := make (chan struct {})
131-
132137 url := prediction .URLs ["stream" ]
133138 if url == "" {
134139 errChan <- errors .New ("streaming not supported or not enabled for this prediction" )
@@ -137,7 +142,10 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
137142
138143 req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
139144 if err != nil {
140- errChan <- fmt .Errorf ("failed to create request: %w" , err )
145+ select {
146+ case errChan <- fmt .Errorf ("failed to create request: %w" , err ):
147+ default :
148+ }
141149 return
142150 }
143151 req .Header .Set ("Accept" , "text/event-stream" )
@@ -150,22 +158,33 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
150158
151159 resp , err := r .c .Do (req )
152160 if err != nil || resp == nil {
153- if resp != nil {
154- resp .Body .Close ()
161+ if resp == nil {
162+ err = errors .New ("received nil response" )
163+ } else {
164+ defer resp .Body .Close ()
165+ }
166+ select {
167+ case errChan <- fmt .Errorf ("failed to send request: %w" , err ):
168+ default :
155169 }
156- errChan <- fmt .Errorf ("failed to send request: %w" , err )
157170 return
158171 }
159172
160173 if resp .StatusCode != http .StatusOK {
161- errChan <- fmt .Errorf ("received invalid status code: %d" , resp .StatusCode )
174+ select {
175+ case errChan <- fmt .Errorf ("received invalid status code: %d" , resp .StatusCode ):
176+ default :
177+ }
162178 return
163179 }
164180
165181 reader := bufio .NewReader (resp .Body )
166182 var buf bytes.Buffer
167183 lineChan := make (chan []byte )
168184
185+ g , ctx := errgroup .WithContext (ctx )
186+ done := make (chan struct {})
187+
169188 g .Go (func () error {
170189 defer close (lineChan )
171190 defer resp .Body .Close ()
@@ -208,18 +227,22 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
208227 b := buf .Bytes ()
209228 buf .Reset ()
210229
211- event := SSEEvent { Type : SSETypeDefault }
212- if err := event . decode ( b ); err != nil {
230+ event , err := decodeSSEEvent ( b )
231+ if err != nil {
213232 select {
214233 case errChan <- err :
215234 default :
216235 }
217- close (done )
218- return
236+ continue
237+ }
238+
239+ if event == nil {
240+ // Skip empty events
241+ continue
219242 }
220243
221244 select {
222- case sseChan <- event :
245+ case sseChan <- * event :
223246 case <- done :
224247 return
225248 case <- ctx .Done ():
@@ -238,9 +261,6 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
238261 go func () {
239262 err := g .Wait ()
240263
241- defer close (sseChan )
242- defer close (errChan )
243-
244264 if err != nil {
245265 if errors .Is (err , io .EOF ) {
246266 // Attempt to reconnect if the connection was closed before the stream was done
@@ -255,5 +275,8 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
255275 }
256276 }
257277 }
278+
279+ close (sseChan )
280+ close (errChan )
258281 }()
259282}
0 commit comments