diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt index d43f4bdf3..ffa8d045f 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt @@ -22,6 +22,7 @@ import com.x8bit.bitwarden.data.auth.repository.util.toUserState import com.x8bit.bitwarden.data.auth.util.KdfParamsConstants.DEFAULT_PBKDF2_ITERATIONS import com.x8bit.bitwarden.data.auth.util.toSdkParams import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager +import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asSuccess import com.x8bit.bitwarden.data.platform.util.flatMap @@ -47,6 +48,7 @@ class AuthRepositoryImpl constructor( private val identityService: IdentityService, private val authSdkSource: AuthSdkSource, private val authDiskSource: AuthDiskSource, + private val environmentRepository: EnvironmentRepository, private val vaultRepository: VaultRepository, dispatcherManager: DispatcherManager, ) : AuthRepository { @@ -126,6 +128,9 @@ class AuthRepositoryImpl constructor( authDiskSource.userState = it .toUserState( previousUserState = authDiskSource.userState, + environmentUrlData = environmentRepository + .environment + .environmentUrlData, ) .also { userState -> authDiskSource.storeUserKey( diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt index 31a3baf1b..106c4b17e 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt @@ -8,6 +8,7 @@ import com.x8bit.bitwarden.data.auth.datasource.sdk.AuthSdkSource import com.x8bit.bitwarden.data.auth.repository.AuthRepository import com.x8bit.bitwarden.data.auth.repository.AuthRepositoryImpl import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager +import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.vault.repository.VaultRepository import dagger.Module import dagger.Provides @@ -32,6 +33,7 @@ object AuthRepositoryModule { authSdkSource: AuthSdkSource, authDiskSource: AuthDiskSource, dispatchers: DispatcherManager, + environmentRepository: EnvironmentRepository, vaultRepository: VaultRepository, ): AuthRepository = AuthRepositoryImpl( accountsService = accountsService, @@ -40,6 +42,7 @@ object AuthRepositoryModule { authDiskSource = authDiskSource, haveIBeenPwnedService = haveIBeenPwnedService, dispatcherManager = dispatchers, + environmentRepository = environmentRepository, vaultRepository = vaultRepository, ) } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensions.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensions.kt index e72b0f81e..735a34477 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensions.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensions.kt @@ -1,15 +1,20 @@ package com.x8bit.bitwarden.data.auth.repository.util import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson /** - * Converts the given [GetTokenResponseJson.Success] to a [UserStateJson], given the current - * [previousUserState]. + * Converts the given [GetTokenResponseJson.Success] to a [UserStateJson], given the following + * additional information: + * + * - the [previousUserState] + * - the current [environmentUrlData] */ fun GetTokenResponseJson.Success.toUserState( previousUserState: UserStateJson?, + environmentUrlData: EnvironmentUrlDataJson, ): UserStateJson { val accessToken = this.accessToken @@ -40,7 +45,7 @@ fun GetTokenResponseJson.Success.toUserState( refreshToken = this.refreshToken, ), settings = AccountJson.Settings( - environmentUrlData = null, + environmentUrlData = environmentUrlData, ), ) diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryImpl.kt index fe4bf943f..ab9af19a6 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryImpl.kt @@ -1,5 +1,6 @@ package com.x8bit.bitwarden.data.platform.repository +import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager @@ -8,7 +9,9 @@ import com.x8bit.bitwarden.data.platform.repository.util.toEnvironmentUrls import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.stateIn /** @@ -16,6 +19,7 @@ import kotlinx.coroutines.flow.stateIn */ class EnvironmentRepositoryImpl( private val environmentDiskSource: EnvironmentDiskSource, + private val authDiskSource: AuthDiskSource, dispatcherManager: DispatcherManager, ) : EnvironmentRepository { @@ -38,6 +42,20 @@ class EnvironmentRepositoryImpl( started = SharingStarted.Lazily, initialValue = Environment.Us, ) + + init { + authDiskSource + .userStateFlow + .onEach { userState -> + // If the active account has environment data, set that as the current value. + userState + ?.activeAccount + ?.settings + ?.environmentUrlData + ?.let { environmentDiskSource.preAuthEnvironmentUrlData = it } + } + .launchIn(scope) + } } /** diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/di/PlatformRepositoryModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/di/PlatformRepositoryModule.kt index fff21bd23..5c56f5c87 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/di/PlatformRepositoryModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/repository/di/PlatformRepositoryModule.kt @@ -1,5 +1,6 @@ package com.x8bit.bitwarden.data.platform.repository.di +import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository @@ -21,10 +22,12 @@ object PlatformRepositoryModule { @Singleton fun provideEnvironmentRepository( environmentDiskSource: EnvironmentDiskSource, + authDiskSource: AuthDiskSource, dispatcherManager: DispatcherManager, ): EnvironmentRepository = EnvironmentRepositoryImpl( environmentDiskSource = environmentDiskSource, + authDiskSource = authDiskSource, dispatcherManager = dispatcherManager, ) } diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt index 3d182b76a..4a0b6074e 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt @@ -5,6 +5,7 @@ import com.bitwarden.core.Kdf import com.bitwarden.core.RegisterKeyResponse import com.bitwarden.core.RsaKeyPair import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson @@ -31,6 +32,8 @@ import com.x8bit.bitwarden.data.auth.repository.util.toUserState import com.x8bit.bitwarden.data.auth.util.toSdkParams import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager +import com.x8bit.bitwarden.data.platform.repository.model.Environment +import com.x8bit.bitwarden.data.platform.repository.util.FakeEnvironmentRepository import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asSuccess import com.x8bit.bitwarden.data.vault.repository.VaultRepository @@ -59,6 +62,11 @@ class AuthRepositoryTest { private val haveIBeenPwnedService: HaveIBeenPwnedService = mockk() private val vaultRepository: VaultRepository = mockk() private val fakeAuthDiskSource = FakeAuthDiskSource() + private val fakeEnvironmentRepository = + FakeEnvironmentRepository() + .apply { + environment = Environment.Us + } private val authSdkSource = mockk { coEvery { hashPassword( @@ -91,8 +99,9 @@ class AuthRepositoryTest { haveIBeenPwnedService = haveIBeenPwnedService, authSdkSource = authSdkSource, authDiskSource = fakeAuthDiskSource, - dispatcherManager = dispatcherManager, + environmentRepository = fakeEnvironmentRepository, vaultRepository = vaultRepository, + dispatcherManager = dispatcherManager, ) @BeforeEach @@ -278,7 +287,10 @@ class AuthRepositoryTest { vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD) } returns VaultUnlockResult.Success every { - GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null) + GET_TOKEN_RESPONSE_SUCCESS.toUserState( + previousUserState = null, + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, + ) } returns SINGLE_USER_STATE_1 val result = repository.login(email = EMAIL, password = PASSWORD, captchaToken = null) assertEquals(LoginResult.Success, result) @@ -684,7 +696,10 @@ class AuthRepositoryTest { vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD) } returns VaultUnlockResult.Success every { - GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null) + GET_TOKEN_RESPONSE_SUCCESS.toUserState( + previousUserState = null, + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, + ) } returns SINGLE_USER_STATE_1 repository.login(email = EMAIL, password = PASSWORD, captchaToken = null) @@ -733,7 +748,10 @@ class AuthRepositoryTest { vaultRepository.unlockVaultAndSync(masterPassword = PASSWORD) } returns VaultUnlockResult.Success every { - GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = SINGLE_USER_STATE_2) + GET_TOKEN_RESPONSE_SUCCESS.toUserState( + previousUserState = SINGLE_USER_STATE_2, + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, + ) } returns MULTI_USER_STATE repository.login(email = EMAIL, password = PASSWORD, captchaToken = null) diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensionsTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensionsTest.kt index 32885df0b..1755aca94 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensionsTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/GetTokenResponseExtensionsTest.kt @@ -1,6 +1,7 @@ package com.x8bit.bitwarden.data.auth.repository.util import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson @@ -31,7 +32,10 @@ class GetTokenResponseExtensionsTest { assertEquals( SINGLE_USER_STATE_1, - GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null), + GET_TOKEN_RESPONSE_SUCCESS.toUserState( + previousUserState = null, + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, + ), ) } @@ -41,7 +45,10 @@ class GetTokenResponseExtensionsTest { assertEquals( MULTI_USER_STATE, - GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = SINGLE_USER_STATE_2), + GET_TOKEN_RESPONSE_SUCCESS.toUserState( + previousUserState = SINGLE_USER_STATE_2, + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, + ), ) } } @@ -102,7 +109,7 @@ private val ACCOUNT_1 = AccountJson( refreshToken = "refreshToken", ), settings = AccountJson.Settings( - environmentUrlData = null, + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, ), ) private val ACCOUNT_2 = AccountJson( @@ -127,7 +134,7 @@ private val ACCOUNT_2 = AccountJson( refreshToken = "refreshToken", ), settings = AccountJson.Settings( - environmentUrlData = null, + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, ), ) private val SINGLE_USER_STATE_1 = UserStateJson( diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryTest.kt index 8929bac24..20d55a6f8 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/repository/EnvironmentRepositoryTest.kt @@ -1,7 +1,10 @@ package com.x8bit.bitwarden.data.platform.repository import app.cash.turbine.test +import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson +import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager @@ -26,9 +29,11 @@ class EnvironmentRepositoryTest { private val dispatcherManager: DispatcherManager = FakeDispatcherManager() private val fakeEnvironmentDiskSource = FakeEnvironmentDiskSource() + private val fakeAuthDiskSource = FakeAuthDiskSource() private val repository = EnvironmentRepositoryImpl( environmentDiskSource = fakeEnvironmentDiskSource, + authDiskSource = fakeAuthDiskSource, dispatcherManager = dispatcherManager, ) @@ -42,6 +47,46 @@ class EnvironmentRepositoryTest { unmockkStatic(ENVIRONMENT_EXTENSIONS_PATH) } + @Test + fun `changes to the active user should update the environment if necessary`() { + assertEquals( + Environment.Us, + repository.environment, + ) + assertEquals( + null, + fakeEnvironmentDiskSource.preAuthEnvironmentUrlData, + ) + + // Updating the environment for the active user to a non-null value triggers an update + // in the saved environment. + fakeAuthDiskSource.userState = getMockUserState( + environmentForActiveUser = EnvironmentUrlDataJson.DEFAULT_EU, + ) + assertEquals( + Environment.Eu, + repository.environment, + ) + assertEquals( + EnvironmentUrlDataJson.DEFAULT_EU, + fakeEnvironmentDiskSource.preAuthEnvironmentUrlData, + ) + + // Updating the environment for the active user to a null value leaves the current + // environment unchanged. + fakeAuthDiskSource.userState = getMockUserState( + environmentForActiveUser = null, + ) + assertEquals( + Environment.Eu, + repository.environment, + ) + assertEquals( + EnvironmentUrlDataJson.DEFAULT_EU, + fakeEnvironmentDiskSource.preAuthEnvironmentUrlData, + ) + } + @Test fun `environment should pull from and update EnvironmentDiskSource`() { val environmentUrlDataJson = mockk() @@ -101,6 +146,22 @@ class EnvironmentRepositoryTest { assertEquals(environment, awaitItem()) } } + + private fun getMockUserState( + environmentForActiveUser: EnvironmentUrlDataJson?, + ): UserStateJson = + UserStateJson( + activeUserId = "activeUserId", + accounts = mapOf( + "activeUserId" to AccountJson( + profile = mockk(), + tokens = mockk(), + settings = AccountJson.Settings( + environmentUrlData = environmentForActiveUser, + ), + ), + ), + ) } private const val ENVIRONMENT_EXTENSIONS_PATH =