Merge pull request #1016 from vector-im/feature/cleanup_quadS

Cleanup quad s and AccountData service
This commit is contained in:
Benoit Marty 2020-02-17 18:52:23 +01:00 committed by GitHub
commit fc740ae347
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 308 additions and 325 deletions

View file

@ -22,6 +22,7 @@
<w>signin</w> <w>signin</w>
<w>signout</w> <w>signout</w>
<w>signup</w> <w>signup</w>
<w>ssss</w>
<w>threepid</w> <w>threepid</w>
</words> </words>
</dictionary> </dictionary>

View file

@ -123,10 +123,10 @@ class RxSession(private val session: Session) {
} }
} }
fun liveAccountData(filter: List<String>): Observable<List<UserAccountDataEvent>> { fun liveAccountData(types: Set<String>): Observable<List<UserAccountDataEvent>> {
return session.getLiveAccountDataEvents(filter).asObservable() return session.getLiveAccountDataEvents(types).asObservable()
.startWithCallable { .startWithCallable {
session.getAccountDataEvents(filter) session.getAccountDataEvents(types)
} }
} }
} }

View file

@ -25,24 +25,26 @@ import im.vector.matrix.android.api.session.Session
import im.vector.matrix.android.api.session.securestorage.Curve25519AesSha2KeySpec import im.vector.matrix.android.api.session.securestorage.Curve25519AesSha2KeySpec
import im.vector.matrix.android.api.session.securestorage.EncryptedSecretContent import im.vector.matrix.android.api.session.securestorage.EncryptedSecretContent
import im.vector.matrix.android.api.session.securestorage.KeySigner import im.vector.matrix.android.api.session.securestorage.KeySigner
import im.vector.matrix.android.api.session.securestorage.SsssKeyCreationInfo
import im.vector.matrix.android.api.session.securestorage.SecretStorageKeyContent import im.vector.matrix.android.api.session.securestorage.SecretStorageKeyContent
import im.vector.matrix.android.api.session.securestorage.SsssKeyCreationInfo
import im.vector.matrix.android.api.util.Optional import im.vector.matrix.android.api.util.Optional
import im.vector.matrix.android.common.CommonTestHelper import im.vector.matrix.android.common.CommonTestHelper
import im.vector.matrix.android.common.CryptoTestHelper
import im.vector.matrix.android.common.SessionTestParams import im.vector.matrix.android.common.SessionTestParams
import im.vector.matrix.android.common.TestConstants import im.vector.matrix.android.common.TestConstants
import im.vector.matrix.android.common.TestMatrixCallback import im.vector.matrix.android.common.TestMatrixCallback
import im.vector.matrix.android.internal.crypto.SSSS_ALGORITHM_CURVE25519_AES_SHA2 import im.vector.matrix.android.internal.crypto.SSSS_ALGORITHM_CURVE25519_AES_SHA2
import im.vector.matrix.android.internal.crypto.crosssigning.toBase64NoPadding import im.vector.matrix.android.internal.crypto.crosssigning.toBase64NoPadding
import im.vector.matrix.android.internal.crypto.secrets.DefaultSharedSecretStorageService import im.vector.matrix.android.internal.crypto.secrets.DefaultSharedSecretStorageService
import im.vector.matrix.android.internal.crypto.tools.withOlmDecryption
import im.vector.matrix.android.internal.session.sync.model.accountdata.UserAccountDataEvent import im.vector.matrix.android.internal.session.sync.model.accountdata.UserAccountDataEvent
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.Assert import org.amshove.kluent.shouldBe
import org.junit.Assert.fail import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotNull
import org.junit.Assert.assertNull
import org.junit.FixMethodOrder import org.junit.FixMethodOrder
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
@ -50,54 +52,38 @@ import org.junit.runners.MethodSorters
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
@RunWith(AndroidJUnit4::class) @RunWith(AndroidJUnit4::class)
@FixMethodOrder(MethodSorters.NAME_ASCENDING) @FixMethodOrder(MethodSorters.JVM)
class QuadSTests : InstrumentedTest { class QuadSTests : InstrumentedTest {
private val mTestHelper = CommonTestHelper(context()) private val mTestHelper = CommonTestHelper(context())
private val mCryptoTestHelper = CryptoTestHelper(mTestHelper)
private val emptyKeySigner = object : KeySigner {
override fun sign(canonicalJson: String): Map<String, Map<String, String>>? {
return null
}
}
@Test @Test
fun test_Generate4SKey() { fun test_Generate4SKey() {
val aliceSession = mTestHelper.createAccount(TestConstants.USER_ALICE, SessionTestParams(true)) val aliceSession = mTestHelper.createAccount(TestConstants.USER_ALICE, SessionTestParams(true))
val aliceLatch = CountDownLatch(1)
val quadS = aliceSession.sharedSecretStorageService val quadS = aliceSession.sharedSecretStorageService
val emptyKeySigner = object : KeySigner {
override fun sign(canonicalJson: String): Map<String, Map<String, String>>? {
return null
}
}
var recoveryKey: String? = null
val TEST_KEY_ID = "my.test.Key" val TEST_KEY_ID = "my.test.Key"
quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, val ssssKeyCreationInfo = mTestHelper.doSync<SsssKeyCreationInfo> {
object : MatrixCallback<SsssKeyCreationInfo> { quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, it)
override fun onSuccess(data: SsssKeyCreationInfo) { }
recoveryKey = data.recoveryKey
aliceLatch.countDown()
}
override fun onFailure(failure: Throwable) {
Assert.fail("onFailure " + failure.localizedMessage)
aliceLatch.countDown()
}
})
mTestHelper.await(aliceLatch)
// Assert Account data is updated // Assert Account data is updated
val accountDataLock = CountDownLatch(1) val accountDataLock = CountDownLatch(1)
var accountData: UserAccountDataEvent? = null var accountData: UserAccountDataEvent? = null
val liveAccountData = runBlocking(Dispatchers.Main) { val liveAccountData = runBlocking(Dispatchers.Main) {
aliceSession.getLiveAccountDataEvent("m.secret_storage.key.$TEST_KEY_ID") aliceSession.getLiveAccountDataEvent("${DefaultSharedSecretStorageService.KEY_ID_BASE}.$TEST_KEY_ID")
} }
val accountDataObserver = Observer<Optional<UserAccountDataEvent>?> { t -> val accountDataObserver = Observer<Optional<UserAccountDataEvent>?> { t ->
if (t?.getOrNull()?.type == "m.secret_storage.key.$TEST_KEY_ID") { if (t?.getOrNull()?.type == "${DefaultSharedSecretStorageService.KEY_ID_BASE}.$TEST_KEY_ID") {
accountData = t.getOrNull() accountData = t.getOrNull()
accountDataLock.countDown() accountDataLock.countDown()
} }
@ -106,19 +92,19 @@ class QuadSTests : InstrumentedTest {
mTestHelper.await(accountDataLock) mTestHelper.await(accountDataLock)
Assert.assertNotNull("Key should be stored in account data", accountData) assertNotNull("Key should be stored in account data", accountData)
val parsed = SecretStorageKeyContent.fromJson(accountData!!.content) val parsed = SecretStorageKeyContent.fromJson(accountData!!.content)
Assert.assertNotNull("Key Content cannot be parsed", parsed) assertNotNull("Key Content cannot be parsed", parsed)
Assert.assertEquals("Unexpected Algorithm", SSSS_ALGORITHM_CURVE25519_AES_SHA2, parsed!!.algorithm) assertEquals("Unexpected Algorithm", SSSS_ALGORITHM_CURVE25519_AES_SHA2, parsed!!.algorithm)
Assert.assertEquals("Unexpected key name", "Test Key", parsed.name) assertEquals("Unexpected key name", "Test Key", parsed.name)
Assert.assertNull("Key was not generated from passphrase", parsed.passphrase) assertNull("Key was not generated from passphrase", parsed.passphrase)
Assert.assertNotNull("Pubkey should be defined", parsed.publicKey) assertNotNull("Pubkey should be defined", parsed.publicKey)
val privateKeySpec = Curve25519AesSha2KeySpec.fromRecoveryKey(recoveryKey!!) val privateKeySpec = Curve25519AesSha2KeySpec.fromRecoveryKey(ssssKeyCreationInfo.recoveryKey)
DefaultSharedSecretStorageService.withOlmDecryption { olmPkDecryption -> val pubKey = withOlmDecryption { olmPkDecryption ->
val pubKey = olmPkDecryption.setPrivateKey(privateKeySpec!!.privateKey) olmPkDecryption.setPrivateKey(privateKeySpec!!.privateKey)
Assert.assertEquals("Unexpected Public Key", pubKey, parsed.publicKey)
} }
assertEquals("Unexpected Public Key", pubKey, parsed.publicKey)
// Set as default key // Set as default key
quadS.setDefaultKey(TEST_KEY_ID, object : MatrixCallback<Unit> {}) quadS.setDefaultKey(TEST_KEY_ID, object : MatrixCallback<Unit> {})
@ -139,8 +125,8 @@ class QuadSTests : InstrumentedTest {
mTestHelper.await(defaultDataLock) mTestHelper.await(defaultDataLock)
Assert.assertNotNull(defaultKeyAccountData?.content) assertNotNull(defaultKeyAccountData?.content)
Assert.assertEquals("Unexpected default key ${defaultKeyAccountData?.content}", TEST_KEY_ID, defaultKeyAccountData?.content?.get("key")) assertEquals("Unexpected default key ${defaultKeyAccountData?.content}", TEST_KEY_ID, defaultKeyAccountData?.content?.get("key"))
mTestHelper.signout(aliceSession) mTestHelper.signout(aliceSession)
} }
@ -152,52 +138,40 @@ class QuadSTests : InstrumentedTest {
val info = generatedSecret(aliceSession, keyId, true) val info = generatedSecret(aliceSession, keyId, true)
// Store a secret // Store a secret
val storeCountDownLatch = CountDownLatch(1)
val clearSecret = Base64.encodeToString("42".toByteArray(), Base64.NO_PADDING or Base64.NO_WRAP) val clearSecret = Base64.encodeToString("42".toByteArray(), Base64.NO_PADDING or Base64.NO_WRAP)
aliceSession.sharedSecretStorageService.storeSecret( mTestHelper.doSync<Unit> {
"secret.of.life", aliceSession.sharedSecretStorageService.storeSecret(
clearSecret, "secret.of.life",
null, // default key clearSecret,
TestMatrixCallback(storeCountDownLatch) null, // default key
) it
)
}
val secretAccountData = assertAccountData(aliceSession, "secret.of.life") val secretAccountData = assertAccountData(aliceSession, "secret.of.life")
val encryptedContent = secretAccountData.content.get("encrypted") as? Map<*, *> val encryptedContent = secretAccountData.content.get("encrypted") as? Map<*, *>
Assert.assertNotNull("Element should be encrypted", encryptedContent) assertNotNull("Element should be encrypted", encryptedContent)
Assert.assertNotNull("Secret should be encrypted with default key", encryptedContent?.get(keyId)) assertNotNull("Secret should be encrypted with default key", encryptedContent?.get(keyId))
val secret = EncryptedSecretContent.fromJson(encryptedContent?.get(keyId)) val secret = EncryptedSecretContent.fromJson(encryptedContent?.get(keyId))
Assert.assertNotNull(secret?.ciphertext) assertNotNull(secret?.ciphertext)
Assert.assertNotNull(secret?.mac) assertNotNull(secret?.mac)
Assert.assertNotNull(secret?.ephemeral) assertNotNull(secret?.ephemeral)
// Try to decrypt?? // Try to decrypt??
val keySpec = Curve25519AesSha2KeySpec.fromRecoveryKey(info.recoveryKey) val keySpec = Curve25519AesSha2KeySpec.fromRecoveryKey(info.recoveryKey)
var decryptedSecret: String? = null val decryptedSecret = mTestHelper.doSync<String> {
aliceSession.sharedSecretStorageService.getSecret("secret.of.life",
null, // default key
keySpec!!,
it
)
}
val decryptCountDownLatch = CountDownLatch(1) assertEquals("Secret mismatch", clearSecret, decryptedSecret)
aliceSession.sharedSecretStorageService.getSecret("secret.of.life",
null, // default key
keySpec!!,
object : MatrixCallback<String> {
override fun onFailure(failure: Throwable) {
fail("Fail to decrypt -> " + failure.localizedMessage)
decryptCountDownLatch.countDown()
}
override fun onSuccess(data: String) {
decryptedSecret = data
decryptCountDownLatch.countDown()
}
}
)
mTestHelper.await(decryptCountDownLatch)
Assert.assertEquals("Secret mismatch", clearSecret, decryptedSecret)
mTestHelper.signout(aliceSession) mTestHelper.signout(aliceSession)
} }
@ -207,24 +181,16 @@ class QuadSTests : InstrumentedTest {
val quadS = aliceSession.sharedSecretStorageService val quadS = aliceSession.sharedSecretStorageService
val emptyKeySigner = object : KeySigner {
override fun sign(canonicalJson: String): Map<String, Map<String, String>>? {
return null
}
}
val TEST_KEY_ID = "my.test.Key" val TEST_KEY_ID = "my.test.Key"
val countDownLatch = CountDownLatch(1) mTestHelper.doSync<SsssKeyCreationInfo> {
quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, quadS.generateKey(TEST_KEY_ID, "Test Key", emptyKeySigner, it)
TestMatrixCallback(countDownLatch)) }
mTestHelper.await(countDownLatch)
// Test that we don't need to wait for an account data sync to access directly the keyid from DB // Test that we don't need to wait for an account data sync to access directly the keyid from DB
val defaultLatch = CountDownLatch(1) mTestHelper.doSync<Unit> {
quadS.setDefaultKey(TEST_KEY_ID, TestMatrixCallback(defaultLatch)) quadS.setDefaultKey(TEST_KEY_ID, it)
mTestHelper.await(defaultLatch) }
mTestHelper.signout(aliceSession) mTestHelper.signout(aliceSession)
} }
@ -239,38 +205,39 @@ class QuadSTests : InstrumentedTest {
val mySecretText = "Lorem ipsum dolor sit amet, consectetur adipiscing elit" val mySecretText = "Lorem ipsum dolor sit amet, consectetur adipiscing elit"
val storeLatch = CountDownLatch(1) mTestHelper.doSync<Unit> {
aliceSession.sharedSecretStorageService.storeSecret( aliceSession.sharedSecretStorageService.storeSecret(
"my.secret", "my.secret",
mySecretText.toByteArray().toBase64NoPadding(), mySecretText.toByteArray().toBase64NoPadding(),
listOf(keyId1, keyId2), listOf(keyId1, keyId2),
TestMatrixCallback(storeLatch) it
) )
mTestHelper.await(storeLatch) }
val accountDataEvent = aliceSession.getAccountDataEvent("my.secret") val accountDataEvent = aliceSession.getAccountDataEvent("my.secret")
val encryptedContent = accountDataEvent?.content?.get("encrypted") as? Map<*, *> val encryptedContent = accountDataEvent?.content?.get("encrypted") as? Map<*, *>
Assert.assertEquals("Content should contains two encryptions", 2, encryptedContent?.keys?.size ?: 0) assertEquals("Content should contains two encryptions", 2, encryptedContent?.keys?.size ?: 0)
Assert.assertNotNull(encryptedContent?.get(keyId1)) assertNotNull(encryptedContent?.get(keyId1))
Assert.assertNotNull(encryptedContent?.get(keyId2)) assertNotNull(encryptedContent?.get(keyId2))
// Assert that can decrypt with both keys // Assert that can decrypt with both keys
val decryptCountDownLatch = CountDownLatch(2) mTestHelper.doSync<String> {
aliceSession.sharedSecretStorageService.getSecret("my.secret", aliceSession.sharedSecretStorageService.getSecret("my.secret",
keyId1, keyId1,
Curve25519AesSha2KeySpec.fromRecoveryKey(key1Info.recoveryKey)!!, Curve25519AesSha2KeySpec.fromRecoveryKey(key1Info.recoveryKey)!!,
TestMatrixCallback(decryptCountDownLatch) it
) )
}
aliceSession.sharedSecretStorageService.getSecret("my.secret", mTestHelper.doSync<String> {
keyId2, aliceSession.sharedSecretStorageService.getSecret("my.secret",
Curve25519AesSha2KeySpec.fromRecoveryKey(key2Info.recoveryKey)!!, keyId2,
TestMatrixCallback(decryptCountDownLatch) Curve25519AesSha2KeySpec.fromRecoveryKey(key2Info.recoveryKey)!!,
) it
)
mTestHelper.await(decryptCountDownLatch) }
mTestHelper.signout(aliceSession) mTestHelper.signout(aliceSession)
} }
@ -284,16 +251,17 @@ class QuadSTests : InstrumentedTest {
val mySecretText = "Lorem ipsum dolor sit amet, consectetur adipiscing elit" val mySecretText = "Lorem ipsum dolor sit amet, consectetur adipiscing elit"
val storeLatch = CountDownLatch(1) mTestHelper.doSync<Unit> {
aliceSession.sharedSecretStorageService.storeSecret( aliceSession.sharedSecretStorageService.storeSecret(
"my.secret", "my.secret",
mySecretText.toByteArray().toBase64NoPadding(), mySecretText.toByteArray().toBase64NoPadding(),
listOf(keyId1), listOf(keyId1),
TestMatrixCallback(storeLatch) it
) )
mTestHelper.await(storeLatch) }
val decryptCountDownLatch = CountDownLatch(2) val decryptCountDownLatch = CountDownLatch(1)
var error = false
aliceSession.sharedSecretStorageService.getSecret("my.secret", aliceSession.sharedSecretStorageService.getSecret("my.secret",
keyId1, keyId1,
Curve25519AesSha2KeySpec.fromPassphrase( Curve25519AesSha2KeySpec.fromPassphrase(
@ -304,29 +272,32 @@ class QuadSTests : InstrumentedTest {
object : MatrixCallback<String> { object : MatrixCallback<String> {
override fun onSuccess(data: String) { override fun onSuccess(data: String) {
decryptCountDownLatch.countDown() decryptCountDownLatch.countDown()
fail("Should not be able to decrypt")
} }
override fun onFailure(failure: Throwable) { override fun onFailure(failure: Throwable) {
Assert.assertTrue(true) error = true
decryptCountDownLatch.countDown() decryptCountDownLatch.countDown()
} }
} }
) )
// Now try with correct key
aliceSession.sharedSecretStorageService.getSecret("my.secret",
keyId1,
Curve25519AesSha2KeySpec.fromPassphrase(
passphrase,
key1Info.content?.passphrase?.salt ?: "",
key1Info.content?.passphrase?.iterations ?: 0,
null),
TestMatrixCallback(decryptCountDownLatch)
)
mTestHelper.await(decryptCountDownLatch) mTestHelper.await(decryptCountDownLatch)
error shouldBe true
// Now try with correct key
mTestHelper.doSync<String> {
aliceSession.sharedSecretStorageService.getSecret("my.secret",
keyId1,
Curve25519AesSha2KeySpec.fromPassphrase(
passphrase,
key1Info.content?.passphrase?.salt ?: "",
key1Info.content?.passphrase?.iterations ?: 0,
null),
it
)
}
mTestHelper.signout(aliceSession) mTestHelper.signout(aliceSession)
} }
@ -346,7 +317,7 @@ class QuadSTests : InstrumentedTest {
GlobalScope.launch(Dispatchers.Main) { liveAccountData.observeForever(accountDataObserver) } GlobalScope.launch(Dispatchers.Main) { liveAccountData.observeForever(accountDataObserver) }
mTestHelper.await(accountDataLock) mTestHelper.await(accountDataLock)
Assert.assertNotNull("Account Data type:$type should be found", accountData) assertNotNull("Account Data type:$type should be found", accountData)
return accountData!! return accountData!!
} }
@ -354,78 +325,36 @@ class QuadSTests : InstrumentedTest {
private fun generatedSecret(session: Session, keyId: String, asDefault: Boolean = true): SsssKeyCreationInfo { private fun generatedSecret(session: Session, keyId: String, asDefault: Boolean = true): SsssKeyCreationInfo {
val quadS = session.sharedSecretStorageService val quadS = session.sharedSecretStorageService
val emptyKeySigner = object : KeySigner { val creationInfo = mTestHelper.doSync<SsssKeyCreationInfo> {
override fun sign(canonicalJson: String): Map<String, Map<String, String>>? { quadS.generateKey(keyId, keyId, emptyKeySigner, it)
return null
}
} }
var creationInfo: SsssKeyCreationInfo? = null assertAccountData(session, "${DefaultSharedSecretStorageService.KEY_ID_BASE}.$keyId")
val generateLatch = CountDownLatch(1)
quadS.generateKey(keyId, keyId, emptyKeySigner,
object : MatrixCallback<SsssKeyCreationInfo> {
override fun onSuccess(data: SsssKeyCreationInfo) {
creationInfo = data
generateLatch.countDown()
}
override fun onFailure(failure: Throwable) {
Assert.fail("onFailure " + failure.localizedMessage)
generateLatch.countDown()
}
})
mTestHelper.await(generateLatch)
Assert.assertNotNull(creationInfo)
assertAccountData(session, "m.secret_storage.key.$keyId")
if (asDefault) { if (asDefault) {
val setDefaultLatch = CountDownLatch(1) mTestHelper.doSync<Unit> {
quadS.setDefaultKey(keyId, TestMatrixCallback(setDefaultLatch)) quadS.setDefaultKey(keyId, it)
mTestHelper.await(setDefaultLatch) }
assertAccountData(session, DefaultSharedSecretStorageService.DEFAULT_KEY_ID) assertAccountData(session, DefaultSharedSecretStorageService.DEFAULT_KEY_ID)
} }
return creationInfo!! return creationInfo
} }
private fun generatedSecretFromPassphrase(session: Session, passphrase: String, keyId: String, asDefault: Boolean = true): SsssKeyCreationInfo { private fun generatedSecretFromPassphrase(session: Session, passphrase: String, keyId: String, asDefault: Boolean = true): SsssKeyCreationInfo {
val quadS = session.sharedSecretStorageService val quadS = session.sharedSecretStorageService
val emptyKeySigner = object : KeySigner { val creationInfo = mTestHelper.doSync<SsssKeyCreationInfo> {
override fun sign(canonicalJson: String): Map<String, Map<String, String>>? { quadS.generateKeyWithPassphrase(
return null keyId,
} keyId,
passphrase,
emptyKeySigner,
null,
it)
} }
var creationInfo: SsssKeyCreationInfo? = null assertAccountData(session, "${DefaultSharedSecretStorageService.KEY_ID_BASE}.$keyId")
val generateLatch = CountDownLatch(1)
quadS.generateKeyWithPassphrase(keyId, keyId,
passphrase,
emptyKeySigner,
null,
object : MatrixCallback<SsssKeyCreationInfo> {
override fun onSuccess(data: SsssKeyCreationInfo) {
creationInfo = data
generateLatch.countDown()
}
override fun onFailure(failure: Throwable) {
Assert.fail("onFailure " + failure.localizedMessage)
generateLatch.countDown()
}
})
mTestHelper.await(generateLatch)
Assert.assertNotNull(creationInfo)
assertAccountData(session, "m.secret_storage.key.$keyId")
if (asDefault) { if (asDefault) {
val setDefaultLatch = CountDownLatch(1) val setDefaultLatch = CountDownLatch(1)
quadS.setDefaultKey(keyId, TestMatrixCallback(setDefaultLatch)) quadS.setDefaultKey(keyId, TestMatrixCallback(setDefaultLatch))
@ -433,6 +362,6 @@ class QuadSTests : InstrumentedTest {
assertAccountData(session, DefaultSharedSecretStorageService.DEFAULT_KEY_ID) assertAccountData(session, DefaultSharedSecretStorageService.DEFAULT_KEY_ID)
} }
return creationInfo!! return creationInfo
} }
} }

View file

@ -19,18 +19,35 @@ package im.vector.matrix.android.api.session.accountdata
import androidx.lifecycle.LiveData import androidx.lifecycle.LiveData
import im.vector.matrix.android.api.MatrixCallback import im.vector.matrix.android.api.MatrixCallback
import im.vector.matrix.android.api.session.events.model.Content import im.vector.matrix.android.api.session.events.model.Content
import im.vector.matrix.android.api.util.Cancelable
import im.vector.matrix.android.api.util.Optional import im.vector.matrix.android.api.util.Optional
import im.vector.matrix.android.internal.session.sync.model.accountdata.UserAccountDataEvent import im.vector.matrix.android.internal.session.sync.model.accountdata.UserAccountDataEvent
interface AccountDataService { interface AccountDataService {
/**
* Retrieve the account data with the provided type or null if not found
*/
fun getAccountDataEvent(type: String): UserAccountDataEvent? fun getAccountDataEvent(type: String): UserAccountDataEvent?
/**
* Observe the account data with the provided type
*/
fun getLiveAccountDataEvent(type: String): LiveData<Optional<UserAccountDataEvent>> fun getLiveAccountDataEvent(type: String): LiveData<Optional<UserAccountDataEvent>>
fun getAccountDataEvents(filterType: List<String>): List<UserAccountDataEvent> /**
* Retrieve the account data with the provided types. The return list can have a different size that
* the size of the types set, because some AccountData may not exist.
* If an empty set is provided, all the AccountData are retrieved
*/
fun getAccountDataEvents(types: Set<String>): List<UserAccountDataEvent>
fun getLiveAccountDataEvents(filterType: List<String>): LiveData<List<UserAccountDataEvent>> /**
* Observe the account data with the provided types. If an empty set is provided, all the AccountData are observed
*/
fun getLiveAccountDataEvents(types: Set<String>): LiveData<List<UserAccountDataEvent>>
fun updateAccountData(type: String, content: Content, callback: MatrixCallback<Unit>? = null) /**
* Update the account data with the provided type and the provided account data content
*/
fun updateAccountData(type: String, content: Content, callback: MatrixCallback<Unit>? = null): Cancelable
} }

View file

@ -19,6 +19,7 @@ package im.vector.matrix.android.api.session.securestorage
import com.squareup.moshi.Json import com.squareup.moshi.Json
import com.squareup.moshi.JsonClass import com.squareup.moshi.JsonClass
import im.vector.matrix.android.internal.di.MoshiProvider import im.vector.matrix.android.internal.di.MoshiProvider
import im.vector.matrix.android.internal.session.user.accountdata.AccountDataContent
/** /**
* The account_data will have an encrypted property that is a map from key ID to an object. * The account_data will have an encrypted property that is a map from key ID to an object.
@ -32,7 +33,7 @@ data class EncryptedSecretContent(
@Json(name = "ciphertext") val ciphertext: String? = null, @Json(name = "ciphertext") val ciphertext: String? = null,
@Json(name = "mac") val mac: String? = null, @Json(name = "mac") val mac: String? = null,
@Json(name = "ephemeral") val ephemeral: String? = null @Json(name = "ephemeral") val ephemeral: String? = null
) { ) : AccountDataContent {
companion object { companion object {
/** /**
* Facility method to convert from object which must be comprised of maps, lists, * Facility method to convert from object which must be comprised of maps, lists,

View file

@ -54,25 +54,28 @@ data class SecretStorageKeyContent(
/** Currently support m.secret_storage.v1.curve25519-aes-sha2 */ /** Currently support m.secret_storage.v1.curve25519-aes-sha2 */
@Json(name = "algorithm") val algorithm: String? = null, @Json(name = "algorithm") val algorithm: String? = null,
@Json(name = "name") val name: String? = null, @Json(name = "name") val name: String? = null,
@Json(name = "passphrase") val passphrase: SSSSPassphrase? = null, @Json(name = "passphrase") val passphrase: SsssPassphrase? = null,
@Json(name = "pubkey") val publicKey: String? = null, @Json(name = "pubkey") val publicKey: String? = null,
@Json(name = "signatures") @Json(name = "signatures") val signatures: Map<String, Map<String, String>>? = null
var signatures: Map<String, Map<String, String>>? = null
) { ) {
private fun signalableJSONDictionary(): Map<String, Any> { private fun signalableJSONDictionary(): Map<String, Any> {
val map = HashMap<String, Any>() return mutableMapOf<String, Any>().apply {
algorithm?.let { map["algorithm"] = it } algorithm
name?.let { map["name"] = it } ?.let { this["algorithm"] = it }
publicKey?.let { map["pubkey"] = it } name
passphrase?.let { ssspp -> ?.let { this["name"] = it }
map["passphrase"] = mapOf( publicKey
"algorithm" to ssspp.algorithm, ?.let { this["pubkey"] = it }
"iterations" to ssspp.salt, passphrase
"salt" to ssspp.salt ?.let { ssssPassphrase ->
) this["passphrase"] = mapOf(
"algorithm" to ssssPassphrase.algorithm,
"iterations" to ssssPassphrase.iterations,
"salt" to ssssPassphrase.salt
)
}
} }
return map
} }
fun canonicalSignable(): String { fun canonicalSignable(): String {
@ -93,7 +96,7 @@ data class SecretStorageKeyContent(
} }
@JsonClass(generateAdapter = true) @JsonClass(generateAdapter = true)
data class SSSSPassphrase( data class SsssPassphrase(
@Json(name = "algorithm") val algorithm: String?, @Json(name = "algorithm") val algorithm: String?,
@Json(name = "iterations") val iterations: Int, @Json(name = "iterations") val iterations: Int,
@Json(name = "salt") val salt: String? @Json(name = "salt") val salt: String?

View file

@ -17,7 +17,6 @@
package im.vector.matrix.android.api.session.securestorage package im.vector.matrix.android.api.session.securestorage
sealed class SharedSecretStorageError(message: String?) : Throwable(message) { sealed class SharedSecretStorageError(message: String?) : Throwable(message) {
data class UnknownSecret(val secretName: String) : SharedSecretStorageError("Unknown Secret $secretName") data class UnknownSecret(val secretName: String) : SharedSecretStorageError("Unknown Secret $secretName")
data class UnknownKey(val keyId: String) : SharedSecretStorageError("Unknown key $keyId") data class UnknownKey(val keyId: String) : SharedSecretStorageError("Unknown key $keyId")
data class UnknownAlgorithm(val keyId: String) : SharedSecretStorageError("Unknown algorithm $keyId") data class UnknownAlgorithm(val keyId: String) : SharedSecretStorageError("Unknown algorithm $keyId")

View file

@ -108,5 +108,5 @@ interface SharedSecretStorageService {
* *
*/ */
@Throws @Throws
fun getSecret(name: String, keyId: String?, secretKey: SSSSKeySpec, callback: MatrixCallback<String>) fun getSecret(name: String, keyId: String?, secretKey: SsssKeySpec, callback: MatrixCallback<String>)
} }

View file

@ -21,11 +21,11 @@ import im.vector.matrix.android.internal.crypto.keysbackup.deriveKey
import im.vector.matrix.android.internal.crypto.keysbackup.util.extractCurveKeyFromRecoveryKey import im.vector.matrix.android.internal.crypto.keysbackup.util.extractCurveKeyFromRecoveryKey
/** Tag class */ /** Tag class */
interface SSSSKeySpec interface SsssKeySpec
data class Curve25519AesSha2KeySpec( data class Curve25519AesSha2KeySpec(
val privateKey: ByteArray val privateKey: ByteArray
) : SSSSKeySpec { ) : SsssKeySpec {
companion object { companion object {

View file

@ -25,24 +25,30 @@ import im.vector.matrix.android.api.session.securestorage.EncryptedSecretContent
import im.vector.matrix.android.api.session.securestorage.KeyInfo import im.vector.matrix.android.api.session.securestorage.KeyInfo
import im.vector.matrix.android.api.session.securestorage.KeyInfoResult import im.vector.matrix.android.api.session.securestorage.KeyInfoResult
import im.vector.matrix.android.api.session.securestorage.KeySigner import im.vector.matrix.android.api.session.securestorage.KeySigner
import im.vector.matrix.android.api.session.securestorage.SsssKeyCreationInfo import im.vector.matrix.android.api.session.securestorage.SsssKeySpec
import im.vector.matrix.android.api.session.securestorage.SSSSKeySpec import im.vector.matrix.android.api.session.securestorage.SsssPassphrase
import im.vector.matrix.android.api.session.securestorage.SSSSPassphrase
import im.vector.matrix.android.api.session.securestorage.SecretStorageKeyContent import im.vector.matrix.android.api.session.securestorage.SecretStorageKeyContent
import im.vector.matrix.android.api.session.securestorage.SharedSecretStorageError import im.vector.matrix.android.api.session.securestorage.SharedSecretStorageError
import im.vector.matrix.android.api.session.securestorage.SharedSecretStorageService import im.vector.matrix.android.api.session.securestorage.SharedSecretStorageService
import im.vector.matrix.android.api.session.securestorage.SsssKeyCreationInfo
import im.vector.matrix.android.internal.crypto.SSSS_ALGORITHM_CURVE25519_AES_SHA2 import im.vector.matrix.android.internal.crypto.SSSS_ALGORITHM_CURVE25519_AES_SHA2
import im.vector.matrix.android.internal.crypto.keysbackup.generatePrivateKeyWithPassword import im.vector.matrix.android.internal.crypto.keysbackup.generatePrivateKeyWithPassword
import im.vector.matrix.android.internal.crypto.keysbackup.util.computeRecoveryKey import im.vector.matrix.android.internal.crypto.keysbackup.util.computeRecoveryKey
import im.vector.matrix.android.internal.crypto.tools.withOlmDecryption
import im.vector.matrix.android.internal.crypto.tools.withOlmEncryption
import im.vector.matrix.android.internal.extensions.foldToCallback import im.vector.matrix.android.internal.extensions.foldToCallback
import im.vector.matrix.android.internal.util.MatrixCoroutineDispatchers import im.vector.matrix.android.internal.util.MatrixCoroutineDispatchers
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import org.matrix.olm.OlmPkDecryption
import org.matrix.olm.OlmPkEncryption
import org.matrix.olm.OlmPkMessage import org.matrix.olm.OlmPkMessage
import javax.inject.Inject import javax.inject.Inject
private data class Key(
val publicKey: String,
@Suppress("ArrayInDataClass")
val privateKey: ByteArray
)
internal class DefaultSharedSecretStorageService @Inject constructor( internal class DefaultSharedSecretStorageService @Inject constructor(
private val accountDataService: AccountDataService, private val accountDataService: AccountDataService,
private val coroutineDispatchers: MatrixCoroutineDispatchers, private val coroutineDispatchers: MatrixCoroutineDispatchers,
@ -54,25 +60,22 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
keySigner: KeySigner, keySigner: KeySigner,
callback: MatrixCallback<SsssKeyCreationInfo>) { callback: MatrixCallback<SsssKeyCreationInfo>) {
cryptoCoroutineScope.launch(coroutineDispatchers.main) { cryptoCoroutineScope.launch(coroutineDispatchers.main) {
val pkDecryption = OlmPkDecryption() val key = try {
val pubKey: String withOlmDecryption { olmPkDecryption ->
val privateKey: ByteArray val pubKey = olmPkDecryption.generateKey()
try { val privateKey = olmPkDecryption.privateKey()
pubKey = pkDecryption.generateKey() Key(pubKey, privateKey)
privateKey = pkDecryption.privateKey()
} catch (failure: Throwable) {
return@launch Unit.also {
callback.onFailure(failure)
} }
} finally { } catch (failure: Throwable) {
pkDecryption.releaseDecryption() callback.onFailure(failure)
return@launch
} }
val storageKeyContent = SecretStorageKeyContent( val storageKeyContent = SecretStorageKeyContent(
name = keyName, name = keyName,
algorithm = SSSS_ALGORITHM_CURVE25519_AES_SHA2, algorithm = SSSS_ALGORITHM_CURVE25519_AES_SHA2,
passphrase = null, passphrase = null,
publicKey = pubKey publicKey = key.publicKey
) )
val signedContent = keySigner.sign(storageKeyContent.canonicalSignable())?.let { val signedContent = keySigner.sign(storageKeyContent.canonicalSignable())?.let {
@ -93,7 +96,7 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
callback.onSuccess(SsssKeyCreationInfo( callback.onSuccess(SsssKeyCreationInfo(
keyId = keyId, keyId = keyId,
content = storageKeyContent, content = storageKeyContent,
recoveryKey = computeRecoveryKey(privateKey) recoveryKey = computeRecoveryKey(key.privateKey)
)) ))
} }
} }
@ -110,21 +113,18 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
cryptoCoroutineScope.launch(coroutineDispatchers.main) { cryptoCoroutineScope.launch(coroutineDispatchers.main) {
val privatePart = generatePrivateKeyWithPassword(passphrase, progressListener) val privatePart = generatePrivateKeyWithPassword(passphrase, progressListener)
val pkDecryption = OlmPkDecryption() val pubKey = try {
val pubKey: String withOlmDecryption { olmPkDecryption ->
try { olmPkDecryption.setPrivateKey(privatePart.privateKey)
pubKey = pkDecryption.setPrivateKey(privatePart.privateKey)
} catch (failure: Throwable) {
return@launch Unit.also {
callback.onFailure(failure)
} }
} finally { } catch (failure: Throwable) {
pkDecryption.releaseDecryption() callback.onFailure(failure)
return@launch
} }
val storageKeyContent = SecretStorageKeyContent( val storageKeyContent = SecretStorageKeyContent(
algorithm = SSSS_ALGORITHM_CURVE25519_AES_SHA2, algorithm = SSSS_ALGORITHM_CURVE25519_AES_SHA2,
passphrase = SSSSPassphrase(algorithm = "m.pbkdf2", iterations = privatePart.iterations, salt = privatePart.salt), passphrase = SsssPassphrase(algorithm = "m.pbkdf2", iterations = privatePart.iterations, salt = privatePart.salt),
publicKey = pubKey publicKey = pubKey
) )
@ -192,21 +192,20 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
cryptoCoroutineScope.launch(coroutineDispatchers.main) { cryptoCoroutineScope.launch(coroutineDispatchers.main) {
val encryptedContents = HashMap<String, EncryptedSecretContent>() val encryptedContents = HashMap<String, EncryptedSecretContent>()
try { try {
if (keys == null || keys.isEmpty()) { if (keys.isNullOrEmpty()) {
// use default key // use default key
val key = getDefaultKey() when (val key = getDefaultKey()) {
when (key) {
is KeyInfoResult.Success -> { is KeyInfoResult.Success -> {
if (key.keyInfo.content.algorithm == SSSS_ALGORITHM_CURVE25519_AES_SHA2) { if (key.keyInfo.content.algorithm == SSSS_ALGORITHM_CURVE25519_AES_SHA2) {
withOlmEncryption { olmEncrypt -> val encryptedResult = withOlmEncryption { olmEncrypt ->
olmEncrypt.setRecipientKey(key.keyInfo.content.publicKey) olmEncrypt.setRecipientKey(key.keyInfo.content.publicKey)
val encryptedResult = olmEncrypt.encrypt(secretBase64) olmEncrypt.encrypt(secretBase64)
encryptedContents[key.keyInfo.id] = EncryptedSecretContent(
ciphertext = encryptedResult.mCipherText,
ephemeral = encryptedResult.mEphemeralKey,
mac = encryptedResult.mMac
)
} }
encryptedContents[key.keyInfo.id] = EncryptedSecretContent(
ciphertext = encryptedResult.mCipherText,
ephemeral = encryptedResult.mEphemeralKey,
mac = encryptedResult.mMac
)
} else { } else {
// Unknown algorithm // Unknown algorithm
callback.onFailure(SharedSecretStorageError.UnknownAlgorithm(key.keyInfo.content.algorithm ?: "")) callback.onFailure(SharedSecretStorageError.UnknownAlgorithm(key.keyInfo.content.algorithm ?: ""))
@ -222,19 +221,18 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
keys.forEach { keys.forEach {
val keyId = it val keyId = it
// encrypt the content // encrypt the content
val key = getKey(keyId) when (val key = getKey(keyId)) {
when (key) {
is KeyInfoResult.Success -> { is KeyInfoResult.Success -> {
if (key.keyInfo.content.algorithm == SSSS_ALGORITHM_CURVE25519_AES_SHA2) { if (key.keyInfo.content.algorithm == SSSS_ALGORITHM_CURVE25519_AES_SHA2) {
withOlmEncryption { olmEncrypt -> val encryptedResult = withOlmEncryption { olmEncrypt ->
olmEncrypt.setRecipientKey(key.keyInfo.content.publicKey) olmEncrypt.setRecipientKey(key.keyInfo.content.publicKey)
val encryptedResult = olmEncrypt.encrypt(secretBase64) olmEncrypt.encrypt(secretBase64)
encryptedContents[keyId] = EncryptedSecretContent(
ciphertext = encryptedResult.mCipherText,
ephemeral = encryptedResult.mEphemeralKey,
mac = encryptedResult.mMac
)
} }
encryptedContents[keyId] = EncryptedSecretContent(
ciphertext = encryptedResult.mCipherText,
ephemeral = encryptedResult.mEphemeralKey,
mac = encryptedResult.mMac
)
} else { } else {
// Unknown algorithm // Unknown algorithm
callback.onFailure(SharedSecretStorageError.UnknownAlgorithm(key.keyInfo.content.algorithm ?: "")) callback.onFailure(SharedSecretStorageError.UnknownAlgorithm(key.keyInfo.content.algorithm ?: ""))
@ -279,7 +277,7 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
return results return results
} }
override fun getSecret(name: String, keyId: String?, secretKey: SSSSKeySpec, callback: MatrixCallback<String>) { override fun getSecret(name: String, keyId: String?, secretKey: SsssKeySpec, callback: MatrixCallback<String>) {
val accountData = accountDataService.getAccountDataEvent(name) ?: return Unit.also { val accountData = accountDataService.getAccountDataEvent(name) ?: return Unit.also {
callback.onFailure(SharedSecretStorageError.UnknownSecret(name)) callback.onFailure(SharedSecretStorageError.UnknownSecret(name))
} }
@ -306,20 +304,16 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
} }
cryptoCoroutineScope.launch(coroutineDispatchers.main) { cryptoCoroutineScope.launch(coroutineDispatchers.main) {
kotlin.runCatching { kotlin.runCatching {
// decryt from recovery key // decrypt from recovery key
val keyBytes = keySpec.privateKey withOlmDecryption { olmPkDecryption ->
val decryption = OlmPkDecryption() olmPkDecryption.setPrivateKey(keySpec.privateKey)
try { olmPkDecryption.decrypt(OlmPkMessage()
decryption.setPrivateKey(keyBytes) .apply {
decryption.decrypt(OlmPkMessage().apply { mCipherText = secretContent.ciphertext
mCipherText = secretContent.ciphertext mEphemeralKey = secretContent.ephemeral
mEphemeralKey = secretContent.ephemeral mMac = secretContent.mac
mMac = secretContent.mac }
}) )
} catch (failure: Throwable) {
throw failure
} finally {
decryption.releaseDecryption()
} }
}.foldToCallback(callback) }.foldToCallback(callback)
} }
@ -332,27 +326,5 @@ internal class DefaultSharedSecretStorageService @Inject constructor(
const val KEY_ID_BASE = "m.secret_storage.key" const val KEY_ID_BASE = "m.secret_storage.key"
const val ENCRYPTED = "encrypted" const val ENCRYPTED = "encrypted"
const val DEFAULT_KEY_ID = "m.secret_storage.default_key" const val DEFAULT_KEY_ID = "m.secret_storage.default_key"
fun withOlmEncryption(block: (OlmPkEncryption) -> Unit) {
val olmPkEncryption = OlmPkEncryption()
try {
block(olmPkEncryption)
} catch (failure: Throwable) {
throw failure
} finally {
olmPkEncryption.releaseEncryption()
}
}
fun withOlmDecryption(block: (OlmPkDecryption) -> Unit) {
val olmPkDecryption = OlmPkDecryption()
try {
block(olmPkDecryption)
} catch (failure: Throwable) {
throw failure
} finally {
olmPkDecryption.releaseDecryption()
}
}
} }
} }

View file

@ -0,0 +1,38 @@
/*
* Copyright (c) 2020 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.matrix.android.internal.crypto.tools
import org.matrix.olm.OlmPkDecryption
import org.matrix.olm.OlmPkEncryption
fun <T> withOlmEncryption(block: (OlmPkEncryption) -> T): T {
val olmPkEncryption = OlmPkEncryption()
try {
return block(olmPkEncryption)
} finally {
olmPkEncryption.releaseEncryption()
}
}
fun <T> withOlmDecryption(block: (OlmPkDecryption) -> T): T {
val olmPkDecryption = OlmPkDecryption()
try {
return block(olmPkDecryption)
} finally {
olmPkDecryption.releaseDecryption()
}
}

View file

@ -269,7 +269,7 @@ internal abstract class SessionModule {
abstract fun bindHomeServerCapabilitiesService(homeServerCapabilitiesService: DefaultHomeServerCapabilitiesService): HomeServerCapabilitiesService abstract fun bindHomeServerCapabilitiesService(homeServerCapabilitiesService: DefaultHomeServerCapabilitiesService): HomeServerCapabilitiesService
@Binds @Binds
abstract fun bindAccountDataService(accountDataService: DefaultAccountDataService): AccountDataService abstract fun bindAccountDataService(service: DefaultAccountDataService): AccountDataService
@Binds @Binds
abstract fun bindSharedSecretStorageService(service: DefaultSharedSecretStorageService): SharedSecretStorageService abstract fun bindSharedSecretStorageService(service: DefaultSharedSecretStorageService): SharedSecretStorageService

View file

@ -17,8 +17,9 @@
package im.vector.matrix.android.internal.session.sync.model.accountdata package im.vector.matrix.android.internal.session.sync.model.accountdata
import com.squareup.moshi.Json import com.squareup.moshi.Json
import im.vector.matrix.android.internal.session.user.accountdata.AccountDataContent
abstract class UserAccountData { abstract class UserAccountData : AccountDataContent {
@Json(name = "type") abstract val type: String @Json(name = "type") abstract val type: String

View file

@ -0,0 +1,22 @@
/*
* Copyright (c) 2020 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.matrix.android.internal.session.user.accountdata
/**
* Tag class to identify every account data content
*/
internal interface AccountDataContent

View file

@ -22,13 +22,13 @@ import com.zhuinden.monarchy.Monarchy
import im.vector.matrix.android.api.MatrixCallback import im.vector.matrix.android.api.MatrixCallback
import im.vector.matrix.android.api.session.accountdata.AccountDataService import im.vector.matrix.android.api.session.accountdata.AccountDataService
import im.vector.matrix.android.api.session.events.model.Content import im.vector.matrix.android.api.session.events.model.Content
import im.vector.matrix.android.api.util.Cancelable
import im.vector.matrix.android.api.util.JSON_DICT_PARAMETERIZED_TYPE import im.vector.matrix.android.api.util.JSON_DICT_PARAMETERIZED_TYPE
import im.vector.matrix.android.api.util.Optional import im.vector.matrix.android.api.util.Optional
import im.vector.matrix.android.api.util.toOptional import im.vector.matrix.android.api.util.toOptional
import im.vector.matrix.android.internal.database.model.UserAccountDataEntity import im.vector.matrix.android.internal.database.model.UserAccountDataEntity
import im.vector.matrix.android.internal.database.model.UserAccountDataEntityFields import im.vector.matrix.android.internal.database.model.UserAccountDataEntityFields
import im.vector.matrix.android.internal.di.MoshiProvider import im.vector.matrix.android.internal.di.MoshiProvider
import im.vector.matrix.android.internal.di.SessionId
import im.vector.matrix.android.internal.session.sync.UserAccountDataSyncHandler import im.vector.matrix.android.internal.session.sync.UserAccountDataSyncHandler
import im.vector.matrix.android.internal.session.sync.model.accountdata.UserAccountDataEvent import im.vector.matrix.android.internal.session.sync.model.accountdata.UserAccountDataEvent
import im.vector.matrix.android.internal.task.TaskExecutor import im.vector.matrix.android.internal.task.TaskExecutor
@ -37,7 +37,6 @@ import javax.inject.Inject
internal class DefaultAccountDataService @Inject constructor( internal class DefaultAccountDataService @Inject constructor(
private val monarchy: Monarchy, private val monarchy: Monarchy,
@SessionId private val sessionId: String,
private val updateUserAccountDataTask: UpdateUserAccountDataTask, private val updateUserAccountDataTask: UpdateUserAccountDataTask,
private val userAccountDataSyncHandler: UserAccountDataSyncHandler, private val userAccountDataSyncHandler: UserAccountDataSyncHandler,
private val taskExecutor: TaskExecutor private val taskExecutor: TaskExecutor
@ -47,39 +46,39 @@ internal class DefaultAccountDataService @Inject constructor(
private val adapter = moshi.adapter<Map<String, Any>>(JSON_DICT_PARAMETERIZED_TYPE) private val adapter = moshi.adapter<Map<String, Any>>(JSON_DICT_PARAMETERIZED_TYPE)
override fun getAccountDataEvent(type: String): UserAccountDataEvent? { override fun getAccountDataEvent(type: String): UserAccountDataEvent? {
return getAccountDataEvents(listOf(type)).firstOrNull() return getAccountDataEvents(setOf(type)).firstOrNull()
} }
override fun getLiveAccountDataEvent(type: String): LiveData<Optional<UserAccountDataEvent>> { override fun getLiveAccountDataEvent(type: String): LiveData<Optional<UserAccountDataEvent>> {
return Transformations.map(getLiveAccountDataEvents(listOf(type))) { return Transformations.map(getLiveAccountDataEvents(setOf(type))) {
it.firstOrNull()?.toOptional() it.firstOrNull()?.toOptional()
} }
} }
override fun getAccountDataEvents(filterType: List<String>): List<UserAccountDataEvent> { override fun getAccountDataEvents(types: Set<String>): List<UserAccountDataEvent> {
return monarchy.fetchAllCopiedSync { realm -> return monarchy.fetchAllCopiedSync { realm ->
realm.where(UserAccountDataEntity::class.java) realm.where(UserAccountDataEntity::class.java)
.apply { .apply {
if (filterType.isNotEmpty()) { if (types.isNotEmpty()) {
`in`(UserAccountDataEntityFields.TYPE, filterType.toTypedArray()) `in`(UserAccountDataEntityFields.TYPE, types.toTypedArray())
} }
} }
}?.mapNotNull { entity -> }.mapNotNull { entity ->
entity.type?.let { type -> entity.type?.let { type ->
UserAccountDataEvent( UserAccountDataEvent(
type = type, type = type,
content = entity.contentStr?.let { adapter.fromJson(it) } ?: emptyMap() content = entity.contentStr?.let { adapter.fromJson(it) } ?: emptyMap()
) )
} }
} ?: emptyList() }
} }
override fun getLiveAccountDataEvents(filterType: List<String>): LiveData<List<UserAccountDataEvent>> { override fun getLiveAccountDataEvents(types: Set<String>): LiveData<List<UserAccountDataEvent>> {
return monarchy.findAllMappedWithChanges({ realm -> return monarchy.findAllMappedWithChanges({ realm ->
realm.where(UserAccountDataEntity::class.java) realm.where(UserAccountDataEntity::class.java)
.apply { .apply {
if (filterType.isNotEmpty()) { if (types.isNotEmpty()) {
`in`(UserAccountDataEntityFields.TYPE, filterType.toTypedArray()) `in`(UserAccountDataEntityFields.TYPE, types.toTypedArray())
} }
} }
}, { entity -> }, { entity ->
@ -90,14 +89,15 @@ internal class DefaultAccountDataService @Inject constructor(
}) })
} }
override fun updateAccountData(type: String, content: Content, callback: MatrixCallback<Unit>?) { override fun updateAccountData(type: String, content: Content, callback: MatrixCallback<Unit>?): Cancelable {
updateUserAccountDataTask.configureWith(UpdateUserAccountDataTask.AnyParams( return updateUserAccountDataTask.configureWith(UpdateUserAccountDataTask.AnyParams(
type = type, type = type,
any = content any = content
)) { )) {
this.retryCount = 5 this.retryCount = 5
this.callback = object : MatrixCallback<Unit> { this.callback = object : MatrixCallback<Unit> {
override fun onSuccess(data: Unit) { override fun onSuccess(data: Unit) {
// TODO Move that to the task (but it created a circular dependencies...)
monarchy.runTransactionSync { realm -> monarchy.runTransactionSync { realm ->
userAccountDataSyncHandler.handleGenericAccountData(realm, type, content) userAccountDataSyncHandler.handleGenericAccountData(realm, type, content)
} }

View file

@ -40,7 +40,7 @@ class AccountDataViewModel @AssistedInject constructor(@Assisted initialState: A
: VectorViewModel<AccountDataViewState, EmptyAction, EmptyViewEvents>(initialState) { : VectorViewModel<AccountDataViewState, EmptyAction, EmptyViewEvents>(initialState) {
init { init {
session.rx().liveAccountData(emptyList()) session.rx().liveAccountData(emptySet())
.execute { .execute {
copy(accountData = it) copy(accountData = it)
} }