BIT-2440, BIT-2441: Clean up the way we update the account info after token refresh (#3416)

This commit is contained in:
David Perez 2024-07-08 10:44:43 -05:00 committed by GitHub
parent 4b0c6ad911
commit e9057cb866
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 32 additions and 234 deletions

View file

@ -65,7 +65,6 @@ import com.x8bit.bitwarden.data.auth.repository.util.activeUserIdChangesFlow
import com.x8bit.bitwarden.data.auth.repository.util.policyInformation
import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams
import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJson
import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJsonWithPassword
import com.x8bit.bitwarden.data.auth.repository.util.userAccountTokens
import com.x8bit.bitwarden.data.auth.repository.util.userAccountTokensFlow
@ -625,19 +624,24 @@ class AuthRepositoryImpl(
?: return IllegalStateException("Must be logged in.").asFailure()
return identityService
.refreshTokenSynchronously(refreshToken)
.onSuccess {
.flatMap { refreshTokenResponse ->
// Check to make sure the user is still logged in after making the request
authDiskSource
.userState
?.accounts
?.get(userId)
?.let { refreshTokenResponse.asSuccess() }
?: IllegalStateException("Must be logged in.").asFailure()
}
.onSuccess { refreshTokenResponse ->
// Update the existing UserState with updated token information
authDiskSource.storeAccountTokens(
userId = userId,
accountTokens = AccountTokensJson(
accessToken = it.accessToken,
refreshToken = it.refreshToken,
accessToken = refreshTokenResponse.accessToken,
refreshToken = refreshTokenResponse.refreshToken,
),
)
authDiskSource.userState = it.toUserStateJson(
userId = userId,
previousUserState = requireNotNull(authDiskSource.userState),
)
}
}

View file

@ -1,39 +0,0 @@
package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
/**
* Converts the given [RefreshTokenResponseJson] to a [UserStateJson], given the following
* additional information:
*
* - the [userId]
* - the [previousUserState]
*/
fun RefreshTokenResponseJson.toUserStateJson(
userId: String,
previousUserState: UserStateJson,
): UserStateJson {
val refreshedAccount = requireNotNull(previousUserState.accounts[userId])
val accessToken = this.accessToken
val jwtTokenData = requireNotNull(parseJwtTokenDataOrNull(jwtToken = accessToken))
val account = refreshedAccount.copy(
profile = refreshedAccount.profile.copy(
userId = jwtTokenData.userId,
email = jwtTokenData.email,
isEmailVerified = jwtTokenData.isEmailVerified,
name = jwtTokenData.name,
),
)
// Update the existing UserState.
return previousUserState.copy(
accounts = previousUserState
.accounts
.toMutableMap()
.apply {
put(userId, account)
},
)
}

View file

@ -233,6 +233,13 @@ class FakeAuthDiskSource : AuthDiskSource {
assertEquals(userState, this.userState)
}
/**
* Assert that the [accountTokens] was stored successfully using the [userId].
*/
fun assertAccountTokens(userId: String, accountTokens: AccountTokensJson?) {
assertEquals(accountTokens, this.storedAccountTokens[userId])
}
/**
* Assert that the [lastActiveTimeMillis] was stored successfully using the [userId].
*/

View file

@ -80,7 +80,6 @@ import com.x8bit.bitwarden.data.auth.repository.util.WebAuthResult
import com.x8bit.bitwarden.data.auth.repository.util.toOrganizations
import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams
import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJson
import com.x8bit.bitwarden.data.auth.util.YubiKeyResult
import com.x8bit.bitwarden.data.auth.util.toSdkParams
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
@ -242,18 +241,12 @@ class AuthRepositoryTest {
@BeforeEach
fun beforeEach() {
mockkStatic(
GetTokenResponseJson.Success::toUserState,
RefreshTokenResponseJson::toUserStateJson,
)
mockkStatic(GetTokenResponseJson.Success::toUserState)
}
@AfterEach
fun tearDown() {
unmockkStatic(
GetTokenResponseJson.Success::toUserState,
RefreshTokenResponseJson::toUserStateJson,
)
unmockkStatic(GetTokenResponseJson.Success::toUserState)
}
@Test
@ -742,7 +735,7 @@ class AuthRepositoryTest {
}
@Test
fun `refreshTokenSynchronously returns failure if not logged in`() = runTest {
fun `refreshAccessTokenSynchronously returns failure if not logged in`() = runTest {
fakeAuthDiskSource.userState = null
val result = repository.refreshAccessTokenSynchronously(USER_ID_1)
@ -751,7 +744,7 @@ class AuthRepositoryTest {
}
@Test
fun `refreshTokenSynchronously returns failure and logs out on failure`() = runTest {
fun `refreshAccessTokenSynchronously returns failure and logs out on failure`() = runTest {
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID_1,
accountTokens = ACCOUNT_TOKENS_1,
@ -768,7 +761,11 @@ class AuthRepositoryTest {
}
@Test
fun `refreshTokenSynchronously returns success and update user state on success`() = runTest {
fun `refreshAccessTokenSynchronously returns success and sets account tokens`() = runTest {
val updatedAccountTokens = AccountTokensJson(
accessToken = ACCESS_TOKEN_2,
refreshToken = REFRESH_TOKEN_2,
)
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID_1,
accountTokens = ACCOUNT_TOKENS_1,
@ -777,22 +774,16 @@ class AuthRepositoryTest {
coEvery {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns REFRESH_TOKEN_RESPONSE_JSON.asSuccess()
every {
REFRESH_TOKEN_RESPONSE_JSON.toUserStateJson(
userId = USER_ID_1,
previousUserState = SINGLE_USER_STATE_1,
)
} returns SINGLE_USER_STATE_1
val result = repository.refreshAccessTokenSynchronously(USER_ID_1)
assertEquals(REFRESH_TOKEN_RESPONSE_JSON.asSuccess(), result)
fakeAuthDiskSource.assertAccountTokens(
userId = USER_ID_1,
accountTokens = updatedAccountTokens,
)
coVerify(exactly = 1) {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
REFRESH_TOKEN_RESPONSE_JSON.toUserStateJson(
userId = USER_ID_1,
previousUserState = SINGLE_USER_STATE_1,
)
}
}
@ -4674,12 +4665,6 @@ class AuthRepositoryTest {
coEvery {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns REFRESH_TOKEN_RESPONSE_JSON.asSuccess()
every {
REFRESH_TOKEN_RESPONSE_JSON.toUserStateJson(
userId = USER_ID_1,
previousUserState = SINGLE_USER_STATE_1,
)
} returns SINGLE_USER_STATE_1
coEvery { vaultRepository.sync() } just runs

View file

@ -1,159 +0,0 @@
package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
import com.x8bit.bitwarden.data.auth.repository.model.JwtTokenDataJson
import io.mockk.every
import io.mockk.mockkStatic
import io.mockk.unmockkStatic
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
class RefreshTokenResponseJsonTest {
@BeforeEach
fun beforeEach() {
mockkStatic(::parseJwtTokenDataOrNull)
}
@AfterEach
fun tearDown() {
unmockkStatic(::parseJwtTokenDataOrNull)
}
@Test
fun `toUserState updates the previous state`() {
every { parseJwtTokenDataOrNull(ACCESS_TOKEN_UPDATED) } returns JWT_TOKEN_DATA
assertEquals(
SINGLE_USER_STATE_UPDATED,
REFRESH_TOKEN_RESPONSE.toUserStateJson(
userId = USER_ID_1,
previousUserState = SINGLE_USER_STATE,
),
)
}
@Test
fun `toUserState updates the previous state for non-active user`() {
every { parseJwtTokenDataOrNull(ACCESS_TOKEN_UPDATED) } returns JWT_TOKEN_DATA
assertEquals(
MULTI_USER_STATE_UPDATED,
REFRESH_TOKEN_RESPONSE.toUserStateJson(
userId = USER_ID_1,
previousUserState = MULTI_USER_STATE,
),
)
}
}
private const val ACCESS_TOKEN_UPDATED = "updatedAccessToken"
private const val REFRESH_TOKEN_UPDATED = "updatedRefreshToken"
private const val USER_ID_1 = "2a135b23-e1fb-42c9-bec3-573857bc8181"
private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02"
private val JWT_TOKEN_DATA = JwtTokenDataJson(
userId = USER_ID_1,
email = "updated@bitwarden.com",
isEmailVerified = false,
name = "Updated Bitwarden Tester",
expirationAsEpochTime = 1697495714,
hasPremium = true,
authenticationMethodsReference = listOf("Application"),
)
private val REFRESH_TOKEN_RESPONSE = RefreshTokenResponseJson(
accessToken = ACCESS_TOKEN_UPDATED,
expiresIn = 3600,
refreshToken = REFRESH_TOKEN_UPDATED,
tokenType = "Bearer",
)
private val ACCOUNT_1 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_1,
email = "test@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.ARGON2_ID,
kdfIterations = 600000,
kdfMemory = 16,
kdfParallelism = 4,
userDecryptionOptions = null,
),
settings = AccountJson.Settings(
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
),
)
private val ACCOUNT_1_UPDATED = ACCOUNT_1.copy(
profile = ACCOUNT_1.profile.copy(
userId = JWT_TOKEN_DATA.userId,
email = JWT_TOKEN_DATA.email,
isEmailVerified = JWT_TOKEN_DATA.isEmailVerified,
name = JWT_TOKEN_DATA.name,
),
)
private val ACCOUNT_2 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_2,
email = "test2@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester 2",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.PBKDF2_SHA256,
kdfIterations = 400000,
kdfMemory = null,
kdfParallelism = null,
userDecryptionOptions = null,
),
settings = AccountJson.Settings(
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
),
)
private val SINGLE_USER_STATE = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
),
)
private val SINGLE_USER_STATE_UPDATED = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1_UPDATED,
),
)
private val MULTI_USER_STATE = UserStateJson(
activeUserId = USER_ID_2,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
USER_ID_2 to ACCOUNT_2,
),
)
private val MULTI_USER_STATE_UPDATED = UserStateJson(
activeUserId = USER_ID_2,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1_UPDATED,
USER_ID_2 to ACCOUNT_2,
),
)