Improve key decryption perf

This commit is contained in:
Valere 2021-11-26 13:56:36 +01:00
parent 1635c9730a
commit 69e4b6e8a4

View file

@ -23,6 +23,8 @@ import androidx.annotation.WorkerThread
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
@ -486,42 +488,42 @@ internal class RustKeyBackupService @Inject constructor(
val data = getKeys(sessionId, roomId, keysVersionResult.version) val data = getKeys(sessionId, roomId, keysVersionResult.version)
return withContext(coroutineDispatchers.computation) { return withContext(coroutineDispatchers.computation) {
val sessionsData = ArrayList<MegolmSessionData>() withContext(Dispatchers.Main) {
// Restore that data
var sessionsFromHsCount = 0
cryptoCoroutineScope.launch(Dispatchers.Main) {
stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(0, data.roomIdToRoomKeysBackupData.size)) stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(0, data.roomIdToRoomKeysBackupData.size))
} }
var progressDecryptIndex = 0 // Decrypting by chunk of 500 keys in parallel
// we loose proper progress report but tested 3x faster on big backup
// TODO this is quite long, could we add some concurrency here? val sessionsData = data.roomIdToRoomKeysBackupData
for ((roomIdLoop, backupData) in data.roomIdToRoomKeysBackupData) { .mapValues {
val roomIndex = progressDecryptIndex it.value.sessionIdToKeyBackupData
progressDecryptIndex++ }
cryptoCoroutineScope.launch(Dispatchers.Main) { .flatMap { flat ->
stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(roomIndex, data.roomIdToRoomKeysBackupData.size)) flat.value.entries.map { flat.key to it }
} }
for ((sessionIdLoop, keyBackupData) in backupData.sessionIdToKeyBackupData) { .chunked(500)
sessionsFromHsCount++ .map { slice ->
async {
val sessionData = decryptKeyBackupData(keyBackupData, sessionIdLoop, roomIdLoop, recoveryKey) slice.mapNotNull { pair ->
decryptKeyBackupData(pair.second.value, pair.second.key, pair.first, recoveryKey)
// rust is not very lax and will throw if field are missing, ?.takeIf { sessionData ->
// add a check sessionData.isValid().also {
// TODO maybe could be done on rust side? if (!it) {
sessionData?.takeIf { Timber.w("restoreKeysWithRecoveryKey: malformed sessionData $sessionData")
it.isValid().also { }
if (!it) { }
Timber.w("restoreKeysWithRecoveryKey: malformed sessionData $sessionData") }
} }
} }
}?.let {
sessionsData.add(it)
} }
} .awaitAll()
.flatten()
withContext(Dispatchers.Main) {
stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(data.roomIdToRoomKeysBackupData.size, data.roomIdToRoomKeysBackupData.size))
} }
Timber.v("restoreKeysWithRecoveryKey: Decrypted ${sessionsData.size} keys out" + Timber.v("restoreKeysWithRecoveryKey: Decrypted ${sessionsData.size} keys out" +
" of $sessionsFromHsCount from the backup store on the homeserver") " of ${data.roomIdToRoomKeysBackupData.size} rooms from the backup store on the homeserver")
// Do not trigger a backup for them if they come from the backup version we are using // Do not trigger a backup for them if they come from the backup version we are using
val backUp = keysVersionResult.version != keysBackupVersion?.version val backUp = keysVersionResult.version != keysBackupVersion?.version
@ -534,7 +536,6 @@ internal class RustKeyBackupService @Inject constructor(
val progressListener = if (stepProgressListener != null) { val progressListener = if (stepProgressListener != null) {
object : ProgressListener { object : ProgressListener {
override fun onProgress(progress: Int, total: Int) { override fun onProgress(progress: Int, total: Int) {
// Note: no need to post to UI thread, importMegolmSessionsData() will do it
cryptoCoroutineScope.launch(Dispatchers.Main) { cryptoCoroutineScope.launch(Dispatchers.Main) {
stepProgressListener.onStepProgress(StepProgressListener.Step.ImportingKey(progress, total)) stepProgressListener.onStepProgress(StepProgressListener.Step.ImportingKey(progress, total))
} }