mirror of
https://github.com/bitwarden/android.git
synced 2025-03-15 10:48:47 +03:00
Add VaultState.unlockingVaultUserIds and clean up the vault unlock logic (#546)
This commit is contained in:
parent
940979599e
commit
d95e5df2a7
5 changed files with 169 additions and 67 deletions
|
@ -91,14 +91,17 @@ class VaultRepositoryImpl(
|
|||
|
||||
private var syncJob: Job = Job().apply { complete() }
|
||||
|
||||
private var willSyncAfterUnlock = false
|
||||
|
||||
private val activeUserId: String? get() = authDiskSource.userState?.activeUserId
|
||||
|
||||
private val mutableTotpCodeFlow = bufferedMutableSharedFlow<String>()
|
||||
|
||||
private val mutableVaultStateStateFlow =
|
||||
MutableStateFlow(VaultState(unlockedVaultUserIds = emptySet()))
|
||||
MutableStateFlow(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
)
|
||||
|
||||
private val mutableSendDataStateFlow = MutableStateFlow<DataState<SendData>>(DataState.Loading)
|
||||
|
||||
|
@ -194,8 +197,8 @@ class VaultRepositoryImpl(
|
|||
}
|
||||
|
||||
override fun sync() {
|
||||
if (!syncJob.isCompleted || willSyncAfterUnlock) return
|
||||
val userId = activeUserId ?: return
|
||||
if (!syncJob.isCompleted || isVaultUnlocking(userId)) return
|
||||
mutableCiphersStateFlow.updateToPendingOrLoading()
|
||||
mutableFoldersStateFlow.updateToPendingOrLoading()
|
||||
mutableCollectionsStateFlow.updateToPendingOrLoading()
|
||||
|
@ -290,22 +293,15 @@ class VaultRepositoryImpl(
|
|||
override suspend fun unlockVaultAndSyncForCurrentUser(
|
||||
masterPassword: String,
|
||||
): VaultUnlockResult {
|
||||
val userState = authDiskSource.userState
|
||||
val userId = activeUserId ?: return VaultUnlockResult.InvalidStateError
|
||||
val userKey = authDiskSource.getUserKey(userId = userId)
|
||||
?: return VaultUnlockResult.InvalidStateError
|
||||
val userKey = authDiskSource.getUserKey(userId = userState.activeUserId)
|
||||
?: return VaultUnlockResult.InvalidStateError
|
||||
val privateKey = authDiskSource.getPrivateKey(userId = userState.activeUserId)
|
||||
?: return VaultUnlockResult.InvalidStateError
|
||||
val organizationKeys = authDiskSource
|
||||
.getOrganizationKeys(userId = userState.activeUserId)
|
||||
return unlockVault(
|
||||
userId = userState.activeUserId,
|
||||
masterPassword = masterPassword,
|
||||
email = userState.activeAccount.profile.email,
|
||||
kdf = userState.activeAccount.profile.toSdkParams(),
|
||||
userKey = userKey,
|
||||
privateKey = privateKey,
|
||||
organizationKeys = organizationKeys,
|
||||
return unlockVaultForUser(
|
||||
userId = userId,
|
||||
initUserCryptoMethod = InitUserCryptoMethod.Password(
|
||||
password = masterPassword,
|
||||
userKey = userKey,
|
||||
),
|
||||
)
|
||||
.also {
|
||||
if (it is VaultUnlockResult.Success) {
|
||||
|
@ -323,53 +319,17 @@ class VaultRepositoryImpl(
|
|||
privateKey: String,
|
||||
organizationKeys: Map<String, String>?,
|
||||
): VaultUnlockResult =
|
||||
flow {
|
||||
willSyncAfterUnlock = true
|
||||
emit(
|
||||
vaultSdkSource
|
||||
.initializeCrypto(
|
||||
userId = userId,
|
||||
request = InitUserCryptoRequest(
|
||||
kdfParams = kdf,
|
||||
email = email,
|
||||
privateKey = privateKey,
|
||||
method = InitUserCryptoMethod.Password(
|
||||
password = masterPassword,
|
||||
userKey = userKey,
|
||||
),
|
||||
),
|
||||
)
|
||||
.flatMap { result ->
|
||||
// Initialize the SDK for organizations if necessary
|
||||
if (organizationKeys != null &&
|
||||
result is InitializeCryptoResult.Success
|
||||
) {
|
||||
vaultSdkSource.initializeOrganizationCrypto(
|
||||
userId = userId,
|
||||
request = InitOrgCryptoRequest(
|
||||
organizationKeys = organizationKeys,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
result.asSuccess()
|
||||
}
|
||||
}
|
||||
.fold(
|
||||
onFailure = { VaultUnlockResult.GenericError },
|
||||
onSuccess = { initializeCryptoResult ->
|
||||
initializeCryptoResult
|
||||
.toVaultUnlockResult()
|
||||
.also {
|
||||
if (it is VaultUnlockResult.Success) {
|
||||
setVaultToUnlocked(userId = userId)
|
||||
}
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
.onCompletion { willSyncAfterUnlock = false }
|
||||
.first()
|
||||
unlockVaultInternal(
|
||||
userId = userId,
|
||||
email = email,
|
||||
kdf = kdf,
|
||||
privateKey = privateKey,
|
||||
initUserCryptoMethod = InitUserCryptoMethod.Password(
|
||||
password = masterPassword,
|
||||
userKey = userKey,
|
||||
),
|
||||
organizationKeys = organizationKeys,
|
||||
)
|
||||
|
||||
override suspend fun createCipher(cipherView: CipherView): CreateCipherResult {
|
||||
val userId = requireNotNull(activeUserId)
|
||||
|
@ -506,6 +466,25 @@ class VaultRepositoryImpl(
|
|||
}
|
||||
}
|
||||
|
||||
private fun setVaultToUnlocking(userId: String) {
|
||||
mutableVaultStateStateFlow.update {
|
||||
it.copy(
|
||||
unlockingVaultUserIds = it.unlockingVaultUserIds + userId,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun setVaultToNotUnlocking(userId: String) {
|
||||
mutableVaultStateStateFlow.update {
|
||||
it.copy(
|
||||
unlockingVaultUserIds = it.unlockingVaultUserIds - userId,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun isVaultUnlocking(userId: String) =
|
||||
userId in mutableVaultStateStateFlow.value.unlockingVaultUserIds
|
||||
|
||||
private fun storeProfileData(
|
||||
syncResponse: SyncResponseJson,
|
||||
) {
|
||||
|
@ -536,6 +515,80 @@ class VaultRepositoryImpl(
|
|||
}
|
||||
}
|
||||
|
||||
@Suppress("ReturnCount")
|
||||
private suspend fun unlockVaultForUser(
|
||||
userId: String,
|
||||
initUserCryptoMethod: InitUserCryptoMethod,
|
||||
): VaultUnlockResult {
|
||||
val account = authDiskSource.userState?.accounts?.get(userId)
|
||||
?: return VaultUnlockResult.InvalidStateError
|
||||
val privateKey = authDiskSource.getPrivateKey(userId = userId)
|
||||
?: return VaultUnlockResult.InvalidStateError
|
||||
val organizationKeys = authDiskSource
|
||||
.getOrganizationKeys(userId = userId)
|
||||
return unlockVaultInternal(
|
||||
userId = userId,
|
||||
email = account.profile.email,
|
||||
kdf = account.profile.toSdkParams(),
|
||||
privateKey = privateKey,
|
||||
initUserCryptoMethod = initUserCryptoMethod,
|
||||
organizationKeys = organizationKeys,
|
||||
)
|
||||
}
|
||||
|
||||
private suspend fun unlockVaultInternal(
|
||||
userId: String,
|
||||
email: String,
|
||||
kdf: Kdf,
|
||||
privateKey: String,
|
||||
initUserCryptoMethod: InitUserCryptoMethod,
|
||||
organizationKeys: Map<String, String>?,
|
||||
): VaultUnlockResult =
|
||||
flow {
|
||||
setVaultToUnlocking(userId = userId)
|
||||
emit(
|
||||
vaultSdkSource
|
||||
.initializeCrypto(
|
||||
userId = userId,
|
||||
request = InitUserCryptoRequest(
|
||||
kdfParams = kdf,
|
||||
email = email,
|
||||
privateKey = privateKey,
|
||||
method = initUserCryptoMethod,
|
||||
),
|
||||
)
|
||||
.flatMap { result ->
|
||||
// Initialize the SDK for organizations if necessary
|
||||
if (organizationKeys != null &&
|
||||
result is InitializeCryptoResult.Success
|
||||
) {
|
||||
vaultSdkSource.initializeOrganizationCrypto(
|
||||
userId = userId,
|
||||
request = InitOrgCryptoRequest(
|
||||
organizationKeys = organizationKeys,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
result.asSuccess()
|
||||
}
|
||||
}
|
||||
.fold(
|
||||
onFailure = { VaultUnlockResult.GenericError },
|
||||
onSuccess = { initializeCryptoResult ->
|
||||
initializeCryptoResult
|
||||
.toVaultUnlockResult()
|
||||
.also {
|
||||
if (it is VaultUnlockResult.Success) {
|
||||
setVaultToUnlocked(userId = userId)
|
||||
}
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
.onCompletion { setVaultToNotUnlocking(userId = userId) }
|
||||
.first()
|
||||
|
||||
private suspend fun unlockVaultForOrganizationsIfNecessary(
|
||||
syncResponse: SyncResponseJson,
|
||||
) {
|
||||
|
|
|
@ -4,7 +4,10 @@ package com.x8bit.bitwarden.data.vault.repository.model
|
|||
* General description of the vault across multiple users.
|
||||
*
|
||||
* @property unlockedVaultUserIds The user IDs for all users that currently have unlocked vaults.
|
||||
* @property unlockedVaultUserIds The user IDs for all users that are actively unlocking their
|
||||
* vaults.
|
||||
*/
|
||||
data class VaultState(
|
||||
val unlockedVaultUserIds: Set<String>,
|
||||
val unlockingVaultUserIds: Set<String>,
|
||||
)
|
||||
|
|
|
@ -179,7 +179,10 @@ class AuthRepositoryTest {
|
|||
repository.userStateFlow.value,
|
||||
)
|
||||
|
||||
val emptyVaultState = VaultState(unlockedVaultUserIds = emptySet())
|
||||
val emptyVaultState = VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
)
|
||||
mutableVaultStateFlow.value = emptyVaultState
|
||||
assertEquals(
|
||||
MULTI_USER_STATE.toUserState(
|
||||
|
@ -1535,6 +1538,7 @@ class AuthRepositoryTest {
|
|||
)
|
||||
private val VAULT_STATE = VaultState(
|
||||
unlockedVaultUserIds = setOf(USER_ID_1),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -135,6 +135,7 @@ class UserStateJsonExtensionsTest {
|
|||
.toUserState(
|
||||
vaultState = VaultState(
|
||||
unlockedVaultUserIds = setOf("activeUserId"),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
userOrganizationsList = listOf(
|
||||
UserOrganizations(
|
||||
|
@ -198,6 +199,7 @@ class UserStateJsonExtensionsTest {
|
|||
.toUserState(
|
||||
vaultState = VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
userOrganizationsList = listOf(
|
||||
UserOrganizations(
|
||||
|
|
|
@ -538,6 +538,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = setOf(userId),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -547,6 +548,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -562,6 +564,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = setOf(userId),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -571,6 +574,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -634,6 +638,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -649,6 +654,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = setOf("mockId-1"),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -794,6 +800,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -809,6 +816,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -861,6 +869,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -876,6 +885,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -913,6 +923,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -925,6 +936,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -974,6 +986,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -986,6 +999,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -999,6 +1013,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1012,6 +1027,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1024,6 +1040,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1045,6 +1062,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1057,6 +1075,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1077,6 +1096,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1114,6 +1134,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1132,6 +1153,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = setOf(userId),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1186,6 +1208,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1204,6 +1227,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1258,6 +1282,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1276,6 +1301,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1328,6 +1354,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1346,6 +1373,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1398,6 +1426,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1416,6 +1445,7 @@ class VaultRepositoryTest {
|
|||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = emptySet(),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
@ -1477,6 +1507,16 @@ class VaultRepositoryTest {
|
|||
organizationKeys = organizationKeys,
|
||||
)
|
||||
}
|
||||
|
||||
// The given userId is marked as unlocking
|
||||
assertEquals(
|
||||
VaultState(
|
||||
unlockedVaultUserIds = emptySet(),
|
||||
unlockingVaultUserIds = setOf(userId),
|
||||
),
|
||||
vaultRepository.vaultStateFlow.value,
|
||||
)
|
||||
|
||||
// Does nothing because we are blocking
|
||||
vaultRepository.sync()
|
||||
scope.cancel()
|
||||
|
|
Loading…
Add table
Reference in a new issue