Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions graphqlws/http.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package graphqlws

import (
"net/http"

"github.com/gorilla/websocket"
"net/http"

"github.com/graph-gophers/graphql-transport-ws/graphqlws/internal/connection"
)
Expand All @@ -17,6 +16,11 @@ var upgrader = websocket.Upgrader{

// NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets
func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler) http.HandlerFunc {
return NewHandlerFuncWithAuth(svc, httpHandler, nil)
}

// NewHandlerFunc returns an http.HandlerFunc that supports GraphQL over websockets
func NewHandlerFuncWithAuth(svc connection.GraphQLService, httpHandler http.Handler, authFunc connection.AuthenticateFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
for _, subprotocol := range websocket.Subprotocols(r) {
if subprotocol == "graphql-ws" {
Expand All @@ -30,7 +34,7 @@ func NewHandlerFunc(svc connection.GraphQLService, httpHandler http.Handler) htt
return
}

go connection.Connect(ws, svc)
go connection.Connect(ws, svc, connection.Authentication(r, authFunc))
return
}
}
Expand Down
41 changes: 34 additions & 7 deletions graphqlws/internal/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
)

Expand Down Expand Up @@ -47,18 +48,28 @@ type startMessagePayload struct {
Variables map[string]interface{} `json:"variables"`
}

type initMessagePayload struct{}

// GraphQLService interface
type GraphQLService interface {
Subscribe(ctx context.Context, document string, operationName string, variableValues map[string]interface{}) (payloads <-chan interface{}, err error)
}

type AuthenticateFunc func(ctx context.Context, r *http.Request, payload map[string]interface{}) error

type connection struct {
cancel func()
service GraphQLService
writeTimeout time.Duration
ws wsConnection
cancel func()
service GraphQLService
writeTimeout time.Duration
ws wsConnection
authenticated bool
authenticateFunc AuthenticateFunc
request *http.Request
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need the request as a field of the connection?

}

func Authentication(r *http.Request, f AuthenticateFunc) func(conn *connection) {
return func(conn *connection) {
conn.authenticateFunc = f
conn.request = r
}
}

// ReadLimit limits the maximum size of incoming messages
Expand Down Expand Up @@ -159,15 +170,31 @@ func (conn *connection) readLoop(ctx context.Context, send sendFunc) {

switch msg.Type {
case typeConnectionInit:
var initMsg initMessagePayload

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for empty line here. Please, remove it.

var initMsg map[string]interface{}
if err := json.Unmarshal(msg.Payload, &initMsg); err != nil {
ep := errPayload(fmt.Errorf("invalid payload for type: %s", msg.Type))
send("", typeConnectionError, ep)
continue
}

if conn.authenticateFunc != nil {
if err := conn.authenticateFunc(ctx, conn.request, initMsg); err != nil {
send("", typeConnectionError, errPayload(err))
continue
}
}
conn.authenticated = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if authenticateFunc is nil the connection is marked as authenticated? Why is that? This doesn't sound ok to me. Am I missing something? Should you move this line

conn.authenticated = true

inside of the if statement above it?

send("", typeConnectionAck, nil)

case typeStart:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary empty line. Please, remove it.

if !conn.authenticated && conn.authenticateFunc != nil {
ep := errPayload(errors.New("authentication required."))
send("", typeConnectionError, ep)
continue
}

// TODO: check an operation with the same ID hasn't been started already
if msg.ID == "" {
ep := errPayload(errors.New("missing ID for start operation"))
Expand Down