Suspend: fix flow builders

This commit is contained in:
ganfra 2022-04-25 17:55:17 +02:00
parent 5581b82ab4
commit 309a290cb8
2 changed files with 78 additions and 24 deletions

View file

@ -0,0 +1,44 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.coroutines.builder
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.channels.ProducerScope
/**
* Use this with a flow builder like [kotlinx.coroutines.flow.channelFlow] to replace [kotlinx.coroutines.channels.awaitClose].
* As awaitClose is at the end of the builder block, it can lead to the block being cancelled before it reaches the awaitClose.
* Example of usage:
*
* return channelFlow {
* val onClose = safeInvokeOnClose {
* // Do stuff on close
* }
* val data = getData()
* send(data)
* onClose.await()
* }
*
*/
fun <T> ProducerScope<T>.safeInvokeOnClose(handler: (cause: Throwable?) -> Unit): CompletableDeferred<Unit> {
val onClose = CompletableDeferred<Unit>()
invokeOnClose {
handler(it)
onClose.complete(Unit)
}
return onClose
}

View file

@ -18,7 +18,6 @@ package org.matrix.android.sdk.internal.crypto
import com.squareup.moshi.Moshi
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.runBlocking
@ -39,6 +38,7 @@ import org.matrix.android.sdk.api.session.sync.model.ToDeviceSyncResponse
import org.matrix.android.sdk.api.util.JsonDict
import org.matrix.android.sdk.api.util.Optional
import org.matrix.android.sdk.api.util.toOptional
import org.matrix.android.sdk.internal.coroutines.builder.safeInvokeOnClose
import org.matrix.android.sdk.internal.crypto.crosssigning.DeviceTrustLevel
import org.matrix.android.sdk.internal.crypto.crosssigning.UserTrustResult
import org.matrix.android.sdk.internal.crypto.keysbackup.model.MegolmBackupAuthData
@ -48,7 +48,6 @@ import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.rest.RestKeyInfo
import org.matrix.android.sdk.internal.crypto.model.rest.UnsignedDeviceInfo
import org.matrix.android.sdk.internal.crypto.store.PrivateKeysInfo
import org.matrix.android.sdk.internal.di.MoshiProvider
import org.matrix.android.sdk.internal.network.parsing.CheckNumberType
import timber.log.Timber
import uniffi.olm.BackupKeys
@ -67,7 +66,6 @@ import uniffi.olm.setLogger
import java.io.File
import java.nio.charset.Charset
import java.util.UUID
import java.util.concurrent.CopyOnWriteArrayList
import uniffi.olm.OlmMachine as InnerMachine
import uniffi.olm.ProgressListener as RustProgressListener
import uniffi.olm.UserIdentity as RustUserIdentity
@ -89,9 +87,9 @@ private data class DevicesCollector(val userIds: List<String>, val collector: Se
private typealias PrivateKeysCollector = SendChannel<Optional<PrivateKeysInfo>>
private class FlowCollectors {
val userIdentityCollectors = CopyOnWriteArrayList<UserIdentityCollector>()
val privateKeyCollectors = CopyOnWriteArrayList<PrivateKeysCollector>()
val deviceCollectors = CopyOnWriteArrayList<DevicesCollector>()
val userIdentityCollectors = ArrayList<UserIdentityCollector>()
val privateKeyCollectors = ArrayList<PrivateKeysCollector>()
val deviceCollectors = ArrayList<DevicesCollector>()
}
fun setRustLogger() {
@ -132,21 +130,21 @@ internal class OlmMachine(
private suspend fun updateLiveDevices() {
for (deviceCollector in flowCollectors.deviceCollectors) {
val devices = getCryptoDeviceInfo(deviceCollector.userIds)
deviceCollector.send(devices)
deviceCollector.trySend(devices)
}
}
private suspend fun updateLiveUserIdentities() {
for (userIdentityCollector in flowCollectors.userIdentityCollectors) {
val identity = getIdentity(userIdentityCollector.userId)?.toMxCrossSigningInfo()
userIdentityCollector.send(identity.toOptional())
userIdentityCollector.trySend(identity.toOptional())
}
}
private suspend fun updateLivePrivateKeys() {
val keys = exportCrossSigningKeys().toOptional()
for (privateKeyCollector in flowCollectors.privateKeyCollectors) {
privateKeyCollector.send(keys)
privateKeyCollector.trySend(keys)
}
}
@ -204,7 +202,6 @@ internal class OlmMachine(
) =
withContext(coroutineDispatchers.io) {
inner.markRequestAsSent(requestId, requestType, responseBody)
if (requestType == RequestType.KEYS_QUERY) {
updateLiveDevices()
updateLiveUserIdentities()
@ -652,7 +649,18 @@ internal class OlmMachine(
* The key query request will be retried a few time in case of shaky connection, but could fail.
*/
suspend fun ensureUserDevicesMap(userIds: List<String>, forceDownload: Boolean = false): MXUsersDevicesMap<CryptoDeviceInfo> {
val toDownload = if (forceDownload) {
ensureUsersKeys(userIds, forceDownload)
return getUserDevicesMap(userIds)
}
/**
* If the user is untracked or forceDownload is set to true, a key query request will be made.
* It will suspend until query response.
*
* The key query request will be retried a few time in case of shaky connection, but could fail.
*/
suspend fun ensureUsersKeys(userIds: List<String>, forceDownload: Boolean = false) {
val userIdsToFetchKeys = if (forceDownload) {
userIds
} else {
userIds.mapNotNull { userId ->
@ -661,32 +669,33 @@ internal class OlmMachine(
updateTrackedUsers(it)
}
}
tryOrNull("Failed to download keys for $toDownload") {
forceKeyDownload(toDownload)
tryOrNull("Failed to download keys for $userIdsToFetchKeys") {
forceKeyDownload(userIdsToFetchKeys)
}
return getUserDevicesMap(userIds)
}
fun getLiveUserIdentity(userId: String): Flow<Optional<MXCrossSigningInfo>> {
return channelFlow {
val userIdentityCollector = UserIdentityCollector(userId, this)
val onClose = safeInvokeOnClose {
flowCollectors.userIdentityCollectors.remove(userIdentityCollector)
}
flowCollectors.userIdentityCollectors.add(userIdentityCollector)
val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional()
send(identity)
awaitClose {
flowCollectors.userIdentityCollectors.remove(userIdentityCollector)
}
onClose.await()
}
}
fun getLivePrivateCrossSigningKeys(): Flow<Optional<PrivateKeysInfo>> {
return channelFlow {
val onClose = safeInvokeOnClose {
flowCollectors.privateKeyCollectors.remove(this)
}
flowCollectors.privateKeyCollectors.add(this)
val keys = this@OlmMachine.exportCrossSigningKeys().toOptional()
send(keys)
awaitClose {
flowCollectors.privateKeyCollectors.remove(this)
}
onClose.await()
}
}
@ -698,17 +707,18 @@ internal class OlmMachine(
*
* @param userIds The ids of the device owners.
*
* @return The list of Devices or an empty list if there aren't any.
* @return The list of Devices or an empty list if there aren't any as a Flow.
*/
fun getLiveDevices(userIds: List<String>): Flow<List<CryptoDeviceInfo>> {
return channelFlow {
val devicesCollector = DevicesCollector(userIds, this)
val onClose = safeInvokeOnClose {
flowCollectors.deviceCollectors.remove(devicesCollector)
}
flowCollectors.deviceCollectors.add(devicesCollector)
val devices = getCryptoDeviceInfo(userIds)
send(devices)
awaitClose {
flowCollectors.deviceCollectors.remove(devicesCollector)
}
onClose.await()
}
}