BIT-1197 Add token refresh handling (#274)

This commit is contained in:
David Perez 2023-11-27 13:40:45 -06:00 committed by Álison Fernandes
parent acfc39ae3c
commit b914f52d0f
16 changed files with 742 additions and 18 deletions

View file

@ -1,18 +1,19 @@
package com.x8bit.bitwarden.data.auth.repository
import com.x8bit.bitwarden.data.auth.datasource.sdk.model.PasswordStrength
import com.x8bit.bitwarden.data.auth.repository.model.UserState
import com.x8bit.bitwarden.data.auth.repository.model.AuthState
import com.x8bit.bitwarden.data.auth.repository.model.LoginResult
import com.x8bit.bitwarden.data.auth.repository.model.RegisterResult
import com.x8bit.bitwarden.data.auth.repository.model.UserState
import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.AuthenticatorProvider
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.StateFlow
/**
* Provides an API for observing an modifying authentication state.
*/
interface AuthRepository {
interface AuthRepository : AuthenticatorProvider {
/**
* Models the current auth state.
*/

View file

@ -5,6 +5,7 @@ import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson.CaptchaRequired
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson.Success
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterRequestJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.service.AccountsService
@ -20,6 +21,7 @@ import com.x8bit.bitwarden.data.auth.repository.model.UserState
import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult
import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams
import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJson
import com.x8bit.bitwarden.data.auth.util.KdfParamsConstants.DEFAULT_PBKDF2_ITERATIONS
import com.x8bit.bitwarden.data.auth.util.toSdkParams
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
@ -56,12 +58,13 @@ class AuthRepositoryImpl constructor(
) : AuthRepository {
private val scope = CoroutineScope(dispatcherManager.io)
override val activeUserId: String? get() = authDiskSource.userState?.activeUserId
override val authStateFlow: StateFlow<AuthState> = authDiskSource
.userStateFlow
.map { userState ->
userState
?.let {
@Suppress("UnsafeCallOnNullableType")
AuthState.Authenticated(
userState
.activeAccount
@ -179,21 +182,42 @@ class AuthRepositoryImpl constructor(
},
)
override fun logout() {
val currentUserState = authDiskSource.userState ?: return
override fun refreshAccessTokenSynchronously(userId: String): Result<RefreshTokenResponseJson> {
val refreshAccount = authDiskSource.userState?.accounts?.get(userId)
?: return IllegalStateException("Must be logged in.").asFailure()
return identityService
.refreshTokenSynchronously(refreshAccount.tokens.refreshToken)
.onSuccess {
// Update the existing UserState with updated token information
authDiskSource.userState = it.toUserStateJson(
userId = userId,
previousUserState = requireNotNull(authDiskSource.userState),
)
}
}
val activeUserId = currentUserState.activeUserId
override fun logout() {
activeUserId?.let { userId -> logout(userId) }
}
override fun logout(userId: String) {
val currentUserState = authDiskSource.userState ?: return
// Remove the active user from the accounts map
val updatedAccounts = currentUserState
.accounts
.filterKeys { it != activeUserId }
authDiskSource.storeUserKey(userId = activeUserId, userKey = null)
authDiskSource.storePrivateKey(userId = activeUserId, privateKey = null)
.filterKeys { it != userId }
authDiskSource.storeUserKey(userId = userId, userKey = null)
authDiskSource.storePrivateKey(userId = userId, privateKey = null)
// Check if there is a new active user
if (updatedAccounts.isNotEmpty()) {
val (updatedActiveUserId, updatedActiveAccount) =
updatedAccounts.entries.first()
// If we logged out a non-active user, we want to leave the active user unchanged.
// If we logged out the active user, we want to set the active user to the first one
// in the list.
val updatedActiveUserId = currentUserState
.activeUserId
.takeUnless { it == userId }
?: updatedAccounts.entries.first().key
// Update the user information and emit an updated token
authDiskSource.userState = currentUserState.copy(

View file

@ -0,0 +1,45 @@
package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
/**
* Converts the given [RefreshTokenResponseJson] to a [UserStateJson], given the following
* additional information:
*
* - the [userId]
* - the [previousUserState]
*/
fun RefreshTokenResponseJson.toUserStateJson(
userId: String,
previousUserState: UserStateJson,
): UserStateJson {
val refreshedAccount = requireNotNull(previousUserState.accounts[userId])
val accessToken = this.accessToken
val jwtTokenData = requireNotNull(parseJwtTokenDataOrNull(jwtToken = accessToken))
val account = refreshedAccount.copy(
profile = refreshedAccount.profile.copy(
userId = jwtTokenData.userId,
email = jwtTokenData.email,
isEmailVerified = jwtTokenData.isEmailVerified,
name = jwtTokenData.name,
hasPremium = jwtTokenData.hasPremium,
),
tokens = AccountJson.Tokens(
accessToken = accessToken,
refreshToken = this.refreshToken,
),
)
// Update the existing UserState.
return previousUserState.copy(
accounts = previousUserState
.accounts
.toMutableMap()
.apply {
put(userId, account)
},
)
}

View file

@ -0,0 +1,27 @@
package com.x8bit.bitwarden.data.platform.datasource.network.authenticator
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
/**
* A provider for all the functionality needed to properly refresh the users access token.
*/
interface AuthenticatorProvider {
/**
* The currently active user's ID.
*/
val activeUserId: String?
/**
* Attempts to logout the user based on the [userId].
*/
fun logout(userId: String)
/**
* Attempt to refresh the user's access token based on the [userId].
*
* This call is both synchronous and performs a network request. Make sure that you are calling
* from an appropriate thread.
*/
fun refreshAccessTokenSynchronously(userId: String): Result<RefreshTokenResponseJson>
}

View file

@ -0,0 +1,69 @@
package com.x8bit.bitwarden.data.platform.datasource.network.authenticator
import com.x8bit.bitwarden.data.auth.repository.util.parseJwtTokenDataOrNull
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_VALUE_BEARER_PREFIX
import okhttp3.Authenticator
import okhttp3.Request
import okhttp3.Response
import okhttp3.Route
import javax.inject.Singleton
/**
* An authenticator used to refresh the access token when a 401 is returned from an API. Upon
* successfully getting a new access token, the original request is retried.
*/
@Singleton
class RefreshAuthenticator : Authenticator {
/**
* A provider required to update tokens.
*/
var authenticatorProvider: AuthenticatorProvider? = null
override fun authenticate(
route: Route?,
response: Response,
): Request? {
val accessToken = requireNotNull(
response
.request
.header(name = HEADER_KEY_AUTHORIZATION)
?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX),
)
return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) {
null -> {
// We unable to get the user ID, let's just let the 401 pass through.
null
}
authenticatorProvider?.activeUserId -> {
// In order to prevent potential deadlocks or thread starvation we want the call
// to refresh the access token to be strictly synchronous with no internal thread
// hopping.
authenticatorProvider
?.refreshAccessTokenSynchronously(userId)
?.fold(
onFailure = {
authenticatorProvider?.logout(userId)
null
},
onSuccess = {
response.request
.newBuilder()
.header(
name = HEADER_KEY_AUTHORIZATION,
value = "$HEADER_VALUE_BEARER_PREFIX${it.accessToken}",
)
.build()
},
)
}
else -> {
// We are no longer the active user, let's just cancel.
null
}
}
}
}

View file

@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.platform.datasource.network.di
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.datasource.network.retrofit.Retrofits
@ -36,16 +37,22 @@ object PlatformNetworkModule {
@Singleton
fun providesAuthTokenInterceptor(): AuthTokenInterceptor = AuthTokenInterceptor()
@Provides
@Singleton
fun providesRefreshAuthenticator(): RefreshAuthenticator = RefreshAuthenticator()
@Provides
@Singleton
fun provideRetrofits(
authTokenInterceptor: AuthTokenInterceptor,
baseUrlInterceptors: BaseUrlInterceptors,
refreshAuthenticator: RefreshAuthenticator,
json: Json,
): Retrofits =
RetrofitsImpl(
authTokenInterceptor = authTokenInterceptor,
baseUrlInterceptors = baseUrlInterceptors,
refreshAuthenticator = refreshAuthenticator,
json = json,
)

View file

@ -1,5 +1,7 @@
package com.x8bit.bitwarden.data.platform.datasource.network.interceptor
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_VALUE_BEARER_PREFIX
import okhttp3.Interceptor
import okhttp3.Response
import java.io.IOException
@ -22,7 +24,10 @@ class AuthTokenInterceptor : Interceptor {
val request = chain
.request()
.newBuilder()
.addHeader("Authorization", "Bearer $token")
.addHeader(
name = HEADER_KEY_AUTHORIZATION,
value = "$HEADER_VALUE_BEARER_PREFIX$token",
)
.build()
return chain
.proceed(request)

View file

@ -1,6 +1,7 @@
package com.x8bit.bitwarden.data.platform.datasource.network.retrofit
import com.jakewharton.retrofit2.converter.kotlinx.serialization.asConverterFactory
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.core.ResultCallAdapterFactory
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptor
@ -17,6 +18,7 @@ import retrofit2.Retrofit
class RetrofitsImpl(
authTokenInterceptor: AuthTokenInterceptor,
baseUrlInterceptors: BaseUrlInterceptors,
refreshAuthenticator: RefreshAuthenticator,
json: Json,
) : Retrofits {
//region Authenticated Retrofits
@ -73,6 +75,7 @@ class RetrofitsImpl(
private val authenticatedOkHttpClient: OkHttpClient by lazy {
baseOkHttpClient
.newBuilder()
.authenticator(refreshAuthenticator)
.addInterceptor(authTokenInterceptor)
.build()
}

View file

@ -0,0 +1,11 @@
package com.x8bit.bitwarden.data.platform.datasource.network.util
/**
* The bearer prefix used for the 'authorization' headers value.
*/
const val HEADER_VALUE_BEARER_PREFIX: String = "Bearer "
/**
* The key used for the 'authorization' headers.
*/
const val HEADER_KEY_AUTHORIZATION: String = "Authorization"

View file

@ -2,6 +2,7 @@ package com.x8bit.bitwarden.data.platform.manager
import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.auth.repository.model.AuthState
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
@ -18,6 +19,7 @@ class NetworkConfigManagerImpl(
private val authTokenInterceptor: AuthTokenInterceptor,
private val environmentRepository: EnvironmentRepository,
private val baseUrlInterceptors: BaseUrlInterceptors,
refreshAuthenticator: RefreshAuthenticator,
dispatcherManager: DispatcherManager,
) : NetworkConfigManager {
@ -41,5 +43,7 @@ class NetworkConfigManagerImpl(
baseUrlInterceptors.environment = environment
}
.launchIn(scope)
refreshAuthenticator.authenticatorProvider = authRepository
}
}

View file

@ -1,6 +1,7 @@
package com.x8bit.bitwarden.data.platform.manager.di
import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManager
@ -25,6 +26,7 @@ object PlatformManagerModule {
@Singleton
fun provideBitwardenDispatchers(): DispatcherManager = DispatcherManagerImpl()
@Suppress("LongParameterList")
@Provides
@Singleton
fun provideNetworkConfigManager(
@ -32,6 +34,7 @@ object PlatformManagerModule {
authTokenInterceptor: AuthTokenInterceptor,
environmentRepository: EnvironmentRepository,
baseUrlInterceptors: BaseUrlInterceptors,
refreshAuthenticator: RefreshAuthenticator,
dispatcherManager: DispatcherManager,
): NetworkConfigManager =
NetworkConfigManagerImpl(
@ -39,6 +42,7 @@ object PlatformManagerModule {
authTokenInterceptor = authTokenInterceptor,
environmentRepository = environmentRepository,
baseUrlInterceptors = baseUrlInterceptors,
refreshAuthenticator = refreshAuthenticator,
dispatcherManager = dispatcherManager,
)
}

View file

@ -12,6 +12,7 @@ import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJs
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson.PBKDF2_SHA256
import com.x8bit.bitwarden.data.auth.datasource.network.model.PreLoginResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterRequestJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.service.AccountsService
@ -29,6 +30,7 @@ import com.x8bit.bitwarden.data.auth.repository.model.RegisterResult
import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult
import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams
import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJson
import com.x8bit.bitwarden.data.auth.util.toSdkParams
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
@ -114,12 +116,18 @@ class AuthRepositoryTest {
@BeforeEach
fun beforeEach() {
clearMocks(identityService, accountsService, haveIBeenPwnedService)
mockkStatic(GET_TOKEN_RESPONSE_EXTENSIONS_PATH)
mockkStatic(
GET_TOKEN_RESPONSE_EXTENSIONS_PATH,
REFRESH_TOKEN_RESPONSE_EXTENSIONS_PATH,
)
}
@AfterEach
fun tearDown() {
unmockkStatic(GET_TOKEN_RESPONSE_EXTENSIONS_PATH)
unmockkStatic(
GET_TOKEN_RESPONSE_EXTENSIONS_PATH,
REFRESH_TOKEN_RESPONSE_EXTENSIONS_PATH,
)
}
@Test
@ -240,6 +248,54 @@ class AuthRepositoryTest {
}
}
@Test
fun `refreshTokenSynchronously returns failure if not logged in`() = runTest {
fakeAuthDiskSource.userState = null
val result = repository.refreshAccessTokenSynchronously(USER_ID_1)
assertTrue(result.isFailure)
}
@Test
fun `refreshTokenSynchronously returns failure and logs out on failure`() = runTest {
fakeAuthDiskSource.userState = SINGLE_USER_STATE_1
coEvery {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns Throwable("Fail").asFailure()
assertTrue(repository.refreshAccessTokenSynchronously(USER_ID_1).isFailure)
coVerify(exactly = 1) {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
}
}
@Test
fun `refreshTokenSynchronously returns success and update user state on success`() = runTest {
fakeAuthDiskSource.userState = SINGLE_USER_STATE_1
coEvery {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns REFRESH_TOKEN_RESPONSE_JSON.asSuccess()
every {
REFRESH_TOKEN_RESPONSE_JSON.toUserStateJson(
userId = USER_ID_1,
previousUserState = SINGLE_USER_STATE_1,
)
} returns SINGLE_USER_STATE_1
val result = repository.refreshAccessTokenSynchronously(USER_ID_1)
assertEquals(REFRESH_TOKEN_RESPONSE_JSON.asSuccess(), result)
coVerify(exactly = 1) {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
REFRESH_TOKEN_RESPONSE_JSON.toUserStateJson(
userId = USER_ID_1,
previousUserState = SINGLE_USER_STATE_1,
)
}
}
@Test
fun `login when pre login fails should return Error with no message`() = runTest {
coEvery {
@ -854,6 +910,36 @@ class AuthRepositoryTest {
}
}
@Test
fun `logout for non-active accounts should leave the active user unchanged`() = runTest {
// First populate multiple user accounts and active user is #3
val initialUserState = MULTI_USER_STATE_2
val finalUserState = initialUserState.copy(
accounts = initialUserState.accounts.filter { it.key != USER_ID_2 },
)
fakeAuthDiskSource.userState = initialUserState
assertEquals(initialUserState, fakeAuthDiskSource.userState)
repository.authStateFlow.test {
assertEquals(AuthState.Authenticated(ACCESS_TOKEN_3), awaitItem())
repository.logout(USER_ID_2)
// The auth state does not actually change
expectNoEvents()
assertEquals(finalUserState, fakeAuthDiskSource.userState)
fakeAuthDiskSource.assertPrivateKey(
userId = USER_ID_2,
privateKey = null,
)
fakeAuthDiskSource.assertUserKey(
userId = USER_ID_2,
userKey = null,
)
}
}
@Test
fun `getPasswordStrength should be based on password length`() = runTest {
// TODO: Replace with SDK call (BIT-964)
@ -878,11 +964,17 @@ class AuthRepositoryTest {
companion object {
private const val GET_TOKEN_RESPONSE_EXTENSIONS_PATH =
"com.x8bit.bitwarden.data.auth.repository.util.GetTokenResponseExtensionsKt"
private const val REFRESH_TOKEN_RESPONSE_EXTENSIONS_PATH =
"com.x8bit.bitwarden.data.auth.repository.util.RefreshTokenResponseExtensionsKt"
private const val EMAIL = "test@bitwarden.com"
private const val EMAIL_2 = "test2@bitwarden.com"
private const val PASSWORD = "password"
private const val PASSWORD_HASH = "passwordHash"
private const val ACCESS_TOKEN = "accessToken"
private const val ACCESS_TOKEN_2 = "accessToken2"
private const val ACCESS_TOKEN_3 = "accessToken3"
private const val REFRESH_TOKEN = "refreshToken"
private const val REFRESH_TOKEN_2 = "refreshToken2"
private const val CAPTCHA_KEY = "captcha"
private const val DEFAULT_KDF_ITERATIONS = 600000
private const val ENCRYPTED_USER_KEY = "encryptedUserKey"
@ -890,9 +982,16 @@ class AuthRepositoryTest {
private const val PRIVATE_KEY = "privateKey"
private const val USER_ID_1 = "2a135b23-e1fb-42c9-bec3-573857bc8181"
private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02"
private const val USER_ID_3 = "3816ef34-0747-4133-9b7a-ba35d3768a68"
private val PRE_LOGIN_SUCCESS = PreLoginResponseJson(
kdfParams = PreLoginResponseJson.KdfParams.Pbkdf2(iterations = 1u),
)
private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson(
accessToken = ACCESS_TOKEN_2,
expiresIn = 3600,
refreshToken = REFRESH_TOKEN_2,
tokenType = "Bearer",
)
private val GET_TOKEN_RESPONSE_SUCCESS = GetTokenResponseJson.Success(
accessToken = ACCESS_TOKEN,
refreshToken = "refreshToken",
@ -928,7 +1027,7 @@ class AuthRepositoryTest {
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN,
refreshToken = "refreshToken",
refreshToken = REFRESH_TOKEN,
),
settings = AccountJson.Settings(
environmentUrlData = null,
@ -937,7 +1036,7 @@ class AuthRepositoryTest {
private val ACCOUNT_2 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_2,
email = "test2@bitwarden.com",
email = EMAIL_2,
isEmailVerified = true,
name = "Bitwarden Tester 2",
hasPremium = false,
@ -959,6 +1058,31 @@ class AuthRepositoryTest {
environmentUrlData = null,
),
)
private val ACCOUNT_3 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_3,
email = "test3@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester 3",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.PBKDF2_SHA256,
kdfIterations = 400000,
kdfMemory = null,
kdfParallelism = null,
userDecryptionOptions = null,
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN_3,
refreshToken = "refreshToken",
),
settings = AccountJson.Settings(
environmentUrlData = null,
),
)
private val SINGLE_USER_STATE_1 = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
@ -978,6 +1102,14 @@ class AuthRepositoryTest {
USER_ID_2 to ACCOUNT_2,
),
)
private val MULTI_USER_STATE_2 = UserStateJson(
activeUserId = USER_ID_3,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
USER_ID_2 to ACCOUNT_2,
USER_ID_3 to ACCOUNT_3,
),
)
private val VAULT_STATE = VaultState(
unlockedVaultUserIds = setOf(USER_ID_1),
)

View file

@ -0,0 +1,177 @@
package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
import com.x8bit.bitwarden.data.auth.repository.model.JwtTokenDataJson
import io.mockk.every
import io.mockk.mockkStatic
import io.mockk.unmockkStatic
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
class RefreshTokenResponseJsonTest {
@BeforeEach
fun beforeEach() {
mockkStatic(JWT_TOKEN_UTILS_PATH)
}
@AfterEach
fun tearDown() {
unmockkStatic(JWT_TOKEN_UTILS_PATH)
}
@Test
fun `toUserState updates the previous state`() {
every { parseJwtTokenDataOrNull(ACCESS_TOKEN_UPDATED) } returns JWT_TOKEN_DATA
assertEquals(
SINGLE_USER_STATE_UPDATED,
REFRESH_TOKEN_RESPONSE.toUserStateJson(
userId = USER_ID_1,
previousUserState = SINGLE_USER_STATE,
),
)
}
@Test
fun `toUserState updates the previous state for non-active user`() {
every { parseJwtTokenDataOrNull(ACCESS_TOKEN_UPDATED) } returns JWT_TOKEN_DATA
assertEquals(
MULTI_USER_STATE_UPDATED,
REFRESH_TOKEN_RESPONSE.toUserStateJson(
userId = USER_ID_1,
previousUserState = MULTI_USER_STATE,
),
)
}
}
private const val ACCESS_TOKEN = "accessToken"
private const val ACCESS_TOKEN_UPDATED = "updatedAccessToken"
private const val REFRESH_TOKEN = "refreshToken"
private const val REFRESH_TOKEN_UPDATED = "updatedRefreshToken"
private const val USER_ID_1 = "2a135b23-e1fb-42c9-bec3-573857bc8181"
private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02"
private const val JWT_TOKEN_UTILS_PATH =
"com.x8bit.bitwarden.data.auth.repository.util.JwtTokenUtilsKt"
private val JWT_TOKEN_DATA = JwtTokenDataJson(
userId = USER_ID_1,
email = "updated@bitwarden.com",
isEmailVerified = false,
name = "Updated Bitwarden Tester",
expirationAsEpochTime = 1697495714,
hasPremium = true,
authenticationMethodsReference = listOf("Application"),
)
private val REFRESH_TOKEN_RESPONSE = RefreshTokenResponseJson(
accessToken = ACCESS_TOKEN_UPDATED,
expiresIn = 3600,
refreshToken = REFRESH_TOKEN_UPDATED,
tokenType = "Bearer",
)
private val ACCOUNT_1 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_1,
email = "test@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.ARGON2_ID,
kdfIterations = 600000,
kdfMemory = 16,
kdfParallelism = 4,
userDecryptionOptions = null,
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN,
refreshToken = REFRESH_TOKEN,
),
settings = AccountJson.Settings(
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
),
)
private val ACCOUNT_1_UPDATED = ACCOUNT_1.copy(
profile = ACCOUNT_1.profile.copy(
userId = JWT_TOKEN_DATA.userId,
email = JWT_TOKEN_DATA.email,
isEmailVerified = JWT_TOKEN_DATA.isEmailVerified,
name = JWT_TOKEN_DATA.name,
hasPremium = JWT_TOKEN_DATA.hasPremium,
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN_UPDATED,
refreshToken = REFRESH_TOKEN_UPDATED,
),
)
private val ACCOUNT_2 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_2,
email = "test2@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester 2",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.PBKDF2_SHA256,
kdfIterations = 400000,
kdfMemory = null,
kdfParallelism = null,
userDecryptionOptions = null,
),
tokens = AccountJson.Tokens(
accessToken = "accessToken2",
refreshToken = "refreshToken2",
),
settings = AccountJson.Settings(
environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US,
),
)
private val SINGLE_USER_STATE = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
),
)
private val SINGLE_USER_STATE_UPDATED = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1_UPDATED,
),
)
private val MULTI_USER_STATE = UserStateJson(
activeUserId = USER_ID_2,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
USER_ID_2 to ACCOUNT_2,
),
)
private val MULTI_USER_STATE_UPDATED = UserStateJson(
activeUserId = USER_ID_2,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1_UPDATED,
USER_ID_2 to ACCOUNT_2,
),
)

View file

@ -0,0 +1,143 @@
package com.x8bit.bitwarden.data.platform.datasource.network.authenticator
import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson
import com.x8bit.bitwarden.data.auth.repository.model.JwtTokenDataJson
import com.x8bit.bitwarden.data.auth.repository.util.parseJwtTokenDataOrNull
import com.x8bit.bitwarden.data.platform.util.asFailure
import com.x8bit.bitwarden.data.platform.util.asSuccess
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.runs
import io.mockk.unmockkStatic
import io.mockk.verify
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
class RefreshAuthenticatorTests {
private lateinit var authenticator: RefreshAuthenticator
private val authenticatorProvider: AuthenticatorProvider = mockk()
@BeforeEach
fun setup() {
authenticator = RefreshAuthenticator()
authenticator.authenticatorProvider = authenticatorProvider
mockkStatic(JWT_TOKEN_UTILS_PATH)
}
@AfterEach
fun tearDown() {
unmockkStatic(JWT_TOKEN_UTILS_PATH)
}
@Test
fun `RefreshAuthenticator returns null if the request is for a different user`() {
every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN
every { authenticatorProvider.activeUserId } returns "different_user_id"
assertNull(authenticator.authenticate(null, RESPONSE_401))
verify(exactly = 1) {
authenticatorProvider.activeUserId
}
}
@Test
fun `RefreshAuthenticator returns null if API has no authorization user ID`() {
every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns null
assertNull(authenticator.authenticate(null, RESPONSE_401))
verify(exactly = 0) {
authenticatorProvider.activeUserId
authenticatorProvider.refreshAccessTokenSynchronously(any())
authenticatorProvider.logout(any())
}
}
@Suppress("MaxLineLength")
@Test
fun `RefreshAuthenticator returns null and logs out when request is for active user and refresh is failure`() {
every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN
every { authenticatorProvider.activeUserId } returns USER_ID
every {
authenticatorProvider.refreshAccessTokenSynchronously(USER_ID)
} returns Throwable("Fail").asFailure()
every { authenticatorProvider.logout(USER_ID) } just runs
assertNull(authenticator.authenticate(null, RESPONSE_401))
verify(exactly = 1) {
authenticatorProvider.activeUserId
authenticatorProvider.refreshAccessTokenSynchronously(USER_ID)
authenticatorProvider.logout(USER_ID)
}
}
@Suppress("MaxLineLength")
@Test
fun `RefreshAuthenticator returns updated request when request is for active user and refresh is success`() {
val newAccessToken = "newAccessToken"
val refreshResponse = RefreshTokenResponseJson(
accessToken = newAccessToken,
expiresIn = 3600,
refreshToken = "refreshToken",
tokenType = "Bearer",
)
every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN
every { authenticatorProvider.activeUserId } returns USER_ID
every {
authenticatorProvider.refreshAccessTokenSynchronously(USER_ID)
} returns refreshResponse.asSuccess()
val authenticatedRequest = authenticator.authenticate(null, RESPONSE_401)
// The okhttp3 Request is not a data class and does not implement equals
// so we are manually checking that the correct header is added.
assertEquals(
"Bearer $newAccessToken",
authenticatedRequest!!.header("Authorization"),
)
verify(exactly = 1) {
authenticatorProvider.activeUserId
authenticatorProvider.refreshAccessTokenSynchronously(USER_ID)
}
}
}
private const val JWT_TOKEN_UTILS_PATH =
"com.x8bit.bitwarden.data.auth.repository.util.JwtTokenUtilsKt"
private const val USER_ID = "2a135b23-e1fb-42c9-bec3-573857bc8181"
private val JTW_TOKEN = JwtTokenDataJson(
userId = USER_ID,
email = "test@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester",
expirationAsEpochTime = 1697495714,
hasPremium = false,
authenticationMethodsReference = listOf("Application"),
)
private const val JWT_ACCESS_TOKEN = "jwt"
private val RESPONSE_401 = Response.Builder()
.code(401)
.request(
request = Request.Builder()
.header(name = "Authorization", value = "Bearer $JWT_ACCESS_TOKEN")
.url("https://www.bitwarden.com")
.build(),
)
.protocol(Protocol.HTTP_2)
.message("Unauthenticated")
.build()

View file

@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.platform.datasource.network.retrofit
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import io.mockk.every
@ -8,6 +9,7 @@ import io.mockk.slot
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import okhttp3.Authenticator
import okhttp3.Interceptor
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
@ -35,12 +37,16 @@ class RetrofitsTest {
mockIntercept { isEventsInterceptorCalled = true }
}
}
private val refreshAuthenticator = mockk<RefreshAuthenticator> {
mockAuthenticate { isRefreshAuthenticatorCalled = true }
}
private val json = Json
private val server = MockWebServer()
private val retrofits = RetrofitsImpl(
authTokenInterceptor = authTokenInterceptor,
baseUrlInterceptors = baseUrlInterceptors,
refreshAuthenticator = refreshAuthenticator,
json = json,
)
@ -48,6 +54,7 @@ class RetrofitsTest {
private var isApiInterceptorCalled = false
private var isIdentityInterceptorCalled = false
private var isEventsInterceptorCalled = false
private var isRefreshAuthenticatorCalled = false
@Before
fun setUp() {
@ -59,6 +66,49 @@ class RetrofitsTest {
server.shutdown()
}
@Test
fun `authenticatedApiRetrofit should not invoke the RefreshAuthenticator on success`() =
runBlocking {
val testApi = retrofits
.authenticatedApiRetrofit
.createMockRetrofit()
.create<TestApi>()
server.enqueue(MockResponse().setBody("""{}"""))
testApi.test()
assertFalse(isRefreshAuthenticatorCalled)
}
@Test
fun `authenticatedApiRetrofit should invoke the RefreshAuthenticator on 401`() = runBlocking {
val testApi = retrofits
.authenticatedApiRetrofit
.createMockRetrofit()
.create<TestApi>()
server.enqueue(MockResponse().setResponseCode(401).setBody("""{}"""))
testApi.test()
assertTrue(isRefreshAuthenticatorCalled)
}
@Test
fun `unauthenticatedApiRetrofit should not invoke the RefreshAuthenticator`() = runBlocking {
val testApi = retrofits
.unauthenticatedApiRetrofit
.createMockRetrofit()
.create<TestApi>()
server.enqueue(MockResponse().setResponseCode(401).setBody("""{}"""))
testApi.test()
assertFalse(isRefreshAuthenticatorCalled)
}
@Test
fun `authenticatedApiRetrofit should invoke the correct interceptors`() = runBlocking {
val testApi = retrofits
@ -138,7 +188,18 @@ class RetrofitsTest {
interface TestApi {
@GET("/test")
suspend fun test(): JsonObject
suspend fun test(): Result<JsonObject>
}
/**
* Mocks the given [Authenticator] such that the [Authenticator.authenticate] is a no-op and
* returns `null` but triggers the [isCalledCallback].
*/
private fun Authenticator.mockAuthenticate(isCalledCallback: () -> Unit) {
every { authenticate(any(), any()) } answers {
isCalledCallback()
null
}
}
/**

View file

@ -3,6 +3,7 @@ package com.x8bit.bitwarden.data.platform.manager
import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.auth.repository.model.AuthState
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
@ -22,7 +23,7 @@ class NetworkConfigManagerTest {
private val mutableAuthStateFlow = MutableStateFlow<AuthState>(AuthState.Uninitialized)
private val mutableEnvironmentStateFlow = MutableStateFlow<Environment>(Environment.Us)
private val authRepository: AuthRepository = mockk() {
private val authRepository: AuthRepository = mockk {
every { authStateFlow } returns mutableAuthStateFlow
}
@ -30,6 +31,7 @@ class NetworkConfigManagerTest {
every { environmentStateFlow } returns mutableEnvironmentStateFlow
}
private val refreshAuthenticator = RefreshAuthenticator()
private val authTokenInterceptor = AuthTokenInterceptor()
private val baseUrlInterceptors = BaseUrlInterceptors()
@ -42,10 +44,19 @@ class NetworkConfigManagerTest {
authTokenInterceptor = authTokenInterceptor,
environmentRepository = environmentRepository,
baseUrlInterceptors = baseUrlInterceptors,
refreshAuthenticator = refreshAuthenticator,
dispatcherManager = dispatcherManager,
)
}
@Test
fun `authenticatorProvider should be set on initialization`() {
assertEquals(
authRepository,
refreshAuthenticator.authenticatorProvider,
)
}
@Test
fun `changes in the AuthState should update the AuthTokenInterceptor`() {
mutableAuthStateFlow.value = AuthState.Uninitialized