VaultRepo clears in-memory vault data whenever the active account changes (#1010)

This commit is contained in:
David Perez 2024-02-13 15:42:58 -06:00 committed by Álison Fernandes
parent 5928987a9b
commit 8cc25a57f0
6 changed files with 138 additions and 175 deletions

View file

@ -469,7 +469,6 @@ class AuthRepositoryImpl(
// Attempt to unlock the vault with password if possible.
password?.let {
vaultRepository.clearUnlockedData()
vaultRepository.unlockVault(
userId = userStateJson.activeUserId,
email = userStateJson.activeAccount.profile.email,
@ -504,7 +503,6 @@ class AuthRepositoryImpl(
// Attempt to unlock the vault with auth request if possible.
deviceData?.let { model ->
vaultRepository.clearUnlockedData()
vaultRepository.unlockVault(
userId = userStateJson.activeUserId,
email = userStateJson.activeAccount.profile.email,
@ -582,12 +580,7 @@ class AuthRepositoryImpl(
}
override fun logout(userId: String) {
val wasActiveUser = userId == activeUserId
userLogoutManager.logout(userId = userId)
// Clear the current vault data if the logged out user was the active one.
if (wasActiveUser) vaultRepository.clearUnlockedData()
}
override suspend fun resendVerificationCodeEmail(): ResendEmailResult =
@ -618,9 +611,6 @@ class AuthRepositoryImpl(
// Switch to the new user
authDiskSource.userState = currentUserState.copy(activeUserId = userId)
// Clear data for the previous user
vaultRepository.clearUnlockedData()
// Clear any pending account additions
hasPendingAccountAddition = false

View file

@ -100,11 +100,6 @@ interface VaultRepository : VaultLockManager {
*/
val totpCodeFlow: Flow<TotpCodeResult>
/**
* Clear any previously unlocked, in-memory data (vault, send, etc).
*/
fun clearUnlockedData()
/**
* Completely remove any persisted data from the vault.
*/

View file

@ -15,6 +15,7 @@ import com.bitwarden.crypto.Kdf
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams
import com.x8bit.bitwarden.data.auth.repository.util.toUpdatedUserStateJson
import com.x8bit.bitwarden.data.auth.repository.util.userSwitchingChangesFlow
import com.x8bit.bitwarden.data.platform.datasource.disk.SettingsDiskSource
import com.x8bit.bitwarden.data.platform.datasource.network.util.isNoConnectionError
import com.x8bit.bitwarden.data.platform.manager.PushManager
@ -215,6 +216,14 @@ class VaultRepositoryImpl(
get() = mutableSendDataStateFlow.asStateFlow()
init {
authDiskSource
.userSwitchingChangesFlow
.onEach {
syncJob.cancel()
clearUnlockedData()
}
.launchIn(unconfinedScope)
// Setup ciphers MutableStateFlow
mutableCiphersStateFlow
.observeWhenSubscribedAndLoggedIn(authDiskSource.userStateFlow) { activeUserId ->
@ -282,7 +291,7 @@ class VaultRepositoryImpl(
.launchIn(ioScope)
}
override fun clearUnlockedData() {
private fun clearUnlockedData() {
mutableCiphersStateFlow.update { DataState.Loading }
mutableDomainsStateFlow.update { DataState.Loading }
mutableFoldersStateFlow.update { DataState.Loading }

View file

@ -137,7 +137,6 @@ class AuthRepositoryTest {
private val vaultRepository: VaultRepository = mockk {
every { vaultUnlockDataStateFlow } returns mutableVaultUnlockDataStateFlow
every { deleteVaultData(any()) } just runs
every { clearUnlockedData() } just runs
}
private val fakeAuthDiskSource = FakeAuthDiskSource()
private val fakeEnvironmentRepository =
@ -489,7 +488,6 @@ class AuthRepositoryTest {
fakeAuthDiskSource.userState,
)
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
verify { vaultRepository.clearUnlockedData() }
}
@Test
@ -866,7 +864,6 @@ class AuthRepositoryTest {
fakeAuthDiskSource.userState,
)
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
verify { vaultRepository.clearUnlockedData() }
}
@Suppress("MaxLineLength")
@ -953,7 +950,6 @@ class AuthRepositoryTest {
)
assertFalse(repository.hasPendingAccountAddition)
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
verify { vaultRepository.clearUnlockedData() }
}
@Test
@ -1194,7 +1190,6 @@ class AuthRepositoryTest {
fakeAuthDiskSource.userState,
)
verify { settingsRepository.setDefaultsIfNecessary(userId = USER_ID_1) }
verify { vaultRepository.clearUnlockedData() }
}
@Test
@ -2580,18 +2575,6 @@ class AuthRepositoryTest {
assertEquals(PrevalidateSsoResult.Success(token = "token"), result)
}
@Suppress("MaxLineLength")
@Test
fun `logout for the active account should call logout on the UserLogoutManager and clear the user's in memory vault data`() {
val userId = USER_ID_1
fakeAuthDiskSource.userState = MULTI_USER_STATE
repository.logout(userId = userId)
verify { userLogoutManager.logout(userId = userId) }
verify { vaultRepository.clearUnlockedData() }
}
@Suppress("MaxLineLength")
@Test
fun `logout for an inactive account should call logout on the UserLogoutManager`() {
@ -2601,7 +2584,6 @@ class AuthRepositoryTest {
repository.logout(userId = userId)
verify { userLogoutManager.logout(userId = userId) }
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
}
@Test
@ -2683,7 +2665,6 @@ class AuthRepositoryTest {
)
assertNull(repository.userStateFlow.value)
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
}
@Suppress("MaxLineLength")
@ -2714,7 +2695,6 @@ class AuthRepositoryTest {
repository.userStateFlow.value,
)
assertFalse(repository.hasPendingAccountAddition)
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
}
@Suppress("MaxLineLength")
@ -2743,12 +2723,11 @@ class AuthRepositoryTest {
originalUserState,
repository.userStateFlow.value,
)
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
}
@Suppress("MaxLineLength")
@Test
fun `switchAccount when the userId is valid should update the current UserState, clear the previously unlocked data, and reset any pending account additions`() {
fun `switchAccount when the userId is valid should update the current UserState and reset any pending account additions`() {
val updatedUserId = USER_ID_2
val originalUserState = MULTI_USER_STATE.toUserState(
vaultState = VAULT_UNLOCK_DATA,
@ -2774,7 +2753,6 @@ class AuthRepositoryTest {
repository.userStateFlow.value,
)
assertFalse(repository.hasPendingAccountAddition)
verify { vaultRepository.clearUnlockedData() }
}
@Test
@ -4045,9 +4023,8 @@ class AuthRepositoryTest {
)
}
@Suppress("MaxLineLength")
@Test
fun `logOutFlow emission for action account should call logout on the UserLogoutManager and clear the user's in memory vault data`() {
fun `logOutFlow emission for action account should call logout on the UserLogoutManager`() {
val userId = USER_ID_1
fakeAuthDiskSource.userState = MULTI_USER_STATE
@ -4055,7 +4032,6 @@ class AuthRepositoryTest {
coVerify(exactly = 1) {
userLogoutManager.logout(userId = userId)
vaultRepository.clearUnlockedData()
}
}

View file

@ -231,6 +231,130 @@ class VaultRepositoryTest {
unmockkStatic(Cipher::toEncryptedNetworkCipherResponse)
}
@Test
fun `userSwitchingChangesFlow should cancel any pending sync call`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
coEvery { syncService.sync() } just awaits
vaultRepository.sync()
vaultRepository.sync()
coVerify(exactly = 1) {
// Despite being called twice, we only allow 1 sync
syncService.sync()
}
fakeAuthDiskSource.userState = UserStateJson(
activeUserId = "mockId-2",
accounts = mapOf("mockId-2" to mockk()),
)
vaultRepository.sync()
coVerify(exactly = 2) {
// A second sync should have happened now since it was cancelled by the userState change
syncService.sync()
}
}
@Test
fun `userSwitchingChangesFlow should clear unlocked data`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery {
vaultSdkSource.decryptCipherList(
userId = userId,
cipherList = listOf(createMockSdkCipher(1)),
)
} returns listOf(createMockCipherView(number = 1)).asSuccess()
coEvery {
vaultSdkSource.decryptFolderList(
userId = userId,
folderList = listOf(createMockSdkFolder(1)),
)
} returns listOf(createMockFolderView(number = 1)).asSuccess()
coEvery {
vaultSdkSource.decryptCollectionList(
userId = userId,
collectionList = listOf(createMockSdkCollection(1)),
)
} returns listOf(createMockCollectionView(number = 1)).asSuccess()
coEvery {
vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(number = 1)),
)
} returns listOf(createMockSendView(number = 1)).asSuccess()
val ciphersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Cipher>>()
val collectionsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Collection>>()
val foldersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Folder>>()
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
val domainsFlow = bufferedMutableSharedFlow<SyncResponseJson.Domains>()
setupVaultDiskSourceFlows(
ciphersFlow = ciphersFlow,
collectionsFlow = collectionsFlow,
foldersFlow = foldersFlow,
sendsFlow = sendsFlow,
domainsFlow = domainsFlow,
)
turbineScope {
val ciphersStateFlow = vaultRepository.ciphersStateFlow.testIn(backgroundScope)
val collectionsStateFlow = vaultRepository.collectionsStateFlow.testIn(backgroundScope)
val foldersStateFlow = vaultRepository.foldersStateFlow.testIn(backgroundScope)
val sendsStateFlow = vaultRepository.sendDataStateFlow.testIn(backgroundScope)
val domainsStateFlow = vaultRepository.domainsStateFlow.testIn(backgroundScope)
assertEquals(DataState.Loading, ciphersStateFlow.awaitItem())
assertEquals(DataState.Loading, collectionsStateFlow.awaitItem())
assertEquals(DataState.Loading, foldersStateFlow.awaitItem())
assertEquals(DataState.Loading, sendsStateFlow.awaitItem())
assertEquals(DataState.Loading, domainsStateFlow.awaitItem())
ciphersFlow.tryEmit(listOf(createMockCipher(number = 1)))
collectionsFlow.tryEmit(listOf(createMockCollection(number = 1)))
foldersFlow.tryEmit(listOf(createMockFolder(number = 1)))
sendsFlow.tryEmit(listOf(createMockSend(number = 1)))
domainsFlow.tryEmit(createMockDomains(number = 1))
// No events received until unlocked
ciphersStateFlow.expectNoEvents()
collectionsStateFlow.expectNoEvents()
foldersStateFlow.expectNoEvents()
sendsStateFlow.expectNoEvents()
// Domains does not care about being unlocked
assertEquals(
DataState.Loaded(createMockDomainsData(number = 1)),
domainsStateFlow.awaitItem(),
)
setVaultToUnlocked(userId = userId)
assertEquals(
DataState.Loaded(listOf(createMockCipherView(number = 1))),
ciphersStateFlow.awaitItem(),
)
assertEquals(
DataState.Loaded(listOf(createMockCollectionView(number = 1))),
collectionsStateFlow.awaitItem(),
)
assertEquals(
DataState.Loaded(listOf(createMockFolderView(number = 1))),
foldersStateFlow.awaitItem(),
)
assertEquals(
DataState.Loaded(SendData(listOf(createMockSendView(number = 1)))),
sendsStateFlow.awaitItem(),
)
// Domain data has not changed
domainsStateFlow.expectNoEvents()
fakeAuthDiskSource.userState = null
assertEquals(DataState.Loading, ciphersStateFlow.awaitItem())
assertEquals(DataState.Loading, collectionsStateFlow.awaitItem())
assertEquals(DataState.Loading, foldersStateFlow.awaitItem())
assertEquals(DataState.Loading, sendsStateFlow.awaitItem())
assertEquals(DataState.Loading, domainsStateFlow.awaitItem())
}
}
@Test
fun `ciphersStateFlow should emit decrypted list of ciphers when decryptCipherList succeeds`() =
runTest {
@ -1329,137 +1453,6 @@ class VaultRepositoryTest {
}
}
@Test
fun `clearUnlockedData should update the vaultDataStateFlow to Loading`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery {
vaultSdkSource.decryptCipherList(
userId = userId,
cipherList = listOf(createMockSdkCipher(1)),
)
} returns listOf(createMockCipherView(number = 1)).asSuccess()
coEvery {
vaultSdkSource.decryptFolderList(
userId = userId,
folderList = listOf(createMockSdkFolder(1)),
)
} returns listOf(createMockFolderView(number = 1)).asSuccess()
coEvery {
vaultSdkSource.decryptCollectionList(
userId = userId,
collectionList = listOf(createMockSdkCollection(1)),
)
} returns listOf(createMockCollectionView(number = 1)).asSuccess()
coEvery {
vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(number = 1)),
)
} returns listOf(createMockSendView(number = 1)).asSuccess()
val ciphersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Cipher>>()
val collectionsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Collection>>()
val foldersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Folder>>()
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
setupVaultDiskSourceFlows(
ciphersFlow = ciphersFlow,
collectionsFlow = collectionsFlow,
foldersFlow = foldersFlow,
sendsFlow = sendsFlow,
)
vaultRepository.vaultDataStateFlow.test {
assertEquals(DataState.Loading, awaitItem())
ciphersFlow.tryEmit(listOf(createMockCipher(number = 1)))
collectionsFlow.tryEmit(listOf(createMockCollection(number = 1)))
foldersFlow.tryEmit(listOf(createMockFolder(number = 1)))
sendsFlow.tryEmit(listOf(createMockSend(number = 1)))
// No events received until unlocked
expectNoEvents()
setVaultToUnlocked(userId = userId)
assertEquals(
DataState.Loaded(
data = VaultData(
cipherViewList = listOf(createMockCipherView(number = 1)),
collectionViewList = listOf(createMockCollectionView(number = 1)),
folderViewList = listOf(createMockFolderView(number = 1)),
sendViewList = listOf(createMockSendView(number = 1)),
),
),
awaitItem(),
)
vaultRepository.clearUnlockedData()
assertEquals(DataState.Loading, awaitItem())
}
}
@Test
fun `clearUnlockedData should update the sendDataStateFlow to Loading`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
setVaultToUnlocked(userId = userId)
coEvery {
vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(number = 1)),
)
} returns listOf(createMockSendView(number = 1)).asSuccess()
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
setupVaultDiskSourceFlows(sendsFlow = sendsFlow)
vaultRepository.sendDataStateFlow.test {
assertEquals(DataState.Loading, awaitItem())
sendsFlow.tryEmit(listOf(createMockSend(number = 1)))
assertEquals(
DataState.Loaded(
data = SendData(sendViewList = listOf(createMockSendView(number = 1))),
),
awaitItem(),
)
vaultRepository.clearUnlockedData()
assertEquals(DataState.Loading, awaitItem())
}
}
@Test
fun `clearUnlockedData should update the domainsStateFlow to Loading`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val domainsData = createMockDomainsData(number = 1)
coEvery {
createMockDomains(number = 1).toDomainsData()
} returns domainsData
val domainsFlow = bufferedMutableSharedFlow<SyncResponseJson.Domains>()
setupVaultDiskSourceFlows(
domainsFlow = domainsFlow,
)
vaultRepository.domainsStateFlow.test {
assertEquals(DataState.Loading, awaitItem())
domainsFlow.tryEmit(createMockDomains(number = 1))
assertEquals(
DataState.Loaded(
data = domainsData,
),
awaitItem(),
)
vaultRepository.clearUnlockedData()
assertEquals(DataState.Loading, awaitItem())
}
}
@Test
fun `getVaultItemStateFlow should update to Error when a sync fails generically`() =
runTest {

View file

@ -5,11 +5,11 @@ package com.x8bit.bitwarden.data.vault.repository.model
*/
fun createMockDomainsData(number: Int): DomainsData =
DomainsData(
equivalentDomains = listOf(listOf("mockEquivalentDomains-$number")),
equivalentDomains = listOf(listOf("mockEquivalentDomain-$number")),
globalEquivalentDomains = listOf(
DomainsData.GlobalEquivalentDomain(
isExcluded = false,
domains = listOf("domains-$number"),
domains = listOf("mockDomain-$number"),
type = number,
),
),