Skip to content

Commit f96a85b

Browse files
committed
fixup! (DO NOT MERGE) crypto: allow run goolm side-by-side with libolm
Signed-off-by: Sumner Evans <[email protected]>
1 parent f4a44df commit f96a85b

File tree

11 files changed

+150
-53
lines changed

11 files changed

+150
-53
lines changed

crypto/account.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"maunium.net/go/mautrix"
1717
"maunium.net/go/mautrix/crypto/canonicaljson"
1818
"maunium.net/go/mautrix/crypto/goolm/account"
19+
"maunium.net/go/mautrix/crypto/libolm"
1920
"maunium.net/go/mautrix/crypto/olm"
2021
"maunium.net/go/mautrix/crypto/signatures"
2122
"maunium.net/go/mautrix/id"
@@ -31,7 +32,7 @@ type OlmAccount struct {
3132
}
3233

3334
func NewOlmAccount() *OlmAccount {
34-
libolmAccount, err := olm.NewAccount()
35+
libolmAccount, err := libolm.NewAccount()
3536
if err != nil {
3637
panic(err)
3738
}

crypto/decryptmegolm.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package crypto
88

99
import (
10+
"bytes"
1011
"context"
1112
"encoding/json"
1213
"errors"
@@ -203,7 +204,11 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
203204
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
204205
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
205206
}
206-
firstKnown := sess.Internal.FirstKnownIndex()
207+
firstKnown := sess.InternalLibolm.FirstKnownIndex()
208+
firstKnownGoolm := sess.InternalGoolm.FirstKnownIndex()
209+
if firstKnown != firstKnownGoolm {
210+
panic(fmt.Sprintf("firstKnown not the same %d != %d", firstKnown, firstKnownGoolm))
211+
}
207212
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
208213
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
209214
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
@@ -228,7 +233,16 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
228233
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
229234
return sess, nil, 0, SenderKeyMismatch
230235
}
231-
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
236+
plaintextGoolm, messageIndexGoolm, errGoolm := sess.InternalGoolm.Decrypt(content.MegolmCiphertext)
237+
plaintext, messageIndex, err := sess.InternalLibolm.Decrypt(content.MegolmCiphertext)
238+
if !bytes.Equal(plaintextGoolm, plaintext) {
239+
panic("plaintext different")
240+
} else if messageIndexGoolm != messageIndex {
241+
panic(fmt.Sprintf("message index different %d != %d", messageIndexGoolm, messageIndex))
242+
} else if err != nil && errGoolm == nil {
243+
panic(fmt.Sprintf("goolm didn't error %v", err))
244+
}
245+
232246
if err != nil {
233247
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
234248
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
@@ -277,7 +291,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
277291
if len(sess.RatchetSafety.MissedIndices) > 0 {
278292
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
279293
}
280-
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
294+
ratchetCurrentIndexGoolm := sess.InternalGoolm.FirstKnownIndex()
295+
ratchetCurrentIndex := sess.InternalLibolm.FirstKnownIndex()
296+
if ratchetCurrentIndexGoolm != ratchetCurrentIndex {
297+
panic(fmt.Sprintf("ratchet current index different %d != %d", ratchetCurrentIndexGoolm, ratchetCurrentIndex))
298+
}
281299
log := zerolog.Ctx(ctx).With().
282300
Uint32("prev_ratchet_index", ratchetCurrentIndex).
283301
Uint32("new_ratchet_index", ratchetTargetIndex).

crypto/keybackup.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import (
99

1010
"maunium.net/go/mautrix"
1111
"maunium.net/go/mautrix/crypto/backup"
12-
"maunium.net/go/mautrix/crypto/olm"
12+
"maunium.net/go/mautrix/crypto/goolm/session"
13+
"maunium.net/go/mautrix/crypto/libolm"
1314
"maunium.net/go/mautrix/crypto/signatures"
1415
"maunium.net/go/mautrix/id"
1516
)
@@ -144,7 +145,12 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
144145
return nil, fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm)
145146
}
146147

147-
igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey))
148+
igsInternalGoolm, err := session.NewMegolmInboundSessionFromExport([]byte(keyBackupData.SessionKey))
149+
if err != nil {
150+
return nil, err
151+
}
152+
153+
igsInternal, err := libolm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey))
148154
if err != nil {
149155
return nil, fmt.Errorf("failed to import inbound group session: %w", err)
150156
} else if igsInternal.ID() != sessionID {
@@ -169,7 +175,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
169175
}
170176

171177
igs := &InboundGroupSession{
172-
Internal: igsInternal,
178+
InternalLibolm: igsInternal,
179+
InternalGoolm: igsInternalGoolm,
173180
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
174181
SenderKey: keyBackupData.SenderKey,
175182
RoomID: roomID,

crypto/keyexport.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"fmt"
2020
"math"
2121

22+
"go.mau.fi/util/exerrors"
2223
"go.mau.fi/util/random"
2324
"golang.org/x/crypto/pbkdf2"
2425

@@ -81,10 +82,14 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte)
8182
func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) {
8283
export := make([]ExportedSession, len(sessions))
8384
for i, session := range sessions {
84-
key, err := session.Internal.Export(session.Internal.FirstKnownIndex())
85+
key, err := session.InternalLibolm.Export(session.InternalLibolm.FirstKnownIndex())
8586
if err != nil {
8687
return nil, fmt.Errorf("failed to export session: %w", err)
8788
}
89+
keyGoolm := exerrors.Must(session.InternalGoolm.Export(session.InternalGoolm.FirstKnownIndex()))
90+
if !bytes.Equal(key, keyGoolm) {
91+
panic("keys not equal")
92+
}
8893
export[i] = ExportedSession{
8994
Algorithm: id.AlgorithmMegolmV1,
9095
ForwardingChains: session.ForwardingChains,

crypto/keyimport.go

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ import (
2020
"fmt"
2121
"time"
2222

23-
"maunium.net/go/mautrix/crypto/olm"
23+
"go.mau.fi/util/exerrors"
24+
25+
"maunium.net/go/mautrix/crypto/goolm/session"
26+
"maunium.net/go/mautrix/crypto/libolm"
2427
"maunium.net/go/mautrix/id"
2528
)
2629

@@ -92,38 +95,42 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession,
9295
return sessionsJSON, nil
9396
}
9497

95-
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) {
96-
if session.Algorithm != id.AlgorithmMegolmV1 {
98+
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, sess ExportedSession) (bool, error) {
99+
if sess.Algorithm != id.AlgorithmMegolmV1 {
97100
return false, ErrInvalidExportedAlgorithm
98101
}
99102

100-
igsInternal, err := olm.InboundGroupSessionImport([]byte(session.SessionKey))
103+
igsInternal, err := libolm.InboundGroupSessionImport([]byte(sess.SessionKey))
101104
if err != nil {
102105
return false, fmt.Errorf("failed to import session: %w", err)
103-
} else if igsInternal.ID() != session.SessionID {
106+
} else if igsInternal.ID() != sess.SessionID {
104107
return false, ErrMismatchingExportedSessionID
105108
}
106109
igs := &InboundGroupSession{
107-
Internal: igsInternal,
108-
SigningKey: session.SenderClaimedKeys.Ed25519,
109-
SenderKey: session.SenderKey,
110-
RoomID: session.RoomID,
110+
InternalLibolm: igsInternal,
111+
InternalGoolm: exerrors.Must(session.NewMegolmInboundSessionFromExport([]byte(sess.SessionKey))),
112+
SigningKey: sess.SenderClaimedKeys.Ed25519,
113+
SenderKey: sess.SenderKey,
114+
RoomID: sess.RoomID,
111115
// TODO should we add something here to mark the signing key as unverified like key requests do?
112-
ForwardingChains: session.ForwardingChains,
116+
ForwardingChains: sess.ForwardingChains,
113117

114118
ReceivedAt: time.Now().UTC(),
115119
}
116120
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
117-
firstKnownIndex := igs.Internal.FirstKnownIndex()
118-
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex {
121+
firstKnownIndex := igs.InternalLibolm.FirstKnownIndex()
122+
if firstKnownIndex != igs.InternalGoolm.FirstKnownIndex() {
123+
panic("indexes different")
124+
}
125+
if existingIGS != nil && existingIGS.InternalLibolm.FirstKnownIndex() <= firstKnownIndex {
119126
// We already have an equivalent or better session in the store, so don't override it.
120127
return false, nil
121128
}
122129
err = mach.CryptoStore.PutGroupSession(ctx, igs)
123130
if err != nil {
124131
return false, fmt.Errorf("failed to store imported session: %w", err)
125132
}
126-
mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex)
133+
mach.markSessionReceived(ctx, sess.RoomID, igs.ID(), firstKnownIndex)
127134
return true, nil
128135
}
129136

crypto/keysharing.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
package crypto
99

1010
import (
11+
"bytes"
1112
"context"
1213
"errors"
1314
"time"
1415

1516
"github.com/rs/zerolog"
17+
"go.mau.fi/util/exerrors"
1618

19+
"maunium.net/go/mautrix/crypto/goolm/session"
1720
"maunium.net/go/mautrix/crypto/olm"
1821
"maunium.net/go/mautrix/id"
1922

@@ -178,7 +181,8 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
178181
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
179182
}
180183
igs := &InboundGroupSession{
181-
Internal: igsInternal,
184+
InternalLibolm: igsInternal,
185+
InternalGoolm: exerrors.Must(session.NewMegolmInboundSessionFromExport([]byte(content.SessionKey))),
182186
SigningKey: evt.Keys.Ed25519,
183187
SenderKey: content.SenderKey,
184188
RoomID: content.RoomID,
@@ -191,7 +195,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
191195
IsScheduled: content.IsScheduled,
192196
}
193197
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
194-
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
198+
if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() {
199+
panic("different indices")
200+
}
201+
if existingIGS != nil && existingIGS.InternalLibolm.FirstKnownIndex() <= igs.InternalLibolm.FirstKnownIndex() {
195202
// We already have an equivalent or better session in the store, so don't override it.
196203
return false
197204
}
@@ -339,14 +346,21 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
339346
log = log.With().Stringer("unexpected_session_id", internalID).Logger()
340347
}
341348

342-
firstKnownIndex := igs.Internal.FirstKnownIndex()
349+
firstKnownIndex := igs.InternalLibolm.FirstKnownIndex()
350+
if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() {
351+
panic("different indices")
352+
}
343353
log = log.With().Uint32("first_known_index", firstKnownIndex).Logger()
344-
exportedKey, err := igs.Internal.Export(firstKnownIndex)
354+
exportedKey, err := igs.InternalLibolm.Export(firstKnownIndex)
345355
if err != nil {
346356
log.Error().Err(err).Msg("Failed to export group session to forward")
347357
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
348358
return
349359
}
360+
exportedKeyGoolm, err := igs.InternalGoolm.Export(firstKnownIndex)
361+
if !bytes.Equal(exportedKey, exportedKeyGoolm) {
362+
panic("keys different")
363+
}
350364
if igs.ForwardingChains == nil {
351365
igs.ForwardingChains = []string{}
352366
}

crypto/machine.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,10 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
579579
log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
580580
return fmt.Errorf("failed to store new inbound group session: %w", err)
581581
}
582-
mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex())
582+
if igs.InternalLibolm.FirstKnownIndex() != igs.InternalGoolm.FirstKnownIndex() {
583+
panic("different index")
584+
}
585+
mach.markSessionReceived(ctx, roomID, sessionID, igs.InternalLibolm.FirstKnownIndex())
583586
log.Debug().
584587
Str("session_id", sessionID.String()).
585588
Str("sender_key", senderKey.String()).

crypto/machine_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ func TestRatchetMegolmSession(t *testing.T) {
6060
assert.NoError(t, err)
6161
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", outSess.ID())
6262
require.NoError(t, err)
63-
assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex())
63+
assert.Equal(t, uint32(0), inSess.InternalLibolm.FirstKnownIndex())
64+
assert.Equal(t, uint32(0), inSess.InternalGoolm.FirstKnownIndex())
6465
err = inSess.RatchetTo(10)
6566
assert.NoError(t, err)
66-
assert.Equal(t, uint32(10), inSess.Internal.FirstKnownIndex())
67+
assert.Equal(t, uint32(10), inSess.InternalLibolm.FirstKnownIndex())
68+
assert.Equal(t, uint32(10), inSess.InternalGoolm.FirstKnownIndex())
6769
}
6870

6971
func TestOlmMachineOlmMegolmSessions(t *testing.T) {

crypto/sessions.go

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ package crypto
99
import (
1010
"bytes"
1111
"errors"
12+
"fmt"
1213
"time"
1314

1415
"go.mau.fi/util/exerrors"
1516

17+
"maunium.net/go/mautrix/crypto/goolm/session"
18+
"maunium.net/go/mautrix/crypto/libolm"
1619
"maunium.net/go/mautrix/crypto/olm"
1720
"maunium.net/go/mautrix/event"
1821
"maunium.net/go/mautrix/id"
@@ -107,7 +110,8 @@ type RatchetSafety struct {
107110
}
108111

109112
type InboundGroupSession struct {
110-
Internal olm.InboundGroupSession
113+
InternalLibolm olm.InboundGroupSession
114+
InternalGoolm olm.InboundGroupSession
111115

112116
SigningKey id.Ed25519
113117
SenderKey id.Curve25519
@@ -126,12 +130,17 @@ type InboundGroupSession struct {
126130
}
127131

128132
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) (*InboundGroupSession, error) {
129-
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
133+
igs, err := libolm.NewInboundGroupSession([]byte(sessionKey))
134+
if err != nil {
135+
return nil, err
136+
}
137+
igsGoolm, err := session.NewMegolmInboundSession([]byte(sessionKey))
130138
if err != nil {
131139
return nil, err
132140
}
133141
return &InboundGroupSession{
134-
Internal: igs,
142+
InternalLibolm: igs,
143+
InternalGoolm: igsGoolm,
135144
SigningKey: signingKey,
136145
SenderKey: senderKey,
137146
RoomID: roomID,
@@ -145,22 +154,31 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
145154

146155
func (igs *InboundGroupSession) ID() id.SessionID {
147156
if igs.id == "" {
148-
igs.id = igs.Internal.ID()
157+
igs.id = igs.InternalLibolm.ID()
158+
if igs.id != igs.InternalGoolm.ID() {
159+
panic(fmt.Sprintf("id different %s %s", igs.id, igs.InternalGoolm.ID()))
160+
}
149161
}
150162
return igs.id
151163
}
152164

153165
func (igs *InboundGroupSession) RatchetTo(index uint32) error {
154-
exported, err := igs.Internal.Export(index)
166+
exported, err := igs.InternalLibolm.Export(index)
155167
if err != nil {
156168
return err
157169
}
158-
imported, err := olm.InboundGroupSessionImport(exported)
170+
exportedGoolm, err := igs.InternalGoolm.Export(index)
171+
if err != nil {
172+
panic(err)
173+
} else if !bytes.Equal(exported, exportedGoolm) {
174+
panic("bytes not equal")
175+
}
176+
igs.InternalLibolm, err = libolm.InboundGroupSessionImport(exported)
159177
if err != nil {
160178
return err
161179
}
162-
igs.Internal = imported
163-
return nil
180+
igs.InternalGoolm, err = session.NewMegolmInboundSessionFromExport(exportedGoolm)
181+
return err
164182
}
165183

166184
type OGSState int

0 commit comments

Comments
 (0)