diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt index 6dfdc3cfcb..aa6a1abc7d 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt @@ -16,13 +16,13 @@ package org.matrix.android.sdk.internal.crypto -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext +import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.auth.UserInteractiveAuthInterceptor import org.matrix.android.sdk.api.extensions.tryOrNull import org.matrix.android.sdk.api.listeners.ProgressListener @@ -102,6 +102,7 @@ internal class OlmMachine( device_id: String, path: File, private val requestSender: RequestSender, + private val coroutineDispatchers: MatrixCoroutineDispatchers ) { private val inner: InnerMachine = InnerMachine(user_id, device_id, path.toString()) internal val verificationListeners = ArrayList() @@ -182,7 +183,7 @@ internal class OlmMachine( * @return the list of requests that needs to be sent to the homeserver */ suspend fun outgoingRequests(): List = - withContext(Dispatchers.IO) { inner.outgoingRequests() } + withContext(coroutineDispatchers.io) { inner.outgoingRequests() } /** * Mark a request that was sent to the server as sent. @@ -199,7 +200,7 @@ internal class OlmMachine( requestType: RequestType, responseBody: String ) = - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { inner.markRequestAsSent(requestId, requestType, responseBody) if (requestType == RequestType.KEYS_QUERY) { @@ -227,7 +228,7 @@ internal class OlmMachine( deviceChanges: DeviceListResponse?, keyCounts: DeviceOneTimeKeysCountSyncResponse? ): ToDeviceSyncResponse { - val response = withContext(Dispatchers.IO) { + val response = withContext(coroutineDispatchers.io) { val counts: MutableMap = mutableMapOf() if (keyCounts?.signedCurve25519 != null) { @@ -260,7 +261,7 @@ internal class OlmMachine( * @param users The users that should be queued up for a key query. */ suspend fun updateTrackedUsers(users: List) = - withContext(Dispatchers.IO) { inner.updateTrackedUsers(users) } + withContext(coroutineDispatchers.io) { inner.updateTrackedUsers(users) } /** * Check if the given user is considered to be tracked. @@ -286,7 +287,7 @@ internal class OlmMachine( */ @Throws(CryptoStoreException::class) suspend fun getMissingSessions(users: List): Request? = - withContext(Dispatchers.IO) { inner.getMissingSessions(users) } + withContext(coroutineDispatchers.io) { inner.getMissingSessions(users) } /** * Share a room key with the given list of users for the given room. @@ -308,7 +309,7 @@ internal class OlmMachine( */ @Throws(CryptoStoreException::class) suspend fun shareRoomKey(roomId: String, users: List): List = - withContext(Dispatchers.IO) { inner.shareRoomKey(roomId, users) } + withContext(coroutineDispatchers.io) { inner.shareRoomKey(roomId, users) } /** * Encrypt the given event with the given type and content for the given room. @@ -342,7 +343,7 @@ internal class OlmMachine( */ @Throws(CryptoStoreException::class) suspend fun encrypt(roomId: String, eventType: String, content: Content): Content = - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { val adapter = MoshiProvider.providesMoshi().adapter(Map::class.java) val contentString = adapter.toJson(content) val encrypted = inner.encrypt(roomId, eventType, contentString) @@ -360,7 +361,7 @@ internal class OlmMachine( */ @Throws(MXCryptoError::class) suspend fun decryptRoomEvent(event: Event): MXEventDecryptionResult = - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { val adapter = MoshiProvider.providesMoshi().adapter(Event::class.java) try { if (event.roomId.isNullOrBlank()) { @@ -400,7 +401,7 @@ internal class OlmMachine( */ @Throws(DecryptionException::class) suspend fun requestRoomKey(event: Event): KeyRequestPair = - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { val adapter = MoshiProvider.providesMoshi().adapter(Event::class.java) val serializedEvent = adapter.toJson(event) @@ -419,7 +420,7 @@ internal class OlmMachine( */ @Throws(CryptoStoreException::class) suspend fun exportKeys(passphrase: String, rounds: Int): ByteArray = - withContext(Dispatchers.IO) { inner.exportKeys(passphrase, rounds).toByteArray() } + withContext(coroutineDispatchers.io) { inner.exportKeys(passphrase, rounds).toByteArray() } /** * Import room keys from the given serialized key export. @@ -436,7 +437,7 @@ internal class OlmMachine( passphrase: String, listener: ProgressListener? ): ImportRoomKeysResult = - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { val decodedKeys = String(keys, Charset.defaultCharset()) val rustListener = CryptoProgressListener(listener) @@ -451,7 +452,7 @@ internal class OlmMachine( keys: List, listener: ProgressListener? ): ImportRoomKeysResult = - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { val adapter = MoshiProvider.providesMoshi().adapter(List::class.java) // If the key backup is too big we take the risk of causing OOM @@ -481,7 +482,7 @@ internal class OlmMachine( @Throws(CryptoStoreException::class) suspend fun getIdentity(userId: String): UserIdentities? { - val identity = withContext(Dispatchers.IO) { + val identity = withContext(coroutineDispatchers.io) { inner.getIdentity(userId) } val adapter = MoshiProvider.providesMoshi().adapter(RestKeyInfo::class.java) @@ -547,7 +548,7 @@ internal class OlmMachine( @Throws(CryptoStoreException::class) suspend fun getDevice(userId: String, deviceId: String): Device? { - val device = withContext(Dispatchers.IO) { + val device = withContext(coroutineDispatchers.io) { inner.getDevice(userId, deviceId) } ?: return null @@ -555,7 +556,7 @@ internal class OlmMachine( } suspend fun getUserDevices(userId: String): List { - return withContext(Dispatchers.IO) { + return withContext(coroutineDispatchers.io) { inner.getUserDevices(userId).map { Device(inner, it, requestSender, verificationListeners) } } } @@ -600,7 +601,7 @@ internal class OlmMachine( @Throws suspend fun forceKeyDownload(userIds: List) { - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { val requestId = UUID.randomUUID().toString() val response = requestSender.queryKeys(Request.KeysQuery(requestId, userIds)) markRequestAsSent(requestId, RequestType.KEYS_QUERY, response) @@ -759,7 +760,7 @@ internal class OlmMachine( } suspend fun bootstrapCrossSigning(uiaInterceptor: UserInteractiveAuthInterceptor?) { - val requests = withContext(Dispatchers.IO) { + val requests = withContext(coroutineDispatchers.io) { inner.bootstrapCrossSigning() } @@ -775,7 +776,7 @@ internal class OlmMachine( } suspend fun exportCrossSigningKeys(): PrivateKeysInfo? { - val export = withContext(Dispatchers.IO) { + val export = withContext(coroutineDispatchers.io) { inner.exportCrossSigningKeys() } ?: return null @@ -786,7 +787,7 @@ internal class OlmMachine( val rustExport = CrossSigningKeyExport(export.master, export.selfSigned, export.user) var result: UserTrustResult - withContext(Dispatchers.IO) { + withContext(coroutineDispatchers.io) { result = try { inner.importCrossSigningKeys(rustExport) @@ -803,21 +804,21 @@ internal class OlmMachine( UserTrustResult.Failure(failure.localizedMessage) } } - withContext(Dispatchers.Main) { + withContext(coroutineDispatchers.main) { this@OlmMachine.updateLivePrivateKeys() } return result } suspend fun sign(message: String): Map> { - return withContext(Dispatchers.Default) { + return withContext(coroutineDispatchers.computation) { inner.sign(message) } } @Throws(CryptoStoreException::class) suspend fun enableBackupV1(key: String, version: String) { - return withContext(Dispatchers.Default) { + return withContext(coroutineDispatchers.computation) { val backupKey = MegolmV1BackupKey(key, mapOf(), null, MXCRYPTO_ALGORITHM_MEGOLM_BACKUP) inner.enableBackupV1(backupKey, version) } @@ -834,28 +835,28 @@ internal class OlmMachine( @Throws(CryptoStoreException::class) suspend fun roomKeyCounts(): RoomKeyCounts { - return withContext(Dispatchers.Default) { + return withContext(coroutineDispatchers.computation) { inner.roomKeyCounts() } } @Throws(CryptoStoreException::class) suspend fun getBackupKeys(): BackupKeys? { - return withContext(Dispatchers.Default) { + return withContext(coroutineDispatchers.computation) { inner.getBackupKeys() } } @Throws(CryptoStoreException::class) suspend fun saveRecoveryKey(key: String?, version: String?) { - withContext(Dispatchers.Default) { + withContext(coroutineDispatchers.computation) { inner.saveRecoveryKey(key, version) } } @Throws(CryptoStoreException::class) suspend fun backupRoomKeys(): Request? { - return withContext(Dispatchers.Default) { + return withContext(coroutineDispatchers.computation) { Timber.d("BACKUP CREATING REQUEST") val request = inner.backupRoomKeys() Timber.d("BACKUP CREATED REQUEST: $request") @@ -865,7 +866,7 @@ internal class OlmMachine( @Throws(CryptoStoreException::class) suspend fun checkAuthDataSignature(authData: MegolmBackupAuthData): Boolean { - return withContext(Dispatchers.Default) { + return withContext(coroutineDispatchers.computation) { val adapter = MoshiProvider .providesMoshi() .newBuilder() diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachineProvider.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachineProvider.kt index 6aa59afc69..513a9297d8 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachineProvider.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmMachineProvider.kt @@ -16,6 +16,7 @@ package org.matrix.android.sdk.internal.crypto +import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.internal.di.DeviceId import org.matrix.android.sdk.internal.di.SessionFilesDirectory import org.matrix.android.sdk.internal.di.UserId @@ -28,8 +29,9 @@ internal class OlmMachineProvider @Inject constructor( @UserId private val userId: String, @DeviceId private val deviceId: String?, @SessionFilesDirectory private val dataDir: File, - requestSender: RequestSender + requestSender: RequestSender, + coroutineDispatchers: MatrixCoroutineDispatchers ) { - var olmMachine: OlmMachine = OlmMachine(userId, deviceId!!, dataDir, requestSender) + var olmMachine: OlmMachine = OlmMachine(userId, deviceId!!, dataDir, requestSender, coroutineDispatchers) }