diff --git a/crypto/account.go b/crypto/account.go index 0bd09ecf..d105f433 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -7,19 +7,24 @@ package crypto import ( + "bytes" "encoding/json" + "fmt" "github.com/tidwall/sjson" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/libolm" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" ) type OlmAccount struct { - Internal olm.Account + InternalLibolm olm.Account + InternalGoolm olm.Account signingKey id.SigningKey identityKey id.IdentityKey Shared bool @@ -27,22 +32,41 @@ type OlmAccount struct { } func NewOlmAccount() *OlmAccount { - account, err := olm.NewAccount() + libolmAccount, err := libolm.NewAccount() + if err != nil { + panic(err) + } + pickled, err := libolmAccount.Pickle([]byte("key")) + if err != nil { + panic(err) + } + goolmAccount, err := account.AccountFromPickled(pickled, []byte("key")) if err != nil { panic(err) } return &OlmAccount{ - Internal: account, + InternalLibolm: libolmAccount, + InternalGoolm: goolmAccount, } } func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) { if len(account.signingKey) == 0 || len(account.identityKey) == 0 { var err error - account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys() + if err != nil { + panic(err) + } + goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys() if err != nil { panic(err) } + if account.signingKey != goolmSigningKey { + panic("account signing keys not equal") + } + if account.identityKey != goolmIdentityKey { + panic("account identity keys not equal") + } } return account.signingKey, account.identityKey } @@ -50,10 +74,20 @@ func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) { func (account *OlmAccount) SigningKey() id.SigningKey { if len(account.signingKey) == 0 { var err error - account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys() + if err != nil { + panic(err) + } + goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys() if err != nil { panic(err) } + if account.signingKey != goolmSigningKey { + panic("account signing keys not equal") + } + if account.identityKey != goolmIdentityKey { + panic("account identity keys not equal") + } } return account.signingKey } @@ -61,10 +95,20 @@ func (account *OlmAccount) SigningKey() id.SigningKey { func (account *OlmAccount) IdentityKey() id.IdentityKey { if len(account.identityKey) == 0 { var err error - account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys() + if err != nil { + panic(err) + } + goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys() if err != nil { panic(err) } + if account.signingKey != goolmSigningKey { + panic("account signing keys not equal") + } + if account.identityKey != goolmIdentityKey { + panic("account identity keys not equal") + } } return account.identityKey } @@ -78,7 +122,15 @@ func (account *OlmAccount) SignJSON(obj any) (string, error) { } objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") - signed, err := account.Internal.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + signed, err := account.InternalLibolm.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + goolmSigned, goolmErr := account.InternalGoolm.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + if err != nil { + if goolmErr == nil { + panic("libolm errored, but goolm did not on account.SignJSON") + } + } else if !bytes.Equal(signed, goolmSigned) { + panic("libolm and goolm signed are not equal in account.SignJSON") + } return string(signed), err } @@ -102,19 +154,36 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID return deviceKeys } -func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey { - newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount +func (a *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey { + newCount := int(a.InternalLibolm.MaxNumberOfOneTimeKeys()/2) - currentOTKCount if newCount > 0 { - account.Internal.GenOneTimeKeys(uint(newCount)) + a.InternalLibolm.GenOneTimeKeys(uint(newCount)) + + pickled, err := a.InternalLibolm.Pickle([]byte("key")) + if err != nil { + panic(err) + } + a.InternalGoolm, err = account.AccountFromPickled(pickled, []byte("key")) + if err != nil { + panic(err) + } } oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey) - internalKeys, err := account.Internal.OneTimeKeys() + internalKeys, err := a.InternalLibolm.OneTimeKeys() + if err != nil { + panic(err) + } + goolmInternalKeys, err := a.InternalGoolm.OneTimeKeys() if err != nil { panic(err) } for keyID, key := range internalKeys { + if goolmInternalKeys[keyID] != key { + panic(fmt.Sprintf("key %s not found in getOneTimeKeys", keyID)) + } + key := mautrix.OneTimeKey{Key: key} - signature, _ := account.SignJSON(key) + signature, _ := a.SignJSON(key) key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature) key.IsSigned = true oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 47279474..f6fe0bad 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -7,6 +7,7 @@ package crypto import ( + "bytes" "context" "encoding/json" "errors" @@ -203,7 +204,11 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) } - firstKnown := sess.Internal.FirstKnownIndex() + firstKnown := sess.InternalLibolm.FirstKnownIndex() + firstKnownGoolm := sess.InternalGoolm.FirstKnownIndex() + if firstKnown != firstKnownGoolm { + panic(fmt.Sprintf("firstKnown not the same %d != %d", firstKnown, firstKnownGoolm)) + } log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") @@ -228,7 +233,16 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { return sess, nil, 0, SenderKeyMismatch } - plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) + plaintextGoolm, messageIndexGoolm, errGoolm := sess.InternalGoolm.Decrypt(content.MegolmCiphertext) + plaintext, messageIndex, err := sess.InternalLibolm.Decrypt(content.MegolmCiphertext) + if !bytes.Equal(plaintextGoolm, plaintext) { + panic("plaintext different") + } else if messageIndexGoolm != messageIndex { + panic(fmt.Sprintf("message index different %d != %d", messageIndexGoolm, messageIndex)) + } else if err != nil && errGoolm == nil { + panic(fmt.Sprintf("goolm didn't error %v", err)) + } + if err != nil { if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) @@ -277,7 +291,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if len(sess.RatchetSafety.MissedIndices) > 0 { ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0]) } - ratchetCurrentIndex := sess.Internal.FirstKnownIndex() + ratchetCurrentIndexGoolm := sess.InternalGoolm.FirstKnownIndex() + ratchetCurrentIndex := sess.InternalLibolm.FirstKnownIndex() + if ratchetCurrentIndexGoolm != ratchetCurrentIndex { + panic(fmt.Sprintf("ratchet current index different %d != %d", ratchetCurrentIndexGoolm, ratchetCurrentIndex)) + } log := zerolog.Ctx(ctx).With(). Uint32("prev_ratchet_index", ratchetCurrentIndex). Uint32("new_ratchet_index", ratchetTargetIndex). diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 80b76dc5..c4a66d54 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -180,9 +180,17 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id log.Error().Err(err).Msg("Failed to verify signature of one-time key") } else if !ok { log.Warn().Msg("One-time key has invalid signature from device") - } else if sess, err := mach.account.Internal.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil { + } else if sess, err := mach.account.InternalLibolm.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil { log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key") } else { + goolmSess, err := mach.account.InternalGoolm.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key) + if err != nil { + panic("goolm NewOutboundSession errored") + } + if sess.Describe() != goolmSess.Describe() { + panic("goolm NewOutboundSession and libolm NewOutboundSession returned different values") + } + wrapped := wrapSession(sess) err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped) if err != nil { diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 4da08a73..7ee653ae 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -336,8 +336,10 @@ func (a *Account) UnpickleLibOlm(buf []byte) error { } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair + fmt.Printf("123 %+v\n", err) return err } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair + fmt.Printf("456 %+v\n", err) return err } diff --git a/crypto/goolm/libolmpickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go index d15358fd..5be9e381 100644 --- a/crypto/goolm/libolmpickle/pickle.go +++ b/crypto/goolm/libolmpickle/pickle.go @@ -28,11 +28,11 @@ func Pickle(key, plaintext []byte) ([]byte, error) { // Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. func Unpickle(key, input []byte) ([]byte, error) { - ciphertext, err := goolmbase64.Decode(input) + decoded, err := goolmbase64.Decode(input) if err != nil { return nil, err } - ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:] + ciphertext, mac := decoded[:len(decoded)-pickleMACLength], decoded[len(decoded)-pickleMACLength:] if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { return nil, err } else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index d8b3d715..9558125f 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -11,7 +11,8 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/backup" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -175,7 +176,12 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( return nil, fmt.Errorf("%w %s", ErrUnknownAlgorithmInKeyBackup, keyBackupData.Algorithm) } - igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) + igsInternalGoolm, err := session.NewMegolmInboundSessionFromExport([]byte(keyBackupData.SessionKey)) + if err != nil { + return nil, err + } + + igsInternal, err := libolm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) if err != nil { return nil, fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { @@ -194,8 +200,14 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( maxMessages = config.RotationPeriodMessages } + firstKnownIndex := igsInternal.FirstKnownIndex() + if firstKnownIndex > 0 { + log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") + } + return &InboundGroupSession{ - Internal: igsInternal, + InternalLibolm: igsInternal, + InternalGoolm: igsInternalGoolm, SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, @@ -221,7 +233,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. if err != nil { return nil, err } - firstKnownIndex := imported.Internal.FirstKnownIndex() + firstKnownIndex := imported.InternalLibolm.FirstKnownIndex() if firstKnownIndex > 0 { zerolog.Ctx(ctx).Warn(). Stringer("room_id", roomID). @@ -229,6 +241,10 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. Uint32("first_known_index", firstKnownIndex). Msg("Importing partial session") } + if firstKnownIndex != imported.InternalGoolm.FirstKnownIndex() { + panic("Goolm and libolm first known index mismatch") + } + err = mach.CryptoStore.PutGroupSession(ctx, imported) if err != nil { return nil, fmt.Errorf("%w: %w", ErrFailedToStoreNewInboundGroupSessionFromBackup, err) diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 36ad6b9c..b135c0e4 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -20,7 +20,10 @@ import ( "fmt" "time" - "maunium.net/go/mautrix/crypto/olm" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" "maunium.net/go/mautrix/id" ) @@ -96,30 +99,34 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession, return sessionsJSON, nil } -func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) { - if session.Algorithm != id.AlgorithmMegolmV1 { +func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, sess ExportedSession) (bool, error) { + if sess.Algorithm != id.AlgorithmMegolmV1 { return false, ErrInvalidExportedAlgorithm } - igsInternal, err := olm.InboundGroupSessionImport([]byte(session.SessionKey)) + igsInternal, err := libolm.InboundGroupSessionImport([]byte(sess.SessionKey)) if err != nil { return false, fmt.Errorf("failed to import session: %w", err) - } else if igsInternal.ID() != session.SessionID { + } else if igsInternal.ID() != sess.SessionID { return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: igsInternal, - SigningKey: session.SenderClaimedKeys.Ed25519, - SenderKey: session.SenderKey, - RoomID: session.RoomID, + InternalLibolm: igsInternal, + InternalGoolm: exerrors.Must(session.NewMegolmInboundSessionFromExport([]byte(sess.SessionKey))), + SigningKey: sess.SenderClaimedKeys.Ed25519, + SenderKey: sess.SenderKey, + RoomID: sess.RoomID, // TODO should we add something here to mark the signing key as unverified like key requests do? - ForwardingChains: session.ForwardingChains, + ForwardingChains: sess.ForwardingChains, ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) - firstKnownIndex := igs.Internal.FirstKnownIndex() - if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { + firstKnownIndex := igs.InternalLibolm.FirstKnownIndex() + if firstKnownIndex != igs.InternalGoolm.FirstKnownIndex() { + panic("indexes different") + } + if existingIGS != nil && existingIGS.InternalLibolm.FirstKnownIndex() <= firstKnownIndex { // We already have an equivalent or better session in the store, so don't override it. return false, nil } @@ -127,7 +134,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) + mach.MarkSessionReceived(ctx, sess.RoomID, igs.ID(), firstKnownIndex) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index f1d427af..768895e3 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -8,12 +8,15 @@ package crypto import ( + "bytes" "context" "errors" "time" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" @@ -178,7 +181,8 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } igs := &InboundGroupSession{ - Internal: igsInternal, + InternalLibolm: igsInternal, + InternalGoolm: exerrors.Must(session.NewMegolmInboundSessionFromExport([]byte(content.SessionKey))), SigningKey: evt.Keys.Ed25519, SenderKey: content.SenderKey, RoomID: content.RoomID, @@ -191,7 +195,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt IsScheduled: content.IsScheduled, } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) - if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { + if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() { + panic("different indices") + } + if existingIGS != nil && existingIGS.InternalLibolm.FirstKnownIndex() <= igs.InternalLibolm.FirstKnownIndex() { // We already have an equivalent or better session in the store, so don't override it. return false } @@ -339,14 +346,24 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User log = log.With().Stringer("unexpected_session_id", internalID).Logger() } - firstKnownIndex := igs.Internal.FirstKnownIndex() + firstKnownIndex := igs.InternalLibolm.FirstKnownIndex() + if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() { + panic("different indices") + } log = log.With().Uint32("first_known_index", firstKnownIndex).Logger() - exportedKey, err := igs.Internal.Export(firstKnownIndex) + exportedKey, err := igs.InternalLibolm.Export(firstKnownIndex) if err != nil { log.Error().Err(err).Msg("Failed to export group session to forward") mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) return } + exportedKeyGoolm, err := igs.InternalGoolm.Export(firstKnownIndex) + if !bytes.Equal(exportedKey, exportedKeyGoolm) { + panic("keys different") + } + if igs.ForwardingChains == nil { + igs.ForwardingChains = []string{} + } forwardedRoomKey := event.Content{ Parsed: &event.ForwardedRoomKeyEventContent{ diff --git a/crypto/machine.go b/crypto/machine.go index cac91bf8..d15dbc60 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -282,7 +282,7 @@ func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.O mach.receivedOTKsForSelf.Store(true) } - minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2 + minCount := mach.account.InternalLibolm.MaxNumberOfOneTimeKeys() / 2 if otkCount.SignedCurve25519 < int(minCount) { traceID := time.Now().Format("15:04:05.000000") log := mach.Log.With().Str("trace_id", traceID).Logger() @@ -584,7 +584,10 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) + if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() { + panic("different index") + } + mach.MarkSessionReceived(ctx, roomID, sessionID, igs.InternalLibolm.FirstKnownIndex()) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -754,7 +757,8 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro return err } mach.lastOTKUpload = time.Now() - mach.account.Internal.MarkKeysAsPublished() + mach.account.InternalLibolm.MarkKeysAsPublished() + mach.account.InternalGoolm.MarkKeysAsPublished() mach.account.Shared = true return mach.saveAccount(ctx) } diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 59c86236..63219d94 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -60,10 +60,12 @@ func TestRatchetMegolmSession(t *testing.T) { assert.NoError(t, err) inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", outSess.ID()) require.NoError(t, err) - assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex()) + assert.Equal(t, uint32(0), inSess.InternalLibolm.FirstKnownIndex()) + assert.Equal(t, uint32(0), inSess.InternalGoolm.FirstKnownIndex()) err = inSess.RatchetTo(10) assert.NoError(t, err) - assert.Equal(t, uint32(10), inSess.Internal.FirstKnownIndex()) + assert.Equal(t, uint32(10), inSess.InternalLibolm.FirstKnownIndex()) + assert.Equal(t, uint32(10), inSess.InternalGoolm.FirstKnownIndex()) } func TestOlmMachineOlmMegolmSessions(t *testing.T) { @@ -78,10 +80,11 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { otk = otkTmp break } - machineIn.account.Internal.MarkKeysAsPublished() + machineIn.account.InternalLibolm.MarkKeysAsPublished() + machineIn.account.InternalGoolm.MarkKeysAsPublished() // create outbound olm session for sending machine using OTK - olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) + olmSession, err := machineOut.account.InternalLibolm.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) if err != nil { t.Errorf("Failed to create outbound olm session: %v", err) } diff --git a/crypto/sessions.go b/crypto/sessions.go index aecb0416..34f0e75b 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -7,13 +7,17 @@ package crypto import ( + "bytes" "errors" "fmt" "time" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" ) @@ -69,11 +73,20 @@ func wrapSession(session olm.Session) *OlmSession { } func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) { - session, err := account.Internal.NewInboundSessionFrom(&senderKey, ciphertext) + session, err := account.InternalLibolm.NewInboundSessionFrom(&senderKey, ciphertext) if err != nil { return nil, err } - _ = account.Internal.RemoveOneTimeKeys(session) + goolmSession, err := account.InternalGoolm.NewInboundSessionFrom(&senderKey, ciphertext) + if err != nil { + return nil, err + } + if !bytes.Equal(exerrors.Must(goolmSession.Pickle([]byte("123"))), exerrors.Must(session.Pickle([]byte("123")))) { + panic("goolm inbound session and libolm inbound session from ciphertext are different") + } + + _ = account.InternalLibolm.RemoveOneTimeKeys(session) + _ = account.InternalGoolm.RemoveOneTimeKeys(goolmSession) return wrapSession(session), nil } @@ -97,7 +110,8 @@ type RatchetSafety struct { } type InboundGroupSession struct { - Internal olm.InboundGroupSession + InternalLibolm olm.InboundGroupSession + InternalGoolm olm.InboundGroupSession SigningKey id.Ed25519 SenderKey id.Curve25519 @@ -116,12 +130,17 @@ type InboundGroupSession struct { } func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) (*InboundGroupSession, error) { - igs, err := olm.NewInboundGroupSession([]byte(sessionKey)) + igs, err := libolm.NewInboundGroupSession([]byte(sessionKey)) + if err != nil { + return nil, err + } + igsGoolm, err := session.NewMegolmInboundSession([]byte(sessionKey)) if err != nil { return nil, err } return &InboundGroupSession{ - Internal: igs, + InternalLibolm: igs, + InternalGoolm: igsGoolm, SigningKey: signingKey, SenderKey: senderKey, RoomID: roomID, @@ -135,29 +154,42 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI func (igs *InboundGroupSession) ID() id.SessionID { if igs.id == "" { - igs.id = igs.Internal.ID() + igs.id = igs.InternalLibolm.ID() + if igs.id != igs.InternalGoolm.ID() { + panic(fmt.Sprintf("id different %s %s", igs.id, igs.InternalGoolm.ID())) + } } return igs.id } func (igs *InboundGroupSession) RatchetTo(index uint32) error { - exported, err := igs.Internal.Export(index) + exported, err := igs.InternalLibolm.Export(index) if err != nil { return err } - imported, err := olm.InboundGroupSessionImport(exported) + exportedGoolm, err := igs.InternalGoolm.Export(index) + if err != nil { + panic(err) + } else if !bytes.Equal(exported, exportedGoolm) { + panic("bytes not equal") + } + igs.InternalLibolm, err = libolm.InboundGroupSessionImport(exported) if err != nil { return err } - igs.Internal = imported - return nil + igs.InternalGoolm, err = session.NewMegolmInboundSessionFromExport(exportedGoolm) + return err } func (igs *InboundGroupSession) export() (*ExportedSession, error) { - key, err := igs.Internal.Export(igs.Internal.FirstKnownIndex()) + key, err := igs.InternalLibolm.Export(igs.InternalLibolm.FirstKnownIndex()) if err != nil { return nil, fmt.Errorf("failed to export session: %w", err) } + keyGoolm := exerrors.Must(igs.InternalGoolm.Export(igs.InternalGoolm.FirstKnownIndex())) + if !bytes.Equal(key, keyGoolm) { + panic("keys not equal") + } return &ExportedSession{ Algorithm: id.AlgorithmMegolmV1, ForwardingChains: igs.ForwardingChains, diff --git a/crypto/sql_store.go b/crypto/sql_store.go index b0625763..8f6291cd 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -7,6 +7,7 @@ package crypto import ( + "bytes" "context" "database/sql" "database/sql/driver" @@ -20,9 +21,13 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/goolm/account" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/event" @@ -128,16 +133,20 @@ func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.Devi // PutAccount stores an OlmAccount in the database. func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error { store.Account = account - bytes, err := account.Internal.Pickle(store.PickleKey) + pickled, err := account.InternalLibolm.Pickle(store.PickleKey) if err != nil { return err } + err = account.InternalGoolm.Unpickle(pickled, store.PickleKey) + if err != nil { + panic(fmt.Sprintf("failed to unpickle account using goolm: %+v", err)) + } _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account, account_id=excluded.account_id, key_backup_version=excluded.key_backup_version - `, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID, account.KeyBackupVersion) + `, store.DeviceID, account.Shared, store.SyncToken, pickled, store.AccountID, account.KeyBackupVersion) return err } @@ -145,7 +154,10 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID) - acc := &OlmAccount{Internal: olm.NewBlankAccount()} + acc := &OlmAccount{ + InternalLibolm: olm.NewBlankAccount(), + InternalGoolm: &account.Account{}, + } var accountBytes []byte err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion) if err == sql.ErrNoRows { @@ -153,10 +165,14 @@ func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error } else if err != nil { return nil, err } - err = acc.Internal.Unpickle(accountBytes, store.PickleKey) + err = acc.InternalLibolm.Unpickle(bytes.Clone(accountBytes), store.PickleKey) if err != nil { return nil, err } + err = acc.InternalGoolm.Unpickle(accountBytes, store.PickleKey) + if err != nil { + panic(fmt.Sprintf("failed to unpickle account using goolm: %+v", err)) + } store.Account = acc } return store.Account, nil @@ -322,10 +338,14 @@ func datePtr(t time.Time) *time.Time { // PutGroupSession stores an inbound Megolm group session for a room, sender and session. func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *InboundGroupSession) error { - sessionBytes, err := session.Internal.Pickle(store.PickleKey) + sessionBytes, err := session.InternalLibolm.Pickle(store.PickleKey) if err != nil { return err } + sessionBytesGoolm := exerrors.Must(session.InternalGoolm.Pickle(store.PickleKey)) + if !bytes.Equal(sessionBytes, sessionBytesGoolm) { + panic("different session bytes") + } if session.ForwardingChains == nil { session.ForwardingChains = []string{} } @@ -393,12 +413,15 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room Reason: withheldReason.String, } } - igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) + fmt.Printf("got here 1\n") + libolmIgs, goolmIgs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) if err != nil { + fmt.Printf("got here 2 %+v\n", err) return nil, err } return &InboundGroupSession{ - Internal: igs, + InternalLibolm: libolmIgs, + InternalGoolm: goolmIgs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, @@ -504,12 +527,18 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID }, nil } -func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { - igs = olm.NewBlankInboundGroupSession() - err = igs.Unpickle(sessionBytes, store.PickleKey) +func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs olm.InboundGroupSession, igsGoolm olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { + igs = libolm.NewBlankInboundGroupSession() + err = igs.Unpickle(bytes.Clone(sessionBytes), store.PickleKey) + if err != nil { + return + } + + igsGoolm, err = session.MegolmInboundSessionFromPickled(sessionBytes, store.PickleKey) if err != nil { return } + if forwardingChains != "" { chains = strings.Split(forwardingChains, ",") } else { @@ -537,12 +566,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In if err != nil { return nil, err } - igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) + igsLibolm, igsGoolm, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) if err != nil { return nil, err } return &InboundGroupSession{ - Internal: igs, + InternalLibolm: igsLibolm, + InternalGoolm: igsGoolm, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, diff --git a/crypto/store_test.go b/crypto/store_test.go index a7c4d75a..b5ae113d 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -13,9 +13,13 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -153,33 +157,30 @@ func TestStoreMegolmSession(t *testing.T) { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal inbound group session: %v", err) - } + internal, err := libolm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) + require.NoError(t, err, "Error creating internal inbound group session") + + internalGoolm, err := session.MegolmInboundSessionFromPickled([]byte(groupSession), []byte("test")) + require.NoError(t, err) igs := &InboundGroupSession{ - Internal: internal, - SigningKey: acc.SigningKey(), - SenderKey: acc.IdentityKey(), - RoomID: "room1", + InternalLibolm: internal, + InternalGoolm: internalGoolm, + SigningKey: acc.SigningKey(), + SenderKey: acc.IdentityKey(), + RoomID: "room1", } err = store.PutGroupSession(context.TODO(), igs) - if err != nil { - t.Errorf("Error storing inbound group session: %v", err) - } + require.NoError(t, err, "Error storing inbound group session") retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID()) - if err != nil { - t.Errorf("Error retrieving inbound group session: %v", err) - } + require.NoError(t, err, "Error retrieving inbound group session") - if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil { - t.Fatalf("Error pickling inbound group session: %v", err) - } else if string(pickled) != groupSession { - t.Error("Pickled inbound group session does not match original") - } + pickled, err := retrieved.InternalLibolm.Pickle([]byte("test")) + require.NoError(t, err) + assert.Equal(t, string(pickled), groupSession) + assert.Equal(t, pickled, exerrors.Must(retrieved.InternalGoolm.Pickle([]byte("test")))) }) } }