From c00c8ac022b4bb25d8ec10fec093180b44a27dc1 Mon Sep 17 00:00:00 2001 From: David Perez Date: Mon, 20 Nov 2023 11:34:05 -0600 Subject: [PATCH] Add vault repo methods for getting vault items by ID (#256) --- .../repository/util/DataStateExtensions.kt | 16 ++ .../data/vault/repository/VaultRepository.kt | 14 ++ .../vault/repository/VaultRepositoryImpl.kt | 36 +++ .../vault/repository/VaultRepositoryTest.kt | 214 ++++++++++++++++++ 4 files changed, 280 insertions(+) create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/repository/util/DataStateExtensions.kt diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/util/DataStateExtensions.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/util/DataStateExtensions.kt new file mode 100644 index 000000000..14133ef51 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/util/DataStateExtensions.kt @@ -0,0 +1,16 @@ +package com.x8bit.bitwarden.data.platform.repository.util + +import com.x8bit.bitwarden.data.platform.repository.model.DataState + +/** + * Maps the data inside a [DataState] with the given [transform]. + */ +inline fun DataState.map( + transform: (T) -> R, +): DataState = when (this) { + is DataState.Loaded -> DataState.Loaded(transform(data)) + is DataState.Loading -> DataState.Loading + is DataState.Pending -> DataState.Pending(transform(data)) + is DataState.Error -> DataState.Error(error, data?.let(transform)) + is DataState.NoNetwork -> DataState.NoNetwork(data?.let(transform)) +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepository.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepository.kt index 70ec45b8b..830dd0edf 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepository.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepository.kt @@ -1,5 +1,7 @@ package com.x8bit.bitwarden.data.vault.repository +import com.bitwarden.core.CipherView +import com.bitwarden.core.FolderView import com.x8bit.bitwarden.data.platform.repository.model.DataState import com.x8bit.bitwarden.data.vault.repository.model.SendData import com.x8bit.bitwarden.data.vault.repository.model.VaultData @@ -31,6 +33,18 @@ interface VaultRepository { */ fun sync() + /** + * Flow that represents the data for a specific vault item as found by ID. This may emit `null` + * if the item cannot be found. + */ + fun getVaultItemStateFlow(itemId: String): StateFlow> + + /** + * Flow that represents the data for a specific vault folder as found by ID. This may emit + * `null` if the folder cannot be found. + */ + fun getVaultFolderStateFlow(folderId: String): StateFlow> + /** * Attempt to initialize crypto and sync the vault data. */ diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt index 36bff96fe..34cca3abd 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt @@ -1,11 +1,14 @@ package com.x8bit.bitwarden.data.vault.repository +import com.bitwarden.core.CipherView +import com.bitwarden.core.FolderView import com.bitwarden.core.InitCryptoRequest import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams import com.x8bit.bitwarden.data.platform.datasource.network.util.isNoConnectionError import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.repository.model.DataState +import com.x8bit.bitwarden.data.platform.repository.util.map import com.x8bit.bitwarden.data.platform.util.flatMap import com.x8bit.bitwarden.data.vault.datasource.network.model.SyncResponseJson import com.x8bit.bitwarden.data.vault.datasource.network.service.SyncService @@ -20,12 +23,15 @@ import com.x8bit.bitwarden.data.vault.repository.util.toVaultUnlockResult import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch @@ -102,6 +108,36 @@ class VaultRepositoryImpl constructor( } } + override fun getVaultItemStateFlow(itemId: String): StateFlow> = + vaultDataStateFlow + .map { dataState -> + dataState.map { vaultData -> + vaultData + .cipherViewList + .find { it.id == itemId } + } + } + .stateIn( + scope = scope, + started = SharingStarted.Lazily, + initialValue = DataState.Loading, + ) + + override fun getVaultFolderStateFlow(folderId: String): StateFlow> = + vaultDataStateFlow + .map { dataState -> + dataState.map { vaultData -> + vaultData + .folderViewList + .find { it.id == folderId } + } + } + .stateIn( + scope = scope, + started = SharingStarted.Lazily, + initialValue = DataState.Loading, + ) + override suspend fun unlockVaultAndSync(masterPassword: String): VaultUnlockResult { return flow { willSyncAfterUnlock = true diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt index f0e0d631a..5c75ac430 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt @@ -1,6 +1,8 @@ package com.x8bit.bitwarden.data.vault.repository import app.cash.turbine.test +import com.bitwarden.core.CipherView +import com.bitwarden.core.FolderView import com.bitwarden.core.InitCryptoRequest import com.bitwarden.core.Kdf import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson @@ -798,6 +800,218 @@ class VaultRepositoryTest { ) } } + + @Test + fun `getVaultItemStateFlow should receive updates whenever a sync is called`() = runTest { + val itemId = 1234 + val itemIdString = "mockId-$itemId" + val item = createMockCipherView(itemId) + coEvery { + syncService.sync() + } returns Result.success(createMockSyncResponse(itemId)) + coEvery { + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(itemId))) + } returns listOf(item).asSuccess() + coEvery { + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(itemId))) + } returns listOf(createMockFolderView(1)).asSuccess() + coEvery { + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(itemId))) + } returns listOf(createMockSendView(itemId)).asSuccess() + + vaultRepository.getVaultItemStateFlow(itemIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Loaded(item), awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Pending(item), awaitItem()) + assertEquals(DataState.Loaded(item), awaitItem()) + } + + coVerify(exactly = 2) { + syncService.sync() + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(itemId))) + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(itemId))) + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(itemId))) + } + } + + @Test + fun `getVaultItemStateFlow should update to Error when a sync fails generically`() = runTest { + val folderId = 1234 + val folderIdString = "mockId-$folderId" + val throwable = Throwable("Fail") + coEvery { + syncService.sync() + } returns throwable.asFailure() + + vaultRepository.getVaultItemStateFlow(folderIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Error(throwable), awaitItem()) + } + + coVerify(exactly = 1) { + syncService.sync() + } + } + + @Test + fun `getVaultItemStateFlow should update to NoNetwork when a sync fails from no network`() = + runTest { + val itemId = 1234 + val itemIdString = "mockId-$itemId" + coEvery { + syncService.sync() + } returns UnknownHostException().asFailure() + + vaultRepository.getVaultItemStateFlow(itemIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.NoNetwork(), awaitItem()) + } + + coVerify(exactly = 1) { + syncService.sync() + } + } + + @Test + fun `getVaultItemStateFlow should update to Loaded with null when a item cannot be found`() = + runTest { + val itemIdString = "mockId-1234" + coEvery { + syncService.sync() + } returns Result.success(createMockSyncResponse(1)) + coEvery { + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(1))) + } returns listOf(createMockCipherView(1)).asSuccess() + coEvery { + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) + } returns listOf(createMockFolderView(1)).asSuccess() + coEvery { + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) + } returns listOf(createMockSendView(1)).asSuccess() + + vaultRepository.getVaultItemStateFlow(itemIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Loaded(null), awaitItem()) + } + + coVerify(exactly = 1) { + syncService.sync() + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(1))) + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) + } + } + + @Test + fun `getVaultFolderStateFlow should receive updates whenever a sync is called`() = runTest { + val folderId = 1234 + val folderIdString = "mockId-$folderId" + val folder = createMockFolderView(folderId) + coEvery { + syncService.sync() + } returns Result.success(createMockSyncResponse(folderId)) + coEvery { + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(folderId))) + } returns listOf(createMockCipherView(folderId)).asSuccess() + coEvery { + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(folderId))) + } returns listOf(createMockFolderView(folderId)).asSuccess() + coEvery { + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(folderId))) + } returns listOf(createMockSendView(folderId)).asSuccess() + + vaultRepository.getVaultFolderStateFlow(folderIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Loaded(folder), awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Pending(folder), awaitItem()) + assertEquals(DataState.Loaded(folder), awaitItem()) + } + + coVerify(exactly = 2) { + syncService.sync() + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(folderId))) + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(folderId))) + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(folderId))) + } + } + + @Test + fun `getVaultFolderStateFlow should update to NoNetwork when a sync fails from no network`() = + runTest { + val folderId = 1234 + val folderIdString = "mockId-$folderId" + coEvery { + syncService.sync() + } returns UnknownHostException().asFailure() + + vaultRepository.getVaultFolderStateFlow(folderIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.NoNetwork(), awaitItem()) + } + + coVerify(exactly = 1) { + syncService.sync() + } + } + + @Test + fun `getVaultFolderStateFlow should update to Error when a sync fails generically`() = runTest { + val folderId = 1234 + val folderIdString = "mockId-$folderId" + val throwable = Throwable("Fail") + coEvery { + syncService.sync() + } returns throwable.asFailure() + + vaultRepository.getVaultFolderStateFlow(folderIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Error(throwable), awaitItem()) + } + + coVerify(exactly = 1) { + syncService.sync() + } + } + + @Test + fun `getVaultFolderStateFlow should update to Loaded with null when a item cannot be found`() = + runTest { + val folderIdString = "mockId-1234" + coEvery { + syncService.sync() + } returns Result.success(createMockSyncResponse(1)) + coEvery { + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(1))) + } returns listOf(createMockCipherView(1)).asSuccess() + coEvery { + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) + } returns listOf(createMockFolderView(1)).asSuccess() + coEvery { + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) + } returns listOf(createMockSendView(1)).asSuccess() + + vaultRepository.getVaultFolderStateFlow(folderIdString).test { + assertEquals(DataState.Loading, awaitItem()) + vaultRepository.sync() + assertEquals(DataState.Loaded(null), awaitItem()) + } + + coVerify(exactly = 1) { + syncService.sync() + vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(1))) + vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) + vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) + } + } } private val MOCK_USER_STATE = UserStateJson(