diff --git a/appdatabase/database.go b/appdatabase/database.go index 7b6a3c4baf9..1f93c75a9bc 100644 --- a/appdatabase/database.go +++ b/appdatabase/database.go @@ -44,7 +44,7 @@ func (a DbInitializer) Initialize(path, password string, kdfIterationsNumber int } func doMigration(db *sql.DB) error { - lastMigration, migrationTableExists, err := sqlite.GetLastMigrationVersion(db) + lastMigration, migrationTableExists, err := sqlite.GetLastMigrationVersion(db, sqlite.StatusMigrationTableName()) if err != nil { return err } diff --git a/cmd/push-notification-server/main.go b/cmd/push-notification-server/main.go index c1974576db6..0dc274dc401 100644 --- a/cmd/push-notification-server/main.go +++ b/cmd/push-notification-server/main.go @@ -122,10 +122,8 @@ func main() { messaging, err := messaging.NewCore( messaging.CoreParams{ - Identity: privateKey, - DB: db, - Persistence: protocol.NewMessagingPersistence(db), - NodeKey: nil, + Identity: privateKey, + NodeKey: nil, WakuConfig: params.WakuV2Config{ Enabled: true, Host: "0.0.0.0", @@ -144,6 +142,7 @@ func main() { TimeSource: timesource.DefaultService(), }, messaging.WithLogger(logger.Named("messaging")), + messaging.WithSQLitePersistence(db), ) if err != nil { os.Exit(exitCodeCreateMessengerFailed) diff --git a/go.mod b/go.mod index 3828d453ccf..79034353e5a 100644 --- a/go.mod +++ b/go.mod @@ -52,7 +52,7 @@ require ( github.com/status-im/doubleratchet v3.0.0+incompatible github.com/status-im/markdown v0.0.0-20250825083641-55c1df9bc05d github.com/status-im/migrate/v4 v4.6.2-status.3 - github.com/status-im/mvds v0.0.27-0.20241031073756-b192c603a75d + github.com/status-im/mvds v0.0.27-0.20251022120125-7bdc695d49c4 github.com/status-im/zxcvbn-go v0.0.0-20220311183720-5e8676676857 github.com/stretchr/testify v1.10.0 github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a // indirect @@ -83,6 +83,7 @@ require ( github.com/btcsuite/btcd/btcutil v1.1.6 github.com/cenkalti/backoff/v4 v4.2.1 github.com/getsentry/sentry-go v0.29.1 + github.com/golang-migrate/migrate/v4 v4.15.2 github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.5.3 github.com/ipfs/go-log/v2 v2.5.1 @@ -182,7 +183,6 @@ require ( github.com/gofrs/flock v0.12.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.2 // indirect - github.com/golang-migrate/migrate/v4 v4.15.2 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect diff --git a/go.sum b/go.sum index 5031713e776..3b4cdfd93c8 100644 --- a/go.sum +++ b/go.sum @@ -2104,8 +2104,8 @@ github.com/status-im/markdown v0.0.0-20250825083641-55c1df9bc05d h1:0jQiaymp0t7X github.com/status-im/markdown v0.0.0-20250825083641-55c1df9bc05d/go.mod h1:i31M0FhtYMUeAzWqJafla0Q4+LCGJyorLRCH3EorAWQ= github.com/status-im/migrate/v4 v4.6.2-status.3 h1:Khwjb59NzniloUr5i9s9AtkEyqBbQFt1lkoAu66sAu0= github.com/status-im/migrate/v4 v4.6.2-status.3/go.mod h1:c/kc90n47GZu/58nnz1OMLTf7uE4Da4gZP5qmU+A/v8= -github.com/status-im/mvds v0.0.27-0.20241031073756-b192c603a75d h1:+eBm+aBGFgXsJi6wDdEo6ASLd78kPN3vcKtnXQXn240= -github.com/status-im/mvds v0.0.27-0.20241031073756-b192c603a75d/go.mod h1:2fiAx0q9XYIPKYRq2B1oiO9zZESy/n4D32gWw6lMDsE= +github.com/status-im/mvds v0.0.27-0.20251022120125-7bdc695d49c4 h1:NrTkZgbgEu7hOJ2Ku4P1yj9boRzdpMs7OlZREBy6MyE= +github.com/status-im/mvds v0.0.27-0.20251022120125-7bdc695d49c4/go.mod h1:2fiAx0q9XYIPKYRq2B1oiO9zZESy/n4D32gWw6lMDsE= github.com/status-im/notify v1.0.2-status/go.mod h1:gF3zSOrafR9DQEWSE8TjfI9NkooDxbyT4UgRGKZA0lc= github.com/status-im/zxcvbn-go v0.0.0-20220311183720-5e8676676857 h1:sPkzT7Z7uLmejOsBRlZ0kwDWpqjpHJsp834o5nbhqho= github.com/status-im/zxcvbn-go v0.0.0-20220311183720-5e8676676857/go.mod h1:lq9I5ROto5tcua65GmCE6SIW7VE0ucdEBs1fn4z7uWU= diff --git a/messaging/adapters/persistence.go b/messaging/adapters/persistence.go deleted file mode 100644 index 595d0f9977c..00000000000 --- a/messaging/adapters/persistence.go +++ /dev/null @@ -1,75 +0,0 @@ -package adapters - -import ( - "crypto/ecdsa" - - "github.com/status-im/status-go/messaging/layers/transport" - "github.com/status-im/status-go/messaging/types" - wakupersistence "github.com/status-im/status-go/messaging/waku/persistence" -) - -type KeysPersistence struct { - P types.Persistence -} - -var _ transport.KeysPersistence = (*KeysPersistence)(nil) - -func (kp *KeysPersistence) All() (map[string][]byte, error) { - return kp.P.WakuKeys() -} - -func (kp *KeysPersistence) Add(chatID string, key []byte) error { - return kp.P.AddWakuKey(chatID, key) -} - -type ProcessedMessageIDsCache struct { - P types.Persistence -} - -var _ transport.ProcessedMessageIDsCachePersistence = (*ProcessedMessageIDsCache)(nil) - -func (pm *ProcessedMessageIDsCache) Clear() error { - return pm.P.MessageCacheClear() -} -func (pm *ProcessedMessageIDsCache) Hits(ids []string) (map[string]bool, error) { - return pm.P.MessageCacheHits(ids) -} -func (pm *ProcessedMessageIDsCache) Add(ids []string, timestamp uint64) error { - return pm.P.MessageCacheAdd(ids, timestamp) -} -func (pm *ProcessedMessageIDsCache) Clean(timestamp uint64) error { - return pm.P.MessageCacheClearOlderThan(timestamp) -} - -type WakuProtectedTopics struct { - P types.Persistence -} - -var _ wakupersistence.ProtectedTopics = (*WakuProtectedTopics)(nil) - -func (wpt *WakuProtectedTopics) Insert(pubsubTopic string, privKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey) error { - return wpt.P.WakuInsertProtectedTopic(pubsubTopic, privKey, publicKey) -} - -func (wpt *WakuProtectedTopics) Delete(pubsubTopic string) error { - return wpt.P.WakuDeleteProtectedTopic(pubsubTopic) -} - -func (wpt *WakuProtectedTopics) FetchPrivateKey(topic string) (*ecdsa.PrivateKey, error) { - return wpt.P.WakuFetchPrivateKeyForProtectedTopic(topic) -} - -func (wpt *WakuProtectedTopics) ProtectedTopics() ([]wakupersistence.ProtectedTopic, error) { - pt, err := wpt.P.WakuProtectedTopics() - if err != nil { - return nil, err - } - result := make([]wakupersistence.ProtectedTopic, len(pt)) - for i, p := range pt { - result[i] = wakupersistence.ProtectedTopic{ - PubKey: p.PubKey, - Topic: p.Topic, - } - } - return result, nil -} diff --git a/messaging/api.go b/messaging/api.go index 73aef3e2c7f..76e67992a70 100644 --- a/messaging/api.go +++ b/messaging/api.go @@ -16,6 +16,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" "github.com/status-im/status-go/connection" + cryptotypes "github.com/status-im/status-go/crypto/types" ethtypes "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/messaging/adapters" "github.com/status-im/status-go/messaging/common" @@ -127,7 +128,11 @@ func (a *API) HandleSharedSecrets(secrets []*types.SharedSecret) error { } func (a *API) JoinPublicChat(chatID string) (*types.ChatFilter, error) { - return a.core.sender.JoinPublic(chatID) + f, err := a.core.sender.JoinPublic(chatID) + if err != nil { + return nil, err + } + return adapters.FromTransportFilter(f), nil } func (a *API) JoinPrivateChat(publicKey *ecdsa.PublicKey) (*types.ChatFilter, error) { @@ -195,7 +200,7 @@ func (a *API) GetCurrentKeyForGroup(groupID []byte) (*encryption.HashRatchetKeyC } func (a *API) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *types.ReceivedMessage) error { - return a.core.persistence.SaveHashRatchetMessage(groupID, keyID, m) + return a.core.sender.SaveHashRatchetMessage(groupID, keyID, m) } func (a *API) SendPubsubTopicKey(ctx context.Context, rawMessage *types.RawMessage) ([]byte, error) { @@ -468,6 +473,10 @@ func (a *API) Metrics() string { return a.core.metrics() } +func (a *API) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID cryptotypes.HexBytes, err error) { + return a.core.sender.MarkAsConfirmed(dataSyncID, atLeastOne) +} + func ToContentTopic(s string) []byte { return transport.ToTopic(s) } diff --git a/messaging/common/message_segmentation.go b/messaging/common/message_segmentation.go index fa265c42c33..3f55f4bc965 100644 --- a/messaging/common/message_segmentation.go +++ b/messaging/common/message_segmentation.go @@ -165,7 +165,7 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { zap.Uint32("ParitySegmentIndex", segmentMessage.ParitySegmentIndex), zap.Uint32("ParitySegmentsCount", segmentMessage.ParitySegmentsCount)) - alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) + alreadyCompleted, err := s.segmentationPersistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash) if err != nil { return err } @@ -173,12 +173,12 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { return ErrMessageSegmentsAlreadyCompleted } - err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) + err = s.segmentationPersistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix()) if err != nil { return err } - segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) + segments, err := s.segmentationPersistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey) if err != nil { return err } @@ -258,7 +258,7 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { return ErrMessageSegmentsHashMismatch } - err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) + err = s.segmentationPersistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix()) if err != nil { return err } @@ -271,12 +271,12 @@ func (s *MessageSender) handleSegmentationLayer(message *types.Message) error { func (s *MessageSender) CleanupSegments() error { monthAgo := time.Now().AddDate(0, -1, 0).Unix() - err := s.persistence.RemoveMessageSegmentsOlderThan(monthAgo) + err := s.segmentationPersistence.RemoveMessageSegmentsOlderThan(monthAgo) if err != nil { return err } - err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo) + err = s.segmentationPersistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo) if err != nil { return err } diff --git a/messaging/common/message_segmentation_test.go b/messaging/common/message_segmentation_test.go index 094b23fe918..741d378fdf8 100644 --- a/messaging/common/message_segmentation_test.go +++ b/messaging/common/message_segmentation_test.go @@ -6,15 +6,16 @@ import ( "testing" "github.com/golang/protobuf/proto" + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/suite" "go.uber.org/zap" - "github.com/status-im/status-go/appdatabase" "github.com/status-im/status-go/crypto" + "github.com/status-im/status-go/messaging/layers/segmentation" + segmentationmigrations "github.com/status-im/status-go/messaging/layers/segmentation/migrations" "github.com/status-im/status-go/messaging/types" wakutypes "github.com/status-im/status-go/messaging/waku/types" "github.com/status-im/status-go/protocol/protobuf" - "github.com/status-im/status-go/protocol/sqlite" "github.com/status-im/status-go/t/helpers" ) @@ -41,18 +42,22 @@ func (s *MessageSegmentationSuite) SetupTest() { identity, err := crypto.GenerateKey() s.Require().NoError(err) - database, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(database) + s.logger, err = zap.NewDevelopment() s.Require().NoError(err) - s.logger, err = zap.NewDevelopment() + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: segmentationmigrations.AssetNames(), + AssetFunc: segmentationmigrations.Asset, + }, + })) s.Require().NoError(err) s.sender, err = NewMessageSender( identity, - database, - NewStubPersistence(), + nil, + nil, + segmentation.NewSQLitePersistence(db), nil, nil, s.logger, diff --git a/messaging/common/message_sender.go b/messaging/common/message_sender.go index 45db3bfd94f..2263514674f 100644 --- a/messaging/common/message_sender.go +++ b/messaging/common/message_sender.go @@ -3,7 +3,6 @@ package common import ( "context" "crypto/ecdsa" - "database/sql" "math/rand" "sync" "time" @@ -20,12 +19,14 @@ import ( utils "github.com/status-im/status-go/common" "github.com/status-im/status-go/crypto" "github.com/status-im/status-go/crypto/types" + cryptotypes "github.com/status-im/status-go/crypto/types" "github.com/status-im/status-go/messaging/adapters" "github.com/status-im/status-go/messaging/datasync" datasyncpeer "github.com/status-im/status-go/messaging/datasync/peer" messagingevents "github.com/status-im/status-go/messaging/events" "github.com/status-im/status-go/messaging/layers/encryption" "github.com/status-im/status-go/messaging/layers/encryption/sharedsecret" + "github.com/status-im/status-go/messaging/layers/segmentation" "github.com/status-im/status-go/messaging/layers/transport" messagingtypes "github.com/status-im/status-go/messaging/types" wakutypes "github.com/status-im/status-go/messaging/waku/types" @@ -47,14 +48,15 @@ const ( var RekeyCompatibility = true type MessageSender struct { - identity *ecdsa.PrivateKey - datasync *datasync.DataSync - database *sql.DB - transport *transport.Transport - protocol *encryption.Protocol - logger *zap.Logger - persistence messagingtypes.Persistence - publisher *pubsub.Publisher + identity *ecdsa.PrivateKey + datasync *datasync.DataSync + datasyncPersistence datasyncnode.Persistence + transport *transport.Transport + protocol *encryption.Protocol + logger *zap.Logger + persistence messagingtypes.MessageSenderPersistence + segmentationPersistence segmentation.Persistence + publisher *pubsub.Publisher datasyncEnabled bool @@ -69,22 +71,24 @@ type MessageSender struct { func NewMessageSender( identity *ecdsa.PrivateKey, - database *sql.DB, // FIXME - persistence messagingtypes.Persistence, + persistence messagingtypes.MessageSenderPersistence, + datasyncPersistence datasyncnode.Persistence, + segmentationPersistence segmentation.Persistence, transport *transport.Transport, enc *encryption.Protocol, logger *zap.Logger, ) (*MessageSender, error) { p := &MessageSender{ - identity: identity, - database: database, - datasyncEnabled: true, // FIXME - protocol: enc, - persistence: persistence, - publisher: pubsub.NewPublisher(), - transport: transport, - logger: logger, - ephemeralKeys: make(map[string]*ecdsa.PrivateKey), + identity: identity, + datasyncPersistence: datasyncPersistence, + datasyncEnabled: true, // FIXME + protocol: enc, + persistence: persistence, + segmentationPersistence: segmentationPersistence, + publisher: pubsub.NewPublisher(), + transport: transport, + logger: logger, + ephemeralKeys: make(map[string]*ecdsa.PrivateKey), } return p, nil @@ -106,7 +110,7 @@ func (s *MessageSender) StartDatasync(statusChangeEvent chan datasyncnode.PeerSt dataSyncTransport := datasync.NewNodeTransport() dataSyncNode, err := datasyncnode.NewPersistentNode( - s.database, + s.datasyncPersistence, dataSyncTransport, datasyncpeer.PublicKeyToPeerID(s.identity.PublicKey), datasyncnode.BATCH, @@ -1219,12 +1223,12 @@ func (s *MessageSender) notifyOnScheduledMessage(recipient *ecdsa.PublicKey, mes }) } -func (s *MessageSender) JoinPublic(id string) (*messagingtypes.ChatFilter, error) { +func (s *MessageSender) JoinPublic(id string) (*transport.Filter, error) { filter, err := s.transport.JoinPublic(id) if err != nil { return nil, err } - return adapters.FromTransportFilter(filter), nil + return filter, nil } func (s *MessageSender) getRandomEphemeralKey() *ecdsa.PrivateKey { @@ -1280,6 +1284,14 @@ func (s *MessageSender) StopDatasync() { } } +func (s *MessageSender) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID cryptotypes.HexBytes, err error) { + return s.persistence.MarkAsConfirmed(dataSyncID, atLeastOne) +} + +func (s *MessageSender) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *messagingtypes.ReceivedMessage) error { + return s.persistence.SaveHashRatchetMessage(groupID, keyID, m) +} + // GetCurrentKeyForGroup returns the latest key timestampID belonging to a key group func (s *MessageSender) GetCurrentKeyForGroup(groupID []byte) (*encryption.HashRatchetKeyCompatibility, error) { return s.protocol.GetCurrentKeyForGroup(groupID) diff --git a/messaging/common/message_sender_test.go b/messaging/common/message_sender_test.go index f47bd75ce14..cde9cf36ec8 100644 --- a/messaging/common/message_sender_test.go +++ b/messaging/common/message_sender_test.go @@ -6,22 +6,25 @@ import ( "github.com/golang/protobuf/proto" "github.com/libp2p/go-libp2p/core/peer" + bindata "github.com/status-im/migrate/v4/source/go_bindata" + mvdsnode "github.com/status-im/mvds/node" + datasyncproto "github.com/status-im/mvds/protobuf" "github.com/stretchr/testify/suite" "go.uber.org/zap" - datasyncproto "github.com/status-im/mvds/protobuf" - - "github.com/status-im/status-go/appdatabase" "github.com/status-im/status-go/crypto" - "github.com/status-im/status-go/messaging/adapters" + messagesendermigrations "github.com/status-im/status-go/messaging/common/migrations" "github.com/status-im/status-go/messaging/datasync" "github.com/status-im/status-go/messaging/layers/encryption" + encryptionmigrations "github.com/status-im/status-go/messaging/layers/encryption/migrations" + "github.com/status-im/status-go/messaging/layers/segmentation" + segmentationmigrations "github.com/status-im/status-go/messaging/layers/segmentation/migrations" "github.com/status-im/status-go/messaging/layers/transport" + transportmigrations "github.com/status-im/status-go/messaging/layers/transport/migrations" messagingtypes "github.com/status-im/status-go/messaging/types" wakuv2 "github.com/status-im/status-go/messaging/waku" wakutypes "github.com/status-im/status-go/messaging/waku/types" "github.com/status-im/status-go/protocol/protobuf" - "github.com/status-im/status-go/protocol/sqlite" v1protocol "github.com/status-im/status-go/protocol/v1" "github.com/status-im/status-go/t/helpers" ) @@ -56,13 +59,28 @@ func (s *MessageSenderSuite) SetupTest() { identity, err := crypto.GenerateKey() s.Require().NoError(err) - database, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(database) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: transportmigrations.AssetNames(), + AssetFunc: transportmigrations.Asset, + }, + { + Names: segmentationmigrations.AssetNames(), + AssetFunc: segmentationmigrations.Asset, + }, + { + Names: encryptionmigrations.AssetNames(), + AssetFunc: encryptionmigrations.Asset, + }, + { + Names: messagesendermigrations.AssetNames(), + AssetFunc: messagesendermigrations.Asset, + }, + })) s.Require().NoError(err) encryptionProtocol := encryption.New( - database, + encryption.NewSQLitePersistence(db), "installation-1", s.logger, ) @@ -80,13 +98,11 @@ func (s *MessageSenderSuite) SetupTest() { s.Require().NoError(err) s.Require().NoError(shh.Start()) - stubPersistence := NewStubPersistence() - transport, err := transport.NewTransport( shh, identity, - &adapters.KeysPersistence{P: stubPersistence}, - &adapters.ProcessedMessageIDsCache{P: stubPersistence}, + transport.NewSQLiteKeysPersistence(db), + transport.NewSQLiteProcessedMessageIDsCachePersistence(db), &transport.EnvelopesMonitorConfig{}, s.logger, ) @@ -94,8 +110,9 @@ func (s *MessageSenderSuite) SetupTest() { s.sender, err = NewMessageSender( identity, - database, - stubPersistence, + NewSQLiteMessageSenderPersistence(db), + mvdsnode.NewSQLitePersistence(db), + segmentation.NewSQLitePersistence(db), transport, encryptionProtocol, s.logger, @@ -196,13 +213,16 @@ func (s *MessageSenderSuite) TestHandleDecodedMessagesDatasyncEncrypted() { s.Require().NoError(err) // Create sender encryption protocol. - senderDatabase, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(senderDatabase) + senderDatabase, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: encryptionmigrations.AssetNames(), + AssetFunc: encryptionmigrations.Asset, + }, + })) s.Require().NoError(err) senderEncryptionProtocol := encryption.New( - senderDatabase, + encryption.NewSQLitePersistence(senderDatabase), "installation-2", s.logger, ) @@ -246,13 +266,16 @@ func (s *MessageSenderSuite) TestHandleOutOfOrderHashRatchet() { s.Require().NoError(err) // Create sender encryption protocol. - senderDatabase, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(senderDatabase) + senderDatabase, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: encryptionmigrations.AssetNames(), + AssetFunc: encryptionmigrations.Asset, + }, + })) s.Require().NoError(err) senderEncryptionProtocol := encryption.New( - senderDatabase, + encryption.NewSQLitePersistence(senderDatabase), "installation-2", s.logger, ) diff --git a/protocol/migrations/sqlite/1615374373_add_confirmations.up.sql b/messaging/common/migrations/sqlite/1615374373_add_confirmations.up.sql similarity index 100% rename from protocol/migrations/sqlite/1615374373_add_confirmations.up.sql rename to messaging/common/migrations/sqlite/1615374373_add_confirmations.up.sql diff --git a/protocol/migrations/sqlite/1665484435_add_encrypted_messages.up.sql b/messaging/common/migrations/sqlite/1665484435_add_encrypted_messages.up.sql similarity index 100% rename from protocol/migrations/sqlite/1665484435_add_encrypted_messages.up.sql rename to messaging/common/migrations/sqlite/1665484435_add_encrypted_messages.up.sql diff --git a/protocol/migrations/sqlite/1698414646_add_padding.up.sql b/messaging/common/migrations/sqlite/1698414646_add_padding.up.sql similarity index 100% rename from protocol/migrations/sqlite/1698414646_add_padding.up.sql rename to messaging/common/migrations/sqlite/1698414646_add_padding.up.sql diff --git a/protocol/migrations/sqlite/1710331284_hash_ratchet_encrypted_messages_key_id.up.sql b/messaging/common/migrations/sqlite/1710331284_hash_ratchet_encrypted_messages_key_id.up.sql similarity index 100% rename from protocol/migrations/sqlite/1710331284_hash_ratchet_encrypted_messages_key_id.up.sql rename to messaging/common/migrations/sqlite/1710331284_hash_ratchet_encrypted_messages_key_id.up.sql diff --git a/protocol/migrations/sqlite/1712745141_hash_ratchet_encrypted_messages_key_id.up.sql b/messaging/common/migrations/sqlite/1712745141_hash_ratchet_encrypted_messages_key_id.up.sql similarity index 100% rename from protocol/migrations/sqlite/1712745141_hash_ratchet_encrypted_messages_key_id.up.sql rename to messaging/common/migrations/sqlite/1712745141_hash_ratchet_encrypted_messages_key_id.up.sql diff --git a/protocol/migrations/sqlite/1747333628_drop_unused_waku_message_fields.up.sql b/messaging/common/migrations/sqlite/1747333628_drop_unused_waku_message_fields.up.sql similarity index 100% rename from protocol/migrations/sqlite/1747333628_drop_unused_waku_message_fields.up.sql rename to messaging/common/migrations/sqlite/1747333628_drop_unused_waku_message_fields.up.sql diff --git a/messaging/common/migrations/sqlite/doc.go b/messaging/common/migrations/sqlite/doc.go new file mode 100644 index 00000000000..a26a30c6799 --- /dev/null +++ b/messaging/common/migrations/sqlite/doc.go @@ -0,0 +1,9 @@ +// This file is necessary because "github.com/status-im/migrate/v4" +// can't handle files starting with a prefix. At least that's the case +// for go-bindata. +// If go-bindata is called from the same directory, asset names +// have no prefix and "github.com/status-im/migrate/v4" works as expected. + +package sqlite + +//go:generate go tool go-bindata -modtime=1700000000 -pkg migrations -o ../migrations.go . diff --git a/messaging/common/persistence_stub_test.go b/messaging/common/persistence_stub_test.go deleted file mode 100644 index f939ca3821d..00000000000 --- a/messaging/common/persistence_stub_test.go +++ /dev/null @@ -1,278 +0,0 @@ -package common - -import ( - "crypto/ecdsa" - "encoding/hex" - "sort" - "sync" - - "github.com/jinzhu/copier" - "google.golang.org/protobuf/proto" - - "github.com/status-im/status-go/crypto" - "github.com/status-im/status-go/messaging/types" - "github.com/status-im/status-go/protocol/protobuf" -) - -// StubPersistence is an in-memory implementation of types.Persistence for testing. -type StubPersistence struct { - mu sync.Mutex - - wakuKeys map[string][]byte - - messageCache map[string]uint64 - - hashRatchetMessages map[string]*types.ReceivedMessage // hash -> received message - hashRatchetMessagesByKeyID map[string][]*types.ReceivedMessage // keyID -> received messages - - messageSegments map[string]map[string][]*types.SegmentMessage // hash+pubkey -> segments - completedSegments map[string]struct{} // hash -} - -var _ types.Persistence = (*StubPersistence)(nil) - -func NewStubPersistence() *StubPersistence { - return &StubPersistence{ - wakuKeys: make(map[string][]byte), - messageCache: make(map[string]uint64), - hashRatchetMessages: make(map[string]*types.ReceivedMessage), - hashRatchetMessagesByKeyID: make(map[string][]*types.ReceivedMessage), - messageSegments: make(map[string]map[string][]*types.SegmentMessage), - completedSegments: make(map[string]struct{}), - } -} - -func (s *StubPersistence) WakuKeys() (map[string][]byte, error) { - s.mu.Lock() - defer s.mu.Unlock() - - copy := make(map[string][]byte, len(s.wakuKeys)) - err := copier.Copy(©, s.wakuKeys) - if err != nil { - return nil, err - } - - return copy, nil -} - -func (s *StubPersistence) AddWakuKey(chatID string, key []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - - copy := make([]byte, 0, len(key)) - err := copier.Copy(©, key) - if err != nil { - return err - } - - s.wakuKeys[chatID] = copy - return nil -} - -func (s *StubPersistence) MessageCacheAdd(ids []string, timestamp uint64) error { - s.mu.Lock() - defer s.mu.Unlock() - - for _, id := range ids { - s.messageCache[id] = timestamp - } - return nil -} - -func (s *StubPersistence) MessageCacheClear() error { - s.mu.Lock() - defer s.mu.Unlock() - - s.messageCache = make(map[string]uint64) - return nil -} - -func (s *StubPersistence) MessageCacheClearOlderThan(timestamp uint64) error { - s.mu.Lock() - defer s.mu.Unlock() - - for id, ts := range s.messageCache { - if ts < timestamp { - delete(s.messageCache, id) - } - } - return nil -} - -func (s *StubPersistence) MessageCacheHits(ids []string) (map[string]bool, error) { - s.mu.Lock() - defer s.mu.Unlock() - - hits := make(map[string]bool) - for _, id := range ids { - _, ok := s.messageCache[id] - hits[id] = ok - } - return hits, nil -} - -func (s *StubPersistence) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *types.ReceivedMessage) error { - s.mu.Lock() - defer s.mu.Unlock() - - copy := &types.ReceivedMessage{} - err := copier.Copy(copy, m) - if err != nil { - return err - } - - hash := hex.EncodeToString(copy.Hash) - key := hex.EncodeToString(keyID) - s.hashRatchetMessages[hash] = copy - s.hashRatchetMessagesByKeyID[key] = append(s.hashRatchetMessagesByKeyID[key], copy) - - return nil -} - -func (s *StubPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.ReceivedMessage, error) { - s.mu.Lock() - defer s.mu.Unlock() - - key := hex.EncodeToString(keyID) - msgs := s.hashRatchetMessagesByKeyID[key] - - copy := make([]*types.ReceivedMessage, 0, len(msgs)) - err := copier.Copy(©, msgs) - if err != nil { - return nil, err - } - - return copy, nil -} - -func (s *StubPersistence) DeleteHashRatchetMessages(ids [][]byte) error { - s.mu.Lock() - defer s.mu.Unlock() - - for _, id := range ids { - hash := hex.EncodeToString(id) - msg, ok := s.hashRatchetMessages[hash] - if ok { - // Remove from hashRatchetMessagesByKeyID as well - for key, arr := range s.hashRatchetMessagesByKeyID { - for i, m := range arr { - if m == msg { - s.hashRatchetMessagesByKeyID[key] = append(arr[:i], arr[i+1:]...) - break - } - } - } - delete(s.hashRatchetMessages, hash) - } - } - return nil -} - -func (s *StubPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) { - s.mu.Lock() - defer s.mu.Unlock() - - _, exists := s.completedSegments[string(hash)] - return exists, nil -} - -func (s *StubPersistence) SaveMessageSegment(segment *types.SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error { - s.mu.Lock() - defer s.mu.Unlock() - - copy := &types.SegmentMessage{ - SegmentMessage: proto.Clone(segment.SegmentMessage).(*protobuf.SegmentMessage), - } - - hash := string(segment.EntireMessageHash) - pubKey := string(crypto.CompressPubkey(sigPubKey)) - if s.messageSegments[hash] == nil { - s.messageSegments[hash] = make(map[string][]*types.SegmentMessage) - } - s.messageSegments[hash][pubKey] = append(s.messageSegments[hash][pubKey], copy) - return nil -} - -func (s *StubPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*types.SegmentMessage, error) { - s.mu.Lock() - defer s.mu.Unlock() - - segments := s.messageSegments[string(hash)][string(crypto.CompressPubkey(sigPubKey))] - copy := make([]*types.SegmentMessage, 0, len(segments)) - for _, seg := range segments { - cloned := &types.SegmentMessage{ - SegmentMessage: proto.Clone(seg.SegmentMessage).(*protobuf.SegmentMessage), - } - copy = append(copy, cloned) - } - - // Sort segments: non-parity first, then by index, then by parity index - sort.SliceStable(copy, func(i, j int) bool { - si, sj := copy[i], copy[j] - - // Non-parity segments first - if si.SegmentsCount == 0 && sj.SegmentsCount > 0 { - return false - } - if si.SegmentsCount > 0 && sj.SegmentsCount == 0 { - return true - } - - if si.SegmentsCount > 0 { - return si.Index < sj.Index - } - - return si.ParitySegmentIndex < sj.ParitySegmentIndex - }) - - return copy, nil -} - -func (s *StubPersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.completedSegments[string(hash)] = struct{}{} - - h := string(hash) - pubKey := string(crypto.CompressPubkey(sigPubKey)) - if s.messageSegments[h] != nil { - delete(s.messageSegments[h], pubKey) - if len(s.messageSegments[h]) == 0 { - delete(s.messageSegments, h) - } - } - return nil -} - -func (s *StubPersistence) DeleteHashRatchetMessagesOlderThan(timestamp int64) error { - return nil -} - -func (s *StubPersistence) InsertPendingConfirmation(*types.RawMessageConfirmation) error { - return nil -} - -func (s *StubPersistence) RemoveMessageSegmentsOlderThan(timestamp int64) error { - return nil -} - -func (s *StubPersistence) RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error { - return nil -} - -func (s *StubPersistence) WakuInsertProtectedTopic(pubsubTopic string, privKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey) error { - return nil -} - -func (s *StubPersistence) WakuDeleteProtectedTopic(pubsubTopic string) error { - return nil -} - -func (s *StubPersistence) WakuFetchPrivateKeyForProtectedTopic(topic string) (*ecdsa.PrivateKey, error) { - return nil, nil -} - -func (s *StubPersistence) WakuProtectedTopics() ([]types.ProtectedTopic, error) { - return nil, nil -} diff --git a/messaging/common/sqlite_persistence.go b/messaging/common/sqlite_persistence.go new file mode 100644 index 00000000000..67d67dc5dfd --- /dev/null +++ b/messaging/common/sqlite_persistence.go @@ -0,0 +1,140 @@ +package common + +import ( + "context" + "database/sql" + "strings" + "time" + + cryptotypes "github.com/status-im/status-go/crypto/types" + "github.com/status-im/status-go/messaging/types" +) + +type SQLiteMessageSenderPersistence struct { + db *sql.DB +} + +var _ types.MessageSenderPersistence = (*SQLiteMessageSenderPersistence)(nil) + +func NewSQLiteMessageSenderPersistence(db *sql.DB) *SQLiteMessageSenderPersistence { + return &SQLiteMessageSenderPersistence{db: db} +} + +func (p *SQLiteMessageSenderPersistence) InsertPendingConfirmation(confirmation *types.RawMessageConfirmation) error { + _, err := p.db.Exec(`INSERT INTO raw_message_confirmations + (datasync_id, message_id, public_key) + VALUES + (?,?,?)`, + confirmation.DataSyncID, + confirmation.MessageID, + confirmation.PublicKey, + ) + return err +} + +// MarkAsConfirmed marks all the messages with dataSyncID as confirmed and returns +// the messageIDs that can be considered confirmed. +// If atLeastOne is set it will return messageid if at least once of the messages +// sent has been confirmed +func (p *SQLiteMessageSenderPersistence) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID cryptotypes.HexBytes, err error) { + tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return nil, err + } + defer func() { + if err == nil { + err = tx.Commit() + return + } + // don't shadow original error + _ = tx.Rollback() + }() + + confirmedAt := time.Now().Unix() + _, err = tx.Exec(`UPDATE raw_message_confirmations SET confirmed_at = ? WHERE datasync_id = ? AND confirmed_at = 0`, confirmedAt, dataSyncID) + if err != nil { + return + } + + // Select any tuple that has a message_id with a datasync_id = ? and that has just been confirmed + rows, err := tx.Query(`SELECT message_id,confirmed_at FROM raw_message_confirmations WHERE message_id = (SELECT message_id FROM raw_message_confirmations WHERE datasync_id = ? LIMIT 1)`, dataSyncID) + if err != nil { + return + } + defer rows.Close() + + confirmedResult := true + + for rows.Next() { + var confirmedAt int64 + err = rows.Scan(&messageID, &confirmedAt) + if err != nil { + return + } + confirmed := confirmedAt > 0 + + if atLeastOne && confirmed { + // We return, as at least one was confirmed + return + } + + confirmedResult = confirmedResult && confirmed + } + + if !confirmedResult { + messageID = nil + return + } + + return +} + +func (p *SQLiteMessageSenderPersistence) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *types.ReceivedMessage) error { + _, err := p.db.Exec(`INSERT INTO hash_ratchet_encrypted_messages(hash, sig, timestamp, topic, payload, dst, padding, group_id, key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, m.Hash, m.Sig, m.Timestamp, m.Topic.Bytes(), m.Payload, m.Dst, m.Padding, groupID, keyID) + return err +} + +func (p *SQLiteMessageSenderPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.ReceivedMessage, error) { + var messages []*types.ReceivedMessage + + rows, err := p.db.Query(`SELECT hash, sig, timestamp, topic, payload, dst, padding FROM hash_ratchet_encrypted_messages WHERE key_id = ?`, keyID) + if err != nil { + return nil, err + } + + for rows.Next() { + var topic []byte + message := &types.ReceivedMessage{} + + err := rows.Scan(&message.Hash, &message.Sig, &message.Timestamp, &topic, &message.Payload, &message.Dst, &message.Padding) + if err != nil { + return nil, err + } + + message.Topic = types.BytesToContentTopic(topic) + messages = append(messages, message) + } + + return messages, nil +} + +func (p *SQLiteMessageSenderPersistence) DeleteHashRatchetMessages(ids [][]byte) error { + if len(ids) == 0 { + return nil + } + + idsArgs := make([]interface{}, 0, len(ids)) + for _, id := range ids { + idsArgs = append(idsArgs, id) + } + inVector := strings.Repeat("?, ", len(ids)-1) + "?" + + _, err := p.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec + + return err +} + +func (p *SQLiteMessageSenderPersistence) DeleteHashRatchetMessagesOlderThan(timestamp int64) error { + _, err := p.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE timestamp < ?", timestamp) + return err +} diff --git a/protocol/messaging_persistence_test.go b/messaging/common/sqlite_persistence_test.go similarity index 64% rename from protocol/messaging_persistence_test.go rename to messaging/common/sqlite_persistence_test.go index f81ba26c622..c07cff6ad45 100644 --- a/protocol/messaging_persistence_test.go +++ b/messaging/common/sqlite_persistence_test.go @@ -1,13 +1,15 @@ -package protocol +package common import ( "testing" + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/require" - "github.com/status-im/status-go/crypto" - "github.com/status-im/status-go/crypto/types" - messagingtypes "github.com/status-im/status-go/messaging/types" + cryptotypes "github.com/status-im/status-go/crypto/types" + "github.com/status-im/status-go/messaging/common/migrations" + "github.com/status-im/status-go/messaging/types" + "github.com/status-im/status-go/t/helpers" ) func TestConfirmations(t *testing.T) { @@ -23,32 +25,38 @@ func TestConfirmations(t *testing.T) { publicKey2 := []byte("pk-2") publicKey3 := []byte("pk-3") - db, err := openTestDB() + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) require.NoError(t, err) - p := NewMessagingPersistence(db) - confirmation1 := &messagingtypes.RawMessageConfirmation{ + p := NewSQLiteMessageSenderPersistence(db) + + confirmation1 := &types.RawMessageConfirmation{ DataSyncID: dataSyncID1, MessageID: messageID1, PublicKey: publicKey1, } // Same datasyncID and same messageID, different pubkey - confirmation2 := &messagingtypes.RawMessageConfirmation{ + confirmation2 := &types.RawMessageConfirmation{ DataSyncID: dataSyncID2, MessageID: messageID1, PublicKey: publicKey2, } // Different datasyncID and same messageID, different pubkey - confirmation3 := &messagingtypes.RawMessageConfirmation{ + confirmation3 := &types.RawMessageConfirmation{ DataSyncID: dataSyncID3, MessageID: messageID1, PublicKey: publicKey3, } // Same dataSyncID, different messageID - confirmation4 := &messagingtypes.RawMessageConfirmation{ + confirmation4 := &types.RawMessageConfirmation{ DataSyncID: dataSyncID4, MessageID: messageID2, PublicKey: publicKey1, @@ -72,7 +80,7 @@ func TestConfirmations(t *testing.T) { // We confirm the third datasync message, messageID1 should be confirmed messageID, err = p.MarkAsConfirmed(dataSyncID3, false) require.NoError(t, err) - require.Equal(t, messageID, types.HexBytes(messageID1)) + require.Equal(t, messageID, cryptotypes.HexBytes(messageID1)) } func TestConfirmationsAtLeastOne(t *testing.T) { @@ -86,25 +94,31 @@ func TestConfirmationsAtLeastOne(t *testing.T) { publicKey2 := []byte("pk-2") publicKey3 := []byte("pk-3") - db, err := openTestDB() + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) require.NoError(t, err) - p := NewMessagingPersistence(db) - confirmation1 := &messagingtypes.RawMessageConfirmation{ + p := NewSQLiteMessageSenderPersistence(db) + + confirmation1 := &types.RawMessageConfirmation{ DataSyncID: dataSyncID1, MessageID: messageID1, PublicKey: publicKey1, } // Same datasyncID and same messageID, different pubkey - confirmation2 := &messagingtypes.RawMessageConfirmation{ + confirmation2 := &types.RawMessageConfirmation{ DataSyncID: dataSyncID2, MessageID: messageID1, PublicKey: publicKey2, } // Different datasyncID and same messageID, different pubkey - confirmation3 := &messagingtypes.RawMessageConfirmation{ + confirmation3 := &types.RawMessageConfirmation{ DataSyncID: dataSyncID3, MessageID: messageID1, PublicKey: publicKey3, @@ -118,19 +132,25 @@ func TestConfirmationsAtLeastOne(t *testing.T) { messageID, err := p.MarkAsConfirmed(dataSyncID1, true) require.NoError(t, err) require.NotNil(t, messageID) - require.Equal(t, types.HexBytes(messageID1), messageID) + require.Equal(t, cryptotypes.HexBytes(messageID1), messageID) } func TestSaveHashRatchetMessage(t *testing.T) { - db, err := openTestDB() + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) require.NoError(t, err) - p := NewMessagingPersistence(db) + + p := NewSQLiteMessageSenderPersistence(db) groupID1 := []byte("group-id-1") groupID2 := []byte("group-id-2") keyID := []byte("key-id") - message1 := &messagingtypes.ReceivedMessage{ + message1 := &types.ReceivedMessage{ Hash: []byte{1}, Sig: []byte{2}, Timestamp: 2, @@ -139,10 +159,10 @@ func TestSaveHashRatchetMessage(t *testing.T) { require.NoError(t, p.SaveHashRatchetMessage(groupID1, keyID, message1)) - message2 := &messagingtypes.ReceivedMessage{ + message2 := &types.ReceivedMessage{ Hash: []byte{2}, Sig: []byte{2}, - Topic: messagingtypes.BytesToContentTopic([]byte{5}), + Topic: types.BytesToContentTopic([]byte{5}), Timestamp: 2, Payload: []byte{3}, Dst: []byte{4}, @@ -157,14 +177,20 @@ func TestSaveHashRatchetMessage(t *testing.T) { } func TestDeleteHashRatchetMessage(t *testing.T) { - db, err := openTestDB() + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) require.NoError(t, err) - p := NewMessagingPersistence(db) + + p := NewSQLiteMessageSenderPersistence(db) groupID := []byte("group-id") keyID := []byte("key-id") - message1 := &messagingtypes.ReceivedMessage{ + message1 := &types.ReceivedMessage{ Hash: []byte{1}, Sig: []byte{2}, Timestamp: 2, @@ -173,10 +199,10 @@ func TestDeleteHashRatchetMessage(t *testing.T) { require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message1)) - message2 := &messagingtypes.ReceivedMessage{ + message2 := &types.ReceivedMessage{ Hash: []byte{2}, Sig: []byte{2}, - Topic: messagingtypes.BytesToContentTopic([]byte{5}), + Topic: types.BytesToContentTopic([]byte{5}), Timestamp: 2, Payload: []byte{3}, Dst: []byte{4}, @@ -184,10 +210,10 @@ func TestDeleteHashRatchetMessage(t *testing.T) { require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message2)) - message3 := &messagingtypes.ReceivedMessage{ + message3 := &types.ReceivedMessage{ Hash: []byte{3}, Sig: []byte{2}, - Topic: messagingtypes.BytesToContentTopic([]byte{5}), + Topic: types.BytesToContentTopic([]byte{5}), Timestamp: 2, Payload: []byte{3}, Dst: []byte{4}, @@ -207,45 +233,3 @@ func TestDeleteHashRatchetMessage(t *testing.T) { require.NotNil(t, fetchedMessages) require.Len(t, fetchedMessages, 1) } - -func TestWakuProtectedTopicPersistence(t *testing.T) { - db, err := openTestDB() - require.NoError(t, err) - p := NewMessagingPersistence(db) - - // Generate ECDSA keys - privKey, err := crypto.GenerateKey() - require.NoError(t, err) - pubKey := &privKey.PublicKey - - pubsubTopic := "test-topic" - - // Insert protected topic - err = p.WakuInsertProtectedTopic(pubsubTopic, privKey, pubKey) - require.NoError(t, err) - - // Fetch private key for topic - fetchedPrivKey, err := p.WakuFetchPrivateKeyForProtectedTopic(pubsubTopic) - require.NoError(t, err) - require.NotNil(t, fetchedPrivKey) - require.Equal(t, privKey.D.Bytes(), fetchedPrivKey.D.Bytes()) - - // Fetch protected topics - topics, err := p.WakuProtectedTopics() - require.NoError(t, err) - require.Len(t, topics, 1) - require.Equal(t, pubsubTopic, topics[0].Topic) - - // Delete protected topic - err = p.WakuDeleteProtectedTopic(pubsubTopic) - require.NoError(t, err) - - // Ensure topic is deleted - topics, err = p.WakuProtectedTopics() - require.NoError(t, err) - require.Len(t, topics, 0) - - fetchedPrivKey, err = p.WakuFetchPrivateKeyForProtectedTopic(pubsubTopic) - require.NoError(t, err) - require.Nil(t, fetchedPrivKey) -} diff --git a/messaging/core.go b/messaging/core.go index a3df35ce6ba..6c4708512cf 100644 --- a/messaging/core.go +++ b/messaging/core.go @@ -3,7 +3,6 @@ package messaging import ( "context" "crypto/ecdsa" - "database/sql" "sync" "time" @@ -12,9 +11,9 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" - datasyncnode "github.com/status-im/mvds/node" - datasyncproto "github.com/status-im/mvds/protobuf" - "github.com/status-im/mvds/state" + mvdsnode "github.com/status-im/mvds/node" + mvdsproto "github.com/status-im/mvds/protobuf" + mvdsstate "github.com/status-im/mvds/state" gocommon "github.com/status-im/status-go/common" "github.com/status-im/status-go/connection" @@ -43,8 +42,6 @@ var ( type Core struct { config - persistence types.Persistence - identity *ecdsa.PrivateKey waku wakutypes.Waku transport *transport.Transport @@ -55,7 +52,7 @@ type Core struct { quit chan struct{} connectionState connection.State - mvdsStatusChangeEvent chan datasyncnode.PeerStatusChangeEvent + mvdsStatusChangeEvent chan mvdsnode.PeerStatusChangeEvent publisher *pubsub.Publisher @@ -66,9 +63,6 @@ type CoreParams struct { Identity *ecdsa.PrivateKey InstallationID string - DB *sql.DB // FIXME: This should be removed once the database is not needed in the sender - Persistence types.Persistence - NodeKey *ecdsa.PrivateKey WakuConfig params.WakuV2Config ClusterConfig params.ClusterConfig @@ -80,8 +74,8 @@ func newCore(waku wakutypes.Waku, params CoreParams, config *config) (*Core, err transport, err := transport.NewTransport( waku, params.Identity, - &adapters.KeysPersistence{P: params.Persistence}, - &adapters.ProcessedMessageIDsCache{P: params.Persistence}, + config.persistence.TransportStorage().KeysStorage(), + config.persistence.TransportStorage().ProcessedMessageIDsCacheStorage(), config.envelopesMonitorConfig, config.logger, ) @@ -90,15 +84,16 @@ func newCore(waku wakutypes.Waku, params CoreParams, config *config) (*Core, err } encryptor := encryption.New( - params.DB, + config.persistence.EncryptionStorage(), params.InstallationID, config.logger, ) sender, err := common.NewMessageSender( params.Identity, - params.DB, - params.Persistence, + config.persistence.MessageSenderStorage(), + config.persistence.MVDSStorage(), + config.persistence.SegmentationStorage(), transport, encryptor, config.logger, @@ -109,14 +104,13 @@ func newCore(waku wakutypes.Waku, params CoreParams, config *config) (*Core, err return &Core{ config: *config, - persistence: params.Persistence, identity: params.Identity, waku: waku, transport: transport, sender: sender, encryptor: encryptor, quit: make(chan struct{}), - mvdsStatusChangeEvent: make(chan datasyncnode.PeerStatusChangeEvent, 5), + mvdsStatusChangeEvent: make(chan mvdsnode.PeerStatusChangeEvent, 5), publisher: pubsub.NewPublisher(), }, nil } @@ -124,8 +118,12 @@ func newCore(waku wakutypes.Waku, params CoreParams, config *config) (*Core, err func NewCore(params CoreParams, options ...Options) (*Core, error) { config := newConfig(options...) + if config.persistence == nil { + return nil, errors.New("persistence is not configured") + } + waku, err := newWaku(wakuParams{ - persistence: params.Persistence, + persistence: config.persistence.WakuStorage(), identity: params.Identity, nodeKey: params.NodeKey, wakuConfig: params.WakuConfig, @@ -256,7 +254,7 @@ func (c *Core) stop() error { return nil } -func (c *Core) sendDataSync(receiver state.PeerID, payload *datasyncproto.Payload) error { +func (c *Core) sendDataSync(receiver mvdsstate.PeerID, payload *mvdsproto.Payload) error { ctx := context.Background() if !payload.IsValid() { c.logger.Error("payload is invalid") @@ -329,9 +327,9 @@ func (c *Core) connectionChanged(state connection.State) { func (c *Core) resetDatasyncForPeer(publicKey *ecdsa.PublicKey, eventTime uint64) { select { - case c.mvdsStatusChangeEvent <- datasyncnode.PeerStatusChangeEvent{ + case c.mvdsStatusChangeEvent <- mvdsnode.PeerStatusChangeEvent{ PeerID: datasyncpeer.PublicKeyToPeerID(*publicKey), - Status: datasyncnode.OnlineStatus, + Status: mvdsnode.OnlineStatus, EventTime: eventTime, }: default: @@ -369,7 +367,7 @@ func (c *Core) startCleanupLoop(name string, cleanupFunc func() error) { } type wakuParams struct { - persistence types.Persistence + persistence wakuv2.ProtectedTopicsPersistence identity *ecdsa.PrivateKey nodeKey *ecdsa.PrivateKey @@ -425,7 +423,7 @@ func newWaku(params wakuParams) (*wakuv2.Waku, error) { params.nodeKey, cfg, params.logger, - &adapters.WakuProtectedTopics{P: params.persistence}, + params.persistence, params.timeSource, params.onHistoricMessagesRequestFailed, params.onPeerStats, diff --git a/messaging/core_config.go b/messaging/core_config.go index 207034b55ca..1e72b3c6da5 100644 --- a/messaging/core_config.go +++ b/messaging/core_config.go @@ -1,6 +1,8 @@ package messaging import ( + "database/sql" + "go.uber.org/zap" "github.com/libp2p/go-libp2p/core/peer" @@ -16,6 +18,7 @@ type config struct { metricsEnabled bool onHistoricMessagesRequestFailed func([]byte, peer.AddrInfo, error) onPeerStats func(types.ConnStatus) + persistence Persistence } func newConfig(options ...Options) *config { @@ -74,3 +77,18 @@ func WithPeerStatsHandler(onPeerStats func(types.ConnStatus)) Options { c.onPeerStats = onPeerStats } } + +// WithSQLitePersistence sets up the messaging persistence using internal SQLite implementation. +// Migrations must be applied beforehand. See SQLiteMigrate. +func WithSQLitePersistence(db *sql.DB) Options { + return func(c *config) { + c.persistence = newSQLitePersistence(db) + } +} + +// WithPersistence sets up the messaging persistence using the provided implementation. +func WithPersistence(persistence Persistence) Options { + return func(c *config) { + c.persistence = persistence + } +} diff --git a/messaging/layers/encryption/encryption_multi_device_test.go b/messaging/layers/encryption/encryption_multi_device_test.go index e2bf006813d..317b68ed841 100644 --- a/messaging/layers/encryption/encryption_multi_device_test.go +++ b/messaging/layers/encryption/encryption_multi_device_test.go @@ -5,17 +5,15 @@ import ( "fmt" "testing" - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/protocol/sqlite" - "github.com/status-im/status-go/protocol/tt" - "github.com/status-im/status-go/t/helpers" - + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/suite" "go.uber.org/zap" "github.com/status-im/status-go/crypto" - + "github.com/status-im/status-go/messaging/layers/encryption/migrations" "github.com/status-im/status-go/messaging/layers/encryption/multidevice" + "github.com/status-im/status-go/protocol/tt" + "github.com/status-im/status-go/t/helpers" ) const ( @@ -52,17 +50,18 @@ func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error { for i := 0; i < n; i++ { installationID := fmt.Sprintf("%s%d", user, i+1) - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - if err != nil { - return err - } - err = sqlite.Migrate(db) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) if err != nil { return err } protocol := New( - db, + NewSQLitePersistence(db), installationID, s.logger.With(zap.String("user", user)), ) diff --git a/messaging/layers/encryption/encryption_test.go b/messaging/layers/encryption/encryption_test.go index 10baff7ae42..8e2c6bdbebd 100644 --- a/messaging/layers/encryption/encryption_test.go +++ b/messaging/layers/encryption/encryption_test.go @@ -8,15 +8,13 @@ import ( "time" "github.com/golang/protobuf/proto" - - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/protocol/sqlite" - "github.com/status-im/status-go/t/helpers" - + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/suite" "go.uber.org/zap" "github.com/status-im/status-go/crypto" + "github.com/status-im/status-go/messaging/layers/encryption/migrations" + "github.com/status-im/status-go/t/helpers" ) var cleartext = []byte("hello") @@ -38,25 +36,33 @@ type EncryptionServiceTestSuite struct { func (s *EncryptionServiceTestSuite) initDatabases(config encryptorConfig) { var err error - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) s.Require().NoError(err) + config.InstallationID = aliceInstallationID s.alice = NewWithEncryptorConfig( - db, + NewSQLitePersistence(db), aliceInstallationID, config, s.logger.With(zap.String("user", "alice")), ) - db, err = helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) + db, err = helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) s.Require().NoError(err) + config.InstallationID = bobInstallationID s.bob = NewWithEncryptorConfig( - db, + NewSQLitePersistence(db), bobInstallationID, config, s.logger.With(zap.String("user", "bob")), @@ -1014,13 +1020,16 @@ func (s *EncryptionServiceTestSuite) TestHashRatchetCompatibility() { keyID1, err := s.bob.encryptor.GenerateHashRatchetKey(groupID) s.Require().NoError(err) + alicePersistence, ok := s.alice.encryptor.persistence.(*SQLitePersistence) + s.Require().True(ok) + // We replicate the same error condition timestamp32 := keyID1.DeprecatedKeyID() - _, err = s.alice.encryptor.persistence.DB.Exec("INSERT INTO hash_ratchet_encryption(group_id, deprecated_key_id, key, key_id) VALUES(?,?,?,?)", groupID, timestamp32, keyID1.Key, append(groupID, []byte("some-bytes")...)) + _, err = alicePersistence.DB.Exec("INSERT INTO hash_ratchet_encryption(group_id, deprecated_key_id, key, key_id) VALUES(?,?,?,?)", groupID, timestamp32, keyID1.Key, append(groupID, []byte("some-bytes")...)) s.Require().NoError(err) // We migrate - _, err = s.alice.encryptor.persistence.DB.Exec("UPDATE hash_ratchet_encryption SET key_timestamp = deprecated_key_id") + _, err = alicePersistence.DB.Exec("UPDATE hash_ratchet_encryption SET key_timestamp = deprecated_key_id") s.Require().NoError(err) payload1 := []byte("community msg 1") diff --git a/messaging/layers/encryption/encryptor.go b/messaging/layers/encryption/encryptor.go index 87925166fd9..ae4231a452b 100644 --- a/messaging/layers/encryption/encryptor.go +++ b/messaging/layers/encryption/encryptor.go @@ -2,7 +2,6 @@ package encryption import ( "crypto/ecdsa" - "database/sql" "encoding/hex" "errors" "sync" @@ -43,7 +42,7 @@ type confirmationData struct { // encryptor defines a service that is responsible for the encryption aspect of the protocol. type encryptor struct { - persistence *sqlitePersistence + persistence Persistence config encryptorConfig messageIDs map[string]*confirmationData mutex sync.Mutex @@ -84,9 +83,9 @@ func defaultEncryptorConfig(installationID string, logger *zap.Logger) encryptor } // newEncryptor creates a new EncryptionService instance. -func newEncryptor(db *sql.DB, config encryptorConfig) *encryptor { +func newEncryptor(persistence Persistence, config encryptorConfig) *encryptor { return &encryptor{ - persistence: newSQLitePersistence(db), + persistence: persistence, config: config, messageIDs: make(map[string]*confirmationData), logger: config.Logger.With(zap.Namespace("encryptor")), diff --git a/messaging/layers/encryption/hash_ratchet_key_compatibility.go b/messaging/layers/encryption/hash_ratchet_key_compatibility.go new file mode 100644 index 00000000000..6522f68f072 --- /dev/null +++ b/messaging/layers/encryption/hash_ratchet_key_compatibility.go @@ -0,0 +1,68 @@ +package encryption + +import ( + "errors" + + "github.com/status-im/status-go/crypto" +) + +type HashRatchetKeyCompatibility struct { + GroupID []byte + keyID []byte + Timestamp uint64 + Key []byte +} + +func (h *HashRatchetKeyCompatibility) DeprecatedKeyID() uint32 { + return uint32(h.Timestamp) +} + +func (h *HashRatchetKeyCompatibility) IsOldFormat() bool { + return len(h.keyID) == 0 && len(h.Key) == 0 +} + +func (h *HashRatchetKeyCompatibility) GetKeyID() ([]byte, error) { + if len(h.keyID) != 0 { + return h.keyID, nil + } + + if len(h.GroupID) == 0 || h.Timestamp == 0 || len(h.Key) == 0 { + return nil, errors.New("could not create key") + } + + return generateHashRatchetKeyID(h.GroupID, h.Timestamp, h.Key), nil +} + +func (h *HashRatchetKeyCompatibility) GenerateNext() (*HashRatchetKeyCompatibility, error) { + + ratchet := &HashRatchetKeyCompatibility{ + GroupID: h.GroupID, + } + + // Randomly generate a hash ratchet key + hrKey, err := crypto.GenerateKey() + if err != nil { + return nil, err + } + hrKeyBytes := crypto.FromECDSA(hrKey) + + if err != nil { + return nil, err + } + + currentTime := GetCurrentTime() + if h.Timestamp < currentTime { + ratchet.Timestamp = bumpKeyID(currentTime) + } else { + ratchet.Timestamp = h.Timestamp + 1 + } + + ratchet.Key = hrKeyBytes + + _, err = ratchet.GetKeyID() + if err != nil { + return nil, err + } + + return ratchet, nil +} diff --git a/protocol/migrations/sqlite/1698137562_fix_encryption_key_id.up.sql b/messaging/layers/encryption/migrations/sqlite/1698137562_fix_encryption_key_id.up.sql similarity index 100% rename from protocol/migrations/sqlite/1698137562_fix_encryption_key_id.up.sql rename to messaging/layers/encryption/migrations/sqlite/1698137562_fix_encryption_key_id.up.sql diff --git a/messaging/layers/encryption/multidevice/multidevice.go b/messaging/layers/encryption/multidevice/multidevice.go index 1468d96ba8b..5c973b45cab 100644 --- a/messaging/layers/encryption/multidevice/multidevice.go +++ b/messaging/layers/encryption/multidevice/multidevice.go @@ -2,7 +2,6 @@ package multidevice import ( "crypto/ecdsa" - "database/sql" "github.com/status-im/status-go/crypto" ) @@ -42,14 +41,14 @@ type Config struct { } type Multidevice struct { - persistence *sqlitePersistence + persistence Persistence config *Config } -func New(db *sql.DB, config *Config) *Multidevice { +func New(persistence Persistence, config *Config) *Multidevice { return &Multidevice{ config: config, - persistence: newSQLitePersistence(db), + persistence: persistence, } } diff --git a/messaging/layers/encryption/multidevice/persistence.go b/messaging/layers/encryption/multidevice/persistence.go index 37e10f00216..059e64a2324 100644 --- a/messaging/layers/encryption/multidevice/persistence.go +++ b/messaging/layers/encryption/multidevice/persistence.go @@ -1,282 +1,11 @@ package multidevice -import "database/sql" - -type sqlitePersistence struct { - db *sql.DB -} - -func newSQLitePersistence(db *sql.DB) *sqlitePersistence { - return &sqlitePersistence{db: db} -} - -// GetActiveInstallations returns the active installations for a given identity -func (s *sqlitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) { - stmt, err := s.db.Prepare(`SELECT installation_id, version - FROM installations - WHERE enabled = 1 AND identity = ? - ORDER BY timestamp DESC - LIMIT ?`) - if err != nil { - return nil, err - } - defer stmt.Close() - - var installations []*Installation - rows, err := stmt.Query(identity, maxInstallations) - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var ( - installationID string - version uint32 - ) - err = rows.Scan( - &installationID, - &version, - ) - if err != nil { - return nil, err - } - installations = append(installations, &Installation{ - ID: installationID, - Version: version, - Enabled: true, - }) - - } - - return installations, nil - -} - -// GetInstallations returns all the installations for a given identity -// we both return the installations & the metadata -// metadata is currently stored in a separate table, as in some cases we -// might have metadata for a device, but no other information on the device -func (s *sqlitePersistence) GetInstallations(identity []byte) ([]*Installation, error) { - installationMap := make(map[string]*Installation) - var installations []*Installation - - // We query both tables as sqlite does not support full outer joins - installationsStmt, err := s.db.Prepare(`SELECT installation_id, version, enabled, timestamp FROM installations WHERE identity = ?`) - if err != nil { - return nil, err - } - defer installationsStmt.Close() - - installationRows, err := installationsStmt.Query(identity) - if err != nil { - return nil, err - } - - for installationRows.Next() { - var installation Installation - err = installationRows.Scan( - &installation.ID, - &installation.Version, - &installation.Enabled, - &installation.Timestamp, - ) - if err != nil { - return nil, err - } - // We initialized to empty in this case as we want to - // return metadata as well in this endpoint, but not in others - installation.InstallationMetadata = &InstallationMetadata{} - installationMap[installation.ID] = &installation - } - - metadataStmt, err := s.db.Prepare(`SELECT installation_id, name, device_type, fcm_token FROM installation_metadata WHERE identity = ?`) - if err != nil { - return nil, err - } - defer metadataStmt.Close() - - metadataRows, err := metadataStmt.Query(identity) - if err != nil { - return nil, err - } - - for metadataRows.Next() { - var ( - installationID string - name sql.NullString - deviceType sql.NullString - fcmToken sql.NullString - installation *Installation - ) - err = metadataRows.Scan( - &installationID, - &name, - &deviceType, - &fcmToken, - ) - if err != nil { - return nil, err - } - if _, ok := installationMap[installationID]; ok { - installation = installationMap[installationID] - } else { - installation = &Installation{ID: installationID} - } - installation.InstallationMetadata = &InstallationMetadata{ - Name: name.String, - DeviceType: deviceType.String, - FCMToken: fcmToken.String, - } - installationMap[installationID] = installation - } - - for _, installation := range installationMap { - installations = append(installations, installation) - } - - return installations, nil -} - -// AddInstallations adds the installations for a given identity, maintaining the enabled flag -func (s *sqlitePersistence) AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) ([]*Installation, error) { - tx, err := s.db.Begin() - if err != nil { - return nil, err - } - - var insertedInstallations []*Installation - - for _, installation := range installations { - stmt, err := tx.Prepare(`SELECT enabled, version - FROM installations - WHERE identity = ? AND installation_id = ? - LIMIT 1`) - if err != nil { - return nil, err - } - defer stmt.Close() - - var oldEnabled bool - // We don't override version once we saw one - var oldVersion uint32 - latestVersion := installation.Version - - err = stmt.QueryRow(identity, installation.ID).Scan(&oldEnabled, &oldVersion) - if err != nil && err != sql.ErrNoRows { - return nil, err - } - - if err == sql.ErrNoRows { - stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled, version) - VALUES (?, ?, ?, ?, ?)`) - if err != nil { - return nil, err - } - defer stmt.Close() - - _, err = stmt.Exec( - identity, - installation.ID, - timestamp, - defaultEnabled, - latestVersion, - ) - if err != nil { - return nil, err - } - insertedInstallations = append(insertedInstallations, installation) - } else { - // We update timestamp if present without changing enabled, only if this is a new bundle - // and we set the version to the latest we ever saw - if oldVersion > installation.Version { - latestVersion = oldVersion - } - - stmt, err = tx.Prepare(`UPDATE installations - SET timestamp = ?, enabled = ?, version = ? - WHERE identity = ? - AND installation_id = ? - AND timestamp < ?`) - if err != nil { - return nil, err - } - defer stmt.Close() - - _, err = stmt.Exec( - timestamp, - oldEnabled, - latestVersion, - identity, - installation.ID, - timestamp, - ) - if err != nil { - return nil, err - } - } - - } - - if err := tx.Commit(); err != nil { - _ = tx.Rollback() - return nil, err - } - - return insertedInstallations, nil - -} - -// EnableInstallation enables the installation -func (s *sqlitePersistence) EnableInstallation(identity []byte, installationID string) error { - stmt, err := s.db.Prepare(`UPDATE installations - SET enabled = 1 - WHERE identity = ? AND installation_id = ?`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(identity, installationID) - return err - -} - -// DisableInstallation disable the installation -func (s *sqlitePersistence) DisableInstallation(identity []byte, installationID string) error { - stmt, err := s.db.Prepare(`UPDATE installations - SET enabled = 0 - WHERE identity = ? AND installation_id = ?`) - if err != nil { - return err - } - defer stmt.Close() - _, err = stmt.Exec(identity, installationID) - return err -} - -// SetInstallationMetadata sets the metadata for a given installation -func (s *sqlitePersistence) SetInstallationMetadata(identity []byte, installationID string, metadata *InstallationMetadata) error { - stmt, err := s.db.Prepare(`INSERT INTO installation_metadata(name, device_type, fcm_token, identity, installation_id) VALUES(?,?,?,?,?)`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(metadata.Name, metadata.DeviceType, metadata.FCMToken, identity, installationID) - return err -} - -// SetInstallationName sets the only the name in metadata for a given installation -func (s *sqlitePersistence) SetInstallationName(identity []byte, installationID string, name string) error { - stmt, err := s.db.Prepare(`UPDATE installation_metadata - SET name = ? - WHERE identity = ? AND installation_id = ?`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(name, identity, installationID) - return err +type Persistence interface { + GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) + GetInstallations(identity []byte) ([]*Installation, error) + AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) ([]*Installation, error) + EnableInstallation(identity []byte, installationID string) error + DisableInstallation(identity []byte, installationID string) error + SetInstallationMetadata(identity []byte, installationID string, metadata *InstallationMetadata) error + SetInstallationName(identity []byte, installationID string, name string) error } diff --git a/messaging/layers/encryption/multidevice/sqlite_persistence.go b/messaging/layers/encryption/multidevice/sqlite_persistence.go new file mode 100644 index 00000000000..9c6252eaa0a --- /dev/null +++ b/messaging/layers/encryption/multidevice/sqlite_persistence.go @@ -0,0 +1,282 @@ +package multidevice + +import "database/sql" + +type SQLitePersistence struct { + db *sql.DB +} + +func NewSQLitePersistence(db *sql.DB) *SQLitePersistence { + return &SQLitePersistence{db: db} +} + +// GetActiveInstallations returns the active installations for a given identity +func (s *SQLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]*Installation, error) { + stmt, err := s.db.Prepare(`SELECT installation_id, version + FROM installations + WHERE enabled = 1 AND identity = ? + ORDER BY timestamp DESC + LIMIT ?`) + if err != nil { + return nil, err + } + defer stmt.Close() + + var installations []*Installation + rows, err := stmt.Query(identity, maxInstallations) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var ( + installationID string + version uint32 + ) + err = rows.Scan( + &installationID, + &version, + ) + if err != nil { + return nil, err + } + installations = append(installations, &Installation{ + ID: installationID, + Version: version, + Enabled: true, + }) + + } + + return installations, nil + +} + +// GetInstallations returns all the installations for a given identity +// we both return the installations & the metadata +// metadata is currently stored in a separate table, as in some cases we +// might have metadata for a device, but no other information on the device +func (s *SQLitePersistence) GetInstallations(identity []byte) ([]*Installation, error) { + installationMap := make(map[string]*Installation) + var installations []*Installation + + // We query both tables as sqlite does not support full outer joins + installationsStmt, err := s.db.Prepare(`SELECT installation_id, version, enabled, timestamp FROM installations WHERE identity = ?`) + if err != nil { + return nil, err + } + defer installationsStmt.Close() + + installationRows, err := installationsStmt.Query(identity) + if err != nil { + return nil, err + } + + for installationRows.Next() { + var installation Installation + err = installationRows.Scan( + &installation.ID, + &installation.Version, + &installation.Enabled, + &installation.Timestamp, + ) + if err != nil { + return nil, err + } + // We initialized to empty in this case as we want to + // return metadata as well in this endpoint, but not in others + installation.InstallationMetadata = &InstallationMetadata{} + installationMap[installation.ID] = &installation + } + + metadataStmt, err := s.db.Prepare(`SELECT installation_id, name, device_type, fcm_token FROM installation_metadata WHERE identity = ?`) + if err != nil { + return nil, err + } + defer metadataStmt.Close() + + metadataRows, err := metadataStmt.Query(identity) + if err != nil { + return nil, err + } + + for metadataRows.Next() { + var ( + installationID string + name sql.NullString + deviceType sql.NullString + fcmToken sql.NullString + installation *Installation + ) + err = metadataRows.Scan( + &installationID, + &name, + &deviceType, + &fcmToken, + ) + if err != nil { + return nil, err + } + if _, ok := installationMap[installationID]; ok { + installation = installationMap[installationID] + } else { + installation = &Installation{ID: installationID} + } + installation.InstallationMetadata = &InstallationMetadata{ + Name: name.String, + DeviceType: deviceType.String, + FCMToken: fcmToken.String, + } + installationMap[installationID] = installation + } + + for _, installation := range installationMap { + installations = append(installations, installation) + } + + return installations, nil +} + +// AddInstallations adds the installations for a given identity, maintaining the enabled flag +func (s *SQLitePersistence) AddInstallations(identity []byte, timestamp int64, installations []*Installation, defaultEnabled bool) ([]*Installation, error) { + tx, err := s.db.Begin() + if err != nil { + return nil, err + } + + var insertedInstallations []*Installation + + for _, installation := range installations { + stmt, err := tx.Prepare(`SELECT enabled, version + FROM installations + WHERE identity = ? AND installation_id = ? + LIMIT 1`) + if err != nil { + return nil, err + } + defer stmt.Close() + + var oldEnabled bool + // We don't override version once we saw one + var oldVersion uint32 + latestVersion := installation.Version + + err = stmt.QueryRow(identity, installation.ID).Scan(&oldEnabled, &oldVersion) + if err != nil && err != sql.ErrNoRows { + return nil, err + } + + if err == sql.ErrNoRows { + stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled, version) + VALUES (?, ?, ?, ?, ?)`) + if err != nil { + return nil, err + } + defer stmt.Close() + + _, err = stmt.Exec( + identity, + installation.ID, + timestamp, + defaultEnabled, + latestVersion, + ) + if err != nil { + return nil, err + } + insertedInstallations = append(insertedInstallations, installation) + } else { + // We update timestamp if present without changing enabled, only if this is a new bundle + // and we set the version to the latest we ever saw + if oldVersion > installation.Version { + latestVersion = oldVersion + } + + stmt, err = tx.Prepare(`UPDATE installations + SET timestamp = ?, enabled = ?, version = ? + WHERE identity = ? + AND installation_id = ? + AND timestamp < ?`) + if err != nil { + return nil, err + } + defer stmt.Close() + + _, err = stmt.Exec( + timestamp, + oldEnabled, + latestVersion, + identity, + installation.ID, + timestamp, + ) + if err != nil { + return nil, err + } + } + + } + + if err := tx.Commit(); err != nil { + _ = tx.Rollback() + return nil, err + } + + return insertedInstallations, nil + +} + +// EnableInstallation enables the installation +func (s *SQLitePersistence) EnableInstallation(identity []byte, installationID string) error { + stmt, err := s.db.Prepare(`UPDATE installations + SET enabled = 1 + WHERE identity = ? AND installation_id = ?`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(identity, installationID) + return err + +} + +// DisableInstallation disable the installation +func (s *SQLitePersistence) DisableInstallation(identity []byte, installationID string) error { + stmt, err := s.db.Prepare(`UPDATE installations + SET enabled = 0 + WHERE identity = ? AND installation_id = ?`) + if err != nil { + return err + } + defer stmt.Close() + _, err = stmt.Exec(identity, installationID) + return err +} + +// SetInstallationMetadata sets the metadata for a given installation +func (s *SQLitePersistence) SetInstallationMetadata(identity []byte, installationID string, metadata *InstallationMetadata) error { + stmt, err := s.db.Prepare(`INSERT INTO installation_metadata(name, device_type, fcm_token, identity, installation_id) VALUES(?,?,?,?,?)`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(metadata.Name, metadata.DeviceType, metadata.FCMToken, identity, installationID) + return err +} + +// SetInstallationName sets the only the name in metadata for a given installation +func (s *SQLitePersistence) SetInstallationName(identity []byte, installationID string, name string) error { + stmt, err := s.db.Prepare(`UPDATE installation_metadata + SET name = ? + WHERE identity = ? AND installation_id = ?`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(name, identity, installationID) + return err +} diff --git a/messaging/layers/encryption/multidevice/persistence_test.go b/messaging/layers/encryption/multidevice/sqlite_persistence_test.go similarity index 94% rename from messaging/layers/encryption/multidevice/persistence_test.go rename to messaging/layers/encryption/multidevice/sqlite_persistence_test.go index 3e272d2532e..f8c07417146 100644 --- a/messaging/layers/encryption/multidevice/persistence_test.go +++ b/messaging/layers/encryption/multidevice/sqlite_persistence_test.go @@ -3,10 +3,10 @@ package multidevice import ( "testing" + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/suite" - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/protocol/sqlite" + "github.com/status-im/status-go/messaging/layers/encryption/migrations" "github.com/status-im/status-go/t/helpers" ) @@ -16,16 +16,19 @@ func TestSQLLitePersistenceTestSuite(t *testing.T) { type SQLLitePersistenceTestSuite struct { suite.Suite - service *sqlitePersistence + service *SQLitePersistence } func (s *SQLLitePersistenceTestSuite) SetupTest() { - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) s.Require().NoError(err) - s.service = newSQLitePersistence(db) + s.service = NewSQLitePersistence(db) } func (s *SQLLitePersistenceTestSuite) TestAddInstallations() { diff --git a/messaging/layers/encryption/persistence.go b/messaging/layers/encryption/persistence.go index 53e9d45fc40..722882051eb 100644 --- a/messaging/layers/encryption/persistence.go +++ b/messaging/layers/encryption/persistence.go @@ -1,17 +1,12 @@ package encryption import ( - "context" "crypto/ecdsa" - "database/sql" - "errors" - "strings" dr "github.com/status-im/doubleratchet" - "github.com/status-im/status-go/crypto" - "github.com/status-im/status-go/messaging/layers/encryption/multidevice" + "github.com/status-im/status-go/messaging/layers/encryption/sharedsecret" ) // RatchetInfo holds the current ratchet state. @@ -26,710 +21,6 @@ type RatchetInfo struct { InstallationID string } -// A safe max number of rows. -const maxNumberOfRows = 100000000 - -type sqlitePersistence struct { - DB *sql.DB - keysStorage dr.KeysStorage - sessionStorage dr.SessionStorage -} - -func newSQLitePersistence(db *sql.DB) *sqlitePersistence { - return &sqlitePersistence{ - DB: db, - keysStorage: newSQLiteKeysStorage(db), - sessionStorage: newSQLiteSessionStorage(db), - } -} - -// GetKeysStorage returns the associated double ratchet KeysStorage object -func (s *sqlitePersistence) KeysStorage() dr.KeysStorage { - return s.keysStorage -} - -// GetSessionStorage returns the associated double ratchet SessionStorage object -func (s *sqlitePersistence) SessionStorage() dr.SessionStorage { - return s.sessionStorage -} - -// AddPrivateBundle adds the specified BundleContainer to the database -func (s *sqlitePersistence) AddPrivateBundle(bc *BundleContainer) error { - tx, err := s.DB.Begin() - if err != nil { - return err - } - - for installationID, signedPreKey := range bc.GetBundle().GetSignedPreKeys() { - var version uint32 - stmt, err := tx.Prepare(`SELECT version - FROM bundles - WHERE installation_id = ? AND identity = ? - ORDER BY version DESC - LIMIT 1`) - if err != nil { - return err - } - - defer stmt.Close() - - err = stmt.QueryRow(installationID, bc.GetBundle().GetIdentity()).Scan(&version) - if err != nil && err != sql.ErrNoRows { - return err - } - - stmt, err = tx.Prepare(`INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) - VALUES(?, ?, ?, ?, ?, ?)`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - bc.GetBundle().GetIdentity(), - bc.GetPrivateSignedPreKey(), - signedPreKey.GetSignedPreKey(), - installationID, - version+1, - bc.GetBundle().GetTimestamp(), - ) - if err != nil { - _ = tx.Rollback() - return err - } - } - - if err := tx.Commit(); err != nil { - _ = tx.Rollback() - return err - } - - return nil -} - -// AddPublicBundle adds the specified Bundle to the database -func (s *sqlitePersistence) AddPublicBundle(b *Bundle) error { - tx, err := s.DB.Begin() - - if err != nil { - return err - } - - for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() { - signedPreKey := signedPreKeyContainer.GetSignedPreKey() - version := signedPreKeyContainer.GetVersion() - insertStmt, err := tx.Prepare(`INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) - VALUES( ?, ?, ?, ?, ?)`) - if err != nil { - return err - } - defer insertStmt.Close() - - _, err = insertStmt.Exec( - b.GetIdentity(), - signedPreKey, - installationID, - version, - b.GetTimestamp(), - ) - if err != nil { - _ = tx.Rollback() - return err - } - // Mark old bundles as expired - updateStmt, err := tx.Prepare(`UPDATE bundles - SET expired = 1 - WHERE identity = ? AND installation_id = ? AND version < ?`) - if err != nil { - return err - } - defer updateStmt.Close() - - _, err = updateStmt.Exec( - b.GetIdentity(), - installationID, - version, - ) - if err != nil { - _ = tx.Rollback() - return err - } - - } - - return tx.Commit() -} - -// GetAnyPrivateBundle retrieves any bundle from the database containing a private key -func (s *sqlitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*multidevice.Installation) (*BundleContainer, error) { - - versions := make(map[string]uint32) - /* #nosec */ - statement := `SELECT identity, private_key, signed_pre_key, installation_id, timestamp, version - FROM bundles - WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + ")" - stmt, err := s.DB.Prepare(statement) - if err != nil { - return nil, err - } - defer stmt.Close() - - var timestamp int64 - var identity []byte - var privateKey []byte - var version uint32 - - args := make([]interface{}, len(installations)+1) - args[0] = myIdentityKey - for i, installation := range installations { - // Lookup up map for versions - versions[installation.ID] = installation.Version - - args[i+1] = installation.ID - } - - rows, err := stmt.Query(args...) - rowCount := 0 - - if err != nil { - return nil, err - } - - defer rows.Close() - - bundle := &Bundle{ - SignedPreKeys: make(map[string]*SignedPreKey), - } - - bundleContainer := &BundleContainer{ - Bundle: bundle, - } - - for rows.Next() { - var signedPreKey []byte - var installationID string - rowCount++ - err = rows.Scan( - &identity, - &privateKey, - &signedPreKey, - &installationID, - ×tamp, - &version, - ) - if err != nil { - return nil, err - } - // If there is a private key, we set the timestamp of the bundle container - if privateKey != nil { - bundle.Timestamp = timestamp - } - - bundle.SignedPreKeys[installationID] = &SignedPreKey{ - SignedPreKey: signedPreKey, - Version: version, - ProtocolVersion: versions[installationID], - } - bundle.Identity = identity - } - - // If no records are found or no record with private key, return nil - if rowCount == 0 || bundleContainer.GetBundle().Timestamp == 0 { - return nil, nil - } - - return bundleContainer, nil - -} - -// GetPrivateKeyBundle retrieves a private key for a bundle from the database -func (s *sqlitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) { - stmt, err := s.DB.Prepare(`SELECT private_key - FROM bundles - WHERE signed_pre_key = ? LIMIT 1`) - if err != nil { - return nil, err - } - defer stmt.Close() - - var privateKey []byte - - err = stmt.QueryRow(bundleID).Scan(&privateKey) - switch err { - case sql.ErrNoRows: - return nil, nil - case nil: - return privateKey, nil - default: - return nil, err - } -} - -// MarkBundleExpired expires any private bundle for a given identity -func (s *sqlitePersistence) MarkBundleExpired(identity []byte) error { - stmt, err := s.DB.Prepare(`UPDATE bundles - SET expired = 1 - WHERE identity = ? AND private_key IS NOT NULL`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(identity) - - return err -} - -// GetPublicBundle retrieves an existing Bundle for the specified public key from the database -func (s *sqlitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*Bundle, error) { - - if len(installations) == 0 { - return nil, nil - } - - versions := make(map[string]uint32) - identity := crypto.CompressPubkey(publicKey) - - /* #nosec */ - statement := `SELECT signed_pre_key,installation_id, version - FROM bundles - WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + `) - ORDER BY version DESC` - stmt, err := s.DB.Prepare(statement) - if err != nil { - return nil, err - } - defer stmt.Close() - - args := make([]interface{}, len(installations)+1) - args[0] = identity - for i, installation := range installations { - // Lookup up map for versions - versions[installation.ID] = installation.Version - args[i+1] = installation.ID - } - - rows, err := stmt.Query(args...) - rowCount := 0 - - if err != nil { - return nil, err - } - - defer rows.Close() - - bundle := &Bundle{ - Identity: identity, - SignedPreKeys: make(map[string]*SignedPreKey), - } - - for rows.Next() { - var signedPreKey []byte - var installationID string - var version uint32 - rowCount++ - err = rows.Scan( - &signedPreKey, - &installationID, - &version, - ) - if err != nil { - return nil, err - } - - bundle.SignedPreKeys[installationID] = &SignedPreKey{ - SignedPreKey: signedPreKey, - Version: version, - ProtocolVersion: versions[installationID], - } - - } - - if rowCount == 0 { - return nil, nil - } - - return bundle, nil - -} - -// AddRatchetInfo persists the specified ratchet info into the database -func (s *sqlitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error { - stmt, err := s.DB.Prepare(`INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id) - VALUES(?, ?, ?, ?, ?)`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - key, - identity, - bundleID, - ephemeralKey, - installationID, - ) - - return err -} - -// GetRatchetInfo retrieves the existing RatchetInfo for a specified bundle ID and interlocutor public key from the database -func (s *sqlitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) { - stmt, err := s.DB.Prepare(`SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id - FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key - WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ? - LIMIT 1`) - if err != nil { - return nil, err - } - defer stmt.Close() - - ratchetInfo := &RatchetInfo{ - BundleID: bundleID, - } - - err = stmt.QueryRow(theirIdentity, installationID, bundleID).Scan( - &ratchetInfo.Identity, - &ratchetInfo.Sk, - &ratchetInfo.PrivateKey, - &ratchetInfo.PublicKey, - &ratchetInfo.EphemeralKey, - &ratchetInfo.InstallationID, - ) - switch err { - case sql.ErrNoRows: - return nil, nil - case nil: - ratchetInfo.ID = append(bundleID, []byte(ratchetInfo.InstallationID)...) - return ratchetInfo, nil - default: - return nil, err - } -} - -// GetAnyRatchetInfo retrieves any existing RatchetInfo for a specified interlocutor public key from the database -func (s *sqlitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) { - stmt, err := s.DB.Prepare(`SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key - FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key - WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? - LIMIT 1`) - if err != nil { - return nil, err - } - defer stmt.Close() - - ratchetInfo := &RatchetInfo{ - Identity: identity, - InstallationID: installationID, - } - - err = stmt.QueryRow(identity, installationID).Scan( - &ratchetInfo.Sk, - &ratchetInfo.PrivateKey, - &ratchetInfo.PublicKey, - &ratchetInfo.BundleID, - &ratchetInfo.EphemeralKey, - ) - switch err { - case sql.ErrNoRows: - return nil, nil - case nil: - ratchetInfo.ID = append(ratchetInfo.BundleID, []byte(installationID)...) - return ratchetInfo, nil - default: - return nil, err - } -} - -// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo -// associated with the specified bundle ID and interlocutor identity public key -func (s *sqlitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error { - stmt, err := s.DB.Prepare(`UPDATE ratchet_info_v2 - SET ephemeral_key = NULL - WHERE identity = ? AND bundle_id = ? AND installation_id = ?`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - theirIdentity, - bundleID, - installationID, - ) - - return err -} - -type sqliteKeysStorage struct { - db *sql.DB -} - -func newSQLiteKeysStorage(db *sql.DB) *sqliteKeysStorage { - return &sqliteKeysStorage{ - db: db, - } -} - -// Get retrieves the message key for a specified public key and message number -func (s *sqliteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) { - var key []byte - stmt, err := s.db.Prepare(`SELECT message_key - FROM keys - WHERE public_key = ? AND msg_num = ? - LIMIT 1`) - - if err != nil { - return key, false, err - } - defer stmt.Close() - - err = stmt.QueryRow(pubKey, msgNum).Scan(&key) - switch err { - case sql.ErrNoRows: - return key, false, nil - case nil: - return key, true, nil - default: - return key, false, err - } -} - -// Put stores a key with the specified public key, message number and message key -func (s *sqliteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error { - stmt, err := s.db.Prepare(`INSERT INTO keys(session_id, public_key, msg_num, message_key, seq_num) - VALUES(?, ?, ?, ?, ?)`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - sessionID, - pubKey, - msgNum, - mk, - seqNum, - ) - - return err -} - -// DeleteOldMks caps remove any key < seq_num, included -func (s *sqliteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error { - stmt, err := s.db.Prepare(`DELETE FROM keys - WHERE session_id = ? AND seq_num <= ?`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - sessionID, - deleteUntil, - ) - - return err -} - -// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion -func (s *sqliteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error { - stmt, err := s.db.Prepare(`DELETE FROM keys - WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - sessionID, - // We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT - maxNumberOfRows, - maxKeysPerSession, - ) - - return err -} - -// DeleteMk deletes the key with the specified public key and message key -func (s *sqliteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error { - stmt, err := s.db.Prepare(`DELETE FROM keys - WHERE public_key = ? AND msg_num = ?`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - pubKey, - msgNum, - ) - - return err -} - -// Count returns the count of keys with the specified public key -func (s *sqliteKeysStorage) Count(pubKey dr.Key) (uint, error) { - stmt, err := s.db.Prepare(`SELECT COUNT(1) - FROM keys - WHERE public_key = ?`) - if err != nil { - return 0, err - } - defer stmt.Close() - - var count uint - err = stmt.QueryRow(pubKey).Scan(&count) - if err != nil { - return 0, err - } - - return count, nil -} - -// CountAll returns the count of keys with the specified public key -func (s *sqliteKeysStorage) CountAll() (uint, error) { - stmt, err := s.db.Prepare(`SELECT COUNT(1) - FROM keys`) - if err != nil { - return 0, err - } - defer stmt.Close() - - var count uint - err = stmt.QueryRow().Scan(&count) - if err != nil { - return 0, err - } - - return count, nil -} - -// All returns nil -func (s *sqliteKeysStorage) All() (map[string]map[uint]dr.Key, error) { - return nil, nil -} - -type sqliteSessionStorage struct { - db *sql.DB -} - -func newSQLiteSessionStorage(db *sql.DB) *sqliteSessionStorage { - return &sqliteSessionStorage{ - db: db, - } -} - -// Save persists the specified double ratchet state -func (s *sqliteSessionStorage) Save(id []byte, state *dr.State) error { - dhr := state.DHr - dhs := state.DHs - dhsPublic := dhs.PublicKey() - dhsPrivate := dhs.PrivateKey() - pn := state.PN - step := state.Step - keysCount := state.KeysCount - - rootChainKey := state.RootCh.CK - - sendChainKey := state.SendCh.CK - sendChainN := state.SendCh.N - - recvChainKey := state.RecvCh.CK - recvChainN := state.RecvCh.N - - stmt, err := s.db.Prepare(`INSERT INTO sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec( - id, - dhr, - dhsPublic, - dhsPrivate, - rootChainKey, - sendChainKey, - sendChainN, - recvChainKey, - recvChainN, - pn, - step, - keysCount, - ) - - return err -} - -// Load retrieves the double ratchet state for a given ID -func (s *sqliteSessionStorage) Load(id []byte) (*dr.State, error) { - stmt, err := s.db.Prepare(`SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count - FROM sessions - WHERE id = ?`) - if err != nil { - return nil, err - } - - defer stmt.Close() - - var ( - dhr []byte - dhsPublic []byte - dhsPrivate []byte - rootChainKey []byte - sendChainKey []byte - sendChainN uint - recvChainKey []byte - recvChainN uint - pn uint - step uint - keysCount uint - ) - - err = stmt.QueryRow(id).Scan( - &dhr, - &dhsPublic, - &dhsPrivate, - &rootChainKey, - &sendChainKey, - &sendChainN, - &recvChainKey, - &recvChainN, - &pn, - &step, - &keysCount, - ) - switch err { - case sql.ErrNoRows: - return nil, nil - case nil: - state := dr.DefaultState(rootChainKey) - - state.PN = uint32(pn) - state.Step = step - state.KeysCount = keysCount - - state.DHs = crypto.DHPair{ - PrvKey: dhsPrivate, - PubKey: dhsPublic, - } - - state.DHr = dhr - - state.SendCh.CK = sendChainKey - state.SendCh.N = uint32(sendChainN) - - state.RecvCh.CK = recvChainKey - state.RecvCh.N = uint32(recvChainN) - - return &state, nil - default: - return nil, err - } -} - type HRCache struct { GroupID []byte KeyID []byte @@ -739,273 +30,26 @@ type HRCache struct { SeqNo uint32 } -// GetHashRatchetCache retrieves a hash ratchet key by group ID and seqNo. -// If cache data with given seqNo (e.g. 0) is not found, -// then the query will return the cache data with the latest seqNo -func (s *sqlitePersistence) GetHashRatchetCache(ratchet *HashRatchetKeyCompatibility, seqNo uint32) (*HRCache, error) { - tx, err := s.DB.BeginTx(context.Background(), &sql.TxOptions{}) - if err != nil { - return nil, err - } - defer func() { - if err == nil { - err = tx.Commit() - return - } - // don't shadow original error - _ = tx.Rollback() - }() - - var key, keyID []byte - if !ratchet.IsOldFormat() { - keyID, err = ratchet.GetKeyID() - if err != nil { - return nil, err - } - } - - err = tx.QueryRow("SELECT key FROM hash_ratchet_encryption WHERE key_id = ? OR (deprecated_key_id = ? AND group_id = ?)", - keyID, - ratchet.DeprecatedKeyID(), - ratchet.GroupID, - ).Scan(&key) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } - - args := make([]interface{}, 0) - args = append(args, ratchet.GroupID) - args = append(args, keyID) - args = append(args, ratchet.DeprecatedKeyID()) - var query string - if seqNo == 0 { - query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) ORDER BY seq_no DESC limit 1" - } else { - query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) AND seq_no == ? ORDER BY seq_no DESC limit 1" - args = append(args, seqNo) - } - - var hash []byte - var seqNoPtr *uint32 - - err = tx.QueryRow(query, args...).Scan(&seqNoPtr, &hash) //nolint: ineffassign,staticcheck - switch err { - case sql.ErrNoRows, nil: - var seqNoResult uint32 - if seqNoPtr == nil { - seqNoResult = 0 - } else { - seqNoResult = *seqNoPtr - } - - ratchet.Key = key - keyID, err := ratchet.GetKeyID() - - if err != nil { - return nil, err - } - - res := &HRCache{ - KeyID: keyID, - Key: key, - Hash: hash, - SeqNo: seqNoResult, - } - - return res, nil - default: - return nil, err - } -} - -type HashRatchetKeyCompatibility struct { - GroupID []byte - keyID []byte - Timestamp uint64 - Key []byte -} - -func (h *HashRatchetKeyCompatibility) DeprecatedKeyID() uint32 { - return uint32(h.Timestamp) -} - -func (h *HashRatchetKeyCompatibility) IsOldFormat() bool { - return len(h.keyID) == 0 && len(h.Key) == 0 -} - -func (h *HashRatchetKeyCompatibility) GetKeyID() ([]byte, error) { - if len(h.keyID) != 0 { - return h.keyID, nil - } - - if len(h.GroupID) == 0 || h.Timestamp == 0 || len(h.Key) == 0 { - return nil, errors.New("could not create key") - } - - return generateHashRatchetKeyID(h.GroupID, h.Timestamp, h.Key), nil -} - -func (h *HashRatchetKeyCompatibility) GenerateNext() (*HashRatchetKeyCompatibility, error) { - - ratchet := &HashRatchetKeyCompatibility{ - GroupID: h.GroupID, - } - - // Randomly generate a hash ratchet key - hrKey, err := crypto.GenerateKey() - if err != nil { - return nil, err - } - hrKeyBytes := crypto.FromECDSA(hrKey) - - if err != nil { - return nil, err - } - - currentTime := GetCurrentTime() - if h.Timestamp < currentTime { - ratchet.Timestamp = bumpKeyID(currentTime) - } else { - ratchet.Timestamp = h.Timestamp + 1 - } - - ratchet.Key = hrKeyBytes - - _, err = ratchet.GetKeyID() - if err != nil { - return nil, err - } - - return ratchet, nil -} - -// GetCurrentKeyForGroup retrieves a key ID for given group ID -// (with an assumption that key ids are shared in the group, and -// at any given time there is a single key used) -func (s *sqlitePersistence) GetCurrentKeyForGroup(groupID []byte) (*HashRatchetKeyCompatibility, error) { - ratchet := &HashRatchetKeyCompatibility{ - GroupID: groupID, - } - - stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key - FROM hash_ratchet_encryption - WHERE group_id = ? order by key_timestamp desc limit 1`) - if err != nil { - return nil, err - } - defer stmt.Close() - - var keyID, key []byte - var timestamp uint64 - err = stmt.QueryRow(groupID).Scan(&keyID, ×tamp, &key) - - switch err { - case sql.ErrNoRows: - return ratchet, nil - case nil: - ratchet.Key = key - ratchet.Timestamp = timestamp - _, err = ratchet.GetKeyID() - if err != nil { - return nil, err - } - return ratchet, nil - default: - return nil, err - } -} - -// GetKeysForGroup retrieves all key IDs for given group ID -func (s *sqlitePersistence) GetKeysForGroup(groupID []byte) ([]*HashRatchetKeyCompatibility, error) { - - var ratchets []*HashRatchetKeyCompatibility - stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key - FROM hash_ratchet_encryption - WHERE group_id = ? order by key_timestamp desc`) - if err != nil { - return nil, err - } - defer stmt.Close() - - rows, err := stmt.Query(groupID) - if err != nil { - return nil, err - } - - for rows.Next() { - ratchet := &HashRatchetKeyCompatibility{GroupID: groupID} - err := rows.Scan(&ratchet.keyID, &ratchet.Timestamp, &ratchet.Key) - if err != nil { - return nil, err - } - ratchets = append(ratchets, ratchet) - } - - return ratchets, nil -} - -// SaveHashRatchetKeyHash saves a hash ratchet key cache data -func (s *sqlitePersistence) SaveHashRatchetKeyHash( - ratchet *HashRatchetKeyCompatibility, - hash []byte, - seqNo uint32, -) error { - - stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption_cache(group_id, key_id, hash, seq_no) - VALUES(?, ?, ?, ?)`) - if err != nil { - return err - } - defer stmt.Close() - - keyID, err := ratchet.GetKeyID() - if err != nil { - return err - } - - _, err = stmt.Exec(ratchet.GroupID, keyID, hash, seqNo) - - return err -} - -// SaveHashRatchetKey saves a hash ratchet key -func (s *sqlitePersistence) SaveHashRatchetKey(ratchet *HashRatchetKeyCompatibility) error { - stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption(group_id, key_id, key_timestamp, deprecated_key_id, key) - VALUES(?,?,?,?,?)`) - if err != nil { - return err - } - defer stmt.Close() - - keyID, err := ratchet.GetKeyID() - if err != nil { - return err - } - - _, err = stmt.Exec(ratchet.GroupID, keyID, ratchet.Timestamp, ratchet.DeprecatedKeyID(), ratchet.Key) - - return err -} - -func (s *sqlitePersistence) GetHashRatchetKeyByID(keyID []byte) (*HashRatchetKeyCompatibility, error) { - ratchet := &HashRatchetKeyCompatibility{ - keyID: keyID, - } - - err := s.DB.QueryRow(` - SELECT group_id, key_timestamp, key - FROM hash_ratchet_encryption - WHERE key_id = ?`, keyID).Scan(&ratchet.GroupID, &ratchet.Timestamp, &ratchet.Key) - - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - - return ratchet, nil +type Persistence interface { + KeysStorage() dr.KeysStorage + SessionStorage() dr.SessionStorage + SharedSecretStorage() sharedsecret.Persistence + MultideviceStorage() multidevice.Persistence + + AddPrivateBundle(bc *BundleContainer) error + AddPublicBundle(b *Bundle) error + GetAnyPrivateBundle(myIdentityKey []byte, installations []*multidevice.Installation) (*BundleContainer, error) + GetPrivateKeyBundle(bundleID []byte) ([]byte, error) + MarkBundleExpired(identity []byte) error + GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*Bundle, error) + AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error + GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) + GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) + RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error + GetHashRatchetCache(ratchet *HashRatchetKeyCompatibility, seqNo uint32) (*HRCache, error) + GetCurrentKeyForGroup(groupID []byte) (*HashRatchetKeyCompatibility, error) + GetKeysForGroup(groupID []byte) ([]*HashRatchetKeyCompatibility, error) + SaveHashRatchetKeyHash(ratchet *HashRatchetKeyCompatibility, hash []byte, seqNo uint32) error + SaveHashRatchetKey(ratchet *HashRatchetKeyCompatibility) error + GetHashRatchetKeyByID(keyID []byte) (*HashRatchetKeyCompatibility, error) } diff --git a/messaging/layers/encryption/protocol.go b/messaging/layers/encryption/protocol.go index 6f8ee8b6d2d..2f233a424da 100644 --- a/messaging/layers/encryption/protocol.go +++ b/messaging/layers/encryption/protocol.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/ecdsa" "crypto/rand" - "database/sql" "fmt" "go.uber.org/zap" @@ -89,12 +88,12 @@ var ( // New creates a new ProtocolService instance func New( - db *sql.DB, + persistence Persistence, installationID string, logger *zap.Logger, ) *Protocol { return NewWithEncryptorConfig( - db, + persistence, installationID, defaultEncryptorConfig(installationID, logger), logger, @@ -104,15 +103,15 @@ func New( // DB and migrations are shared between encryption package // and its sub-packages. func NewWithEncryptorConfig( - db *sql.DB, + persistence Persistence, installationID string, encryptorConfig encryptorConfig, logger *zap.Logger, ) *Protocol { return &Protocol{ - encryptor: newEncryptor(db, encryptorConfig), - secret: sharedsecret.New(db, logger), - multidevice: multidevice.New(db, &multidevice.Config{ + encryptor: newEncryptor(persistence, encryptorConfig), + secret: sharedsecret.New(persistence.SharedSecretStorage(), logger), + multidevice: multidevice.New(persistence.MultideviceStorage(), &multidevice.Config{ MaxInstallations: 3, ProtocolVersion: protocolVersion, InstallationID: installationID, diff --git a/messaging/layers/encryption/protocol_test.go b/messaging/layers/encryption/protocol_test.go index 25763ea2ace..f92e7b2b36f 100644 --- a/messaging/layers/encryption/protocol_test.go +++ b/messaging/layers/encryption/protocol_test.go @@ -3,15 +3,14 @@ package encryption import ( "testing" - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/protocol/sqlite" - "github.com/status-im/status-go/protocol/tt" - "github.com/status-im/status-go/t/helpers" - + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/suite" "go.uber.org/zap" "github.com/status-im/status-go/crypto" + "github.com/status-im/status-go/messaging/layers/encryption/migrations" + "github.com/status-im/status-go/protocol/tt" + "github.com/status-im/status-go/t/helpers" ) func TestProtocolServiceTestSuite(t *testing.T) { @@ -30,22 +29,30 @@ func (s *ProtocolServiceTestSuite) SetupTest() { s.logger = tt.MustCreateTestLogger() - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) s.Require().NoError(err) + s.alice = New( - db, + NewSQLitePersistence(db), "1", s.logger.With(zap.String("user", "alice")), ) - db, err = helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) + db, err = helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) s.Require().NoError(err) + s.bob = New( - db, + NewSQLitePersistence(db), "2", s.logger.With(zap.String("user", "bob")), ) diff --git a/messaging/layers/encryption/sharedsecret/persistence.go b/messaging/layers/encryption/sharedsecret/persistence.go index 1df38853633..6790e1ab7d9 100644 --- a/messaging/layers/encryption/sharedsecret/persistence.go +++ b/messaging/layers/encryption/sharedsecret/persistence.go @@ -1,120 +1,12 @@ package sharedsecret -import ( - "database/sql" - "strings" -) - type Response struct { - secret []byte - installationIDs map[string]bool -} - -type sqlitePersistence struct { - db *sql.DB -} - -func newSQLitePersistence(db *sql.DB) *sqlitePersistence { - return &sqlitePersistence{db: db} -} - -func (s *sqlitePersistence) Add(identity []byte, secret []byte, installationID string) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - - insertSecretStmt, err := tx.Prepare("INSERT INTO secrets(identity, secret) VALUES (?, ?)") - if err != nil { - _ = tx.Rollback() - return err - } - defer insertSecretStmt.Close() - - _, err = insertSecretStmt.Exec(identity, secret) - if err != nil { - _ = tx.Rollback() - return err - } - - insertInstallationIDStmt, err := tx.Prepare("INSERT INTO secret_installation_ids(id, identity_id) VALUES (?, ?)") - if err != nil { - _ = tx.Rollback() - return err - } - defer insertInstallationIDStmt.Close() - - _, err = insertInstallationIDStmt.Exec(installationID, identity) - if err != nil { - _ = tx.Rollback() - return err - } - return tx.Commit() + Secret []byte + InstallationIDs map[string]bool } -func (s *sqlitePersistence) Get(identity []byte, installationIDs []string) (*Response, error) { - response := &Response{ - installationIDs: make(map[string]bool), - } - args := make([]interface{}, len(installationIDs)+1) - args[0] = identity - for i, installationID := range installationIDs { - args[i+1] = installationID - } - - /* #nosec */ - query := `SELECT secret, id - FROM secrets t - JOIN - secret_installation_ids tid - ON t.identity = tid.identity_id - WHERE - t.identity = ? - AND - tid.id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + `)` - - rows, err := s.db.Query(query, args...) - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var installationID string - var secret []byte - err = rows.Scan(&secret, &installationID) - if err != nil { - return nil, err - } - - response.secret = secret - response.installationIDs[installationID] = true - } - - return response, nil -} - -func (s *sqlitePersistence) All() ([][][]byte, error) { - query := "SELECT identity, secret FROM secrets" - - var secrets [][][]byte - - rows, err := s.db.Query(query) - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var secret []byte - var identity []byte - err = rows.Scan(&identity, &secret) - if err != nil { - return nil, err - } - - secrets = append(secrets, [][]byte{identity, secret}) - } - - return secrets, nil +type Persistence interface { + Add(identity []byte, secret []byte, installationID string) error + Get(identity []byte, installationIDs []string) (*Response, error) + All() ([][][]byte, error) } diff --git a/messaging/layers/encryption/sharedsecret/service_test.go b/messaging/layers/encryption/sharedsecret/service_test.go index abe70d29ad0..25fad4fe334 100644 --- a/messaging/layers/encryption/sharedsecret/service_test.go +++ b/messaging/layers/encryption/sharedsecret/service_test.go @@ -3,15 +3,15 @@ package sharedsecret import ( "testing" - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/protocol/sqlite" - "github.com/status-im/status-go/protocol/tt" - "github.com/status-im/status-go/t/helpers" - "github.com/stretchr/testify/suite" "go.uber.org/zap" + bindata "github.com/status-im/migrate/v4/source/go_bindata" + "github.com/status-im/status-go/crypto" + "github.com/status-im/status-go/messaging/layers/encryption/migrations" + "github.com/status-im/status-go/protocol/tt" + "github.com/status-im/status-go/t/helpers" ) func TestServiceTestSuite(t *testing.T) { @@ -27,12 +27,15 @@ type SharedSecretTestSuite struct { func (s *SharedSecretTestSuite) SetupTest() { s.logger = tt.MustCreateTestLogger() - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) s.Require().NoError(err) - s.service = New(db, s.logger) + s.service = New(NewSQLitePersistence(db), s.logger) } func (s *SharedSecretTestSuite) TearDownTest() { diff --git a/messaging/layers/encryption/sharedsecret/sharedsecret.go b/messaging/layers/encryption/sharedsecret/sharedsecret.go index 8797bdfb64d..9eb9c5c0943 100644 --- a/messaging/layers/encryption/sharedsecret/sharedsecret.go +++ b/messaging/layers/encryption/sharedsecret/sharedsecret.go @@ -3,7 +3,6 @@ package sharedsecret import ( "bytes" "crypto/ecdsa" - "database/sql" "errors" "go.uber.org/zap" @@ -21,17 +20,17 @@ type Secret struct { // are compressed. // TODO: make compression of public keys a responsibility of sqlitePersistence instead of SharedSecret. type SharedSecret struct { - persistence *sqlitePersistence + persistence Persistence logger *zap.Logger } -func New(db *sql.DB, logger *zap.Logger) *SharedSecret { +func New(persistence Persistence, logger *zap.Logger) *SharedSecret { if logger == nil { logger = zap.NewNop() } return &SharedSecret{ - persistence: newSQLitePersistence(db), + persistence: persistence, logger: logger.With(zap.Namespace("SharedSecret")), } } @@ -73,12 +72,12 @@ func (s *SharedSecret) Agreed(myPrivateKey *ecdsa.PrivateKey, myInstallationID s } for _, installationID := range theirInstallationIDs { - if !response.installationIDs[installationID] { + if !response.InstallationIDs[installationID] { return secret, false, nil } } - if !bytes.Equal(secret.Key, response.secret) { + if !bytes.Equal(secret.Key, response.Secret) { return nil, false, errors.New("computed and saved secrets are different for a given identity") } diff --git a/messaging/layers/encryption/sharedsecret/sqlite_persistence.go b/messaging/layers/encryption/sharedsecret/sqlite_persistence.go new file mode 100644 index 00000000000..6c3277a61ca --- /dev/null +++ b/messaging/layers/encryption/sharedsecret/sqlite_persistence.go @@ -0,0 +1,117 @@ +package sharedsecret + +import ( + "database/sql" + "strings" +) + +type SQLitePersistence struct { + db *sql.DB +} + +var _ Persistence = (*SQLitePersistence)(nil) + +func NewSQLitePersistence(db *sql.DB) *SQLitePersistence { + return &SQLitePersistence{db: db} +} + +func (s *SQLitePersistence) Add(identity []byte, secret []byte, installationID string) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + + insertSecretStmt, err := tx.Prepare("INSERT INTO secrets(identity, secret) VALUES (?, ?)") + if err != nil { + _ = tx.Rollback() + return err + } + defer insertSecretStmt.Close() + + _, err = insertSecretStmt.Exec(identity, secret) + if err != nil { + _ = tx.Rollback() + return err + } + + insertInstallationIDStmt, err := tx.Prepare("INSERT INTO secret_installation_ids(id, identity_id) VALUES (?, ?)") + if err != nil { + _ = tx.Rollback() + return err + } + defer insertInstallationIDStmt.Close() + + _, err = insertInstallationIDStmt.Exec(installationID, identity) + if err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} + +func (s *SQLitePersistence) Get(identity []byte, installationIDs []string) (*Response, error) { + response := &Response{ + InstallationIDs: make(map[string]bool), + } + args := make([]interface{}, len(installationIDs)+1) + args[0] = identity + for i, installationID := range installationIDs { + args[i+1] = installationID + } + + /* #nosec */ + query := `SELECT secret, id + FROM secrets t + JOIN + secret_installation_ids tid + ON t.identity = tid.identity_id + WHERE + t.identity = ? + AND + tid.id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + `)` + + rows, err := s.db.Query(query, args...) + if err != nil && err != sql.ErrNoRows { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var installationID string + var secret []byte + err = rows.Scan(&secret, &installationID) + if err != nil { + return nil, err + } + + response.Secret = secret + response.InstallationIDs[installationID] = true + } + + return response, nil +} + +func (s *SQLitePersistence) All() ([][][]byte, error) { + query := "SELECT identity, secret FROM secrets" + + var secrets [][][]byte + + rows, err := s.db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var secret []byte + var identity []byte + err = rows.Scan(&identity, &secret) + if err != nil { + return nil, err + } + + secrets = append(secrets, [][]byte{identity, secret}) + } + + return secrets, nil +} diff --git a/messaging/layers/encryption/sqlite_persistence.go b/messaging/layers/encryption/sqlite_persistence.go new file mode 100644 index 00000000000..00cbf114b35 --- /dev/null +++ b/messaging/layers/encryption/sqlite_persistence.go @@ -0,0 +1,943 @@ +package encryption + +import ( + "context" + "crypto/ecdsa" + "database/sql" + "strings" + + dr "github.com/status-im/doubleratchet" + + "github.com/status-im/status-go/crypto" + + "github.com/status-im/status-go/messaging/layers/encryption/multidevice" + "github.com/status-im/status-go/messaging/layers/encryption/sharedsecret" +) + +// A safe max number of rows. +const maxNumberOfRows = 100000000 + +type SQLitePersistence struct { + DB *sql.DB + keysStorage dr.KeysStorage + sessionStorage dr.SessionStorage + sharedSecretStorage sharedsecret.Persistence + multideviceStorage multidevice.Persistence +} + +var _ Persistence = (*SQLitePersistence)(nil) + +func NewSQLitePersistence(db *sql.DB) *SQLitePersistence { + return &SQLitePersistence{ + DB: db, + keysStorage: newSQLiteKeysStorage(db), + sessionStorage: newSQLiteSessionStorage(db), + sharedSecretStorage: sharedsecret.NewSQLitePersistence(db), + multideviceStorage: multidevice.NewSQLitePersistence(db), + } +} + +// GetKeysStorage returns the associated double ratchet KeysStorage object +func (s *SQLitePersistence) KeysStorage() dr.KeysStorage { + return s.keysStorage +} + +// GetSessionStorage returns the associated double ratchet SessionStorage object +func (s *SQLitePersistence) SessionStorage() dr.SessionStorage { + return s.sessionStorage +} + +func (s *SQLitePersistence) SharedSecretStorage() sharedsecret.Persistence { + return s.sharedSecretStorage +} + +func (s *SQLitePersistence) MultideviceStorage() multidevice.Persistence { + return s.multideviceStorage +} + +// AddPrivateBundle adds the specified encryption.BundleContainer to the database +func (s *SQLitePersistence) AddPrivateBundle(bc *BundleContainer) error { + tx, err := s.DB.Begin() + if err != nil { + return err + } + + for installationID, signedPreKey := range bc.GetBundle().GetSignedPreKeys() { + var version uint32 + stmt, err := tx.Prepare(`SELECT version + FROM bundles + WHERE installation_id = ? AND identity = ? + ORDER BY version DESC + LIMIT 1`) + if err != nil { + return err + } + + defer stmt.Close() + + err = stmt.QueryRow(installationID, bc.GetBundle().GetIdentity()).Scan(&version) + if err != nil && err != sql.ErrNoRows { + return err + } + + stmt, err = tx.Prepare(`INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) + VALUES(?, ?, ?, ?, ?, ?)`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + bc.GetBundle().GetIdentity(), + bc.GetPrivateSignedPreKey(), + signedPreKey.GetSignedPreKey(), + installationID, + version+1, + bc.GetBundle().GetTimestamp(), + ) + if err != nil { + _ = tx.Rollback() + return err + } + } + + if err := tx.Commit(); err != nil { + _ = tx.Rollback() + return err + } + + return nil +} + +// AddPublicBundle adds the specified Bundle to the database +func (s *SQLitePersistence) AddPublicBundle(b *Bundle) error { + tx, err := s.DB.Begin() + + if err != nil { + return err + } + + for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() { + signedPreKey := signedPreKeyContainer.GetSignedPreKey() + version := signedPreKeyContainer.GetVersion() + insertStmt, err := tx.Prepare(`INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) + VALUES( ?, ?, ?, ?, ?)`) + if err != nil { + return err + } + defer insertStmt.Close() + + _, err = insertStmt.Exec( + b.GetIdentity(), + signedPreKey, + installationID, + version, + b.GetTimestamp(), + ) + if err != nil { + _ = tx.Rollback() + return err + } + // Mark old bundles as expired + updateStmt, err := tx.Prepare(`UPDATE bundles + SET expired = 1 + WHERE identity = ? AND installation_id = ? AND version < ?`) + if err != nil { + return err + } + defer updateStmt.Close() + + _, err = updateStmt.Exec( + b.GetIdentity(), + installationID, + version, + ) + if err != nil { + _ = tx.Rollback() + return err + } + + } + + return tx.Commit() +} + +// GetAnyPrivateBundle retrieves any bundle from the database containing a private key +func (s *SQLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*multidevice.Installation) (*BundleContainer, error) { + + versions := make(map[string]uint32) + /* #nosec */ + statement := `SELECT identity, private_key, signed_pre_key, installation_id, timestamp, version + FROM bundles + WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + ")" + stmt, err := s.DB.Prepare(statement) + if err != nil { + return nil, err + } + defer stmt.Close() + + var timestamp int64 + var identity []byte + var privateKey []byte + var version uint32 + + args := make([]interface{}, len(installations)+1) + args[0] = myIdentityKey + for i, installation := range installations { + // Lookup up map for versions + versions[installation.ID] = installation.Version + + args[i+1] = installation.ID + } + + rows, err := stmt.Query(args...) + rowCount := 0 + + if err != nil { + return nil, err + } + + defer rows.Close() + + bundle := &Bundle{ + SignedPreKeys: make(map[string]*SignedPreKey), + } + + bundleContainer := &BundleContainer{ + Bundle: bundle, + } + + for rows.Next() { + var signedPreKey []byte + var installationID string + rowCount++ + err = rows.Scan( + &identity, + &privateKey, + &signedPreKey, + &installationID, + ×tamp, + &version, + ) + if err != nil { + return nil, err + } + // If there is a private key, we set the timestamp of the bundle container + if privateKey != nil { + bundle.Timestamp = timestamp + } + + bundle.SignedPreKeys[installationID] = &SignedPreKey{ + SignedPreKey: signedPreKey, + Version: version, + ProtocolVersion: versions[installationID], + } + bundle.Identity = identity + } + + // If no records are found or no record with private key, return nil + if rowCount == 0 || bundleContainer.GetBundle().Timestamp == 0 { + return nil, nil + } + + return bundleContainer, nil + +} + +// GetPrivateKeyBundle retrieves a private key for a bundle from the database +func (s *SQLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) { + stmt, err := s.DB.Prepare(`SELECT private_key + FROM bundles + WHERE signed_pre_key = ? LIMIT 1`) + if err != nil { + return nil, err + } + defer stmt.Close() + + var privateKey []byte + + err = stmt.QueryRow(bundleID).Scan(&privateKey) + switch err { + case sql.ErrNoRows: + return nil, nil + case nil: + return privateKey, nil + default: + return nil, err + } +} + +// MarkBundleExpired expires any private bundle for a given identity +func (s *SQLitePersistence) MarkBundleExpired(identity []byte) error { + stmt, err := s.DB.Prepare(`UPDATE bundles + SET expired = 1 + WHERE identity = ? AND private_key IS NOT NULL`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(identity) + + return err +} + +// GetPublicBundle retrieves an existing Bundle for the specified public key from the database +func (s *SQLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*Bundle, error) { + + if len(installations) == 0 { + return nil, nil + } + + versions := make(map[string]uint32) + identity := crypto.CompressPubkey(publicKey) + + /* #nosec */ + statement := `SELECT signed_pre_key,installation_id, version + FROM bundles + WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + `) + ORDER BY version DESC` + stmt, err := s.DB.Prepare(statement) + if err != nil { + return nil, err + } + defer stmt.Close() + + args := make([]interface{}, len(installations)+1) + args[0] = identity + for i, installation := range installations { + // Lookup up map for versions + versions[installation.ID] = installation.Version + args[i+1] = installation.ID + } + + rows, err := stmt.Query(args...) + rowCount := 0 + + if err != nil { + return nil, err + } + + defer rows.Close() + + bundle := &Bundle{ + Identity: identity, + SignedPreKeys: make(map[string]*SignedPreKey), + } + + for rows.Next() { + var signedPreKey []byte + var installationID string + var version uint32 + rowCount++ + err = rows.Scan( + &signedPreKey, + &installationID, + &version, + ) + if err != nil { + return nil, err + } + + bundle.SignedPreKeys[installationID] = &SignedPreKey{ + SignedPreKey: signedPreKey, + Version: version, + ProtocolVersion: versions[installationID], + } + + } + + if rowCount == 0 { + return nil, nil + } + + return bundle, nil + +} + +// AddRatchetInfo persists the specified ratchet info into the database +func (s *SQLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error { + stmt, err := s.DB.Prepare(`INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id) + VALUES(?, ?, ?, ?, ?)`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + key, + identity, + bundleID, + ephemeralKey, + installationID, + ) + + return err +} + +// GetRatchetInfo retrieves the existing encryption.RatchetInfo for a specified bundle ID and interlocutor public key from the database +func (s *SQLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) { + stmt, err := s.DB.Prepare(`SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id + FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key + WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ? + LIMIT 1`) + if err != nil { + return nil, err + } + defer stmt.Close() + + ratchetInfo := &RatchetInfo{ + BundleID: bundleID, + } + + err = stmt.QueryRow(theirIdentity, installationID, bundleID).Scan( + &ratchetInfo.Identity, + &ratchetInfo.Sk, + &ratchetInfo.PrivateKey, + &ratchetInfo.PublicKey, + &ratchetInfo.EphemeralKey, + &ratchetInfo.InstallationID, + ) + switch err { + case sql.ErrNoRows: + return nil, nil + case nil: + ratchetInfo.ID = append(bundleID, []byte(ratchetInfo.InstallationID)...) + return ratchetInfo, nil + default: + return nil, err + } +} + +// GetAnyRatchetInfo retrieves any existing encryption.RatchetInfo for a specified interlocutor public key from the database +func (s *SQLitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) { + stmt, err := s.DB.Prepare(`SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key + FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key + WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? + LIMIT 1`) + if err != nil { + return nil, err + } + defer stmt.Close() + + ratchetInfo := &RatchetInfo{ + Identity: identity, + InstallationID: installationID, + } + + err = stmt.QueryRow(identity, installationID).Scan( + &ratchetInfo.Sk, + &ratchetInfo.PrivateKey, + &ratchetInfo.PublicKey, + &ratchetInfo.BundleID, + &ratchetInfo.EphemeralKey, + ) + switch err { + case sql.ErrNoRows: + return nil, nil + case nil: + ratchetInfo.ID = append(ratchetInfo.BundleID, []byte(installationID)...) + return ratchetInfo, nil + default: + return nil, err + } +} + +// RatchetInfoConfirmed clears the ephemeral key in the encryption.RatchetInfo +// associated with the specified bundle ID and interlocutor identity public key +func (s *SQLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error { + stmt, err := s.DB.Prepare(`UPDATE ratchet_info_v2 + SET ephemeral_key = NULL + WHERE identity = ? AND bundle_id = ? AND installation_id = ?`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + theirIdentity, + bundleID, + installationID, + ) + + return err +} + +type sqliteKeysStorage struct { + db *sql.DB +} + +func newSQLiteKeysStorage(db *sql.DB) *sqliteKeysStorage { + return &sqliteKeysStorage{ + db: db, + } +} + +// Get retrieves the message key for a specified public key and message number +func (s *sqliteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) { + var key []byte + stmt, err := s.db.Prepare(`SELECT message_key + FROM keys + WHERE public_key = ? AND msg_num = ? + LIMIT 1`) + + if err != nil { + return key, false, err + } + defer stmt.Close() + + err = stmt.QueryRow(pubKey, msgNum).Scan(&key) + switch err { + case sql.ErrNoRows: + return key, false, nil + case nil: + return key, true, nil + default: + return key, false, err + } +} + +// Put stores a key with the specified public key, message number and message key +func (s *sqliteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error { + stmt, err := s.db.Prepare(`INSERT INTO keys(session_id, public_key, msg_num, message_key, seq_num) + VALUES(?, ?, ?, ?, ?)`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + sessionID, + pubKey, + msgNum, + mk, + seqNum, + ) + + return err +} + +// DeleteOldMks caps remove any key < seq_num, included +func (s *sqliteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error { + stmt, err := s.db.Prepare(`DELETE FROM keys + WHERE session_id = ? AND seq_num <= ?`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + sessionID, + deleteUntil, + ) + + return err +} + +// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion +func (s *sqliteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error { + stmt, err := s.db.Prepare(`DELETE FROM keys + WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + sessionID, + // We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT + maxNumberOfRows, + maxKeysPerSession, + ) + + return err +} + +// DeleteMk deletes the key with the specified public key and message key +func (s *sqliteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error { + stmt, err := s.db.Prepare(`DELETE FROM keys + WHERE public_key = ? AND msg_num = ?`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + pubKey, + msgNum, + ) + + return err +} + +// Count returns the count of keys with the specified public key +func (s *sqliteKeysStorage) Count(pubKey dr.Key) (uint, error) { + stmt, err := s.db.Prepare(`SELECT COUNT(1) + FROM keys + WHERE public_key = ?`) + if err != nil { + return 0, err + } + defer stmt.Close() + + var count uint + err = stmt.QueryRow(pubKey).Scan(&count) + if err != nil { + return 0, err + } + + return count, nil +} + +// CountAll returns the count of keys with the specified public key +func (s *sqliteKeysStorage) CountAll() (uint, error) { + stmt, err := s.db.Prepare(`SELECT COUNT(1) + FROM keys`) + if err != nil { + return 0, err + } + defer stmt.Close() + + var count uint + err = stmt.QueryRow().Scan(&count) + if err != nil { + return 0, err + } + + return count, nil +} + +// All returns nil +func (s *sqliteKeysStorage) All() (map[string]map[uint]dr.Key, error) { + return nil, nil +} + +type sqliteSessionStorage struct { + db *sql.DB +} + +func newSQLiteSessionStorage(db *sql.DB) *sqliteSessionStorage { + return &sqliteSessionStorage{ + db: db, + } +} + +// Save persists the specified double ratchet state +func (s *sqliteSessionStorage) Save(id []byte, state *dr.State) error { + dhr := state.DHr + dhs := state.DHs + dhsPublic := dhs.PublicKey() + dhsPrivate := dhs.PrivateKey() + pn := state.PN + step := state.Step + keysCount := state.KeysCount + + rootChainKey := state.RootCh.CK + + sendChainKey := state.SendCh.CK + sendChainN := state.SendCh.N + + recvChainKey := state.RecvCh.CK + recvChainN := state.RecvCh.N + + stmt, err := s.db.Prepare(`INSERT INTO sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec( + id, + dhr, + dhsPublic, + dhsPrivate, + rootChainKey, + sendChainKey, + sendChainN, + recvChainKey, + recvChainN, + pn, + step, + keysCount, + ) + + return err +} + +// Load retrieves the double ratchet state for a given ID +func (s *sqliteSessionStorage) Load(id []byte) (*dr.State, error) { + stmt, err := s.db.Prepare(`SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count + FROM sessions + WHERE id = ?`) + if err != nil { + return nil, err + } + + defer stmt.Close() + + var ( + dhr []byte + dhsPublic []byte + dhsPrivate []byte + rootChainKey []byte + sendChainKey []byte + sendChainN uint + recvChainKey []byte + recvChainN uint + pn uint + step uint + keysCount uint + ) + + err = stmt.QueryRow(id).Scan( + &dhr, + &dhsPublic, + &dhsPrivate, + &rootChainKey, + &sendChainKey, + &sendChainN, + &recvChainKey, + &recvChainN, + &pn, + &step, + &keysCount, + ) + switch err { + case sql.ErrNoRows: + return nil, nil + case nil: + state := dr.DefaultState(rootChainKey) + + state.PN = uint32(pn) + state.Step = step + state.KeysCount = keysCount + + state.DHs = crypto.DHPair{ + PrvKey: dhsPrivate, + PubKey: dhsPublic, + } + + state.DHr = dhr + + state.SendCh.CK = sendChainKey + state.SendCh.N = uint32(sendChainN) + + state.RecvCh.CK = recvChainKey + state.RecvCh.N = uint32(recvChainN) + + return &state, nil + default: + return nil, err + } +} + +// GetHashRatchetCache retrieves a hash ratchet key by group ID and seqNo. +// If cache data with given seqNo (e.g. 0) is not found, +// then the query will return the cache data with the latest seqNo +func (s *SQLitePersistence) GetHashRatchetCache(ratchet *HashRatchetKeyCompatibility, seqNo uint32) (*HRCache, error) { + tx, err := s.DB.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return nil, err + } + defer func() { + if err == nil { + err = tx.Commit() + return + } + // don't shadow original error + _ = tx.Rollback() + }() + + var key, keyID []byte + if !ratchet.IsOldFormat() { + keyID, err = ratchet.GetKeyID() + if err != nil { + return nil, err + } + } + + err = tx.QueryRow("SELECT key FROM hash_ratchet_encryption WHERE key_id = ? OR (deprecated_key_id = ? AND group_id = ?)", + keyID, + ratchet.DeprecatedKeyID(), + ratchet.GroupID, + ).Scan(&key) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + + args := make([]interface{}, 0) + args = append(args, ratchet.GroupID) + args = append(args, keyID) + args = append(args, ratchet.DeprecatedKeyID()) + var query string + if seqNo == 0 { + query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) ORDER BY seq_no DESC limit 1" + } else { + query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) AND seq_no == ? ORDER BY seq_no DESC limit 1" + args = append(args, seqNo) + } + + var hash []byte + var seqNoPtr *uint32 + + err = tx.QueryRow(query, args...).Scan(&seqNoPtr, &hash) //nolint: ineffassign,staticcheck + switch err { + case sql.ErrNoRows, nil: + var seqNoResult uint32 + if seqNoPtr == nil { + seqNoResult = 0 + } else { + seqNoResult = *seqNoPtr + } + + ratchet.Key = key + keyID, err := ratchet.GetKeyID() + + if err != nil { + return nil, err + } + + res := &HRCache{ + KeyID: keyID, + Key: key, + Hash: hash, + SeqNo: seqNoResult, + } + + return res, nil + default: + return nil, err + } +} + +// GetCurrentKeyForGroup retrieves a key ID for given group ID +// (with an assumption that key ids are shared in the group, and +// at any given time there is a single key used) +func (s *SQLitePersistence) GetCurrentKeyForGroup(groupID []byte) (*HashRatchetKeyCompatibility, error) { + ratchet := &HashRatchetKeyCompatibility{ + GroupID: groupID, + } + + stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key + FROM hash_ratchet_encryption + WHERE group_id = ? order by key_timestamp desc limit 1`) + if err != nil { + return nil, err + } + defer stmt.Close() + + var keyID, key []byte + var timestamp uint64 + err = stmt.QueryRow(groupID).Scan(&keyID, ×tamp, &key) + + switch err { + case sql.ErrNoRows: + return ratchet, nil + case nil: + ratchet.Key = key + ratchet.Timestamp = timestamp + _, err = ratchet.GetKeyID() + if err != nil { + return nil, err + } + return ratchet, nil + default: + return nil, err + } +} + +// GetKeysForGroup retrieves all key IDs for given group ID +func (s *SQLitePersistence) GetKeysForGroup(groupID []byte) ([]*HashRatchetKeyCompatibility, error) { + + var ratchets []*HashRatchetKeyCompatibility + stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key + FROM hash_ratchet_encryption + WHERE group_id = ? order by key_timestamp desc`) + if err != nil { + return nil, err + } + defer stmt.Close() + + rows, err := stmt.Query(groupID) + if err != nil { + return nil, err + } + + for rows.Next() { + ratchet := &HashRatchetKeyCompatibility{GroupID: groupID} + err := rows.Scan(&ratchet.keyID, &ratchet.Timestamp, &ratchet.Key) + if err != nil { + return nil, err + } + ratchets = append(ratchets, ratchet) + } + + return ratchets, nil +} + +// SaveHashRatchetKeyHash saves a hash ratchet key cache data +func (s *SQLitePersistence) SaveHashRatchetKeyHash( + ratchet *HashRatchetKeyCompatibility, + hash []byte, + seqNo uint32, +) error { + + stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption_cache(group_id, key_id, hash, seq_no) + VALUES(?, ?, ?, ?)`) + if err != nil { + return err + } + defer stmt.Close() + + keyID, err := ratchet.GetKeyID() + if err != nil { + return err + } + + _, err = stmt.Exec(ratchet.GroupID, keyID, hash, seqNo) + + return err +} + +// SaveHashRatchetKey saves a hash ratchet key +func (s *SQLitePersistence) SaveHashRatchetKey(ratchet *HashRatchetKeyCompatibility) error { + stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption(group_id, key_id, key_timestamp, deprecated_key_id, key) + VALUES(?,?,?,?,?)`) + if err != nil { + return err + } + defer stmt.Close() + + keyID, err := ratchet.GetKeyID() + if err != nil { + return err + } + + _, err = stmt.Exec(ratchet.GroupID, keyID, ratchet.Timestamp, ratchet.DeprecatedKeyID(), ratchet.Key) + + return err +} + +func (s *SQLitePersistence) GetHashRatchetKeyByID(keyID []byte) (*HashRatchetKeyCompatibility, error) { + ratchet := &HashRatchetKeyCompatibility{ + keyID: keyID, + } + + err := s.DB.QueryRow(` + SELECT group_id, key_timestamp, key + FROM hash_ratchet_encryption + WHERE key_id = ?`, keyID).Scan(&ratchet.GroupID, &ratchet.Timestamp, &ratchet.Key) + + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + + return ratchet, nil +} diff --git a/messaging/layers/encryption/persistence_keys_storage_test.go b/messaging/layers/encryption/sqlite_persistence_keys_storage_test.go similarity index 94% rename from messaging/layers/encryption/persistence_keys_storage_test.go rename to messaging/layers/encryption/sqlite_persistence_keys_storage_test.go index 7f6d681060b..179e191573c 100644 --- a/messaging/layers/encryption/persistence_keys_storage_test.go +++ b/messaging/layers/encryption/sqlite_persistence_keys_storage_test.go @@ -4,10 +4,10 @@ import ( "testing" dr "github.com/status-im/doubleratchet" + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/suite" - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/protocol/sqlite" + "github.com/status-im/status-go/messaging/layers/encryption/migrations" "github.com/status-im/status-go/t/helpers" ) @@ -31,13 +31,15 @@ type SQLLitePersistenceKeysStorageTestSuite struct { } func (s *SQLLitePersistenceKeysStorageTestSuite) SetupTest() { - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) s.Require().NoError(err) - p := newSQLitePersistence(db) - s.service = p.KeysStorage() + s.service = newSQLiteKeysStorage(db) } func (s *SQLLitePersistenceKeysStorageTestSuite) TestKeysStorageSqlLiteGetMissing() { diff --git a/messaging/layers/encryption/persistence_test.go b/messaging/layers/encryption/sqlite_persistence_test.go similarity index 97% rename from messaging/layers/encryption/persistence_test.go rename to messaging/layers/encryption/sqlite_persistence_test.go index 108f15c2d25..ae267c2664d 100644 --- a/messaging/layers/encryption/persistence_test.go +++ b/messaging/layers/encryption/sqlite_persistence_test.go @@ -6,12 +6,12 @@ import ( "github.com/stretchr/testify/suite" - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/crypto" - "github.com/status-im/status-go/t/helpers" + bindata "github.com/status-im/migrate/v4/source/go_bindata" + "github.com/status-im/status-go/crypto" + "github.com/status-im/status-go/messaging/layers/encryption/migrations" "github.com/status-im/status-go/messaging/layers/encryption/multidevice" - "github.com/status-im/status-go/protocol/sqlite" + "github.com/status-im/status-go/t/helpers" ) func TestSQLLitePersistenceTestSuite(t *testing.T) { @@ -20,16 +20,18 @@ func TestSQLLitePersistenceTestSuite(t *testing.T) { type SQLLitePersistenceTestSuite struct { suite.Suite - service *sqlitePersistence + service Persistence } func (s *SQLLitePersistenceTestSuite) SetupTest() { - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - s.Require().NoError(err) - err = sqlite.Migrate(db) - s.Require().NoError(err) - - s.service = newSQLitePersistence(db) + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) + s.Require().NoError(err) + s.service = NewSQLitePersistence(db) } func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() { diff --git a/protocol/migrations/sqlite/1699554099_message_segments.up.sql b/messaging/layers/segmentation/migrations/sqlite/1699554099_message_segments.up.sql similarity index 100% rename from protocol/migrations/sqlite/1699554099_message_segments.up.sql rename to messaging/layers/segmentation/migrations/sqlite/1699554099_message_segments.up.sql diff --git a/protocol/migrations/sqlite/1700044186_message_segments_timestamp.up.sql b/messaging/layers/segmentation/migrations/sqlite/1700044186_message_segments_timestamp.up.sql similarity index 100% rename from protocol/migrations/sqlite/1700044186_message_segments_timestamp.up.sql rename to messaging/layers/segmentation/migrations/sqlite/1700044186_message_segments_timestamp.up.sql diff --git a/messaging/layers/segmentation/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql b/messaging/layers/segmentation/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql new file mode 100644 index 00000000000..87a3b69f78c --- /dev/null +++ b/messaging/layers/segmentation/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql @@ -0,0 +1,45 @@ +ALTER TABLE message_segments RENAME TO old_message_segments; + +CREATE TABLE message_segments ( + hash BLOB NOT NULL, + segment_index INTEGER NOT NULL, + segments_count INTEGER NOT NULL, + payload BLOB NOT NULL, + sig_pub_key BLOB NOT NULL, + timestamp INTEGER NOT NULL, + parity_segment_index INTEGER NOT NULL, + parity_segments_count INTEGER NOT NULL, + PRIMARY KEY ( + hash, + sig_pub_key, + segment_index, + segments_count, + parity_segment_index, + parity_segments_count + ) ON CONFLICT + REPLACE +); + +INSERT INTO + message_segments ( + hash, + segment_index, + segments_count, + payload, + sig_pub_key, + timestamp, + parity_segment_index, + parity_segments_count + ) +SELECT + hash, + segment_index, + segments_count, + payload, + sig_pub_key, + timestamp, + 0, + 0 +FROM old_message_segments; + +DROP TABLE old_message_segments; diff --git a/messaging/layers/segmentation/migrations/sqlite/doc.go b/messaging/layers/segmentation/migrations/sqlite/doc.go new file mode 100644 index 00000000000..a26a30c6799 --- /dev/null +++ b/messaging/layers/segmentation/migrations/sqlite/doc.go @@ -0,0 +1,9 @@ +// This file is necessary because "github.com/status-im/migrate/v4" +// can't handle files starting with a prefix. At least that's the case +// for go-bindata. +// If go-bindata is called from the same directory, asset names +// have no prefix and "github.com/status-im/migrate/v4" works as expected. + +package sqlite + +//go:generate go tool go-bindata -modtime=1700000000 -pkg migrations -o ../migrations.go . diff --git a/messaging/layers/segmentation/persistence.go b/messaging/layers/segmentation/persistence.go new file mode 100644 index 00000000000..2b6085eea24 --- /dev/null +++ b/messaging/layers/segmentation/persistence.go @@ -0,0 +1,16 @@ +package segmentation + +import ( + "crypto/ecdsa" + + "github.com/status-im/status-go/messaging/types" +) + +type Persistence interface { + IsMessageAlreadyCompleted(hash []byte) (bool, error) + SaveMessageSegment(segment *types.SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error + GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*types.SegmentMessage, error) + CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error + RemoveMessageSegmentsOlderThan(timestamp int64) error + RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error +} diff --git a/messaging/layers/segmentation/sqlite_persistence.go b/messaging/layers/segmentation/sqlite_persistence.go new file mode 100644 index 00000000000..28263481d8e --- /dev/null +++ b/messaging/layers/segmentation/sqlite_persistence.go @@ -0,0 +1,118 @@ +package segmentation + +import ( + "context" + "crypto/ecdsa" + "database/sql" + + "github.com/ethereum/go-ethereum/crypto" + + "github.com/status-im/status-go/messaging/types" + "github.com/status-im/status-go/protocol/protobuf" +) + +type SQLitePersistence struct { + db *sql.DB +} + +func NewSQLitePersistence(db *sql.DB) *SQLitePersistence { + return &SQLitePersistence{db: db} +} + +func (s *SQLitePersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) { + var alreadyCompleted int + err := s.db.QueryRow("SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash).Scan(&alreadyCompleted) + if err != nil { + return false, err + } + return alreadyCompleted > 0, nil +} + +func (s *SQLitePersistence) SaveMessageSegment(segment *types.SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error { + sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) + + _, err := s.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + segment.EntireMessageHash, segment.Index, segment.SegmentsCount, segment.ParitySegmentIndex, segment.ParitySegmentsCount, sigPubKeyBlob, segment.Payload, timestamp) + + return err +} + +// Get ordered message segments for given hash +func (s *SQLitePersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*types.SegmentMessage, error) { + sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) + + rows, err := s.db.Query(` + SELECT + hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload + FROM + message_segments + WHERE + hash = ? AND sig_pub_key = ? + ORDER BY + (segments_count = 0) ASC, -- Prioritize segments_count > 0 + segment_index ASC, + parity_segment_index ASC`, + hash, sigPubKeyBlob) + if err != nil { + return nil, err + } + defer rows.Close() + + var segments []*types.SegmentMessage + for rows.Next() { + segment := &types.SegmentMessage{ + SegmentMessage: &protobuf.SegmentMessage{}, + } + err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload) + if err != nil { + return nil, err + } + segments = append(segments, segment) + } + err = rows.Err() + if err != nil { + return nil, err + } + + return segments, nil +} + +func (s *SQLitePersistence) RemoveMessageSegmentsOlderThan(timestamp int64) error { + _, err := s.db.Exec("DELETE FROM message_segments WHERE timestamp < ?", timestamp) + return err +} + +func (s *SQLitePersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error { + tx, err := s.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return err + } + + defer func() { + if err == nil { + err = tx.Commit() + return + } + // don't shadow original error + _ = tx.Rollback() + }() + + sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) + + _, err = tx.Exec("DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob) + if err != nil { + return err + } + + _, err = tx.Exec("INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?,?,?)", hash, sigPubKeyBlob, timestamp) + if err != nil { + return err + } + + return err +} + +func (s *SQLitePersistence) RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error { + _, err := s.db.Exec("DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp) + return err +} diff --git a/messaging/layers/transport/filters_manager.go b/messaging/layers/transport/filters_manager.go index 44fe617e6bc..0e90bbc2414 100644 --- a/messaging/layers/transport/filters_manager.go +++ b/messaging/layers/transport/filters_manager.go @@ -28,11 +28,6 @@ type RawFilter struct { SymKeyID string } -type KeysPersistence interface { - All() (map[string][]byte, error) - Add(chatID string, key []byte) error -} - type FiltersService interface { AddKeyPair(key *ecdsa.PrivateKey) (string, error) DeleteKeyPair(keyID string) bool diff --git a/protocol/migrations/sqlite/1610117927_add_message_cache.up.sql b/messaging/layers/transport/migrations/sqlite/1610117927_add_message_cache.up.sql similarity index 100% rename from protocol/migrations/sqlite/1610117927_add_message_cache.up.sql rename to messaging/layers/transport/migrations/sqlite/1610117927_add_message_cache.up.sql diff --git a/protocol/migrations/sqlite/1616691080_add_wakuV2_keys.up.sql b/messaging/layers/transport/migrations/sqlite/1616691080_add_wakuV2_keys.up.sql similarity index 100% rename from protocol/migrations/sqlite/1616691080_add_wakuV2_keys.up.sql rename to messaging/layers/transport/migrations/sqlite/1616691080_add_wakuV2_keys.up.sql diff --git a/protocol/migrations/sqlite/1634723014_add_wakuV2_keys.up.sql b/messaging/layers/transport/migrations/sqlite/1634723014_add_wakuV2_keys.up.sql similarity index 100% rename from protocol/migrations/sqlite/1634723014_add_wakuV2_keys.up.sql rename to messaging/layers/transport/migrations/sqlite/1634723014_add_wakuV2_keys.up.sql diff --git a/messaging/layers/transport/migrations/sqlite/doc.go b/messaging/layers/transport/migrations/sqlite/doc.go new file mode 100644 index 00000000000..a26a30c6799 --- /dev/null +++ b/messaging/layers/transport/migrations/sqlite/doc.go @@ -0,0 +1,9 @@ +// This file is necessary because "github.com/status-im/migrate/v4" +// can't handle files starting with a prefix. At least that's the case +// for go-bindata. +// If go-bindata is called from the same directory, asset names +// have no prefix and "github.com/status-im/migrate/v4" works as expected. + +package sqlite + +//go:generate go tool go-bindata -modtime=1700000000 -pkg migrations -o ../migrations.go . diff --git a/messaging/layers/transport/processed_message_ids_cache.go b/messaging/layers/transport/persistence.go similarity index 65% rename from messaging/layers/transport/processed_message_ids_cache.go rename to messaging/layers/transport/persistence.go index 82de1ea3789..2ee30cca323 100644 --- a/messaging/layers/transport/processed_message_ids_cache.go +++ b/messaging/layers/transport/persistence.go @@ -1,5 +1,10 @@ package transport +type KeysPersistence interface { + All() (map[string][]byte, error) + Add(chatID string, key []byte) error +} + type ProcessedMessageIDsCachePersistence interface { Clear() error Hits(ids []string) (map[string]bool, error) diff --git a/messaging/layers/transport/sqlite_persistence.go b/messaging/layers/transport/sqlite_persistence.go new file mode 100644 index 00000000000..24d6cebf6c7 --- /dev/null +++ b/messaging/layers/transport/sqlite_persistence.go @@ -0,0 +1,150 @@ +package transport + +import ( + "context" + "database/sql" + "strings" +) + +type SQLiteKeysPersistence struct { + db *sql.DB +} + +func NewSQLiteKeysPersistence(db *sql.DB) *SQLiteKeysPersistence { + return &SQLiteKeysPersistence{db: db} +} + +func (s *SQLiteKeysPersistence) Add(chatID string, key []byte) error { + stmt, err := s.db.Prepare("INSERT INTO wakuv2_keys(chat_id, key) VALUES(?, ?)") + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(chatID, key) + return err +} + +func (s *SQLiteKeysPersistence) All() (map[string][]byte, error) { + keys := make(map[string][]byte) + + stmt, err := s.db.Prepare("SELECT chat_id, key FROM wakuv2_keys") + if err != nil { + return nil, err + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil && err != sql.ErrNoRows { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var ( + chatID string + key []byte + ) + + err := rows.Scan(&chatID, &key) + if err != nil { + return nil, err + } + keys[chatID] = key + } + + return keys, nil +} + +type SQLiteProcessedMessageIDsCachePersistence struct { + db *sql.DB +} + +func NewSQLiteProcessedMessageIDsCachePersistence(db *sql.DB) *SQLiteProcessedMessageIDsCachePersistence { + return &SQLiteProcessedMessageIDsCachePersistence{db: db} +} + +func (c *SQLiteProcessedMessageIDsCachePersistence) Clear() error { + _, err := c.db.Exec("DELETE FROM transport_message_cache") + return err +} + +func (c *SQLiteProcessedMessageIDsCachePersistence) Hits(ids []string) (map[string]bool, error) { + hits := make(map[string]bool) + + // Split the results into batches of 999 items. + // To prevent excessive memory allocations, the maximum value of a host parameter number + // is SQLITE_MAX_VARIABLE_NUMBER, which defaults to 999 + batch := 999 + for i := 0; i < len(ids); i += batch { + j := i + batch + if j > len(ids) { + j = len(ids) + } + + currentBatch := ids[i:j] + + idsArgs := make([]interface{}, 0, len(currentBatch)) + for _, id := range currentBatch { + idsArgs = append(idsArgs, id) + } + + inVector := strings.Repeat("?, ", len(currentBatch)-1) + "?" + query := "SELECT id FROM transport_message_cache WHERE id IN (" + inVector + ")" // nolint: gosec + + rows, err := c.db.Query(query, idsArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var id string + err := rows.Scan(&id) + if err != nil { + return nil, err + } + hits[id] = true + } + } + + return hits, nil +} + +func (c *SQLiteProcessedMessageIDsCachePersistence) Add(ids []string, timestamp uint64) (err error) { + var tx *sql.Tx + tx, err = c.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return + } + + defer func() { + if err == nil { + err = tx.Commit() + return + } + // don't shadow original error + _ = tx.Rollback() + }() + + for _, id := range ids { + + var stmt *sql.Stmt + stmt, err = tx.Prepare(`INSERT INTO transport_message_cache(id,timestamp) VALUES (?, ?)`) + if err != nil { + return + } + + _, err = stmt.Exec(id, timestamp) + if err != nil { + return + } + } + + return +} + +func (c *SQLiteProcessedMessageIDsCachePersistence) Clean(timestamp uint64) error { + _, err := c.db.Exec(`DELETE FROM transport_message_cache WHERE timestamp < ?`, timestamp) + return err +} diff --git a/messaging/layers/transport/transport_test.go b/messaging/layers/transport/transport_test.go index 0924d1e5027..492e16bceb2 100644 --- a/messaging/layers/transport/transport_test.go +++ b/messaging/layers/transport/transport_test.go @@ -3,54 +3,26 @@ package transport import ( "testing" - "github.com/status-im/status-go/appdatabase" - "github.com/status-im/status-go/protocol/sqlite" - "github.com/status-im/status-go/t/helpers" - + bindata "github.com/status-im/migrate/v4/source/go_bindata" "github.com/stretchr/testify/require" + "github.com/status-im/status-go/messaging/layers/transport/migrations" "github.com/status-im/status-go/protocol/tt" + "github.com/status-im/status-go/t/helpers" ) -type keysPersistenceMock struct { -} - -func (p *keysPersistenceMock) All() (map[string][]byte, error) { - return map[string][]byte{}, nil -} - -func (p *keysPersistenceMock) Add(chatID string, key []byte) error { - return nil -} - -type processedMessageIDsCacheMock struct { -} - -func (p *processedMessageIDsCacheMock) Clear() error { - return nil -} -func (p *processedMessageIDsCacheMock) Hits(ids []string) (map[string]bool, error) { - return map[string]bool{}, nil -} -func (p *processedMessageIDsCacheMock) Add(ids []string, timestamp uint64) error { - return nil -} -func (p *processedMessageIDsCacheMock) Clean(timestamp uint64) error { - return nil -} - func TestNewTransport(t *testing.T) { - db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{}) - require.NoError(t, err) - err = sqlite.Migrate(db) - require.NoError(t, err) - + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) require.NoError(t, err) logger := tt.MustCreateTestLogger() - require.NoError(t, err) defer func() { _ = logger.Sync() }() - _, err = NewTransport(nil, nil, &keysPersistenceMock{}, &processedMessageIDsCacheMock{}, nil, logger) + _, err = NewTransport(nil, nil, NewSQLiteKeysPersistence(db), NewSQLiteProcessedMessageIDsCachePersistence(db), nil, logger) require.NoError(t, err) } diff --git a/messaging/persistence.go b/messaging/persistence.go new file mode 100644 index 00000000000..ac21b00d368 --- /dev/null +++ b/messaging/persistence.go @@ -0,0 +1,25 @@ +package messaging + +import ( + mvdsnode "github.com/status-im/mvds/node" + + "github.com/status-im/status-go/messaging/layers/encryption" + "github.com/status-im/status-go/messaging/layers/segmentation" + "github.com/status-im/status-go/messaging/layers/transport" + "github.com/status-im/status-go/messaging/types" + wakuv2 "github.com/status-im/status-go/messaging/waku" +) + +type Persistence interface { + WakuStorage() wakuv2.ProtectedTopicsPersistence + TransportStorage() TransportPersistence + SegmentationStorage() segmentation.Persistence + MVDSStorage() mvdsnode.Persistence + EncryptionStorage() encryption.Persistence + MessageSenderStorage() types.MessageSenderPersistence +} + +type TransportPersistence interface { + KeysStorage() transport.KeysPersistence + ProcessedMessageIDsCacheStorage() transport.ProcessedMessageIDsCachePersistence +} diff --git a/messaging/sqlite_persistence.go b/messaging/sqlite_persistence.go new file mode 100644 index 00000000000..0c632b2780d --- /dev/null +++ b/messaging/sqlite_persistence.go @@ -0,0 +1,155 @@ +package messaging + +import ( + "database/sql" + "fmt" + + "github.com/pkg/errors" + bindata "github.com/status-im/migrate/v4/source/go_bindata" + mvdsnode "github.com/status-im/mvds/node" + mvdsmigrations "github.com/status-im/mvds/persistenceutil" + + "github.com/status-im/status-go/messaging/common" + messagesendermigrations "github.com/status-im/status-go/messaging/common/migrations" + "github.com/status-im/status-go/messaging/layers/encryption" + encryptionmigrations "github.com/status-im/status-go/messaging/layers/encryption/migrations" + "github.com/status-im/status-go/messaging/layers/segmentation" + segmentationmigrations "github.com/status-im/status-go/messaging/layers/segmentation/migrations" + "github.com/status-im/status-go/messaging/layers/transport" + transportmigrations "github.com/status-im/status-go/messaging/layers/transport/migrations" + "github.com/status-im/status-go/messaging/types" + wakuv2 "github.com/status-im/status-go/messaging/waku" + wakumigrations "github.com/status-im/status-go/messaging/waku/migrations" + "github.com/status-im/status-go/sqlite" +) + +type migrationsMetadata struct { + *bindata.AssetSource + MigrationTableName string +} + +var migrations = []migrationsMetadata{ + { + AssetSource: &bindata.AssetSource{ + Names: wakumigrations.AssetNames(), + AssetFunc: wakumigrations.Asset, + }, + MigrationTableName: "status_schema_migrations_waku", + }, + { + AssetSource: &bindata.AssetSource{ + Names: transportmigrations.AssetNames(), + AssetFunc: transportmigrations.Asset, + }, + MigrationTableName: "status_schema_migrations_transport", + }, + { + AssetSource: &bindata.AssetSource{ + Names: segmentationmigrations.AssetNames(), + AssetFunc: segmentationmigrations.Asset, + }, + MigrationTableName: "status_schema_migrations_segmentation", + }, + { + AssetSource: &bindata.AssetSource{ + Names: encryptionmigrations.AssetNames(), + AssetFunc: encryptionmigrations.Asset, + }, + MigrationTableName: "status_schema_migrations_encryption", + }, + { + AssetSource: &bindata.AssetSource{ + Names: messagesendermigrations.AssetNames(), + AssetFunc: messagesendermigrations.Asset, + }, + MigrationTableName: "status_schema_migrations_message_sender", + }, +} + +// SQLiteMigrate applies necessary migrations to the SQLite database schema. +func SQLiteMigrate(database *sql.DB, maxVersion uint) error { + if maxVersion > 0 { + err := createMigrationTables(database, maxVersion) + if err != nil { + return errors.Wrap(err, "failed to update migration tables") + } + } + + err := mvdsmigrations.Migrate(database) + if err != nil { + return errors.Wrap(err, "failed to apply mvds migrations") + } + + for _, m := range migrations { + err := sqlite.Migrate(database, m.AssetSource, sqlite.MigrateOptions{MigrationTableName: m.MigrationTableName}) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("failed to apply %s migrations", m.MigrationTableName)) + } + } + + return nil +} + +// Migration tables were transitioned from a single shared table in the client (status_protocol_go) +// to dedicated pre-component tables. To maintain migration consistency and prevent reapplication +// of migrations already executed in the client, it is essential to initialize any newly created +// migration tables with the latest version. This ensures that components introduced into the client +// do not re-run migrations that have previously been applied. +func createMigrationTables(database *sql.DB, maxVersion uint) error { + for _, m := range migrations { + err := sqlite.UpdateMigrationTableVersion(database, m.MigrationTableName, m.Names, maxVersion) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("failed to update migration table %s", m.MigrationTableName)) + } + } + + return nil +} + +type sqlitePersistence struct { + db *sql.DB +} + +var _ Persistence = (*sqlitePersistence)(nil) + +func newSQLitePersistence(db *sql.DB) Persistence { + return &sqlitePersistence{db: db} +} + +func (p *sqlitePersistence) WakuStorage() wakuv2.ProtectedTopicsPersistence { + return wakuv2.NewSQLiteProtectedTopicsPersistence(p.db) +} + +type sqliteTransportPersistence struct { + db *sql.DB +} + +var _ TransportPersistence = (*sqliteTransportPersistence)(nil) + +func (p *sqliteTransportPersistence) KeysStorage() transport.KeysPersistence { + return transport.NewSQLiteKeysPersistence(p.db) +} + +func (p *sqliteTransportPersistence) ProcessedMessageIDsCacheStorage() transport.ProcessedMessageIDsCachePersistence { + return transport.NewSQLiteProcessedMessageIDsCachePersistence(p.db) +} + +func (p *sqlitePersistence) TransportStorage() TransportPersistence { + return &sqliteTransportPersistence{db: p.db} +} + +func (p *sqlitePersistence) SegmentationStorage() segmentation.Persistence { + return segmentation.NewSQLitePersistence(p.db) +} + +func (p *sqlitePersistence) MVDSStorage() mvdsnode.Persistence { + return mvdsnode.NewSQLitePersistence(p.db) +} + +func (p *sqlitePersistence) EncryptionStorage() encryption.Persistence { + return encryption.NewSQLitePersistence(p.db) +} + +func (p *sqlitePersistence) MessageSenderStorage() types.MessageSenderPersistence { + return common.NewSQLiteMessageSenderPersistence(p.db) +} diff --git a/messaging/types/message_sender_persistence.go b/messaging/types/message_sender_persistence.go new file mode 100644 index 00000000000..565ae170f97 --- /dev/null +++ b/messaging/types/message_sender_persistence.go @@ -0,0 +1,14 @@ +package types + +import ( + cryptotypes "github.com/status-im/status-go/crypto/types" +) + +type MessageSenderPersistence interface { + InsertPendingConfirmation(confirmation *RawMessageConfirmation) error + MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID cryptotypes.HexBytes, err error) + SaveHashRatchetMessage(groupID []byte, keyID []byte, m *ReceivedMessage) error + GetHashRatchetMessages(keyID []byte) ([]*ReceivedMessage, error) + DeleteHashRatchetMessages(ids [][]byte) error + DeleteHashRatchetMessagesOlderThan(timestamp int64) error +} diff --git a/messaging/types/persistence.go b/messaging/types/persistence.go deleted file mode 100644 index 2d06b799d58..00000000000 --- a/messaging/types/persistence.go +++ /dev/null @@ -1,39 +0,0 @@ -package types - -import "crypto/ecdsa" - -type Persistence interface { - wakuPersistence - - MessageCacheAdd(ids []string, timestamp uint64) error - MessageCacheClear() error - MessageCacheClearOlderThan(timestamp uint64) error - MessageCacheHits(ids []string) (map[string]bool, error) - - InsertPendingConfirmation(confirmation *RawMessageConfirmation) error - SaveHashRatchetMessage(groupID []byte, keyID []byte, m *ReceivedMessage) error - GetHashRatchetMessages(keyID []byte) ([]*ReceivedMessage, error) - DeleteHashRatchetMessages(ids [][]byte) error - DeleteHashRatchetMessagesOlderThan(timestamp int64) error - - IsMessageAlreadyCompleted(hash []byte) (bool, error) - SaveMessageSegment(segment *SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error - GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*SegmentMessage, error) - CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error - RemoveMessageSegmentsOlderThan(timestamp int64) error - RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error -} - -type ProtectedTopic struct { - PubKey *ecdsa.PublicKey - Topic string -} - -type wakuPersistence interface { - WakuKeys() (map[string][]byte, error) - AddWakuKey(chatID string, key []byte) error - WakuInsertProtectedTopic(pubsubTopic string, privKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey) error - WakuDeleteProtectedTopic(pubsubTopic string) error - WakuFetchPrivateKeyForProtectedTopic(topic string) (*ecdsa.PrivateKey, error) - WakuProtectedTopics() ([]ProtectedTopic, error) -} diff --git a/messaging/waku/gowaku.go b/messaging/waku/gowaku.go index 8bc2028bb38..d5d0b132d5e 100644 --- a/messaging/waku/gowaku.go +++ b/messaging/waku/gowaku.go @@ -29,7 +29,6 @@ import ( "context" "crypto/ecdsa" "crypto/sha256" - "database/sql" "errors" "fmt" "math" @@ -91,7 +90,6 @@ import ( "github.com/status-im/status-go/internal/timesource" "github.com/status-im/status-go/logutils" "github.com/status-im/status-go/messaging/waku/common" - "github.com/status-im/status-go/messaging/waku/persistence" "github.com/status-im/status-go/messaging/waku/types" ) @@ -133,8 +131,7 @@ type IMetricsHandler interface { // Waku represents a dark communication interface through the Ethereum // network, using its very own P2P communication layer. type Waku struct { - node *node.WakuNode // reference to a libp2p waku node - appDB *sql.DB + node *node.WakuNode // reference to a libp2p waku node dnsAddressCache map[string][]dnsdisc.DiscoveredNode // Map to store the multiaddresses returned by dns discovery dnsAddressCacheLock *sync.RWMutex // lock to handle access to the map @@ -153,7 +150,7 @@ type Waku struct { bandwidthCounter *metrics.BandwidthCounter - protectedTopicStore persistence.ProtectedTopics + protectedTopicStore ProtectedTopicsPersistence sendQueue *publish.MessageQueue @@ -220,7 +217,7 @@ func newTTLCache() *ttlcache.Cache[gethcommon.Hash, bool] { } // New creates a WakuV2 client ready to communicate through the LibP2P network. -func New(nodeKey *ecdsa.PrivateKey, cfg *Config, logger *zap.Logger, protectedTopicsPersistence persistence.ProtectedTopics, ts timesource.Provider, onHistoricMessagesRequestFailed func([]byte, peer.AddrInfo, error), onPeerStats func(types.ConnStatus)) (*Waku, error) { +func New(nodeKey *ecdsa.PrivateKey, cfg *Config, logger *zap.Logger, protectedTopicsPersistence ProtectedTopicsPersistence, ts timesource.Provider, onHistoricMessagesRequestFailed func([]byte, peer.AddrInfo, error), onPeerStats func(types.ConnStatus)) (*Waku, error) { var err error if logger == nil { logger, err = zap.NewDevelopment() @@ -1359,7 +1356,7 @@ func (w *Waku) setupRelaySubscriptions() error { } if w.protectedTopicStore != nil { - protectedTopics, err := w.protectedTopicStore.ProtectedTopics() + protectedTopics, err := w.protectedTopicStore.All() if err != nil { return err } diff --git a/messaging/waku/migrations/sqlite/0001_mailserver_topics.up.sql b/messaging/waku/migrations/sqlite/0001_mailserver_topics.up.sql new file mode 100644 index 00000000000..bc9cc73dea6 --- /dev/null +++ b/messaging/waku/migrations/sqlite/0001_mailserver_topics.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS mailserver_topics ( + topic VARCHAR PRIMARY KEY, + chat_ids VARCHAR, + last_request INTEGER DEFAULT 1, + discovery BOOLEAN DEFAULT FALSE, + negotiated BOOLEAN DEFAULT FALSE +) WITHOUT ROWID; diff --git a/appdatabase/migrations/sql/1691753800_pubsubtopic_key.up.sql b/messaging/waku/migrations/sqlite/1691753800_pubsubtopic_key.up.sql similarity index 100% rename from appdatabase/migrations/sql/1691753800_pubsubtopic_key.up.sql rename to messaging/waku/migrations/sqlite/1691753800_pubsubtopic_key.up.sql diff --git a/messaging/waku/migrations/sqlite/doc.go b/messaging/waku/migrations/sqlite/doc.go new file mode 100644 index 00000000000..4f1c2a3f66f --- /dev/null +++ b/messaging/waku/migrations/sqlite/doc.go @@ -0,0 +1,3 @@ +package sql + +//go:generate go tool go-bindata -modtime=1700000000 -pkg migrations -o ../migrations.go ./ diff --git a/messaging/waku/nwaku.go b/messaging/waku/nwaku.go index 4c01e133022..b8741621c24 100644 --- a/messaging/waku/nwaku.go +++ b/messaging/waku/nwaku.go @@ -63,7 +63,6 @@ import ( "github.com/status-im/status-go/internal/timesource" "github.com/status-im/status-go/logutils" "github.com/status-im/status-go/messaging/waku/common" - "github.com/status-im/status-go/messaging/waku/persistence" "github.com/status-im/status-go/messaging/waku/types" "github.com/waku-org/go-waku/waku/v2/node" @@ -123,7 +122,7 @@ type Waku struct { bandwidthCounter *metrics.BandwidthCounter - protectedTopicStore persistence.ProtectedTopics + protectedTopicStore ProtectedTopicsPersistence sendQueue *publish.MessageQueue @@ -203,7 +202,7 @@ func newTTLCache() *ttlcache.Cache[gethcommon.Hash, bool] { } // New creates a WakuV2 client ready to communicate through the LibP2P network. -func New(nodeKey *ecdsa.PrivateKey, cfg *Config, logger *zap.Logger, protectedTopicsPersistence persistence.ProtectedTopics, ts timesource.Provider, onHistoricMessagesRequestFailed func([]byte, peer.AddrInfo, error), onPeerStats func(types.ConnStatus)) (*Waku, error) { +func New(nodeKey *ecdsa.PrivateKey, cfg *Config, logger *zap.Logger, protectedTopicsPersistence ProtectedTopicsPersistence, ts timesource.Provider, onHistoricMessagesRequestFailed func([]byte, peer.AddrInfo, error), onPeerStats func(types.ConnStatus)) (*Waku, error) { var err error if logger == nil { logger, err = zap.NewDevelopment() @@ -1034,7 +1033,7 @@ func (w *Waku) setupRelaySubscriptions() error { } if w.protectedTopicStore != nil { - protectedTopics, err := w.protectedTopicStore.ProtectedTopics() + protectedTopics, err := w.protectedTopicStore.All() if err != nil { return err } diff --git a/messaging/waku/persistence/protected_topics.go b/messaging/waku/persistence.go similarity index 67% rename from messaging/waku/persistence/protected_topics.go rename to messaging/waku/persistence.go index bf641e0f594..cb055dcbbf4 100644 --- a/messaging/waku/persistence/protected_topics.go +++ b/messaging/waku/persistence.go @@ -1,14 +1,12 @@ -package persistence +package wakuv2 -import ( - "crypto/ecdsa" -) +import "crypto/ecdsa" -type ProtectedTopics interface { +type ProtectedTopicsPersistence interface { Insert(pubsubTopic string, privKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey) error Delete(pubsubTopic string) error FetchPrivateKey(topic string) (*ecdsa.PrivateKey, error) - ProtectedTopics() ([]ProtectedTopic, error) + All() ([]ProtectedTopic, error) } type ProtectedTopic struct { diff --git a/messaging/waku/persistence_test.go b/messaging/waku/persistence_test.go new file mode 100644 index 00000000000..f980521ffd1 --- /dev/null +++ b/messaging/waku/persistence_test.go @@ -0,0 +1,61 @@ +package wakuv2 + +import ( + "testing" + + bindata "github.com/status-im/migrate/v4/source/go_bindata" + "github.com/stretchr/testify/require" + + "github.com/ethereum/go-ethereum/crypto" + + "github.com/status-im/status-go/messaging/waku/migrations" + "github.com/status-im/status-go/t/helpers" +) + +func TestProtectedTopicsPersistence(t *testing.T) { + db, err := helpers.SetupTestMemorySQLDB(helpers.NewTestDBInitializer([]*bindata.AssetSource{ + { + Names: migrations.AssetNames(), + AssetFunc: migrations.Asset, + }, + })) + require.NoError(t, err) + + p := NewSQLiteProtectedTopicsPersistence(db) + + // Generate ECDSA keys + privKey, err := crypto.GenerateKey() + require.NoError(t, err) + pubKey := &privKey.PublicKey + + pubsubTopic := "test-topic" + + // Insert protected topic + err = p.Insert(pubsubTopic, privKey, pubKey) + require.NoError(t, err) + + // Fetch private key for topic + fetchedPrivKey, err := p.FetchPrivateKey(pubsubTopic) + require.NoError(t, err) + require.NotNil(t, fetchedPrivKey) + require.Equal(t, privKey.D.Bytes(), fetchedPrivKey.D.Bytes()) + + // Fetch protected topics + topics, err := p.All() + require.NoError(t, err) + require.Len(t, topics, 1) + require.Equal(t, pubsubTopic, topics[0].Topic) + + // Delete protected topic + err = p.Delete(pubsubTopic) + require.NoError(t, err) + + // Ensure topic is deleted + topics, err = p.All() + require.NoError(t, err) + require.Len(t, topics, 0) + + fetchedPrivKey, err = p.FetchPrivateKey(pubsubTopic) + require.NoError(t, err) + require.Nil(t, fetchedPrivKey) +} diff --git a/messaging/waku/sqlite_persistence.go b/messaging/waku/sqlite_persistence.go new file mode 100644 index 00000000000..4e795cbeaa2 --- /dev/null +++ b/messaging/waku/sqlite_persistence.go @@ -0,0 +1,78 @@ +package wakuv2 + +import ( + "crypto/ecdsa" + "database/sql" + "errors" + + "github.com/status-im/status-go/crypto" +) + +type SQLiteProtectedTopicsPersistence struct { + db *sql.DB +} + +var _ ProtectedTopicsPersistence = (*SQLiteProtectedTopicsPersistence)(nil) + +func NewSQLiteProtectedTopicsPersistence(db *sql.DB) *SQLiteProtectedTopicsPersistence { + return &SQLiteProtectedTopicsPersistence{db: db} +} + +func (s *SQLiteProtectedTopicsPersistence) Insert(pubsubTopic string, privKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey) error { + var privKeyBytes []byte + if privKey != nil { + privKeyBytes = crypto.FromECDSA(privKey) + } + pubKeyBytes := crypto.FromECDSAPub(publicKey) + + _, err := s.db.Exec("INSERT OR REPLACE INTO pubsubtopic_signing_key (topic, priv_key, pub_key) VALUES (?, ?, ?)", + pubsubTopic, privKeyBytes, pubKeyBytes) + return err +} + +func (s *SQLiteProtectedTopicsPersistence) Delete(pubsubTopic string) error { + _, err := s.db.Exec("DELETE FROM pubsubtopic_signing_key WHERE topic = ?", pubsubTopic) + return err +} + +func (s *SQLiteProtectedTopicsPersistence) FetchPrivateKey(topic string) (*ecdsa.PrivateKey, error) { + var privKeyBytes []byte + err := s.db.QueryRow("SELECT priv_key FROM pubsubtopic_signing_key WHERE topic = ?", topic).Scan(&privKeyBytes) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return crypto.ToECDSA(privKeyBytes) +} + +func (s *SQLiteProtectedTopicsPersistence) All() ([]ProtectedTopic, error) { + rows, err := s.db.Query("SELECT pub_key, topic FROM pubsubtopic_signing_key") + if err != nil { + return nil, err + } + defer rows.Close() + + var result []ProtectedTopic + for rows.Next() { + var pubKeyBytes []byte + var topic string + err := rows.Scan(&pubKeyBytes, &topic) + if err != nil { + return nil, err + } + + pubk, err := crypto.UnmarshalPubkey(pubKeyBytes) + if err != nil { + return nil, err + } + + result = append(result, ProtectedTopic{ + PubKey: pubk, + Topic: topic, + }) + } + + return result, nil +} diff --git a/protocol/messaging_persistence.go b/protocol/messaging_persistence.go deleted file mode 100644 index bf1169d5365..00000000000 --- a/protocol/messaging_persistence.go +++ /dev/null @@ -1,436 +0,0 @@ -package protocol - -import ( - "context" - "crypto/ecdsa" - "database/sql" - "errors" - "fmt" - "strings" - "time" - - "github.com/status-im/status-go/crypto" - cryptotypes "github.com/status-im/status-go/crypto/types" - "github.com/status-im/status-go/messaging/types" - messagingtypes "github.com/status-im/status-go/messaging/types" - "github.com/status-im/status-go/protocol/protobuf" -) - -const tableName = "wakuv2_keys" - -type messagingPersistence struct { - db *sql.DB -} - -var _ types.Persistence = (*messagingPersistence)(nil) - -func NewMessagingPersistence(db *sql.DB) *messagingPersistence { - return &messagingPersistence{db: db} -} - -func (s *messagingPersistence) AddWakuKey(chatID string, key []byte) error { - statement := fmt.Sprintf("INSERT INTO %s(chat_id, key) VALUES(?, ?)", tableName) // nolint:gosec - stmt, err := s.db.Prepare(statement) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(chatID, key) - return err -} - -func (s *messagingPersistence) WakuKeys() (map[string][]byte, error) { - keys := make(map[string][]byte) - - statement := fmt.Sprintf("SELECT chat_id, key FROM %s", tableName) // nolint: gosec - - stmt, err := s.db.Prepare(statement) - if err != nil { - return nil, err - } - defer stmt.Close() - - rows, err := stmt.Query() - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var ( - chatID string - key []byte - ) - - err := rows.Scan(&chatID, &key) - if err != nil { - return nil, err - } - keys[chatID] = key - } - - return keys, nil -} - -func (c *messagingPersistence) MessageCacheClear() error { - _, err := c.db.Exec("DELETE FROM transport_message_cache") - return err -} - -func (c *messagingPersistence) MessageCacheHits(ids []string) (map[string]bool, error) { - hits := make(map[string]bool) - - // Split the results into batches of 999 items. - // To prevent excessive memory allocations, the maximum value of a host parameter number - // is SQLITE_MAX_VARIABLE_NUMBER, which defaults to 999 - batch := 999 - for i := 0; i < len(ids); i += batch { - j := i + batch - if j > len(ids) { - j = len(ids) - } - - currentBatch := ids[i:j] - - idsArgs := make([]interface{}, 0, len(currentBatch)) - for _, id := range currentBatch { - idsArgs = append(idsArgs, id) - } - - inVector := strings.Repeat("?, ", len(currentBatch)-1) + "?" - query := "SELECT id FROM transport_message_cache WHERE id IN (" + inVector + ")" // nolint: gosec - - rows, err := c.db.Query(query, idsArgs...) - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var id string - err := rows.Scan(&id) - if err != nil { - return nil, err - } - hits[id] = true - } - } - - return hits, nil -} - -func (c *messagingPersistence) MessageCacheAdd(ids []string, timestamp uint64) (err error) { - var tx *sql.Tx - tx, err = c.db.BeginTx(context.Background(), &sql.TxOptions{}) - if err != nil { - return - } - - defer func() { - if err == nil { - err = tx.Commit() - return - } - // don't shadow original error - _ = tx.Rollback() - }() - - for _, id := range ids { - - var stmt *sql.Stmt - stmt, err = tx.Prepare(`INSERT INTO transport_message_cache(id,timestamp) VALUES (?, ?)`) - if err != nil { - return - } - defer stmt.Close() - - _, err = stmt.Exec(id, timestamp) - if err != nil { - return - } - } - - return -} - -func (c *messagingPersistence) MessageCacheClearOlderThan(timestamp uint64) error { - _, err := c.db.Exec(`DELETE FROM transport_message_cache WHERE timestamp < ?`, timestamp) - return err -} - -func (c *messagingPersistence) InsertPendingConfirmation(confirmation *messagingtypes.RawMessageConfirmation) error { - _, err := c.db.Exec(`INSERT INTO raw_message_confirmations - (datasync_id, message_id, public_key) - VALUES - (?,?,?)`, - confirmation.DataSyncID, - confirmation.MessageID, - confirmation.PublicKey, - ) - return err -} - -func (c *messagingPersistence) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *messagingtypes.ReceivedMessage) error { - _, err := c.db.Exec(`INSERT INTO hash_ratchet_encrypted_messages(hash, sig, timestamp, topic, payload, dst, padding, group_id, key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, m.Hash, m.Sig, m.Timestamp, m.Topic.Bytes(), m.Payload, m.Dst, m.Padding, groupID, keyID) - return err -} - -func (c *messagingPersistence) GetHashRatchetMessages(keyID []byte) ([]*messagingtypes.ReceivedMessage, error) { - var messages []*messagingtypes.ReceivedMessage - - rows, err := c.db.Query(`SELECT hash, sig, timestamp, topic, payload, dst, padding FROM hash_ratchet_encrypted_messages WHERE key_id = ?`, keyID) - if err != nil { - return nil, err - } - - for rows.Next() { - var topic []byte - message := &messagingtypes.ReceivedMessage{} - - err := rows.Scan(&message.Hash, &message.Sig, &message.Timestamp, &topic, &message.Payload, &message.Dst, &message.Padding) - if err != nil { - return nil, err - } - - message.Topic = messagingtypes.BytesToContentTopic(topic) - messages = append(messages, message) - } - - return messages, nil -} - -func (c *messagingPersistence) DeleteHashRatchetMessages(ids [][]byte) error { - if len(ids) == 0 { - return nil - } - - idsArgs := make([]interface{}, 0, len(ids)) - for _, id := range ids { - idsArgs = append(idsArgs, id) - } - inVector := strings.Repeat("?, ", len(ids)-1) + "?" - - _, err := c.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec - - return err -} - -func (c *messagingPersistence) DeleteHashRatchetMessagesOlderThan(timestamp int64) error { - _, err := c.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE timestamp < ?", timestamp) - return err -} - -func (c *messagingPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) { - var alreadyCompleted int - err := c.db.QueryRow("SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash).Scan(&alreadyCompleted) - if err != nil { - return false, err - } - return alreadyCompleted > 0, nil -} - -func (c *messagingPersistence) SaveMessageSegment(segment *messagingtypes.SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error { - sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) - - _, err := c.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - segment.EntireMessageHash, segment.Index, segment.SegmentsCount, segment.ParitySegmentIndex, segment.ParitySegmentsCount, sigPubKeyBlob, segment.Payload, timestamp) - - return err -} - -// Get ordered message segments for given hash -func (c *messagingPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*messagingtypes.SegmentMessage, error) { - sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) - - rows, err := c.db.Query(` - SELECT - hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload - FROM - message_segments - WHERE - hash = ? AND sig_pub_key = ? - ORDER BY - (segments_count = 0) ASC, -- Prioritize segments_count > 0 - segment_index ASC, - parity_segment_index ASC`, - hash, sigPubKeyBlob) - if err != nil { - return nil, err - } - defer rows.Close() - - var segments []*messagingtypes.SegmentMessage - for rows.Next() { - segment := &messagingtypes.SegmentMessage{ - SegmentMessage: &protobuf.SegmentMessage{}, - } - err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload) - if err != nil { - return nil, err - } - segments = append(segments, segment) - } - err = rows.Err() - if err != nil { - return nil, err - } - - return segments, nil -} - -func (c *messagingPersistence) RemoveMessageSegmentsOlderThan(timestamp int64) error { - _, err := c.db.Exec("DELETE FROM message_segments WHERE timestamp < ?", timestamp) - return err -} - -func (c *messagingPersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error { - tx, err := c.db.BeginTx(context.Background(), &sql.TxOptions{}) - if err != nil { - return err - } - - defer func() { - if err == nil { - err = tx.Commit() - return - } - // don't shadow original error - _ = tx.Rollback() - }() - - sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) - - _, err = tx.Exec("DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob) - if err != nil { - return err - } - - _, err = tx.Exec("INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?,?,?)", hash, sigPubKeyBlob, timestamp) - if err != nil { - return err - } - - return err -} - -func (c *messagingPersistence) RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error { - _, err := c.db.Exec("DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp) - return err -} - -// MarkAsConfirmed marks all the messages with dataSyncID as confirmed and returns -// the messageIDs that can be considered confirmed. -// If atLeastOne is set it will return messageid if at least once of the messages -// sent has been confirmed -func (c *messagingPersistence) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID cryptotypes.HexBytes, err error) { - tx, err := c.db.BeginTx(context.Background(), &sql.TxOptions{}) - if err != nil { - return nil, err - } - defer func() { - if err == nil { - err = tx.Commit() - return - } - // don't shadow original error - _ = tx.Rollback() - }() - - confirmedAt := time.Now().Unix() - _, err = tx.Exec(`UPDATE raw_message_confirmations SET confirmed_at = ? WHERE datasync_id = ? AND confirmed_at = 0`, confirmedAt, dataSyncID) - if err != nil { - return - } - - // Select any tuple that has a message_id with a datasync_id = ? and that has just been confirmed - rows, err := tx.Query(`SELECT message_id,confirmed_at FROM raw_message_confirmations WHERE message_id = (SELECT message_id FROM raw_message_confirmations WHERE datasync_id = ? LIMIT 1)`, dataSyncID) - if err != nil { - return - } - defer rows.Close() - - confirmedResult := true - - for rows.Next() { - var confirmedAt int64 - err = rows.Scan(&messageID, &confirmedAt) - if err != nil { - return - } - confirmed := confirmedAt > 0 - - if atLeastOne && confirmed { - // We return, as at least one was confirmed - return - } - - confirmedResult = confirmedResult && confirmed - } - - if !confirmedResult { - messageID = nil - return - } - - return -} - -func (c *messagingPersistence) WakuInsertProtectedTopic(pubsubTopic string, privKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey) error { - var privKeyBytes []byte - if privKey != nil { - privKeyBytes = crypto.FromECDSA(privKey) - } - pubKeyBytes := crypto.FromECDSAPub(publicKey) - - _, err := c.db.Exec("INSERT OR REPLACE INTO pubsubtopic_signing_key (topic, priv_key, pub_key) VALUES (?, ?, ?)", - pubsubTopic, privKeyBytes, pubKeyBytes) - return err -} - -func (c *messagingPersistence) WakuDeleteProtectedTopic(pubsubTopic string) error { - _, err := c.db.Exec("DELETE FROM pubsubtopic_signing_key WHERE topic = ?", pubsubTopic) - return err -} - -func (c *messagingPersistence) WakuFetchPrivateKeyForProtectedTopic(topic string) (*ecdsa.PrivateKey, error) { - var privKeyBytes []byte - err := c.db.QueryRow("SELECT priv_key FROM pubsubtopic_signing_key WHERE topic = ?", topic).Scan(&privKeyBytes) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - return crypto.ToECDSA(privKeyBytes) -} - -func (c *messagingPersistence) WakuProtectedTopics() ([]types.ProtectedTopic, error) { - rows, err := c.db.Query("SELECT pub_key, topic FROM pubsubtopic_signing_key") - if err != nil { - return nil, err - } - defer rows.Close() - - var result []types.ProtectedTopic - for rows.Next() { - var pubKeyBytes []byte - var topic string - err := rows.Scan(&pubKeyBytes, &topic) - if err != nil { - return nil, err - } - - pubk, err := crypto.UnmarshalPubkey(pubKeyBytes) - if err != nil { - return nil, err - } - - result = append(result, types.ProtectedTopic{ - PubKey: pubk, - Topic: topic, - }) - } - - return result, nil -} diff --git a/protocol/messenger.go b/protocol/messenger.go index d7a1bde7942..76fac737cb9 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -95,7 +95,6 @@ type Messenger struct { identity *ecdsa.PrivateKey signer communities.MessageSigner messaging *messaging.API - messagingPersistence *messagingPersistence persistence *sqlitePersistence ensVerifier *ens.Verifier pushNotificationClient *pushnotificationclient.Client @@ -397,7 +396,6 @@ func NewMessenger( config: &c, identity: identity, messaging: messaging, - messagingPersistence: NewMessagingPersistence(database), persistence: sqlitePersistence, communityTokensService: c.communityTokensService, pushNotificationClient: pushNotificationClient, diff --git a/protocol/messenger_builder_test.go b/protocol/messenger_builder_test.go index 62042b33912..4505f29b0f4 100644 --- a/protocol/messenger_builder_test.go +++ b/protocol/messenger_builder_test.go @@ -127,12 +127,11 @@ func newTestMessenger(messagingEnv *messaging.TestMessagingEnvironment, config t messaging, err := messagingEnv.NewTestCore( messaging.CoreParams{ Identity: config.privateKey, - DB: appDb, - Persistence: NewMessagingPersistence(appDb), InstallationID: installationID, TimeSource: &testTimeSource{}, }, messaging.WithLogger(config.logger), + messaging.WithSQLitePersistence(appDb), ) if err != nil { return nil, err diff --git a/protocol/messenger_peersyncing.go b/protocol/messenger_peersyncing.go index 8ab49d35424..03f2a074cd0 100644 --- a/protocol/messenger_peersyncing.go +++ b/protocol/messenger_peersyncing.go @@ -27,7 +27,7 @@ func (m *Messenger) markDeliveredMessages(acks [][]byte) { for _, ack := range acks { //get message ID from database by datasync ID, with at-least-one // semantic - messageIDBytes, err := m.messagingPersistence.MarkAsConfirmed(ack, true) + messageIDBytes, err := m.messaging.MarkAsConfirmed(ack, true) if err != nil { m.logger.Info("got datasync acknowledge for message we don't have in db", zap.String("ack", hex.EncodeToString(ack))) continue diff --git a/protocol/migrations/sqlite/1561059284_add_waku_keys.down.sql b/protocol/migrations/sqlite/1561059284_add_waku_keys.down.sql deleted file mode 100644 index 95d756b9427..00000000000 --- a/protocol/migrations/sqlite/1561059284_add_waku_keys.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE waku_keys; diff --git a/protocol/migrations/sqlite/1561059284_add_waku_keys.up.sql b/protocol/migrations/sqlite/1561059284_add_waku_keys.up.sql deleted file mode 100644 index 8746cabcef3..00000000000 --- a/protocol/migrations/sqlite/1561059284_add_waku_keys.up.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE TABLE waku_keys ( - chat_id TEXT PRIMARY KEY ON CONFLICT IGNORE, - key BLOB NOT NULL -) WITHOUT ROWID; diff --git a/protocol/migrations/sqlite/1616691080_add_wakuV2_keys.down.sql b/protocol/migrations/sqlite/1616691080_add_wakuV2_keys.down.sql deleted file mode 100644 index 290723651ff..00000000000 --- a/protocol/migrations/sqlite/1616691080_add_wakuV2_keys.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE wakuv2_keys; diff --git a/protocol/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql b/protocol/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql deleted file mode 100644 index d42158e7843..00000000000 --- a/protocol/migrations/sqlite/1712905223_add_parity_to_message_segments.up.sql +++ /dev/null @@ -1,19 +0,0 @@ -ALTER TABLE message_segments RENAME TO old_message_segments; - -CREATE TABLE message_segments ( - hash BLOB NOT NULL, - segment_index INTEGER NOT NULL, - segments_count INTEGER NOT NULL, - payload BLOB NOT NULL, - sig_pub_key BLOB NOT NULL, - timestamp INTEGER NOT NULL, - parity_segment_index INTEGER NOT NULL, - parity_segments_count INTEGER NOT NULL, - PRIMARY KEY (hash, sig_pub_key, segment_index, segments_count, parity_segment_index, parity_segments_count) ON CONFLICT REPLACE -); - -INSERT INTO message_segments (hash, segment_index, segments_count, payload, sig_pub_key, timestamp, parity_segment_index, parity_segments_count) -SELECT hash, segment_index, segments_count, payload, sig_pub_key, timestamp, 0, 0 -FROM old_message_segments; - -DROP TABLE old_message_segments; diff --git a/protocol/sqlite/db.go b/protocol/sqlite/db.go index 9ea76c1962e..149c9cd0ed6 100644 --- a/protocol/sqlite/db.go +++ b/protocol/sqlite/db.go @@ -8,18 +8,22 @@ import ( _ "github.com/mutecomm/go-sqlcipher/v4" // We require go sqlcipher that overrides default implementation "github.com/status-im/migrate/v4/database/sqlcipher" bindata "github.com/status-im/migrate/v4/source/go_bindata" - mvdsmigrations "github.com/status-im/mvds/persistenceutil" + "github.com/status-im/status-go/messaging" "github.com/status-im/status-go/sqlite" ) var migrationsTable = "status_protocol_go_" + sqlcipher.DefaultMigrationsTable func Migrate(database *sql.DB) error { - // Apply migrations for all components. - err := mvdsmigrations.Migrate(database) + lastProtocolMigrationVersion, _, err := sqlite.GetLastMigrationVersion(database, migrationsTable) if err != nil { - return errors.Wrap(err, "failed to apply mvds migrations") + return errors.Wrap(err, "failed to get current migration version") + } + + err = messaging.SQLiteMigrate(database, lastProtocolMigrationVersion) + if err != nil { + return errors.Wrap(err, "failed to apply messaging migrations") } migrationNames, migrationGetter, err := prepareMigrations(defaultMigrations) @@ -27,6 +31,12 @@ func Migrate(database *sql.DB) error { return errors.Wrap(err, "failed to prepare status-go/protocol migrations") } + // Ensure the protocol migration table version is synchronized with the latest migration after migrations are moved between modules. + err = sqlite.UpdateMigrationTableVersion(database, migrationsTable, migrationNames, lastProtocolMigrationVersion) + if err != nil { + return errors.Wrap(err, "failed to update migration table version") + } + options := sqlite.MigrateOptions{ MigrationTableName: migrationsTable, } diff --git a/protocol/sqlite/prepare_migrations.go b/protocol/sqlite/prepare_migrations.go index 51df3cba71c..302dd61b355 100644 --- a/protocol/sqlite/prepare_migrations.go +++ b/protocol/sqlite/prepare_migrations.go @@ -5,7 +5,6 @@ import ( "github.com/pkg/errors" - encryptmigrations "github.com/status-im/status-go/messaging/layers/encryption/migrations" appmigrations "github.com/status-im/status-go/protocol/migrations" push_notification_client_migrations "github.com/status-im/status-go/protocol/pushnotificationclient/migrations" push_notification_server_migrations "github.com/status-im/status-go/protocol/pushnotificationserver/migrations" @@ -19,10 +18,6 @@ type migrationsWithGetter struct { } var defaultMigrations = []migrationsWithGetter{ - { - Names: encryptmigrations.AssetNames(), - Getter: encryptmigrations.Asset, - }, { Names: appmigrations.AssetNames(), Getter: appmigrations.Asset, diff --git a/services/ext/service.go b/services/ext/service.go index 5e854366b85..0e310d4d396 100644 --- a/services/ext/service.go +++ b/services/ext/service.go @@ -163,14 +163,13 @@ func (s *Service) InitProtocol(params InitProtocolParams) error { messaging, err := messaging.NewCore( messaging.CoreParams{ Identity: params.Identity, - DB: params.AppDB, - Persistence: protocol.NewMessagingPersistence(params.AppDB), NodeKey: nodeKey, WakuConfig: s.config.WakuV2Config, ClusterConfig: s.config.ClusterConfig, InstallationID: s.config.ShhextConfig.InstallationID, TimeSource: params.TimeSource, }, + messaging.WithSQLitePersistence(params.AppDB), messaging.WithLogger(s.logger), messaging.WithEnvelopeEventsConfig(envelopeEventsConfig), messaging.WithHistoricMessagesRequestFailedHandler(signal.SendHistoricMessagesRequestFailed), diff --git a/sqlite/migrate.go b/sqlite/migrate.go index 62a345e116c..56c130262ee 100644 --- a/sqlite/migrate.go +++ b/sqlite/migrate.go @@ -5,6 +5,7 @@ import ( "fmt" "sort" + "github.com/golang-migrate/migrate/v4/source" "github.com/status-im/migrate/v4" "github.com/status-im/migrate/v4/database/sqlcipher" bindata "github.com/status-im/migrate/v4/source/go_bindata" @@ -18,7 +19,9 @@ type PostStep struct { RollBackVersion uint } -var migrationTable = "status_go_" + sqlcipher.DefaultMigrationsTable +func StatusMigrationTableName() string { + return "status_go_" + sqlcipher.DefaultMigrationsTable +} type MigrateOptions struct { MigrationTableName string @@ -52,7 +55,7 @@ func Migrate(db *sql.DB, resources *bindata.AssetSource, options MigrateOptions) migrationTableName := options.MigrationTableName if len(migrationTableName) == 0 { - migrationTableName = migrationTable + migrationTableName = StatusMigrationTableName() } driver, err := sqlcipher.WithInstance(db, &sqlcipher.Config{ MigrationsTable: migrationTableName, @@ -164,7 +167,7 @@ func getCurrentVersion(m *migrate.Migrate, db *sql.DB) (uint, error) { return 0, fmt.Errorf("DB is dirty after migration version %d", lastVersion) } if err == migrate.ErrNilVersion { - lastVersion, _, err = GetLastMigrationVersion(db) + lastVersion, _, err = GetLastMigrationVersion(db, StatusMigrationTableName()) return lastVersion, err } return lastVersion, nil @@ -172,9 +175,9 @@ func getCurrentVersion(m *migrate.Migrate, db *sql.DB) (uint, error) { // GetLastMigrationVersion returns the last migration version stored in the migration table. // Returns 0 for version in case migrationTableExists is true -func GetLastMigrationVersion(db *sql.DB) (version uint, migrationTableExists bool, err error) { +func GetLastMigrationVersion(db *sql.DB, migrationTableName string) (version uint, migrationTableExists bool, err error) { // Check if the migration table exists - row := db.QueryRow("SELECT exists(SELECT name FROM sqlite_master WHERE type='table' AND name=?)", migrationTable) + row := db.QueryRow("SELECT exists(SELECT name FROM sqlite_master WHERE type='table' AND name=?)", migrationTableName) migrationTableExists = false err = row.Scan(&migrationTableExists) if err != nil && err != sql.ErrNoRows { @@ -183,7 +186,7 @@ func GetLastMigrationVersion(db *sql.DB) (version uint, migrationTableExists boo var lastMigration uint64 = 0 if migrationTableExists { - row = db.QueryRow("SELECT version FROM status_go_schema_migrations") + row = db.QueryRow(fmt.Sprintf("SELECT version FROM %s", migrationTableName)) err = row.Scan(&lastMigration) if err != nil && err != sql.ErrNoRows { return 0, true, err @@ -191,3 +194,76 @@ func GetLastMigrationVersion(db *sql.DB) (version uint, migrationTableExists boo } return uint(lastMigration), migrationTableExists, nil } + +// UpdateMigrationTableVersion migrates from one migration table to another, ensuring the table reflects the correct migration state. +// Ensures migrationTableName exists and records the latest version from assetNames, capped at maxVersion. +// Intended for migration table transitions and version synchronization. +func UpdateMigrationTableVersion(db *sql.DB, migrationTableName string, assetNames []string, maxVersion uint) error { + tx, err := db.Begin() + if err != nil { + return err + } + + defer func() { + if err != nil { + rollbackErr := tx.Rollback() + if rollbackErr != nil { + err = fmt.Errorf("failed to rollback transaction: %w; original error: %v", rollbackErr, err) + } + } + }() + + row := tx.QueryRow("SELECT exists(SELECT name FROM sqlite_master WHERE type='table' AND name=?)", migrationTableName) + exists := false + err = row.Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + + if !exists { + createTable := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool);`, migrationTableName) + _, err = tx.Exec(createTable) + if err != nil { + return err + } + createIndex := fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);`, migrationTableName) + _, err = tx.Exec(createIndex) + if err != nil { + return err + } + } + + version := getMaxMigrationVersion(assetNames, maxVersion) + if version > 0 { + // #nosec G201 -- migrationTableName is a trusted constant, not user input + deleteQuery := fmt.Sprintf("DELETE FROM %s", migrationTableName) + _, err = tx.Exec(deleteQuery) + if err != nil { + return err + } + + insertVersion := fmt.Sprintf(`INSERT INTO %s (version, dirty)`, migrationTableName) + `VALUES (?, ?)` + _, err = tx.Exec(insertVersion, version, false) + if err != nil { + return err + } + } + + err = tx.Commit() + + return err +} + +func getMaxMigrationVersion(assetNames []string, max uint) uint { + floor := uint(0) + for _, name := range assetNames { + m, err := source.DefaultParse(name) + if err != nil { + continue // ignore files that we can't parse + } + if m.Version <= max && m.Version > floor { + floor = m.Version + } + } + return floor +} diff --git a/t/helpers/db.go b/t/helpers/db.go index 4b71f5a3658..9a88b2e737e 100644 --- a/t/helpers/db.go +++ b/t/helpers/db.go @@ -2,11 +2,16 @@ package helpers import ( "database/sql" + "fmt" "io/ioutil" "os" + "github.com/google/uuid" + bindata "github.com/status-im/migrate/v4/source/go_bindata" + "github.com/status-im/status-go/common/dbsetup" "github.com/status-im/status-go/multiaccounts" + "github.com/status-im/status-go/sqlite" ) const kdfIterationsNumberForTests = 1 @@ -80,3 +85,31 @@ func ColumnExists(db *sql.DB, tableName string, columnName string) (bool, error) return false, nil } + +type TestDBInitializer struct { + assetSources []*bindata.AssetSource +} + +func NewTestDBInitializer(assetSource []*bindata.AssetSource) TestDBInitializer { + return TestDBInitializer{ + assetSources: assetSource, + } +} + +func (dbi TestDBInitializer) Initialize(dbPath string, password string, kdfIterations int) (*sql.DB, error) { + db, err := sqlite.OpenDB(dbPath, password, kdfIterations) + if err != nil { + return nil, err + } + + for _, as := range dbi.assetSources { + err = sqlite.Migrate(db, as, sqlite.MigrateOptions{ + MigrationTableName: "status_schema_migrations_" + fmt.Sprintf("%x", uuid.New()), + }) + if err != nil { + return nil, err + } + } + + return db, nil +} diff --git a/vendor/github.com/status-im/mvds/node/epoch_persistency.go b/vendor/github.com/status-im/mvds/node/epoch_persistency.go index b89f0d0a5db..a02224eb631 100644 --- a/vendor/github.com/status-im/mvds/node/epoch_persistency.go +++ b/vendor/github.com/status-im/mvds/node/epoch_persistency.go @@ -6,15 +6,20 @@ import ( "github.com/status-im/mvds/state" ) -type epochSQLitePersistence struct { +type EpochPersistence interface { + Get(nodeID state.PeerID) (epoch int64, err error) + Set(nodeID state.PeerID, epoch int64) error +} + +type EpochSQLitePersistence struct { db *sql.DB } -func newEpochSQLitePersistence(db *sql.DB) *epochSQLitePersistence { - return &epochSQLitePersistence{db: db} +func NewEpochSQLitePersistence(db *sql.DB) *EpochSQLitePersistence { + return &EpochSQLitePersistence{db: db} } -func (p *epochSQLitePersistence) Get(nodeID state.PeerID) (epoch int64, err error) { +func (p *EpochSQLitePersistence) Get(nodeID state.PeerID) (epoch int64, err error) { row := p.db.QueryRow(`SELECT epoch FROM mvds_epoch WHERE peer_id = ?`, nodeID[:]) err = row.Scan(&epoch) if err == sql.ErrNoRows { @@ -23,7 +28,7 @@ func (p *epochSQLitePersistence) Get(nodeID state.PeerID) (epoch int64, err erro return } -func (p *epochSQLitePersistence) Set(nodeID state.PeerID, epoch int64) error { +func (p *EpochSQLitePersistence) Set(nodeID state.PeerID, epoch int64) error { _, err := p.db.Exec(` INSERT OR REPLACE INTO mvds_epoch (peer_id, epoch) VALUES (?, ?)`, nodeID[:], diff --git a/vendor/github.com/status-im/mvds/node/node.go b/vendor/github.com/status-im/mvds/node/node.go index 9c00f12d365..c080f6c5c0b 100644 --- a/vendor/github.com/status-im/mvds/node/node.go +++ b/vendor/github.com/status-im/mvds/node/node.go @@ -68,7 +68,7 @@ type Node struct { ID state.PeerID - epochPersistence *epochSQLitePersistence + epochPersistence EpochPersistence mode Mode subscription chan protobuf.Message @@ -78,8 +78,41 @@ type Node struct { logger *zap.Logger } +type Persistence interface { + MessageStore() store.MessageStore + PeersStore() peers.Persistence + StateStore() state.SyncState + EpochStore() EpochPersistence +} + +type sqlitePersistence struct { + db *sql.DB +} + +var _ Persistence = (*sqlitePersistence)(nil) + +func NewSQLitePersistence(db *sql.DB) Persistence { + return &sqlitePersistence{db: db} +} + +func (p *sqlitePersistence) MessageStore() store.MessageStore { + return store.NewPersistentMessageStore(p.db) +} + +func (p *sqlitePersistence) PeersStore() peers.Persistence { + return peers.NewSQLitePersistence(p.db) +} + +func (p *sqlitePersistence) StateStore() state.SyncState { + return state.NewPersistentSyncState(p.db) +} + +func (p *sqlitePersistence) EpochStore() EpochPersistence { + return NewEpochSQLitePersistence(p.db) +} + func NewPersistentNode( - db *sql.DB, + persistence Persistence, st transport.Transport, id state.PeerID, mode Mode, @@ -96,12 +129,12 @@ func NewPersistentNode( ID: id, ctx: ctx, cancel: cancel, - store: store.NewPersistentMessageStore(db), + store: persistence.MessageStore(), transport: st, - peers: peers.NewSQLitePersistence(db), - syncState: state.NewPersistentSyncState(db), + peers: persistence.PeersStore(), + syncState: persistence.StateStore(), payloads: newPayloads(), - epochPersistence: newEpochSQLitePersistence(db), + epochPersistence: persistence.EpochStore(), nextEpoch: nextEpoch, peerStatusChangeEvent: peerStatusChangeEvent, logger: logger.With(zap.Namespace("mvds")), diff --git a/vendor/modules.txt b/vendor/modules.txt index 0eb241ee7a5..98034d51707 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1220,7 +1220,7 @@ github.com/status-im/migrate/v4 github.com/status-im/migrate/v4/database/sqlcipher github.com/status-im/migrate/v4/internal/url github.com/status-im/migrate/v4/source/go_bindata -# github.com/status-im/mvds v0.0.27-0.20241031073756-b192c603a75d +# github.com/status-im/mvds v0.0.27-0.20251022120125-7bdc695d49c4 ## explicit; go 1.19 github.com/status-im/mvds/node github.com/status-im/mvds/node/migrations