mirror of
https://github.com/bitwarden/android.git
synced 2025-03-16 03:08:50 +03:00
Ensure collection IDs are maintained when restoring a cipher (#1445)
This commit is contained in:
parent
b671bf0626
commit
2032f50fef
5 changed files with 79 additions and 15 deletions
|
@ -80,6 +80,7 @@ interface CipherManager {
|
|||
*/
|
||||
suspend fun restoreCipher(
|
||||
cipherId: String,
|
||||
cipherView: CipherView,
|
||||
): RestoreCipherResult
|
||||
|
||||
/**
|
||||
|
|
|
@ -176,11 +176,17 @@ class CipherManagerImpl(
|
|||
|
||||
override suspend fun restoreCipher(
|
||||
cipherId: String,
|
||||
cipherView: CipherView,
|
||||
): RestoreCipherResult {
|
||||
val userId = activeUserId ?: return RestoreCipherResult.Error
|
||||
return ciphersService
|
||||
.restoreCipher(cipherId = cipherId)
|
||||
.onSuccess { vaultDiskSource.saveCipher(userId = userId, cipher = it) }
|
||||
.onSuccess {
|
||||
vaultDiskSource.saveCipher(
|
||||
userId = userId,
|
||||
cipher = it.copy(collectionIds = cipherView.collectionIds),
|
||||
)
|
||||
}
|
||||
.fold(
|
||||
onSuccess = { RestoreCipherResult.Success },
|
||||
onFailure = { RestoreCipherResult.Error },
|
||||
|
|
|
@ -486,14 +486,22 @@ class VaultItemViewModel @Inject constructor(
|
|||
),
|
||||
)
|
||||
}
|
||||
viewModelScope.launch {
|
||||
trySendAction(
|
||||
VaultItemAction.Internal.RestoreCipherReceive(
|
||||
result = vaultRepository.restoreCipher(
|
||||
cipherId = state.vaultItemId,
|
||||
),
|
||||
),
|
||||
)
|
||||
onContent { content ->
|
||||
content
|
||||
.common
|
||||
.currentCipher
|
||||
?.let { cipher ->
|
||||
viewModelScope.launch {
|
||||
trySendAction(
|
||||
VaultItemAction.Internal.RestoreCipherReceive(
|
||||
result = vaultRepository.restoreCipher(
|
||||
cipherId = state.vaultItemId,
|
||||
cipherView = cipher,
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -670,7 +670,10 @@ class CipherManagerTest {
|
|||
fun `restoreCipher with no active user should return RestoreCipherResult Error`() = runTest {
|
||||
fakeAuthDiskSource.userState = null
|
||||
|
||||
val result = cipherManager.restoreCipher(cipherId = "cipherId")
|
||||
val result = cipherManager.restoreCipher(
|
||||
cipherId = "cipherId",
|
||||
cipherView = createMockCipherView(number = 1),
|
||||
)
|
||||
|
||||
assertEquals(RestoreCipherResult.Error, result)
|
||||
}
|
||||
|
@ -681,11 +684,15 @@ class CipherManagerTest {
|
|||
runTest {
|
||||
fakeAuthDiskSource.userState = MOCK_USER_STATE
|
||||
val cipherId = "mockId-1"
|
||||
val cipherView = createMockCipherView(number = 1)
|
||||
coEvery {
|
||||
ciphersService.restoreCipher(cipherId = cipherId)
|
||||
} returns Throwable("Fail").asFailure()
|
||||
|
||||
val result = cipherManager.restoreCipher(cipherId = cipherId)
|
||||
val result = cipherManager.restoreCipher(
|
||||
cipherId = cipherId,
|
||||
cipherView = cipherView,
|
||||
)
|
||||
|
||||
assertEquals(RestoreCipherResult.Error, result)
|
||||
}
|
||||
|
@ -697,11 +704,20 @@ class CipherManagerTest {
|
|||
val userId = "mockId-1"
|
||||
val cipherId = "mockId-1"
|
||||
val cipher = createMockCipher(number = 1)
|
||||
val cipherView = createMockCipherView(number = 1)
|
||||
fakeAuthDiskSource.userState = MOCK_USER_STATE
|
||||
coEvery { ciphersService.restoreCipher(cipherId = cipherId) } returns cipher.asSuccess()
|
||||
coEvery { vaultDiskSource.saveCipher(userId = userId, cipher = cipher) } just runs
|
||||
coEvery {
|
||||
vaultDiskSource.saveCipher(
|
||||
userId = userId,
|
||||
cipher = cipher.copy(collectionIds = cipherView.collectionIds),
|
||||
)
|
||||
} just runs
|
||||
|
||||
val result = cipherManager.restoreCipher(cipherId = cipherId)
|
||||
val result = cipherManager.restoreCipher(
|
||||
cipherId = cipherId,
|
||||
cipherView = cipherView,
|
||||
)
|
||||
|
||||
assertEquals(RestoreCipherResult.Success, result)
|
||||
}
|
||||
|
|
|
@ -377,9 +377,26 @@ class VaultItemViewModelTest : BaseViewModelTest() {
|
|||
@Suppress("MaxLineLength")
|
||||
fun `ConfirmRestoreClick with RestoreCipherResult Success should should ShowToast and NavigateBack`() =
|
||||
runTest {
|
||||
val mockCipherView = mockk<CipherView> {
|
||||
every {
|
||||
toViewState(
|
||||
isPremiumUser = true,
|
||||
hasMasterPassword = true,
|
||||
totpCodeItemData = createTotpCodeData(),
|
||||
)
|
||||
} returns DEFAULT_VIEW_STATE
|
||||
}
|
||||
mutableVaultItemFlow.value = DataState.Loaded(data = mockCipherView)
|
||||
mutableAuthCodeItemFlow.value = DataState.Loaded(
|
||||
data = createVerificationCodeItem(),
|
||||
)
|
||||
|
||||
val viewModel = createViewModel(state = DEFAULT_STATE)
|
||||
coEvery {
|
||||
vaultRepo.restoreCipher(cipherId = VAULT_ITEM_ID)
|
||||
vaultRepo.restoreCipher(
|
||||
cipherId = VAULT_ITEM_ID,
|
||||
cipherView = createMockCipherView(number = 1),
|
||||
)
|
||||
} returns RestoreCipherResult.Success
|
||||
|
||||
viewModel.trySendAction(VaultItemAction.Common.ConfirmRestoreClick)
|
||||
|
@ -399,15 +416,31 @@ class VaultItemViewModelTest : BaseViewModelTest() {
|
|||
@Test
|
||||
@Suppress("MaxLineLength")
|
||||
fun `ConfirmRestoreClick with RestoreCipherResult Failure should should Show generic error`() {
|
||||
val mockCipherView = mockk<CipherView> {
|
||||
every {
|
||||
toViewState(
|
||||
isPremiumUser = true,
|
||||
hasMasterPassword = true,
|
||||
totpCodeItemData = createTotpCodeData(),
|
||||
)
|
||||
} returns DEFAULT_VIEW_STATE
|
||||
}
|
||||
mutableVaultItemFlow.value = DataState.Loaded(data = mockCipherView)
|
||||
mutableAuthCodeItemFlow.value = DataState.Loaded(data = createVerificationCodeItem())
|
||||
|
||||
val viewModel = createViewModel(state = DEFAULT_STATE)
|
||||
coEvery {
|
||||
vaultRepo.restoreCipher(cipherId = VAULT_ITEM_ID)
|
||||
vaultRepo.restoreCipher(
|
||||
cipherId = VAULT_ITEM_ID,
|
||||
cipherView = createMockCipherView(number = 1),
|
||||
)
|
||||
} returns RestoreCipherResult.Error
|
||||
|
||||
viewModel.trySendAction(VaultItemAction.Common.ConfirmRestoreClick)
|
||||
|
||||
assertEquals(
|
||||
DEFAULT_STATE.copy(
|
||||
viewState = DEFAULT_VIEW_STATE,
|
||||
dialog = VaultItemState.DialogState.Generic(
|
||||
message = R.string.generic_error_message.asText(),
|
||||
),
|
||||
|
|
Loading…
Add table
Reference in a new issue