Use mutex

This commit is contained in:
Hugh Nimmo-Smith 2022-10-18 08:48:28 +01:00
parent 8a62dfb34a
commit f297117df2

View file

@ -19,6 +19,8 @@ package org.matrix.android.sdk.api.rendezvous.channels
import android.util.Base64
import com.squareup.moshi.Json
import com.squareup.moshi.JsonClass
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import okhttp3.MediaType.Companion.toMediaType
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.rendezvous.RendezvousChannel
@ -71,6 +73,7 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
@Json val iv: String? = null
)
private var olmSASMutex = Mutex()
private var olmSAS: OlmSAS?
private val ourPublicKey: ByteArray
private val ecdhAdapter = MatrixJsonParser.getMoshi().adapter(ECDHPayload::class.java)
@ -87,7 +90,7 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
@Throws(RendezvousError::class)
override suspend fun connect(): String {
olmSAS ?.let { olmSAS ->
val sas = olmSAS ?: throw RendezvousError("Channel closed", RendezvousFailureReason.Unknown)
val isInitiator = theirPublicKey == null
if (isInitiator) {
@ -112,20 +115,19 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
)
}
synchronized(olmSAS) {
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
olmSASMutex.withLock {
sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP)
val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP)
val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey"
aesKey = olmSAS.generateShortCode(aesInfo, 32)
aesKey = sas.generateShortCode(aesInfo, 32)
val rawChecksum = olmSAS.generateShortCode(aesInfo, 5)
val rawChecksum = sas.generateShortCode(aesInfo, 5)
return getDecimalCodeRepresentation(rawChecksum)
}
} ?: throw RuntimeException("Channel closed")
}
private suspend fun send(payload: ECDHPayload) {
@ -154,13 +156,12 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
}
override suspend fun close() {
olmSAS ?.let {
synchronized(it) {
val sas = olmSAS ?: throw IllegalStateException("Channel already closed")
olmSASMutex.withLock {
// this does a double release check already so we don't re-check ourselves
it.releaseSas()
sas.releaseSas()
olmSAS = null
}
}
transport.close()
}