diff --git a/services/connector/api.go b/services/connector/api.go index 2c5e47174b1..9b9c0d2f8c4 100644 --- a/services/connector/api.go +++ b/services/connector/api.go @@ -10,7 +10,9 @@ import ( ) var ( - ErrInvalidResponseFromForwardedRpc = errors.New("invalid response from forwarded RPC") + ErrInvalidResponseFromForwardedRpc = errors.New("invalid response from forwarded RPC") + ErrCannotOverrideClientIDForHttpConnection = errors.New("cannot override clientId for HTTP connection") + ErrEmptyClientIDFromTrustedConnection = errors.New("trusted connection must provide a clientId") ) type API struct { @@ -71,8 +73,8 @@ func NewAPI(s *Service) *API { } } -func (api *API) forwardRPC(ctx context.Context, URL string, request commands.RPCRequest) (interface{}, error) { - dApp, err := persistence.SelectDAppByUrl(api.s.db, URL) +func (api *API) forwardRPC(ctx context.Context, request commands.RPCRequest) (interface{}, error) { + dApp, err := persistence.SelectDApp(api.s.db, request.URL, request.ClientID) if err != nil { return "", err } @@ -97,6 +99,18 @@ func (api *API) CallRPC(ctx context.Context, inputJSON string) (interface{}, err return "", err } + // This prevents external clients from spoofing ClientID to impersonate trusted clients + if IsUntrustedConnection(ctx) { + if request.ClientID != "" { + return "", ErrCannotOverrideClientIDForHttpConnection + } + } else { + // Trusted connections MUST provide a ClientID + if request.ClientID == "" { + return "", ErrEmptyClientIDFromTrustedConnection + } + } + if command, exists := api.r.GetCommand(request.Method); exists { return command.Execute(ctx, request) } @@ -105,12 +119,17 @@ func (api *API) CallRPC(ctx context.Context, inputJSON string) (interface{}, err return nil, fmt.Errorf("method %s is not allowed", request.Method) } - return api.forwardRPC(ctx, request.URL, request) + return api.forwardRPC(ctx, request) } +// Deprecated: Use RecallDAppPermissionV2 instead func (api *API) RecallDAppPermission(origin string) error { + return api.RecallDAppPermissionV2(origin, "") +} + +func (api *API) RecallDAppPermissionV2(origin string, clientID string) error { // TODO: close the websocket connection - return api.c.RecallDAppPermissions(commands.RecallDAppPermissionsArgs{URL: origin}) + return api.c.RecallDAppPermissions(commands.RecallDAppPermissionsArgs{URL: origin, ClientID: clientID}) } func (api *API) GetPermittedDAppsList() ([]persistence.DApp, error) { diff --git a/services/connector/api_test.go b/services/connector/api_test.go index 3e15b766205..75e4f9ffc68 100644 --- a/services/connector/api_test.go +++ b/services/connector/api_test.go @@ -9,7 +9,7 @@ import ( "github.com/status-im/status-go/services/connector/commands" ) -func TestCallRPC(t *testing.T) { +func TestCallRPC_UntrustedConnection(t *testing.T) { state := setupTests(t) tests := []struct { @@ -36,9 +36,20 @@ func TestCallRPC(t *testing.T) { request: "{\"method\": \"wallet_switchEthereumChain\", \"params\": []}", expectError: commands.ErrRequestMissingDAppData, }, + { + request: `{ + "method": "eth_chainId", + "params": [], + "url": "https://example.com", + "name": "Example DApp", + "iconUrl": "https://example.com/icon.png", + "clientId": "wallet-connect" + }`, + expectError: ErrCannotOverrideClientIDForHttpConnection, + }, } - ctx := context.Background() + ctx := WithConnectionType(context.Background(), ConnectionTypeHTTP) for _, tt := range tests { t.Run(tt.request, func(t *testing.T) { _, err := state.api.CallRPC(ctx, tt.request) @@ -47,3 +58,41 @@ func TestCallRPC(t *testing.T) { }) } } + +func TestCallRPC_TrustedConnectionRequiresClientID(t *testing.T) { + state := setupTests(t) + + // Trusted connection (Internal) without ClientID should fail + ctx := WithConnectionType(context.Background(), ConnectionTypeInternal) + + request := `{ + "method": "eth_chainId", + "params": [], + "url": "https://example.com", + "name": "Example DApp", + "iconUrl": "https://example.com/icon.png" + }` + + _, err := state.api.CallRPC(ctx, request) + require.Error(t, err) + require.Equal(t, ErrEmptyClientIDFromTrustedConnection, err) +} + +func TestCallRPC_TrustedConnectionWithClientID(t *testing.T) { + state := setupTests(t) + + ctx := WithConnectionType(context.Background(), ConnectionTypeInternal) + + request := `{ + "method": "eth_chainId", + "params": [], + "url": "https://example.com", + "name": "Example DApp", + "iconUrl": "https://example.com/icon.png", + "clientId": "status-desktop" + }` + + result, err := state.api.CallRPC(ctx, request) + require.NoError(t, err) + require.NotNil(t, result) +} diff --git a/services/connector/commands/accounts.go b/services/connector/commands/accounts.go index 18beb0f6219..5dfbbe92ed9 100644 --- a/services/connector/commands/accounts.go +++ b/services/connector/commands/accounts.go @@ -23,7 +23,7 @@ func (c *AccountsCommand) Execute(ctx context.Context, request RPCRequest) (inte return "", err } - dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL) + dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID) if err != nil { return "", err } diff --git a/services/connector/commands/chain_id.go b/services/connector/commands/chain_id.go index d3cbcd000b4..1a0e5b93af8 100644 --- a/services/connector/commands/chain_id.go +++ b/services/connector/commands/chain_id.go @@ -21,7 +21,7 @@ func (c *ChainIDCommand) Execute(ctx context.Context, request RPCRequest) (inter return "", err } - dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL) + dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID) if err != nil { return "", err } diff --git a/services/connector/commands/client_handler.go b/services/connector/commands/client_handler.go index 14c704f9261..032ca3b5d7b 100644 --- a/services/connector/commands/client_handler.go +++ b/services/connector/commands/client_handler.go @@ -58,7 +58,7 @@ func NewClientSideHandler(db *sql.DB) *ClientSideHandler { } func (c *ClientSideHandler) generateRequestID(dApp signal.ConnectorDApp) string { - rawID := fmt.Sprintf("%d%s", time.Now().UnixMilli(), dApp.URL) + rawID := fmt.Sprintf("%d%s%s", time.Now().UnixMilli(), dApp.URL, dApp.ClientID) hash := sha256.Sum256([]byte(rawID)) return hex.EncodeToString(hash[:]) } @@ -126,7 +126,7 @@ func (c *ClientSideHandler) RecallDAppPermissions(args RecallDAppPermissionsArgs return ErrEmptyUrl } - dApp, err := persistence.SelectDAppByUrl(c.Db, args.URL) + dApp, err := persistence.SelectDApp(c.Db, args.URL, args.ClientID) if err != nil { return err } @@ -135,15 +135,16 @@ func (c *ClientSideHandler) RecallDAppPermissions(args RecallDAppPermissionsArgs return ErrDAppDoesNotHavePermissions } - err = persistence.DeleteDApp(c.Db, dApp.URL) + err = persistence.DeleteDApp(c.Db, dApp.URL, dApp.ClientID) if err != nil { return err } signal.SendConnectorDAppPermissionRevoked(signal.ConnectorDApp{ - URL: dApp.URL, - Name: dApp.Name, - IconURL: dApp.IconURL, + URL: dApp.URL, + Name: dApp.Name, + IconURL: dApp.IconURL, + ClientID: dApp.ClientID, }) return nil } diff --git a/services/connector/commands/client_handler_test.go b/services/connector/commands/client_handler_test.go index 823f249db77..27775af8568 100644 --- a/services/connector/commands/client_handler_test.go +++ b/services/connector/commands/client_handler_test.go @@ -44,6 +44,7 @@ func TestRecallDAppPermission(t *testing.T) { Name: "Test DApp", URL: "http://testDAppURL", IconURL: "http://testDAppIconUrl", + ClientID: "", SharedAccount: types.HexToAddress("0x1234567890"), ChainID: 0x1, } @@ -51,21 +52,21 @@ func TestRecallDAppPermission(t *testing.T) { err := persistence.UpsertDApp(db, &dapp) assert.NoError(t, err) - persistedDapp, err := persistence.SelectDAppByUrl(db, dapp.URL) + persistedDapp, err := persistence.SelectDApp(db, dapp.URL, dapp.ClientID) assert.Equal(t, persistedDapp, &dapp) assert.NoError(t, err) clientHandler := NewClientSideHandler(db) - err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL}) + err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL, ClientID: dapp.ClientID}) assert.NoError(t, err) err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{}) assert.ErrorIs(t, err, ErrEmptyUrl) - err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL}) + err = clientHandler.RecallDAppPermissions(RecallDAppPermissionsArgs{URL: dapp.URL, ClientID: dapp.ClientID}) assert.ErrorIs(t, err, ErrDAppDoesNotHavePermissions) - recalledDapp, err := persistence.SelectDAppByUrl(db, dapp.URL) + recalledDapp, err := persistence.SelectDApp(db, dapp.URL, dapp.ClientID) assert.Equal(t, recalledDapp, (*persistence.DApp)(nil)) assert.NoError(t, err) diff --git a/services/connector/commands/request_accounts.go b/services/connector/commands/request_accounts.go index ba60e9d371e..dadb8f31173 100644 --- a/services/connector/commands/request_accounts.go +++ b/services/connector/commands/request_accounts.go @@ -26,7 +26,7 @@ func (c *RequestAccountsCommand) Execute(ctx context.Context, request RPCRequest return "", err } - dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL) + dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID) if err != nil { return "", err } @@ -34,9 +34,10 @@ func (c *RequestAccountsCommand) Execute(ctx context.Context, request RPCRequest // FIXME: this may have a security issue in case some malicious software tries to fake the origin if dApp == nil { connectorDApp := signal.ConnectorDApp{ - URL: request.URL, - Name: request.Name, - IconURL: request.IconURL, + URL: request.URL, + Name: request.Name, + IconURL: request.IconURL, + ClientID: request.ClientID, } account, chainID, err := c.ClientHandler.RequestShareAccountForDApp(connectorDApp) if err != nil { @@ -47,6 +48,7 @@ func (c *RequestAccountsCommand) Execute(ctx context.Context, request RPCRequest URL: request.URL, Name: request.Name, IconURL: request.IconURL, + ClientID: request.ClientID, SharedAccount: account, ChainID: chainID, } diff --git a/services/connector/commands/request_accounts_test.go b/services/connector/commands/request_accounts_test.go index 4aae7b58fc3..5addfe0e875 100644 --- a/services/connector/commands/request_accounts_test.go +++ b/services/connector/commands/request_accounts_test.go @@ -81,10 +81,11 @@ func TestRequestAccountsAcceptedAndRequestAgain(t *testing.T) { assert.Equal(t, expectedResponse, response) // Check dApp in the database - dApp, err := persistence.SelectDAppByUrl(state.walletDb, request.URL) + dApp, err := persistence.SelectDApp(state.walletDb, request.URL, request.ClientID) assert.NoError(t, err) assert.Equal(t, request.Name, dApp.Name) assert.Equal(t, request.IconURL, dApp.IconURL) + assert.Equal(t, request.ClientID, dApp.ClientID) assert.Equal(t, accountAddress, dApp.SharedAccount) assert.Equal(t, walletCommon.EthereumMainnet, dApp.ChainID) diff --git a/services/connector/commands/revoke_permissions.go b/services/connector/commands/revoke_permissions.go index 9e92d5adeb6..7cca895aff1 100644 --- a/services/connector/commands/revoke_permissions.go +++ b/services/connector/commands/revoke_permissions.go @@ -18,7 +18,7 @@ func (c *RevokePermissionsCommand) Execute(ctx context.Context, request RPCReque return "", err } - dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL) + dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID) if err != nil { return "", err } @@ -27,15 +27,16 @@ func (c *RevokePermissionsCommand) Execute(ctx context.Context, request RPCReque return "", ErrDAppIsNotPermittedByUser } - err = persistence.DeleteDApp(c.Db, dApp.URL) + err = persistence.DeleteDApp(c.Db, dApp.URL, dApp.ClientID) if err != nil { return "", err } signal.SendConnectorDAppPermissionRevoked(signal.ConnectorDApp{ - URL: request.URL, - Name: request.Name, - IconURL: request.IconURL, + URL: request.URL, + Name: request.Name, + IconURL: request.IconURL, + ClientID: request.ClientID, }) return nil, nil diff --git a/services/connector/commands/revoke_permissions_test.go b/services/connector/commands/revoke_permissions_test.go index 73610459519..a13dd44737a 100644 --- a/services/connector/commands/revoke_permissions_test.go +++ b/services/connector/commands/revoke_permissions_test.go @@ -65,7 +65,7 @@ func TestRevokePermissionsSucceeded(t *testing.T) { assert.NoError(t, err) assert.Empty(t, result) - dApp, err := persistence.SelectDAppByUrl(state.walletDb, testDAppData.URL) + dApp, err := persistence.SelectDApp(state.walletDb, testDAppData.URL, testDAppData.ClientID) assert.NoError(t, err) assert.Nil(t, dApp) diff --git a/services/connector/commands/rpc_traits.go b/services/connector/commands/rpc_traits.go index 45b6129c146..eac459ae7e4 100644 --- a/services/connector/commands/rpc_traits.go +++ b/services/connector/commands/rpc_traits.go @@ -32,14 +32,15 @@ var ( ) type RPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID int `json:"id"` - Method string `json:"method"` - Params []interface{} `json:"params"` - URL string `json:"url"` - Name string `json:"name"` - IconURL string `json:"iconUrl"` - ChainID uint64 `json:"chainId"` + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Method string `json:"method"` + Params []interface{} `json:"params"` + URL string `json:"url"` + Name string `json:"name"` + IconURL string `json:"iconUrl"` + ClientID string `json:"clientId"` + ChainID uint64 `json:"chainId"` } type RPCCommand interface { @@ -67,7 +68,8 @@ type RejectedArgs struct { } type RecallDAppPermissionsArgs struct { - URL string `json:"url"` + URL string `json:"url"` + ClientID string `json:"clientId"` } type ClientSideHandlerInterface interface { @@ -107,5 +109,6 @@ func (r *RPCRequest) Validate() error { if r.URL == "" || r.Name == "" { return ErrRequestMissingDAppData } + return nil } diff --git a/services/connector/commands/send_transaction.go b/services/connector/commands/send_transaction.go index 94fbf460f8f..cce9c4083f5 100644 --- a/services/connector/commands/send_transaction.go +++ b/services/connector/commands/send_transaction.go @@ -63,7 +63,7 @@ func (c *SendTransactionCommand) Execute(ctx context.Context, request RPCRequest return "", err } - dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL) + dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID) if err != nil { return "", err } @@ -122,9 +122,10 @@ func (c *SendTransactionCommand) Execute(ctx context.Context, request RPCRequest } hash, err := c.ClientHandler.RequestSendTransaction(signal.ConnectorDApp{ - URL: request.URL, - Name: request.Name, - IconURL: request.IconURL, + URL: request.URL, + Name: request.Name, + IconURL: request.IconURL, + ClientID: request.ClientID, }, dApp.ChainID, params) if err != nil { return "", err diff --git a/services/connector/commands/sign.go b/services/connector/commands/sign.go index b08dc9b36e1..695343e5227 100644 --- a/services/connector/commands/sign.go +++ b/services/connector/commands/sign.go @@ -77,7 +77,7 @@ func (c *SignCommand) Execute(ctx context.Context, request RPCRequest) (interfac return "", err } - dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL) + dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID) if err != nil { return "", err } @@ -87,8 +87,9 @@ func (c *SignCommand) Execute(ctx context.Context, request RPCRequest) (interfac } return c.ClientHandler.RequestSign(signal.ConnectorDApp{ - URL: request.URL, - Name: request.Name, - IconURL: request.IconURL, + URL: request.URL, + Name: request.Name, + IconURL: request.IconURL, + ClientID: request.ClientID, }, params.Challenge, params.Address, params.Method) } diff --git a/services/connector/commands/switch_ethereum_chain.go b/services/connector/commands/switch_ethereum_chain.go index 8824f547e5d..75ca4ca1e58 100644 --- a/services/connector/commands/switch_ethereum_chain.go +++ b/services/connector/commands/switch_ethereum_chain.go @@ -75,7 +75,7 @@ func (c *SwitchEthereumChainCommand) Execute(ctx context.Context, request RPCReq return "", ErrUnsupportedNetwork } - dApp, err := persistence.SelectDAppByUrl(c.Db, request.URL) + dApp, err := persistence.SelectDApp(c.Db, request.URL, request.ClientID) if err != nil { return "", err } diff --git a/services/connector/commands/test_helpers.go b/services/connector/commands/test_helpers.go index 2be29d2da4d..4020a427ed4 100644 --- a/services/connector/commands/test_helpers.go +++ b/services/connector/commands/test_helpers.go @@ -24,9 +24,10 @@ import ( ) var testDAppData = signal.ConnectorDApp{ - URL: "http://testDAppURL", - Name: "testDAppName", - IconURL: "http://testDAppIconUrl", + URL: "http://testDAppURL", + Name: "testDAppName", + IconURL: "http://testDAppIconUrl", + ClientID: "", } type EventType struct { @@ -145,6 +146,7 @@ func PersistDAppData(db *sql.DB, dApp signal.ConnectorDApp, sharedAccount types. URL: dApp.URL, Name: dApp.Name, IconURL: dApp.IconURL, + ClientID: dApp.ClientID, SharedAccount: sharedAccount, ChainID: chainID, } @@ -164,6 +166,7 @@ func ConstructRPCRequest(method string, params []interface{}, dApp *signal.Conne request.URL = dApp.URL request.Name = dApp.Name request.IconURL = dApp.IconURL + request.ClientID = dApp.ClientID } return request, nil diff --git a/services/connector/context.go b/services/connector/context.go new file mode 100644 index 00000000000..654498bbab2 --- /dev/null +++ b/services/connector/context.go @@ -0,0 +1,40 @@ +package connector + +import ( + "context" +) + +// ConnectionType represents the source of the RPC connection +type ConnectionType string + +const ( + ConnectionTypeHTTP ConnectionType = "http" // Untrusted - from WebSocket + ConnectionTypeInternal ConnectionType = "internal" // Trusted - from CallRPC (status-desktop) +) + +// ContextKey is a type used for keys in connector Context. +type ContextKey struct { + Name string +} + +var ( + connectionTypeKey = ContextKey{Name: "connectionType"} +) + +// WithConnectionType adds connection type to context +func WithConnectionType(ctx context.Context, connType ConnectionType) context.Context { + return context.WithValue(ctx, connectionTypeKey, connType) +} + +// GetConnectionType retrieves connection type from context +func GetConnectionType(ctx context.Context) ConnectionType { + if connType, ok := ctx.Value(connectionTypeKey).(ConnectionType); ok { + return connType + } + return ConnectionTypeInternal // default to untrusted +} + +// IsUntrustedConnection checks if the connection is from HTTP/WebSocket +func IsUntrustedConnection(ctx context.Context) bool { + return GetConnectionType(ctx) == ConnectionTypeHTTP +} diff --git a/services/connector/context_test.go b/services/connector/context_test.go new file mode 100644 index 00000000000..ba9c3ad0bae --- /dev/null +++ b/services/connector/context_test.go @@ -0,0 +1,31 @@ +package connector + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConnectionType(t *testing.T) { + ctx := WithConnectionType(context.Background(), ConnectionTypeHTTP) + require.True(t, IsUntrustedConnection(ctx)) + require.Equal(t, ConnectionTypeHTTP, GetConnectionType(ctx)) + + ctx = WithConnectionType(context.Background(), ConnectionTypeInternal) + require.False(t, IsUntrustedConnection(ctx)) + require.Equal(t, ConnectionTypeInternal, GetConnectionType(ctx)) +} + +func TestConnectionTypeDefault(t *testing.T) { + // Context without connection type should default to untrusted + ctx := context.Background() + + if IsUntrustedConnection(ctx) { + t.Error("Default connection type should be untrusted (HTTP)") + } + + if got := GetConnectionType(ctx); got != ConnectionTypeInternal { + t.Errorf("Default GetConnectionType() = %v, want %v", got, ConnectionTypeInternal) + } +} diff --git a/services/connector/database/persistence.go b/services/connector/database/persistence.go index b49e3ebf1d6..ceaf4eed99b 100644 --- a/services/connector/database/persistence.go +++ b/services/connector/database/persistence.go @@ -6,29 +6,32 @@ import ( "github.com/status-im/status-go/crypto/types" ) -const upsertDAppQuery = "INSERT INTO connector_dapps (url, name, icon_url, shared_account, chain_id) VALUES (?, ?, ?, ?, ?) ON CONFLICT(url) DO UPDATE SET name = excluded.name, icon_url = excluded.icon_url, shared_account = excluded.shared_account, chain_id = excluded.chain_id" -const selectDAppByUrlQuery = "SELECT name, icon_url, shared_account, chain_id FROM connector_dapps WHERE url = ?" -const selectDAppsQuery = "SELECT url, name, icon_url, shared_account, chain_id FROM connector_dapps" -const deleteDAppQuery = "DELETE FROM connector_dapps WHERE url = ?" +const upsertDAppQuery = "INSERT INTO connector_dapps (url, name, icon_url, client_id, shared_account, chain_id) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(url, client_id) DO UPDATE SET name = excluded.name, icon_url = excluded.icon_url, shared_account = excluded.shared_account, chain_id = excluded.chain_id" +const selectDAppQuery = "SELECT name, icon_url, shared_account, chain_id FROM connector_dapps WHERE url = ? AND client_id = ?" +const selectDAppsQuery = "SELECT url, name, icon_url, client_id, shared_account, chain_id FROM connector_dapps" +const deleteDAppQuery = "DELETE FROM connector_dapps WHERE url = ? AND client_id = ?" type DApp struct { URL string `json:"url"` Name string `json:"name"` IconURL string `json:"iconUrl"` + ClientID string `json:"clientId"` SharedAccount types.Address `json:"sharedAccount"` ChainID uint64 `json:"chainId"` } func UpsertDApp(db *sql.DB, dApp *DApp) error { - _, err := db.Exec(upsertDAppQuery, dApp.URL, dApp.Name, dApp.IconURL, dApp.SharedAccount, dApp.ChainID) + _, err := db.Exec(upsertDAppQuery, dApp.URL, dApp.Name, dApp.IconURL, dApp.ClientID, dApp.SharedAccount, dApp.ChainID) return err } -func SelectDAppByUrl(db *sql.DB, url string) (*DApp, error) { +func SelectDApp(db *sql.DB, url string, clientID string) (*DApp, error) { + // clientID can be empty for backward compatibility with browser extension dApp := &DApp{ - URL: url, + URL: url, + ClientID: clientID, } - err := db.QueryRow(selectDAppByUrlQuery, url).Scan(&dApp.Name, &dApp.IconURL, &dApp.SharedAccount, &dApp.ChainID) + err := db.QueryRow(selectDAppQuery, url, clientID).Scan(&dApp.Name, &dApp.IconURL, &dApp.SharedAccount, &dApp.ChainID) if err == sql.ErrNoRows { return nil, nil } @@ -45,7 +48,7 @@ func SelectAllDApps(db *sql.DB) ([]DApp, error) { var dApps []DApp for rows.Next() { dApp := DApp{} - err = rows.Scan(&dApp.URL, &dApp.Name, &dApp.IconURL, &dApp.SharedAccount, &dApp.ChainID) + err = rows.Scan(&dApp.URL, &dApp.Name, &dApp.IconURL, &dApp.ClientID, &dApp.SharedAccount, &dApp.ChainID) if err != nil { return nil, err } @@ -54,7 +57,7 @@ func SelectAllDApps(db *sql.DB) ([]DApp, error) { return dApps, nil } -func DeleteDApp(db *sql.DB, url string) error { - _, err := db.Exec(deleteDAppQuery, url) +func DeleteDApp(db *sql.DB, url string, clientID string) error { + _, err := db.Exec(deleteDAppQuery, url, clientID) return err } diff --git a/services/connector/database/persistence_test.go b/services/connector/database/persistence_test.go index b4b33edd86b..679e81405f5 100644 --- a/services/connector/database/persistence_test.go +++ b/services/connector/database/persistence_test.go @@ -16,6 +16,7 @@ var testDApp = DApp{ Name: "Test DApp", URL: "https://test-dapp-url.com", IconURL: "https://test-dapp-icon-url.com", + ClientID: "", SharedAccount: types.HexToAddress("0x1234567890"), ChainID: 0x1, } @@ -35,7 +36,7 @@ func TestInsertAndSelectDApp(t *testing.T) { err := UpsertDApp(db, &testDApp) require.NoError(t, err) - dAppBack, err := SelectDAppByUrl(db, testDApp.URL) + dAppBack, err := SelectDApp(db, testDApp.URL, testDApp.ClientID) require.NoError(t, err) require.Equal(t, &testDApp, dAppBack) } @@ -48,15 +49,16 @@ func TestInsertAndUpdateDApp(t *testing.T) { require.NoError(t, err) updatedDApp := DApp{ - Name: "Updated Test DApp", - URL: testDApp.URL, - IconURL: "https://updated-test-dapp-icon-url.com", + Name: "Updated Test DApp", + URL: testDApp.URL, + IconURL: "https://updated-test-dapp-icon-url.com", + ClientID: testDApp.ClientID, } err = UpsertDApp(db, &updatedDApp) require.NoError(t, err) - dAppBack, err := SelectDAppByUrl(db, testDApp.URL) + dAppBack, err := SelectDApp(db, testDApp.URL, testDApp.ClientID) require.NoError(t, err) require.Equal(t, &updatedDApp, dAppBack) require.NotEqual(t, &testDApp, dAppBack) @@ -69,14 +71,14 @@ func TestInsertAndRemoveDApp(t *testing.T) { err := UpsertDApp(db, &testDApp) require.NoError(t, err) - dAppBack, err := SelectDAppByUrl(db, testDApp.URL) + dAppBack, err := SelectDApp(db, testDApp.URL, testDApp.ClientID) require.NoError(t, err) require.Equal(t, &testDApp, dAppBack) - err = DeleteDApp(db, testDApp.URL) + err = DeleteDApp(db, testDApp.URL, testDApp.ClientID) require.NoError(t, err) - dAppBack, err = SelectDAppByUrl(db, testDApp.URL) + dAppBack, err = SelectDApp(db, testDApp.URL, testDApp.ClientID) require.NoError(t, err) require.Empty(t, dAppBack) } @@ -93,3 +95,150 @@ func TestSelectAllDApps(t *testing.T) { require.Len(t, dApps, 1) require.Equal(t, testDApp, dApps[0]) } + +func TestMultipleClientsWithSameURL(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + // Create two dApps with same URL but different clientIds + dApp1 := DApp{ + Name: "DApp Client 1", + URL: "https://same-url.com", + IconURL: "https://same-icon.com", + ClientID: "client1", + SharedAccount: types.HexToAddress("0x1111111111"), + ChainID: 0x1, + } + + dApp2 := DApp{ + Name: "DApp Client 2", + URL: "https://same-url.com", // Same URL + IconURL: "https://same-icon.com", + ClientID: "client2", + SharedAccount: types.HexToAddress("0x2222222222"), + ChainID: 0x89, + } + + err := UpsertDApp(db, &dApp1) + require.NoError(t, err) + err = UpsertDApp(db, &dApp2) + require.NoError(t, err) + + retrievedDApp1, err := SelectDApp(db, dApp1.URL, dApp1.ClientID) + require.NoError(t, err) + require.Equal(t, &dApp1, retrievedDApp1) + + retrievedDApp2, err := SelectDApp(db, dApp2.URL, dApp2.ClientID) + require.NoError(t, err) + require.Equal(t, &dApp2, retrievedDApp2) + + allDApps, err := SelectAllDApps(db) + require.NoError(t, err) + require.Len(t, allDApps, 2) + + foundDApp1 := false + foundDApp2 := false + for _, dApp := range allDApps { + if dApp.ClientID == "client1" { + require.Equal(t, dApp1, dApp) + foundDApp1 = true + } + if dApp.ClientID == "client2" { + require.Equal(t, dApp2, dApp) + foundDApp2 = true + } + } + require.True(t, foundDApp1, "client1 dApp not found") + require.True(t, foundDApp2, "client2 dApp not found") +} + +func TestDeleteSpecificClient(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + // Create two dApps with same URL but different clientIds + dApp1 := DApp{ + Name: "DApp Client 1", + URL: "https://test-delete.com", + IconURL: "https://test-icon.com", + ClientID: "client1", + SharedAccount: types.HexToAddress("0x1111111111"), + ChainID: 0x1, + } + + dApp2 := DApp{ + Name: "DApp Client 2", + URL: "https://test-delete.com", + IconURL: "https://test-icon.com", + ClientID: "client2", + SharedAccount: types.HexToAddress("0x2222222222"), + ChainID: 0x89, + } + + err := UpsertDApp(db, &dApp1) + require.NoError(t, err) + err = UpsertDApp(db, &dApp2) + require.NoError(t, err) + + // Delete only client1 + err = DeleteDApp(db, dApp1.URL, dApp1.ClientID) + require.NoError(t, err) + + deletedDApp1, err := SelectDApp(db, dApp1.URL, dApp1.ClientID) + require.NoError(t, err) + require.Nil(t, deletedDApp1) + + // Verify client2 still exists + stillExistsDApp2, err := SelectDApp(db, dApp2.URL, dApp2.ClientID) + require.NoError(t, err) + require.Equal(t, &dApp2, stillExistsDApp2) +} + +func TestBackwardCompatibilityEmptyClientID(t *testing.T) { + db, close := setupTestDB(t) + defer close() + + oldDApp := DApp{ + Name: "Old Client DApp", + URL: "https://old-client.com", + IconURL: "https://old-icon.com", + ClientID: "", + SharedAccount: types.HexToAddress("0x0000000000"), + ChainID: 0x1, + } + + newDApp := DApp{ + Name: "New Client DApp", + URL: "https://old-client.com", // Same URL + IconURL: "https://old-icon.com", + ClientID: "newclient123", + SharedAccount: types.HexToAddress("0x3333333333"), + ChainID: 0x89, + } + + err := UpsertDApp(db, &oldDApp) + require.NoError(t, err) + err = UpsertDApp(db, &newDApp) + require.NoError(t, err) + + retrievedOldDApp, err := SelectDApp(db, oldDApp.URL, "") + require.NoError(t, err) + require.Equal(t, &oldDApp, retrievedOldDApp) + + retrievedNewDApp, err := SelectDApp(db, newDApp.URL, newDApp.ClientID) + require.NoError(t, err) + require.Equal(t, &newDApp, retrievedNewDApp) + + // Delete old client (empty clientId) + err = DeleteDApp(db, oldDApp.URL, "") + require.NoError(t, err) + + deletedOldDApp, err := SelectDApp(db, oldDApp.URL, "") + require.NoError(t, err) + require.Nil(t, deletedOldDApp) + + // Verify new client still exists + stillExistsNewDApp, err := SelectDApp(db, newDApp.URL, newDApp.ClientID) + require.NoError(t, err) + require.Equal(t, &newDApp, stillExistsNewDApp) +} diff --git a/services/connector/service.go b/services/connector/service.go index 6c9a5dd5fbf..5175472b3e9 100644 --- a/services/connector/service.go +++ b/services/connector/service.go @@ -77,6 +77,11 @@ func (s *Service) Start() error { w.WriteHeader(http.StatusNotFound) return } + + // Inject connection type into request context + ctx := WithConnectionType(r.Context(), ConnectionTypeHTTP) + r = r.WithContext(ctx) + // FIXME: this is a temporary solution to allow all origins origins := []string{"*"} wsHandler := s.rpcServer.WebsocketHandler(origins) diff --git a/services/connector/test_helpers_test.go b/services/connector/test_helpers_test.go index cba50bc352e..c562c4d750e 100644 --- a/services/connector/test_helpers_test.go +++ b/services/connector/test_helpers_test.go @@ -52,7 +52,7 @@ func createWalletDB(t *testing.T) (db *sql.DB) { } func setupTests(t *testing.T) (state testState) { - state.ctx = context.Background() + state.ctx = WithConnectionType(context.Background(), ConnectionTypeHTTP) state.db = createDB(t) state.walletDb = createWalletDB(t) diff --git a/signal/events_connector.go b/signal/events_connector.go index 40bae42203d..fb79e9536fb 100644 --- a/signal/events_connector.go +++ b/signal/events_connector.go @@ -14,9 +14,10 @@ const ( ) type ConnectorDApp struct { - URL string `json:"url"` - Name string `json:"name"` - IconURL string `json:"iconUrl"` + URL string `json:"url"` + Name string `json:"name"` + IconURL string `json:"iconUrl"` + ClientID string `json:"clientId"` } // ConnectorSendRequestAccountsSignal is triggered when a request for accounts is sent. diff --git a/tests-functional/clients/connector.py b/tests-functional/clients/connector.py index 361037a9117..cddbb1487e2 100644 --- a/tests-functional/clients/connector.py +++ b/tests-functional/clients/connector.py @@ -92,6 +92,7 @@ def _send(self, method, params=None): "name": self.name, "url": "http://localhost/", "method": method, + "clientId": "tests-functional", } if params is not None: request["params"] = params diff --git a/walletdatabase/migrations/sql/1759312232_add_client_id_to_connector_dapps.up.sql b/walletdatabase/migrations/sql/1759312232_add_client_id_to_connector_dapps.up.sql new file mode 100644 index 00000000000..e18156e2302 --- /dev/null +++ b/walletdatabase/migrations/sql/1759312232_add_client_id_to_connector_dapps.up.sql @@ -0,0 +1,19 @@ +-- Each connection can have its own clientId (status-desktop browser, chrome extension, wallet-connect) +-- allowing independent session state management +CREATE TABLE connector_dapps_new ( + url TEXT NOT NULL, + name TEXT NOT NULL, + shared_account TEXT NOT NULL, + chain_id UNSIGNED BIGINT NOT NULL, + icon_url TEXT, + client_id TEXT NOT NULL DEFAULT '', + PRIMARY KEY (url, client_id) +) WITHOUT ROWID; + +-- Migrate existing data with empty client_id for backward compatibility +INSERT INTO connector_dapps_new (url, name, shared_account, chain_id, icon_url, client_id) +SELECT url, name, shared_account, chain_id, icon_url, '' FROM connector_dapps; + +DROP TABLE connector_dapps; + +ALTER TABLE connector_dapps_new RENAME TO connector_dapps;