From 604c3932cdba38dc86d0d5c28a249d3c915a4c2c Mon Sep 17 00:00:00 2001
From: valere <valeref@matrix.org>
Date: Fri, 3 Feb 2023 15:38:16 +0100
Subject: [PATCH] Flow collector causing strange NPE in some occasions

---
 .../sdk/internal/crypto/FlowCollectors.kt     | 113 ++++++++++++++++++
 .../android/sdk/internal/crypto/OlmMachine.kt |  45 +++----
 2 files changed, 128 insertions(+), 30 deletions(-)
 create mode 100644 matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt

diff --git a/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt
new file mode 100644
index 0000000000..391c0a2ae7
--- /dev/null
+++ b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2023 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.matrix.android.sdk.internal.crypto
+
+import kotlinx.coroutines.channels.SendChannel
+import kotlinx.coroutines.runBlocking
+import kotlinx.coroutines.sync.Mutex
+import kotlinx.coroutines.sync.withLock
+import org.matrix.android.sdk.api.session.crypto.crosssigning.MXCrossSigningInfo
+import org.matrix.android.sdk.api.session.crypto.crosssigning.PrivateKeysInfo
+import org.matrix.android.sdk.api.session.crypto.model.CryptoDeviceInfo
+import org.matrix.android.sdk.api.util.Optional
+
+internal data class UserIdentityCollector(val userId: String, val collector: SendChannel<Optional<MXCrossSigningInfo>>) :
+        SendChannel<Optional<MXCrossSigningInfo>> by collector
+
+internal data class DevicesCollector(val userIds: List<String>, val collector: SendChannel<List<CryptoDeviceInfo>>) :
+        SendChannel<List<CryptoDeviceInfo>> by collector
+
+private typealias PrivateKeysCollector = SendChannel<Optional<PrivateKeysInfo>>
+
+internal class FlowCollectors {
+    private val userIdentityCollectors = mutableListOf<UserIdentityCollector>()
+    private val privateKeyCollectors = mutableListOf<PrivateKeysCollector>()
+    private val deviceCollectors = ArrayList<DevicesCollector>()
+
+    private val identityLock = Mutex()
+    private val keysLock = Mutex()
+    private val deviceLock = Mutex()
+
+    suspend fun addIdentityCollector(collector: UserIdentityCollector) {
+        identityLock.withLock {
+            userIdentityCollectors.add(collector)
+        }
+    }
+
+    fun removeIdentityCollector(collector: UserIdentityCollector) {
+        // Annoying but it's called when the channel is closed and can't call
+        // something suspendable there :/
+        runBlocking {
+            identityLock.withLock {
+                userIdentityCollectors.remove(collector)
+            }
+        }
+    }
+
+    suspend fun forEachIdentityCollector(block: suspend ((UserIdentityCollector) -> Unit)) {
+        val safeCopy = identityLock.withLock {
+            userIdentityCollectors.toList()
+        }
+        safeCopy.onEach { block(it) }
+    }
+
+    suspend fun addPrivateKeysCollector(collector: PrivateKeysCollector) {
+        keysLock.withLock {
+            privateKeyCollectors.add(collector)
+        }
+    }
+
+    fun removePrivateKeysCollector(collector: PrivateKeysCollector) {
+        // Annoying but it's called when the channel is closed and can't call
+        // something suspendable there :/
+        runBlocking {
+            keysLock.withLock {
+                privateKeyCollectors.remove(collector)
+            }
+        }
+    }
+
+    suspend fun forEachPrivateKeysCollector(block: suspend ((PrivateKeysCollector) -> Unit)) {
+        val safeCopy = keysLock.withLock {
+            privateKeyCollectors.toList()
+        }
+        safeCopy.onEach { block(it) }
+    }
+
+    suspend fun addDevicesCollector(collector: DevicesCollector) {
+        deviceLock.withLock {
+            deviceCollectors.add(collector)
+        }
+    }
+
+    fun removeDevicesCollector(collector: DevicesCollector) {
+        // Annoying but it's called when the channel is closed and can't call
+        // something suspendable there :/
+        runBlocking {
+            deviceLock.withLock {
+                deviceCollectors.remove(collector)
+            }
+        }
+    }
+
+    suspend fun forEachDevicesCollector(block: suspend ((DevicesCollector) -> Unit)) {
+        val safeCopy = deviceLock.withLock {
+            deviceCollectors.toList()
+        }
+        safeCopy.onEach { block(it) }
+    }
+}
diff --git a/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt
index f5b9ec17a1..9c52ef9da5 100644
--- a/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt
+++ b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt
@@ -20,8 +20,6 @@ import androidx.lifecycle.LiveData
 import androidx.lifecycle.asLiveData
 import com.squareup.moshi.Moshi
 import com.squareup.moshi.Types
-import com.squareup.moshi.adapter
-import kotlinx.coroutines.channels.SendChannel
 import kotlinx.coroutines.flow.Flow
 import kotlinx.coroutines.flow.channelFlow
 import kotlinx.coroutines.runBlocking
@@ -99,19 +97,6 @@ private class CryptoProgressListener(private val listener: ProgressListener?) :
     }
 }
 
-private data class UserIdentityCollector(val userId: String, val collector: SendChannel<Optional<MXCrossSigningInfo>>) :
-        SendChannel<Optional<MXCrossSigningInfo>> by collector
-
-private data class DevicesCollector(val userIds: List<String>, val collector: SendChannel<List<CryptoDeviceInfo>>) :
-        SendChannel<List<CryptoDeviceInfo>> by collector
-private typealias PrivateKeysCollector = SendChannel<Optional<PrivateKeysInfo>>
-
-private class FlowCollectors {
-    val userIdentityCollectors = ArrayList<UserIdentityCollector>()
-    val privateKeyCollectors = ArrayList<PrivateKeysCollector>()
-    val deviceCollectors = ArrayList<DevicesCollector>()
-}
-
 fun setRustLogger() {
     setLogger(CryptoLogger() as Logger)
 }
@@ -130,7 +115,7 @@ internal class OlmMachine @Inject constructor(
         private val ensureUsersKeys: EnsureUsersKeysUseCase,
         private val matrixConfiguration: MatrixConfiguration,
         private val megolmSessionImportManager: MegolmSessionImportManager,
-        private val rustEncryptionConfiguration: RustEncryptionConfiguration,
+        rustEncryptionConfiguration: RustEncryptionConfiguration,
 ) {
 
     private val inner: InnerMachine
@@ -165,23 +150,23 @@ internal class OlmMachine @Inject constructor(
     }
 
     private suspend fun updateLiveDevices() {
-        for (deviceCollector in flowCollectors.deviceCollectors) {
-            val devices = getCryptoDeviceInfo(deviceCollector.userIds)
-            deviceCollector.trySend(devices)
+        flowCollectors.forEachDevicesCollector {
+            val devices = getCryptoDeviceInfo(it.userIds)
+            it.trySend(devices)
         }
     }
 
     private suspend fun updateLiveUserIdentities() {
-        for (userIdentityCollector in flowCollectors.userIdentityCollectors) {
-            val identity = getIdentity(userIdentityCollector.userId)?.toMxCrossSigningInfo()
-            userIdentityCollector.trySend(identity.toOptional())
+        flowCollectors.forEachIdentityCollector {
+            val identity = getIdentity(it.userId)?.toMxCrossSigningInfo().toOptional()
+            it.trySend(identity)
         }
     }
 
     private suspend fun updateLivePrivateKeys() {
         val keys = exportCrossSigningKeys().toOptional()
-        for (privateKeyCollector in flowCollectors.privateKeyCollectors) {
-            privateKeyCollector.trySend(keys)
+        flowCollectors.forEachPrivateKeysCollector {
+            it.trySend(keys)
         }
     }
 
@@ -699,9 +684,9 @@ internal class OlmMachine @Inject constructor(
         return channelFlow {
             val userIdentityCollector = UserIdentityCollector(userId, this)
             val onClose = safeInvokeOnClose {
-                flowCollectors.userIdentityCollectors.remove(userIdentityCollector)
+                flowCollectors.removeIdentityCollector(userIdentityCollector)
             }
-            flowCollectors.userIdentityCollectors.add(userIdentityCollector)
+            flowCollectors.addIdentityCollector(userIdentityCollector)
             val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional()
             send(identity)
             onClose.await()
@@ -719,9 +704,9 @@ internal class OlmMachine @Inject constructor(
     fun getPrivateCrossSigningKeysFlow(): Flow<Optional<PrivateKeysInfo>> {
         return channelFlow {
             val onClose = safeInvokeOnClose {
-                flowCollectors.privateKeyCollectors.remove(this)
+                flowCollectors.removePrivateKeysCollector(this)
             }
-            flowCollectors.privateKeyCollectors.add(this)
+            flowCollectors.addPrivateKeysCollector(this)
             val keys = this@OlmMachine.exportCrossSigningKeys().toOptional()
             send(keys)
             onClose.await()
@@ -746,9 +731,9 @@ internal class OlmMachine @Inject constructor(
         return channelFlow {
             val devicesCollector = DevicesCollector(userIds, this)
             val onClose = safeInvokeOnClose {
-                flowCollectors.deviceCollectors.remove(devicesCollector)
+                flowCollectors.removeDevicesCollector(devicesCollector)
             }
-            flowCollectors.deviceCollectors.add(devicesCollector)
+            flowCollectors.addDevicesCollector(devicesCollector)
             val devices = getCryptoDeviceInfo(userIds)
             send(devices)
             onClose.await()