Save the current environment to UserStateJson when logging in (#260)

This commit is contained in:
Brian Yencho 2023-11-20 12:47:34 -06:00 committed by Álison Fernandes
parent ef35477083
commit b168f6fb09
8 changed files with 131 additions and 11 deletions

View file

@ -22,6 +22,7 @@ import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.util.KdfParamsConstants.DEFAULT_PBKDF2_ITERATIONS import com.x8bit.bitwarden.data.auth.util.KdfParamsConstants.DEFAULT_PBKDF2_ITERATIONS
import com.x8bit.bitwarden.data.auth.util.toSdkParams import com.x8bit.bitwarden.data.auth.util.toSdkParams
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asFailure
import com.x8bit.bitwarden.data.platform.util.asSuccess import com.x8bit.bitwarden.data.platform.util.asSuccess
import com.x8bit.bitwarden.data.platform.util.flatMap import com.x8bit.bitwarden.data.platform.util.flatMap
@ -47,6 +48,7 @@ class AuthRepositoryImpl constructor(
private val identityService: IdentityService, private val identityService: IdentityService,
private val authSdkSource: AuthSdkSource, private val authSdkSource: AuthSdkSource,
private val authDiskSource: AuthDiskSource, private val authDiskSource: AuthDiskSource,
private val environmentRepository: EnvironmentRepository,
private val vaultRepository: VaultRepository, private val vaultRepository: VaultRepository,
dispatcherManager: DispatcherManager, dispatcherManager: DispatcherManager,
) : AuthRepository { ) : AuthRepository {
@ -126,6 +128,9 @@ class AuthRepositoryImpl constructor(
authDiskSource.userState = it authDiskSource.userState = it
.toUserState( .toUserState(
previousUserState = authDiskSource.userState, previousUserState = authDiskSource.userState,
environmentUrlData = environmentRepository
.environment
.environmentUrlData,
) )
.also { userState -> .also { userState ->
authDiskSource.storeUserKey( authDiskSource.storeUserKey(

View file

@ -8,6 +8,7 @@ import com.x8bit.bitwarden.data.auth.datasource.sdk.AuthSdkSource
import com.x8bit.bitwarden.data.auth.repository.AuthRepository import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.auth.repository.AuthRepositoryImpl import com.x8bit.bitwarden.data.auth.repository.AuthRepositoryImpl
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
import com.x8bit.bitwarden.data.vault.repository.VaultRepository import com.x8bit.bitwarden.data.vault.repository.VaultRepository
import dagger.Module import dagger.Module
import dagger.Provides import dagger.Provides
@ -32,6 +33,7 @@ object AuthRepositoryModule {
authSdkSource: AuthSdkSource, authSdkSource: AuthSdkSource,
authDiskSource: AuthDiskSource, authDiskSource: AuthDiskSource,
dispatchers: DispatcherManager, dispatchers: DispatcherManager,
environmentRepository: EnvironmentRepository,
vaultRepository: VaultRepository, vaultRepository: VaultRepository,
): AuthRepository = AuthRepositoryImpl( ): AuthRepository = AuthRepositoryImpl(
accountsService = accountsService, accountsService = accountsService,
@ -40,6 +42,7 @@ object AuthRepositoryModule {
authDiskSource = authDiskSource, authDiskSource = authDiskSource,
haveIBeenPwnedService = haveIBeenPwnedService, haveIBeenPwnedService = haveIBeenPwnedService,
dispatcherManager = dispatchers, dispatcherManager = dispatchers,
environmentRepository = environmentRepository,
vaultRepository = vaultRepository, vaultRepository = vaultRepository,
) )
} }

View file

@ -1,15 +1,20 @@
package com.x8bit.bitwarden.data.auth.repository.util package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson
/** /**
* Converts the given [GetTokenResponseJson.Success] to a [UserStateJson], given the current * Converts the given [GetTokenResponseJson.Success] to a [UserStateJson], given the following
* [previousUserState]. * additional information:
*
* - the [previousUserState]
* - the current [environmentUrlData]
*/ */
fun GetTokenResponseJson.Success.toUserState( fun GetTokenResponseJson.Success.toUserState(
previousUserState: UserStateJson?, previousUserState: UserStateJson?,
environmentUrlData: EnvironmentUrlDataJson,
): UserStateJson { ): UserStateJson {
val accessToken = this.accessToken val accessToken = this.accessToken
@ -40,7 +45,7 @@ fun GetTokenResponseJson.Success.toUserState(
refreshToken = this.refreshToken, refreshToken = this.refreshToken,
), ),
settings = AccountJson.Settings( settings = AccountJson.Settings(
environmentUrlData = null, environmentUrlData = environmentUrlData,
), ),
) )

View file

@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.platform.repository package com.x8bit.bitwarden.data.platform.repository
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
@ -8,7 +9,9 @@ import com.x8bit.bitwarden.data.platform.repository.util.toEnvironmentUrls
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.flow.stateIn
/** /**
@ -16,6 +19,7 @@ import kotlinx.coroutines.flow.stateIn
*/ */
class EnvironmentRepositoryImpl( class EnvironmentRepositoryImpl(
private val environmentDiskSource: EnvironmentDiskSource, private val environmentDiskSource: EnvironmentDiskSource,
private val authDiskSource: AuthDiskSource,
dispatcherManager: DispatcherManager, dispatcherManager: DispatcherManager,
) : EnvironmentRepository { ) : EnvironmentRepository {
@ -38,6 +42,20 @@ class EnvironmentRepositoryImpl(
started = SharingStarted.Lazily, started = SharingStarted.Lazily,
initialValue = Environment.Us, initialValue = Environment.Us,
) )
init {
authDiskSource
.userStateFlow
.onEach { userState ->
// If the active account has environment data, set that as the current value.
userState
?.activeAccount
?.settings
?.environmentUrlData
?.let { environmentDiskSource.preAuthEnvironmentUrlData = it }
}
.launchIn(scope)
}
} }
/** /**

View file

@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.platform.repository.di package com.x8bit.bitwarden.data.platform.repository.di
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
@ -21,10 +22,12 @@ object PlatformRepositoryModule {
@Singleton @Singleton
fun provideEnvironmentRepository( fun provideEnvironmentRepository(
environmentDiskSource: EnvironmentDiskSource, environmentDiskSource: EnvironmentDiskSource,
authDiskSource: AuthDiskSource,
dispatcherManager: DispatcherManager, dispatcherManager: DispatcherManager,
): EnvironmentRepository = ): EnvironmentRepository =
EnvironmentRepositoryImpl( EnvironmentRepositoryImpl(
environmentDiskSource = environmentDiskSource, environmentDiskSource = environmentDiskSource,
authDiskSource = authDiskSource,
dispatcherManager = dispatcherManager, dispatcherManager = dispatcherManager,
) )
} }

View file

@ -5,6 +5,7 @@ import com.bitwarden.core.Kdf
import com.bitwarden.core.RegisterKeyResponse import com.bitwarden.core.RegisterKeyResponse
import com.bitwarden.core.RsaKeyPair import com.bitwarden.core.RsaKeyPair
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson
@ -31,6 +32,8 @@ import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.util.toSdkParams import com.x8bit.bitwarden.data.auth.util.toSdkParams
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.repository.model.Environment
import com.x8bit.bitwarden.data.platform.repository.util.FakeEnvironmentRepository
import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asFailure
import com.x8bit.bitwarden.data.platform.util.asSuccess import com.x8bit.bitwarden.data.platform.util.asSuccess
import com.x8bit.bitwarden.data.vault.repository.VaultRepository import com.x8bit.bitwarden.data.vault.repository.VaultRepository
@ -59,6 +62,11 @@ class AuthRepositoryTest {
private val haveIBeenPwnedService: HaveIBeenPwnedService = mockk() private val haveIBeenPwnedService: HaveIBeenPwnedService = mockk()
private val vaultRepository: VaultRepository = mockk() private val vaultRepository: VaultRepository = mockk()
private val fakeAuthDiskSource = FakeAuthDiskSource() private val fakeAuthDiskSource = FakeAuthDiskSource()
private val fakeEnvironmentRepository =
FakeEnvironmentRepository()
.apply {
environment = Environment.Us
}
private val authSdkSource = mockk<AuthSdkSource> { private val authSdkSource = mockk<AuthSdkSource> {
coEvery { coEvery {
hashPassword( hashPassword(
@ -91,8 +99,9 @@ class AuthRepositoryTest {
haveIBeenPwnedService = haveIBeenPwnedService, haveIBeenPwnedService = haveIBeenPwnedService,
authSdkSource = authSdkSource, authSdkSource = authSdkSource,
authDiskSource = fakeAuthDiskSource, authDiskSource = fakeAuthDiskSource,
dispatcherManager = dispatcherManager, environmentRepository = fakeEnvironmentRepository,
vaultRepository = vaultRepository, vaultRepository = vaultRepository,
dispatcherManager = dispatcherManager,
) )
@BeforeEach @BeforeEach
@ -278,7 +287,10 @@ class AuthRepositoryTest {
vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD) vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD)
} returns VaultUnlockResult.Success } returns VaultUnlockResult.Success
every { every {
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null) GET_TOKEN_RESPONSE_SUCCESS.toUserState(
previousUserState = null,
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
)
} returns SINGLE_USER_STATE_1 } returns SINGLE_USER_STATE_1
val result = repository.login(email = EMAIL, password = PASSWORD, captchaToken = null) val result = repository.login(email = EMAIL, password = PASSWORD, captchaToken = null)
assertEquals(LoginResult.Success, result) assertEquals(LoginResult.Success, result)
@ -684,7 +696,10 @@ class AuthRepositoryTest {
vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD) vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD)
} returns VaultUnlockResult.Success } returns VaultUnlockResult.Success
every { every {
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null) GET_TOKEN_RESPONSE_SUCCESS.toUserState(
previousUserState = null,
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
)
} returns SINGLE_USER_STATE_1 } returns SINGLE_USER_STATE_1
repository.login(email = EMAIL, password = PASSWORD, captchaToken = null) repository.login(email = EMAIL, password = PASSWORD, captchaToken = null)
@ -733,7 +748,10 @@ class AuthRepositoryTest {
vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD) vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD)
} returns VaultUnlockResult.Success } returns VaultUnlockResult.Success
every { every {
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = SINGLE_USER_STATE_2) GET_TOKEN_RESPONSE_SUCCESS.toUserState(
previousUserState = SINGLE_USER_STATE_2,
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
)
} returns MULTI_USER_STATE } returns MULTI_USER_STATE
repository.login(email = EMAIL, password = PASSWORD, captchaToken = null) repository.login(email = EMAIL, password = PASSWORD, captchaToken = null)

View file

@ -1,6 +1,7 @@
package com.x8bit.bitwarden.data.auth.repository.util package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
@ -31,7 +32,10 @@ class GetTokenResponseExtensionsTest {
assertEquals( assertEquals(
SINGLE_USER_STATE_1, SINGLE_USER_STATE_1,
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null), GET_TOKEN_RESPONSE_SUCCESS.toUserState(
previousUserState = null,
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
),
) )
} }
@ -41,7 +45,10 @@ class GetTokenResponseExtensionsTest {
assertEquals( assertEquals(
MULTI_USER_STATE, MULTI_USER_STATE,
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = SINGLE_USER_STATE_2), GET_TOKEN_RESPONSE_SUCCESS.toUserState(
previousUserState = SINGLE_USER_STATE_2,
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
),
) )
} }
} }
@ -102,7 +109,7 @@ private val ACCOUNT_1 = AccountJson(
refreshToken = "refreshToken", refreshToken = "refreshToken",
), ),
settings = AccountJson.Settings( settings = AccountJson.Settings(
environmentUrlData = null, environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
), ),
) )
private val ACCOUNT_2 = AccountJson( private val ACCOUNT_2 = AccountJson(
@ -127,7 +134,7 @@ private val ACCOUNT_2 = AccountJson(
refreshToken = "refreshToken", refreshToken = "refreshToken",
), ),
settings = AccountJson.Settings( settings = AccountJson.Settings(
environmentUrlData = null, environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
), ),
) )
private val SINGLE_USER_STATE_1 = UserStateJson( private val SINGLE_USER_STATE_1 = UserStateJson(

View file

@ -1,7 +1,10 @@
package com.x8bit.bitwarden.data.platform.repository package com.x8bit.bitwarden.data.platform.repository
import app.cash.turbine.test import app.cash.turbine.test
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
@ -26,9 +29,11 @@ class EnvironmentRepositoryTest {
private val dispatcherManager: DispatcherManager = FakeDispatcherManager() private val dispatcherManager: DispatcherManager = FakeDispatcherManager()
private val fakeEnvironmentDiskSource = FakeEnvironmentDiskSource() private val fakeEnvironmentDiskSource = FakeEnvironmentDiskSource()
private val fakeAuthDiskSource = FakeAuthDiskSource()
private val repository = EnvironmentRepositoryImpl( private val repository = EnvironmentRepositoryImpl(
environmentDiskSource = fakeEnvironmentDiskSource, environmentDiskSource = fakeEnvironmentDiskSource,
authDiskSource = fakeAuthDiskSource,
dispatcherManager = dispatcherManager, dispatcherManager = dispatcherManager,
) )
@ -42,6 +47,46 @@ class EnvironmentRepositoryTest {
unmockkStatic(ENVIRONMENT_EXTENSIONS_PATH) unmockkStatic(ENVIRONMENT_EXTENSIONS_PATH)
} }
@Test
fun `changes to the active user should update the environment if necessary`() {
assertEquals(
Environment.Us,
repository.environment,
)
assertEquals(
null,
fakeEnvironmentDiskSource.preAuthEnvironmentUrlData,
)
// Updating the environment for the active user to a non-null value triggers an update
// in the saved environment.
fakeAuthDiskSource.userState = getMockUserState(
environmentForActiveUser = EnvironmentUrlDataJson.DEFAULT_EU,
)
assertEquals(
Environment.Eu,
repository.environment,
)
assertEquals(
EnvironmentUrlDataJson.DEFAULT_EU,
fakeEnvironmentDiskSource.preAuthEnvironmentUrlData,
)
// Updating the environment for the active user to a null value leaves the current
// environment unchanged.
fakeAuthDiskSource.userState = getMockUserState(
environmentForActiveUser = null,
)
assertEquals(
Environment.Eu,
repository.environment,
)
assertEquals(
EnvironmentUrlDataJson.DEFAULT_EU,
fakeEnvironmentDiskSource.preAuthEnvironmentUrlData,
)
}
@Test @Test
fun `environment should pull from and update EnvironmentDiskSource`() { fun `environment should pull from and update EnvironmentDiskSource`() {
val environmentUrlDataJson = mockk<EnvironmentUrlDataJson>() val environmentUrlDataJson = mockk<EnvironmentUrlDataJson>()
@ -101,6 +146,22 @@ class EnvironmentRepositoryTest {
assertEquals(environment, awaitItem()) assertEquals(environment, awaitItem())
} }
} }
private fun getMockUserState(
environmentForActiveUser: EnvironmentUrlDataJson?,
): UserStateJson =
UserStateJson(
activeUserId = "activeUserId",
accounts = mapOf(
"activeUserId" to AccountJson(
profile = mockk(),
tokens = mockk(),
settings = AccountJson.Settings(
environmentUrlData = environmentForActiveUser,
),
),
),
)
} }
private const val ENVIRONMENT_EXTENSIONS_PATH = private const val ENVIRONMENT_EXTENSIONS_PATH =