diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensions.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensions.kt new file mode 100644 index 000000000..f3b1f5fc1 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensions.kt @@ -0,0 +1,33 @@ +package com.x8bit.bitwarden.data.auth.repository.util + +import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson +import com.x8bit.bitwarden.data.vault.datasource.network.model.SyncResponseJson + +/** + * Updates the given [UserStateJson] with the data from the [syncResponse] to return a new + * [UserStateJson]. The original will be returned if the sync response does not match any accounts + * in the [UserStateJson]. + */ +@Suppress("ReturnCount") +fun UserStateJson.toUpdatedUserStateJson( + syncResponse: SyncResponseJson, +): UserStateJson { + val userId = syncResponse.profile?.id ?: return this + val account = this.accounts[userId] ?: return this + val profile = account.profile + // TODO: Update additional missing UserStateJson properties (BIT-916) + val updatedProfile = profile + .copy( + avatarColorHex = syncResponse.profile.avatarColor, + stamp = syncResponse.profile.securityStamp, + ) + val updatedAccount = account.copy(profile = updatedProfile) + return this + .copy( + accounts = accounts + .toMutableMap() + .apply { + replace(userId, updatedAccount) + }, + ) +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt index 34cca3abd..b5003c2f5 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt @@ -5,6 +5,7 @@ import com.bitwarden.core.FolderView import com.bitwarden.core.InitCryptoRequest import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams +import com.x8bit.bitwarden.data.auth.repository.util.toUpdatedUserStateJson import com.x8bit.bitwarden.data.platform.datasource.network.util.isNoConnectionError import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.repository.model.DataState @@ -85,6 +86,13 @@ class VaultRepositoryImpl constructor( .sync() .fold( onSuccess = { syncResponse -> + // Update user information with additional information from sync response + authDiskSource.userState = authDiskSource + .userState + ?.toUpdatedUserStateJson( + syncResponse = syncResponse, + ) + storeUserKeyAndPrivateKey( userKey = syncResponse.profile?.key, privateKey = syncResponse.profile?.privateKey, diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt index 7058bd565..4cc592ee0 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt @@ -41,6 +41,13 @@ class FakeAuthDiskSource : AuthDiskSource { private val storedPrivateKeys = mutableMapOf() + /** + * Assert that the given [userState] matches the currently tracked value. + */ + fun assertUserState(userState: UserStateJson) { + assertEquals(userState, this.userState) + } + /** * Assert that the [userKey] was stored successfully using the [userId]. */ diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensionsTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensionsTest.kt new file mode 100644 index 000000000..938d803fc --- /dev/null +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensionsTest.kt @@ -0,0 +1,85 @@ +package com.x8bit.bitwarden.data.auth.repository.util + +import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson +import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson +import io.mockk.every +import io.mockk.mockk +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class UserStateJsonExtensionsTest { + @Test + fun `toUpdatedUserStateJson should do nothing for a non-matching account`() { + val originalUserState = UserStateJson( + activeUserId = "activeUserId", + accounts = mapOf( + "activeUserId" to mockk(), + ), + ) + assertEquals( + originalUserState, + originalUserState + .toUpdatedUserStateJson( + syncResponse = mockk { + every { profile } returns mockk { + every { id } returns "otherUserId" + } + }, + ), + ) + } + + @Test + fun `toUpdatedUserStateJson should update the correct account with new information`() { + val originalProfile = AccountJson.Profile( + userId = "activeUserId", + email = "email", + isEmailVerified = true, + name = "name", + stamp = null, + organizationId = null, + avatarColorHex = null, + hasPremium = true, + forcePasswordResetReason = null, + kdfType = KdfTypeJson.ARGON2_ID, + kdfIterations = 600000, + kdfMemory = 16, + kdfParallelism = 4, + userDecryptionOptions = null, + ) + val originalAccount = AccountJson( + profile = originalProfile, + tokens = mockk(), + settings = mockk(), + ) + assertEquals( + UserStateJson( + activeUserId = "activeUserId", + accounts = mapOf( + "activeUserId" to originalAccount.copy( + profile = originalProfile.copy( + avatarColorHex = "avatarColor", + stamp = "securityStamp", + ), + ), + ), + ), + UserStateJson( + activeUserId = "activeUserId", + accounts = mapOf( + "activeUserId" to originalAccount, + ), + ) + .toUpdatedUserStateJson( + syncResponse = mockk { + every { profile } returns mockk { + every { id } returns "activeUserId" + every { avatarColor } returns "avatarColor" + every { securityStamp } returns "securityStamp" + } + }, + ), + ) + } +} diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt index 5c75ac430..4d86cf3ec 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt @@ -74,12 +74,26 @@ class VaultRepositoryTest { vaultRepository.sync() + val updatedUserState = MOCK_USER_STATE + .copy( + accounts = mapOf( + "mockId-1" to MOCK_ACCOUNT.copy( + profile = MOCK_PROFILE.copy( + avatarColorHex = "mockAvatarColor-1", + stamp = "mockSecurityStamp-1", + ), + ), + ), + ) + fakeAuthDiskSource.assertUserState( + userState = updatedUserState, + ) fakeAuthDiskSource.assertUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = "mockKey-1", ) fakeAuthDiskSource.assertPrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = "mockPrivateKey-1", ) assertEquals( @@ -462,11 +476,11 @@ class VaultRepositoryTest { vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) } returns listOf(createMockSendView(number = 1)).asSuccess() fakeAuthDiskSource.storePrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = "mockPrivateKey-1", ) fakeAuthDiskSource.storeUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE @@ -507,11 +521,11 @@ class VaultRepositoryTest { vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) } returns listOf(createMockSendView(number = 1)).asSuccess() fakeAuthDiskSource.storePrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = "mockPrivateKey-1", ) fakeAuthDiskSource.storeUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE @@ -551,11 +565,11 @@ class VaultRepositoryTest { vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) } returns listOf(createMockFolderView(number = 1)).asSuccess() fakeAuthDiskSource.storePrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = "mockPrivateKey-1", ) fakeAuthDiskSource.storeUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE @@ -598,11 +612,11 @@ class VaultRepositoryTest { vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) } returns mockk() fakeAuthDiskSource.storePrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = "mockPrivateKey-1", ) fakeAuthDiskSource.storeUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE @@ -639,11 +653,11 @@ class VaultRepositoryTest { vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) } returns mockk() fakeAuthDiskSource.storePrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = "mockPrivateKey-1", ) fakeAuthDiskSource.storeUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE @@ -685,11 +699,11 @@ class VaultRepositoryTest { runTest { val result = vaultRepository.unlockVaultAndSync(masterPassword = "") fakeAuthDiskSource.storeUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = null, ) fakeAuthDiskSource.storePrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = "mockPrivateKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE @@ -704,11 +718,11 @@ class VaultRepositoryTest { runTest { val result = vaultRepository.unlockVaultAndSync(masterPassword = "") fakeAuthDiskSource.storeUserKey( - userId = "mockUserId", + userId = "mockId-1", userKey = "mockKey-1", ) fakeAuthDiskSource.storePrivateKey( - userId = "mockUserId", + userId = "mockId-1", privateKey = null, ) fakeAuthDiskSource.userState = MOCK_USER_STATE @@ -1014,33 +1028,37 @@ class VaultRepositoryTest { } } -private val MOCK_USER_STATE = UserStateJson( - activeUserId = "mockUserId", - accounts = mapOf( - "mockUserId" to AccountJson( - profile = AccountJson.Profile( - userId = "activeUserId", - email = "email", - isEmailVerified = true, - name = null, - stamp = null, - organizationId = null, - avatarColorHex = null, - hasPremium = true, - forcePasswordResetReason = null, - kdfType = null, - kdfIterations = null, - kdfMemory = null, - kdfParallelism = null, - userDecryptionOptions = null, - ), - tokens = AccountJson.Tokens( - accessToken = "accessToken", - refreshToken = "refreshToken", - ), - settings = AccountJson.Settings( - environmentUrlData = null, - ), - ), +private val MOCK_PROFILE = AccountJson.Profile( + userId = "mockId-1", + email = "email", + isEmailVerified = true, + name = null, + stamp = null, + organizationId = null, + avatarColorHex = null, + hasPremium = true, + forcePasswordResetReason = null, + kdfType = null, + kdfIterations = null, + kdfMemory = null, + kdfParallelism = null, + userDecryptionOptions = null, +) + +private val MOCK_ACCOUNT = AccountJson( + profile = MOCK_PROFILE, + tokens = AccountJson.Tokens( + accessToken = "accessToken", + refreshToken = "refreshToken", + ), + settings = AccountJson.Settings( + environmentUrlData = null, + ), +) + +private val MOCK_USER_STATE = UserStateJson( + activeUserId = "mockId-1", + accounts = mapOf( + "mockId-1" to MOCK_ACCOUNT, ), )