Add VaultState.unlockingVaultUserIds and clean up the vault unlock logic (#546)

This commit is contained in:
Brian Yencho 2024-01-08 22:56:35 -06:00 committed by Álison Fernandes
parent 940979599e
commit d95e5df2a7
5 changed files with 169 additions and 67 deletions

View file

@ -91,14 +91,17 @@ class VaultRepositoryImpl(
private var syncJob: Job = Job().apply { complete() }
private var willSyncAfterUnlock = false
private val activeUserId: String? get() = authDiskSource.userState?.activeUserId
private val mutableTotpCodeFlow = bufferedMutableSharedFlow<String>()
private val mutableVaultStateStateFlow =
MutableStateFlow(VaultState(unlockedVaultUserIds = emptySet()))
MutableStateFlow(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
)
private val mutableSendDataStateFlow = MutableStateFlow<DataState<SendData>>(DataState.Loading)
@ -194,8 +197,8 @@ class VaultRepositoryImpl(
}
override fun sync() {
if (!syncJob.isCompleted || willSyncAfterUnlock) return
val userId = activeUserId ?: return
if (!syncJob.isCompleted || isVaultUnlocking(userId)) return
mutableCiphersStateFlow.updateToPendingOrLoading()
mutableFoldersStateFlow.updateToPendingOrLoading()
mutableCollectionsStateFlow.updateToPendingOrLoading()
@ -290,22 +293,15 @@ class VaultRepositoryImpl(
override suspend fun unlockVaultAndSyncForCurrentUser(
masterPassword: String,
): VaultUnlockResult {
val userState = authDiskSource.userState
val userId = activeUserId ?: return VaultUnlockResult.InvalidStateError
val userKey = authDiskSource.getUserKey(userId = userId)
?: return VaultUnlockResult.InvalidStateError
val userKey = authDiskSource.getUserKey(userId = userState.activeUserId)
?: return VaultUnlockResult.InvalidStateError
val privateKey = authDiskSource.getPrivateKey(userId = userState.activeUserId)
?: return VaultUnlockResult.InvalidStateError
val organizationKeys = authDiskSource
.getOrganizationKeys(userId = userState.activeUserId)
return unlockVault(
userId = userState.activeUserId,
masterPassword = masterPassword,
email = userState.activeAccount.profile.email,
kdf = userState.activeAccount.profile.toSdkParams(),
userKey = userKey,
privateKey = privateKey,
organizationKeys = organizationKeys,
return unlockVaultForUser(
userId = userId,
initUserCryptoMethod = InitUserCryptoMethod.Password(
password = masterPassword,
userKey = userKey,
),
)
.also {
if (it is VaultUnlockResult.Success) {
@ -323,53 +319,17 @@ class VaultRepositoryImpl(
privateKey: String,
organizationKeys: Map<String, String>?,
): VaultUnlockResult =
flow {
willSyncAfterUnlock = true
emit(
vaultSdkSource
.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest(
kdfParams = kdf,
email = email,
privateKey = privateKey,
method = InitUserCryptoMethod.Password(
password = masterPassword,
userKey = userKey,
),
),
)
.flatMap { result ->
// Initialize the SDK for organizations if necessary
if (organizationKeys != null &&
result is InitializeCryptoResult.Success
) {
vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(
organizationKeys = organizationKeys,
),
)
} else {
result.asSuccess()
}
}
.fold(
onFailure = { VaultUnlockResult.GenericError },
onSuccess = { initializeCryptoResult ->
initializeCryptoResult
.toVaultUnlockResult()
.also {
if (it is VaultUnlockResult.Success) {
setVaultToUnlocked(userId = userId)
}
}
},
),
)
}
.onCompletion { willSyncAfterUnlock = false }
.first()
unlockVaultInternal(
userId = userId,
email = email,
kdf = kdf,
privateKey = privateKey,
initUserCryptoMethod = InitUserCryptoMethod.Password(
password = masterPassword,
userKey = userKey,
),
organizationKeys = organizationKeys,
)
override suspend fun createCipher(cipherView: CipherView): CreateCipherResult {
val userId = requireNotNull(activeUserId)
@ -506,6 +466,25 @@ class VaultRepositoryImpl(
}
}
private fun setVaultToUnlocking(userId: String) {
mutableVaultStateStateFlow.update {
it.copy(
unlockingVaultUserIds = it.unlockingVaultUserIds + userId,
)
}
}
private fun setVaultToNotUnlocking(userId: String) {
mutableVaultStateStateFlow.update {
it.copy(
unlockingVaultUserIds = it.unlockingVaultUserIds - userId,
)
}
}
private fun isVaultUnlocking(userId: String) =
userId in mutableVaultStateStateFlow.value.unlockingVaultUserIds
private fun storeProfileData(
syncResponse: SyncResponseJson,
) {
@ -536,6 +515,80 @@ class VaultRepositoryImpl(
}
}
@Suppress("ReturnCount")
private suspend fun unlockVaultForUser(
userId: String,
initUserCryptoMethod: InitUserCryptoMethod,
): VaultUnlockResult {
val account = authDiskSource.userState?.accounts?.get(userId)
?: return VaultUnlockResult.InvalidStateError
val privateKey = authDiskSource.getPrivateKey(userId = userId)
?: return VaultUnlockResult.InvalidStateError
val organizationKeys = authDiskSource
.getOrganizationKeys(userId = userId)
return unlockVaultInternal(
userId = userId,
email = account.profile.email,
kdf = account.profile.toSdkParams(),
privateKey = privateKey,
initUserCryptoMethod = initUserCryptoMethod,
organizationKeys = organizationKeys,
)
}
private suspend fun unlockVaultInternal(
userId: String,
email: String,
kdf: Kdf,
privateKey: String,
initUserCryptoMethod: InitUserCryptoMethod,
organizationKeys: Map<String, String>?,
): VaultUnlockResult =
flow {
setVaultToUnlocking(userId = userId)
emit(
vaultSdkSource
.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest(
kdfParams = kdf,
email = email,
privateKey = privateKey,
method = initUserCryptoMethod,
),
)
.flatMap { result ->
// Initialize the SDK for organizations if necessary
if (organizationKeys != null &&
result is InitializeCryptoResult.Success
) {
vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(
organizationKeys = organizationKeys,
),
)
} else {
result.asSuccess()
}
}
.fold(
onFailure = { VaultUnlockResult.GenericError },
onSuccess = { initializeCryptoResult ->
initializeCryptoResult
.toVaultUnlockResult()
.also {
if (it is VaultUnlockResult.Success) {
setVaultToUnlocked(userId = userId)
}
}
},
),
)
}
.onCompletion { setVaultToNotUnlocking(userId = userId) }
.first()
private suspend fun unlockVaultForOrganizationsIfNecessary(
syncResponse: SyncResponseJson,
) {

View file

@ -4,7 +4,10 @@ package com.x8bit.bitwarden.data.vault.repository.model
* General description of the vault across multiple users.
*
* @property unlockedVaultUserIds The user IDs for all users that currently have unlocked vaults.
* @property unlockedVaultUserIds The user IDs for all users that are actively unlocking their
* vaults.
*/
data class VaultState(
val unlockedVaultUserIds: Set<String>,
val unlockingVaultUserIds: Set<String>,
)

View file

@ -179,7 +179,10 @@ class AuthRepositoryTest {
repository.userStateFlow.value,
)
val emptyVaultState = VaultState(unlockedVaultUserIds = emptySet())
val emptyVaultState = VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
)
mutableVaultStateFlow.value = emptyVaultState
assertEquals(
MULTI_USER_STATE.toUserState(
@ -1535,6 +1538,7 @@ class AuthRepositoryTest {
)
private val VAULT_STATE = VaultState(
unlockedVaultUserIds = setOf(USER_ID_1),
unlockingVaultUserIds = emptySet(),
)
}
}

View file

@ -135,6 +135,7 @@ class UserStateJsonExtensionsTest {
.toUserState(
vaultState = VaultState(
unlockedVaultUserIds = setOf("activeUserId"),
unlockingVaultUserIds = emptySet(),
),
userOrganizationsList = listOf(
UserOrganizations(
@ -198,6 +199,7 @@ class UserStateJsonExtensionsTest {
.toUserState(
vaultState = VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
userOrganizationsList = listOf(
UserOrganizations(

View file

@ -538,6 +538,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = setOf(userId),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -547,6 +548,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -562,6 +564,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = setOf(userId),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -571,6 +574,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -634,6 +638,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -649,6 +654,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = setOf("mockId-1"),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -794,6 +800,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -809,6 +816,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -861,6 +869,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -876,6 +885,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -913,6 +923,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -925,6 +936,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -974,6 +986,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -986,6 +999,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -999,6 +1013,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1012,6 +1027,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1024,6 +1040,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1045,6 +1062,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1057,6 +1075,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1077,6 +1096,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1114,6 +1134,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1132,6 +1153,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = setOf(userId),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1186,6 +1208,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1204,6 +1227,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1258,6 +1282,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1276,6 +1301,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1328,6 +1354,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1346,6 +1373,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1398,6 +1426,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1416,6 +1445,7 @@ class VaultRepositoryTest {
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
vaultRepository.vaultStateFlow.value,
)
@ -1477,6 +1507,16 @@ class VaultRepositoryTest {
organizationKeys = organizationKeys,
)
}
// The given userId is marked as unlocking
assertEquals(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = setOf(userId),
),
vaultRepository.vaultStateFlow.value,
)
// Does nothing because we are blocking
vaultRepository.sync()
scope.cancel()