Use MatrixCoroutineDispatchers in OlmMachine

This commit is contained in:
ganfra 2022-04-14 16:33:48 +02:00
parent 91daa1ab90
commit d020d1f6e0
2 changed files with 33 additions and 30 deletions

View file

@ -16,13 +16,13 @@
package org.matrix.android.sdk.internal.crypto package org.matrix.android.sdk.internal.crypto
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.auth.UserInteractiveAuthInterceptor import org.matrix.android.sdk.api.auth.UserInteractiveAuthInterceptor
import org.matrix.android.sdk.api.extensions.tryOrNull import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.api.listeners.ProgressListener import org.matrix.android.sdk.api.listeners.ProgressListener
@ -102,6 +102,7 @@ internal class OlmMachine(
device_id: String, device_id: String,
path: File, path: File,
private val requestSender: RequestSender, private val requestSender: RequestSender,
private val coroutineDispatchers: MatrixCoroutineDispatchers
) { ) {
private val inner: InnerMachine = InnerMachine(user_id, device_id, path.toString()) private val inner: InnerMachine = InnerMachine(user_id, device_id, path.toString())
internal val verificationListeners = ArrayList<VerificationService.Listener>() internal val verificationListeners = ArrayList<VerificationService.Listener>()
@ -182,7 +183,7 @@ internal class OlmMachine(
* @return the list of requests that needs to be sent to the homeserver * @return the list of requests that needs to be sent to the homeserver
*/ */
suspend fun outgoingRequests(): List<Request> = suspend fun outgoingRequests(): List<Request> =
withContext(Dispatchers.IO) { inner.outgoingRequests() } withContext(coroutineDispatchers.io) { inner.outgoingRequests() }
/** /**
* Mark a request that was sent to the server as sent. * Mark a request that was sent to the server as sent.
@ -199,7 +200,7 @@ internal class OlmMachine(
requestType: RequestType, requestType: RequestType,
responseBody: String responseBody: String
) = ) =
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
inner.markRequestAsSent(requestId, requestType, responseBody) inner.markRequestAsSent(requestId, requestType, responseBody)
if (requestType == RequestType.KEYS_QUERY) { if (requestType == RequestType.KEYS_QUERY) {
@ -227,7 +228,7 @@ internal class OlmMachine(
deviceChanges: DeviceListResponse?, deviceChanges: DeviceListResponse?,
keyCounts: DeviceOneTimeKeysCountSyncResponse? keyCounts: DeviceOneTimeKeysCountSyncResponse?
): ToDeviceSyncResponse { ): ToDeviceSyncResponse {
val response = withContext(Dispatchers.IO) { val response = withContext(coroutineDispatchers.io) {
val counts: MutableMap<String, Int> = mutableMapOf() val counts: MutableMap<String, Int> = mutableMapOf()
if (keyCounts?.signedCurve25519 != null) { if (keyCounts?.signedCurve25519 != null) {
@ -260,7 +261,7 @@ internal class OlmMachine(
* @param users The users that should be queued up for a key query. * @param users The users that should be queued up for a key query.
*/ */
suspend fun updateTrackedUsers(users: List<String>) = suspend fun updateTrackedUsers(users: List<String>) =
withContext(Dispatchers.IO) { inner.updateTrackedUsers(users) } withContext(coroutineDispatchers.io) { inner.updateTrackedUsers(users) }
/** /**
* Check if the given user is considered to be tracked. * Check if the given user is considered to be tracked.
@ -286,7 +287,7 @@ internal class OlmMachine(
*/ */
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun getMissingSessions(users: List<String>): Request? = suspend fun getMissingSessions(users: List<String>): Request? =
withContext(Dispatchers.IO) { inner.getMissingSessions(users) } withContext(coroutineDispatchers.io) { inner.getMissingSessions(users) }
/** /**
* Share a room key with the given list of users for the given room. * Share a room key with the given list of users for the given room.
@ -308,7 +309,7 @@ internal class OlmMachine(
*/ */
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun shareRoomKey(roomId: String, users: List<String>): List<Request> = suspend fun shareRoomKey(roomId: String, users: List<String>): List<Request> =
withContext(Dispatchers.IO) { inner.shareRoomKey(roomId, users) } withContext(coroutineDispatchers.io) { inner.shareRoomKey(roomId, users) }
/** /**
* Encrypt the given event with the given type and content for the given room. * Encrypt the given event with the given type and content for the given room.
@ -342,7 +343,7 @@ internal class OlmMachine(
*/ */
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun encrypt(roomId: String, eventType: String, content: Content): Content = suspend fun encrypt(roomId: String, eventType: String, content: Content): Content =
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
val adapter = MoshiProvider.providesMoshi().adapter<Content>(Map::class.java) val adapter = MoshiProvider.providesMoshi().adapter<Content>(Map::class.java)
val contentString = adapter.toJson(content) val contentString = adapter.toJson(content)
val encrypted = inner.encrypt(roomId, eventType, contentString) val encrypted = inner.encrypt(roomId, eventType, contentString)
@ -360,7 +361,7 @@ internal class OlmMachine(
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
suspend fun decryptRoomEvent(event: Event): MXEventDecryptionResult = suspend fun decryptRoomEvent(event: Event): MXEventDecryptionResult =
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
val adapter = MoshiProvider.providesMoshi().adapter(Event::class.java) val adapter = MoshiProvider.providesMoshi().adapter(Event::class.java)
try { try {
if (event.roomId.isNullOrBlank()) { if (event.roomId.isNullOrBlank()) {
@ -400,7 +401,7 @@ internal class OlmMachine(
*/ */
@Throws(DecryptionException::class) @Throws(DecryptionException::class)
suspend fun requestRoomKey(event: Event): KeyRequestPair = suspend fun requestRoomKey(event: Event): KeyRequestPair =
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
val adapter = MoshiProvider.providesMoshi().adapter(Event::class.java) val adapter = MoshiProvider.providesMoshi().adapter(Event::class.java)
val serializedEvent = adapter.toJson(event) val serializedEvent = adapter.toJson(event)
@ -419,7 +420,7 @@ internal class OlmMachine(
*/ */
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun exportKeys(passphrase: String, rounds: Int): ByteArray = suspend fun exportKeys(passphrase: String, rounds: Int): ByteArray =
withContext(Dispatchers.IO) { inner.exportKeys(passphrase, rounds).toByteArray() } withContext(coroutineDispatchers.io) { inner.exportKeys(passphrase, rounds).toByteArray() }
/** /**
* Import room keys from the given serialized key export. * Import room keys from the given serialized key export.
@ -436,7 +437,7 @@ internal class OlmMachine(
passphrase: String, passphrase: String,
listener: ProgressListener? listener: ProgressListener?
): ImportRoomKeysResult = ): ImportRoomKeysResult =
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
val decodedKeys = String(keys, Charset.defaultCharset()) val decodedKeys = String(keys, Charset.defaultCharset())
val rustListener = CryptoProgressListener(listener) val rustListener = CryptoProgressListener(listener)
@ -451,7 +452,7 @@ internal class OlmMachine(
keys: List<MegolmSessionData>, keys: List<MegolmSessionData>,
listener: ProgressListener? listener: ProgressListener?
): ImportRoomKeysResult = ): ImportRoomKeysResult =
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
val adapter = MoshiProvider.providesMoshi().adapter(List::class.java) val adapter = MoshiProvider.providesMoshi().adapter(List::class.java)
// If the key backup is too big we take the risk of causing OOM // If the key backup is too big we take the risk of causing OOM
@ -481,7 +482,7 @@ internal class OlmMachine(
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun getIdentity(userId: String): UserIdentities? { suspend fun getIdentity(userId: String): UserIdentities? {
val identity = withContext(Dispatchers.IO) { val identity = withContext(coroutineDispatchers.io) {
inner.getIdentity(userId) inner.getIdentity(userId)
} }
val adapter = MoshiProvider.providesMoshi().adapter(RestKeyInfo::class.java) val adapter = MoshiProvider.providesMoshi().adapter(RestKeyInfo::class.java)
@ -547,7 +548,7 @@ internal class OlmMachine(
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun getDevice(userId: String, deviceId: String): Device? { suspend fun getDevice(userId: String, deviceId: String): Device? {
val device = withContext(Dispatchers.IO) { val device = withContext(coroutineDispatchers.io) {
inner.getDevice(userId, deviceId) inner.getDevice(userId, deviceId)
} ?: return null } ?: return null
@ -555,7 +556,7 @@ internal class OlmMachine(
} }
suspend fun getUserDevices(userId: String): List<Device> { suspend fun getUserDevices(userId: String): List<Device> {
return withContext(Dispatchers.IO) { return withContext(coroutineDispatchers.io) {
inner.getUserDevices(userId).map { Device(inner, it, requestSender, verificationListeners) } inner.getUserDevices(userId).map { Device(inner, it, requestSender, verificationListeners) }
} }
} }
@ -600,7 +601,7 @@ internal class OlmMachine(
@Throws @Throws
suspend fun forceKeyDownload(userIds: List<String>) { suspend fun forceKeyDownload(userIds: List<String>) {
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
val requestId = UUID.randomUUID().toString() val requestId = UUID.randomUUID().toString()
val response = requestSender.queryKeys(Request.KeysQuery(requestId, userIds)) val response = requestSender.queryKeys(Request.KeysQuery(requestId, userIds))
markRequestAsSent(requestId, RequestType.KEYS_QUERY, response) markRequestAsSent(requestId, RequestType.KEYS_QUERY, response)
@ -759,7 +760,7 @@ internal class OlmMachine(
} }
suspend fun bootstrapCrossSigning(uiaInterceptor: UserInteractiveAuthInterceptor?) { suspend fun bootstrapCrossSigning(uiaInterceptor: UserInteractiveAuthInterceptor?) {
val requests = withContext(Dispatchers.IO) { val requests = withContext(coroutineDispatchers.io) {
inner.bootstrapCrossSigning() inner.bootstrapCrossSigning()
} }
@ -775,7 +776,7 @@ internal class OlmMachine(
} }
suspend fun exportCrossSigningKeys(): PrivateKeysInfo? { suspend fun exportCrossSigningKeys(): PrivateKeysInfo? {
val export = withContext(Dispatchers.IO) { val export = withContext(coroutineDispatchers.io) {
inner.exportCrossSigningKeys() inner.exportCrossSigningKeys()
} ?: return null } ?: return null
@ -786,7 +787,7 @@ internal class OlmMachine(
val rustExport = CrossSigningKeyExport(export.master, export.selfSigned, export.user) val rustExport = CrossSigningKeyExport(export.master, export.selfSigned, export.user)
var result: UserTrustResult var result: UserTrustResult
withContext(Dispatchers.IO) { withContext(coroutineDispatchers.io) {
result = try { result = try {
inner.importCrossSigningKeys(rustExport) inner.importCrossSigningKeys(rustExport)
@ -803,21 +804,21 @@ internal class OlmMachine(
UserTrustResult.Failure(failure.localizedMessage) UserTrustResult.Failure(failure.localizedMessage)
} }
} }
withContext(Dispatchers.Main) { withContext(coroutineDispatchers.main) {
this@OlmMachine.updateLivePrivateKeys() this@OlmMachine.updateLivePrivateKeys()
} }
return result return result
} }
suspend fun sign(message: String): Map<String, Map<String, String>> { suspend fun sign(message: String): Map<String, Map<String, String>> {
return withContext(Dispatchers.Default) { return withContext(coroutineDispatchers.computation) {
inner.sign(message) inner.sign(message)
} }
} }
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun enableBackupV1(key: String, version: String) { suspend fun enableBackupV1(key: String, version: String) {
return withContext(Dispatchers.Default) { return withContext(coroutineDispatchers.computation) {
val backupKey = MegolmV1BackupKey(key, mapOf(), null, MXCRYPTO_ALGORITHM_MEGOLM_BACKUP) val backupKey = MegolmV1BackupKey(key, mapOf(), null, MXCRYPTO_ALGORITHM_MEGOLM_BACKUP)
inner.enableBackupV1(backupKey, version) inner.enableBackupV1(backupKey, version)
} }
@ -834,28 +835,28 @@ internal class OlmMachine(
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun roomKeyCounts(): RoomKeyCounts { suspend fun roomKeyCounts(): RoomKeyCounts {
return withContext(Dispatchers.Default) { return withContext(coroutineDispatchers.computation) {
inner.roomKeyCounts() inner.roomKeyCounts()
} }
} }
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun getBackupKeys(): BackupKeys? { suspend fun getBackupKeys(): BackupKeys? {
return withContext(Dispatchers.Default) { return withContext(coroutineDispatchers.computation) {
inner.getBackupKeys() inner.getBackupKeys()
} }
} }
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun saveRecoveryKey(key: String?, version: String?) { suspend fun saveRecoveryKey(key: String?, version: String?) {
withContext(Dispatchers.Default) { withContext(coroutineDispatchers.computation) {
inner.saveRecoveryKey(key, version) inner.saveRecoveryKey(key, version)
} }
} }
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun backupRoomKeys(): Request? { suspend fun backupRoomKeys(): Request? {
return withContext(Dispatchers.Default) { return withContext(coroutineDispatchers.computation) {
Timber.d("BACKUP CREATING REQUEST") Timber.d("BACKUP CREATING REQUEST")
val request = inner.backupRoomKeys() val request = inner.backupRoomKeys()
Timber.d("BACKUP CREATED REQUEST: $request") Timber.d("BACKUP CREATED REQUEST: $request")
@ -865,7 +866,7 @@ internal class OlmMachine(
@Throws(CryptoStoreException::class) @Throws(CryptoStoreException::class)
suspend fun checkAuthDataSignature(authData: MegolmBackupAuthData): Boolean { suspend fun checkAuthDataSignature(authData: MegolmBackupAuthData): Boolean {
return withContext(Dispatchers.Default) { return withContext(coroutineDispatchers.computation) {
val adapter = MoshiProvider val adapter = MoshiProvider
.providesMoshi() .providesMoshi()
.newBuilder() .newBuilder()

View file

@ -16,6 +16,7 @@
package org.matrix.android.sdk.internal.crypto package org.matrix.android.sdk.internal.crypto
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.internal.di.DeviceId import org.matrix.android.sdk.internal.di.DeviceId
import org.matrix.android.sdk.internal.di.SessionFilesDirectory import org.matrix.android.sdk.internal.di.SessionFilesDirectory
import org.matrix.android.sdk.internal.di.UserId import org.matrix.android.sdk.internal.di.UserId
@ -28,8 +29,9 @@ internal class OlmMachineProvider @Inject constructor(
@UserId private val userId: String, @UserId private val userId: String,
@DeviceId private val deviceId: String?, @DeviceId private val deviceId: String?,
@SessionFilesDirectory private val dataDir: File, @SessionFilesDirectory private val dataDir: File,
requestSender: RequestSender requestSender: RequestSender,
coroutineDispatchers: MatrixCoroutineDispatchers
) { ) {
var olmMachine: OlmMachine = OlmMachine(userId, deviceId!!, dataDir, requestSender) var olmMachine: OlmMachine = OlmMachine(userId, deviceId!!, dataDir, requestSender, coroutineDispatchers)
} }