BIT-765: Add access token storage (#138)

This commit is contained in:
Brian Yencho 2023-10-20 14:46:01 -05:00 committed by Álison Fernandes
parent e9b8bd2e78
commit f5619d1710
15 changed files with 963 additions and 43 deletions

View file

@ -1,5 +1,8 @@
package com.x8bit.bitwarden.data.auth.datasource.disk
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import kotlinx.coroutines.flow.Flow
/**
* Primary access point for disk information.
*/
@ -8,4 +11,14 @@ interface AuthDiskSource {
* The currently persisted saved email address (or `null` if not set).
*/
var rememberedEmailAddress: String?
/**
* The currently persisted user state information (or `null` if not set).
*/
var userState: UserStateJson?
/**
* Emits updates that track [userState]. This will replay the last known value, if any.
*/
val userStateFlow: Flow<UserStateJson?>
}

View file

@ -1,21 +1,72 @@
package com.x8bit.bitwarden.data.auth.datasource.disk
import android.content.SharedPreferences
import android.content.SharedPreferences.OnSharedPreferenceChangeListener
import androidx.core.content.edit
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.onSubscription
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
private const val REMEMBERED_EMAIL_ADDRESS_KEY = "bwPreferencesStorage:rememberedEmail"
private const val BASE_KEY = "bwPreferencesStorage"
private const val REMEMBERED_EMAIL_ADDRESS_KEY = "$BASE_KEY:rememberedEmail"
private const val STATE_KEY = "$BASE_KEY:state"
/**
* Primary implementation of [AuthDiskSource].
*/
class AuthDiskSourceImpl(
private val sharedPreferences: SharedPreferences,
private val json: Json,
) : AuthDiskSource {
override var rememberedEmailAddress: String?
get() = sharedPreferences.getString(REMEMBERED_EMAIL_ADDRESS_KEY, null)
get() = getString(key = REMEMBERED_EMAIL_ADDRESS_KEY)
set(value) {
sharedPreferences
.edit()
.putString(REMEMBERED_EMAIL_ADDRESS_KEY, value)
.apply()
putString(
key = REMEMBERED_EMAIL_ADDRESS_KEY,
value = value,
)
}
override var userState: UserStateJson?
get() = getString(key = STATE_KEY)?.let { json.decodeFromString(it) }
set(value) {
putString(
key = STATE_KEY,
value = value?.let { json.encodeToString(value) },
)
}
override val userStateFlow: Flow<UserStateJson?>
get() = mutableUserStateFlow
.onSubscription { emit(userState) }
private val mutableUserStateFlow = MutableSharedFlow<UserStateJson?>(
replay = 1,
extraBufferCapacity = Int.MAX_VALUE,
)
private val onSharedPreferenceChangeListener =
OnSharedPreferenceChangeListener { _, key ->
when (key) {
STATE_KEY -> mutableUserStateFlow.tryEmit(userState)
}
}
init {
sharedPreferences
.registerOnSharedPreferenceChangeListener(onSharedPreferenceChangeListener)
}
private fun getString(
key: String,
default: String? = null,
): String? = sharedPreferences.getString(key, default)
private fun putString(
key: String,
value: String?,
): Unit = sharedPreferences.edit { putString(key, value) }
}

View file

@ -7,6 +7,7 @@ import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import kotlinx.serialization.json.Json
import javax.inject.Singleton
/**
@ -20,6 +21,10 @@ object DiskModule {
@Singleton
fun provideAuthDiskSource(
sharedPreferences: SharedPreferences,
json: Json,
): AuthDiskSource =
AuthDiskSourceImpl(sharedPreferences = sharedPreferences)
AuthDiskSourceImpl(
sharedPreferences = sharedPreferences,
json = json,
)
}

View file

@ -0,0 +1,116 @@
package com.x8bit.bitwarden.data.auth.datasource.disk.model
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.UserDecryptionOptionsJson
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
/**
* Represents the current account information for a given user.
*
* @property profile Information about a user's personal profile.
* @property tokens Information about a user's access tokens.
* @property settings Information about a user's app settings.
*/
@Serializable
data class AccountJson(
@SerialName("profile")
val profile: Profile,
@SerialName("tokens")
val tokens: Tokens,
@SerialName("settings")
val settings: Settings,
) {
/**
* Represents a user's personal profile.
*
* @property userId The ID of the user.
* @property email The user's email address.
* @property isEmailVerified Whether or not the user's email is verified.
* @property name The user's name (if applicable).
* @property stamp The account's security stamp (if applicable).
* @property organizationId The ID of the associated organization (if applicable).
* @property hasPremium True if the user has a premium account.
* @property avatarColorHex Hex color value for a user's avatar in the "#AARRGGBB" format.
* @property forcePasswordResetReason Describes the reason for a forced password reset.
* @property kdfType The KDF type.
* @property kdfIterations The number of iterations when calculating a user's password.
* @property kdfMemory The amount of memory to use when calculating a password hash (MB).
* @property kdfParallelism The number of threads to use when calculating a password hash.
* @property userDecryptionOptions The options available to a user for decryption.
*/
@Serializable
data class Profile(
@SerialName("userId")
val userId: String,
@SerialName("email")
val email: String,
@SerialName("emailVerified")
val isEmailVerified: Boolean?,
@SerialName("name")
val name: String?,
@SerialName("stamp")
val stamp: String?,
@SerialName("orgIdentifier")
val organizationId: String?,
@SerialName("avatarColor")
val avatarColorHex: String?,
@SerialName("hasPremiumPersonally")
val hasPremium: Boolean?,
@SerialName("forcePasswordResetReason")
val forcePasswordResetReason: ForcePasswordResetReason?,
@SerialName("kdfType")
val kdfType: KdfTypeJson?,
@SerialName("kdfIterations")
val kdfIterations: Int?,
@SerialName("kdfMemory")
val kdfMemory: Int?,
@SerialName("kdfParallelism")
val kdfParallelism: Int?,
@SerialName("accountDecryptionOptions")
val userDecryptionOptions: UserDecryptionOptionsJson?,
)
/**
* Container for the user's API tokens.
*
* @property accessToken The user's primary access token.
* @property refreshToken The user's refresh token.
*/
@Serializable
data class Tokens(
@SerialName("accessToken")
val accessToken: String,
@SerialName("refreshToken")
val refreshToken: String,
)
/**
* Container for various user settings.
*
* @property environmentUrlData Data concerning the current environment associated with the
* current user.
*/
@Serializable
data class Settings(
@SerialName("environmentUrls")
val environmentUrlData: EnvironmentUrlDataJson?,
)
}

View file

@ -0,0 +1,53 @@
package com.x8bit.bitwarden.data.auth.datasource.disk.model
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
/**
* Represents URLs for various Bitwarden domains.
*
* @property base The overall base URL.
* @property api Separate base URL for the "/api" domain (if applicable).
* @property identity Separate base URL for the "/identity" domain (if applicable).
* @property icon Separate base URL for the icon domain (if applicable).
* @property notifications Separate base URL for the notifications domain (if applicable).
* @property webVault Separate base URL for the web vault domain (if applicable).
* @property events Separate base URL for the events domain (if applicable).
*/
@Serializable
data class EnvironmentUrlDataJson(
@SerialName("base")
val base: String,
@SerialName("api")
val api: String? = null,
@SerialName("identity")
val identity: String? = null,
@SerialName("icon")
val icon: String? = null,
@SerialName("notifications")
val notifications: String? = null,
@SerialName("webVault")
val webVault: String? = null,
@SerialName("events")
val events: String? = null,
) {
companion object {
/**
* Default [EnvironmentUrlDataJson] for the US region.
*/
val DEFAULT_US: EnvironmentUrlDataJson =
EnvironmentUrlDataJson(base = "https://vault.bitwarden.com")
/**
* Default [EnvironmentUrlDataJson] for the EU region.
*/
val DEFAULT_EU: EnvironmentUrlDataJson =
EnvironmentUrlDataJson(base = "https://vault.bitwarden.eu")
}
}

View file

@ -0,0 +1,23 @@
package com.x8bit.bitwarden.data.auth.datasource.disk.model
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
/**
* Describes the reason for a forced password reset.
*/
@Serializable
enum class ForcePasswordResetReason {
/**
* An organization admin forced a user to reset their password.
*/
@SerialName("adminForcePasswordReset")
ADMIN_FORCE_PASSWORD_RESET,
/**
* A user logged in with a master password that does not meet an organization's master password
* policy that is enforced on login.
*/
@SerialName("weakMasterPasswordOnLogin")
WEAK_MASTER_PASSWORD_ON_LOGIN,
}

View file

@ -0,0 +1,32 @@
package com.x8bit.bitwarden.data.auth.datasource.disk.model
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
/**
* Represents the overall "user state" of the current active user as well as any users that may be
* switched to.
*
* @property activeUserId The ID of the current active user.
* @property accounts A mapping between user IDs and the [AccountJson] information associated with
* that user.
*/
@Serializable
data class UserStateJson(
@SerialName("activeUserId")
val activeUserId: String,
@SerialName("accounts")
val accounts: Map<String, AccountJson>,
) {
init {
requireNotNull(accounts[activeUserId])
}
/**
* The current active account.
*/
@Suppress("UnsafeCallOnNullableType")
val activeAccount: AccountJson
get() = accounts[activeUserId]!!
}

View file

@ -15,16 +15,20 @@ import com.x8bit.bitwarden.data.auth.repository.model.AuthState
import com.x8bit.bitwarden.data.auth.repository.model.LoginResult
import com.x8bit.bitwarden.data.auth.repository.model.RegisterResult
import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult
import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.util.toSdkParams
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.util.flatMap
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.stateIn
import javax.inject.Inject
import javax.inject.Singleton
@ -40,10 +44,38 @@ class AuthRepositoryImpl @Inject constructor(
private val authSdkSource: AuthSdkSource,
private val authDiskSource: AuthDiskSource,
private val authTokenInterceptor: AuthTokenInterceptor,
dispatcher: CoroutineDispatcher,
) : AuthRepository {
private val scope = CoroutineScope(dispatcher)
private val mutableAuthStateFlow = MutableStateFlow<AuthState>(AuthState.Unauthenticated)
override val authStateFlow: StateFlow<AuthState> = mutableAuthStateFlow.asStateFlow()
override val authStateFlow: StateFlow<AuthState> = authDiskSource
.userStateFlow
.map { userState ->
userState
?.let {
@Suppress("UnsafeCallOnNullableType")
AuthState.Authenticated(
userState
.activeAccount
.tokens
.accessToken,
)
}
?: AuthState.Unauthenticated
}
.onEach {
// TODO: Create intermediate class for providing auth token to interceptor (BIT-411)
authTokenInterceptor.authToken = when (it) {
is AuthState.Authenticated -> it.accessToken
AuthState.Unauthenticated -> null
AuthState.Uninitialized -> null
}
}
.stateIn(
scope = scope,
started = SharingStarted.Eagerly,
initialValue = AuthState.Uninitialized,
)
private val mutableCaptchaTokenFlow =
MutableSharedFlow<CaptchaCallbackTokenResult>(extraBufferCapacity = Int.MAX_VALUE)
@ -85,10 +117,10 @@ class AuthRepositoryImpl @Inject constructor(
when (it) {
is CaptchaRequired -> LoginResult.CaptchaRequired(it.captchaKey)
is Success -> {
// TODO: Create intermediate class for providing auth token
// to interceptor (BIT-411)
authTokenInterceptor.authToken = it.accessToken
mutableAuthStateFlow.value = AuthState.Authenticated(it.accessToken)
authDiskSource.userState = it
.toUserState(
previousUserState = authDiskSource.userState,
)
LoginResult.Success
}
@ -100,7 +132,29 @@ class AuthRepositoryImpl @Inject constructor(
)
override fun logout() {
mutableAuthStateFlow.update { AuthState.Unauthenticated }
val currentUserState = authDiskSource.userState ?: return
val activeUserId = currentUserState.activeUserId
// Remove the active user from the accounts map
val updatedAccounts = currentUserState
.accounts
.filterKeys { it != activeUserId }
// Check if there is a new active user
if (updatedAccounts.isNotEmpty()) {
val (updatedActiveUserId, updatedActiveAccount) =
updatedAccounts.entries.first()
// Update the user information and emit an updated token
authDiskSource.userState = currentUserState.copy(
activeUserId = updatedActiveUserId,
accounts = updatedAccounts,
)
} else {
// Update the user information and log out
authDiskSource.userState = null
}
}
override suspend fun register(

View file

@ -1,19 +1,40 @@
package com.x8bit.bitwarden.data.auth.repository.di
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.datasource.network.service.AccountsService
import com.x8bit.bitwarden.data.auth.datasource.network.service.IdentityService
import com.x8bit.bitwarden.data.auth.datasource.sdk.AuthSdkSource
import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.auth.repository.AuthRepositoryImpl
import dagger.Binds
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import kotlinx.coroutines.Dispatchers
import javax.inject.Singleton
/**
* Provides repositories in the auth package.
*/
@Module
@InstallIn(SingletonComponent::class)
abstract class RepositoryModule {
object RepositoryModule {
@Binds
abstract fun bindsAuthRepository(authRepositoryImpl: AuthRepositoryImpl): AuthRepository
@Provides
@Singleton
fun bindsAuthRepository(
accountsService: AccountsService,
identityService: IdentityService,
authSdkSource: AuthSdkSource,
authDiskSource: AuthDiskSource,
authTokenInterceptor: AuthTokenInterceptor,
): AuthRepository = AuthRepositoryImpl(
accountsService = accountsService,
identityService = identityService,
authSdkSource = authSdkSource,
authDiskSource = authDiskSource,
authTokenInterceptor = authTokenInterceptor,
dispatcher = Dispatchers.IO,
)
}

View file

@ -0,0 +1,62 @@
package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson
/**
* Converts the given [GetTokenResponseJson.Success] to a [UserStateJson], given the current
* [previousUserState].
*/
fun GetTokenResponseJson.Success.toUserState(
previousUserState: UserStateJson?,
): UserStateJson {
val accessToken = this.accessToken
@Suppress("UnsafeCallOnNullableType")
val jwtTokenData = parseJwtTokenDataOrNull(jwtToken = accessToken)!!
val userId = jwtTokenData.userId
// TODO: Update null properties below via sync request (BIT-916)
val account = AccountJson(
profile = AccountJson.Profile(
userId = userId,
email = jwtTokenData.email,
isEmailVerified = jwtTokenData.isEmailVerified,
name = jwtTokenData.name,
stamp = null,
organizationId = null,
avatarColorHex = null,
hasPremium = jwtTokenData.hasPremium,
forcePasswordResetReason = null,
kdfType = this.kdfType,
kdfIterations = this.kdfIterations,
kdfMemory = this.kdfMemory,
kdfParallelism = this.kdfParallelism,
userDecryptionOptions = this.userDecryptionOptions,
),
tokens = AccountJson.Tokens(
accessToken = accessToken,
refreshToken = this.refreshToken,
),
settings = AccountJson.Settings(
environmentUrlData = null,
),
)
// Create a new UserState with the updated info or update the existing one.
return previousUserState
?.copy(
activeUserId = userId,
accounts = previousUserState
.accounts
.toMutableMap()
.apply {
put(userId, account)
},
)
?: UserStateJson(
activeUserId = userId,
accounts = mapOf(userId to account),
)
}

View file

@ -9,6 +9,7 @@ import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
@ -85,6 +86,7 @@ object NetworkModule {
)
.build()
@OptIn(ExperimentalSerializationApi::class)
@Provides
@Singleton
fun providesJson(): Json = Json {
@ -93,5 +95,8 @@ object NetworkModule {
// ignore them.
// This makes additive server changes non-breaking.
ignoreUnknownKeys = true
// We allow for nullable values to have keys missing in the JSON response.
explicitNulls = false
}
}

View file

@ -1,6 +1,18 @@
package com.x8bit.bitwarden.data.auth.datasource.disk
import app.cash.turbine.test
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.ForcePasswordResetReason
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KeyConnectorUserDecryptionOptionsJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.TrustedDeviceUserDecryptionOptionsJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.UserDecryptionOptionsJson
import com.x8bit.bitwarden.data.platform.base.FakeSharedPreferences
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Test
@ -8,8 +20,15 @@ import org.junit.jupiter.api.Test
class AuthDiskSourceTest {
private val fakeSharedPreferences = FakeSharedPreferences()
@OptIn(ExperimentalSerializationApi::class)
private val json = Json {
ignoreUnknownKeys = true
explicitNulls = false
}
private val authDiskSource = AuthDiskSourceImpl(
sharedPreferences = fakeSharedPreferences,
json = json,
)
@Test
@ -31,4 +50,145 @@ class AuthDiskSourceTest {
fakeSharedPreferences.edit().putString(rememberedEmailKey, null).apply()
assertNull(authDiskSource.rememberedEmailAddress)
}
@Test
fun `userState should pull from and update SharedPreferences`() {
val userStateKey = "bwPreferencesStorage:state"
// Shared preferences and the repository start with the same value.
assertNull(authDiskSource.userState)
assertNull(fakeSharedPreferences.getString(userStateKey, null))
// Updating the repository updates shared preferences
authDiskSource.userState = USER_STATE
assertEquals(
json.parseToJsonElement(
USER_STATE_JSON,
),
json.parseToJsonElement(
fakeSharedPreferences.getString(userStateKey, null)!!,
),
)
// Update SharedPreferences updates the repository
fakeSharedPreferences.edit().putString(userStateKey, null).apply()
assertNull(authDiskSource.userState)
}
@Test
fun `userStateFlow should react to changes in userState`() = runTest {
authDiskSource.userStateFlow.test {
// The initial values of the Flow and the property are in sync
assertNull(authDiskSource.userState)
assertNull(awaitItem())
// Updating the repository updates shared preferences
authDiskSource.userState = USER_STATE
assertEquals(USER_STATE, awaitItem())
}
}
}
private const val USER_STATE_JSON = """
{
"activeUserId": "activeUserId",
"accounts": {
"activeUserId": {
"profile": {
"userId": "activeUserId",
"email": "email",
"emailVerified": true,
"name": "name",
"stamp": "stamp",
"orgIdentifier": "organizationId",
"avatarColor": "avatarColorHex",
"hasPremiumPersonally": true,
"forcePasswordResetReason": "adminForcePasswordReset",
"kdfType": 1,
"kdfIterations": 600000,
"kdfMemory": 16,
"kdfParallelism": 4,
"accountDecryptionOptions": {
"HasMasterPassword": true,
"TrustedDeviceOption": {
"EncryptedPrivateKey": "encryptedPrivateKey",
"EncryptedUserKey": "encryptedUserKey",
"HasAdminApproval": true,
"HasLoginApprovingDevice": true,
"HasManageResetPasswordPermission": true
},
"KeyConnectorOption": {
"KeyConnectorUrl": "keyConnectorUrl"
}
}
},
"tokens": {
"accessToken": "accessToken",
"refreshToken": "refreshToken"
},
"settings": {
"environmentUrls": {
"base": "base",
"api": "api",
"identity": "identity",
"icon": "icon",
"notifications": "notifications",
"webVault": "webVault",
"events": "events"
}
}
}
}
}
"""
private val USER_STATE = UserStateJson(
activeUserId = "activeUserId",
accounts = mapOf(
"activeUserId" to AccountJson(
profile = AccountJson.Profile(
userId = "activeUserId",
email = "email",
isEmailVerified = true,
name = "name",
stamp = "stamp",
organizationId = "organizationId",
avatarColorHex = "avatarColorHex",
hasPremium = true,
forcePasswordResetReason = ForcePasswordResetReason.ADMIN_FORCE_PASSWORD_RESET,
kdfType = KdfTypeJson.ARGON2_ID,
kdfIterations = 600000,
kdfMemory = 16,
kdfParallelism = 4,
userDecryptionOptions = UserDecryptionOptionsJson(
hasMasterPassword = true,
trustedDeviceUserDecryptionOptions = TrustedDeviceUserDecryptionOptionsJson(
encryptedPrivateKey = "encryptedPrivateKey",
encryptedUserKey = "encryptedUserKey",
hasAdminApproval = true,
hasLoginApprovingDevice = true,
hasManageResetPasswordPermission = true,
),
keyConnectorUserDecryptionOptions = KeyConnectorUserDecryptionOptionsJson(
keyConnectorUrl = "keyConnectorUrl",
),
),
),
tokens = AccountJson.Tokens(
accessToken = "accessToken",
refreshToken = "refreshToken",
),
settings = AccountJson.Settings(
environmentUrlData = EnvironmentUrlDataJson(
base = "base",
api = "api",
identity = "identity",
icon = "icon",
notifications = "notifications",
webVault = "webVault",
events = "events",
),
),
),
),
)

View file

@ -5,7 +5,10 @@ import com.bitwarden.core.Kdf
import com.bitwarden.core.RegisterKeyResponse
import com.bitwarden.core.RsaKeyPair
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson.PBKDF2_SHA256
import com.x8bit.bitwarden.data.auth.datasource.network.model.PreLoginResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterRequestJson
@ -17,6 +20,7 @@ import com.x8bit.bitwarden.data.auth.repository.model.AuthState
import com.x8bit.bitwarden.data.auth.repository.model.LoginResult
import com.x8bit.bitwarden.data.auth.repository.model.RegisterResult
import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult
import com.x8bit.bitwarden.data.auth.repository.util.toUserState
import com.x8bit.bitwarden.data.auth.util.toSdkParams
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import io.mockk.clearMocks
@ -24,8 +28,15 @@ import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import io.mockk.mockkStatic
import io.mockk.unmockkStatic
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.onSubscription
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.BeforeEach
@ -35,7 +46,7 @@ class AuthRepositoryTest {
private val accountsService: AccountsService = mockk()
private val identityService: IdentityService = mockk()
private val authInterceptor = mockk<AuthTokenInterceptor>()
private val authInterceptor = AuthTokenInterceptor()
private val fakeAuthDiskSource = FakeAuthDiskSource()
private val authSdkSource = mockk<AuthSdkSource> {
coEvery {
@ -63,17 +74,25 @@ class AuthRepositoryTest {
)
}
@OptIn(ExperimentalCoroutinesApi::class)
private val repository = AuthRepositoryImpl(
accountsService = accountsService,
identityService = identityService,
authSdkSource = authSdkSource,
authDiskSource = fakeAuthDiskSource,
authTokenInterceptor = authInterceptor,
dispatcher = UnconfinedTestDispatcher(),
)
@BeforeEach
fun beforeEach() {
clearMocks(identityService, accountsService, authInterceptor)
clearMocks(identityService, accountsService)
mockkStatic(GET_TOKEN_RESPONSE_EXTENSIONS_PATH)
}
@AfterEach
fun tearDown() {
unmockkStatic(GET_TOKEN_RESPONSE_EXTENSIONS_PATH)
}
@Test
@ -162,9 +181,7 @@ class AuthRepositoryTest {
@Test
fun `login get token succeeds should return Success and update AuthState`() = runTest {
val successResponse = mockk<GetTokenResponseJson.Success> {
every { accessToken } returns ACCESS_TOKEN
}
val successResponse = GET_TOKEN_RESPONSE_SUCCESS
coEvery {
accountsService.preLogin(email = EMAIL)
} returns Result.success(PRE_LOGIN_SUCCESS)
@ -176,11 +193,13 @@ class AuthRepositoryTest {
)
}
.returns(Result.success(successResponse))
every { authInterceptor.authToken = ACCESS_TOKEN } returns Unit
every {
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null)
} returns SINGLE_USER_STATE_1
val result = repository.login(email = EMAIL, password = PASSWORD, captchaToken = null)
assertEquals(LoginResult.Success, result)
assertEquals(AuthState.Authenticated(ACCESS_TOKEN), repository.authStateFlow.value)
verify { authInterceptor.authToken = ACCESS_TOKEN }
assertEquals(ACCESS_TOKEN, authInterceptor.authToken)
coVerify { accountsService.preLogin(email = EMAIL) }
coVerify {
identityService.getToken(
@ -365,11 +384,9 @@ class AuthRepositoryTest {
}
@Test
fun `logout should change AuthState to be Unauthenticated`() = runTest {
fun `logout for single account should clear the access token`() = runTest {
// First login:
val successResponse = mockk<GetTokenResponseJson.Success> {
every { accessToken } returns ACCESS_TOKEN
}
val successResponse = GET_TOKEN_RESPONSE_SUCCESS
coEvery {
accountsService.preLogin(email = EMAIL)
} returns Result.success(PRE_LOGIN_SUCCESS)
@ -379,36 +396,189 @@ class AuthRepositoryTest {
passwordHash = PASSWORD_HASH,
captchaToken = null,
)
}
.returns(Result.success(successResponse))
} returns Result.success(successResponse)
every {
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null)
} returns SINGLE_USER_STATE_1
every { authInterceptor.authToken = ACCESS_TOKEN } returns Unit
repository.login(email = EMAIL, password = PASSWORD, captchaToken = null)
assertEquals(AuthState.Authenticated(ACCESS_TOKEN), repository.authStateFlow.value)
assertEquals(ACCESS_TOKEN, authInterceptor.authToken)
assertEquals(SINGLE_USER_STATE_1, fakeAuthDiskSource.userState)
// Then call logout:
repository.authStateFlow.test {
assertEquals(AuthState.Authenticated(ACCESS_TOKEN), awaitItem())
repository.logout()
assertEquals(AuthState.Unauthenticated, awaitItem())
assertNull(authInterceptor.authToken)
assertNull(fakeAuthDiskSource.userState)
}
}
@Test
fun `logout for multiple accounts should update current access token`() = runTest {
// First populate multiple user accounts
fakeAuthDiskSource.userState = SINGLE_USER_STATE_2
// Then login:
val successResponse = GET_TOKEN_RESPONSE_SUCCESS
coEvery {
accountsService.preLogin(email = EMAIL)
} returns Result.success(PRE_LOGIN_SUCCESS)
coEvery {
identityService.getToken(
email = EMAIL,
passwordHash = PASSWORD_HASH,
captchaToken = null,
)
} returns Result.success(successResponse)
every {
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = SINGLE_USER_STATE_2)
} returns MULTI_USER_STATE
repository.login(email = EMAIL, password = PASSWORD, captchaToken = null)
assertEquals(AuthState.Authenticated(ACCESS_TOKEN), repository.authStateFlow.value)
assertEquals(ACCESS_TOKEN, authInterceptor.authToken)
assertEquals(MULTI_USER_STATE, fakeAuthDiskSource.userState)
// Then call logout:
repository.authStateFlow.test {
assertEquals(AuthState.Authenticated(ACCESS_TOKEN), awaitItem())
repository.logout()
assertEquals(AuthState.Authenticated(ACCESS_TOKEN_2), awaitItem())
assertEquals(ACCESS_TOKEN_2, authInterceptor.authToken)
assertEquals(SINGLE_USER_STATE_2, fakeAuthDiskSource.userState)
}
}
companion object {
private const val EMAIL = "test@test.com"
private const val GET_TOKEN_RESPONSE_EXTENSIONS_PATH =
"com.x8bit.bitwarden.data.auth.repository.util.GetTokenResponseExtensionsKt"
private const val EMAIL = "test@bitwarden.com"
private const val PASSWORD = "password"
private const val PASSWORD_HASH = "passwordHash"
private const val ACCESS_TOKEN = "accessToken"
private const val ACCESS_TOKEN_2 = "accessToken2"
private const val CAPTCHA_KEY = "captcha"
private const val DEFAULT_KDF_ITERATIONS = 600000
private const val ENCRYPTED_USER_KEY = "encryptedUserKey"
private const val PUBLIC_KEY = "PublicKey"
private const val PRIVATE_KEY = "privateKey"
private const val USER_ID_1 = "2a135b23-e1fb-42c9-bec3-573857bc8181"
private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02"
private val PRE_LOGIN_SUCCESS = PreLoginResponseJson(
kdfParams = PreLoginResponseJson.KdfParams.Pbkdf2(iterations = 1u),
)
private val GET_TOKEN_RESPONSE_SUCCESS = GetTokenResponseJson.Success(
accessToken = ACCESS_TOKEN,
refreshToken = "refreshToken",
tokenType = "Bearer",
expiresInSeconds = 3600,
key = "key",
kdfType = KdfTypeJson.ARGON2_ID,
kdfIterations = 600000,
kdfMemory = 16,
kdfParallelism = 4,
privateKey = "privateKey",
shouldForcePasswordReset = true,
shouldResetMasterPassword = true,
masterPasswordPolicyOptions = null,
userDecryptionOptions = null,
)
private val ACCOUNT_1 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_1,
email = "test@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.ARGON2_ID,
kdfIterations = 600000,
kdfMemory = 16,
kdfParallelism = 4,
userDecryptionOptions = null,
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN,
refreshToken = "refreshToken",
),
settings = AccountJson.Settings(
environmentUrlData = null,
),
)
private val ACCOUNT_2 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_2,
email = "test2@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester 2",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.PBKDF2_SHA256,
kdfIterations = 400000,
kdfMemory = null,
kdfParallelism = null,
userDecryptionOptions = null,
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN_2,
refreshToken = "refreshToken",
),
settings = AccountJson.Settings(
environmentUrlData = null,
),
)
private val SINGLE_USER_STATE_1 = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
),
)
private val SINGLE_USER_STATE_2 = UserStateJson(
activeUserId = USER_ID_2,
accounts = mapOf(
USER_ID_2 to ACCOUNT_2,
),
)
private val MULTI_USER_STATE = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
USER_ID_2 to ACCOUNT_2,
),
)
}
}
private class FakeAuthDiskSource : AuthDiskSource {
override var rememberedEmailAddress: String? = null
override var userState: UserStateJson? = null
set(value) {
field = value
mutableUserStateFlow.tryEmit(value)
}
override val userStateFlow: Flow<UserStateJson?>
get() = mutableUserStateFlow.onSubscription { emit(userState) }
private val mutableUserStateFlow =
MutableSharedFlow<UserStateJson?>(
replay = 1,
extraBufferCapacity = Int.MAX_VALUE,
)
}

View file

@ -0,0 +1,151 @@
package com.x8bit.bitwarden.data.auth.repository.util
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson
import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.repository.model.JwtTokenDataJson
import io.mockk.every
import io.mockk.mockkStatic
import io.mockk.unmockkStatic
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
class GetTokenResponseExtensionsTest {
@BeforeEach
fun beforeEach() {
mockkStatic(JWT_TOKEN_UTILS_PATH)
}
@AfterEach
fun tearDown() {
unmockkStatic(JWT_TOKEN_UTILS_PATH)
}
@Test
fun `toUserState with a null previous state creates a new single user state`() {
every { parseJwtTokenDataOrNull(ACCESS_TOKEN_1) } returns JWT_TOKEN_DATA
assertEquals(
SINGLE_USER_STATE_1,
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = null),
)
}
@Test
fun `toUserState with a non-null previous state updates the previous state`() {
every { parseJwtTokenDataOrNull(ACCESS_TOKEN_1) } returns JWT_TOKEN_DATA
assertEquals(
MULTI_USER_STATE,
GET_TOKEN_RESPONSE_SUCCESS.toUserState(previousUserState = SINGLE_USER_STATE_2),
)
}
}
private const val ACCESS_TOKEN_1 = "accessToken1"
private const val ACCESS_TOKEN_2 = "accessToken2"
private const val USER_ID_1 = "2a135b23-e1fb-42c9-bec3-573857bc8181"
private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02"
private const val JWT_TOKEN_UTILS_PATH =
"com.x8bit.bitwarden.data.auth.repository.util.JwtTokenUtilsKt"
private val JWT_TOKEN_DATA = JwtTokenDataJson(
userId = "2a135b23-e1fb-42c9-bec3-573857bc8181",
email = "test@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester",
expirationAsEpochTime = 1697495714,
hasPremium = false,
authenticationMethodsReference = listOf("Application"),
)
private val GET_TOKEN_RESPONSE_SUCCESS = GetTokenResponseJson.Success(
accessToken = ACCESS_TOKEN_1,
refreshToken = "refreshToken",
tokenType = "Bearer",
expiresInSeconds = 3600,
key = "key",
kdfType = KdfTypeJson.ARGON2_ID,
kdfIterations = 600000,
kdfMemory = 16,
kdfParallelism = 4,
privateKey = "privateKey",
shouldForcePasswordReset = true,
shouldResetMasterPassword = true,
masterPasswordPolicyOptions = null,
userDecryptionOptions = null,
)
private val ACCOUNT_1 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_1,
email = "test@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.ARGON2_ID,
kdfIterations = 600000,
kdfMemory = 16,
kdfParallelism = 4,
userDecryptionOptions = null,
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN_1,
refreshToken = "refreshToken",
),
settings = AccountJson.Settings(
environmentUrlData = null,
),
)
private val ACCOUNT_2 = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID_2,
email = "test2@bitwarden.com",
isEmailVerified = true,
name = "Bitwarden Tester 2",
hasPremium = false,
stamp = null,
organizationId = null,
avatarColorHex = null,
forcePasswordResetReason = null,
kdfType = KdfTypeJson.PBKDF2_SHA256,
kdfIterations = 400000,
kdfMemory = null,
kdfParallelism = null,
userDecryptionOptions = null,
),
tokens = AccountJson.Tokens(
accessToken = ACCESS_TOKEN_2,
refreshToken = "refreshToken",
),
settings = AccountJson.Settings(
environmentUrlData = null,
),
)
private val SINGLE_USER_STATE_1 = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
),
)
private val SINGLE_USER_STATE_2 = UserStateJson(
activeUserId = USER_ID_2,
accounts = mapOf(
USER_ID_2 to ACCOUNT_2,
),
)
private val MULTI_USER_STATE = UserStateJson(
activeUserId = USER_ID_1,
accounts = mapOf(
USER_ID_1 to ACCOUNT_1,
USER_ID_2 to ACCOUNT_2,
),
)

View file

@ -7,6 +7,7 @@ import android.content.SharedPreferences
*/
class FakeSharedPreferences : SharedPreferences {
private val sharedPreferences: MutableMap<String, Any?> = mutableMapOf()
private val listeners = mutableSetOf<SharedPreferences.OnSharedPreferenceChangeListener>()
override fun contains(key: String): Boolean =
sharedPreferences.containsKey(key)
@ -36,17 +37,13 @@ class FakeSharedPreferences : SharedPreferences {
override fun registerOnSharedPreferenceChangeListener(
listener: SharedPreferences.OnSharedPreferenceChangeListener,
) {
throw NotImplementedError(
"registerOnSharedPreferenceChangeListener is not currently implemented.",
)
listeners += listener
}
override fun unregisterOnSharedPreferenceChangeListener(
listener: SharedPreferences.OnSharedPreferenceChangeListener,
) {
throw NotImplementedError(
"unregisterOnSharedPreferenceChangeListener is not currently implemented.",
)
listeners -= listener
}
private inline fun <reified T> getValue(
@ -61,6 +58,13 @@ class FakeSharedPreferences : SharedPreferences {
sharedPreferences.apply {
clear()
putAll(pendingSharedPreferences)
// Notify listeners
listeners.forEach { listener ->
pendingSharedPreferences.keys.forEach { key ->
listener.onSharedPreferenceChanged(this@FakeSharedPreferences, key)
}
}
}
}