Skip to content

Commit 0c07400

Browse files
authored
Fix server sent event decoding (#81)
* Refactor SSE creation to stop propagating empty events * This works * Remove unnecessary once for closing channels * Simplify String() method * Use existing context instead of TODO * Fix linting error
1 parent eb1270a commit 0c07400

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

stream.go

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ type SSEEvent struct {
4141
Data string
4242
}
4343

44-
func (e *SSEEvent) decode(b []byte) error {
44+
// decodeSSEEvent parses the raw SSE event data and returns an SSEEvent pointer and an error.
45+
func decodeSSEEvent(b []byte) (*SSEEvent, error) {
4546
chunks := [][]byte{}
47+
e := &SSEEvent{Type: SSETypeDefault}
48+
4649
for _, line := range bytes.Split(b, []byte("\n")) {
4750
// Parse field and value from line
4851
parts := bytes.SplitN(line, []byte{':'}, 2)
@@ -56,7 +59,7 @@ func (e *SSEEvent) decode(b []byte) error {
5659
if len(parts) == 2 {
5760
value = parts[1]
5861
// Trim leading space if present
59-
value, _ = bytes.CutPrefix(value, []byte(" "))
62+
value = bytes.TrimPrefix(value, []byte(" "))
6063
}
6164

6265
switch field {
@@ -73,16 +76,21 @@ func (e *SSEEvent) decode(b []byte) error {
7376

7477
data := bytes.Join(chunks, []byte("\n"))
7578
if !utf8.Valid(data) {
76-
return ErrInvalidUTF8Data
79+
return nil, ErrInvalidUTF8Data
7780
}
7881
e.Data = string(data)
7982

80-
return nil
83+
// Return nil if event data is empty and event type is not "done"
84+
if e.Data == "" && e.Type != SSETypeDone {
85+
return nil, nil
86+
}
87+
88+
return e, nil
8189
}
8290

8391
func (e *SSEEvent) String() string {
8492
switch e.Type {
85-
case "output":
93+
case SSETypeOutput:
8694
return e.Data
8795
default:
8896
return ""
@@ -126,9 +134,6 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (
126134
}
127135

128136
func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) {
129-
g, ctx := errgroup.WithContext(ctx)
130-
done := make(chan struct{})
131-
132137
url := prediction.URLs["stream"]
133138
if url == "" {
134139
errChan <- errors.New("streaming not supported or not enabled for this prediction")
@@ -137,7 +142,10 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
137142

138143
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
139144
if err != nil {
140-
errChan <- fmt.Errorf("failed to create request: %w", err)
145+
select {
146+
case errChan <- fmt.Errorf("failed to create request: %w", err):
147+
default:
148+
}
141149
return
142150
}
143151
req.Header.Set("Accept", "text/event-stream")
@@ -150,22 +158,33 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
150158

151159
resp, err := r.c.Do(req)
152160
if err != nil || resp == nil {
153-
if resp != nil {
154-
resp.Body.Close()
161+
if resp == nil {
162+
err = errors.New("received nil response")
163+
} else {
164+
defer resp.Body.Close()
165+
}
166+
select {
167+
case errChan <- fmt.Errorf("failed to send request: %w", err):
168+
default:
155169
}
156-
errChan <- fmt.Errorf("failed to send request: %w", err)
157170
return
158171
}
159172

160173
if resp.StatusCode != http.StatusOK {
161-
errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode)
174+
select {
175+
case errChan <- fmt.Errorf("received invalid status code: %d", resp.StatusCode):
176+
default:
177+
}
162178
return
163179
}
164180

165181
reader := bufio.NewReader(resp.Body)
166182
var buf bytes.Buffer
167183
lineChan := make(chan []byte)
168184

185+
g, ctx := errgroup.WithContext(ctx)
186+
done := make(chan struct{})
187+
169188
g.Go(func() error {
170189
defer close(lineChan)
171190
defer resp.Body.Close()
@@ -208,18 +227,22 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
208227
b := buf.Bytes()
209228
buf.Reset()
210229

211-
event := SSEEvent{Type: SSETypeDefault}
212-
if err := event.decode(b); err != nil {
230+
event, err := decodeSSEEvent(b)
231+
if err != nil {
213232
select {
214233
case errChan <- err:
215234
default:
216235
}
217-
close(done)
218-
return
236+
continue
237+
}
238+
239+
if event == nil {
240+
// Skip empty events
241+
continue
219242
}
220243

221244
select {
222-
case sseChan <- event:
245+
case sseChan <- *event:
223246
case <-done:
224247
return
225248
case <-ctx.Done():
@@ -238,9 +261,6 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
238261
go func() {
239262
err := g.Wait()
240263

241-
defer close(sseChan)
242-
defer close(errChan)
243-
244264
if err != nil {
245265
if errors.Is(err, io.EOF) {
246266
// Attempt to reconnect if the connection was closed before the stream was done
@@ -255,5 +275,8 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l
255275
}
256276
}
257277
}
278+
279+
close(sseChan)
280+
close(errChan)
258281
}()
259282
}

0 commit comments

Comments
 (0)