mirror of
https://github.com/SchildiChat/SchildiChat-android.git
synced 2025-02-17 04:20:00 +03:00
Extract olm cache store
This commit is contained in:
parent
bcdf004082
commit
10ea166b2a
6 changed files with 575 additions and 38 deletions
|
@ -16,7 +16,16 @@
|
|||
|
||||
package org.matrix.android.sdk.account
|
||||
|
||||
import android.util.Log
|
||||
import androidx.test.filters.LargeTest
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Deferred
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.asCoroutineDispatcher
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
import org.junit.FixMethodOrder
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
|
@ -28,6 +37,9 @@ import org.matrix.android.sdk.common.CommonTestHelper
|
|||
import org.matrix.android.sdk.common.CryptoTestHelper
|
||||
import org.matrix.android.sdk.common.SessionTestParams
|
||||
import org.matrix.android.sdk.common.TestConstants
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import kotlin.random.Random
|
||||
|
||||
@RunWith(JUnit4::class)
|
||||
@FixMethodOrder(MethodSorters.JVM)
|
||||
|
@ -62,4 +74,144 @@ class AccountCreationTest : InstrumentedTest {
|
|||
|
||||
res.cleanUp(commonTestHelper)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConcurrentDecrypt() {
|
||||
// val res = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom()
|
||||
|
||||
// =============================
|
||||
// ARRANGE
|
||||
// =============================
|
||||
|
||||
val aliceSession = commonTestHelper.createAccount(TestConstants.USER_ALICE, SessionTestParams(true))
|
||||
val bobSession = commonTestHelper.createAccount(TestConstants.USER_BOB, SessionTestParams(true))
|
||||
cryptoTestHelper.initializeCrossSigning(bobSession)
|
||||
val bobSession2 = commonTestHelper.logIntoAccount(bobSession.myUserId, SessionTestParams(true))
|
||||
|
||||
bobSession2.cryptoService().verificationService().markedLocallyAsManuallyVerified(bobSession.myUserId, bobSession.sessionParams.deviceId ?: "")
|
||||
bobSession.cryptoService().verificationService().markedLocallyAsManuallyVerified(bobSession.myUserId, bobSession2.sessionParams.deviceId ?: "")
|
||||
|
||||
val roomId = cryptoTestHelper.createDM(aliceSession, bobSession)
|
||||
val roomAlicePOV = aliceSession.getRoom(roomId)!!
|
||||
|
||||
// =============================
|
||||
// ACT
|
||||
// =============================
|
||||
|
||||
val timelineEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob", 1).first()
|
||||
val secondEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob 2", 1).first()
|
||||
val thirdEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob 3", 1).first()
|
||||
val forthEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob 4", 1).first()
|
||||
|
||||
// await for bob unverified session to get the message
|
||||
commonTestHelper.waitWithLatch { latch ->
|
||||
commonTestHelper.retryPeriodicallyWithLatch(latch) {
|
||||
bobSession.getRoom(roomId)?.getTimeLineEvent(forthEvent.eventId) != null
|
||||
}
|
||||
}
|
||||
|
||||
val eventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(timelineEvent.eventId)!!
|
||||
val secondEventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(secondEvent.eventId)!!
|
||||
val thirdEventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(thirdEvent.eventId)!!
|
||||
val forthEventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(forthEvent.eventId)!!
|
||||
|
||||
// let's try to decrypt concurrently and check that we are not getting exceptions
|
||||
val dispatcher = Executors
|
||||
.newFixedThreadPool(100)
|
||||
.asCoroutineDispatcher()
|
||||
val coroutineScope = CoroutineScope(SupervisorJob() + dispatcher)
|
||||
|
||||
val eventList = listOf(eventBobPOV, secondEventBobPOV, thirdEventBobPOV, forthEventBobPOV)
|
||||
|
||||
// commonTestHelper.runBlockingTest {
|
||||
// val export = bobSession.cryptoService().exportRoomKeys("foo")
|
||||
|
||||
// }
|
||||
val atomicAsError = AtomicBoolean()
|
||||
val deff = mutableListOf<Deferred<Any>>()
|
||||
// for (i in 1..100) {
|
||||
// GlobalScope.launch {
|
||||
// val index = Random.nextInt(eventList.size)
|
||||
// try {
|
||||
// val event = eventList[index]
|
||||
// bobSession.cryptoService().decryptEvent(event.root, "")
|
||||
// Log.d("#TEST", "Decrypt Success $index :${Thread.currentThread().name}")
|
||||
// } catch (failure: Throwable) {
|
||||
// Log.d("#TEST", "Failed to decrypt $index :$failure")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
val cryptoService = bobSession.cryptoService()
|
||||
|
||||
coroutineScope.launch {
|
||||
for (spawn in 1..100) {
|
||||
delay((Random.nextFloat() * 1000).toLong())
|
||||
aliceSession.cryptoService().requestRoomKeyForEvent(eventList.random().root)
|
||||
}
|
||||
}
|
||||
|
||||
for (spawn in 1..8000) {
|
||||
eventList.random().let { event ->
|
||||
coroutineScope.async {
|
||||
try {
|
||||
cryptoService.decryptEvent(event.root, "")
|
||||
Log.d("#TEST", "[$spawn] Decrypt Success ${event.eventId} :${Thread.currentThread().name}")
|
||||
} catch (failure: Throwable) {
|
||||
atomicAsError.set(true)
|
||||
Log.e("#TEST", "Failed to decrypt $spawn/${event.eventId} :$failure")
|
||||
}
|
||||
}.let {
|
||||
deff.add(it)
|
||||
}
|
||||
}
|
||||
// coroutineScope.async {
|
||||
// val index = Random.nextInt(eventList.size)
|
||||
// try {
|
||||
// val event = eventList[index]
|
||||
// cryptoService.decryptEvent(event.root, "")
|
||||
// for (other in eventList.indices) {
|
||||
// if (other != index) {
|
||||
// cryptoService.decryptEventAsync(eventList[other].root, "", object : MatrixCallback<MXEventDecryptionResult> {
|
||||
// override fun onFailure(failure: Throwable) {
|
||||
// Log.e("#TEST", "Failed to decrypt $spawn/$index :$failure")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// Log.d("#TEST", "[$spawn] Decrypt Success $index :${Thread.currentThread().name}")
|
||||
// } catch (failure: Throwable) {
|
||||
// Log.e("#TEST", "Failed to decrypt $spawn/$index :$failure")
|
||||
// }
|
||||
// }.let {
|
||||
// deff.add(it)
|
||||
// }
|
||||
}
|
||||
|
||||
coroutineScope.launch {
|
||||
for (spawn in 1..100) {
|
||||
delay((Random.nextFloat() * 1000).toLong())
|
||||
bobSession.cryptoService().requestRoomKeyForEvent(eventList.random().root)
|
||||
}
|
||||
}
|
||||
|
||||
commonTestHelper.runBlockingTest(10 * 60_000) {
|
||||
deff.awaitAll()
|
||||
delay(10_000)
|
||||
assert(!atomicAsError.get())
|
||||
// There should be no errors?
|
||||
// deff.map { it.await() }.forEach {
|
||||
// it.fold({
|
||||
// Log.d("#TEST", "Decrypt Success :${it}")
|
||||
// }, {
|
||||
// Log.d("#TEST", "Failed to decrypt :$it")
|
||||
// })
|
||||
// val hasFailure = deff.any { it.await().exceptionOrNull() != null }
|
||||
// assert(!hasFailure)
|
||||
// }
|
||||
|
||||
commonTestHelper.signOutAndClose(aliceSession)
|
||||
commonTestHelper.signOutAndClose(bobSession)
|
||||
commonTestHelper.signOutAndClose(bobSession2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
/*
|
||||
* Copyright (c) 2022 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.matrix.android.sdk.internal.crypto
|
||||
|
||||
import android.util.Log
|
||||
import androidx.test.filters.LargeTest
|
||||
import kotlinx.coroutines.delay
|
||||
import org.junit.Assert
|
||||
import org.junit.FixMethodOrder
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.JUnit4
|
||||
import org.junit.runners.MethodSorters
|
||||
import org.matrix.android.sdk.InstrumentedTest
|
||||
import org.matrix.android.sdk.api.session.Session
|
||||
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
|
||||
import org.matrix.android.sdk.api.session.events.model.EventType
|
||||
import org.matrix.android.sdk.api.session.events.model.toModel
|
||||
import org.matrix.android.sdk.api.session.room.Room
|
||||
import org.matrix.android.sdk.api.session.room.failure.JoinRoomFailure
|
||||
import org.matrix.android.sdk.api.session.room.model.Membership
|
||||
import org.matrix.android.sdk.api.session.room.model.message.MessageContent
|
||||
import org.matrix.android.sdk.api.session.room.send.SendState
|
||||
import org.matrix.android.sdk.api.session.room.timeline.TimelineSettings
|
||||
import org.matrix.android.sdk.common.CommonTestHelper
|
||||
import org.matrix.android.sdk.common.CryptoTestHelper
|
||||
import org.matrix.android.sdk.common.SessionTestParams
|
||||
|
||||
@RunWith(JUnit4::class)
|
||||
@FixMethodOrder(MethodSorters.JVM)
|
||||
@LargeTest
|
||||
class SimpleE2EEChatTest : InstrumentedTest {
|
||||
|
||||
private val testHelper = CommonTestHelper(context())
|
||||
private val cryptoTestHelper = CryptoTestHelper(testHelper)
|
||||
|
||||
/**
|
||||
* Simple test that create an e2ee room.
|
||||
* Some new members are added, and a message is sent.
|
||||
* We check that the message is e2e and can be decrypted.
|
||||
*
|
||||
* Additional users join, we check that they can't decrypt history
|
||||
*
|
||||
* Alice sends a new message, then check that the new one can be decrypted
|
||||
*/
|
||||
@Test
|
||||
fun testSendingE2EEMessages() {
|
||||
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)
|
||||
val aliceSession = cryptoTestData.firstSession
|
||||
val e2eRoomID = cryptoTestData.roomId
|
||||
|
||||
val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!!
|
||||
|
||||
// add some more users and invite them
|
||||
val otherAccounts = listOf("benoit", "valere", "ganfra") // , "adam", "manu")
|
||||
.map {
|
||||
testHelper.createAccount(it, SessionTestParams(true))
|
||||
}
|
||||
|
||||
Log.v("#E2E TEST", "All accounts created")
|
||||
// we want to invite them in the room
|
||||
otherAccounts.forEach {
|
||||
testHelper.runBlockingTest {
|
||||
Log.v("#E2E TEST", "Alice invites ${it.myUserId}")
|
||||
aliceRoomPOV.invite(it.myUserId)
|
||||
}
|
||||
}
|
||||
|
||||
// All user should accept invite
|
||||
otherAccounts.forEach { otherSession ->
|
||||
waitForAndAcceptInviteInRoom(otherSession, e2eRoomID)
|
||||
Log.v("#E2E TEST", "${otherSession.myUserId} joined room $e2eRoomID")
|
||||
}
|
||||
|
||||
// check that alice see them as joined (not really necessary?)
|
||||
ensureMembersHaveJoined(aliceSession, otherAccounts, e2eRoomID)
|
||||
|
||||
Log.v("#E2E TEST", "All users have joined the room")
|
||||
|
||||
Log.v("#E2E TEST", "Alice is sending the message")
|
||||
|
||||
val text = "This is my message"
|
||||
val sentEventId: String? = sendMessageInRoom(aliceRoomPOV, text)
|
||||
// val sentEvent = testHelper.sendTextMessage(aliceRoomPOV, "Hello all", 1).first()
|
||||
Assert.assertTrue("Message should be sent", sentEventId != null)
|
||||
|
||||
// All should be able to decrypt
|
||||
otherAccounts.forEach { otherSession ->
|
||||
testHelper.waitWithLatch { latch ->
|
||||
testHelper.retryPeriodicallyWithLatch(latch) {
|
||||
val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!)
|
||||
timeLineEvent != null &&
|
||||
timeLineEvent.isEncrypted() &&
|
||||
timeLineEvent.root.getClearType() == EventType.MESSAGE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new user to the room, and check that he can't decrypt
|
||||
val newAccount = listOf("adam") // , "adam", "manu")
|
||||
.map {
|
||||
testHelper.createAccount(it, SessionTestParams(true))
|
||||
}
|
||||
|
||||
newAccount.forEach {
|
||||
testHelper.runBlockingTest {
|
||||
Log.v("#E2E TEST", "Alice invites ${it.myUserId}")
|
||||
aliceRoomPOV.invite(it.myUserId)
|
||||
}
|
||||
}
|
||||
|
||||
newAccount.forEach {
|
||||
waitForAndAcceptInviteInRoom(it, e2eRoomID)
|
||||
}
|
||||
|
||||
ensureMembersHaveJoined(aliceSession, newAccount, e2eRoomID)
|
||||
|
||||
// wait a bit
|
||||
testHelper.runBlockingTest {
|
||||
delay(3_000)
|
||||
}
|
||||
|
||||
// check that messages are encrypted (uisi)
|
||||
newAccount.forEach { otherSession ->
|
||||
testHelper.waitWithLatch { latch ->
|
||||
testHelper.retryPeriodicallyWithLatch(latch) {
|
||||
val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!).also {
|
||||
Log.v("#E2E TEST", "Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}")
|
||||
}
|
||||
timeLineEvent != null &&
|
||||
timeLineEvent.root.getClearType() == EventType.ENCRYPTED &&
|
||||
timeLineEvent.root.mCryptoError == MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Let alice send a new message
|
||||
Log.v("#E2E TEST", "Alice sends a new message")
|
||||
|
||||
val secondMessage = "2 This is my message"
|
||||
val secondSentEventId: String? = sendMessageInRoom(aliceRoomPOV, secondMessage)
|
||||
|
||||
// new members should be able to decrypt it
|
||||
newAccount.forEach { otherSession ->
|
||||
testHelper.waitWithLatch { latch ->
|
||||
testHelper.retryPeriodicallyWithLatch(latch) {
|
||||
val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondSentEventId!!).also {
|
||||
Log.v("#E2E TEST", "Second Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}")
|
||||
}
|
||||
timeLineEvent != null &&
|
||||
timeLineEvent.root.getClearType() == EventType.MESSAGE &&
|
||||
secondMessage.equals(timeLineEvent.root.getClearContent().toModel<MessageContent>()?.body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
otherAccounts.forEach {
|
||||
testHelper.signOutAndClose(it)
|
||||
}
|
||||
newAccount.forEach { testHelper.signOutAndClose(it) }
|
||||
|
||||
cryptoTestData.cleanUp(testHelper)
|
||||
}
|
||||
|
||||
private fun sendMessageInRoom(aliceRoomPOV: Room, text: String): String? {
|
||||
aliceRoomPOV.sendTextMessage(text)
|
||||
var sentEventId: String? = null
|
||||
testHelper.waitWithLatch(4 * 60_000) {
|
||||
val timeline = aliceRoomPOV.createTimeline(null, TimelineSettings(60))
|
||||
timeline.start()
|
||||
|
||||
testHelper.retryPeriodicallyWithLatch(it) {
|
||||
val decryptedMsg = timeline.getSnapshot()
|
||||
.filter { it.root.getClearType() == EventType.MESSAGE }
|
||||
.also {
|
||||
Log.v("#E2E TEST", "Timeline snapshot is ${it.map { "${it.root.type}|${it.root.sendState}" }.joinToString(",", "[", "]")}")
|
||||
}
|
||||
.filter { it.root.sendState == SendState.SYNCED }
|
||||
.firstOrNull { it.root.getClearContent().toModel<MessageContent>()?.body?.startsWith(text) == true }
|
||||
sentEventId = decryptedMsg?.eventId
|
||||
decryptedMsg != null
|
||||
}
|
||||
|
||||
timeline.dispose()
|
||||
}
|
||||
return sentEventId
|
||||
}
|
||||
|
||||
private fun ensureMembersHaveJoined(aliceSession: Session, otherAccounts: List<Session>, e2eRoomID: String) {
|
||||
testHelper.waitWithLatch {
|
||||
testHelper.retryPeriodicallyWithLatch(it) {
|
||||
otherAccounts.map {
|
||||
aliceSession.getRoomMember(it.myUserId, e2eRoomID)?.membership
|
||||
}.all {
|
||||
it == Membership.JOIN
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun waitForAndAcceptInviteInRoom(otherSession: Session, e2eRoomID: String) {
|
||||
testHelper.waitWithLatch {
|
||||
testHelper.retryPeriodicallyWithLatch(it) {
|
||||
val roomSummary = otherSession.getRoomSummary(e2eRoomID)
|
||||
(roomSummary != null && roomSummary.membership == Membership.INVITE).also {
|
||||
if (it) {
|
||||
Log.v("#E2E TEST", "${otherSession.myUserId} can see the invite from alice")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testHelper.runBlockingTest(60_000) {
|
||||
Log.v("#E2E TEST", "${otherSession.myUserId} tries to join room $e2eRoomID")
|
||||
try {
|
||||
otherSession.joinRoom(e2eRoomID)
|
||||
} catch (ex: JoinRoomFailure.JoinedWithTimeout) {
|
||||
// it's ok we will wait after
|
||||
}
|
||||
}
|
||||
|
||||
Log.v("#E2E TEST", "${otherSession.myUserId} waiting for join echo ...")
|
||||
testHelper.waitWithLatch {
|
||||
testHelper.retryPeriodicallyWithLatch(it) {
|
||||
val roomSummary = otherSession.getRoomSummary(e2eRoomID)
|
||||
roomSummary != null && roomSummary.membership == Membership.JOIN
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -47,6 +47,7 @@ internal class MXOlmDevice @Inject constructor(
|
|||
* The store where crypto data is saved.
|
||||
*/
|
||||
private val store: IMXCryptoStore,
|
||||
private val olmSessionStore: OlmSessionStore,
|
||||
private val inboundGroupSessionStore: InboundGroupSessionStore
|
||||
) {
|
||||
|
||||
|
@ -190,6 +191,7 @@ internal class MXOlmDevice @Inject constructor(
|
|||
it.groupSession.releaseSession()
|
||||
}
|
||||
outboundGroupSessionCache.clear()
|
||||
olmSessionStore.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -257,7 +259,8 @@ internal class MXOlmDevice @Inject constructor(
|
|||
// this session
|
||||
olmSessionWrapper.onMessageReceived()
|
||||
|
||||
store.storeSession(olmSessionWrapper, theirIdentityKey)
|
||||
olmSessionStore.storeSession(olmSessionWrapper, theirIdentityKey)
|
||||
// store.storeSession(olmSessionWrapper, theirIdentityKey)
|
||||
|
||||
val sessionIdentifier = olmSession.sessionIdentifier()
|
||||
|
||||
|
@ -324,7 +327,7 @@ internal class MXOlmDevice @Inject constructor(
|
|||
// This counts as a received message: set last received message time to now
|
||||
olmSessionWrapper.onMessageReceived()
|
||||
|
||||
store.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
} catch (e: Exception) {
|
||||
Timber.e(e, "## createInboundSession() : decryptMessage failed")
|
||||
}
|
||||
|
@ -357,8 +360,8 @@ internal class MXOlmDevice @Inject constructor(
|
|||
* @param theirDeviceIdentityKey the Curve25519 identity key for the remote device.
|
||||
* @return a list of known session ids for the device.
|
||||
*/
|
||||
fun getSessionIds(theirDeviceIdentityKey: String): List<String>? {
|
||||
return store.getDeviceSessionIds(theirDeviceIdentityKey)
|
||||
fun getSessionIds(theirDeviceIdentityKey: String): List<String> {
|
||||
return olmSessionStore.getDeviceSessionIds(theirDeviceIdentityKey)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -368,7 +371,7 @@ internal class MXOlmDevice @Inject constructor(
|
|||
* @return the session id, or null if no established session.
|
||||
*/
|
||||
fun getSessionId(theirDeviceIdentityKey: String): String? {
|
||||
return store.getLastUsedSessionId(theirDeviceIdentityKey)
|
||||
return olmSessionStore.getLastUsedSessionId(theirDeviceIdentityKey)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -390,7 +393,8 @@ internal class MXOlmDevice @Inject constructor(
|
|||
// Timber.v("## encryptMessage() : payloadString: " + payloadString);
|
||||
|
||||
olmMessage = olmSessionWrapper.olmSession.encryptMessage(payloadString)
|
||||
store.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
// store.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
res = HashMap()
|
||||
|
||||
res["body"] = olmMessage.mCipherText
|
||||
|
@ -427,7 +431,8 @@ internal class MXOlmDevice @Inject constructor(
|
|||
try {
|
||||
payloadString = olmSessionWrapper.olmSession.decryptMessage(olmMessage)
|
||||
olmSessionWrapper.onMessageReceived()
|
||||
store.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
// store.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
|
||||
} catch (e: Exception) {
|
||||
Timber.e(e, "## decryptMessage() : decryptMessage failed")
|
||||
}
|
||||
|
@ -819,7 +824,8 @@ internal class MXOlmDevice @Inject constructor(
|
|||
private fun getSessionForDevice(theirDeviceIdentityKey: String, sessionId: String): OlmSessionWrapper? {
|
||||
// sanity check
|
||||
return if (theirDeviceIdentityKey.isEmpty() || sessionId.isEmpty()) null else {
|
||||
store.getDeviceSession(sessionId, theirDeviceIdentityKey)
|
||||
olmSessionStore.getDeviceSession(sessionId, theirDeviceIdentityKey)
|
||||
// store.getDeviceSession(sessionId, theirDeviceIdentityKey)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Copyright (c) 2022 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.matrix.android.sdk.internal.crypto
|
||||
|
||||
import org.matrix.android.sdk.api.extensions.tryOrNull
|
||||
import org.matrix.android.sdk.internal.crypto.model.OlmSessionWrapper
|
||||
import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore
|
||||
import javax.inject.Inject
|
||||
|
||||
/**
|
||||
* Keep the used olm session in memory and load them from the data layer when needed
|
||||
* Access is synchronized for thread safety
|
||||
*/
|
||||
internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoStore) {
|
||||
|
||||
/*
|
||||
* map of device key to list of olm sessions (it is possible to have several active sessions with a device)
|
||||
*/
|
||||
private val olmSessions = HashMap<String, MutableList<OlmSessionWrapper>>()
|
||||
|
||||
/**
|
||||
* Store a session between the logged-in user and another device.
|
||||
* This will be called after creation but also after any use of the ratchet
|
||||
* in order to persist the correct state for next run
|
||||
* @param olmSessionWrapper the end-to-end session.
|
||||
* @param deviceKey the public key of the other device.
|
||||
*/
|
||||
@Synchronized
|
||||
fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) {
|
||||
// This could be a newly created session or one that was just created
|
||||
// Anyhow we should persist ratchet state for futur app lifecycle
|
||||
addNewSessionInCache(olmSessionWrapper, deviceKey)
|
||||
store.storeSession(olmSessionWrapper, deviceKey)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the end-to-end session ids between the logged-in user and another
|
||||
* device.
|
||||
*
|
||||
* @param deviceKey the public key of the other device.
|
||||
* @return A set of sessionId, or empty if device is not known
|
||||
*/
|
||||
@Synchronized
|
||||
fun getDeviceSessionIds(deviceKey: String): List<String> {
|
||||
return internalGetAllSessions(deviceKey)
|
||||
}
|
||||
|
||||
private fun internalGetAllSessions(deviceKey: String): MutableList<String> {
|
||||
// we need to get the persisted ids first
|
||||
val persistedKnownSessions = store.getDeviceSessionIds(deviceKey)
|
||||
.orEmpty()
|
||||
.toMutableList()
|
||||
// Do we have some in cache not yet persisted?
|
||||
olmSessions.getOrPut(deviceKey) { mutableListOf() }.forEach { cached ->
|
||||
tryOrNull("Olm session was released") { cached.olmSession.sessionIdentifier() }?.let { cachedSessionId ->
|
||||
if (!persistedKnownSessions.contains(cachedSessionId)) {
|
||||
persistedKnownSessions.add(cachedSessionId)
|
||||
}
|
||||
}
|
||||
}
|
||||
return persistedKnownSessions
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve an end-to-end session between the logged-in user and another
|
||||
* device.
|
||||
*
|
||||
* @param sessionId the session Id.
|
||||
* @param deviceKey the public key of the other device.
|
||||
* @return The Base64 end-to-end session, or null if not found
|
||||
*/
|
||||
@Synchronized
|
||||
fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? {
|
||||
// get from cache or load and add to cache
|
||||
return internalGetSession(sessionId, deviceKey)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the last used sessionId, regarding `lastReceivedMessageTs`, or null if no session exist
|
||||
*
|
||||
* @param deviceKey the public key of the other device.
|
||||
* @return last used sessionId, or null if not found
|
||||
*/
|
||||
@Synchronized
|
||||
fun getLastUsedSessionId(deviceKey: String): String? {
|
||||
// We want to avoid to load in memory old session if possible
|
||||
val lastPersistedUsedSession = store.getLastUsedSessionId(deviceKey)
|
||||
var candidate = lastPersistedUsedSession?.let { internalGetSession(it, deviceKey) }
|
||||
// we should check if we have one in cache with a higher last message received?
|
||||
olmSessions[deviceKey].orEmpty().forEach { inCache ->
|
||||
if (inCache.lastReceivedMessageTs > (candidate?.lastReceivedMessageTs ?: 0L)) {
|
||||
candidate = inCache
|
||||
}
|
||||
}
|
||||
|
||||
return candidate?.olmSession?.sessionIdentifier()
|
||||
}
|
||||
|
||||
/**
|
||||
* Release all sessions and clear cache
|
||||
*/
|
||||
@Synchronized
|
||||
fun clear() {
|
||||
olmSessions.entries.onEach { entry ->
|
||||
entry.value.onEach { it.olmSession.releaseSession() }
|
||||
}
|
||||
olmSessions.clear()
|
||||
}
|
||||
|
||||
private fun internalGetSession(sessionId: String, deviceKey: String): OlmSessionWrapper? {
|
||||
return getSessionInCache(sessionId, deviceKey)
|
||||
?: // deserialize from store
|
||||
return store.getDeviceSession(sessionId, deviceKey)?.also {
|
||||
addNewSessionInCache(it, deviceKey)
|
||||
}
|
||||
}
|
||||
|
||||
private fun getSessionInCache(sessionId: String, deviceKey: String): OlmSessionWrapper? {
|
||||
return olmSessions[deviceKey]?.firstOrNull {
|
||||
it.olmSession.isReleased && it.olmSession.sessionIdentifier() == sessionId
|
||||
}
|
||||
}
|
||||
|
||||
private fun addNewSessionInCache(session: OlmSessionWrapper, deviceKey: String) {
|
||||
val sessionId = tryOrNull { session.olmSession.sessionIdentifier() } ?: return
|
||||
olmSessions.getOrPut(deviceKey) { mutableListOf() }.let {
|
||||
val existing = it.firstOrNull { tryOrNull { it.olmSession.sessionIdentifier() } == sessionId }
|
||||
it.add(session)
|
||||
// remove and release if was there but with different instance
|
||||
if (existing != null && existing.olmSession != session.olmSession) {
|
||||
// mm not sure when this could happen
|
||||
// anyhow we should remove and release the one known
|
||||
it.remove(existing)
|
||||
existing.olmSession.releaseSession()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -154,7 +154,7 @@ internal class MXOlmDecryption(
|
|||
* @return payload, if decrypted successfully.
|
||||
*/
|
||||
private fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? {
|
||||
val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey).orEmpty()
|
||||
val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey)
|
||||
|
||||
val messageBody = message["body"] as? String ?: return null
|
||||
val messageType = when (val typeAsVoid = message["type"]) {
|
||||
|
|
|
@ -125,7 +125,7 @@ internal class RealmCryptoStore @Inject constructor(
|
|||
private var olmAccount: OlmAccount? = null
|
||||
|
||||
// Cache for OlmSession, to release them properly
|
||||
private val olmSessionsToRelease = HashMap<String, OlmSessionWrapper>()
|
||||
// private val olmSessionsToRelease = HashMap<String, OlmSessionWrapper>()
|
||||
|
||||
// Cache for InboundGroupSession, to release them properly
|
||||
private val inboundGroupSessionToRelease = HashMap<String, OlmInboundGroupSessionWrapper2>()
|
||||
|
@ -213,11 +213,6 @@ internal class RealmCryptoStore @Inject constructor(
|
|||
monarchyWriteAsyncExecutor.awaitTermination(1, TimeUnit.MINUTES)
|
||||
}
|
||||
|
||||
olmSessionsToRelease.forEach {
|
||||
it.value.olmSession.releaseSession()
|
||||
}
|
||||
olmSessionsToRelease.clear()
|
||||
|
||||
inboundGroupSessionToRelease.forEach {
|
||||
it.value.olmInboundGroupSession?.releaseSession()
|
||||
}
|
||||
|
@ -680,13 +675,6 @@ internal class RealmCryptoStore @Inject constructor(
|
|||
if (sessionIdentifier != null) {
|
||||
val key = OlmSessionEntity.createPrimaryKey(sessionIdentifier, deviceKey)
|
||||
|
||||
// Release memory of previously known session, if it is not the same one
|
||||
if (olmSessionsToRelease[key]?.olmSession != olmSessionWrapper.olmSession) {
|
||||
olmSessionsToRelease[key]?.olmSession?.releaseSession()
|
||||
}
|
||||
|
||||
olmSessionsToRelease[key] = olmSessionWrapper
|
||||
|
||||
doRealmTransaction(realmConfiguration) {
|
||||
val realmOlmSession = OlmSessionEntity().apply {
|
||||
primaryKey = key
|
||||
|
@ -703,23 +691,18 @@ internal class RealmCryptoStore @Inject constructor(
|
|||
|
||||
override fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? {
|
||||
val key = OlmSessionEntity.createPrimaryKey(sessionId, deviceKey)
|
||||
|
||||
// If not in cache (or not found), try to read it from realm
|
||||
if (olmSessionsToRelease[key] == null) {
|
||||
doRealmQueryAndCopy(realmConfiguration) {
|
||||
it.where<OlmSessionEntity>()
|
||||
.equalTo(OlmSessionEntityFields.PRIMARY_KEY, key)
|
||||
.findFirst()
|
||||
}
|
||||
?.let {
|
||||
val olmSession = it.getOlmSession()
|
||||
if (olmSession != null && it.sessionId != null) {
|
||||
olmSessionsToRelease[key] = OlmSessionWrapper(olmSession, it.lastReceivedMessageTs)
|
||||
}
|
||||
}
|
||||
return doRealmQueryAndCopy(realmConfiguration) {
|
||||
it.where<OlmSessionEntity>()
|
||||
.equalTo(OlmSessionEntityFields.PRIMARY_KEY, key)
|
||||
.findFirst()
|
||||
}
|
||||
|
||||
return olmSessionsToRelease[key]
|
||||
?.let {
|
||||
val olmSession = it.getOlmSession()
|
||||
if (olmSession != null && it.sessionId != null) {
|
||||
return@let OlmSessionWrapper(olmSession, it.lastReceivedMessageTs)
|
||||
}
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
override fun getLastUsedSessionId(deviceKey: String): String? {
|
||||
|
|
Loading…
Add table
Reference in a new issue