mirror of
https://github.com/bitwarden/android.git
synced 2025-03-15 18:58:59 +03:00
Lock vault and clear data when logging out (#288)
This commit is contained in:
parent
de6f707964
commit
e41e681700
5 changed files with 112 additions and 0 deletions
|
@ -202,6 +202,7 @@ class AuthRepositoryImpl constructor(
|
|||
|
||||
override fun logout(userId: String) {
|
||||
val currentUserState = authDiskSource.userState ?: return
|
||||
val wasActiveUser = userId == activeUserId
|
||||
|
||||
// Remove the active user from the accounts map
|
||||
val updatedAccounts = currentUserState
|
||||
|
@ -228,6 +229,12 @@ class AuthRepositoryImpl constructor(
|
|||
// Update the user information and log out
|
||||
authDiskSource.userState = null
|
||||
}
|
||||
|
||||
// Lock the vault for the logged out user
|
||||
vaultRepository.lockVaultIfNecessary(userId)
|
||||
|
||||
// Clear the current vault data if the logged out user was the active one.
|
||||
if (wasActiveUser) vaultRepository.clearUnlockedData()
|
||||
}
|
||||
|
||||
@Suppress("ReturnCount", "LongMethod")
|
||||
|
|
|
@ -52,6 +52,11 @@ interface VaultRepository {
|
|||
*/
|
||||
fun getVaultFolderStateFlow(folderId: String): StateFlow<DataState<FolderView?>>
|
||||
|
||||
/**
|
||||
* Locks the vault for the user with the given [userId] if necessary.
|
||||
*/
|
||||
fun lockVaultIfNecessary(userId: String)
|
||||
|
||||
/**
|
||||
* Attempt to unlock the vault and sync the vault data for the currently active user.
|
||||
*/
|
||||
|
|
|
@ -155,6 +155,10 @@ class VaultRepositoryImpl constructor(
|
|||
initialValue = DataState.Loading,
|
||||
)
|
||||
|
||||
override fun lockVaultIfNecessary(userId: String) {
|
||||
setVaultToLocked(userId = userId)
|
||||
}
|
||||
|
||||
@Suppress("ReturnCount")
|
||||
override suspend fun unlockVaultAndSyncForCurrentUser(
|
||||
masterPassword: String,
|
||||
|
@ -233,6 +237,16 @@ class VaultRepositoryImpl constructor(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: This is temporary. Eventually this needs to be based on the presence of various
|
||||
// user keys but this will likely require SDK updates to support this (BIT-1190).
|
||||
private fun setVaultToLocked(userId: String) {
|
||||
vaultMutableStateFlow.update {
|
||||
it.copy(
|
||||
unlockedVaultUserIds = it.unlockedVaultUserIds - userId,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun storeUserKeyAndPrivateKey(
|
||||
userKey: String?,
|
||||
privateKey: String?,
|
||||
|
|
|
@ -50,6 +50,7 @@ import io.mockk.mockk
|
|||
import io.mockk.mockkStatic
|
||||
import io.mockk.runs
|
||||
import io.mockk.unmockkStatic
|
||||
import io.mockk.verify
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.AfterEach
|
||||
|
@ -69,6 +70,8 @@ class AuthRepositoryTest {
|
|||
private val mutableVaultStateFlow = MutableStateFlow(VAULT_STATE)
|
||||
private val vaultRepository: VaultRepository = mockk() {
|
||||
every { vaultStateFlow } returns mutableVaultStateFlow
|
||||
every { lockVaultIfNecessary(any()) } just runs
|
||||
every { clearUnlockedData() } just runs
|
||||
}
|
||||
private val fakeAuthDiskSource = FakeAuthDiskSource()
|
||||
private val fakeEnvironmentRepository =
|
||||
|
@ -846,6 +849,8 @@ class AuthRepositoryTest {
|
|||
userId = USER_ID_1,
|
||||
userKey = null,
|
||||
)
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
verify { vaultRepository.lockVaultIfNecessary(userId = USER_ID_1) }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -907,6 +912,8 @@ class AuthRepositoryTest {
|
|||
userId = USER_ID_1,
|
||||
userKey = null,
|
||||
)
|
||||
verify { vaultRepository.clearUnlockedData() }
|
||||
verify { vaultRepository.lockVaultIfNecessary(userId = USER_ID_1) }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -937,6 +944,8 @@ class AuthRepositoryTest {
|
|||
userId = USER_ID_2,
|
||||
userKey = null,
|
||||
)
|
||||
verify(exactly = 0) { vaultRepository.clearUnlockedData() }
|
||||
verify { vaultRepository.lockVaultIfNecessary(userId = USER_ID_2) }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -463,6 +463,29 @@ class VaultRepositoryTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `lockVaultIfNecessary should lock the given account if it is currently unlocked`() =
|
||||
runTest {
|
||||
val userId = "userId"
|
||||
verifyUnlockedVault(userId = userId)
|
||||
|
||||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = setOf(userId),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
||||
vaultRepository.lockVaultIfNecessary(userId = userId)
|
||||
|
||||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `unlockVaultAndSyncForCurrentUser with unlockVault Success should sync and return Success`() =
|
||||
|
@ -1357,6 +1380,60 @@ class VaultRepositoryTest {
|
|||
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1)))
|
||||
}
|
||||
}
|
||||
|
||||
//region Helper functions
|
||||
|
||||
/**
|
||||
* Helper to ensures that the vault for the user with the given [userId] is unlocked.
|
||||
*/
|
||||
private suspend fun verifyUnlockedVault(userId: String) {
|
||||
val kdf = MOCK_PROFILE.toSdkParams()
|
||||
val email = MOCK_PROFILE.email
|
||||
val masterPassword = "drowssap"
|
||||
val userKey = "12345"
|
||||
val privateKey = "54321"
|
||||
val organizationalKeys = emptyMap<String, String>()
|
||||
coEvery {
|
||||
vaultSdkSource.initializeCrypto(
|
||||
request = InitUserCryptoRequest(
|
||||
kdfParams = kdf,
|
||||
email = email,
|
||||
privateKey = privateKey,
|
||||
method = InitUserCryptoMethod.Password(
|
||||
password = masterPassword,
|
||||
userKey = userKey,
|
||||
),
|
||||
),
|
||||
)
|
||||
} returns InitializeCryptoResult.Success.asSuccess()
|
||||
|
||||
val result = vaultRepository.unlockVault(
|
||||
userId = userId,
|
||||
masterPassword = masterPassword,
|
||||
kdf = kdf,
|
||||
email = email,
|
||||
userKey = userKey,
|
||||
privateKey = privateKey,
|
||||
organizationalKeys = organizationalKeys,
|
||||
)
|
||||
|
||||
assertEquals(VaultUnlockResult.Success, result)
|
||||
coVerify(exactly = 1) {
|
||||
vaultSdkSource.initializeCrypto(
|
||||
request = InitUserCryptoRequest(
|
||||
kdfParams = kdf,
|
||||
email = email,
|
||||
privateKey = privateKey,
|
||||
method = InitUserCryptoMethod.Password(
|
||||
password = masterPassword,
|
||||
userKey = userKey,
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
//endregion Helper functions
|
||||
}
|
||||
|
||||
private val MOCK_PROFILE = AccountJson.Profile(
|
||||
|
|
Loading…
Add table
Reference in a new issue