Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions services/connector/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
)

var (
ErrInvalidResponseFromForwardedRpc = errors.New("invalid response from forwarded RPC")
ErrInvalidResponseFromForwardedRpc = errors.New("invalid response from forwarded RPC")
ErrCannotOverrideClientIDForHttpConnection = errors.New("cannot override clientId for HTTP connection")
ErrEmptyClientIDFromTrustedConnection = errors.New("trusted connection must provide a clientId")
)

type API struct {
Expand Down Expand Up @@ -71,8 +73,8 @@ func NewAPI(s *Service) *API {
}
}

func (api *API) forwardRPC(ctx context.Context, URL string, request commands.RPCRequest) (interface{}, error) {
dApp, err := persistence.SelectDAppByUrl(api.s.db, URL)
func (api *API) forwardRPC(ctx context.Context, request commands.RPCRequest) (interface{}, error) {
dApp, err := persistence.SelectDApp(api.s.db, request.URL, request.ClientID)
if err != nil {
return "", err
}
Expand All @@ -97,6 +99,18 @@ func (api *API) CallRPC(ctx context.Context, inputJSON string) (interface{}, err
return "", err
}

// This prevents external clients from spoofing ClientID to impersonate trusted clients
if IsUntrustedConnection(ctx) {
if request.ClientID != "" {
return "", ErrCannotOverrideClientIDForHttpConnection
}
} else {
// Trusted connections MUST provide a ClientID
if request.ClientID == "" {
return "", ErrEmptyClientIDFromTrustedConnection
}
}

if command, exists := api.r.GetCommand(request.Method); exists {
return command.Execute(ctx, request)
}
Expand All @@ -105,12 +119,17 @@ func (api *API) CallRPC(ctx context.Context, inputJSON string) (interface{}, err
return nil, fmt.Errorf("method %s is not allowed", request.Method)
}

return api.forwardRPC(ctx, request.URL, request)
return api.forwardRPC(ctx, request)
}

// Deprecated: Use RecallDAppPermissionV2 instead
func (api *API) RecallDAppPermission(origin string) error {
return api.RecallDAppPermissionV2(origin, "")
}

func (api *API) RecallDAppPermissionV2(origin string, clientID string) error {
// TODO: close the websocket connection
return api.c.RecallDAppPermissions(commands.RecallDAppPermissionsArgs{URL: origin})
return api.c.RecallDAppPermissions(commands.RecallDAppPermissionsArgs{URL: origin, ClientID: clientID})
}

func (api *API) GetPermittedDAppsList() ([]persistence.DApp, error) {
Expand Down
53 changes: 51 additions & 2 deletions services/connector/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/status-im/status-go/services/connector/commands"
)

func TestCallRPC(t *testing.T) {
func TestCallRPC_UntrustedConnection(t *testing.T) {
state := setupTests(t)

tests := []struct {
Expand All @@ -36,9 +36,20 @@ func TestCallRPC(t *testing.T) {
request: "{\"method\": \"wallet_switchEthereumChain\", \"params\": []}",
expectError: commands.ErrRequestMissingDAppData,
},
{
request: `{
"method": "eth_chainId",
"params": [],
"url": "https://example.com",
"name": "Example DApp",
"iconUrl": "https://example.com/icon.png",
"clientId": "wallet-connect"
}`,
expectError: ErrCannotOverrideClientIDForHttpConnection,
},
}

ctx := context.Background()
ctx := WithConnectionType(context.Background(), ConnectionTypeHTTP)
for _, tt := range tests {
t.Run(tt.request, func(t *testing.T) {
_, err := state.api.CallRPC(ctx, tt.request)
Expand All @@ -47,3 +58,41 @@ func TestCallRPC(t *testing.T) {
})
}
}

func TestCallRPC_TrustedConnectionRequiresClientID(t *testing.T) {
state := setupTests(t)

// Trusted connection (Internal) without ClientID should fail
ctx := WithConnectionType(context.Background(), ConnectionTypeInternal)

request := `{
"method": "eth_chainId",
"params": [],
"url": "https://example.com",
"name": "Example DApp",
"iconUrl": "https://example.com/icon.png"
}`

_, err := state.api.CallRPC(ctx, request)
require.Error(t, err)
require.Equal(t, ErrEmptyClientIDFromTrustedConnection, err)
}

func TestCallRPC_TrustedConnectionWithClientID(t *testing.T) {
state := setupTests(t)

ctx := WithConnectionType(context.Background(), ConnectionTypeInternal)

request := `{
"method": "eth_chainId",
"params": [],
"url": "https://example.com",
"name": "Example DApp",
"iconUrl": "https://example.com/icon.png",
"clientId": "status-desktop"
}`

result, err := state.api.CallRPC(ctx, request)
require.NoError(t, err)
require.NotNil(t, result)
}
2 changes: 1 addition & 1 deletion services/connector/commands/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (c *AccountsCommand) Execute(ctx context.Context, request RPCRequest) (inte
return "", err
}

dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL)
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion services/connector/commands/chain_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (c *ChainIDCommand) Execute(ctx context.Context, request RPCRequest) (inter
return "", err
}

dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL)
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
if err != nil {
return "", err
}
Expand Down
13 changes: 7 additions & 6 deletions services/connector/commands/client_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func NewClientSideHandler(db *sql.DB) *ClientSideHandler {
}

func (c *ClientSideHandler) generateRequestID(dApp signal.ConnectorDApp) string {
rawID := fmt.Sprintf("%d%s", time.Now().UnixMilli(), dApp.URL)
rawID := fmt.Sprintf("%d%s%s", time.Now().UnixMilli(), dApp.URL, dApp.ClientID)
hash := sha256.Sum256([]byte(rawID))
return hex.EncodeToString(hash[:])
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func (c *ClientSideHandler) RecallDAppPermissions(args RecallDAppPermissionsArgs
return ErrEmptyUrl
}

dApp, err := persistence.SelectDAppByUrl(c.Db, args.URL)
dApp, err := persistence.SelectDApp(c.Db, args.URL, args.ClientID)
if err != nil {
return err
}
Expand All @@ -135,15 +135,16 @@ func (c *ClientSideHandler) RecallDAppPermissions(args RecallDAppPermissionsArgs
return ErrDAppDoesNotHavePermissions
}

err = persistence.DeleteDApp(c.Db, dApp.URL)
err = persistence.DeleteDApp(c.Db, dApp.URL, dApp.ClientID)
if err != nil {
return err
}

signal.SendConnectorDAppPermissionRevoked(signal.ConnectorDApp{
URL: dApp.URL,
Name: dApp.Name,
IconURL: dApp.IconURL,
URL: dApp.URL,
Name: dApp.Name,
IconURL: dApp.IconURL,
ClientID: dApp.ClientID,
})
return nil
}
Expand Down
9 changes: 5 additions & 4 deletions services/connector/commands/client_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,29 @@ func TestRecallDAppPermission(t *testing.T) {
Name: "Test DApp",
URL: "http://testDAppURL",
IconURL: "http://testDAppIconUrl",
ClientID: "",
SharedAccount: types.HexToAddress("0x1234567890"),
ChainID: 0x1,
}

err := persistence.UpsertDApp(db, &dapp)
assert.NoError(t, err)

persistedDapp, err := persistence.SelectDAppByUrl(db, dapp.URL)
persistedDapp, err := persistence.SelectDApp(db, dapp.URL, dapp.ClientID)
assert.Equal(t, persistedDapp, &dapp)
assert.NoError(t, err)

clientHandler := NewClientSideHandler(db)
err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL})
err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL, ClientID: dapp.ClientID})
assert.NoError(t, err)

err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{})
assert.ErrorIs(t, err, ErrEmptyUrl)

err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL})
err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL, ClientID: dapp.ClientID})
assert.ErrorIs(t, err, ErrDAppDoesNotHavePermissions)

recalledDapp, err := persistence.SelectDAppByUrl(db, dapp.URL)
recalledDapp, err := persistence.SelectDApp(db, dapp.URL, dapp.ClientID)

assert.Equal(t, recalledDapp, (*persistence.DApp)(nil))
assert.NoError(t, err)
Expand Down
10 changes: 6 additions & 4 deletions services/connector/commands/request_accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ func (c *RequestAccountsCommand) Execute(ctx context.Context, request RPCRequest
return "", err
}

dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL)
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
if err != nil {
return "", err
}

// FIXME: this may have a security issue in case some malicious software tries to fake the origin
if dApp == nil {
connectorDApp := signal.ConnectorDApp{
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
ClientID: request.ClientID,
}
account, chainID, err := c.ClientHandler.RequestShareAccountForDApp(connectorDApp)
if err != nil {
Expand All @@ -47,6 +48,7 @@ func (c *RequestAccountsCommand) Execute(ctx context.Context, request RPCRequest
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
ClientID: request.ClientID,
SharedAccount: account,
ChainID: chainID,
}
Expand Down
3 changes: 2 additions & 1 deletion services/connector/commands/request_accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ func TestRequestAccountsAcceptedAndRequestAgain(t *testing.T) {
assert.Equal(t, expectedResponse, response)

// Check dApp in the database
dApp, err := persistence.SelectDAppByUrl(state.walletDb, request.URL)
dApp, err := persistence.SelectDApp(state.walletDb, request.URL, request.ClientID)
assert.NoError(t, err)
assert.Equal(t, request.Name, dApp.Name)
assert.Equal(t, request.IconURL, dApp.IconURL)
assert.Equal(t, request.ClientID, dApp.ClientID)
assert.Equal(t, accountAddress, dApp.SharedAccount)
assert.Equal(t, walletCommon.EthereumMainnet, dApp.ChainID)

Expand Down
11 changes: 6 additions & 5 deletions services/connector/commands/revoke_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (c *RevokePermissionsCommand) Execute(ctx context.Context, request RPCReque
return "", err
}

dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL)
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
if err != nil {
return "", err
}
Expand All @@ -27,15 +27,16 @@ func (c *RevokePermissionsCommand) Execute(ctx context.Context, request RPCReque
return "", ErrDAppIsNotPermittedByUser
}

err = persistence.DeleteDApp(c.Db, dApp.URL)
err = persistence.DeleteDApp(c.Db, dApp.URL, dApp.ClientID)
if err != nil {
return "", err
}

signal.SendConnectorDAppPermissionRevoked(signal.ConnectorDApp{
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
ClientID: request.ClientID,
})

return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion services/connector/commands/revoke_permissions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestRevokePermissionsSucceeded(t *testing.T) {
assert.NoError(t, err)
assert.Empty(t, result)

dApp, err := persistence.SelectDAppByUrl(state.walletDb, testDAppData.URL)
dApp, err := persistence.SelectDApp(state.walletDb, testDAppData.URL, testDAppData.ClientID)
assert.NoError(t, err)
assert.Nil(t, dApp)

Expand Down
21 changes: 12 additions & 9 deletions services/connector/commands/rpc_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ var (
)

type RPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID int `json:"id"`
Method string `json:"method"`
Params []interface{} `json:"params"`
URL string `json:"url"`
Name string `json:"name"`
IconURL string `json:"iconUrl"`
ChainID uint64 `json:"chainId"`
JSONRPC string `json:"jsonrpc"`
ID int `json:"id"`
Method string `json:"method"`
Params []interface{} `json:"params"`
URL string `json:"url"`
Name string `json:"name"`
IconURL string `json:"iconUrl"`
ClientID string `json:"clientId"`
ChainID uint64 `json:"chainId"`
}

type RPCCommand interface {
Expand Down Expand Up @@ -67,7 +68,8 @@ type RejectedArgs struct {
}

type RecallDAppPermissionsArgs struct {
URL string `json:"url"`
URL string `json:"url"`
ClientID string `json:"clientId"`
}

type ClientSideHandlerInterface interface {
Expand Down Expand Up @@ -107,5 +109,6 @@ func (r *RPCRequest) Validate() error {
if r.URL == "" || r.Name == "" {
return ErrRequestMissingDAppData
}

return nil
}
9 changes: 5 additions & 4 deletions services/connector/commands/send_transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (c *SendTransactionCommand) Execute(ctx context.Context, request RPCRequest
return "", err
}

dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL)
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -122,9 +122,10 @@ func (c *SendTransactionCommand) Execute(ctx context.Context, request RPCRequest
}

hash, err := c.ClientHandler.RequestSendTransaction(signal.ConnectorDApp{
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
ClientID: request.ClientID,
}, dApp.ChainID, params)
if err != nil {
return "", err
Expand Down
9 changes: 5 additions & 4 deletions services/connector/commands/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (c *SignCommand) Execute(ctx context.Context, request RPCRequest) (interfac
return "", err
}

dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL)
dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID)
if err != nil {
return "", err
}
Expand All @@ -87,8 +87,9 @@ func (c *SignCommand) Execute(ctx context.Context, request RPCRequest) (interfac
}

return c.ClientHandler.RequestSign(signal.ConnectorDApp{
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
URL: request.URL,
Name: request.Name,
IconURL: request.IconURL,
ClientID: request.ClientID,
}, params.Challenge, params.Address, params.Method)
}
Loading
Loading