Enhance decryption to prevent DUPLICATED_MESSAGE_INDEX when decrypting the same eventId

Improve code format
This commit is contained in:
ariskotsomitopoulos 2022-05-16 13:05:38 +03:00
parent cad9d443be
commit 2e08c07dad
3 changed files with 218 additions and 43 deletions

View file

@ -0,0 +1,160 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.crypto.replay_attack
import android.util.Log
import androidx.test.filters.LargeTest
import org.junit.Assert
import org.junit.Assert.assertEquals
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.runners.MethodSorters
import org.matrix.android.sdk.InstrumentedTest
import org.matrix.android.sdk.api.session.events.model.EventType
import org.matrix.android.sdk.api.session.events.model.toModel
import org.matrix.android.sdk.api.session.room.Room
import org.matrix.android.sdk.api.session.room.model.message.MessageContent
import org.matrix.android.sdk.api.session.room.send.SendState
import org.matrix.android.sdk.api.session.room.timeline.TimelineSettings
import org.matrix.android.sdk.common.CommonTestHelper
import org.matrix.android.sdk.common.CryptoTestHelper
import org.matrix.android.sdk.common.TestConstants
@RunWith(JUnit4::class)
@FixMethodOrder(MethodSorters.JVM)
@LargeTest
class ReplayAttackTest : InstrumentedTest {
@Test
fun replayAttackTest() {
val testHelper = CommonTestHelper(context())
val cryptoTestHelper = CryptoTestHelper(testHelper)
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)
val e2eRoomID = cryptoTestData.roomId
// Alice
val aliceSession = cryptoTestData.firstSession
val aliceRoomPOV = aliceSession.roomService().getRoom(e2eRoomID)!!
// Bob
val bobSession = cryptoTestData.secondSession
val bobRoomPOV = bobSession!!.roomService().getRoom(e2eRoomID)!!
assertEquals(bobRoomPOV.roomSummary()?.joinedMembersCount, 2)
Log.v("##REPLAY ATTACK", "Alice and Bob are in roomId: $e2eRoomID")
val sentEvents = testHelper.sendTextMessage(aliceRoomPOV, "Hello", 20)
// val aliceMessageId: String? = sendMessageInRoom(aliceRoomPOV, "Hello Bob, I am Alice!", testHelper)
Assert.assertTrue("Message should be sent", sentEvents.size == 20)
Log.v("##REPLAY ATTACK", "Alice sent message to roomId: $e2eRoomID")
// Bob should be able to decrypt the message
// testHelper.waitWithLatch { latch ->
// testHelper.retryPeriodicallyWithLatch(latch) {
// val timelineEvent = bobSession.roomService().getRoom(e2eRoomID)?.timelineService()?.getTimelineEvent(aliceMessageId!!)
// (timelineEvent != null &&
// timelineEvent.isEncrypted() &&
// timelineEvent.root.getClearType() == EventType.MESSAGE).also {
// if (it) {
// Log.v("#E2E TEST", "Bob can decrypt the message: ${timelineEvent?.root?.getDecryptedTextSummary()}")
// }
// }
// }
// }
//
// // Create a new user
// val arisSession = testHelper.createAccount("aris", SessionTestParams(true))
// Log.v("#E2E TEST", "Aris user created")
//
// // Alice invites new user to the room
// testHelper.runBlockingTest {
// Log.v("#E2E TEST", "Alice invites ${arisSession.myUserId}")
// aliceRoomPOV.membershipService().invite(arisSession.myUserId)
// }
//
// waitForAndAcceptInviteInRoom(arisSession, e2eRoomID, testHelper)
//
// ensureMembersHaveJoined(aliceSession, arrayListOf(arisSession), e2eRoomID, testHelper)
// Log.v("#E2E TEST", "Aris has joined roomId: $e2eRoomID")
//
// when (roomHistoryVisibility) {
// RoomHistoryVisibility.WORLD_READABLE,
// RoomHistoryVisibility.SHARED,
// null
// -> {
// // Aris should be able to decrypt the message
// testHelper.waitWithLatch { latch ->
// testHelper.retryPeriodicallyWithLatch(latch) {
// val timelineEvent = arisSession.roomService().getRoom(e2eRoomID)?.timelineService()?.getTimelineEvent(aliceMessageId!!)
// (timelineEvent != null &&
// timelineEvent.isEncrypted() &&
// timelineEvent.root.getClearType() == EventType.MESSAGE
// ).also {
// if (it) {
// Log.v("#E2E TEST", "Aris can decrypt the message: ${timelineEvent?.root?.getDecryptedTextSummary()}")
// }
// }
// }
// }
// }
// RoomHistoryVisibility.INVITED,
// RoomHistoryVisibility.JOINED -> {
// // Aris should not even be able to get the message
// testHelper.waitWithLatch { latch ->
// testHelper.retryPeriodicallyWithLatch(latch) {
// val timelineEvent = arisSession.roomService().getRoom(e2eRoomID)
// ?.timelineService()
// ?.getTimelineEvent(aliceMessageId!!)
// timelineEvent == null
// }
// }
// }
// }
// testHelper.signOutAndClose(arisSession)
cryptoTestData.cleanUp(testHelper)
}
private fun sendMessageInRoom(aliceRoomPOV: Room, text: String, testHelper: CommonTestHelper): String? {
aliceRoomPOV.sendService().sendTextMessage(text)
var sentEventId: String? = null
testHelper.waitWithLatch(4 * TestConstants.timeOutMillis) { latch ->
val timeline = aliceRoomPOV.timelineService().createTimeline(null, TimelineSettings(60))
timeline.start()
testHelper.retryPeriodicallyWithLatch(latch) {
val decryptedMsg = timeline.getSnapshot()
.filter { it.root.getClearType() == EventType.MESSAGE }
.also { list ->
val message = list.joinToString(",", "[", "]") { "${it.root.type}|${it.root.sendState}" }
Log.v("#E2E TEST", "Timeline snapshot is $message")
}
.filter { it.root.sendState == SendState.SYNCED }
.firstOrNull { it.root.getClearContent().toModel<MessageContent>()?.body?.startsWith(text) == true }
sentEventId = decryptedMsg?.eventId
decryptedMsg != null
}
timeline.dispose()
}
return sentEventId
}
}

View file

@ -99,6 +99,8 @@ internal class MXOlmDevice @Inject constructor(
// The second level keys are strings of form "<senderKey>|<session_id>|<message_index>" // The second level keys are strings of form "<senderKey>|<session_id>|<message_index>"
private val inboundGroupSessionMessageIndexes: MutableMap<String, MutableSet<String>> = HashMap() private val inboundGroupSessionMessageIndexes: MutableMap<String, MutableSet<String>> = HashMap()
private val replayAttackMap: MutableMap<String, String> = HashMap()
init { init {
// Retrieve the account from the store // Retrieve the account from the store
try { try {
@ -763,59 +765,71 @@ internal class MXOlmDevice @Inject constructor(
suspend fun decryptGroupMessage(body: String, suspend fun decryptGroupMessage(body: String,
roomId: String, roomId: String,
timeline: String?, timeline: String?,
eventId: String,
sessionId: String, sessionId: String,
senderKey: String): OlmDecryptionResult { senderKey: String): OlmDecryptionResult {
val sessionHolder = getInboundGroupSession(sessionId, senderKey, roomId) val sessionHolder = getInboundGroupSession(sessionId, senderKey, roomId)
val wrapper = sessionHolder.wrapper val wrapper = sessionHolder.wrapper
val inboundGroupSession = wrapper.olmInboundGroupSession val inboundGroupSession = wrapper.olmInboundGroupSession
?: throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, "Session is null") ?: throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, "Session is null")
// Check that the room id matches the original one for the session. This stops if (roomId != wrapper.roomId) {
// the HS pretending a message was targeting a different room. // Check that the room id matches the original one for the session. This stops
if (roomId == wrapper.roomId) { // the HS pretending a message was targeting a different room.
val decryptResult = try {
sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) {
Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e)
}
if (timeline?.isNotBlank() == true) {
val timelineSet = inboundGroupSessionMessageIndexes.getOrPut(timeline) { mutableSetOf() }
val messageIndexKey = senderKey + "|" + sessionId + "|" + decryptResult.mIndex
if (timelineSet.contains(messageIndexKey)) {
val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() timelineId=$timeline: $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason)
}
timelineSet.add(messageIndexKey)
}
inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
adapter.fromJson(payloadString)
} catch (e: Exception) {
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
}
return OlmDecryptionResult(
payload,
wrapper.keysClaimed,
senderKey,
wrapper.forwardingCurve25519KeyChain
)
} else {
val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, wrapper.roomId) val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, wrapper.roomId)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason") Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason) throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason)
} }
val decryptResult = try {
sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) {
Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e)
}
val messageIndexKey = senderKey + "|" + sessionId + "|" + roomId + "|" + decryptResult.mIndex
Timber.tag(loggerTag.value).d("##########################################################")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() timeline: $timeline")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() senderKey: $senderKey")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() sessionId: $sessionId")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() roomId: $roomId")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() eventId: $eventId")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() mIndex: ${decryptResult.mIndex}")
if (timeline?.isNotBlank() == true) {
val timelineSet = inboundGroupSessionMessageIndexes.getOrPut(timeline) { mutableSetOf() }
if (timelineSet.contains(messageIndexKey) && messageIndexKey.alreadyUsed(eventId)) {
val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() timelineId=$timeline: $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason)
}
timelineSet.add(messageIndexKey)
}
replayAttackMap[messageIndexKey] = eventId
inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
adapter.fromJson(payloadString)
} catch (e: Exception) {
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
}
return OlmDecryptionResult(
payload,
wrapper.keysClaimed,
senderKey,
wrapper.forwardingCurve25519KeyChain
)
}
/**
* Determines whether or not the messageKey has already been used to decrypt another eventId
*/
private fun String.alreadyUsed(eventId: String): Boolean {
return replayAttackMap[this] != null && replayAttackMap[this] != eventId
} }
/** /**

View file

@ -78,6 +78,7 @@ internal class MXMegolmDecryption(
encryptedEventContent.ciphertext, encryptedEventContent.ciphertext,
event.roomId, event.roomId,
timeline, timeline,
eventId = event.eventId.orEmpty(),
encryptedEventContent.sessionId, encryptedEventContent.sessionId,
encryptedEventContent.senderKey encryptedEventContent.senderKey
) )