From 69e4b6e8a4fa9bcf08abefb4cbe53cc096f4f8e6 Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 26 Nov 2021 13:56:36 +0100 Subject: [PATCH] Improve key decryption perf --- .../crypto/keysbackup/RustKeyBackupService.kt | 61 ++++++++++--------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/RustKeyBackupService.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/RustKeyBackupService.kt index 40d27c926d..3320ec089a 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/RustKeyBackupService.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/RustKeyBackupService.kt @@ -23,6 +23,8 @@ import androidx.annotation.WorkerThread import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.withContext @@ -486,42 +488,42 @@ internal class RustKeyBackupService @Inject constructor( val data = getKeys(sessionId, roomId, keysVersionResult.version) return withContext(coroutineDispatchers.computation) { - val sessionsData = ArrayList() - // Restore that data - var sessionsFromHsCount = 0 - cryptoCoroutineScope.launch(Dispatchers.Main) { + withContext(Dispatchers.Main) { stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(0, data.roomIdToRoomKeysBackupData.size)) } - var progressDecryptIndex = 0 - - // TODO this is quite long, could we add some concurrency here? - for ((roomIdLoop, backupData) in data.roomIdToRoomKeysBackupData) { - val roomIndex = progressDecryptIndex - progressDecryptIndex++ - cryptoCoroutineScope.launch(Dispatchers.Main) { - stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(roomIndex, data.roomIdToRoomKeysBackupData.size)) - } - for ((sessionIdLoop, keyBackupData) in backupData.sessionIdToKeyBackupData) { - sessionsFromHsCount++ - - val sessionData = decryptKeyBackupData(keyBackupData, sessionIdLoop, roomIdLoop, recoveryKey) - - // rust is not very lax and will throw if field are missing, - // add a check - // TODO maybe could be done on rust side? - sessionData?.takeIf { - it.isValid().also { - if (!it) { - Timber.w("restoreKeysWithRecoveryKey: malformed sessionData $sessionData") + // Decrypting by chunk of 500 keys in parallel + // we loose proper progress report but tested 3x faster on big backup + val sessionsData = data.roomIdToRoomKeysBackupData + .mapValues { + it.value.sessionIdToKeyBackupData + } + .flatMap { flat -> + flat.value.entries.map { flat.key to it } + } + .chunked(500) + .map { slice -> + async { + slice.mapNotNull { pair -> + decryptKeyBackupData(pair.second.value, pair.second.key, pair.first, recoveryKey) + ?.takeIf { sessionData -> + sessionData.isValid().also { + if (!it) { + Timber.w("restoreKeysWithRecoveryKey: malformed sessionData $sessionData") + } + } + } } } - }?.let { - sessionsData.add(it) } - } + .awaitAll() + .flatten() + + withContext(Dispatchers.Main) { + stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(data.roomIdToRoomKeysBackupData.size, data.roomIdToRoomKeysBackupData.size)) } + Timber.v("restoreKeysWithRecoveryKey: Decrypted ${sessionsData.size} keys out" + - " of $sessionsFromHsCount from the backup store on the homeserver") + " of ${data.roomIdToRoomKeysBackupData.size} rooms from the backup store on the homeserver") // Do not trigger a backup for them if they come from the backup version we are using val backUp = keysVersionResult.version != keysBackupVersion?.version @@ -534,7 +536,6 @@ internal class RustKeyBackupService @Inject constructor( val progressListener = if (stepProgressListener != null) { object : ProgressListener { override fun onProgress(progress: Int, total: Int) { - // Note: no need to post to UI thread, importMegolmSessionsData() will do it cryptoCoroutineScope.launch(Dispatchers.Main) { stepProgressListener.onStepProgress(StepProgressListener.Step.ImportingKey(progress, total)) }