From e3d2186c25eeaf2446f55420250c259dd454485e Mon Sep 17 00:00:00 2001
From: Benoit Marty <benoit@matrix.org>
Date: Fri, 12 Mar 2021 18:41:31 +0100
Subject: [PATCH] Rework UpdateTrustWorker, should have better perf.

---
 .../crypto/crosssigning/UpdateTrustWorker.kt  | 326 +++++++++---------
 .../android/sdk/internal/util/LogUtil.kt      |  12 +
 2 files changed, 176 insertions(+), 162 deletions(-)

diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt
index 1660bae0b7..a2a90c49cd 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt
@@ -36,12 +36,14 @@ import org.matrix.android.sdk.internal.crypto.store.db.model.UserEntityFields
 import org.matrix.android.sdk.internal.database.model.RoomMemberSummaryEntity
 import org.matrix.android.sdk.internal.database.model.RoomMemberSummaryEntityFields
 import org.matrix.android.sdk.internal.database.model.RoomSummaryEntity
+import org.matrix.android.sdk.internal.database.model.RoomSummaryEntityFields
 import org.matrix.android.sdk.internal.database.query.where
 import org.matrix.android.sdk.internal.di.CryptoDatabase
 import org.matrix.android.sdk.internal.di.SessionDatabase
 import org.matrix.android.sdk.internal.di.UserId
 import org.matrix.android.sdk.internal.session.SessionComponent
 import org.matrix.android.sdk.internal.session.room.membership.RoomMemberHelper
+import org.matrix.android.sdk.internal.util.logLimit
 import org.matrix.android.sdk.internal.worker.SessionSafeCoroutineWorker
 import org.matrix.android.sdk.internal.worker.SessionWorkerParams
 import timber.log.Timber
@@ -65,11 +67,14 @@ internal class UpdateTrustWorker(context: Context,
     @Inject lateinit var crossSigningService: DefaultCrossSigningService
 
     // It breaks the crypto store contract, but we need to batch things :/
-    @CryptoDatabase @Inject lateinit var realmConfiguration: RealmConfiguration
-    @UserId @Inject lateinit var myUserId: String
+    @CryptoDatabase
+    @Inject lateinit var cryptoRealmConfiguration: RealmConfiguration
+    @SessionDatabase
+    @Inject lateinit var sessionRealmConfiguration: RealmConfiguration
+    @UserId
+    @Inject lateinit var myUserId: String
     @Inject lateinit var crossSigningKeysMapper: CrossSigningKeysMapper
     @Inject lateinit var updateTrustWorkerDataRepository: UpdateTrustWorkerDataRepository
-    @SessionDatabase @Inject lateinit var sessionRealmConfiguration: RealmConfiguration
 
     //    @Inject lateinit var roomSummaryUpdater: RoomSummaryUpdater
     @Inject lateinit var cryptoStore: IMXCryptoStore
@@ -79,118 +84,115 @@ internal class UpdateTrustWorker(context: Context,
     }
 
     override suspend fun doSafeWork(params: Params): Result {
-        var userList = params.filename
+        val userList = params.filename
                 ?.let { updateTrustWorkerDataRepository.getParam(it) }
                 ?.userIds
                 ?: params.updatedUserIds.orEmpty()
 
-        if (userList.isEmpty()) {
-            // This should not happen, but let's avoid go further in case of empty list
-            cleanup(params)
-            return Result.success()
+        // List should not be empty, but let's avoid go further in case of empty list
+        if (userList.isNotEmpty()) {
+            // Unfortunately we don't have much info on what did exactly changed (is it the cross signing keys of that user,
+            // or a new device?) So we check all again :/
+            Timber.d("## CrossSigning - Updating trust for users: ${userList.logLimit()}")
+
+            Realm.getInstance(cryptoRealmConfiguration).use { cryptoRealm ->
+                Realm.getInstance(sessionRealmConfiguration).use { sessionRealm ->
+                    updateTrust(userList, cryptoRealm, sessionRealm)
+                }
+            }
         }
 
-        // Unfortunately we don't have much info on what did exactly changed (is it the cross signing keys of that user,
-        // or a new device?) So we check all again :/
-
-        Timber.d("## CrossSigning - Updating trust for $userList")
+        cleanup(params)
+        return Result.success()
+    }
 
+    private fun updateTrust(userListParam: List<String>,
+                            cRealm: Realm,
+                            sRealm: Realm) {
+        var userList = userListParam
+        var myCrossSigningInfo: MXCrossSigningInfo? = null
         // First we check that the users MSK are trusted by mine
         // After that we check the trust chain for each devices of each users
-        Realm.getInstance(realmConfiguration).use { realm ->
-            realm.executeTransaction {
-                // By mapping here to model, this object is not live
-                // I should update it if needed
-                var myCrossSigningInfo = realm.where(CrossSigningInfoEntity::class.java)
-                        .equalTo(CrossSigningInfoEntityFields.USER_ID, myUserId)
-                        .findFirst()?.let { mapCrossSigningInfoEntity(it) }
+        cRealm.executeTransaction { cryptoRealm ->
+            // By mapping here to model, this object is not live
+            // I should update it if needed
+            myCrossSigningInfo = getCrossSigningInfo(cryptoRealm, myUserId)
 
-                var myTrustResult: UserTrustResult? = null
+            var myTrustResult: UserTrustResult? = null
 
-                if (userList.contains(myUserId)) {
-                    Timber.d("## CrossSigning - Clear all trust as a change on my user was detected")
-                    // i am in the list.. but i don't know exactly the delta of change :/
-                    // If it's my cross signing keys we should refresh all trust
-                    // do it anyway ?
-                    userList = realm.where(CrossSigningInfoEntity::class.java)
-                            .findAll().mapNotNull { it.userId }
-                    Timber.d("## CrossSigning - Updating trust for all $userList")
+            if (userList.contains(myUserId)) {
+                Timber.d("## CrossSigning - Clear all trust as a change on my user was detected")
+                // i am in the list.. but i don't know exactly the delta of change :/
+                // If it's my cross signing keys we should refresh all trust
+                // do it anyway ?
+                userList = cryptoRealm.where(CrossSigningInfoEntity::class.java)
+                        .findAll()
+                        .mapNotNull { it.userId }
 
-                    // check right now my keys and mark it as trusted as other trust depends on it
-                    val myDevices = realm.where<UserEntity>()
-                            .equalTo(UserEntityFields.USER_ID, myUserId)
-                            .findFirst()
-                            ?.devices
-                            ?.map { deviceInfo ->
-                                CryptoMapper.mapToModel(deviceInfo)
-                            }
-                    myTrustResult = crossSigningService.checkSelfTrust(myCrossSigningInfo, myDevices).also {
-                        updateCrossSigningKeysTrust(realm, myUserId, it.isVerified())
-                        // update model reference
-                        myCrossSigningInfo = realm.where(CrossSigningInfoEntity::class.java)
-                                .equalTo(CrossSigningInfoEntityFields.USER_ID, myUserId)
-                                .findFirst()?.let { mapCrossSigningInfoEntity(it) }
-                    }
-                }
+                // check right now my keys and mark it as trusted as other trust depends on it
+                val myDevices = cryptoRealm.where<UserEntity>()
+                        .equalTo(UserEntityFields.USER_ID, myUserId)
+                        .findFirst()
+                        ?.devices
+                        ?.map { CryptoMapper.mapToModel(it) }
 
-                val otherInfos = userList.map {
-                    it to realm.where(CrossSigningInfoEntity::class.java)
-                            .equalTo(CrossSigningInfoEntityFields.USER_ID, it)
-                            .findFirst()?.let { mapCrossSigningInfoEntity(it) }
-                }
-                        .toMap()
+                myTrustResult = crossSigningService.checkSelfTrust(myCrossSigningInfo, myDevices)
+                updateCrossSigningKeysTrust(cryptoRealm, myUserId, myTrustResult.isVerified())
+                // update model reference
+                myCrossSigningInfo = getCrossSigningInfo(cryptoRealm, myUserId)
+            }
 
-                val trusts = otherInfos.map { infoEntry ->
-                    infoEntry.key to when (infoEntry.key) {
-                        myUserId -> myTrustResult
-                        else     -> {
-                            crossSigningService.checkOtherMSKTrusted(myCrossSigningInfo, infoEntry.value).also {
-                                Timber.d("## CrossSigning - user:${infoEntry.key} result:$it")
-                            }
+            val otherInfos = userList.associateWith { userId ->
+                getCrossSigningInfo(cryptoRealm, userId)
+            }
+
+            val trusts = otherInfos.mapValues { entry ->
+                when (entry.key) {
+                    myUserId -> myTrustResult
+                    else     -> {
+                        crossSigningService.checkOtherMSKTrusted(myCrossSigningInfo, entry.value).also {
+                            Timber.d("## CrossSigning - user:${entry.key} result:$it")
                         }
                     }
-                }.toMap()
+                }
+            }
 
-                // TODO! if it's me and my keys has changed... I have to reset trust for everyone!
-                // i have all the new trusts, update DB
-                trusts.forEach {
-                    val verified = it.value?.isVerified() == true
-                    updateCrossSigningKeysTrust(realm, it.key, verified)
+            // TODO! if it's me and my keys has changed... I have to reset trust for everyone!
+            // i have all the new trusts, update DB
+            trusts.forEach {
+                val verified = it.value?.isVerified() == true
+                updateCrossSigningKeysTrust(cryptoRealm, it.key, verified)
+            }
+
+            // Ok so now we have to check device trust for all these users..
+            Timber.v("## CrossSigning - Updating devices cross trust users: ${trusts.keys.logLimit()}")
+            trusts.keys.forEach { userId ->
+                val devicesEntities = cryptoRealm.where<UserEntity>()
+                        .equalTo(UserEntityFields.USER_ID, userId)
+                        .findFirst()
+                        ?.devices
+
+                val trustMap = devicesEntities?.associateWith { device ->
+                    // get up to date from DB has could have been updated
+                    val otherInfo = getCrossSigningInfo(cryptoRealm, userId)
+                    crossSigningService.checkDeviceTrust(myCrossSigningInfo, otherInfo, CryptoMapper.mapToModel(device))
                 }
 
-                // Ok so now we have to check device trust for all these users..
-                Timber.v("## CrossSigning - Updating devices cross trust users ${trusts.keys}")
-                trusts.keys.forEach {
-                    val devicesEntities = realm.where<UserEntity>()
-                            .equalTo(UserEntityFields.USER_ID, it)
-                            .findFirst()
-                            ?.devices
-
-                    val trustMap = devicesEntities?.map { device ->
-                        // get up to date from DB has could have been updated
-                        val otherInfo = realm.where(CrossSigningInfoEntity::class.java)
-                                .equalTo(CrossSigningInfoEntityFields.USER_ID, it)
-                                .findFirst()?.let { mapCrossSigningInfoEntity(it) }
-                        device to crossSigningService.checkDeviceTrust(myCrossSigningInfo, otherInfo, CryptoMapper.mapToModel(device))
-                    }?.toMap()
-
-                    // Update trust if needed
-                    devicesEntities?.forEach { device ->
-                        val crossSignedVerified = trustMap?.get(device)?.isCrossSignedVerified()
-                        Timber.d("## CrossSigning - Trust for ${device.userId}|${device.deviceId} : cross verified: ${trustMap?.get(device)}")
-                        if (device.trustLevelEntity?.crossSignedVerified != crossSignedVerified) {
-                            Timber.d("## CrossSigning - Trust change detected for ${device.userId}|${device.deviceId} : cross verified: $crossSignedVerified")
-                            // need to save
-                            val trustEntity = device.trustLevelEntity
-                            if (trustEntity == null) {
-                                realm.createObject(TrustLevelEntity::class.java).let {
-                                    it.locallyVerified = false
-                                    it.crossSignedVerified = crossSignedVerified
-                                    device.trustLevelEntity = it
-                                }
-                            } else {
-                                trustEntity.crossSignedVerified = crossSignedVerified
+                // Update trust if needed
+                devicesEntities?.forEach { device ->
+                    val crossSignedVerified = trustMap?.get(device)?.isCrossSignedVerified()
+                    Timber.d("## CrossSigning - Trust for ${device.userId}|${device.deviceId} : cross verified: ${trustMap?.get(device)}")
+                    if (device.trustLevelEntity?.crossSignedVerified != crossSignedVerified) {
+                        Timber.d("## CrossSigning - Trust change detected for ${device.userId}|${device.deviceId} : cross verified: $crossSignedVerified")
+                        // need to save
+                        val trustEntity = device.trustLevelEntity
+                        if (trustEntity == null) {
+                            device.trustLevelEntity = cryptoRealm.createObject(TrustLevelEntity::class.java).also {
+                                it.locallyVerified = false
+                                it.crossSignedVerified = crossSignedVerified
                             }
+                        } else {
+                            trustEntity.crossSignedVerified = crossSignedVerified
                         }
                     }
                 }
@@ -201,35 +203,44 @@ internal class UpdateTrustWorker(context: Context,
         // We can now update room shields? in the session DB?
 
         Timber.d("## CrossSigning - Updating shields for impacted rooms...")
-        Realm.getInstance(sessionRealmConfiguration).use { it ->
-            it.executeTransaction { realm ->
-                val distinctRoomIds = realm.where(RoomMemberSummaryEntity::class.java)
-                        .`in`(RoomMemberSummaryEntityFields.USER_ID, userList.toTypedArray())
-                        .distinct(RoomMemberSummaryEntityFields.ROOM_ID)
-                        .findAll()
-                        .map { it.roomId }
-                Timber.d("## CrossSigning -  ... impacted rooms $distinctRoomIds")
-                distinctRoomIds.forEach { roomId ->
-                    val roomSummary = RoomSummaryEntity.where(realm, roomId).findFirst()
-                    if (roomSummary?.isEncrypted == true) {
-                        Timber.d("## CrossSigning - Check shield state for room $roomId")
-                        val allActiveRoomMembers = RoomMemberHelper(realm, roomId).getActiveRoomMemberIds()
-                        try {
-                            val updatedTrust = computeRoomShield(allActiveRoomMembers, roomSummary)
-                            if (roomSummary.roomEncryptionTrustLevel != updatedTrust) {
-                                Timber.d("## CrossSigning - Shield change detected for $roomId -> $updatedTrust")
-                                roomSummary.roomEncryptionTrustLevel = updatedTrust
-                            }
-                        } catch (failure: Throwable) {
-                            Timber.e(failure)
-                        }
+        sRealm.executeTransaction { sessionRealm ->
+            sessionRealm.where(RoomMemberSummaryEntity::class.java)
+                    .`in`(RoomMemberSummaryEntityFields.USER_ID, userList.toTypedArray())
+                    .distinct(RoomMemberSummaryEntityFields.ROOM_ID)
+                    .findAll()
+                    .map { it.roomId }
+                    .also { Timber.d("## CrossSigning -  ... impacted rooms ${it.logLimit()}") }
+                    .forEach { roomId ->
+                        RoomSummaryEntity.where(sessionRealm, roomId)
+                                .equalTo(RoomSummaryEntityFields.IS_ENCRYPTED, true)
+                                .findFirst()
+                                ?.let { roomSummary ->
+                                    Timber.d("## CrossSigning - Check shield state for room $roomId")
+                                    val allActiveRoomMembers = RoomMemberHelper(sessionRealm, roomId).getActiveRoomMemberIds()
+                                    try {
+                                        val updatedTrust = computeRoomShield(
+                                                myCrossSigningInfo,
+                                                cRealm,
+                                                allActiveRoomMembers,
+                                                roomSummary
+                                        )
+                                        if (roomSummary.roomEncryptionTrustLevel != updatedTrust) {
+                                            Timber.d("## CrossSigning - Shield change detected for $roomId -> $updatedTrust")
+                                            roomSummary.roomEncryptionTrustLevel = updatedTrust
+                                        }
+                                    } catch (failure: Throwable) {
+                                        Timber.e(failure)
+                                    }
+                                }
                     }
-                }
-            }
         }
+    }
 
-        cleanup(params)
-        return Result.success()
+    private fun getCrossSigningInfo(cryptoRealm: Realm, userId: String): MXCrossSigningInfo? {
+        return cryptoRealm.where(CrossSigningInfoEntity::class.java)
+                .equalTo(CrossSigningInfoEntityFields.USER_ID, userId)
+                .findFirst()
+                ?.let { mapCrossSigningInfoEntity(it) }
     }
 
     private fun cleanup(params: Params) {
@@ -237,30 +248,34 @@ internal class UpdateTrustWorker(context: Context,
                 ?.let { updateTrustWorkerDataRepository.delete(it) }
     }
 
-    private fun updateCrossSigningKeysTrust(realm: Realm, userId: String, verified: Boolean) {
-        val xInfoEntity = realm.where(CrossSigningInfoEntity::class.java)
+    private fun updateCrossSigningKeysTrust(cryptoRealm: Realm, userId: String, verified: Boolean) {
+        cryptoRealm.where(CrossSigningInfoEntity::class.java)
                 .equalTo(CrossSigningInfoEntityFields.USER_ID, userId)
                 .findFirst()
-        xInfoEntity?.crossSigningKeys?.forEach { info ->
-            // optimization to avoid trigger updates when there is no change..
-            if (info.trustLevelEntity?.isVerified() != verified) {
-                Timber.d("## CrossSigning - Trust change for $userId : $verified")
-                val level = info.trustLevelEntity
-                if (level == null) {
-                    val newLevel = realm.createObject(TrustLevelEntity::class.java)
-                    newLevel.locallyVerified = verified
-                    newLevel.crossSignedVerified = verified
-                    info.trustLevelEntity = newLevel
-                } else {
-                    level.locallyVerified = verified
-                    level.crossSignedVerified = verified
+                ?.crossSigningKeys
+                ?.forEach { info ->
+                    // optimization to avoid trigger updates when there is no change..
+                    if (info.trustLevelEntity?.isVerified() != verified) {
+                        Timber.d("## CrossSigning - Trust change for $userId : $verified")
+                        val level = info.trustLevelEntity
+                        if (level == null) {
+                            info.trustLevelEntity = cryptoRealm.createObject(TrustLevelEntity::class.java).also {
+                                it.locallyVerified = verified
+                                it.crossSignedVerified = verified
+                            }
+                        } else {
+                            level.locallyVerified = verified
+                            level.crossSignedVerified = verified
+                        }
+                    }
                 }
-            }
-        }
     }
 
-    private fun computeRoomShield(activeMemberUserIds: List<String>, roomSummaryEntity: RoomSummaryEntity): RoomEncryptionTrustLevel {
-        Timber.d("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} -> $activeMemberUserIds")
+    private fun computeRoomShield(myCrossSigningInfo: MXCrossSigningInfo?,
+                                  cryptoRealm: Realm,
+                                  activeMemberUserIds: List<String>,
+                                  roomSummaryEntity: RoomSummaryEntity): RoomEncryptionTrustLevel {
+        Timber.d("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} -> ${activeMemberUserIds.logLimit()}")
         // The set of “all users” depends on the type of room:
         // For regular / topic rooms which have more than 2 members (including yourself) are considered when decorating a room
         // For 1:1 and group DM rooms, all other users (i.e. excluding yourself) are considered when decorating a room
@@ -272,17 +287,8 @@ internal class UpdateTrustWorker(context: Context,
 
         val allTrustedUserIds = listToCheck
                 .filter { userId ->
-                    Realm.getInstance(realmConfiguration).use {
-                        it.where(CrossSigningInfoEntity::class.java)
-                                .equalTo(CrossSigningInfoEntityFields.USER_ID, userId)
-                                .findFirst()?.let { mapCrossSigningInfoEntity(it) }?.isTrusted() == true
-                    }
+                    getCrossSigningInfo(cryptoRealm, userId)?.isTrusted() == true
                 }
-        val myCrossKeys = Realm.getInstance(realmConfiguration).use {
-            it.where(CrossSigningInfoEntity::class.java)
-                    .equalTo(CrossSigningInfoEntityFields.USER_ID, myUserId)
-                    .findFirst()?.let { mapCrossSigningInfoEntity(it) }
-        }
 
         return if (allTrustedUserIds.isEmpty()) {
             RoomEncryptionTrustLevel.Default
@@ -291,21 +297,17 @@ internal class UpdateTrustWorker(context: Context,
             // If all devices of all verified users are trusted -> green
             // else -> black
             allTrustedUserIds
-                    .mapNotNull { uid ->
-                        Realm.getInstance(realmConfiguration).use {
-                            it.where<UserEntity>()
-                                    .equalTo(UserEntityFields.USER_ID, uid)
-                                    .findFirst()
-                                    ?.devices
-                                    ?.map {
-                                        CryptoMapper.mapToModel(it)
-                                    }
-                        }
+                    .mapNotNull { userId ->
+                        cryptoRealm.where<UserEntity>()
+                                .equalTo(UserEntityFields.USER_ID, userId)
+                                .findFirst()
+                                ?.devices
+                                ?.map { CryptoMapper.mapToModel(it) }
                     }
                     .flatten()
                     .let { allDevices ->
-                        Timber.v("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} devices ${allDevices.map { it.deviceId }}")
-                        if (myCrossKeys != null) {
+                        Timber.v("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} devices ${allDevices.map { it.deviceId }.logLimit()}")
+                        if (myCrossSigningInfo != null) {
                             allDevices.any { !it.trustLevel?.crossSigningVerified.orFalse() }
                         } else {
                             // Legacy method
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/util/LogUtil.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/util/LogUtil.kt
index fe68b49a5c..bfa723c160 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/util/LogUtil.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/util/LogUtil.kt
@@ -19,6 +19,18 @@ package org.matrix.android.sdk.internal.util
 import org.matrix.android.sdk.BuildConfig
 import timber.log.Timber
 
+internal fun <T> Collection<T>.logLimit(maxQuantity: Int = 5): String {
+    return buildString {
+        append(size)
+        append(" item(s)")
+        if (size > maxQuantity) {
+            append(", first $maxQuantity items")
+        }
+        append(": ")
+        append(this@logLimit.take(maxQuantity))
+    }
+}
+
 internal suspend fun <T> logDuration(message: String,
                                      block: suspend () -> T): T {
     Timber.v("$message -- BEGIN")