Cleanup SecretStoringUtils, and delete keys when user signs out

This commit is contained in:
Benoit Marty 2019-09-16 18:29:06 +02:00
parent c8010561fc
commit 1ba8a58219
6 changed files with 72 additions and 42 deletions

View file

@ -22,12 +22,14 @@ import android.security.KeyPairGeneratorSpec
import android.security.keystore.KeyGenParameterSpec import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties import android.security.keystore.KeyProperties
import androidx.annotation.RequiresApi import androidx.annotation.RequiresApi
import timber.log.Timber
import java.io.* import java.io.*
import java.math.BigInteger import java.math.BigInteger
import java.security.KeyPairGenerator import java.security.KeyPairGenerator
import java.security.KeyStore import java.security.KeyStore
import java.security.KeyStoreException
import java.security.SecureRandom import java.security.SecureRandom
import java.util.Calendar import java.util.*
import javax.crypto.* import javax.crypto.*
import javax.crypto.spec.GCMParameterSpec import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.IvParameterSpec import javax.crypto.spec.IvParameterSpec
@ -65,7 +67,7 @@ import javax.security.auth.x500.X500Principal
* val kDecripted = SecretStoringUtils.loadSecureSecret(KEncrypted!!, "myAlias", context) * val kDecripted = SecretStoringUtils.loadSecureSecret(KEncrypted!!, "myAlias", context)
* </code> * </code>
* *
* You can also just use this utility to store a secret key, and use any encryption algorthim that you want. * You can also just use this utility to store a secret key, and use any encryption algorithm that you want.
* *
* Important: Keys stored in the keystore can be wiped out (depends of the OS version, like for example if you * Important: Keys stored in the keystore can be wiped out (depends of the OS version, like for example if you
* add a pin or change the schema); So you might and with a useless pile of bytes. * add a pin or change the schema); So you might and with a useless pile of bytes.
@ -76,11 +78,11 @@ object SecretStoringUtils {
private const val AES_MODE = "AES/GCM/NoPadding"; private const val AES_MODE = "AES/GCM/NoPadding";
private const val RSA_MODE = "RSA/ECB/PKCS1Padding" private const val RSA_MODE = "RSA/ECB/PKCS1Padding"
const val FORMAT_API_M: Byte = 0 private const val FORMAT_API_M: Byte = 0
const val FORMAT_1: Byte = 1 private const val FORMAT_1: Byte = 1
const val FORMAT_2: Byte = 2 private const val FORMAT_2: Byte = 2
val keyStore: KeyStore by lazy { private val keyStore: KeyStore by lazy {
KeyStore.getInstance(ANDROID_KEY_STORE).apply { KeyStore.getInstance(ANDROID_KEY_STORE).apply {
load(null) load(null)
} }
@ -88,11 +90,19 @@ object SecretStoringUtils {
private val secureRandom = SecureRandom() private val secureRandom = SecureRandom()
fun safeDeleteKey(keyAlias: String) {
try {
keyStore.deleteEntry(keyAlias)
} catch (e: KeyStoreException) {
Timber.e(e)
}
}
/** /**
* Encrypt the given secret using the android Keystore. * Encrypt the given secret using the android Keystore.
* On android >= M, will directly use the keystore to generate a symetric key * On android >= M, will directly use the keystore to generate a symmetric key
* On KitKat >= KitKat and <M, as symetric key gen is not available, will use an asymetric key generated * On android >= KitKat and <M, as symmetric key gen is not available, will use an symmetric key generated
* in the keystore to encrypted a random symetric key. The encrypted symetric key is returned * in the keystore to encrypted a random symmetric key. The encrypted symmetric key is returned
* in the bytearray (in can be stored anywhere, it is encrypted) * in the bytearray (in can be stored anywhere, it is encrypted)
* On older version a key in generated from alias with random salt. * On older version a key in generated from alias with random salt.
* *
@ -103,7 +113,7 @@ object SecretStoringUtils {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
return encryptStringM(secret, keyAlias) return encryptStringM(secret, keyAlias)
} else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) { } else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) {
return encryptStringJ(secret, keyAlias, context) return encryptStringK(secret, keyAlias, context)
} else { } else {
return encryptForOldDevicesNotGood(secret, keyAlias) return encryptForOldDevicesNotGood(secret, keyAlias)
} }
@ -117,7 +127,7 @@ object SecretStoringUtils {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
return decryptStringM(encrypted, keyAlias) return decryptStringM(encrypted, keyAlias)
} else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) { } else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) {
return decryptStringJ(encrypted, keyAlias, context) return decryptStringK(encrypted, keyAlias, context)
} else { } else {
return decryptForOldDevicesNotGood(encrypted, keyAlias) return decryptForOldDevicesNotGood(encrypted, keyAlias)
} }
@ -145,7 +155,7 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.M) @RequiresApi(Build.VERSION_CODES.M)
fun getOrGenerateSymmetricKeyForAlias(alias: String): SecretKey { private fun getOrGenerateSymmetricKeyForAliasM(alias: String): SecretKey {
val secretKeyEntry = (keyStore.getEntry(alias, null) as? KeyStore.SecretKeyEntry) val secretKeyEntry = (keyStore.getEntry(alias, null) as? KeyStore.SecretKeyEntry)
?.secretKey ?.secretKey
if (secretKeyEntry == null) { if (secretKeyEntry == null) {
@ -163,7 +173,6 @@ object SecretStoringUtils {
return secretKeyEntry return secretKeyEntry
} }
/* /*
Symetric Key Generation is only available in M, so before M the idea is to: Symetric Key Generation is only available in M, so before M the idea is to:
- Generate a pair of RSA keys; - Generate a pair of RSA keys;
@ -172,7 +181,7 @@ object SecretStoringUtils {
- Store the encrypted AES - Store the encrypted AES
Generate a key pair for encryption Generate a key pair for encryption
*/ */
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2) @RequiresApi(Build.VERSION_CODES.KITKAT)
fun getOrGenerateKeyPairForAlias(alias: String, context: Context): KeyStore.PrivateKeyEntry { fun getOrGenerateKeyPairForAlias(alias: String, context: Context): KeyStore.PrivateKeyEntry {
val privateKeyEntry = (keyStore.getEntry(alias, null) as? KeyStore.PrivateKeyEntry) val privateKeyEntry = (keyStore.getEntry(alias, null) as? KeyStore.PrivateKeyEntry)
@ -201,7 +210,7 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.M) @RequiresApi(Build.VERSION_CODES.M)
fun encryptStringM(text: String, keyAlias: String): ByteArray? { fun encryptStringM(text: String, keyAlias: String): ByteArray? {
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias) val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val cipher = Cipher.getInstance(AES_MODE) val cipher = Cipher.getInstance(AES_MODE)
cipher.init(Cipher.ENCRYPT_MODE, secretKey) cipher.init(Cipher.ENCRYPT_MODE, secretKey)
@ -212,10 +221,10 @@ object SecretStoringUtils {
} }
@RequiresApi(Build.VERSION_CODES.M) @RequiresApi(Build.VERSION_CODES.M)
fun decryptStringM(encryptedChunk: ByteArray, keyAlias: String): String { private fun decryptStringM(encryptedChunk: ByteArray, keyAlias: String): String {
val (iv, encryptedText) = formatMExtract(ByteArrayInputStream(encryptedChunk)) val (iv, encryptedText) = formatMExtract(ByteArrayInputStream(encryptedChunk))
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias) val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val cipher = Cipher.getInstance(AES_MODE) val cipher = Cipher.getInstance(AES_MODE)
val spec = GCMParameterSpec(128, iv) val spec = GCMParameterSpec(128, iv)
@ -224,8 +233,8 @@ object SecretStoringUtils {
return String(cipher.doFinal(encryptedText), Charsets.UTF_8) return String(cipher.doFinal(encryptedText), Charsets.UTF_8)
} }
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2) @RequiresApi(Build.VERSION_CODES.KITKAT)
fun encryptStringJ(text: String, keyAlias: String, context: Context): ByteArray? { private fun encryptStringK(text: String, keyAlias: String, context: Context): ByteArray? {
//we generate a random symetric key //we generate a random symetric key
val key = ByteArray(16) val key = ByteArray(16)
secureRandom.nextBytes(key) secureRandom.nextBytes(key)
@ -242,7 +251,7 @@ object SecretStoringUtils {
return format1Make(encryptedKey, iv, encryptedBytes) return format1Make(encryptedKey, iv, encryptedBytes)
} }
fun encryptForOldDevicesNotGood(text: String, keyAlias: String): ByteArray { private fun encryptForOldDevicesNotGood(text: String, keyAlias: String): ByteArray {
val salt = ByteArray(8) val salt = ByteArray(8)
secureRandom.nextBytes(salt) secureRandom.nextBytes(salt)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256") val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
@ -258,11 +267,11 @@ object SecretStoringUtils {
return format2Make(salt, iv, encryptedBytes) return format2Make(salt, iv, encryptedBytes)
} }
fun decryptForOldDevicesNotGood(data: ByteArray, keyAlias: String): String? { private fun decryptForOldDevicesNotGood(data: ByteArray, keyAlias: String): String? {
val (salt, iv, encrypted) = format2Extract(ByteArrayInputStream(data)) val (salt, iv, encrypted) = format2Extract(ByteArrayInputStream(data))
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256") val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
val spec = PBEKeySpec(keyAlias.toCharArray(), salt, 10000, 128) val spec = PBEKeySpec(keyAlias.toCharArray(), salt, 10_000, 128)
val tmp = factory.generateSecret(spec) val tmp = factory.generateSecret(spec)
val sKey = SecretKeySpec(tmp.encoded, "AES") val sKey = SecretKeySpec(tmp.encoded, "AES")
@ -277,7 +286,7 @@ object SecretStoringUtils {
} }
@RequiresApi(Build.VERSION_CODES.KITKAT) @RequiresApi(Build.VERSION_CODES.KITKAT)
fun decryptStringJ(data: ByteArray, keyAlias: String, context: Context): String? { private fun decryptStringK(data: ByteArray, keyAlias: String, context: Context): String? {
val (encryptedKey, iv, encrypted) = format1Extract(ByteArrayInputStream(data)) val (encryptedKey, iv, encrypted) = format1Extract(ByteArrayInputStream(data))
@ -288,14 +297,12 @@ object SecretStoringUtils {
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(sKeyBytes, "AES"), spec) cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(sKeyBytes, "AES"), spec)
return String(cipher.doFinal(encrypted), Charsets.UTF_8) return String(cipher.doFinal(encrypted), Charsets.UTF_8)
} }
@RequiresApi(Build.VERSION_CODES.M) @RequiresApi(Build.VERSION_CODES.M)
@Throws(IOException::class) @Throws(IOException::class)
fun saveSecureObjectM(keyAlias: String, output: OutputStream, writeObject: Any) { private fun saveSecureObjectM(keyAlias: String, output: OutputStream, writeObject: Any) {
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias) val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val cipher = Cipher.getInstance(AES_MODE) val cipher = Cipher.getInstance(AES_MODE)
cipher.init(Cipher.ENCRYPT_MODE, secretKey/*, spec*/) cipher.init(Cipher.ENCRYPT_MODE, secretKey/*, spec*/)
@ -314,7 +321,7 @@ object SecretStoringUtils {
} }
@RequiresApi(Build.VERSION_CODES.KITKAT) @RequiresApi(Build.VERSION_CODES.KITKAT)
fun saveSecureObjectK(keyAlias: String, output: OutputStream, writeObject: Any, context: Context) { private fun saveSecureObjectK(keyAlias: String, output: OutputStream, writeObject: Any, context: Context) {
//we generate a random symetric key //we generate a random symetric key
val key = ByteArray(16) val key = ByteArray(16)
secureRandom.nextBytes(key) secureRandom.nextBytes(key)
@ -342,7 +349,7 @@ object SecretStoringUtils {
output.write(bos1.toByteArray()) output.write(bos1.toByteArray())
} }
fun saveSecureObjectOldNotGood(keyAlias: String, output: OutputStream, writeObject: Any) { private fun saveSecureObjectOldNotGood(keyAlias: String, output: OutputStream, writeObject: Any) {
val salt = ByteArray(8) val salt = ByteArray(8)
secureRandom.nextBytes(salt) secureRandom.nextBytes(salt)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256") val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
@ -387,8 +394,8 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.M) @RequiresApi(Build.VERSION_CODES.M)
@Throws(IOException::class) @Throws(IOException::class)
fun <T> loadSecureObjectM(keyAlias: String, inputStream: InputStream): T? { private fun <T> loadSecureObjectM(keyAlias: String, inputStream: InputStream): T? {
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias) val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val format = inputStream.read() val format = inputStream.read()
assert(format.toByte() == FORMAT_API_M) assert(format.toByte() == FORMAT_API_M)
@ -411,7 +418,7 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.KITKAT) @RequiresApi(Build.VERSION_CODES.KITKAT)
@Throws(IOException::class) @Throws(IOException::class)
fun <T> loadSecureObjectK(keyAlias: String, inputStream: InputStream, context: Context): T? { private fun <T> loadSecureObjectK(keyAlias: String, inputStream: InputStream, context: Context): T? {
val (encryptedKey, iv, encrypted) = format1Extract(inputStream) val (encryptedKey, iv, encrypted) = format1Extract(inputStream)
@ -432,8 +439,7 @@ object SecretStoringUtils {
} }
@Throws(Exception::class) @Throws(Exception::class)
fun <T> loadSecureObjectOldNotGood(keyAlias: String, inputStream: InputStream): T? { private fun <T> loadSecureObjectOldNotGood(keyAlias: String, inputStream: InputStream): T? {
val (salt, iv, encrypted) = format2Extract(inputStream) val (salt, iv, encrypted) = format2Extract(inputStream)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256") val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
@ -456,7 +462,7 @@ object SecretStoringUtils {
} }
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2) @RequiresApi(Build.VERSION_CODES.KITKAT)
@Throws(Exception::class) @Throws(Exception::class)
private fun rsaEncrypt(alias: String, secret: ByteArray, context: Context): ByteArray { private fun rsaEncrypt(alias: String, secret: ByteArray, context: Context): ByteArray {
val privateKeyEntry = getOrGenerateKeyPairForAlias(alias, context) val privateKeyEntry = getOrGenerateKeyPairForAlias(alias, context)
@ -472,7 +478,7 @@ object SecretStoringUtils {
return outputStream.toByteArray() return outputStream.toByteArray()
} }
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2) @RequiresApi(Build.VERSION_CODES.KITKAT)
@Throws(Exception::class) @Throws(Exception::class)
private fun rsaDecrypt(alias: String, encrypted: InputStream, context: Context): ByteArray { private fun rsaDecrypt(alias: String, encrypted: InputStream, context: Context): ByteArray {
val privateKeyEntry = getOrGenerateKeyPairForAlias(alias, context) val privateKeyEntry = getOrGenerateKeyPairForAlias(alias, context)
@ -504,7 +510,6 @@ object SecretStoringUtils {
} }
private fun format1Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> { private fun format1Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> {
val format = bis.read() val format = bis.read()
assert(format.toByte() == FORMAT_1) assert(format.toByte() == FORMAT_1)
@ -548,7 +553,6 @@ object SecretStoringUtils {
} }
private fun format2Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> { private fun format2Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> {
val format = bis.read() val format = bis.read()
assert(format.toByte() == FORMAT_2) assert(format.toByte() == FORMAT_2)

View file

@ -33,6 +33,8 @@ internal abstract class AuthModule {
@Module @Module
companion object { companion object {
private const val DB_ALIAS = "matrix-sdk-auth"
@JvmStatic @JvmStatic
@Provides @Provides
@AuthDatabase @AuthDatabase
@ -44,7 +46,7 @@ internal abstract class AuthModule {
return RealmConfiguration.Builder() return RealmConfiguration.Builder()
.apply { .apply {
realmKeysUtils.configureEncryption(this, "matrix-sdk-auth") realmKeysUtils.configureEncryption(this, DB_ALIAS)
} }
.name("matrix-sdk-auth.realm") .name("matrix-sdk-auth.realm")
.modules(AuthRealmModule()) .modules(AuthRealmModule())

View file

@ -45,6 +45,7 @@ internal abstract class CryptoModule {
@Module @Module
companion object { companion object {
internal const val DB_ALIAS_PREFIX = "crypto_module_"
@JvmStatic @JvmStatic
@Provides @Provides
@ -56,7 +57,7 @@ internal abstract class CryptoModule {
return RealmConfiguration.Builder() return RealmConfiguration.Builder()
.directory(directory) .directory(directory)
.apply { .apply {
realmKeysUtils.configureEncryption(this, "crypto_module_$userMd5") realmKeysUtils.configureEncryption(this, "$DB_ALIAS_PREFIX$userMd5")
} }
.name("crypto_store.realm") .name("crypto_store.realm")
.modules(RealmCryptoStoreModule()) .modules(RealmCryptoStoreModule())

View file

@ -101,6 +101,18 @@ internal class RealmKeysUtils @Inject constructor(private val context: Context)
realmConfigurationBuilder.encryptionKey(key) realmConfigurationBuilder.encryptionKey(key)
} }
// Delete elements related to the alias
fun clear(alias: String) {
if (hasKeyForDatabase(alias)) {
SecretStoringUtils.safeDeleteKey(alias)
sharedPreferences
.edit()
.remove("${ENCRYPTED_KEY_PREFIX}_$alias")
.apply()
}
}
companion object { companion object {
private const val ENCRYPTED_KEY_PREFIX = "REALM_ENCRYPTED_KEY" private const val ENCRYPTED_KEY_PREFIX = "REALM_ENCRYPTED_KEY"
} }

View file

@ -50,6 +50,7 @@ internal abstract class SessionModule {
@Module @Module
companion object { companion object {
internal const val DB_ALIAS_PREFIX = "session_db_"
@JvmStatic @JvmStatic
@Provides @Provides
@ -89,7 +90,7 @@ internal abstract class SessionModule {
.directory(directory) .directory(directory)
.name("disk_store.realm") .name("disk_store.realm")
.apply { .apply {
realmKeysUtils.configureEncryption(this, "session_db_$userMd5") realmKeysUtils.configureEncryption(this, "$DB_ALIAS_PREFIX$userMd5")
} }
.modules(SessionRealmModule()) .modules(SessionRealmModule())
.deleteRealmIfMigrationNeeded() .deleteRealmIfMigrationNeeded()

View file

@ -20,10 +20,14 @@ import android.content.Context
import im.vector.matrix.android.api.auth.data.Credentials import im.vector.matrix.android.api.auth.data.Credentials
import im.vector.matrix.android.internal.SessionManager import im.vector.matrix.android.internal.SessionManager
import im.vector.matrix.android.internal.auth.SessionParamsStore import im.vector.matrix.android.internal.auth.SessionParamsStore
import im.vector.matrix.android.internal.crypto.CryptoModule
import im.vector.matrix.android.internal.database.RealmKeysUtils
import im.vector.matrix.android.internal.di.CryptoDatabase import im.vector.matrix.android.internal.di.CryptoDatabase
import im.vector.matrix.android.internal.di.SessionDatabase import im.vector.matrix.android.internal.di.SessionDatabase
import im.vector.matrix.android.internal.di.UserCacheDirectory import im.vector.matrix.android.internal.di.UserCacheDirectory
import im.vector.matrix.android.internal.di.UserMd5
import im.vector.matrix.android.internal.network.executeRequest import im.vector.matrix.android.internal.network.executeRequest
import im.vector.matrix.android.internal.session.SessionModule
import im.vector.matrix.android.internal.session.cache.ClearCacheTask import im.vector.matrix.android.internal.session.cache.ClearCacheTask
import im.vector.matrix.android.internal.task.Task import im.vector.matrix.android.internal.task.Task
import im.vector.matrix.android.internal.worker.WorkManagerUtil import im.vector.matrix.android.internal.worker.WorkManagerUtil
@ -40,7 +44,9 @@ internal class DefaultSignOutTask @Inject constructor(private val context: Conte
private val sessionParamsStore: SessionParamsStore, private val sessionParamsStore: SessionParamsStore,
@SessionDatabase private val clearSessionDataTask: ClearCacheTask, @SessionDatabase private val clearSessionDataTask: ClearCacheTask,
@CryptoDatabase private val clearCryptoDataTask: ClearCacheTask, @CryptoDatabase private val clearCryptoDataTask: ClearCacheTask,
@UserCacheDirectory private val userFile: File) : SignOutTask { @UserCacheDirectory private val userFile: File,
private val realmKeysUtils: RealmKeysUtils,
@UserMd5 private val userMd5: String) : SignOutTask {
override suspend fun execute(params: Unit) { override suspend fun execute(params: Unit) {
Timber.d("SignOut: send request...") Timber.d("SignOut: send request...")
@ -65,5 +71,9 @@ internal class DefaultSignOutTask @Inject constructor(private val context: Conte
Timber.d("SignOut: clear file system") Timber.d("SignOut: clear file system")
userFile.deleteRecursively() userFile.deleteRecursively()
Timber.d("SignOut: clear the database keys")
realmKeysUtils.clear(SessionModule.DB_ALIAS_PREFIX + userMd5)
realmKeysUtils.clear(CryptoModule.DB_ALIAS_PREFIX + userMd5)
} }
} }