mirror of
https://github.com/bitwarden/android.git
synced 2025-03-15 18:58:59 +03:00
VaultRepo clears in-memory vault data whenever the active account changes (#1010)
This commit is contained in:
parent
5928987a9b
commit
8cc25a57f0
6 changed files with 138 additions and 175 deletions
|
@ -469,7 +469,6 @@ class AuthRepositoryImpl(
|
|||
|
||||
// Attempt to unlock the vault with password if possible.
|
||||
password?.let {
|
||||
vaultRepository.clearUnlockedData()
|
||||
vaultRepository.unlockVault(
|
||||
userId = userStateJson.activeUserId,
|
||||
email = userStateJson.activeAccount.profile.email,
|
||||
|
@ -504,7 +503,6 @@ class AuthRepositoryImpl(
|
|||
|
||||
// Attempt to unlock the vault with auth request if possible.
|
||||
deviceData?.let { model ->
|
||||
vaultRepository.clearUnlockedData()
|
||||
vaultRepository.unlockVault(
|
||||
userId = userStateJson.activeUserId,
|
||||
email = userStateJson.activeAccount.profile.email,
|
||||
|
@ -582,12 +580,7 @@ class AuthRepositoryImpl(
|
|||
}
|
||||
|
||||
override fun logout(userId: String) {
|
||||
val wasActiveUser = userId == activeUserId
|
||||
|
||||
userLogoutManager.logout(userId = userId)
|
||||
|
||||
// Clear the current vault data if the logged out user was the active one.
|
||||
if (wasActiveUser) vaultRepository.clearUnlockedData()
|
||||
}
|
||||
|
||||
override suspend fun resendVerificationCodeEmail(): ResendEmailResult =
|
||||
|
@ -618,9 +611,6 @@ class AuthRepositoryImpl(
|
|||
// Switch to the new user
|
||||
authDiskSource.userState = currentUserState.copy(activeUserId = userId)
|
||||
|
||||
// Clear data for the previous user
|
||||
vaultRepository.clearUnlockedData()
|
||||
|
||||
// Clear any pending account additions
|
||||
hasPendingAccountAddition = false
|
||||
|
||||
|
|
|
@ -100,11 +100,6 @@ interface VaultRepository : VaultLockManager {
|
|||
*/
|
||||
val totpCodeFlow: Flow<TotpCodeResult>
|
||||
|
||||
/**
|
||||
* Clear any previously unlocked, in-memory data (vault, send, etc).
|
||||
*/
|
||||
fun clearUnlockedData()
|
||||
|
||||
/**
|
||||
* Completely remove any persisted data from the vault.
|
||||
*/
|
||||
|
|
|
@ -15,6 +15,7 @@ import com.bitwarden.crypto.Kdf
|
|||
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
|
||||
import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams
|
||||
import com.x8bit.bitwarden.data.auth.repository.util.toUpdatedUserStateJson
|
||||
import com.x8bit.bitwarden.data.auth.repository.util.userSwitchingChangesFlow
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.SettingsDiskSource
|
||||
import com.x8bit.bitwarden.data.platform.datasource.network.util.isNoConnectionError
|
||||
import com.x8bit.bitwarden.data.platform.manager.PushManager
|
||||
|
@ -215,6 +216,14 @@ class VaultRepositoryImpl(
|
|||
get() = mutableSendDataStateFlow.asStateFlow()
|
||||
|
||||
init {
|
||||
authDiskSource
|
||||
.userSwitchingChangesFlow
|
||||
.onEach {
|
||||
syncJob.cancel()
|
||||
clearUnlockedData()
|
||||
}
|
||||
.launchIn(unconfinedScope)
|
||||
|
||||
// Setup ciphers MutableStateFlow
|
||||
mutableCiphersStateFlow
|
||||
.observeWhenSubscribedAndLoggedIn(authDiskSource.userStateFlow) { activeUserId ->
|
||||
|
@ -282,7 +291,7 @@ class VaultRepositoryImpl(
|
|||
.launchIn(ioScope)
|
||||
}
|
||||
|
||||
override fun clearUnlockedData() {
|
||||
private fun clearUnlockedData() {
|
||||
mutableCiphersStateFlow.update { DataState.Loading }
|
||||
mutableDomainsStateFlow.update { DataState.Loading }
|
||||
mutableFoldersStateFlow.update { DataState.Loading }
|
||||
|
|
|
@ -137,7 +137,6 @@ class AuthRepositoryTest {
|
|||
private val vaultRepository: VaultRepository = mockk {
|
||||
every { vaultUnlockDataStateFlow } returns mutableVaultUnlockDataStateFlow
|
||||
every { deleteVaultData(any()) } just runs
|
||||
every { clearUnlockedData() } just runs
|
||||
}
|
||||
private val fakeAuthDiskSource = FakeAuthDiskSource()
|
||||
private val fakeEnvironmentRepository =
|
||||
|
@ -489,7 +488,6 @@ class AuthRepositoryTest {
|
|||
fakeAuthDiskSource.userState,
|
||||
)
|
||||
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -866,7 +864,6 @@ class AuthRepositoryTest {
|
|||
fakeAuthDiskSource.userState,
|
||||
)
|
||||
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
|
@ -953,7 +950,6 @@ class AuthRepositoryTest {
|
|||
)
|
||||
assertFalse(repository.hasPendingAccountAddition)
|
||||
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1194,7 +1190,6 @@ class AuthRepositoryTest {
|
|||
fakeAuthDiskSource.userState,
|
||||
)
|
||||
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -2580,18 +2575,6 @@ class AuthRepositoryTest {
|
|||
assertEquals(PrevalidateSsoResult.Success(token = "token"), result)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `logout for the active account should call logout on the UserLogoutManager and clear the user's in memory vault data`() {
|
||||
val userId = USER_ID_1
|
||||
fakeAuthDiskSource.userState = MULTI_USER_STATE
|
||||
|
||||
repository.logout(userId = userId)
|
||||
|
||||
verify { userLogoutManager.logout(userId = userId) }
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `logout for an inactive account should call logout on the UserLogoutManager`() {
|
||||
|
@ -2601,7 +2584,6 @@ class AuthRepositoryTest {
|
|||
repository.logout(userId = userId)
|
||||
|
||||
verify { userLogoutManager.logout(userId = userId) }
|
||||
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -2683,7 +2665,6 @@ class AuthRepositoryTest {
|
|||
)
|
||||
|
||||
assertNull(repository.userStateFlow.value)
|
||||
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
|
@ -2714,7 +2695,6 @@ class AuthRepositoryTest {
|
|||
repository.userStateFlow.value,
|
||||
)
|
||||
assertFalse(repository.hasPendingAccountAddition)
|
||||
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
|
@ -2743,12 +2723,11 @@ class AuthRepositoryTest {
|
|||
originalUserState,
|
||||
repository.userStateFlow.value,
|
||||
)
|
||||
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `switchAccount when the userId is valid should update the current UserState, clear the previously unlocked data, and reset any pending account additions`() {
|
||||
fun `switchAccount when the userId is valid should update the current UserState and reset any pending account additions`() {
|
||||
val updatedUserId = USER_ID_2
|
||||
val originalUserState = MULTI_USER_STATE.toUserState(
|
||||
vaultState = VAULT_UNLOCK_DATA,
|
||||
|
@ -2774,7 +2753,6 @@ class AuthRepositoryTest {
|
|||
repository.userStateFlow.value,
|
||||
)
|
||||
assertFalse(repository.hasPendingAccountAddition)
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -4045,9 +4023,8 @@ class AuthRepositoryTest {
|
|||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `logOutFlow emission for action account should call logout on the UserLogoutManager and clear the user's in memory vault data`() {
|
||||
fun `logOutFlow emission for action account should call logout on the UserLogoutManager`() {
|
||||
val userId = USER_ID_1
|
||||
fakeAuthDiskSource.userState = MULTI_USER_STATE
|
||||
|
||||
|
@ -4055,7 +4032,6 @@ class AuthRepositoryTest {
|
|||
|
||||
coVerify(exactly = 1) {
|
||||
userLogoutManager.logout(userId = userId)
|
||||
vaultRepository.clearUnlockedData()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -231,6 +231,130 @@ class VaultRepositoryTest {
|
|||
unmockkStatic(Cipher::toEncryptedNetworkCipherResponse)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `userSwitchingChangesFlow should cancel any pending sync call`() = runTest {
|
||||
fakeAuthDiskSource.userState = MOCK_USER_STATE
|
||||
coEvery { syncService.sync() } just awaits
|
||||
|
||||
vaultRepository.sync()
|
||||
vaultRepository.sync()
|
||||
coVerify(exactly = 1) {
|
||||
// Despite being called twice, we only allow 1 sync
|
||||
syncService.sync()
|
||||
}
|
||||
|
||||
fakeAuthDiskSource.userState = UserStateJson(
|
||||
activeUserId = "mockId-2",
|
||||
accounts = mapOf("mockId-2" to mockk()),
|
||||
)
|
||||
vaultRepository.sync()
|
||||
coVerify(exactly = 2) {
|
||||
// A second sync should have happened now since it was cancelled by the userState change
|
||||
syncService.sync()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `userSwitchingChangesFlow should clear unlocked data`() = runTest {
|
||||
fakeAuthDiskSource.userState = MOCK_USER_STATE
|
||||
val userId = "mockId-1"
|
||||
coEvery {
|
||||
vaultSdkSource.decryptCipherList(
|
||||
userId = userId,
|
||||
cipherList = listOf(createMockSdkCipher(1)),
|
||||
)
|
||||
} returns listOf(createMockCipherView(number = 1)).asSuccess()
|
||||
coEvery {
|
||||
vaultSdkSource.decryptFolderList(
|
||||
userId = userId,
|
||||
folderList = listOf(createMockSdkFolder(1)),
|
||||
)
|
||||
} returns listOf(createMockFolderView(number = 1)).asSuccess()
|
||||
coEvery {
|
||||
vaultSdkSource.decryptCollectionList(
|
||||
userId = userId,
|
||||
collectionList = listOf(createMockSdkCollection(1)),
|
||||
)
|
||||
} returns listOf(createMockCollectionView(number = 1)).asSuccess()
|
||||
coEvery {
|
||||
vaultSdkSource.decryptSendList(
|
||||
userId = userId,
|
||||
sendList = listOf(createMockSdkSend(number = 1)),
|
||||
)
|
||||
} returns listOf(createMockSendView(number = 1)).asSuccess()
|
||||
val ciphersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Cipher>>()
|
||||
val collectionsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Collection>>()
|
||||
val foldersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Folder>>()
|
||||
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
|
||||
val domainsFlow = bufferedMutableSharedFlow<SyncResponseJson.Domains>()
|
||||
setupVaultDiskSourceFlows(
|
||||
ciphersFlow = ciphersFlow,
|
||||
collectionsFlow = collectionsFlow,
|
||||
foldersFlow = foldersFlow,
|
||||
sendsFlow = sendsFlow,
|
||||
domainsFlow = domainsFlow,
|
||||
)
|
||||
|
||||
turbineScope {
|
||||
val ciphersStateFlow = vaultRepository.ciphersStateFlow.testIn(backgroundScope)
|
||||
val collectionsStateFlow = vaultRepository.collectionsStateFlow.testIn(backgroundScope)
|
||||
val foldersStateFlow = vaultRepository.foldersStateFlow.testIn(backgroundScope)
|
||||
val sendsStateFlow = vaultRepository.sendDataStateFlow.testIn(backgroundScope)
|
||||
val domainsStateFlow = vaultRepository.domainsStateFlow.testIn(backgroundScope)
|
||||
|
||||
assertEquals(DataState.Loading, ciphersStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, collectionsStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, foldersStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, sendsStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, domainsStateFlow.awaitItem())
|
||||
|
||||
ciphersFlow.tryEmit(listOf(createMockCipher(number = 1)))
|
||||
collectionsFlow.tryEmit(listOf(createMockCollection(number = 1)))
|
||||
foldersFlow.tryEmit(listOf(createMockFolder(number = 1)))
|
||||
sendsFlow.tryEmit(listOf(createMockSend(number = 1)))
|
||||
domainsFlow.tryEmit(createMockDomains(number = 1))
|
||||
|
||||
// No events received until unlocked
|
||||
ciphersStateFlow.expectNoEvents()
|
||||
collectionsStateFlow.expectNoEvents()
|
||||
foldersStateFlow.expectNoEvents()
|
||||
sendsStateFlow.expectNoEvents()
|
||||
// Domains does not care about being unlocked
|
||||
assertEquals(
|
||||
DataState.Loaded(createMockDomainsData(number = 1)),
|
||||
domainsStateFlow.awaitItem(),
|
||||
)
|
||||
setVaultToUnlocked(userId = userId)
|
||||
|
||||
assertEquals(
|
||||
DataState.Loaded(listOf(createMockCipherView(number = 1))),
|
||||
ciphersStateFlow.awaitItem(),
|
||||
)
|
||||
assertEquals(
|
||||
DataState.Loaded(listOf(createMockCollectionView(number = 1))),
|
||||
collectionsStateFlow.awaitItem(),
|
||||
)
|
||||
assertEquals(
|
||||
DataState.Loaded(listOf(createMockFolderView(number = 1))),
|
||||
foldersStateFlow.awaitItem(),
|
||||
)
|
||||
assertEquals(
|
||||
DataState.Loaded(SendData(listOf(createMockSendView(number = 1)))),
|
||||
sendsStateFlow.awaitItem(),
|
||||
)
|
||||
// Domain data has not changed
|
||||
domainsStateFlow.expectNoEvents()
|
||||
|
||||
fakeAuthDiskSource.userState = null
|
||||
|
||||
assertEquals(DataState.Loading, ciphersStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, collectionsStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, foldersStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, sendsStateFlow.awaitItem())
|
||||
assertEquals(DataState.Loading, domainsStateFlow.awaitItem())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `ciphersStateFlow should emit decrypted list of ciphers when decryptCipherList succeeds`() =
|
||||
runTest {
|
||||
|
@ -1329,137 +1453,6 @@ class VaultRepositoryTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `clearUnlockedData should update the vaultDataStateFlow to Loading`() = runTest {
|
||||
fakeAuthDiskSource.userState = MOCK_USER_STATE
|
||||
val userId = "mockId-1"
|
||||
coEvery {
|
||||
vaultSdkSource.decryptCipherList(
|
||||
userId = userId,
|
||||
cipherList = listOf(createMockSdkCipher(1)),
|
||||
)
|
||||
} returns listOf(createMockCipherView(number = 1)).asSuccess()
|
||||
coEvery {
|
||||
vaultSdkSource.decryptFolderList(
|
||||
userId = userId,
|
||||
folderList = listOf(createMockSdkFolder(1)),
|
||||
)
|
||||
} returns listOf(createMockFolderView(number = 1)).asSuccess()
|
||||
coEvery {
|
||||
vaultSdkSource.decryptCollectionList(
|
||||
userId = userId,
|
||||
collectionList = listOf(createMockSdkCollection(1)),
|
||||
)
|
||||
} returns listOf(createMockCollectionView(number = 1)).asSuccess()
|
||||
coEvery {
|
||||
vaultSdkSource.decryptSendList(
|
||||
userId = userId,
|
||||
sendList = listOf(createMockSdkSend(number = 1)),
|
||||
)
|
||||
} returns listOf(createMockSendView(number = 1)).asSuccess()
|
||||
val ciphersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Cipher>>()
|
||||
val collectionsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Collection>>()
|
||||
val foldersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Folder>>()
|
||||
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
|
||||
setupVaultDiskSourceFlows(
|
||||
ciphersFlow = ciphersFlow,
|
||||
collectionsFlow = collectionsFlow,
|
||||
foldersFlow = foldersFlow,
|
||||
sendsFlow = sendsFlow,
|
||||
)
|
||||
|
||||
vaultRepository.vaultDataStateFlow.test {
|
||||
assertEquals(DataState.Loading, awaitItem())
|
||||
|
||||
ciphersFlow.tryEmit(listOf(createMockCipher(number = 1)))
|
||||
collectionsFlow.tryEmit(listOf(createMockCollection(number = 1)))
|
||||
foldersFlow.tryEmit(listOf(createMockFolder(number = 1)))
|
||||
sendsFlow.tryEmit(listOf(createMockSend(number = 1)))
|
||||
|
||||
// No events received until unlocked
|
||||
expectNoEvents()
|
||||
setVaultToUnlocked(userId = userId)
|
||||
|
||||
assertEquals(
|
||||
DataState.Loaded(
|
||||
data = VaultData(
|
||||
cipherViewList = listOf(createMockCipherView(number = 1)),
|
||||
collectionViewList = listOf(createMockCollectionView(number = 1)),
|
||||
folderViewList = listOf(createMockFolderView(number = 1)),
|
||||
sendViewList = listOf(createMockSendView(number = 1)),
|
||||
),
|
||||
),
|
||||
awaitItem(),
|
||||
)
|
||||
|
||||
vaultRepository.clearUnlockedData()
|
||||
|
||||
assertEquals(DataState.Loading, awaitItem())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `clearUnlockedData should update the sendDataStateFlow to Loading`() = runTest {
|
||||
fakeAuthDiskSource.userState = MOCK_USER_STATE
|
||||
val userId = "mockId-1"
|
||||
setVaultToUnlocked(userId = userId)
|
||||
coEvery {
|
||||
vaultSdkSource.decryptSendList(
|
||||
userId = userId,
|
||||
sendList = listOf(createMockSdkSend(number = 1)),
|
||||
)
|
||||
} returns listOf(createMockSendView(number = 1)).asSuccess()
|
||||
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
|
||||
setupVaultDiskSourceFlows(sendsFlow = sendsFlow)
|
||||
|
||||
vaultRepository.sendDataStateFlow.test {
|
||||
assertEquals(DataState.Loading, awaitItem())
|
||||
|
||||
sendsFlow.tryEmit(listOf(createMockSend(number = 1)))
|
||||
|
||||
assertEquals(
|
||||
DataState.Loaded(
|
||||
data = SendData(sendViewList = listOf(createMockSendView(number = 1))),
|
||||
),
|
||||
awaitItem(),
|
||||
)
|
||||
|
||||
vaultRepository.clearUnlockedData()
|
||||
|
||||
assertEquals(DataState.Loading, awaitItem())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `clearUnlockedData should update the domainsStateFlow to Loading`() = runTest {
|
||||
fakeAuthDiskSource.userState = MOCK_USER_STATE
|
||||
val domainsData = createMockDomainsData(number = 1)
|
||||
coEvery {
|
||||
createMockDomains(number = 1).toDomainsData()
|
||||
} returns domainsData
|
||||
val domainsFlow = bufferedMutableSharedFlow<SyncResponseJson.Domains>()
|
||||
setupVaultDiskSourceFlows(
|
||||
domainsFlow = domainsFlow,
|
||||
)
|
||||
|
||||
vaultRepository.domainsStateFlow.test {
|
||||
assertEquals(DataState.Loading, awaitItem())
|
||||
|
||||
domainsFlow.tryEmit(createMockDomains(number = 1))
|
||||
|
||||
assertEquals(
|
||||
DataState.Loaded(
|
||||
data = domainsData,
|
||||
),
|
||||
awaitItem(),
|
||||
)
|
||||
|
||||
vaultRepository.clearUnlockedData()
|
||||
|
||||
assertEquals(DataState.Loading, awaitItem())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getVaultItemStateFlow should update to Error when a sync fails generically`() =
|
||||
runTest {
|
||||
|
|
|
@ -5,11 +5,11 @@ package com.x8bit.bitwarden.data.vault.repository.model
|
|||
*/
|
||||
fun createMockDomainsData(number: Int): DomainsData =
|
||||
DomainsData(
|
||||
equivalentDomains = listOf(listOf("mockEquivalentDomains-$number")),
|
||||
equivalentDomains = listOf(listOf("mockEquivalentDomain-$number")),
|
||||
globalEquivalentDomains = listOf(
|
||||
DomainsData.GlobalEquivalentDomain(
|
||||
isExcluded = false,
|
||||
domains = listOf("domains-$number"),
|
||||
domains = listOf("mockDomain-$number"),
|
||||
type = number,
|
||||
),
|
||||
),
|
||||
|
|
Loading…
Add table
Reference in a new issue