Protect olm session from concurrent access

This commit is contained in:
Valere 2022-02-28 08:55:03 +01:00
parent 10ea166b2a
commit 33f9bc52cb
17 changed files with 75 additions and 44 deletions

View file

@ -121,7 +121,7 @@ interface CryptoService {
fun discardOutboundSession(roomId: String)
@Throws(MXCryptoError::class)
fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult
suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult
fun decryptEventAsync(event: Event, timeline: String, callback: MatrixCallback<MXEventDecryptionResult>)

View file

@ -723,7 +723,7 @@ internal class DefaultCryptoService @Inject constructor(
* @return the MXEventDecryptionResult data, or throw in case of error
*/
@Throws(MXCryptoError::class)
override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
return internalDecryptEvent(event, timeline)
}
@ -746,7 +746,7 @@ internal class DefaultCryptoService @Inject constructor(
* @return the MXEventDecryptionResult data, or null in case of error
*/
@Throws(MXCryptoError::class)
private fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
return eventDecryptor.decryptEvent(event, timeline)
}

View file

@ -63,7 +63,7 @@ internal class EventDecryptor @Inject constructor(
* @return the MXEventDecryptionResult data, or throw in case of error
*/
@Throws(MXCryptoError::class)
fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
return internalDecryptEvent(event, timeline)
}
@ -91,7 +91,7 @@ internal class EventDecryptor @Inject constructor(
* @return the MXEventDecryptionResult data, or null in case of error
*/
@Throws(MXCryptoError::class)
private fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
val eventContent = event.content
if (eventContent == null) {
Timber.e("## CRYPTO | decryptEvent : empty event content")

View file

@ -16,6 +16,7 @@
package org.matrix.android.sdk.internal.crypto
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.util.JSON_DICT_PARAMETERIZED_TYPE
import org.matrix.android.sdk.api.util.JsonDict
@ -382,31 +383,30 @@ internal class MXOlmDevice @Inject constructor(
* @param payloadString the payload to be encrypted and sent
* @return the cipher text
*/
fun encryptMessage(theirDeviceIdentityKey: String, sessionId: String, payloadString: String): Map<String, Any>? {
var res: MutableMap<String, Any>? = null
val olmMessage: OlmMessage
suspend fun encryptMessage(theirDeviceIdentityKey: String, sessionId: String, payloadString: String): Map<String, Any>? {
val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId)
if (olmSessionWrapper != null) {
try {
Timber.v("## encryptMessage() : olmSession.sessionIdentifier: $sessionId")
// Timber.v("## encryptMessage() : payloadString: " + payloadString);
olmMessage = olmSessionWrapper.olmSession.encryptMessage(payloadString)
// store.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
res = HashMap()
res["body"] = olmMessage.mCipherText
res["type"] = olmMessage.mType
} catch (e: Exception) {
Timber.e(e, "## encryptMessage() : failed")
val olmMessage = olmSessionWrapper.mutex.withLock {
olmSessionWrapper.olmSession.encryptMessage(payloadString)
}
return mapOf(
"body" to olmMessage.mCipherText,
"type" to olmMessage.mType,
).also {
olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
}
} catch (e: Throwable) {
Timber.e(e, "## encryptMessage() : failed to encrypt olm with device|session:$theirDeviceIdentityKey|$sessionId")
return null
}
} else {
Timber.e("## encryptMessage() : Failed to encrypt unknown session $sessionId")
return null
}
return res
}
/**
@ -418,7 +418,7 @@ internal class MXOlmDevice @Inject constructor(
* @param sessionId the id of the active session.
* @return the decrypted payload.
*/
fun decryptMessage(ciphertext: String, messageType: Int, sessionId: String, theirDeviceIdentityKey: String): String? {
suspend fun decryptMessage(ciphertext: String, messageType: Int, sessionId: String, theirDeviceIdentityKey: String): String? {
var payloadString: String? = null
val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId)
@ -429,8 +429,12 @@ internal class MXOlmDevice @Inject constructor(
olmMessage.mType = messageType.toLong()
try {
payloadString = olmSessionWrapper.olmSession.decryptMessage(olmMessage)
olmSessionWrapper.onMessageReceived()
payloadString =
olmSessionWrapper.mutex.withLock {
olmSessionWrapper.olmSession.decryptMessage(olmMessage).also {
olmSessionWrapper.onMessageReceived()
}
}
olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
// store.storeSession(olmSessionWrapper, theirDeviceIdentityKey)
} catch (e: Exception) {

View file

@ -16,17 +16,20 @@
package org.matrix.android.sdk.internal.crypto
import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.internal.crypto.model.OlmSessionWrapper
import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore
import org.matrix.olm.OlmSession
import timber.log.Timber
import javax.inject.Inject
private val loggerTag = LoggerTag("OlmSessionStore", LoggerTag.CRYPTO)
/**
* Keep the used olm session in memory and load them from the data layer when needed
* Access is synchronized for thread safety
*/
internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoStore) {
/*
* map of device key to list of olm sessions (it is possible to have several active sessions with a device)
*/
@ -66,7 +69,7 @@ internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoS
.toMutableList()
// Do we have some in cache not yet persisted?
olmSessions.getOrPut(deviceKey) { mutableListOf() }.forEach { cached ->
tryOrNull("Olm session was released") { cached.olmSession.sessionIdentifier() }?.let { cachedSessionId ->
getSafeSessionIdentifier(cached.olmSession)?.let { cachedSessionId ->
if (!persistedKnownSessions.contains(cachedSessionId)) {
persistedKnownSessions.add(cachedSessionId)
}
@ -131,14 +134,23 @@ internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoS
private fun getSessionInCache(sessionId: String, deviceKey: String): OlmSessionWrapper? {
return olmSessions[deviceKey]?.firstOrNull {
it.olmSession.isReleased && it.olmSession.sessionIdentifier() == sessionId
getSafeSessionIdentifier(it.olmSession) == sessionId
}
}
private fun getSafeSessionIdentifier(session: OlmSession): String? {
return try {
session.sessionIdentifier()
} catch (throwable: Throwable) {
Timber.tag(loggerTag.value).w("Failed to load sessionId from loaded olm session")
null
}
}
private fun addNewSessionInCache(session: OlmSessionWrapper, deviceKey: String) {
val sessionId = tryOrNull { session.olmSession.sessionIdentifier() } ?: return
val sessionId = getSafeSessionIdentifier(session.olmSession) ?: return
olmSessions.getOrPut(deviceKey) { mutableListOf() }.let {
val existing = it.firstOrNull { tryOrNull { it.olmSession.sessionIdentifier() } == sessionId }
val existing = it.firstOrNull { getSafeSessionIdentifier(it.olmSession) == sessionId }
it.add(session)
// remove and release if was there but with different instance
if (existing != null && existing.olmSession != session.olmSession) {

View file

@ -42,7 +42,7 @@ internal class MessageEncrypter @Inject constructor(
* @param deviceInfos list of device infos to encrypt for.
* @return the content for an m.room.encrypted event.
*/
fun encryptMessage(payloadFields: Content, deviceInfos: List<CryptoDeviceInfo>): EncryptedMessage {
suspend fun encryptMessage(payloadFields: Content, deviceInfos: List<CryptoDeviceInfo>): EncryptedMessage {
val deviceInfoParticipantKey = deviceInfos.associateBy { it.identityKey()!! }
val payloadJson = payloadFields.toMutableMap()

View file

@ -36,7 +36,7 @@ internal interface IMXDecrypting {
* @return the decryption information, or an error
*/
@Throws(MXCryptoError::class)
fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult
suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult
/**
* Handle a key event.

View file

@ -71,7 +71,7 @@ internal class MXMegolmDecryption(private val userId: String,
// private var pendingEvents: MutableMap<String /* senderKey|sessionId */, MutableMap<String /* timelineId */, MutableList<Event>>> = HashMap()
@Throws(MXCryptoError::class)
override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
// If cross signing is enabled, we don't send request until the keys are trusted
// There could be a race effect here when xsigning is enabled, we should ensure that keys was downloaded once
val requestOnFail = cryptoStore.getMyCrossSigningInfo()?.isTrusted() == true

View file

@ -38,7 +38,7 @@ internal class MXOlmDecryption(
IMXDecrypting {
@Throws(MXCryptoError::class)
override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
val olmEventContent = event.content.toModel<OlmEventContent>() ?: run {
Timber.e("## decryptEvent() : bad event format")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_EVENT_FORMAT,
@ -153,7 +153,7 @@ internal class MXOlmDecryption(
* @param message message object, with 'type' and 'body' fields.
* @return payload, if decrypted successfully.
*/
private fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? {
private suspend fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? {
val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey)
val messageBody = message["body"] as? String ?: return null

View file

@ -16,6 +16,7 @@
package org.matrix.android.sdk.internal.crypto.model
import kotlinx.coroutines.sync.Mutex
import org.matrix.olm.OlmSession
/**
@ -25,7 +26,9 @@ data class OlmSessionWrapper(
// The associated olm session.
val olmSession: OlmSession,
// Timestamp at which the session last received a message.
var lastReceivedMessageTs: Long = 0) {
var lastReceivedMessageTs: Long = 0,
var mutex: Mutex = Mutex()) {
/**
* Notify that a message has been received on this olm session so that it updates `lastReceivedMessageTs`

View file

@ -156,7 +156,7 @@ internal class DefaultFetchThreadTimelineTask @Inject constructor(
* Invoke the event decryption mechanism for a specific event
*/
private fun decryptIfNeeded(event: Event, roomId: String) {
private suspend fun decryptIfNeeded(event: Event, roomId: String) {
try {
// Event from sync does not have roomId, so add it to the event first
val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "")

View file

@ -18,6 +18,7 @@ package org.matrix.android.sdk.internal.session.room.summary
import io.realm.Realm
import io.realm.kotlin.createObject
import kotlinx.coroutines.runBlocking
import org.matrix.android.sdk.api.extensions.orFalse
import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.api.session.events.model.EventType
@ -165,7 +166,9 @@ internal class RoomSummaryUpdater @Inject constructor(
Timber.v("Should decrypt ${latestPreviewableEvent.eventId}")
// mmm i want to decrypt now or is it ok to do it async?
tryOrNull {
eventDecryptor.decryptEvent(root.asDomain(), "")
runBlocking {
eventDecryptor.decryptEvent(root.asDomain(), "")
}
}
?.let { root.setDecryptionResult(it) }
}

View file

@ -17,6 +17,7 @@ package org.matrix.android.sdk.internal.session.room.timeline
import io.realm.Realm
import io.realm.RealmConfiguration
import kotlinx.coroutines.runBlocking
import org.matrix.android.sdk.api.session.crypto.CryptoService
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event
@ -99,7 +100,9 @@ internal class TimelineEventDecryptor @Inject constructor(
}
executor?.execute {
Realm.getInstance(realmConfiguration).use { realm ->
processDecryptRequest(request, realm)
runBlocking {
processDecryptRequest(request, realm)
}
}
}
}
@ -115,7 +118,7 @@ internal class TimelineEventDecryptor @Inject constructor(
threadsAwarenessHandler.makeEventThreadAware(realm, event.roomId, decryptedEvent, eventEntity)
}
}
private fun processDecryptRequest(request: DecryptionRequest, realm: Realm) {
private suspend fun processDecryptRequest(request: DecryptionRequest, realm: Realm) {
val event = request.event
val timelineId = request.timelineId

View file

@ -110,6 +110,7 @@ internal class SyncResponseHandler @Inject constructor(
// Start one big transaction
monarchy.awaitTransaction { realm ->
// IMPORTANT nothing should be suspend here as we are accessing the realm instance (thread local)
measureTimeMillis {
Timber.v("Handle rooms")
reportSubtask(reporter, InitSyncStep.ImportingAccountRoom, 1, 0.7f) {

View file

@ -38,7 +38,7 @@ private val loggerTag = LoggerTag("CryptoSyncHandler", LoggerTag.CRYPTO)
internal class CryptoSyncHandler @Inject constructor(private val cryptoService: DefaultCryptoService,
private val verificationService: DefaultVerificationService) {
fun handleToDevice(toDevice: ToDeviceSyncResponse, progressReporter: ProgressReporter? = null) {
suspend fun handleToDevice(toDevice: ToDeviceSyncResponse, progressReporter: ProgressReporter? = null) {
val total = toDevice.events?.size ?: 0
toDevice.events?.forEachIndexed { index, event ->
progressReporter?.reportProgress(index * 100F / total)
@ -66,7 +66,7 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService:
* @param timelineId the timeline identifier
* @return true if the event has been decrypted
*/
private fun decryptToDeviceEvent(event: Event, timelineId: String?): Boolean {
private suspend fun decryptToDeviceEvent(event: Event, timelineId: String?): Boolean {
Timber.v("## CRYPTO | decryptToDeviceEvent")
if (event.getClearType() == EventType.ENCRYPTED) {
var result: MXEventDecryptionResult? = null
@ -80,6 +80,8 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService:
it.identityKey() == senderKey
}?.deviceId ?: senderKey
Timber.e("## CRYPTO | Failed to decrypt to device event from ${event.senderId}|$deviceId reason:<${event.mCryptoError ?: exception}>")
} catch (failure: Throwable) {
Timber.e(failure, "## CRYPTO | Failed to decrypt to device event from ${event.senderId}")
}
if (null != result) {

View file

@ -19,6 +19,7 @@ package org.matrix.android.sdk.internal.session.sync.handler.room
import dagger.Lazy
import io.realm.Realm
import io.realm.kotlin.createObject
import kotlinx.coroutines.runBlocking
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event
import org.matrix.android.sdk.api.session.events.model.EventType
@ -379,7 +380,9 @@ internal class RoomSyncHandler @Inject constructor(private val readReceiptHandle
val isInitialSync = insertType == EventInsertType.INITIAL_SYNC
if (event.isEncrypted() && !isInitialSync) {
decryptIfNeeded(event, roomId)
runBlocking {
decryptIfNeeded(event, roomId)
}
}
var contentToInject: String? = null
if (!isInitialSync) {
@ -455,7 +458,7 @@ internal class RoomSyncHandler @Inject constructor(private val readReceiptHandle
return chunkEntity
}
private fun decryptIfNeeded(event: Event, roomId: String) {
private suspend fun decryptIfNeeded(event: Event, roomId: String) {
try {
// Event from sync does not have roomId, so add it to the event first
val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "")

View file

@ -190,7 +190,7 @@ class NotifiableEventResolver @Inject constructor(
}
}
private fun TimelineEvent.attemptToDecryptIfNeeded(session: Session) {
private suspend fun TimelineEvent.attemptToDecryptIfNeeded(session: Session) {
if (root.isEncrypted() && root.mxDecryptionResult == null) {
// TODO use a global event decryptor? attache to session and that listen to new sessionId?
// for now decrypt sync