Force the UserState to re-evaluate on authentication change (#1291)

This commit is contained in:
David Perez 2024-04-22 09:32:30 -05:00 committed by Álison Fernandes
parent e6dfaeeab2
commit 77a7cb0e51
7 changed files with 244 additions and 27 deletions

View file

@ -50,6 +50,8 @@ import com.x8bit.bitwarden.data.auth.repository.model.ResendEmailResult
import com.x8bit.bitwarden.data.auth.repository.model.ResetPasswordResult
import com.x8bit.bitwarden.data.auth.repository.model.SetPasswordResult
import com.x8bit.bitwarden.data.auth.repository.model.SwitchAccountResult
import com.x8bit.bitwarden.data.auth.repository.model.UserAccountTokens
import com.x8bit.bitwarden.data.auth.repository.model.UserOrganizations
import com.x8bit.bitwarden.data.auth.repository.model.UserState
import com.x8bit.bitwarden.data.auth.repository.model.ValidatePasswordResult
import com.x8bit.bitwarden.data.auth.repository.model.VaultUnlockType
@ -63,6 +65,8 @@ import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams
import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJson
import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJsonWithPassword
import com.x8bit.bitwarden.data.auth.repository.util.userAccountTokens
import com.x8bit.bitwarden.data.auth.repository.util.userAccountTokensFlow
import com.x8bit.bitwarden.data.auth.repository.util.userOrganizationsList
import com.x8bit.bitwarden.data.auth.repository.util.userOrganizationsListFlow
import com.x8bit.bitwarden.data.auth.util.KdfParamsConstants.DEFAULT_PBKDF2_ITERATIONS
@ -82,6 +86,7 @@ import com.x8bit.bitwarden.data.vault.datasource.network.model.PolicyTypeJson
import com.x8bit.bitwarden.data.vault.datasource.network.model.SyncResponseJson
import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSource
import com.x8bit.bitwarden.data.vault.repository.VaultRepository
import com.x8bit.bitwarden.data.vault.repository.model.VaultUnlockData
import com.x8bit.bitwarden.data.vault.repository.model.VaultUnlockResult
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
@ -214,8 +219,10 @@ class AuthRepositoryImpl(
initialValue = AuthState.Uninitialized,
)
@Suppress("UNCHECKED_CAST", "MagicNumber")
override val userStateFlow: StateFlow<UserState?> = combine(
authDiskSource.userStateFlow,
authDiskSource.userAccountTokensFlow,
authDiskSource.userOrganizationsListFlow,
vaultRepository.vaultUnlockDataStateFlow,
mutableHasPendingAccountAdditionStateFlow,
@ -224,20 +231,19 @@ class AuthRepositoryImpl(
mutableHasPendingAccountDeletionStateFlow,
mutableUserStateTransactionCountStateFlow,
),
) {
userStateJson,
userOrganizationsList,
vaultState,
hasPendingAccountAddition,
_,
->
) { array ->
val userStateJson = array[0] as UserStateJson?
val userAccountTokens = array[1] as List<UserAccountTokens>
val userOrganizationsList = array[2] as List<UserOrganizations>
val vaultState = array[3] as List<VaultUnlockData>
val hasPendingAccountAddition = array[4] as Boolean
userStateJson?.toUserState(
vaultState = vaultState,
userAccountTokens = userAccountTokens,
userOrganizationsList = userOrganizationsList,
hasPendingAccountAddition = hasPendingAccountAddition,
isBiometricsEnabledProvider = ::isBiometricsEnabled,
vaultUnlockTypeProvider = ::getVaultUnlockType,
isLoggedInProvider = ::isUserLoggedIn,
isDeviceTrustedProvider = ::isDeviceTrusted,
)
}
@ -250,11 +256,11 @@ class AuthRepositoryImpl(
.userState
?.toUserState(
vaultState = vaultRepository.vaultUnlockDataStateFlow.value,
userAccountTokens = authDiskSource.userAccountTokens,
userOrganizationsList = authDiskSource.userOrganizationsList,
hasPendingAccountAddition = mutableHasPendingAccountAdditionStateFlow.value,
isBiometricsEnabledProvider = ::isBiometricsEnabled,
vaultUnlockTypeProvider = ::getVaultUnlockType,
isLoggedInProvider = ::isUserLoggedIn,
isDeviceTrustedProvider = ::isDeviceTrusted,
),
)
@ -1136,10 +1142,6 @@ class AuthRepositoryImpl(
userId: String,
): Boolean = authDiskSource.getDeviceKey(userId = userId) != null
private fun isUserLoggedIn(
userId: String,
): Boolean = authDiskSource.getAccountTokens(userId = userId)?.isLoggedIn == true
private fun getVaultUnlockType(
userId: String,
): VaultUnlockType =

View file

@ -0,0 +1,15 @@
package com.x8bit.bitwarden.data.auth.repository.model
/**
* Associates the [accessToken] and [refreshToken] with the given [userId].
*/
data class UserAccountTokens(
val userId: String,
val accessToken: String?,
val refreshToken: String?,
) {
/**
* Returns `true` if the user is logged in, `false otherwise.
*/
val isLoggedIn: Boolean get() = accessToken != null
}

View file

@ -1,6 +1,7 @@
package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.repository.model.UserAccountTokens
import com.x8bit.bitwarden.data.auth.repository.model.UserOrganizations
import com.x8bit.bitwarden.data.auth.repository.model.UserSwitchingData
import kotlinx.coroutines.ExperimentalCoroutinesApi
@ -55,6 +56,50 @@ val AuthDiskSource.userOrganizationsListFlow: Flow<List<UserOrganizations>>
}
.distinctUntilChanged()
/**
* Returns the current list of [UserAccountTokens].
*/
val AuthDiskSource.userAccountTokens: List<UserAccountTokens>
get() = this
.userState
?.accounts
.orEmpty()
.map { (userId, _) ->
val accountTokens = this.getAccountTokens(userId = userId)
UserAccountTokens(
userId = userId,
accessToken = accountTokens?.accessToken,
refreshToken = accountTokens?.refreshToken,
)
}
/**
* Returns a [Flow] that emits distinct updates to [UserAccountTokens].
*/
@OptIn(ExperimentalCoroutinesApi::class)
val AuthDiskSource.userAccountTokensFlow: Flow<List<UserAccountTokens>>
get() = this
.userStateFlow
.flatMapLatest { userStateJson ->
combine(
userStateJson
?.accounts
.orEmpty()
.map { (userId, _) ->
this
.getAccountTokensFlow(userId = userId)
.map {
UserAccountTokens(
userId = userId,
accessToken = it?.accessToken,
refreshToken = it?.refreshToken,
)
}
},
) { it.toList() }
}
.distinctUntilChanged()
/**
* Returns a [Flow] that emits every time the active user is changed.
*/

View file

@ -2,6 +2,7 @@ package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.UserDecryptionOptionsJson
import com.x8bit.bitwarden.data.auth.repository.model.UserAccountTokens
import com.x8bit.bitwarden.data.auth.repository.model.UserOrganizations
import com.x8bit.bitwarden.data.auth.repository.model.UserState
import com.x8bit.bitwarden.data.auth.repository.model.VaultUnlockType
@ -77,11 +78,11 @@ fun UserStateJson.toUserStateJsonWithPassword(): UserStateJson {
@Suppress("LongParameterList")
fun UserStateJson.toUserState(
vaultState: List<VaultUnlockData>,
userAccountTokens: List<UserAccountTokens>,
userOrganizationsList: List<UserOrganizations>,
hasPendingAccountAddition: Boolean,
isBiometricsEnabledProvider: (userId: String) -> Boolean,
vaultUnlockTypeProvider: (userId: String) -> VaultUnlockType,
isLoggedInProvider: (userId: String) -> Boolean,
isDeviceTrustedProvider: (userId: String) -> Boolean,
): UserState =
UserState(
@ -120,7 +121,9 @@ fun UserStateJson.toUserState(
.environmentUrlData
.toEnvironmentUrlsOrDefault(),
isPremium = profile.hasPremium == true,
isLoggedIn = isLoggedInProvider(userId),
isLoggedIn = userAccountTokens
.find { it.userId == userId }
?.isLoggedIn == true,
isVaultUnlocked = vaultUnlocked,
needsPasswordReset = needsPasswordReset,
organizations = userOrganizationsList

View file

@ -312,11 +312,11 @@ class AuthRepositoryTest {
assertEquals(
SINGLE_USER_STATE_1.toUserState(
vaultState = VAULT_UNLOCK_DATA,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
),
repository.userStateFlow.value,
@ -336,11 +336,11 @@ class AuthRepositoryTest {
assertEquals(
MULTI_USER_STATE.toUserState(
vaultState = VAULT_UNLOCK_DATA,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.PIN },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
),
repository.userStateFlow.value,
@ -351,11 +351,11 @@ class AuthRepositoryTest {
assertEquals(
MULTI_USER_STATE.toUserState(
vaultState = emptyVaultState,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.PIN },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
),
repository.userStateFlow.value,
@ -378,11 +378,11 @@ class AuthRepositoryTest {
assertEquals(
MULTI_USER_STATE.toUserState(
vaultState = emptyVaultState,
userAccountTokens = emptyList(),
userOrganizationsList = USER_ORGANIZATIONS,
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
),
repository.userStateFlow.value,
@ -574,20 +574,20 @@ class AuthRepositoryTest {
val hashedMasterPassword = "dlrow olleh"
val originalUserState = SINGLE_USER_STATE_1.toUserState(
vaultState = VAULT_UNLOCK_DATA,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
)
val finalUserState = SINGLE_USER_STATE_2.toUserState(
vaultState = VAULT_UNLOCK_DATA,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
)
val kdf = SINGLE_USER_STATE_1.activeAccount.profile.toSdkParams()
@ -4193,11 +4193,11 @@ class AuthRepositoryTest {
val originalUserId = USER_ID_1
val originalUserState = SINGLE_USER_STATE_1.toUserState(
vaultState = VAULT_UNLOCK_DATA,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
)
fakeAuthDiskSource.userState = SINGLE_USER_STATE_1
@ -4225,11 +4225,11 @@ class AuthRepositoryTest {
val invalidId = "invalidId"
val originalUserState = SINGLE_USER_STATE_1.toUserState(
vaultState = VAULT_UNLOCK_DATA,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
)
fakeAuthDiskSource.userState = SINGLE_USER_STATE_1
@ -4255,11 +4255,11 @@ class AuthRepositoryTest {
val updatedUserId = USER_ID_2
val originalUserState = MULTI_USER_STATE.toUserState(
vaultState = VAULT_UNLOCK_DATA,
userAccountTokens = emptyList(),
userOrganizationsList = emptyList(),
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
)
fakeAuthDiskSource.userState = MULTI_USER_STATE

View file

@ -7,6 +7,7 @@ import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountTokensJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource
import com.x8bit.bitwarden.data.auth.repository.model.Organization
import com.x8bit.bitwarden.data.auth.repository.model.UserAccountTokens
import com.x8bit.bitwarden.data.auth.repository.model.UserOrganizations
import com.x8bit.bitwarden.data.auth.repository.model.UserSwitchingData
import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockOrganization
@ -20,6 +21,138 @@ import org.junit.jupiter.api.Test
class AuthDiskSourceExtensionsTest {
private val authDiskSource: AuthDiskSource = FakeAuthDiskSource()
@Test
fun `userAccountTokens should return data for all available users`() {
val mockAccounts = mapOf(
"userId1" to mockk<AccountJson>(),
"userId2" to mockk<AccountJson>(),
"userId3" to mockk<AccountJson>(),
)
val userStateJson = mockk<UserStateJson> {
every { accounts } returns mockAccounts
}
authDiskSource.apply {
userState = userStateJson
storeAccountTokens(
userId = "userId1",
accountTokens = AccountTokensJson(
accessToken = "accessToken1",
refreshToken = "refreshToken1",
),
)
storeAccountTokens(
userId = "userId2",
accountTokens = AccountTokensJson(
accessToken = "accessToken2",
refreshToken = "refreshToken2",
),
)
storeAccountTokens(
userId = "userId3",
accountTokens = AccountTokensJson(
accessToken = null,
refreshToken = null,
),
)
}
assertEquals(
listOf(
UserAccountTokens(
userId = "userId1",
accessToken = "accessToken1",
refreshToken = "refreshToken1",
),
UserAccountTokens(
userId = "userId2",
accessToken = "accessToken2",
refreshToken = "refreshToken2",
),
UserAccountTokens(
userId = "userId3",
accessToken = null,
refreshToken = null,
),
),
authDiskSource.userAccountTokens,
)
}
@Test
fun `userAccountTokensFlow should emit whenever there are changes to the token data`() =
runTest {
val mockAccounts = mapOf(
"userId1" to mockk<AccountJson>(),
"userId2" to mockk<AccountJson>(),
"userId3" to mockk<AccountJson>(),
)
val userStateJson = mockk<UserStateJson> {
every { accounts } returns mockAccounts
}
authDiskSource.apply {
userState = userStateJson
storeAccountTokens(
userId = "userId1",
accountTokens = AccountTokensJson(
accessToken = "accessToken1",
refreshToken = "refreshToken1",
),
)
}
authDiskSource.userAccountTokensFlow.test {
assertEquals(
listOf(
UserAccountTokens(
userId = "userId1",
accessToken = "accessToken1",
refreshToken = "refreshToken1",
),
UserAccountTokens(
userId = "userId2",
accessToken = null,
refreshToken = null,
),
UserAccountTokens(
userId = "userId3",
accessToken = null,
refreshToken = null,
),
),
awaitItem(),
)
authDiskSource.storeAccountTokens(
userId = "userId2",
accountTokens = AccountTokensJson(
accessToken = "accessToken2",
refreshToken = "refreshToken2",
),
)
assertEquals(
listOf(
UserAccountTokens(
userId = "userId1",
accessToken = "accessToken1",
refreshToken = "refreshToken1",
),
UserAccountTokens(
userId = "userId2",
accessToken = "accessToken2",
refreshToken = "refreshToken2",
),
UserAccountTokens(
userId = "userId3",
accessToken = null,
refreshToken = null,
),
),
awaitItem(),
)
}
}
@Test
fun `userOrganizationsList should return data for all available users`() {
val mockAccounts = mapOf(

View file

@ -10,6 +10,7 @@ import com.x8bit.bitwarden.data.auth.datasource.network.model.KeyConnectorUserDe
import com.x8bit.bitwarden.data.auth.datasource.network.model.TrustedDeviceUserDecryptionOptionsJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.UserDecryptionOptionsJson
import com.x8bit.bitwarden.data.auth.repository.model.Organization
import com.x8bit.bitwarden.data.auth.repository.model.UserAccountTokens
import com.x8bit.bitwarden.data.auth.repository.model.UserOrganizations
import com.x8bit.bitwarden.data.auth.repository.model.UserState
import com.x8bit.bitwarden.data.auth.repository.model.VaultUnlockType
@ -271,6 +272,13 @@ class UserStateJsonExtensionsTest {
status = VaultUnlockData.Status.UNLOCKED,
),
),
userAccountTokens = listOf(
UserAccountTokens(
userId = "activeUserId",
accessToken = "accessToken",
refreshToken = "refreshToken",
),
),
userOrganizationsList = listOf(
UserOrganizations(
userId = "activeUserId",
@ -285,7 +293,6 @@ class UserStateJsonExtensionsTest {
hasPendingAccountAddition = false,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.PIN },
isLoggedInProvider = { true },
isDeviceTrustedProvider = { false },
),
)
@ -351,6 +358,13 @@ class UserStateJsonExtensionsTest {
)
.toUserState(
vaultState = emptyList(),
userAccountTokens = listOf(
UserAccountTokens(
userId = "activeUserId",
accessToken = null,
refreshToken = null,
),
),
userOrganizationsList = listOf(
UserOrganizations(
userId = "activeUserId",
@ -365,7 +379,6 @@ class UserStateJsonExtensionsTest {
hasPendingAccountAddition = true,
isBiometricsEnabledProvider = { true },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { false },
),
)
@ -441,6 +454,13 @@ class UserStateJsonExtensionsTest {
)
.toUserState(
vaultState = emptyList(),
userAccountTokens = listOf(
UserAccountTokens(
userId = "activeUserId",
accessToken = null,
refreshToken = null,
),
),
userOrganizationsList = listOf(
UserOrganizations(
userId = "activeUserId",
@ -455,7 +475,6 @@ class UserStateJsonExtensionsTest {
hasPendingAccountAddition = true,
isBiometricsEnabledProvider = { false },
vaultUnlockTypeProvider = { VaultUnlockType.MASTER_PASSWORD },
isLoggedInProvider = { false },
isDeviceTrustedProvider = { true },
),
)