PM-13688: Remove race condition from AuthTokenInterceptor (#4108)

This commit is contained in:
David Perez 2024-10-16 17:01:05 -05:00 committed by GitHub
parent 0d6a8513b2
commit 655beb9dd6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 40 additions and 56 deletions

View file

@ -58,7 +58,9 @@ object PlatformNetworkModule {
@Provides @Provides
@Singleton @Singleton
fun providesAuthTokenInterceptor(): AuthTokenInterceptor = AuthTokenInterceptor() fun providesAuthTokenInterceptor(
authDiskSource: AuthDiskSource,
): AuthTokenInterceptor = AuthTokenInterceptor(authDiskSource = authDiskSource)
@Provides @Provides
@Singleton @Singleton

View file

@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.platform.datasource.network.interceptor package com.x8bit.bitwarden.data.platform.datasource.network.interceptor
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_VALUE_BEARER_PREFIX import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_VALUE_BEARER_PREFIX
import okhttp3.Interceptor import okhttp3.Interceptor
@ -11,11 +12,20 @@ import javax.inject.Singleton
* Interceptor responsible for adding the auth token(Bearer) to API requests. * Interceptor responsible for adding the auth token(Bearer) to API requests.
*/ */
@Singleton @Singleton
class AuthTokenInterceptor : Interceptor { class AuthTokenInterceptor(
private val authDiskSource: AuthDiskSource,
) : Interceptor {
/** /**
* The auth token to be added to API requests. * The auth token to be added to API requests.
*
* Note: This is done on demand to ensure that no race conditions can exist when retrieving the
* token.
*/ */
var authToken: String? = null private val authToken: String?
get() = authDiskSource
.userState
?.activeUserId
?.let { userId -> authDiskSource.getAccountTokens(userId = userId)?.accessToken }
private val missingTokenMessage = "Auth token is missing!" private val missingTokenMessage = "Auth token is missing!"

View file

@ -1,9 +1,7 @@
package com.x8bit.bitwarden.data.platform.manager package com.x8bit.bitwarden.data.platform.manager
import com.x8bit.bitwarden.data.auth.repository.AuthRepository import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.auth.repository.model.AuthState
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
@ -18,10 +16,8 @@ private const val ENVIRONMENT_DEBOUNCE_TIMEOUT_MS: Long = 500L
/** /**
* Primary implementation of [NetworkConfigManager]. * Primary implementation of [NetworkConfigManager].
*/ */
@Suppress("LongParameterList")
class NetworkConfigManagerImpl( class NetworkConfigManagerImpl(
authRepository: AuthRepository, authRepository: AuthRepository,
private val authTokenInterceptor: AuthTokenInterceptor,
environmentRepository: EnvironmentRepository, environmentRepository: EnvironmentRepository,
serverConfigRepository: ServerConfigRepository, serverConfigRepository: ServerConfigRepository,
private val baseUrlInterceptors: BaseUrlInterceptors, private val baseUrlInterceptors: BaseUrlInterceptors,
@ -32,17 +28,6 @@ class NetworkConfigManagerImpl(
private val collectionScope = CoroutineScope(dispatcherManager.unconfined) private val collectionScope = CoroutineScope(dispatcherManager.unconfined)
init { init {
authRepository
.authStateFlow
.onEach { authState ->
authTokenInterceptor.authToken = when (authState) {
is AuthState.Authenticated -> authState.accessToken
is AuthState.Unauthenticated -> null
is AuthState.Uninitialized -> null
}
}
.launchIn(collectionScope)
@Suppress("OPT_IN_USAGE") @Suppress("OPT_IN_USAGE")
environmentRepository environmentRepository
.environmentStateFlow .environmentStateFlow

View file

@ -1,6 +1,7 @@
package com.x8bit.bitwarden.data.platform.manager package com.x8bit.bitwarden.data.platform.manager
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.repository.util.activeUserIdChangesFlow
import com.x8bit.bitwarden.data.platform.datasource.disk.PushDiskSource import com.x8bit.bitwarden.data.platform.datasource.disk.PushDiskSource
import com.x8bit.bitwarden.data.platform.datasource.network.model.PushTokenRequest import com.x8bit.bitwarden.data.platform.datasource.network.model.PushTokenRequest
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService
@ -21,7 +22,6 @@ import com.x8bit.bitwarden.data.platform.util.decodeFromStringOrNull
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.asSharedFlow import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.mapNotNull import kotlinx.coroutines.flow.mapNotNull
import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.onEach
@ -100,9 +100,8 @@ class PushManagerImpl @Inject constructor(
init { init {
authDiskSource authDiskSource
.userStateFlow .activeUserIdChangesFlow
.mapNotNull { it?.activeUserId } .mapNotNull { it }
.distinctUntilChanged()
.onEach { registerStoredPushTokenIfNecessary() } .onEach { registerStoredPushTokenIfNecessary() }
.launchIn(unconfinedScope) .launchIn(unconfinedScope)
} }

View file

@ -10,7 +10,6 @@ import com.x8bit.bitwarden.data.platform.datasource.disk.PushDiskSource
import com.x8bit.bitwarden.data.platform.datasource.disk.SettingsDiskSource import com.x8bit.bitwarden.data.platform.datasource.disk.SettingsDiskSource
import com.x8bit.bitwarden.data.platform.datasource.disk.legacy.LegacyAppCenterMigrator import com.x8bit.bitwarden.data.platform.datasource.disk.legacy.LegacyAppCenterMigrator
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.datasource.network.service.EventService import com.x8bit.bitwarden.data.platform.datasource.network.service.EventService
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService
@ -185,7 +184,6 @@ object PlatformManagerModule {
@Singleton @Singleton
fun provideNetworkConfigManager( fun provideNetworkConfigManager(
authRepository: AuthRepository, authRepository: AuthRepository,
authTokenInterceptor: AuthTokenInterceptor,
environmentRepository: EnvironmentRepository, environmentRepository: EnvironmentRepository,
serverConfigRepository: ServerConfigRepository, serverConfigRepository: ServerConfigRepository,
baseUrlInterceptors: BaseUrlInterceptors, baseUrlInterceptors: BaseUrlInterceptors,
@ -194,7 +192,6 @@ object PlatformManagerModule {
): NetworkConfigManager = ): NetworkConfigManager =
NetworkConfigManagerImpl( NetworkConfigManagerImpl(
authRepository = authRepository, authRepository = authRepository,
authTokenInterceptor = authTokenInterceptor,
environmentRepository = environmentRepository, environmentRepository = environmentRepository,
serverConfigRepository = serverConfigRepository, serverConfigRepository = serverConfigRepository,
baseUrlInterceptors = baseUrlInterceptors, baseUrlInterceptors = baseUrlInterceptors,

View file

@ -1,5 +1,9 @@
package com.x8bit.bitwarden.data.platform.datasource.network.interceptor package com.x8bit.bitwarden.data.platform.datasource.network.interceptor
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountTokensJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource
import io.mockk.mockk
import junit.framework.TestCase.assertEquals import junit.framework.TestCase.assertEquals
import okhttp3.Request import okhttp3.Request
import org.junit.Assert.assertThrows import org.junit.Assert.assertThrows
@ -9,8 +13,10 @@ import javax.inject.Singleton
@Singleton @Singleton
class AuthTokenInterceptorTest { class AuthTokenInterceptorTest {
private val interceptor: AuthTokenInterceptor = AuthTokenInterceptor() private val authDiskSource = FakeAuthDiskSource()
private val mockAuthToken = "yourAuthToken" private val interceptor: AuthTokenInterceptor = AuthTokenInterceptor(
authDiskSource = authDiskSource,
)
private val request: Request = Request private val request: Request = Request
.Builder() .Builder()
.url("http://localhost") .url("http://localhost")
@ -18,12 +24,13 @@ class AuthTokenInterceptorTest {
@Test @Test
fun `intercept should add the auth token when set`() { fun `intercept should add the auth token when set`() {
interceptor.authToken = mockAuthToken authDiskSource.userState = USER_STATE
authDiskSource.storeAccountTokens(userId = USER_ID, ACCOUNT_TOKENS)
val response = interceptor.intercept( val response = interceptor.intercept(
chain = FakeInterceptorChain(request = request), chain = FakeInterceptorChain(request = request),
) )
assertEquals( assertEquals(
"Bearer $mockAuthToken", "Bearer $ACCESS_TOKEN",
response.request.header("Authorization"), response.request.header("Authorization"),
) )
} }
@ -41,3 +48,14 @@ class AuthTokenInterceptorTest {
) )
} }
} }
private const val USER_ID: String = "user_id"
private const val ACCESS_TOKEN: String = "access_token"
private val USER_STATE: UserStateJson = UserStateJson(
activeUserId = USER_ID,
accounts = mapOf(USER_ID to mockk()),
)
private val ACCOUNT_TOKENS: AccountTokensJson = AccountTokensJson(
accessToken = ACCESS_TOKEN,
refreshToken = null,
)

View file

@ -4,7 +4,6 @@ import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.auth.repository.model.AuthState import com.x8bit.bitwarden.data.auth.repository.model.AuthState
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
@ -19,7 +18,6 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.test.UnconfinedTestDispatcher import kotlinx.coroutines.test.UnconfinedTestDispatcher
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
@ -44,7 +42,6 @@ class NetworkConfigManagerTest {
} }
private val refreshAuthenticator = RefreshAuthenticator() private val refreshAuthenticator = RefreshAuthenticator()
private val authTokenInterceptor = AuthTokenInterceptor()
private val baseUrlInterceptors = BaseUrlInterceptors() private val baseUrlInterceptors = BaseUrlInterceptors()
private lateinit var networkConfigManager: NetworkConfigManager private lateinit var networkConfigManager: NetworkConfigManager
@ -53,7 +50,6 @@ class NetworkConfigManagerTest {
fun setUp() { fun setUp() {
networkConfigManager = NetworkConfigManagerImpl( networkConfigManager = NetworkConfigManagerImpl(
authRepository = authRepository, authRepository = authRepository,
authTokenInterceptor = authTokenInterceptor,
environmentRepository = environmentRepository, environmentRepository = environmentRepository,
serverConfigRepository = serverConfigRepository, serverConfigRepository = serverConfigRepository,
baseUrlInterceptors = baseUrlInterceptors, baseUrlInterceptors = baseUrlInterceptors,
@ -62,29 +58,6 @@ class NetworkConfigManagerTest {
) )
} }
@Test
fun `authenticatorProvider should be set on initialization`() {
assertEquals(
authRepository,
refreshAuthenticator.authenticatorProvider,
)
}
@Test
fun `changes in the AuthState should update the AuthTokenInterceptor`() {
mutableAuthStateFlow.value = AuthState.Uninitialized
assertNull(authTokenInterceptor.authToken)
mutableAuthStateFlow.value = AuthState.Authenticated(accessToken = "accessToken")
assertEquals(
"accessToken",
authTokenInterceptor.authToken,
)
mutableAuthStateFlow.value = AuthState.Unauthenticated
assertNull(authTokenInterceptor.authToken)
}
@Test @Test
fun `changes in the Environment should update the BaseUrlInterceptors`() { fun `changes in the Environment should update the BaseUrlInterceptors`() {
mutableEnvironmentStateFlow.value = Environment.Us mutableEnvironmentStateFlow.value = Environment.Us