Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8262,6 +8262,80 @@ func (f *FailingSubscriptionClient) UniqueRequestID(ctx *resolve.Context, option
return errSubscriptionClientFail
}

type testSubscriptionUpdaterChan struct {
updates chan string
complete chan struct{}
closed chan resolve.SubscriptionCloseKind
}

func newTestSubscriptionUpdaterChan() *testSubscriptionUpdaterChan {
return &testSubscriptionUpdaterChan{
updates: make(chan string),
complete: make(chan struct{}),
closed: make(chan resolve.SubscriptionCloseKind),
}
}

func (t *testSubscriptionUpdaterChan) Heartbeat() {
t.updates <- "{}"
}

func (t *testSubscriptionUpdaterChan) Update(data []byte) {
t.updates <- string(data)
}

func (t *testSubscriptionUpdaterChan) Complete() {
close(t.complete)
}

func (t *testSubscriptionUpdaterChan) Close(kind resolve.SubscriptionCloseKind) {
t.closed <- kind
}

func (t *testSubscriptionUpdaterChan) AwaitUpdateWithT(tt *testing.T, timeout time.Duration, f func(t *testing.T, update string), msgAndArgs ...any) {
tt.Helper()

select {
case args := <-t.updates:
f(tt, args)
case <-time.After(timeout):
require.Fail(tt, "unable to receive update before timeout", msgAndArgs...)
}
}

func (t *testSubscriptionUpdaterChan) AwaitClose(tt *testing.T, timeout time.Duration, msgAndArgs ...any) {
tt.Helper()

select {
case <-t.closed:
case <-time.After(timeout):
require.Fail(tt, "updater not closed before timeout", msgAndArgs...)
}
}

func (t *testSubscriptionUpdaterChan) AwaitCloseKind(tt *testing.T, timeout time.Duration, expectedCloseKind resolve.SubscriptionCloseKind, msgAndArgs ...any) {
tt.Helper()

select {
case closeKind := <-t.closed:
require.Equal(tt, expectedCloseKind, closeKind, msgAndArgs...)
case <-time.After(timeout):
require.Fail(tt, "updater not closed before timeout", msgAndArgs...)
}
}

func (t *testSubscriptionUpdaterChan) AwaitComplete(tt *testing.T, timeout time.Duration, msgAndArgs ...any) {
tt.Helper()

select {
case <-t.complete:
case <-time.After(timeout):
require.Fail(tt, "updater not completed before timeout", msgAndArgs...)
}
}

// !! If you see this in a test you're working on, please replace it with the new testSubscriptionUpdaterChan
// It's faster, more ergonomic and more reliable. See SSE handler tests for usage examples.
type testSubscriptionUpdater struct {
updates []string
done bool
Expand All @@ -8270,6 +8344,8 @@ type testSubscriptionUpdater struct {
}

func (t *testSubscriptionUpdater) AwaitUpdates(tt *testing.T, timeout time.Duration, count int) {
tt.Helper()

ticker := time.NewTicker(timeout)
defer ticker.Stop()
for {
Expand All @@ -8289,6 +8365,8 @@ func (t *testSubscriptionUpdater) AwaitUpdates(tt *testing.T, timeout time.Durat
}

func (t *testSubscriptionUpdater) AwaitDone(tt *testing.T, timeout time.Duration) {
tt.Helper()

ticker := time.NewTicker(timeout)
defer ticker.Stop()
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,8 @@ func newSSEConnectionHandler(requestContext, engineContext context.Context, conn
}

func (h *gqlSSEConnectionHandler) StartBlocking() {
dataCh := make(chan []byte)
errCh := make(chan []byte)
defer func() {
close(dataCh)
close(errCh)
h.updater.Complete()
}()

go h.subscribe(dataCh, errCh)
defer h.updater.Close(resolve.SubscriptionCloseKindNormal)

for {
select {
case data := <-dataCh:
h.updater.Update(data)
case data := <-errCh:
h.updater.Update(data)
return
case <-h.requestContext.Done():
return
case <-h.engineContext.Done():
return
}
}
}

func (h *gqlSSEConnectionHandler) subscribe(dataCh, errCh chan []byte) {
resp, err := h.performSubscriptionRequest()
if err != nil {
h.log.Error("failed to perform subscription request", log.Error(err))
Expand All @@ -83,6 +59,7 @@ func (h *gqlSSEConnectionHandler) subscribe(dataCh, errCh chan []byte) {

return
}

defer func() {
_ = resp.Body.Close()
}()
Expand All @@ -105,8 +82,7 @@ func (h *gqlSSEConnectionHandler) subscribe(dataCh, errCh chan []byte) {
}

h.log.Error("failed to read event", log.Error(err))

errCh <- []byte(internalError)
h.updater.Update([]byte(internalError))
return
}

Expand All @@ -131,12 +107,13 @@ func (h *gqlSSEConnectionHandler) subscribe(dataCh, errCh chan []byte) {
return
}

dataCh <- data
h.updater.Update(data)
case bytes.HasPrefix(line, headerEvent):
event := trim(line[len(headerEvent):])

switch {
case bytes.Equal(event, eventTypeComplete):
h.updater.Complete()
return
case bytes.Equal(event, eventTypeNext):
continue
Expand Down Expand Up @@ -165,33 +142,31 @@ func (h *gqlSSEConnectionHandler) subscribe(dataCh, errCh chan []byte) {
response, err = jsonparser.Set(response, val, "errors")
if err != nil {
h.log.Error("failed to set errors", log.Error(err))

errCh <- []byte(internalError)
h.updater.Update([]byte(internalError))
return
}
errCh <- response
h.updater.Update(response)
return
case jsonparser.Object:
response := []byte(`{"errors":[]}`)
response, err = jsonparser.Set(response, val, "errors", "[0]")
if err != nil {
h.log.Error("failed to set errors", log.Error(err))

errCh <- []byte(internalError)
h.updater.Update([]byte(internalError))
return
}
errCh <- response
h.updater.Update(response)
return
default:
// don't crash on unexpected payloads from upstream
h.log.Error(fmt.Sprintf("unexpected value type: %d", valueType))
errCh <- []byte(internalError)
h.updater.Update([]byte(internalError))
return
}

default:
h.log.Error("failed to parse errors", log.Error(err))
errCh <- []byte(internalError)
h.updater.Update([]byte(internalError))
return
}
}
Expand All @@ -210,7 +185,6 @@ func trim(data []byte) []byte {
}

func (h *gqlSSEConnectionHandler) performSubscriptionRequest() (*http.Response, error) {

var req *http.Request
var err error

Expand Down
Loading