Suspend: fix flow builders

This commit is contained in:
ganfra 2022-04-25 17:55:17 +02:00
parent 5581b82ab4
commit 309a290cb8
2 changed files with 78 additions and 24 deletions

View file

@ -0,0 +1,44 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.coroutines.builder
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.channels.ProducerScope
/**
* Use this with a flow builder like [kotlinx.coroutines.flow.channelFlow] to replace [kotlinx.coroutines.channels.awaitClose].
* As awaitClose is at the end of the builder block, it can lead to the block being cancelled before it reaches the awaitClose.
* Example of usage:
*
* return channelFlow {
* val onClose = safeInvokeOnClose {
* // Do stuff on close
* }
* val data = getData()
* send(data)
* onClose.await()
* }
*
*/
fun <T> ProducerScope<T>.safeInvokeOnClose(handler: (cause: Throwable?) -> Unit): CompletableDeferred<Unit> {
val onClose = CompletableDeferred<Unit>()
invokeOnClose {
handler(it)
onClose.complete(Unit)
}
return onClose
}

View file

@ -18,7 +18,6 @@ package org.matrix.android.sdk.internal.crypto
import com.squareup.moshi.Moshi import com.squareup.moshi.Moshi
import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
@ -39,6 +38,7 @@ import org.matrix.android.sdk.api.session.sync.model.ToDeviceSyncResponse
import org.matrix.android.sdk.api.util.JsonDict import org.matrix.android.sdk.api.util.JsonDict
import org.matrix.android.sdk.api.util.Optional import org.matrix.android.sdk.api.util.Optional
import org.matrix.android.sdk.api.util.toOptional import org.matrix.android.sdk.api.util.toOptional
import org.matrix.android.sdk.internal.coroutines.builder.safeInvokeOnClose
import org.matrix.android.sdk.internal.crypto.crosssigning.DeviceTrustLevel import org.matrix.android.sdk.internal.crypto.crosssigning.DeviceTrustLevel
import org.matrix.android.sdk.internal.crypto.crosssigning.UserTrustResult import org.matrix.android.sdk.internal.crypto.crosssigning.UserTrustResult
import org.matrix.android.sdk.internal.crypto.keysbackup.model.MegolmBackupAuthData import org.matrix.android.sdk.internal.crypto.keysbackup.model.MegolmBackupAuthData
@ -48,7 +48,6 @@ import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.rest.RestKeyInfo import org.matrix.android.sdk.internal.crypto.model.rest.RestKeyInfo
import org.matrix.android.sdk.internal.crypto.model.rest.UnsignedDeviceInfo import org.matrix.android.sdk.internal.crypto.model.rest.UnsignedDeviceInfo
import org.matrix.android.sdk.internal.crypto.store.PrivateKeysInfo import org.matrix.android.sdk.internal.crypto.store.PrivateKeysInfo
import org.matrix.android.sdk.internal.di.MoshiProvider
import org.matrix.android.sdk.internal.network.parsing.CheckNumberType import org.matrix.android.sdk.internal.network.parsing.CheckNumberType
import timber.log.Timber import timber.log.Timber
import uniffi.olm.BackupKeys import uniffi.olm.BackupKeys
@ -67,7 +66,6 @@ import uniffi.olm.setLogger
import java.io.File import java.io.File
import java.nio.charset.Charset import java.nio.charset.Charset
import java.util.UUID import java.util.UUID
import java.util.concurrent.CopyOnWriteArrayList
import uniffi.olm.OlmMachine as InnerMachine import uniffi.olm.OlmMachine as InnerMachine
import uniffi.olm.ProgressListener as RustProgressListener import uniffi.olm.ProgressListener as RustProgressListener
import uniffi.olm.UserIdentity as RustUserIdentity import uniffi.olm.UserIdentity as RustUserIdentity
@ -89,9 +87,9 @@ private data class DevicesCollector(val userIds: List<String>, val collector: Se
private typealias PrivateKeysCollector = SendChannel<Optional<PrivateKeysInfo>> private typealias PrivateKeysCollector = SendChannel<Optional<PrivateKeysInfo>>
private class FlowCollectors { private class FlowCollectors {
val userIdentityCollectors = CopyOnWriteArrayList<UserIdentityCollector>() val userIdentityCollectors = ArrayList<UserIdentityCollector>()
val privateKeyCollectors = CopyOnWriteArrayList<PrivateKeysCollector>() val privateKeyCollectors = ArrayList<PrivateKeysCollector>()
val deviceCollectors = CopyOnWriteArrayList<DevicesCollector>() val deviceCollectors = ArrayList<DevicesCollector>()
} }
fun setRustLogger() { fun setRustLogger() {
@ -132,21 +130,21 @@ internal class OlmMachine(
private suspend fun updateLiveDevices() { private suspend fun updateLiveDevices() {
for (deviceCollector in flowCollectors.deviceCollectors) { for (deviceCollector in flowCollectors.deviceCollectors) {
val devices = getCryptoDeviceInfo(deviceCollector.userIds) val devices = getCryptoDeviceInfo(deviceCollector.userIds)
deviceCollector.send(devices) deviceCollector.trySend(devices)
} }
} }
private suspend fun updateLiveUserIdentities() { private suspend fun updateLiveUserIdentities() {
for (userIdentityCollector in flowCollectors.userIdentityCollectors) { for (userIdentityCollector in flowCollectors.userIdentityCollectors) {
val identity = getIdentity(userIdentityCollector.userId)?.toMxCrossSigningInfo() val identity = getIdentity(userIdentityCollector.userId)?.toMxCrossSigningInfo()
userIdentityCollector.send(identity.toOptional()) userIdentityCollector.trySend(identity.toOptional())
} }
} }
private suspend fun updateLivePrivateKeys() { private suspend fun updateLivePrivateKeys() {
val keys = exportCrossSigningKeys().toOptional() val keys = exportCrossSigningKeys().toOptional()
for (privateKeyCollector in flowCollectors.privateKeyCollectors) { for (privateKeyCollector in flowCollectors.privateKeyCollectors) {
privateKeyCollector.send(keys) privateKeyCollector.trySend(keys)
} }
} }
@ -204,7 +202,6 @@ internal class OlmMachine(
) = ) =
withContext(coroutineDispatchers.io) { withContext(coroutineDispatchers.io) {
inner.markRequestAsSent(requestId, requestType, responseBody) inner.markRequestAsSent(requestId, requestType, responseBody)
if (requestType == RequestType.KEYS_QUERY) { if (requestType == RequestType.KEYS_QUERY) {
updateLiveDevices() updateLiveDevices()
updateLiveUserIdentities() updateLiveUserIdentities()
@ -652,7 +649,18 @@ internal class OlmMachine(
* The key query request will be retried a few time in case of shaky connection, but could fail. * The key query request will be retried a few time in case of shaky connection, but could fail.
*/ */
suspend fun ensureUserDevicesMap(userIds: List<String>, forceDownload: Boolean = false): MXUsersDevicesMap<CryptoDeviceInfo> { suspend fun ensureUserDevicesMap(userIds: List<String>, forceDownload: Boolean = false): MXUsersDevicesMap<CryptoDeviceInfo> {
val toDownload = if (forceDownload) { ensureUsersKeys(userIds, forceDownload)
return getUserDevicesMap(userIds)
}
/**
* If the user is untracked or forceDownload is set to true, a key query request will be made.
* It will suspend until query response.
*
* The key query request will be retried a few time in case of shaky connection, but could fail.
*/
suspend fun ensureUsersKeys(userIds: List<String>, forceDownload: Boolean = false) {
val userIdsToFetchKeys = if (forceDownload) {
userIds userIds
} else { } else {
userIds.mapNotNull { userId -> userIds.mapNotNull { userId ->
@ -661,32 +669,33 @@ internal class OlmMachine(
updateTrackedUsers(it) updateTrackedUsers(it)
} }
} }
tryOrNull("Failed to download keys for $toDownload") { tryOrNull("Failed to download keys for $userIdsToFetchKeys") {
forceKeyDownload(toDownload) forceKeyDownload(userIdsToFetchKeys)
} }
return getUserDevicesMap(userIds)
} }
fun getLiveUserIdentity(userId: String): Flow<Optional<MXCrossSigningInfo>> { fun getLiveUserIdentity(userId: String): Flow<Optional<MXCrossSigningInfo>> {
return channelFlow { return channelFlow {
val userIdentityCollector = UserIdentityCollector(userId, this) val userIdentityCollector = UserIdentityCollector(userId, this)
val onClose = safeInvokeOnClose {
flowCollectors.userIdentityCollectors.remove(userIdentityCollector)
}
flowCollectors.userIdentityCollectors.add(userIdentityCollector) flowCollectors.userIdentityCollectors.add(userIdentityCollector)
val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional() val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional()
send(identity) send(identity)
awaitClose { onClose.await()
flowCollectors.userIdentityCollectors.remove(userIdentityCollector)
}
} }
} }
fun getLivePrivateCrossSigningKeys(): Flow<Optional<PrivateKeysInfo>> { fun getLivePrivateCrossSigningKeys(): Flow<Optional<PrivateKeysInfo>> {
return channelFlow { return channelFlow {
val onClose = safeInvokeOnClose {
flowCollectors.privateKeyCollectors.remove(this)
}
flowCollectors.privateKeyCollectors.add(this) flowCollectors.privateKeyCollectors.add(this)
val keys = this@OlmMachine.exportCrossSigningKeys().toOptional() val keys = this@OlmMachine.exportCrossSigningKeys().toOptional()
send(keys) send(keys)
awaitClose { onClose.await()
flowCollectors.privateKeyCollectors.remove(this)
}
} }
} }
@ -698,17 +707,18 @@ internal class OlmMachine(
* *
* @param userIds The ids of the device owners. * @param userIds The ids of the device owners.
* *
* @return The list of Devices or an empty list if there aren't any. * @return The list of Devices or an empty list if there aren't any as a Flow.
*/ */
fun getLiveDevices(userIds: List<String>): Flow<List<CryptoDeviceInfo>> { fun getLiveDevices(userIds: List<String>): Flow<List<CryptoDeviceInfo>> {
return channelFlow { return channelFlow {
val devicesCollector = DevicesCollector(userIds, this) val devicesCollector = DevicesCollector(userIds, this)
val onClose = safeInvokeOnClose {
flowCollectors.deviceCollectors.remove(devicesCollector)
}
flowCollectors.deviceCollectors.add(devicesCollector) flowCollectors.deviceCollectors.add(devicesCollector)
val devices = getCryptoDeviceInfo(userIds) val devices = getCryptoDeviceInfo(userIds)
send(devices) send(devices)
awaitClose { onClose.await()
flowCollectors.deviceCollectors.remove(devicesCollector)
}
} }
} }