Cleanup SecretStoringUtils, and delete keys when user signs out

This commit is contained in:
Benoit Marty 2019-09-16 18:29:06 +02:00
parent c8010561fc
commit 1ba8a58219
6 changed files with 72 additions and 42 deletions

View file

@ -22,12 +22,14 @@ import android.security.KeyPairGeneratorSpec
import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties
import androidx.annotation.RequiresApi
import timber.log.Timber
import java.io.*
import java.math.BigInteger
import java.security.KeyPairGenerator
import java.security.KeyStore
import java.security.KeyStoreException
import java.security.SecureRandom
import java.util.Calendar
import java.util.*
import javax.crypto.*
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.IvParameterSpec
@ -65,7 +67,7 @@ import javax.security.auth.x500.X500Principal
* val kDecripted = SecretStoringUtils.loadSecureSecret(KEncrypted!!, "myAlias", context)
* </code>
*
* You can also just use this utility to store a secret key, and use any encryption algorthim that you want.
* You can also just use this utility to store a secret key, and use any encryption algorithm that you want.
*
* Important: Keys stored in the keystore can be wiped out (depends of the OS version, like for example if you
* add a pin or change the schema); So you might and with a useless pile of bytes.
@ -76,11 +78,11 @@ object SecretStoringUtils {
private const val AES_MODE = "AES/GCM/NoPadding";
private const val RSA_MODE = "RSA/ECB/PKCS1Padding"
const val FORMAT_API_M: Byte = 0
const val FORMAT_1: Byte = 1
const val FORMAT_2: Byte = 2
private const val FORMAT_API_M: Byte = 0
private const val FORMAT_1: Byte = 1
private const val FORMAT_2: Byte = 2
val keyStore: KeyStore by lazy {
private val keyStore: KeyStore by lazy {
KeyStore.getInstance(ANDROID_KEY_STORE).apply {
load(null)
}
@ -88,11 +90,19 @@ object SecretStoringUtils {
private val secureRandom = SecureRandom()
fun safeDeleteKey(keyAlias: String) {
try {
keyStore.deleteEntry(keyAlias)
} catch (e: KeyStoreException) {
Timber.e(e)
}
}
/**
* Encrypt the given secret using the android Keystore.
* On android >= M, will directly use the keystore to generate a symetric key
* On KitKat >= KitKat and <M, as symetric key gen is not available, will use an asymetric key generated
* in the keystore to encrypted a random symetric key. The encrypted symetric key is returned
* On android >= M, will directly use the keystore to generate a symmetric key
* On android >= KitKat and <M, as symmetric key gen is not available, will use an symmetric key generated
* in the keystore to encrypted a random symmetric key. The encrypted symmetric key is returned
* in the bytearray (in can be stored anywhere, it is encrypted)
* On older version a key in generated from alias with random salt.
*
@ -103,7 +113,7 @@ object SecretStoringUtils {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
return encryptStringM(secret, keyAlias)
} else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) {
return encryptStringJ(secret, keyAlias, context)
return encryptStringK(secret, keyAlias, context)
} else {
return encryptForOldDevicesNotGood(secret, keyAlias)
}
@ -117,7 +127,7 @@ object SecretStoringUtils {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
return decryptStringM(encrypted, keyAlias)
} else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) {
return decryptStringJ(encrypted, keyAlias, context)
return decryptStringK(encrypted, keyAlias, context)
} else {
return decryptForOldDevicesNotGood(encrypted, keyAlias)
}
@ -145,7 +155,7 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.M)
fun getOrGenerateSymmetricKeyForAlias(alias: String): SecretKey {
private fun getOrGenerateSymmetricKeyForAliasM(alias: String): SecretKey {
val secretKeyEntry = (keyStore.getEntry(alias, null) as? KeyStore.SecretKeyEntry)
?.secretKey
if (secretKeyEntry == null) {
@ -163,7 +173,6 @@ object SecretStoringUtils {
return secretKeyEntry
}
/*
Symetric Key Generation is only available in M, so before M the idea is to:
- Generate a pair of RSA keys;
@ -172,7 +181,7 @@ object SecretStoringUtils {
- Store the encrypted AES
Generate a key pair for encryption
*/
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2)
@RequiresApi(Build.VERSION_CODES.KITKAT)
fun getOrGenerateKeyPairForAlias(alias: String, context: Context): KeyStore.PrivateKeyEntry {
val privateKeyEntry = (keyStore.getEntry(alias, null) as? KeyStore.PrivateKeyEntry)
@ -201,7 +210,7 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.M)
fun encryptStringM(text: String, keyAlias: String): ByteArray? {
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias)
val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val cipher = Cipher.getInstance(AES_MODE)
cipher.init(Cipher.ENCRYPT_MODE, secretKey)
@ -212,10 +221,10 @@ object SecretStoringUtils {
}
@RequiresApi(Build.VERSION_CODES.M)
fun decryptStringM(encryptedChunk: ByteArray, keyAlias: String): String {
private fun decryptStringM(encryptedChunk: ByteArray, keyAlias: String): String {
val (iv, encryptedText) = formatMExtract(ByteArrayInputStream(encryptedChunk))
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias)
val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val cipher = Cipher.getInstance(AES_MODE)
val spec = GCMParameterSpec(128, iv)
@ -224,8 +233,8 @@ object SecretStoringUtils {
return String(cipher.doFinal(encryptedText), Charsets.UTF_8)
}
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2)
fun encryptStringJ(text: String, keyAlias: String, context: Context): ByteArray? {
@RequiresApi(Build.VERSION_CODES.KITKAT)
private fun encryptStringK(text: String, keyAlias: String, context: Context): ByteArray? {
//we generate a random symetric key
val key = ByteArray(16)
secureRandom.nextBytes(key)
@ -242,7 +251,7 @@ object SecretStoringUtils {
return format1Make(encryptedKey, iv, encryptedBytes)
}
fun encryptForOldDevicesNotGood(text: String, keyAlias: String): ByteArray {
private fun encryptForOldDevicesNotGood(text: String, keyAlias: String): ByteArray {
val salt = ByteArray(8)
secureRandom.nextBytes(salt)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
@ -258,11 +267,11 @@ object SecretStoringUtils {
return format2Make(salt, iv, encryptedBytes)
}
fun decryptForOldDevicesNotGood(data: ByteArray, keyAlias: String): String? {
private fun decryptForOldDevicesNotGood(data: ByteArray, keyAlias: String): String? {
val (salt, iv, encrypted) = format2Extract(ByteArrayInputStream(data))
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
val spec = PBEKeySpec(keyAlias.toCharArray(), salt, 10000, 128)
val spec = PBEKeySpec(keyAlias.toCharArray(), salt, 10_000, 128)
val tmp = factory.generateSecret(spec)
val sKey = SecretKeySpec(tmp.encoded, "AES")
@ -277,7 +286,7 @@ object SecretStoringUtils {
}
@RequiresApi(Build.VERSION_CODES.KITKAT)
fun decryptStringJ(data: ByteArray, keyAlias: String, context: Context): String? {
private fun decryptStringK(data: ByteArray, keyAlias: String, context: Context): String? {
val (encryptedKey, iv, encrypted) = format1Extract(ByteArrayInputStream(data))
@ -288,14 +297,12 @@ object SecretStoringUtils {
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(sKeyBytes, "AES"), spec)
return String(cipher.doFinal(encrypted), Charsets.UTF_8)
}
@RequiresApi(Build.VERSION_CODES.M)
@Throws(IOException::class)
fun saveSecureObjectM(keyAlias: String, output: OutputStream, writeObject: Any) {
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias)
private fun saveSecureObjectM(keyAlias: String, output: OutputStream, writeObject: Any) {
val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val cipher = Cipher.getInstance(AES_MODE)
cipher.init(Cipher.ENCRYPT_MODE, secretKey/*, spec*/)
@ -314,7 +321,7 @@ object SecretStoringUtils {
}
@RequiresApi(Build.VERSION_CODES.KITKAT)
fun saveSecureObjectK(keyAlias: String, output: OutputStream, writeObject: Any, context: Context) {
private fun saveSecureObjectK(keyAlias: String, output: OutputStream, writeObject: Any, context: Context) {
//we generate a random symetric key
val key = ByteArray(16)
secureRandom.nextBytes(key)
@ -342,7 +349,7 @@ object SecretStoringUtils {
output.write(bos1.toByteArray())
}
fun saveSecureObjectOldNotGood(keyAlias: String, output: OutputStream, writeObject: Any) {
private fun saveSecureObjectOldNotGood(keyAlias: String, output: OutputStream, writeObject: Any) {
val salt = ByteArray(8)
secureRandom.nextBytes(salt)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
@ -387,8 +394,8 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.M)
@Throws(IOException::class)
fun <T> loadSecureObjectM(keyAlias: String, inputStream: InputStream): T? {
val secretKey = getOrGenerateSymmetricKeyForAlias(keyAlias)
private fun <T> loadSecureObjectM(keyAlias: String, inputStream: InputStream): T? {
val secretKey = getOrGenerateSymmetricKeyForAliasM(keyAlias)
val format = inputStream.read()
assert(format.toByte() == FORMAT_API_M)
@ -411,7 +418,7 @@ object SecretStoringUtils {
@RequiresApi(Build.VERSION_CODES.KITKAT)
@Throws(IOException::class)
fun <T> loadSecureObjectK(keyAlias: String, inputStream: InputStream, context: Context): T? {
private fun <T> loadSecureObjectK(keyAlias: String, inputStream: InputStream, context: Context): T? {
val (encryptedKey, iv, encrypted) = format1Extract(inputStream)
@ -432,8 +439,7 @@ object SecretStoringUtils {
}
@Throws(Exception::class)
fun <T> loadSecureObjectOldNotGood(keyAlias: String, inputStream: InputStream): T? {
private fun <T> loadSecureObjectOldNotGood(keyAlias: String, inputStream: InputStream): T? {
val (salt, iv, encrypted) = format2Extract(inputStream)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
@ -456,7 +462,7 @@ object SecretStoringUtils {
}
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2)
@RequiresApi(Build.VERSION_CODES.KITKAT)
@Throws(Exception::class)
private fun rsaEncrypt(alias: String, secret: ByteArray, context: Context): ByteArray {
val privateKeyEntry = getOrGenerateKeyPairForAlias(alias, context)
@ -472,7 +478,7 @@ object SecretStoringUtils {
return outputStream.toByteArray()
}
@RequiresApi(Build.VERSION_CODES.JELLY_BEAN_MR2)
@RequiresApi(Build.VERSION_CODES.KITKAT)
@Throws(Exception::class)
private fun rsaDecrypt(alias: String, encrypted: InputStream, context: Context): ByteArray {
val privateKeyEntry = getOrGenerateKeyPairForAlias(alias, context)
@ -504,7 +510,6 @@ object SecretStoringUtils {
}
private fun format1Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> {
val format = bis.read()
assert(format.toByte() == FORMAT_1)
@ -548,7 +553,6 @@ object SecretStoringUtils {
}
private fun format2Extract(bis: InputStream): Triple<ByteArray, ByteArray, ByteArray> {
val format = bis.read()
assert(format.toByte() == FORMAT_2)

View file

@ -33,6 +33,8 @@ internal abstract class AuthModule {
@Module
companion object {
private const val DB_ALIAS = "matrix-sdk-auth"
@JvmStatic
@Provides
@AuthDatabase
@ -44,7 +46,7 @@ internal abstract class AuthModule {
return RealmConfiguration.Builder()
.apply {
realmKeysUtils.configureEncryption(this, "matrix-sdk-auth")
realmKeysUtils.configureEncryption(this, DB_ALIAS)
}
.name("matrix-sdk-auth.realm")
.modules(AuthRealmModule())

View file

@ -45,6 +45,7 @@ internal abstract class CryptoModule {
@Module
companion object {
internal const val DB_ALIAS_PREFIX = "crypto_module_"
@JvmStatic
@Provides
@ -56,7 +57,7 @@ internal abstract class CryptoModule {
return RealmConfiguration.Builder()
.directory(directory)
.apply {
realmKeysUtils.configureEncryption(this, "crypto_module_$userMd5")
realmKeysUtils.configureEncryption(this, "$DB_ALIAS_PREFIX$userMd5")
}
.name("crypto_store.realm")
.modules(RealmCryptoStoreModule())

View file

@ -101,6 +101,18 @@ internal class RealmKeysUtils @Inject constructor(private val context: Context)
realmConfigurationBuilder.encryptionKey(key)
}
// Delete elements related to the alias
fun clear(alias: String) {
if (hasKeyForDatabase(alias)) {
SecretStoringUtils.safeDeleteKey(alias)
sharedPreferences
.edit()
.remove("${ENCRYPTED_KEY_PREFIX}_$alias")
.apply()
}
}
companion object {
private const val ENCRYPTED_KEY_PREFIX = "REALM_ENCRYPTED_KEY"
}

View file

@ -50,6 +50,7 @@ internal abstract class SessionModule {
@Module
companion object {
internal const val DB_ALIAS_PREFIX = "session_db_"
@JvmStatic
@Provides
@ -89,7 +90,7 @@ internal abstract class SessionModule {
.directory(directory)
.name("disk_store.realm")
.apply {
realmKeysUtils.configureEncryption(this, "session_db_$userMd5")
realmKeysUtils.configureEncryption(this, "$DB_ALIAS_PREFIX$userMd5")
}
.modules(SessionRealmModule())
.deleteRealmIfMigrationNeeded()

View file

@ -20,10 +20,14 @@ import android.content.Context
import im.vector.matrix.android.api.auth.data.Credentials
import im.vector.matrix.android.internal.SessionManager
import im.vector.matrix.android.internal.auth.SessionParamsStore
import im.vector.matrix.android.internal.crypto.CryptoModule
import im.vector.matrix.android.internal.database.RealmKeysUtils
import im.vector.matrix.android.internal.di.CryptoDatabase
import im.vector.matrix.android.internal.di.SessionDatabase
import im.vector.matrix.android.internal.di.UserCacheDirectory
import im.vector.matrix.android.internal.di.UserMd5
import im.vector.matrix.android.internal.network.executeRequest
import im.vector.matrix.android.internal.session.SessionModule
import im.vector.matrix.android.internal.session.cache.ClearCacheTask
import im.vector.matrix.android.internal.task.Task
import im.vector.matrix.android.internal.worker.WorkManagerUtil
@ -40,7 +44,9 @@ internal class DefaultSignOutTask @Inject constructor(private val context: Conte
private val sessionParamsStore: SessionParamsStore,
@SessionDatabase private val clearSessionDataTask: ClearCacheTask,
@CryptoDatabase private val clearCryptoDataTask: ClearCacheTask,
@UserCacheDirectory private val userFile: File) : SignOutTask {
@UserCacheDirectory private val userFile: File,
private val realmKeysUtils: RealmKeysUtils,
@UserMd5 private val userMd5: String) : SignOutTask {
override suspend fun execute(params: Unit) {
Timber.d("SignOut: send request...")
@ -65,5 +71,9 @@ internal class DefaultSignOutTask @Inject constructor(private val context: Conte
Timber.d("SignOut: clear file system")
userFile.deleteRecursively()
Timber.d("SignOut: clear the database keys")
realmKeysUtils.clear(SessionModule.DB_ALIAS_PREFIX + userMd5)
realmKeysUtils.clear(CryptoModule.DB_ALIAS_PREFIX + userMd5)
}
}