diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/coroutines/builder/FlowBuilders.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/coroutines/builder/FlowBuilders.kt new file mode 100644 index 0000000000..282b1b2cf3 --- /dev/null +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/coroutines/builder/FlowBuilders.kt @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.matrix.android.sdk.internal.coroutines.builder + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.ProducerScope + +/** + * Use this with a flow builder like [kotlinx.coroutines.flow.channelFlow] to replace [kotlinx.coroutines.channels.awaitClose]. + * As awaitClose is at the end of the builder block, it can lead to the block being cancelled before it reaches the awaitClose. + * Example of usage: + * + * return channelFlow { + * val onClose = safeInvokeOnClose { + * // Do stuff on close + * } + * val data = getData() + * send(data) + * onClose.await() + * } + * + */ +fun ProducerScope.safeInvokeOnClose(handler: (cause: Throwable?) -> Unit): CompletableDeferred { + val onClose = CompletableDeferred() + invokeOnClose { + handler(it) + onClose.complete(Unit) + } + return onClose +} diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt index 581a3425d6..46fab21f0b 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt @@ -18,7 +18,6 @@ package org.matrix.android.sdk.internal.crypto import com.squareup.moshi.Moshi import kotlinx.coroutines.channels.SendChannel -import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.runBlocking @@ -39,6 +38,7 @@ import org.matrix.android.sdk.api.session.sync.model.ToDeviceSyncResponse import org.matrix.android.sdk.api.util.JsonDict import org.matrix.android.sdk.api.util.Optional import org.matrix.android.sdk.api.util.toOptional +import org.matrix.android.sdk.internal.coroutines.builder.safeInvokeOnClose import org.matrix.android.sdk.internal.crypto.crosssigning.DeviceTrustLevel import org.matrix.android.sdk.internal.crypto.crosssigning.UserTrustResult import org.matrix.android.sdk.internal.crypto.keysbackup.model.MegolmBackupAuthData @@ -48,7 +48,6 @@ import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap import org.matrix.android.sdk.internal.crypto.model.rest.RestKeyInfo import org.matrix.android.sdk.internal.crypto.model.rest.UnsignedDeviceInfo import org.matrix.android.sdk.internal.crypto.store.PrivateKeysInfo -import org.matrix.android.sdk.internal.di.MoshiProvider import org.matrix.android.sdk.internal.network.parsing.CheckNumberType import timber.log.Timber import uniffi.olm.BackupKeys @@ -67,7 +66,6 @@ import uniffi.olm.setLogger import java.io.File import java.nio.charset.Charset import java.util.UUID -import java.util.concurrent.CopyOnWriteArrayList import uniffi.olm.OlmMachine as InnerMachine import uniffi.olm.ProgressListener as RustProgressListener import uniffi.olm.UserIdentity as RustUserIdentity @@ -89,9 +87,9 @@ private data class DevicesCollector(val userIds: List, val collector: Se private typealias PrivateKeysCollector = SendChannel> private class FlowCollectors { - val userIdentityCollectors = CopyOnWriteArrayList() - val privateKeyCollectors = CopyOnWriteArrayList() - val deviceCollectors = CopyOnWriteArrayList() + val userIdentityCollectors = ArrayList() + val privateKeyCollectors = ArrayList() + val deviceCollectors = ArrayList() } fun setRustLogger() { @@ -132,21 +130,21 @@ internal class OlmMachine( private suspend fun updateLiveDevices() { for (deviceCollector in flowCollectors.deviceCollectors) { val devices = getCryptoDeviceInfo(deviceCollector.userIds) - deviceCollector.send(devices) + deviceCollector.trySend(devices) } } private suspend fun updateLiveUserIdentities() { for (userIdentityCollector in flowCollectors.userIdentityCollectors) { val identity = getIdentity(userIdentityCollector.userId)?.toMxCrossSigningInfo() - userIdentityCollector.send(identity.toOptional()) + userIdentityCollector.trySend(identity.toOptional()) } } private suspend fun updateLivePrivateKeys() { val keys = exportCrossSigningKeys().toOptional() for (privateKeyCollector in flowCollectors.privateKeyCollectors) { - privateKeyCollector.send(keys) + privateKeyCollector.trySend(keys) } } @@ -204,7 +202,6 @@ internal class OlmMachine( ) = withContext(coroutineDispatchers.io) { inner.markRequestAsSent(requestId, requestType, responseBody) - if (requestType == RequestType.KEYS_QUERY) { updateLiveDevices() updateLiveUserIdentities() @@ -652,7 +649,18 @@ internal class OlmMachine( * The key query request will be retried a few time in case of shaky connection, but could fail. */ suspend fun ensureUserDevicesMap(userIds: List, forceDownload: Boolean = false): MXUsersDevicesMap { - val toDownload = if (forceDownload) { + ensureUsersKeys(userIds, forceDownload) + return getUserDevicesMap(userIds) + } + + /** + * If the user is untracked or forceDownload is set to true, a key query request will be made. + * It will suspend until query response. + * + * The key query request will be retried a few time in case of shaky connection, but could fail. + */ + suspend fun ensureUsersKeys(userIds: List, forceDownload: Boolean = false) { + val userIdsToFetchKeys = if (forceDownload) { userIds } else { userIds.mapNotNull { userId -> @@ -661,32 +669,33 @@ internal class OlmMachine( updateTrackedUsers(it) } } - tryOrNull("Failed to download keys for $toDownload") { - forceKeyDownload(toDownload) + tryOrNull("Failed to download keys for $userIdsToFetchKeys") { + forceKeyDownload(userIdsToFetchKeys) } - return getUserDevicesMap(userIds) } fun getLiveUserIdentity(userId: String): Flow> { return channelFlow { val userIdentityCollector = UserIdentityCollector(userId, this) + val onClose = safeInvokeOnClose { + flowCollectors.userIdentityCollectors.remove(userIdentityCollector) + } flowCollectors.userIdentityCollectors.add(userIdentityCollector) val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional() send(identity) - awaitClose { - flowCollectors.userIdentityCollectors.remove(userIdentityCollector) - } + onClose.await() } } fun getLivePrivateCrossSigningKeys(): Flow> { return channelFlow { + val onClose = safeInvokeOnClose { + flowCollectors.privateKeyCollectors.remove(this) + } flowCollectors.privateKeyCollectors.add(this) val keys = this@OlmMachine.exportCrossSigningKeys().toOptional() send(keys) - awaitClose { - flowCollectors.privateKeyCollectors.remove(this) - } + onClose.await() } } @@ -698,17 +707,18 @@ internal class OlmMachine( * * @param userIds The ids of the device owners. * - * @return The list of Devices or an empty list if there aren't any. + * @return The list of Devices or an empty list if there aren't any as a Flow. */ fun getLiveDevices(userIds: List): Flow> { return channelFlow { val devicesCollector = DevicesCollector(userIds, this) + val onClose = safeInvokeOnClose { + flowCollectors.deviceCollectors.remove(devicesCollector) + } flowCollectors.deviceCollectors.add(devicesCollector) val devices = getCryptoDeviceInfo(userIds) send(devices) - awaitClose { - flowCollectors.deviceCollectors.remove(devicesCollector) - } + onClose.await() } }