Use mutex

This commit is contained in:
Hugh Nimmo-Smith 2022-10-18 08:48:28 +01:00
parent 8a62dfb34a
commit f297117df2

View file

@ -19,6 +19,8 @@ package org.matrix.android.sdk.api.rendezvous.channels
import android.util.Base64 import android.util.Base64
import com.squareup.moshi.Json import com.squareup.moshi.Json
import com.squareup.moshi.JsonClass import com.squareup.moshi.JsonClass
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import okhttp3.MediaType.Companion.toMediaType import okhttp3.MediaType.Companion.toMediaType
import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.rendezvous.RendezvousChannel import org.matrix.android.sdk.api.rendezvous.RendezvousChannel
@ -71,6 +73,7 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
@Json val iv: String? = null @Json val iv: String? = null
) )
private var olmSASMutex = Mutex()
private var olmSAS: OlmSAS? private var olmSAS: OlmSAS?
private val ourPublicKey: ByteArray private val ourPublicKey: ByteArray
private val ecdhAdapter = MatrixJsonParser.getMoshi().adapter(ECDHPayload::class.java) private val ecdhAdapter = MatrixJsonParser.getMoshi().adapter(ECDHPayload::class.java)
@ -87,45 +90,44 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
@Throws(RendezvousError::class) @Throws(RendezvousError::class)
override suspend fun connect(): String { override suspend fun connect(): String {
olmSAS ?.let { olmSAS -> val sas = olmSAS ?: throw RendezvousError("Channel closed", RendezvousFailureReason.Unknown)
val isInitiator = theirPublicKey == null val isInitiator = theirPublicKey == null
if (isInitiator) { if (isInitiator) {
Timber.tag(TAG).i("Waiting for other device to send their public key") Timber.tag(TAG).i("Waiting for other device to send their public key")
val res = this.receiveAsPayload() ?: throw RendezvousError("No reply from other device", RendezvousFailureReason.ProtocolError) val res = this.receiveAsPayload() ?: throw RendezvousError("No reply from other device", RendezvousFailureReason.ProtocolError)
if (res.key == null) { if (res.key == null) {
throw RendezvousError( throw RendezvousError(
"Unsupported algorithm: ${res.algorithm}", "Unsupported algorithm: ${res.algorithm}",
RendezvousFailureReason.UnsupportedAlgorithm, RendezvousFailureReason.UnsupportedAlgorithm,
)
}
theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP)
} else {
// send our public key unencrypted
Timber.tag(TAG).i("Sending public key")
send(
ECDHPayload(
algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1,
key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP)
)
) )
} }
theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP)
} else {
// send our public key unencrypted
Timber.tag(TAG).i("Sending public key")
send(
ECDHPayload(
algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1,
key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP)
)
)
}
synchronized(olmSAS) { olmSASMutex.withLock {
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP) val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP)
val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP) val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP)
val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey" val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey"
aesKey = olmSAS.generateShortCode(aesInfo, 32) aesKey = sas.generateShortCode(aesInfo, 32)
val rawChecksum = olmSAS.generateShortCode(aesInfo, 5) val rawChecksum = sas.generateShortCode(aesInfo, 5)
return getDecimalCodeRepresentation(rawChecksum) return getDecimalCodeRepresentation(rawChecksum)
} }
} ?: throw RuntimeException("Channel closed")
} }
private suspend fun send(payload: ECDHPayload) { private suspend fun send(payload: ECDHPayload) {
@ -154,12 +156,11 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
} }
override suspend fun close() { override suspend fun close() {
olmSAS ?.let { val sas = olmSAS ?: throw IllegalStateException("Channel already closed")
synchronized(it) { olmSASMutex.withLock {
// this does a double release check already so we don't re-check ourselves // this does a double release check already so we don't re-check ourselves
it.releaseSas() sas.releaseSas()
olmSAS = null olmSAS = null
}
} }
transport.close() transport.close()
} }