Improve attachment encryption and decryption code

This commit is contained in:
Benoit Marty 2020-01-13 21:06:29 +01:00
parent 96d6b75037
commit 159c96681f
9 changed files with 117 additions and 104 deletions

View file

@ -55,8 +55,6 @@ class AttachmentEncryptionTest {
assertNotNull(decryptedStream)
inputStream.close()
val buffer = ByteArray(100)
val len = decryptedStream!!.read(buffer)

View file

@ -31,7 +31,7 @@ import im.vector.matrix.android.internal.crypto.store.db.RealmCryptoStoreModule
import im.vector.matrix.android.internal.crypto.tasks.*
import im.vector.matrix.android.internal.database.RealmKeysUtils
import im.vector.matrix.android.internal.di.CryptoDatabase
import im.vector.matrix.android.internal.di.UserCacheDirectory
import im.vector.matrix.android.internal.di.SessionFilesDirectory
import im.vector.matrix.android.internal.di.UserMd5
import im.vector.matrix.android.internal.session.SessionScope
import im.vector.matrix.android.internal.session.cache.ClearCacheTask
@ -53,7 +53,7 @@ internal abstract class CryptoModule {
@Provides
@CryptoDatabase
@SessionScope
fun providesRealmConfiguration(@UserCacheDirectory directory: File,
fun providesRealmConfiguration(@SessionFilesDirectory directory: File,
@UserMd5 userMd5: String,
realmKeysUtils: RealmKeysUtils): RealmConfiguration {
return RealmConfiguration.Builder()

View file

@ -0,0 +1,27 @@
/*
* Copyright 2020 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.matrix.android.internal.crypto.attachments
import im.vector.matrix.android.internal.crypto.model.rest.EncryptedFileInfo
/**
* Define the result of an encryption file
*/
data class EncryptionResult(
var encryptedFileInfo: EncryptedFileInfo,
var encryptedByteArray: ByteArray
)

View file

@ -35,17 +35,9 @@ object MXEncryptedAttachments {
private const val SECRET_KEY_SPEC_ALGORITHM = "AES"
private const val MESSAGE_DIGEST_ALGORITHM = "SHA-256"
/**
* Define the result of an encryption file
*/
data class EncryptionResult(
var encryptedFileInfo: EncryptedFileInfo,
var encryptedByteArray: ByteArray
)
/***
* Encrypt an attachment stream.
* @param attachmentStream the attachment stream
* @param attachmentStream the attachment stream. Will be closed after this method call.
* @param mimetype the mime type
* @return the encryption file info
*/
@ -67,9 +59,7 @@ object MXEncryptedAttachments {
val key = ByteArray(32)
secureRandom.nextBytes(key)
val outStream = ByteArrayOutputStream()
outStream.use {
ByteArrayOutputStream().use { outputStream ->
val encryptCipher = Cipher.getInstance(CIPHER_ALGORITHM)
val secretKeySpec = SecretKeySpec(key, SECRET_KEY_SPEC_ALGORITHM)
val ivParameterSpec = IvParameterSpec(initVectorBytes)
@ -81,20 +71,22 @@ object MXEncryptedAttachments {
var read: Int
var encodedBytes: ByteArray
read = attachmentStream.read(data)
while (read != -1) {
encodedBytes = encryptCipher.update(data, 0, read)
messageDigest.update(encodedBytes, 0, encodedBytes.size)
outStream.write(encodedBytes)
read = attachmentStream.read(data)
attachmentStream.use { inputStream ->
read = inputStream.read(data)
while (read != -1) {
encodedBytes = encryptCipher.update(data, 0, read)
messageDigest.update(encodedBytes, 0, encodedBytes.size)
outputStream.write(encodedBytes)
read = inputStream.read(data)
}
}
// encrypt the latest chunk
encodedBytes = encryptCipher.doFinal()
messageDigest.update(encodedBytes, 0, encodedBytes.size)
outStream.write(encodedBytes)
outputStream.write(encodedBytes)
val result = EncryptionResult(
return EncryptionResult(
encryptedFileInfo = EncryptedFileInfo(
url = null,
mimetype = mimetype,
@ -109,18 +101,16 @@ object MXEncryptedAttachments {
hashes = mapOf("sha256" to base64ToUnpaddedBase64(Base64.encodeToString(messageDigest.digest(), Base64.DEFAULT))),
v = "v2"
),
encryptedByteArray = outStream.toByteArray()
encryptedByteArray = outputStream.toByteArray()
)
Timber.v("Encrypt in ${System.currentTimeMillis() - t0} ms")
return result
.also { Timber.v("Encrypt in ${System.currentTimeMillis() - t0}ms") }
}
}
/**
* Decrypt an attachment
*
* @param attachmentStream the attachment stream
* @param attachmentStream the attachment stream. Will be closed after this method call.
* @param encryptedFileInfo the encryption file info
* @return the decrypted attachment stream
*/
@ -138,7 +128,7 @@ object MXEncryptedAttachments {
/**
* Decrypt an attachment
*
* @param attachmentStream the attachment stream
* @param attachmentStream the attachment stream. Will be closed after this method call.
* @param elementToDecrypt the elementToDecrypt info
* @return the decrypted attachment stream
*/
@ -151,59 +141,50 @@ object MXEncryptedAttachments {
val t0 = System.currentTimeMillis()
val outStream = ByteArrayOutputStream()
ByteArrayOutputStream().use { outputStream ->
try {
val key = Base64.decode(base64UrlToBase64(elementToDecrypt.k), Base64.DEFAULT)
val initVectorBytes = Base64.decode(elementToDecrypt.iv, Base64.DEFAULT)
try {
val key = Base64.decode(base64UrlToBase64(elementToDecrypt.k), Base64.DEFAULT)
val initVectorBytes = Base64.decode(elementToDecrypt.iv, Base64.DEFAULT)
val decryptCipher = Cipher.getInstance(CIPHER_ALGORITHM)
val secretKeySpec = SecretKeySpec(key, SECRET_KEY_SPEC_ALGORITHM)
val ivParameterSpec = IvParameterSpec(initVectorBytes)
decryptCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, ivParameterSpec)
val decryptCipher = Cipher.getInstance(CIPHER_ALGORITHM)
val secretKeySpec = SecretKeySpec(key, SECRET_KEY_SPEC_ALGORITHM)
val ivParameterSpec = IvParameterSpec(initVectorBytes)
decryptCipher.init(Cipher.DECRYPT_MODE, secretKeySpec, ivParameterSpec)
val messageDigest = MessageDigest.getInstance(MESSAGE_DIGEST_ALGORITHM)
val messageDigest = MessageDigest.getInstance(MESSAGE_DIGEST_ALGORITHM)
var read: Int
val data = ByteArray(CRYPTO_BUFFER_SIZE)
var decodedBytes: ByteArray
var read: Int
val data = ByteArray(CRYPTO_BUFFER_SIZE)
var decodedBytes: ByteArray
attachmentStream.use { inputStream ->
read = inputStream.read(data)
while (read != -1) {
messageDigest.update(data, 0, read)
decodedBytes = decryptCipher.update(data, 0, read)
outputStream.write(decodedBytes)
read = inputStream.read(data)
}
}
read = attachmentStream.read(data)
while (read != -1) {
messageDigest.update(data, 0, read)
decodedBytes = decryptCipher.update(data, 0, read)
outStream.write(decodedBytes)
read = attachmentStream.read(data)
// decrypt the last chunk
decodedBytes = decryptCipher.doFinal()
outputStream.write(decodedBytes)
val currentDigestValue = base64ToUnpaddedBase64(Base64.encodeToString(messageDigest.digest(), Base64.DEFAULT))
if (elementToDecrypt.sha256 != currentDigestValue) {
Timber.e("## decryptAttachment() : Digest value mismatch")
return null
}
return ByteArrayInputStream(outputStream.toByteArray())
.also { Timber.v("Decrypt in ${System.currentTimeMillis() - t0}ms") }
} catch (oom: OutOfMemoryError) {
Timber.e(oom, "## decryptAttachment() : failed ${oom.message}")
} catch (e: Exception) {
Timber.e(e, "## decryptAttachment() : failed ${e.message}")
}
// decrypt the last chunk
decodedBytes = decryptCipher.doFinal()
outStream.write(decodedBytes)
val currentDigestValue = base64ToUnpaddedBase64(Base64.encodeToString(messageDigest.digest(), Base64.DEFAULT))
if (elementToDecrypt.sha256 != currentDigestValue) {
Timber.e("## decryptAttachment() : Digest value mismatch")
outStream.close()
return null
}
val decryptedStream = ByteArrayInputStream(outStream.toByteArray())
outStream.close()
Timber.v("Decrypt in ${System.currentTimeMillis() - t0} ms")
return decryptedStream
} catch (oom: OutOfMemoryError) {
Timber.e(oom, "## decryptAttachment() : failed ${oom.message}")
} catch (e: Exception) {
Timber.e(e, "## decryptAttachment() : failed ${e.message}")
}
try {
outStream.close()
} catch (closeException: Exception) {
Timber.e(closeException, "## decryptAttachment() : fail to close the file")
}
return null

View file

@ -18,8 +18,8 @@ package im.vector.matrix.android.internal.database
import android.content.Context
import im.vector.matrix.android.internal.database.model.SessionRealmModule
import im.vector.matrix.android.internal.di.SessionFilesDirectory
import im.vector.matrix.android.internal.di.SessionId
import im.vector.matrix.android.internal.di.UserCacheDirectory
import im.vector.matrix.android.internal.di.UserMd5
import im.vector.matrix.android.internal.session.SessionModule
import io.realm.Realm
@ -36,11 +36,12 @@ private const val REALM_NAME = "disk_store.realm"
* It will handle corrupted realm by clearing the db file. It allows to just clear cache without losing your crypto keys.
* It's clearly not perfect but there is no way to catch the native crash.
*/
internal class SessionRealmConfigurationFactory @Inject constructor(private val realmKeysUtils: RealmKeysUtils,
@UserCacheDirectory val directory: File,
@SessionId val sessionId: String,
@UserMd5 val userMd5: String,
context: Context) {
internal class SessionRealmConfigurationFactory @Inject constructor(
private val realmKeysUtils: RealmKeysUtils,
@SessionFilesDirectory val directory: File,
@SessionId val sessionId: String,
@UserMd5 val userMd5: String,
context: Context) {
private val sharedPreferences = context.getSharedPreferences("im.vector.matrix.android.realm", Context.MODE_PRIVATE)

View file

@ -14,16 +14,14 @@
* limitations under the License.
*/
/*
* Unfortunatly "ktlint-disable filename" this does not work so this file is renamed to UserCacheDirectory.kt
* If a new qualifier is added, please rename this file ti FileQualifiers.kt...
*/
/* ktlint-disable filename */
package im.vector.matrix.android.internal.di
import javax.inject.Qualifier
@Qualifier
@Retention(AnnotationRetention.RUNTIME)
annotation class UserCacheDirectory
annotation class SessionFilesDirectory
@Qualifier
@Retention(AnnotationRetention.RUNTIME)
annotation class SessionCacheDirectory

View file

@ -16,7 +16,6 @@
package im.vector.matrix.android.internal.session
import android.content.Context
import android.os.Environment
import arrow.core.Try
import im.vector.matrix.android.api.MatrixCallback
@ -25,7 +24,8 @@ import im.vector.matrix.android.api.session.file.FileService
import im.vector.matrix.android.api.util.Cancelable
import im.vector.matrix.android.internal.crypto.attachments.ElementToDecrypt
import im.vector.matrix.android.internal.crypto.attachments.MXEncryptedAttachments
import im.vector.matrix.android.internal.di.SessionId
import im.vector.matrix.android.internal.di.SessionCacheDirectory
import im.vector.matrix.android.internal.di.Unauthenticated
import im.vector.matrix.android.internal.extensions.foldToCallback
import im.vector.matrix.android.internal.util.MatrixCoroutineDispatchers
import im.vector.matrix.android.internal.util.md5
@ -41,12 +41,13 @@ import java.io.File
import java.io.IOException
import javax.inject.Inject
internal class DefaultFileService @Inject constructor(private val context: Context,
@SessionId private val sessionId: String,
private val contentUrlResolver: ContentUrlResolver,
private val coroutineDispatchers: MatrixCoroutineDispatchers) : FileService {
val okHttpClient = OkHttpClient()
internal class DefaultFileService @Inject constructor(
@SessionCacheDirectory
private val cacheDirectory: File,
private val contentUrlResolver: ContentUrlResolver,
@Unauthenticated
private val okHttpClient: OkHttpClient,
private val coroutineDispatchers: MatrixCoroutineDispatchers) : FileService {
/**
* Download file in the cache folder, and eventually decrypt it
@ -103,10 +104,9 @@ internal class DefaultFileService @Inject constructor(private val context: Conte
return when (downloadMode) {
FileService.DownloadMode.FOR_INTERNAL_USE -> {
// Create dir tree (MF stands for Matrix File):
// <cache>/MF/<sessionId>/<md5(id)>/
val tmpFolderRoot = File(context.cacheDir, "MF")
val tmpFolderUser = File(tmpFolderRoot, sessionId)
File(tmpFolderUser, id.md5())
// <cache>/<sessionId>/MF/<md5(id)>/
val tmpFolderSession = File(cacheDirectory, "MF")
File(tmpFolderSession, id.md5())
}
FileService.DownloadMode.TO_EXPORT -> {
Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS)

View file

@ -93,7 +93,7 @@ internal abstract class SessionModule {
@JvmStatic
@Provides
@UserCacheDirectory
@SessionFilesDirectory
fun providesFilesDir(@UserMd5 userMd5: String,
@SessionId sessionId: String,
context: Context): File {
@ -106,6 +106,14 @@ internal abstract class SessionModule {
return File(context.filesDir, sessionId)
}
@JvmStatic
@Provides
@SessionCacheDirectory
fun providesCacheDir(@SessionId sessionId: String,
context: Context): File {
return File(context.cacheDir, sessionId)
}
@JvmStatic
@Provides
@SessionDatabase

View file

@ -52,7 +52,7 @@ internal class DefaultSignOutTask @Inject constructor(
private val sessionParamsStore: SessionParamsStore,
@SessionDatabase private val clearSessionDataTask: ClearCacheTask,
@CryptoDatabase private val clearCryptoDataTask: ClearCacheTask,
@UserCacheDirectory private val userFile: File,
@SessionFilesDirectory private val userFile: File,
private val realmKeysUtils: RealmKeysUtils,
@SessionDatabase private val realmSessionConfiguration: RealmConfiguration,
@CryptoDatabase private val realmCryptoConfiguration: RealmConfiguration,